mirror of
https://github.com/immich-app/immich.git
synced 2025-12-19 17:23:21 +03:00
feat(ml)!: customizable ML settings (#3891)
* consolidated endpoints, added live configuration * added ml settings to server * added settings dashboard * updated deps, fixed typos * simplified modelconfig updated tests * Added ml setting accordion for admin page updated tests * merge `clipText` and `clipVision` * added face distance setting clarified setting * add clip mode in request, dropdown for face models * polished ml settings updated descriptions * update clip field on error * removed unused import * add description for image classification threshold * pin safetensors for arm wheel updated poetry lock * moved dto * set model type only in ml repository * revert form-data package install use fetch instead of axios * added slotted description with link updated facial recognition description clarified effect of disabling tasks * validation before model load * removed unnecessary getconfig call * added migration * updated api updated api updated api --------- Co-authored-by: Alex Tran <alex.tran1502@gmail.com>
This commit is contained in:
@@ -115,6 +115,7 @@ describe(FacialRecognitionService.name, () => {
|
||||
personMock = newPersonRepositoryMock();
|
||||
searchMock = newSearchRepositoryMock();
|
||||
storageMock = newStorageRepositoryMock();
|
||||
configMock = newSystemConfigRepositoryMock();
|
||||
|
||||
mediaMock.crop.mockResolvedValue(croppedFace);
|
||||
|
||||
@@ -179,9 +180,18 @@ describe(FacialRecognitionService.name, () => {
|
||||
machineLearningMock.detectFaces.mockResolvedValue([]);
|
||||
assetMock.getByIds.mockResolvedValue([assetStub.image]);
|
||||
await sut.handleRecognizeFaces({ id: assetStub.image.id });
|
||||
expect(machineLearningMock.detectFaces).toHaveBeenCalledWith('http://immich-machine-learning:3003', {
|
||||
imagePath: assetStub.image.resizePath,
|
||||
});
|
||||
expect(machineLearningMock.detectFaces).toHaveBeenCalledWith(
|
||||
'http://immich-machine-learning:3003',
|
||||
{
|
||||
imagePath: assetStub.image.resizePath,
|
||||
},
|
||||
{
|
||||
enabled: true,
|
||||
maxDistance: 0.6,
|
||||
minScore: 0.7,
|
||||
modelName: 'buffalo_l',
|
||||
},
|
||||
);
|
||||
expect(faceMock.create).not.toHaveBeenCalled();
|
||||
expect(jobMock.queue).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
@@ -32,7 +32,7 @@ export class FacialRecognitionService {
|
||||
|
||||
async handleQueueRecognizeFaces({ force }: IBaseJob) {
|
||||
const { machineLearning } = await this.configCore.getConfig();
|
||||
if (!machineLearning.enabled || !machineLearning.facialRecognitionEnabled) {
|
||||
if (!machineLearning.enabled || !machineLearning.facialRecognition.enabled) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -59,7 +59,7 @@ export class FacialRecognitionService {
|
||||
|
||||
async handleRecognizeFaces({ id }: IEntityJob) {
|
||||
const { machineLearning } = await this.configCore.getConfig();
|
||||
if (!machineLearning.enabled || !machineLearning.facialRecognitionEnabled) {
|
||||
if (!machineLearning.enabled || !machineLearning.facialRecognition.enabled) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -68,7 +68,11 @@ export class FacialRecognitionService {
|
||||
return false;
|
||||
}
|
||||
|
||||
const faces = await this.machineLearning.detectFaces(machineLearning.url, { imagePath: asset.resizePath });
|
||||
const faces = await this.machineLearning.detectFaces(
|
||||
machineLearning.url,
|
||||
{ imagePath: asset.resizePath },
|
||||
machineLearning.facialRecognition,
|
||||
);
|
||||
|
||||
this.logger.debug(`${faces.length} faces detected in ${asset.resizePath}`);
|
||||
this.logger.verbose(faces.map((face) => ({ ...face, embedding: `float[${face.embedding.length}]` })));
|
||||
@@ -80,7 +84,7 @@ export class FacialRecognitionService {
|
||||
|
||||
// try to find a matching face and link to the associated person
|
||||
// The closer to 0, the better the match. Range is from 0 to 2
|
||||
if (faceSearchResult.total && faceSearchResult.distances[0] < 0.6) {
|
||||
if (faceSearchResult.total && faceSearchResult.distances[0] <= machineLearning.facialRecognition.maxDistance) {
|
||||
this.logger.verbose(`Match face with distance ${faceSearchResult.distances[0]}`);
|
||||
personId = faceSearchResult.items[0].personId;
|
||||
}
|
||||
@@ -115,7 +119,7 @@ export class FacialRecognitionService {
|
||||
|
||||
async handleGenerateFaceThumbnail(data: IFaceThumbnailJob) {
|
||||
const { machineLearning } = await this.configCore.getConfig();
|
||||
if (!machineLearning.enabled || !machineLearning.facialRecognitionEnabled) {
|
||||
if (!machineLearning.enabled || !machineLearning.facialRecognition.enabled) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -86,6 +86,7 @@ export interface ISearchRepository {
|
||||
deleteAssets(ids: string[]): Promise<void>;
|
||||
deleteFaces(ids: string[]): Promise<void>;
|
||||
deleteAllFaces(): Promise<number>;
|
||||
updateCLIPField(num_dim: number): Promise<void>;
|
||||
|
||||
searchAlbums(query: string, filters: SearchFilter): Promise<SearchResult<AlbumEntity>>;
|
||||
searchAssets(query: string, filters: SearchFilter): Promise<SearchResult<AssetEntity>>;
|
||||
|
||||
@@ -121,15 +121,18 @@ export class SearchService {
|
||||
await this.configCore.requireFeature(FeatureFlag.SEARCH);
|
||||
|
||||
const query = dto.q || dto.query || '*';
|
||||
const hasClip = machineLearning.enabled && machineLearning.clipEncodeEnabled;
|
||||
const hasClip = machineLearning.enabled && machineLearning.clip.enabled;
|
||||
const strategy = dto.clip && hasClip ? SearchStrategy.CLIP : SearchStrategy.TEXT;
|
||||
const filters = { userId: authUser.id, ...dto };
|
||||
|
||||
let assets: SearchResult<AssetEntity>;
|
||||
switch (strategy) {
|
||||
case SearchStrategy.CLIP:
|
||||
const clip = await this.machineLearning.encodeText(machineLearning.url, query);
|
||||
assets = await this.searchRepository.vectorSearch(clip, filters);
|
||||
const {
|
||||
machineLearning: { clip },
|
||||
} = await this.configCore.getConfig();
|
||||
const embedding = await this.machineLearning.encodeText(machineLearning.url, { text: query }, clip);
|
||||
assets = await this.searchRepository.vectorSearch(embedding, filters);
|
||||
break;
|
||||
case SearchStrategy.TEXT:
|
||||
default:
|
||||
|
||||
1
server/src/domain/smart-info/dto/index.ts
Normal file
1
server/src/domain/smart-info/dto/index.ts
Normal file
@@ -0,0 +1 @@
|
||||
export * from './model-config.dto';
|
||||
50
server/src/domain/smart-info/dto/model-config.dto.ts
Normal file
50
server/src/domain/smart-info/dto/model-config.dto.ts
Normal file
@@ -0,0 +1,50 @@
|
||||
import { ApiProperty } from '@nestjs/swagger';
|
||||
import { Type } from 'class-transformer';
|
||||
import { IsBoolean, IsEnum, IsNotEmpty, IsNumber, IsOptional, IsString, Max, Min } from 'class-validator';
|
||||
import { CLIPMode, ModelType } from '../machine-learning.interface';
|
||||
|
||||
export class ModelConfig {
|
||||
@IsBoolean()
|
||||
enabled!: boolean;
|
||||
|
||||
@IsString()
|
||||
@IsNotEmpty()
|
||||
modelName!: string;
|
||||
|
||||
@IsEnum(ModelType)
|
||||
@IsOptional()
|
||||
@ApiProperty({ enumName: 'ModelType', enum: ModelType })
|
||||
modelType?: ModelType;
|
||||
}
|
||||
|
||||
export class ClassificationConfig extends ModelConfig {
|
||||
@IsNumber()
|
||||
@Min(0)
|
||||
@Max(1)
|
||||
@Type(() => Number)
|
||||
@ApiProperty({ type: 'integer' })
|
||||
minScore!: number;
|
||||
}
|
||||
|
||||
export class CLIPConfig extends ModelConfig {
|
||||
@IsEnum(CLIPMode)
|
||||
@IsOptional()
|
||||
@ApiProperty({ enumName: 'CLIPMode', enum: CLIPMode })
|
||||
mode?: CLIPMode;
|
||||
}
|
||||
|
||||
export class RecognitionConfig extends ModelConfig {
|
||||
@IsNumber()
|
||||
@Min(0)
|
||||
@Max(1)
|
||||
@Type(() => Number)
|
||||
@ApiProperty({ type: 'integer' })
|
||||
minScore!: number;
|
||||
|
||||
@IsNumber()
|
||||
@Min(0)
|
||||
@Max(2)
|
||||
@Type(() => Number)
|
||||
@ApiProperty({ type: 'integer' })
|
||||
maxDistance!: number;
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
export * from './dto';
|
||||
export * from './machine-learning.interface';
|
||||
export * from './smart-info.repository';
|
||||
export * from './smart-info.service';
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
import { ClassificationConfig, CLIPConfig, RecognitionConfig } from './dto';
|
||||
|
||||
export const IMachineLearningRepository = 'IMachineLearningRepository';
|
||||
|
||||
export interface MachineLearningInput {
|
||||
export interface VisionModelInput {
|
||||
imagePath: string;
|
||||
}
|
||||
|
||||
export interface TextModelInput {
|
||||
text: string;
|
||||
}
|
||||
|
||||
export interface BoundingBox {
|
||||
x1: number;
|
||||
y1: number;
|
||||
@@ -19,9 +25,20 @@ export interface DetectFaceResult {
|
||||
embedding: number[];
|
||||
}
|
||||
|
||||
export interface IMachineLearningRepository {
|
||||
classifyImage(url: string, input: MachineLearningInput): Promise<string[]>;
|
||||
encodeImage(url: string, input: MachineLearningInput): Promise<number[]>;
|
||||
encodeText(url: string, input: string): Promise<number[]>;
|
||||
detectFaces(url: string, input: MachineLearningInput): Promise<DetectFaceResult[]>;
|
||||
export enum ModelType {
|
||||
IMAGE_CLASSIFICATION = 'image-classification',
|
||||
FACIAL_RECOGNITION = 'facial-recognition',
|
||||
CLIP = 'clip',
|
||||
}
|
||||
|
||||
export enum CLIPMode {
|
||||
VISION = 'vision',
|
||||
TEXT = 'text',
|
||||
}
|
||||
|
||||
export interface IMachineLearningRepository {
|
||||
classifyImage(url: string, input: VisionModelInput, config: ClassificationConfig): Promise<string[]>;
|
||||
encodeImage(url: string, input: VisionModelInput, config: CLIPConfig): Promise<number[]>;
|
||||
encodeText(url: string, input: TextModelInput, config: CLIPConfig): Promise<number[]>;
|
||||
detectFaces(url: string, input: VisionModelInput, config: RecognitionConfig): Promise<DetectFaceResult[]>;
|
||||
}
|
||||
|
||||
@@ -84,9 +84,13 @@ describe(SmartInfoService.name, () => {
|
||||
|
||||
await sut.handleClassifyImage({ id: asset.id });
|
||||
|
||||
expect(machineMock.classifyImage).toHaveBeenCalledWith('http://immich-machine-learning:3003', {
|
||||
imagePath: 'path/to/resize.ext',
|
||||
});
|
||||
expect(machineMock.classifyImage).toHaveBeenCalledWith(
|
||||
'http://immich-machine-learning:3003',
|
||||
{
|
||||
imagePath: 'path/to/resize.ext',
|
||||
},
|
||||
{ enabled: true, minScore: 0.9, modelName: 'microsoft/resnet-50' },
|
||||
);
|
||||
expect(smartMock.upsert).toHaveBeenCalledWith({
|
||||
assetId: 'asset-1',
|
||||
tags: ['tag1', 'tag2', 'tag3'],
|
||||
@@ -141,13 +145,16 @@ describe(SmartInfoService.name, () => {
|
||||
});
|
||||
|
||||
it('should save the returned objects', async () => {
|
||||
smartMock.upsert.mockResolvedValue();
|
||||
machineMock.encodeImage.mockResolvedValue([0.01, 0.02, 0.03]);
|
||||
|
||||
await sut.handleEncodeClip({ id: asset.id });
|
||||
|
||||
expect(machineMock.encodeImage).toHaveBeenCalledWith('http://immich-machine-learning:3003', {
|
||||
imagePath: 'path/to/resize.ext',
|
||||
});
|
||||
expect(machineMock.encodeImage).toHaveBeenCalledWith(
|
||||
'http://immich-machine-learning:3003',
|
||||
{ imagePath: 'path/to/resize.ext' },
|
||||
{ enabled: true, modelName: 'ViT-B-32::openai' },
|
||||
);
|
||||
expect(smartMock.upsert).toHaveBeenCalledWith({
|
||||
assetId: 'asset-1',
|
||||
clipEmbedding: [0.01, 0.02, 0.03],
|
||||
|
||||
@@ -22,7 +22,7 @@ export class SmartInfoService {
|
||||
|
||||
async handleQueueObjectTagging({ force }: IBaseJob) {
|
||||
const { machineLearning } = await this.configCore.getConfig();
|
||||
if (!machineLearning.enabled || !machineLearning.tagImageEnabled) {
|
||||
if (!machineLearning.enabled || !machineLearning.classification.enabled) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ export class SmartInfoService {
|
||||
|
||||
async handleClassifyImage({ id }: IEntityJob) {
|
||||
const { machineLearning } = await this.configCore.getConfig();
|
||||
if (!machineLearning.enabled || !machineLearning.tagImageEnabled) {
|
||||
if (!machineLearning.enabled || !machineLearning.classification.enabled) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -52,7 +52,11 @@ export class SmartInfoService {
|
||||
return false;
|
||||
}
|
||||
|
||||
const tags = await this.machineLearning.classifyImage(machineLearning.url, { imagePath: asset.resizePath });
|
||||
const tags = await this.machineLearning.classifyImage(
|
||||
machineLearning.url,
|
||||
{ imagePath: asset.resizePath },
|
||||
machineLearning.classification,
|
||||
);
|
||||
await this.repository.upsert({ assetId: asset.id, tags });
|
||||
|
||||
return true;
|
||||
@@ -60,7 +64,7 @@ export class SmartInfoService {
|
||||
|
||||
async handleQueueEncodeClip({ force }: IBaseJob) {
|
||||
const { machineLearning } = await this.configCore.getConfig();
|
||||
if (!machineLearning.enabled || !machineLearning.clipEncodeEnabled) {
|
||||
if (!machineLearning.enabled || !machineLearning.clip.enabled) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -81,7 +85,7 @@ export class SmartInfoService {
|
||||
|
||||
async handleEncodeClip({ id }: IEntityJob) {
|
||||
const { machineLearning } = await this.configCore.getConfig();
|
||||
if (!machineLearning.enabled || !machineLearning.clipEncodeEnabled) {
|
||||
if (!machineLearning.enabled || !machineLearning.clip.enabled) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -90,7 +94,12 @@ export class SmartInfoService {
|
||||
return false;
|
||||
}
|
||||
|
||||
const clipEmbedding = await this.machineLearning.encodeImage(machineLearning.url, { imagePath: asset.resizePath });
|
||||
const clipEmbedding = await this.machineLearning.encodeImage(
|
||||
machineLearning.url,
|
||||
{ imagePath: asset.resizePath },
|
||||
machineLearning.clip,
|
||||
);
|
||||
|
||||
await this.repository.upsert({ assetId: asset.id, clipEmbedding: clipEmbedding });
|
||||
|
||||
return true;
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import { IsBoolean, IsUrl, ValidateIf } from 'class-validator';
|
||||
import { ClassificationConfig, CLIPConfig, RecognitionConfig } from '@app/domain';
|
||||
import { Type } from 'class-transformer';
|
||||
import { IsBoolean, IsObject, IsUrl, ValidateIf, ValidateNested } from 'class-validator';
|
||||
|
||||
export class SystemConfigMachineLearningDto {
|
||||
@IsBoolean()
|
||||
@@ -8,12 +10,18 @@ export class SystemConfigMachineLearningDto {
|
||||
@ValidateIf((dto) => dto.enabled)
|
||||
url!: string;
|
||||
|
||||
@IsBoolean()
|
||||
clipEncodeEnabled!: boolean;
|
||||
@Type(() => ClassificationConfig)
|
||||
@ValidateNested()
|
||||
@IsObject()
|
||||
classification!: ClassificationConfig;
|
||||
|
||||
@IsBoolean()
|
||||
facialRecognitionEnabled!: boolean;
|
||||
@Type(() => CLIPConfig)
|
||||
@ValidateNested()
|
||||
@IsObject()
|
||||
clip!: CLIPConfig;
|
||||
|
||||
@IsBoolean()
|
||||
tagImageEnabled!: boolean;
|
||||
@Type(() => RecognitionConfig)
|
||||
@ValidateNested()
|
||||
@IsObject()
|
||||
facialRecognition!: RecognitionConfig;
|
||||
}
|
||||
|
||||
@@ -47,12 +47,25 @@ export const defaults = Object.freeze<SystemConfig>({
|
||||
[QueueName.THUMBNAIL_GENERATION]: { concurrency: 5 },
|
||||
[QueueName.VIDEO_CONVERSION]: { concurrency: 1 },
|
||||
},
|
||||
|
||||
machineLearning: {
|
||||
enabled: process.env.IMMICH_MACHINE_LEARNING_ENABLED !== 'false',
|
||||
url: process.env.IMMICH_MACHINE_LEARNING_URL || 'http://immich-machine-learning:3003',
|
||||
facialRecognitionEnabled: true,
|
||||
tagImageEnabled: true,
|
||||
clipEncodeEnabled: true,
|
||||
classification: {
|
||||
enabled: true,
|
||||
modelName: 'microsoft/resnet-50',
|
||||
minScore: 0.9,
|
||||
},
|
||||
clip: {
|
||||
enabled: true,
|
||||
modelName: 'ViT-B-32::openai',
|
||||
},
|
||||
facialRecognition: {
|
||||
enabled: true,
|
||||
modelName: 'buffalo_l',
|
||||
minScore: 0.7,
|
||||
maxDistance: 0.6,
|
||||
},
|
||||
},
|
||||
oauth: {
|
||||
enabled: false,
|
||||
@@ -143,9 +156,9 @@ export class SystemConfigCore {
|
||||
const mlEnabled = config.machineLearning.enabled;
|
||||
|
||||
return {
|
||||
[FeatureFlag.CLIP_ENCODE]: mlEnabled && config.machineLearning.clipEncodeEnabled,
|
||||
[FeatureFlag.FACIAL_RECOGNITION]: mlEnabled && config.machineLearning.facialRecognitionEnabled,
|
||||
[FeatureFlag.TAG_IMAGE]: mlEnabled && config.machineLearning.tagImageEnabled,
|
||||
[FeatureFlag.CLIP_ENCODE]: mlEnabled && config.machineLearning.clip.enabled,
|
||||
[FeatureFlag.FACIAL_RECOGNITION]: mlEnabled && config.machineLearning.facialRecognition.enabled,
|
||||
[FeatureFlag.TAG_IMAGE]: mlEnabled && config.machineLearning.classification.enabled,
|
||||
[FeatureFlag.SIDECAR]: true,
|
||||
[FeatureFlag.SEARCH]: process.env.TYPESENSE_ENABLED !== 'false',
|
||||
|
||||
@@ -230,7 +243,7 @@ export class SystemConfigCore {
|
||||
_.set(config, key, value);
|
||||
}
|
||||
|
||||
return _.defaultsDeep(config, defaults) as SystemConfig;
|
||||
return plainToClass(SystemConfigDto, _.defaultsDeep(config, defaults));
|
||||
}
|
||||
|
||||
private async loadFromFile(filepath: string, force = false) {
|
||||
|
||||
@@ -49,9 +49,21 @@ const updatedConfig = Object.freeze<SystemConfig>({
|
||||
machineLearning: {
|
||||
enabled: true,
|
||||
url: 'http://immich-machine-learning:3003',
|
||||
facialRecognitionEnabled: true,
|
||||
tagImageEnabled: true,
|
||||
clipEncodeEnabled: true,
|
||||
classification: {
|
||||
enabled: true,
|
||||
modelName: 'microsoft/resnet-50',
|
||||
minScore: 0.9,
|
||||
},
|
||||
clip: {
|
||||
enabled: true,
|
||||
modelName: 'ViT-B-32::openai',
|
||||
},
|
||||
facialRecognition: {
|
||||
enabled: true,
|
||||
modelName: 'buffalo_l',
|
||||
minScore: 0.7,
|
||||
maxDistance: 0.6,
|
||||
},
|
||||
},
|
||||
oauth: {
|
||||
autoLaunch: true,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { QueueName } from '@app/domain/job/job.constants';
|
||||
import { QueueName } from '@app/domain';
|
||||
import { Column, Entity, PrimaryColumn } from 'typeorm';
|
||||
|
||||
@Entity('system_config')
|
||||
@@ -39,9 +39,18 @@ export enum SystemConfigKey {
|
||||
|
||||
MACHINE_LEARNING_ENABLED = 'machineLearning.enabled',
|
||||
MACHINE_LEARNING_URL = 'machineLearning.url',
|
||||
MACHINE_LEARNING_FACIAL_RECOGNITION_ENABLED = 'machineLearning.facialRecognitionEnabled',
|
||||
MACHINE_LEARNING_TAG_IMAGE_ENABLED = 'machineLearning.tagImageEnabled',
|
||||
MACHINE_LEARNING_CLIP_ENCODE_ENABLED = 'machineLearning.clipEncodeEnabled',
|
||||
|
||||
MACHINE_LEARNING_CLASSIFICATION_ENABLED = 'machineLearning.classification.enabled',
|
||||
MACHINE_LEARNING_CLASSIFICATION_MODEL_NAME = 'machineLearning.classification.modelName',
|
||||
MACHINE_LEARNING_CLASSIFICATION_MIN_SCORE = 'machineLearning.classification.minScore',
|
||||
|
||||
MACHINE_LEARNING_CLIP_ENABLED = 'machineLearning.clip.enabled',
|
||||
MACHINE_LEARNING_CLIP_MODEL_NAME = 'machineLearning.clip.modelName',
|
||||
|
||||
MACHINE_LEARNING_FACIAL_RECOGNITION_ENABLED = 'machineLearning.facialRecognition.enabled',
|
||||
MACHINE_LEARNING_FACIAL_RECOGNITION_MODEL_NAME = 'machineLearning.facialRecognition.modelName',
|
||||
MACHINE_LEARNING_FACIAL_RECOGNITION_MIN_SCORE = 'machineLearning.facialRecognition.minScore',
|
||||
MACHINE_LEARNING_FACIAL_RECOGNITION_MAX_DISTANCE = 'machineLearning.facialRecognition.maxDistance',
|
||||
|
||||
OAUTH_ENABLED = 'oauth.enabled',
|
||||
OAUTH_ISSUER_URL = 'oauth.issuerUrl',
|
||||
@@ -114,9 +123,21 @@ export interface SystemConfig {
|
||||
machineLearning: {
|
||||
enabled: boolean;
|
||||
url: string;
|
||||
clipEncodeEnabled: boolean;
|
||||
facialRecognitionEnabled: boolean;
|
||||
tagImageEnabled: boolean;
|
||||
classification: {
|
||||
enabled: boolean;
|
||||
modelName: string;
|
||||
minScore: number;
|
||||
};
|
||||
clip: {
|
||||
enabled: boolean;
|
||||
modelName: string;
|
||||
};
|
||||
facialRecognition: {
|
||||
enabled: boolean;
|
||||
modelName: string;
|
||||
minScore: number;
|
||||
maxDistance: number;
|
||||
};
|
||||
};
|
||||
oauth: {
|
||||
enabled: boolean;
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
import { MigrationInterface, QueryRunner } from "typeorm"
|
||||
|
||||
export class RenameMLEnableFlags1693236627291 implements MigrationInterface {
|
||||
public async up(queryRunner: QueryRunner): Promise<void> {
|
||||
await queryRunner.query(`
|
||||
UPDATE system_config SET key = CASE
|
||||
WHEN key = 'ffmpeg.classificationEnabled' THEN 'ffmpeg.classification.enabled'
|
||||
WHEN key = 'ffmpeg.clipEnabled' THEN 'ffmpeg.clip.enabled'
|
||||
WHEN key = 'ffmpeg.facialRecognitionEnabled' THEN 'ffmpeg.facialRecognition.enabled'
|
||||
ELSE key
|
||||
END
|
||||
`);
|
||||
}
|
||||
|
||||
public async down(queryRunner: QueryRunner): Promise<void> {
|
||||
await queryRunner.query(`
|
||||
UPDATE system_config SET key = CASE
|
||||
WHEN key = 'ffmpeg.classification.enabled' THEN 'ffmpeg.classificationEnabled'
|
||||
WHEN key = 'ffmpeg.clip.enabled' THEN 'ffmpeg.clipEnabled'
|
||||
WHEN key = 'ffmpeg.facialRecognition.enabled' THEN 'ffmpeg.facialRecognitionEnabled'
|
||||
ELSE key
|
||||
END
|
||||
`);
|
||||
}
|
||||
}
|
||||
@@ -1,29 +1,65 @@
|
||||
import { DetectFaceResult, IMachineLearningRepository, MachineLearningInput } from '@app/domain';
|
||||
import {
|
||||
ClassificationConfig,
|
||||
CLIPConfig,
|
||||
CLIPMode,
|
||||
DetectFaceResult,
|
||||
IMachineLearningRepository,
|
||||
ModelConfig,
|
||||
ModelType,
|
||||
RecognitionConfig,
|
||||
TextModelInput,
|
||||
VisionModelInput,
|
||||
} from '@app/domain';
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import axios from 'axios';
|
||||
import { createReadStream } from 'fs';
|
||||
|
||||
const client = axios.create();
|
||||
import { readFile } from 'fs/promises';
|
||||
|
||||
@Injectable()
|
||||
export class MachineLearningRepository implements IMachineLearningRepository {
|
||||
private post<T>(input: MachineLearningInput, endpoint: string): Promise<T> {
|
||||
return client.post<T>(endpoint, createReadStream(input.imagePath)).then((res) => res.data);
|
||||
private async post<T>(url: string, input: TextModelInput | VisionModelInput, config: ModelConfig): Promise<T> {
|
||||
const formData = await this.getFormData(input, config);
|
||||
const res = await fetch(`${url}/predict`, { method: 'POST', body: formData });
|
||||
return res.json();
|
||||
}
|
||||
|
||||
classifyImage(url: string, input: MachineLearningInput): Promise<string[]> {
|
||||
return this.post<string[]>(input, `${url}/image-classifier/tag-image`);
|
||||
classifyImage(url: string, input: VisionModelInput, config: ClassificationConfig): Promise<string[]> {
|
||||
return this.post<string[]>(url, input, { ...config, modelType: ModelType.IMAGE_CLASSIFICATION });
|
||||
}
|
||||
|
||||
detectFaces(url: string, input: MachineLearningInput): Promise<DetectFaceResult[]> {
|
||||
return this.post<DetectFaceResult[]>(input, `${url}/facial-recognition/detect-faces`);
|
||||
detectFaces(url: string, input: VisionModelInput, config: RecognitionConfig): Promise<DetectFaceResult[]> {
|
||||
return this.post<DetectFaceResult[]>(url, input, { ...config, modelType: ModelType.FACIAL_RECOGNITION });
|
||||
}
|
||||
|
||||
encodeImage(url: string, input: MachineLearningInput): Promise<number[]> {
|
||||
return this.post<number[]>(input, `${url}/sentence-transformer/encode-image`);
|
||||
encodeImage(url: string, input: VisionModelInput, config: CLIPConfig): Promise<number[]> {
|
||||
return this.post<number[]>(url, input, {
|
||||
...config,
|
||||
modelType: ModelType.CLIP,
|
||||
mode: CLIPMode.VISION,
|
||||
} as CLIPConfig);
|
||||
}
|
||||
|
||||
encodeText(url: string, input: string): Promise<number[]> {
|
||||
return client.post<number[]>(`${url}/sentence-transformer/encode-text`, { text: input }).then((res) => res.data);
|
||||
encodeText(url: string, input: TextModelInput, config: CLIPConfig): Promise<number[]> {
|
||||
return this.post<number[]>(url, input, { ...config, modelType: ModelType.CLIP, mode: CLIPMode.TEXT } as CLIPConfig);
|
||||
}
|
||||
|
||||
async getFormData(input: TextModelInput | VisionModelInput, config: ModelConfig): Promise<FormData> {
|
||||
const formData = new FormData();
|
||||
const { modelName, modelType, ...options } = config;
|
||||
|
||||
formData.append('modelName', modelName);
|
||||
if (modelType) {
|
||||
formData.append('modelType', modelType);
|
||||
}
|
||||
if (options) {
|
||||
formData.append('options', JSON.stringify(options));
|
||||
}
|
||||
if ('imagePath' in input) {
|
||||
formData.append('image', new Blob([await readFile(input.imagePath)]));
|
||||
} else if ('text' in input) {
|
||||
formData.append('text', input.text);
|
||||
} else {
|
||||
throw new Error('Invalid input');
|
||||
}
|
||||
|
||||
return formData;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,6 +52,8 @@ export class TypesenseRepository implements ISearchRepository {
|
||||
private logger = new Logger(TypesenseRepository.name);
|
||||
|
||||
private _client: Client | null = null;
|
||||
private _updateCLIPLock = false;
|
||||
|
||||
private get client(): Client {
|
||||
if (!this._client) {
|
||||
throw new Error('Typesense client not available (no apiKey was provided)');
|
||||
@@ -141,7 +143,7 @@ export class TypesenseRepository implements ISearchRepository {
|
||||
await this.updateAlias(collection);
|
||||
}
|
||||
} catch (error: any) {
|
||||
this.handleError(error);
|
||||
await this.handleError(error);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -221,6 +223,30 @@ export class TypesenseRepository implements ISearchRepository {
|
||||
return records.num_deleted;
|
||||
}
|
||||
|
||||
async deleteAllAssets(): Promise<number> {
|
||||
const records = await this.client.collections(assetSchema.name).documents().delete({ filter_by: 'ownerId:!=null' });
|
||||
return records.num_deleted;
|
||||
}
|
||||
|
||||
async updateCLIPField(num_dim: number): Promise<void> {
|
||||
const clipField = assetSchema.fields?.find((field) => field.name === 'smartInfo.clipEmbedding');
|
||||
if (clipField && !this._updateCLIPLock) {
|
||||
try {
|
||||
this._updateCLIPLock = true;
|
||||
clipField.num_dim = num_dim;
|
||||
await this.deleteAllAssets();
|
||||
await this.client
|
||||
.collections(assetSchema.name)
|
||||
.update({ fields: [{ name: 'smartInfo.clipEmbedding', drop: true } as any, clipField] });
|
||||
this.logger.log(`Successfully updated CLIP dimensions to ${num_dim}`);
|
||||
} catch (err: any) {
|
||||
this.logger.error(`Error while updating CLIP field: ${err.message}`);
|
||||
} finally {
|
||||
this._updateCLIPLock = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async delete(collection: SearchCollection, ids: string[]): Promise<void> {
|
||||
await this.client
|
||||
.collections(schemaMap[collection].name)
|
||||
@@ -326,21 +352,34 @@ export class TypesenseRepository implements ISearchRepository {
|
||||
} as SearchResult<T>;
|
||||
}
|
||||
|
||||
private handleError(error: any) {
|
||||
private async handleError(error: any) {
|
||||
this.logger.error('Unable to index documents');
|
||||
const results = error.importResults || [];
|
||||
let dimsChanged = false;
|
||||
for (const result of results) {
|
||||
try {
|
||||
result.document = JSON.parse(result.document);
|
||||
if (result.error.includes('Field `smartInfo.clipEmbedding` must have')) {
|
||||
dimsChanged = true;
|
||||
this.logger.warn(
|
||||
`CLIP embedding dimensions have changed, now ${result.document.smartInfo.clipEmbedding.length} dims. Updating schema...`,
|
||||
);
|
||||
await this.updateCLIPField(result.document.smartInfo.clipEmbedding.length);
|
||||
break;
|
||||
}
|
||||
|
||||
if (result.document?.smartInfo?.clipEmbedding) {
|
||||
result.document.smartInfo.clipEmbedding = '<truncated>';
|
||||
}
|
||||
} catch {}
|
||||
} catch (err: any) {
|
||||
this.logger.error(`Error while updating CLIP field: ${(err.message, err.stack)}`);
|
||||
}
|
||||
}
|
||||
|
||||
this.logger.verbose(JSON.stringify(results, null, 2));
|
||||
if (!dimsChanged) {
|
||||
this.logger.log(JSON.stringify(results, null, 2));
|
||||
}
|
||||
}
|
||||
|
||||
private async updateAlias(collection: SearchCollection) {
|
||||
const schema = schemaMap[collection];
|
||||
const alias = await this.client
|
||||
|
||||
Reference in New Issue
Block a user