feat: authenticate websocket requests in maintenance mode

This commit is contained in:
izzy
2025-11-19 15:27:44 +00:00
parent d040de2d52
commit c090a1a9d9
2 changed files with 25 additions and 3 deletions

View File

@@ -124,6 +124,7 @@ export class MaintenanceModule {
@Inject(IWorker) private worker: ImmichWorker, @Inject(IWorker) private worker: ImmichWorker,
logger: LoggingRepository, logger: LoggingRepository,
private maintenanceWorkerService: MaintenanceWorkerService, private maintenanceWorkerService: MaintenanceWorkerService,
private maintenanceWebsocketRepository: MaintenanceWebsocketRepository,
) { ) {
logger.setAppName(this.worker); logger.setAppName(this.worker);
} }
@@ -131,6 +132,10 @@ export class MaintenanceModule {
async onModuleInit() { async onModuleInit() {
StorageCore.setMediaLocation(this.maintenanceWorkerService.detectMediaLocation()); StorageCore.setMediaLocation(this.maintenanceWorkerService.detectMediaLocation());
this.maintenanceWebsocketRepository.setAuthFn(async (client) =>
this.maintenanceWorkerService.authenticate(client.request.headers),
);
await this.maintenanceWorkerService.logSecret(); await this.maintenanceWorkerService.logSecret();
} }
} }

View File

@@ -7,6 +7,7 @@ import {
WebSocketServer, WebSocketServer,
} from '@nestjs/websockets'; } from '@nestjs/websockets';
import { Server, Socket } from 'socket.io'; import { Server, Socket } from 'socket.io';
import { MaintenanceAuthDto } from 'src/dtos/maintenance.dto';
import { AppRepository } from 'src/repositories/app.repository'; import { AppRepository } from 'src/repositories/app.repository';
import { AppRestartEvent, ArgsOf } from 'src/repositories/event.repository'; import { AppRestartEvent, ArgsOf } from 'src/repositories/event.repository';
import { LoggingRepository } from 'src/repositories/logging.repository'; import { LoggingRepository } from 'src/repositories/logging.repository';
@@ -18,6 +19,8 @@ export interface ClientEventMap {
AppRestartV1: [AppRestartEvent]; AppRestartV1: [AppRestartEvent];
} }
type AuthFn = (client: Socket) => Promise<MaintenanceAuthDto>;
@WebSocketGateway({ @WebSocketGateway({
cors: true, cors: true,
path: '/api/socket.io', path: '/api/socket.io',
@@ -25,6 +28,8 @@ export interface ClientEventMap {
}) })
@Injectable() @Injectable()
export class MaintenanceWebsocketRepository implements OnGatewayConnection, OnGatewayDisconnect, OnGatewayInit { export class MaintenanceWebsocketRepository implements OnGatewayConnection, OnGatewayDisconnect, OnGatewayInit {
private authFn?: AuthFn;
@WebSocketServer() @WebSocketServer()
private websocketServer?: Server; private websocketServer?: Server;
@@ -49,11 +54,23 @@ export class MaintenanceWebsocketRepository implements OnGatewayConnection, OnGa
this.websocketServer?.serverSideEmit(event, ...args); this.websocketServer?.serverSideEmit(event, ...args);
} }
handleConnection(client: Socket) { async handleConnection(client: Socket) {
this.logger.log(`Websocket Connect: ${client.id}`); try {
await this.authFn!(client);
await client.join('private');
this.logger.log(`Websocket Connect: ${client.id} (private)`);
} catch {
await client.join('public');
this.logger.log(`Websocket Connect: ${client.id} (public)`);
}
} }
handleDisconnect(client: Socket) { async handleDisconnect(client: Socket) {
this.logger.log(`Websocket Disconnect: ${client.id}`); this.logger.log(`Websocket Disconnect: ${client.id}`);
await Promise.allSettled([client.leave('private'), client.leave('public')]);
}
setAuthFn(fn: (client: Socket) => Promise<MaintenanceAuthDto>) {
this.authFn = fn;
} }
} }