diff --git a/server/src/app.module.ts b/server/src/app.module.ts index aa7dbd94fe..247958b686 100644 --- a/server/src/app.module.ts +++ b/server/src/app.module.ts @@ -124,6 +124,7 @@ export class MaintenanceModule { @Inject(IWorker) private worker: ImmichWorker, logger: LoggingRepository, private maintenanceWorkerService: MaintenanceWorkerService, + private maintenanceWebsocketRepository: MaintenanceWebsocketRepository, ) { logger.setAppName(this.worker); } @@ -131,6 +132,10 @@ export class MaintenanceModule { async onModuleInit() { StorageCore.setMediaLocation(this.maintenanceWorkerService.detectMediaLocation()); + this.maintenanceWebsocketRepository.setAuthFn(async (client) => + this.maintenanceWorkerService.authenticate(client.request.headers), + ); + await this.maintenanceWorkerService.logSecret(); } } diff --git a/server/src/maintenance/maintenance-websocket.repository.ts b/server/src/maintenance/maintenance-websocket.repository.ts index 5d8368cf69..6bc57fa71e 100644 --- a/server/src/maintenance/maintenance-websocket.repository.ts +++ b/server/src/maintenance/maintenance-websocket.repository.ts @@ -7,6 +7,7 @@ import { WebSocketServer, } from '@nestjs/websockets'; import { Server, Socket } from 'socket.io'; +import { MaintenanceAuthDto } from 'src/dtos/maintenance.dto'; import { AppRepository } from 'src/repositories/app.repository'; import { AppRestartEvent, ArgsOf } from 'src/repositories/event.repository'; import { LoggingRepository } from 'src/repositories/logging.repository'; @@ -18,6 +19,8 @@ export interface ClientEventMap { AppRestartV1: [AppRestartEvent]; } +type AuthFn = (client: Socket) => Promise; + @WebSocketGateway({ cors: true, path: '/api/socket.io', @@ -25,6 +28,8 @@ export interface ClientEventMap { }) @Injectable() export class MaintenanceWebsocketRepository implements OnGatewayConnection, OnGatewayDisconnect, OnGatewayInit { + private authFn?: AuthFn; + @WebSocketServer() private websocketServer?: Server; @@ -49,11 +54,23 @@ export class MaintenanceWebsocketRepository implements OnGatewayConnection, OnGa this.websocketServer?.serverSideEmit(event, ...args); } - handleConnection(client: Socket) { - this.logger.log(`Websocket Connect: ${client.id}`); + async handleConnection(client: Socket) { + try { + await this.authFn!(client); + await client.join('private'); + this.logger.log(`Websocket Connect: ${client.id} (private)`); + } catch { + await client.join('public'); + this.logger.log(`Websocket Connect: ${client.id} (public)`); + } } - handleDisconnect(client: Socket) { + async handleDisconnect(client: Socket) { this.logger.log(`Websocket Disconnect: ${client.id}`); + await Promise.allSettled([client.leave('private'), client.leave('public')]); + } + + setAuthFn(fn: (client: Socket) => Promise) { + this.authFn = fn; } }