From c090a1a9d960a33d7ba92da2c213e25c601c3ed7 Mon Sep 17 00:00:00 2001 From: izzy Date: Wed, 19 Nov 2025 15:27:44 +0000 Subject: [PATCH] feat: authenticate websocket requests in maintenance mode --- server/src/app.module.ts | 5 ++++ .../maintenance-websocket.repository.ts | 23 ++++++++++++++++--- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/server/src/app.module.ts b/server/src/app.module.ts index aa7dbd94fe..247958b686 100644 --- a/server/src/app.module.ts +++ b/server/src/app.module.ts @@ -124,6 +124,7 @@ export class MaintenanceModule { @Inject(IWorker) private worker: ImmichWorker, logger: LoggingRepository, private maintenanceWorkerService: MaintenanceWorkerService, + private maintenanceWebsocketRepository: MaintenanceWebsocketRepository, ) { logger.setAppName(this.worker); } @@ -131,6 +132,10 @@ export class MaintenanceModule { async onModuleInit() { StorageCore.setMediaLocation(this.maintenanceWorkerService.detectMediaLocation()); + this.maintenanceWebsocketRepository.setAuthFn(async (client) => + this.maintenanceWorkerService.authenticate(client.request.headers), + ); + await this.maintenanceWorkerService.logSecret(); } } diff --git a/server/src/maintenance/maintenance-websocket.repository.ts b/server/src/maintenance/maintenance-websocket.repository.ts index 5d8368cf69..6bc57fa71e 100644 --- a/server/src/maintenance/maintenance-websocket.repository.ts +++ b/server/src/maintenance/maintenance-websocket.repository.ts @@ -7,6 +7,7 @@ import { WebSocketServer, } from '@nestjs/websockets'; import { Server, Socket } from 'socket.io'; +import { MaintenanceAuthDto } from 'src/dtos/maintenance.dto'; import { AppRepository } from 'src/repositories/app.repository'; import { AppRestartEvent, ArgsOf } from 'src/repositories/event.repository'; import { LoggingRepository } from 'src/repositories/logging.repository'; @@ -18,6 +19,8 @@ export interface ClientEventMap { AppRestartV1: [AppRestartEvent]; } +type AuthFn = (client: Socket) => Promise; + @WebSocketGateway({ cors: true, path: '/api/socket.io', @@ -25,6 +28,8 @@ export interface ClientEventMap { }) @Injectable() export class MaintenanceWebsocketRepository implements OnGatewayConnection, OnGatewayDisconnect, OnGatewayInit { + private authFn?: AuthFn; + @WebSocketServer() private websocketServer?: Server; @@ -49,11 +54,23 @@ export class MaintenanceWebsocketRepository implements OnGatewayConnection, OnGa this.websocketServer?.serverSideEmit(event, ...args); } - handleConnection(client: Socket) { - this.logger.log(`Websocket Connect: ${client.id}`); + async handleConnection(client: Socket) { + try { + await this.authFn!(client); + await client.join('private'); + this.logger.log(`Websocket Connect: ${client.id} (private)`); + } catch { + await client.join('public'); + this.logger.log(`Websocket Connect: ${client.id} (public)`); + } } - handleDisconnect(client: Socket) { + async handleDisconnect(client: Socket) { this.logger.log(`Websocket Disconnect: ${client.id}`); + await Promise.allSettled([client.leave('private'), client.leave('public')]); + } + + setAuthFn(fn: (client: Socket) => Promise) { + this.authFn = fn; } }