feat(ml): support multiple urls (#14347)

* support multiple url

* update api

* styling

unnecessary `?.`

* update docs, make new url field go first

add load balancing section

* update tests

doc formatting

wording

wording

linting

* small styling

* `url` -> `urls`

* fix tests

* update docs

* make docusaurus happy

---------

Co-authored-by: Alex <alex.tran1502@gmail.com>
This commit is contained in:
Mert
2024-12-04 15:17:47 -05:00
committed by GitHub
parent 411878c0aa
commit 4bf1b84cc2
22 changed files with 202 additions and 73 deletions

View File

@@ -155,7 +155,7 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect
this.emitHandlers[event].push(item);
}
async emit<T extends EmitEvent>(event: T, ...args: ArgsOf<T>): Promise<void> {
emit<T extends EmitEvent>(event: T, ...args: ArgsOf<T>): Promise<void> {
return this.onEvent({ name: event, args, server: false });
}

View File

@@ -1,6 +1,7 @@
import { Injectable } from '@nestjs/common';
import { Inject, Injectable } from '@nestjs/common';
import { readFile } from 'node:fs/promises';
import { CLIPConfig } from 'src/dtos/model-config.dto';
import { ILoggerRepository } from 'src/interfaces/logger.interface';
import {
ClipTextualResponse,
ClipVisualResponse,
@@ -13,33 +14,42 @@ import {
ModelType,
} from 'src/interfaces/machine-learning.interface';
const errorPrefix = 'Machine learning request';
@Injectable()
export class MachineLearningRepository implements IMachineLearningRepository {
private async predict<T>(url: string, payload: ModelPayload, config: MachineLearningRequest): Promise<T> {
const formData = await this.getFormData(payload, config);
const res = await fetch(new URL('/predict', url), { method: 'POST', body: formData }).catch(
(error: Error | any) => {
throw new Error(`${errorPrefix} to "${url}" failed with ${error?.cause || error}`);
},
);
if (res.status >= 400) {
throw new Error(`${errorPrefix} '${JSON.stringify(config)}' failed with status ${res.status}: ${res.statusText}`);
}
return res.json();
constructor(@Inject(ILoggerRepository) private logger: ILoggerRepository) {
this.logger.setContext(MachineLearningRepository.name);
}
async detectFaces(url: string, imagePath: string, { modelName, minScore }: FaceDetectionOptions) {
private async predict<T>(urls: string[], payload: ModelPayload, config: MachineLearningRequest): Promise<T> {
const formData = await this.getFormData(payload, config);
for (const url of urls) {
try {
const response = await fetch(new URL('/predict', url), { method: 'POST', body: formData });
if (response.ok) {
return response.json();
}
this.logger.warn(
`Machine learning request to "${url}" failed with status ${response.status}: ${response.statusText}`,
);
} catch (error: Error | unknown) {
this.logger.warn(
`Machine learning request to "${url}" failed: ${error instanceof Error ? error.message : error}`,
);
}
}
throw new Error(`Machine learning request '${JSON.stringify(config)}' failed for all URLs`);
}
async detectFaces(urls: string[], imagePath: string, { modelName, minScore }: FaceDetectionOptions) {
const request = {
[ModelTask.FACIAL_RECOGNITION]: {
[ModelType.DETECTION]: { modelName, options: { minScore } },
[ModelType.RECOGNITION]: { modelName },
},
};
const response = await this.predict<FacialRecognitionResponse>(url, { imagePath }, request);
const response = await this.predict<FacialRecognitionResponse>(urls, { imagePath }, request);
return {
imageHeight: response.imageHeight,
imageWidth: response.imageWidth,
@@ -47,15 +57,15 @@ export class MachineLearningRepository implements IMachineLearningRepository {
};
}
async encodeImage(url: string, imagePath: string, { modelName }: CLIPConfig) {
async encodeImage(urls: string[], imagePath: string, { modelName }: CLIPConfig) {
const request = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: { modelName } } };
const response = await this.predict<ClipVisualResponse>(url, { imagePath }, request);
const response = await this.predict<ClipVisualResponse>(urls, { imagePath }, request);
return response[ModelTask.SEARCH];
}
async encodeText(url: string, text: string, { modelName }: CLIPConfig) {
async encodeText(urls: string[], text: string, { modelName }: CLIPConfig) {
const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName } } };
const response = await this.predict<ClipTextualResponse>(url, { text }, request);
const response = await this.predict<ClipTextualResponse>(urls, { text }, request);
return response[ModelTask.SEARCH];
}