mirror of
https://github.com/immich-app/immich.git
synced 2025-12-06 09:13:13 +03:00
fix(ml): ocr inputs not resized correctly (#23541)
* fix resizing, use pillow * unused import * linting * lanczos * optimizations fused operations unused import
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
from PIL import Image
|
||||
from rapidocr.ch_ppocr_det import TextDetector as RapidTextDetector
|
||||
from rapidocr.ch_ppocr_det.utils import DBPostProcess
|
||||
from rapidocr.inference_engine.base import FileInfo, InferSession
|
||||
from rapidocr.utils import DownloadFile, DownloadFileInput
|
||||
from rapidocr.utils.typings import EngineType, LangDet, OCRVersion, TaskType
|
||||
@@ -10,11 +12,10 @@ from rapidocr.utils.typings import ModelType as RapidModelType
|
||||
|
||||
from immich_ml.config import log
|
||||
from immich_ml.models.base import InferenceModel
|
||||
from immich_ml.models.transforms import decode_cv2
|
||||
from immich_ml.schemas import ModelFormat, ModelSession, ModelTask, ModelType
|
||||
from immich_ml.sessions.ort import OrtSession
|
||||
|
||||
from .schemas import OcrOptions, TextDetectionOutput
|
||||
from .schemas import TextDetectionOutput
|
||||
|
||||
|
||||
class TextDetector(InferenceModel):
|
||||
@@ -24,13 +25,20 @@ class TextDetector(InferenceModel):
|
||||
def __init__(self, model_name: str, **model_kwargs: Any) -> None:
|
||||
super().__init__(model_name, **model_kwargs, model_format=ModelFormat.ONNX)
|
||||
self.max_resolution = 736
|
||||
self.min_score = 0.5
|
||||
self.score_mode = "fast"
|
||||
self.mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
|
||||
self.std_inv = np.float32(1.0) / (np.array([0.5, 0.5, 0.5], dtype=np.float32) * 255.0)
|
||||
self._empty: TextDetectionOutput = {
|
||||
"image": np.empty(0, dtype=np.float32),
|
||||
"boxes": np.empty(0, dtype=np.float32),
|
||||
"scores": np.empty(0, dtype=np.float32),
|
||||
}
|
||||
self.postprocess = DBPostProcess(
|
||||
thresh=0.3,
|
||||
box_thresh=model_kwargs.get("minScore", 0.5),
|
||||
max_candidates=1000,
|
||||
unclip_ratio=1.6,
|
||||
use_dilation=True,
|
||||
score_mode="fast",
|
||||
)
|
||||
|
||||
def _download(self) -> None:
|
||||
model_info = InferSession.get_model_url(
|
||||
@@ -52,35 +60,65 @@ class TextDetector(InferenceModel):
|
||||
|
||||
def _load(self) -> ModelSession:
|
||||
# TODO: support other runtime sessions
|
||||
session = OrtSession(self.model_path)
|
||||
self.model = RapidTextDetector(
|
||||
OcrOptions(
|
||||
session=session.session,
|
||||
limit_side_len=self.max_resolution,
|
||||
limit_type="min",
|
||||
box_thresh=self.min_score,
|
||||
score_mode=self.score_mode,
|
||||
)
|
||||
)
|
||||
return session
|
||||
return OrtSession(self.model_path)
|
||||
|
||||
def _predict(self, inputs: bytes | Image.Image) -> TextDetectionOutput:
|
||||
results = self.model(decode_cv2(inputs))
|
||||
if results.boxes is None or results.scores is None or results.img is None:
|
||||
# partly adapted from RapidOCR
|
||||
def _predict(self, inputs: Image.Image) -> TextDetectionOutput:
|
||||
w, h = inputs.size
|
||||
if w < 32 or h < 32:
|
||||
return self._empty
|
||||
out = self.session.run(None, {"x": self._transform(inputs)})[0]
|
||||
boxes, scores = self.postprocess(out, (h, w))
|
||||
if len(boxes) == 0:
|
||||
return self._empty
|
||||
return {
|
||||
"image": results.img,
|
||||
"boxes": np.array(results.boxes, dtype=np.float32),
|
||||
"scores": np.array(results.scores, dtype=np.float32),
|
||||
"boxes": self.sorted_boxes(boxes),
|
||||
"scores": np.array(scores, dtype=np.float32),
|
||||
}
|
||||
|
||||
# adapted from RapidOCR
|
||||
def _transform(self, img: Image.Image) -> NDArray[np.float32]:
|
||||
if img.height < img.width:
|
||||
ratio = float(self.max_resolution) / img.height
|
||||
else:
|
||||
ratio = float(self.max_resolution) / img.width
|
||||
|
||||
resize_h = int(img.height * ratio)
|
||||
resize_w = int(img.width * ratio)
|
||||
|
||||
resize_h = int(round(resize_h / 32) * 32)
|
||||
resize_w = int(round(resize_w / 32) * 32)
|
||||
resized_img = img.resize((int(resize_w), int(resize_h)), resample=Image.Resampling.LANCZOS)
|
||||
|
||||
img_np: NDArray[np.float32] = cv2.cvtColor(np.array(resized_img, dtype=np.float32), cv2.COLOR_RGB2BGR) # type: ignore
|
||||
img_np -= self.mean
|
||||
img_np *= self.std_inv
|
||||
img_np = np.transpose(img_np, (2, 0, 1))
|
||||
return np.expand_dims(img_np, axis=0)
|
||||
|
||||
def sorted_boxes(self, dt_boxes: NDArray[np.float32]) -> NDArray[np.float32]:
|
||||
if len(dt_boxes) == 0:
|
||||
return dt_boxes
|
||||
|
||||
# Sort by y, then identify lines, then sort by (line, x)
|
||||
y_order = np.argsort(dt_boxes[:, 0, 1], kind="stable")
|
||||
sorted_y = dt_boxes[y_order, 0, 1]
|
||||
|
||||
line_ids = np.empty(len(dt_boxes), dtype=np.int32)
|
||||
line_ids[0] = 0
|
||||
np.cumsum(np.abs(np.diff(sorted_y)) >= 10, out=line_ids[1:])
|
||||
|
||||
# Create composite sort key for final ordering
|
||||
# Shift line_ids by large factor, add x for tie-breaking
|
||||
sort_key = line_ids[y_order] * 1e6 + dt_boxes[y_order, 0, 0]
|
||||
final_order = np.argsort(sort_key, kind="stable")
|
||||
sorted_boxes: NDArray[np.float32] = dt_boxes[y_order[final_order]]
|
||||
return sorted_boxes
|
||||
|
||||
def configure(self, **kwargs: Any) -> None:
|
||||
if (max_resolution := kwargs.get("maxResolution")) is not None:
|
||||
self.max_resolution = max_resolution
|
||||
self.model.limit_side_len = max_resolution
|
||||
if (min_score := kwargs.get("minScore")) is not None:
|
||||
self.min_score = min_score
|
||||
self.model.postprocess_op.box_thresh = min_score
|
||||
self.postprocess.box_thresh = min_score
|
||||
if (score_mode := kwargs.get("scoreMode")) is not None:
|
||||
self.score_mode = score_mode
|
||||
self.model.postprocess_op.score_mode = score_mode
|
||||
self.postprocess.score_mode = score_mode
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
from PIL.Image import Image
|
||||
from PIL import Image
|
||||
from rapidocr.ch_ppocr_rec import TextRecInput
|
||||
from rapidocr.ch_ppocr_rec import TextRecognizer as RapidTextRecognizer
|
||||
from rapidocr.inference_engine.base import FileInfo, InferSession
|
||||
@@ -14,6 +13,7 @@ from rapidocr.utils.vis_res import VisRes
|
||||
|
||||
from immich_ml.config import log, settings
|
||||
from immich_ml.models.base import InferenceModel
|
||||
from immich_ml.models.transforms import pil_to_cv2
|
||||
from immich_ml.schemas import ModelFormat, ModelSession, ModelTask, ModelType
|
||||
from immich_ml.sessions.ort import OrtSession
|
||||
|
||||
@@ -65,17 +65,16 @@ class TextRecognizer(InferenceModel):
|
||||
)
|
||||
return session
|
||||
|
||||
def _predict(self, _: Image, texts: TextDetectionOutput) -> TextRecognitionOutput:
|
||||
boxes, img, box_scores = texts["boxes"], texts["image"], texts["scores"]
|
||||
def _predict(self, img: Image.Image, texts: TextDetectionOutput) -> TextRecognitionOutput:
|
||||
boxes, box_scores = texts["boxes"], texts["scores"]
|
||||
if boxes.shape[0] == 0:
|
||||
return self._empty
|
||||
rec = self.model(TextRecInput(img=self.get_crop_img_list(img, boxes)))
|
||||
if rec.txts is None:
|
||||
return self._empty
|
||||
|
||||
height, width = img.shape[0:2]
|
||||
boxes[:, :, 0] /= width
|
||||
boxes[:, :, 1] /= height
|
||||
boxes[:, :, 0] /= img.width
|
||||
boxes[:, :, 1] /= img.height
|
||||
|
||||
text_scores = np.array(rec.scores)
|
||||
valid_text_score_idx = text_scores > self.min_score
|
||||
@@ -87,7 +86,7 @@ class TextRecognizer(InferenceModel):
|
||||
"textScore": text_scores[valid_text_score_idx],
|
||||
}
|
||||
|
||||
def get_crop_img_list(self, img: NDArray[np.float32], boxes: NDArray[np.float32]) -> list[NDArray[np.float32]]:
|
||||
def get_crop_img_list(self, img: Image.Image, boxes: NDArray[np.float32]) -> list[NDArray[np.uint8]]:
|
||||
img_crop_width = np.maximum(
|
||||
np.linalg.norm(boxes[:, 1] - boxes[:, 0], axis=1), np.linalg.norm(boxes[:, 2] - boxes[:, 3], axis=1)
|
||||
).astype(np.int32)
|
||||
@@ -98,22 +97,55 @@ class TextRecognizer(InferenceModel):
|
||||
pts_std[:, 1:3, 0] = img_crop_width[:, None]
|
||||
pts_std[:, 2:4, 1] = img_crop_height[:, None]
|
||||
|
||||
img_crop_sizes = np.stack([img_crop_width, img_crop_height], axis=1).tolist()
|
||||
imgs: list[NDArray[np.float32]] = []
|
||||
for box, pts_std, dst_size in zip(list(boxes), list(pts_std), img_crop_sizes):
|
||||
M = cv2.getPerspectiveTransform(box, pts_std)
|
||||
dst_img: NDArray[np.float32] = cv2.warpPerspective(
|
||||
img,
|
||||
M,
|
||||
dst_size,
|
||||
borderMode=cv2.BORDER_REPLICATE,
|
||||
flags=cv2.INTER_CUBIC,
|
||||
) # type: ignore
|
||||
dst_height, dst_width = dst_img.shape[0:2]
|
||||
img_crop_sizes = np.stack([img_crop_width, img_crop_height], axis=1)
|
||||
all_coeffs = self._get_perspective_transform(pts_std, boxes)
|
||||
imgs: list[NDArray[np.uint8]] = []
|
||||
for coeffs, dst_size in zip(all_coeffs, img_crop_sizes):
|
||||
dst_img = img.transform(
|
||||
size=tuple(dst_size),
|
||||
method=Image.Transform.PERSPECTIVE,
|
||||
data=tuple(coeffs),
|
||||
resample=Image.Resampling.BICUBIC,
|
||||
)
|
||||
|
||||
dst_width, dst_height = dst_img.size
|
||||
if dst_height * 1.0 / dst_width >= 1.5:
|
||||
dst_img = np.rot90(dst_img)
|
||||
imgs.append(dst_img)
|
||||
dst_img = dst_img.rotate(90, expand=True)
|
||||
imgs.append(pil_to_cv2(dst_img))
|
||||
|
||||
return imgs
|
||||
|
||||
def _get_perspective_transform(self, src: NDArray[np.float32], dst: NDArray[np.float32]) -> NDArray[np.float32]:
|
||||
N = src.shape[0]
|
||||
x, y = src[:, :, 0], src[:, :, 1]
|
||||
u, v = dst[:, :, 0], dst[:, :, 1]
|
||||
A = np.zeros((N, 8, 9), dtype=np.float32)
|
||||
|
||||
# Fill even rows (0, 2, 4, 6): [x, y, 1, 0, 0, 0, -u*x, -u*y, -u]
|
||||
A[:, ::2, 0] = x
|
||||
A[:, ::2, 1] = y
|
||||
A[:, ::2, 2] = 1
|
||||
A[:, ::2, 6] = -u * x
|
||||
A[:, ::2, 7] = -u * y
|
||||
A[:, ::2, 8] = -u
|
||||
|
||||
# Fill odd rows (1, 3, 5, 7): [0, 0, 0, x, y, 1, -v*x, -v*y, -v]
|
||||
A[:, 1::2, 3] = x
|
||||
A[:, 1::2, 4] = y
|
||||
A[:, 1::2, 5] = 1
|
||||
A[:, 1::2, 6] = -v * x
|
||||
A[:, 1::2, 7] = -v * y
|
||||
A[:, 1::2, 8] = -v
|
||||
|
||||
# Solve using SVD for all matrices at once
|
||||
_, _, Vt = np.linalg.svd(A)
|
||||
H = Vt[:, -1, :].reshape(N, 3, 3)
|
||||
H = H / H[:, 2:3, 2:3]
|
||||
|
||||
# Extract the 8 coefficients for each transformation
|
||||
return np.column_stack(
|
||||
[H[:, 0, 0], H[:, 0, 1], H[:, 0, 2], H[:, 1, 0], H[:, 1, 1], H[:, 1, 2], H[:, 2, 0], H[:, 2, 1]]
|
||||
) # pyright: ignore[reportReturnType]
|
||||
|
||||
def configure(self, **kwargs: Any) -> None:
|
||||
self.min_score = kwargs.get("minScore", self.min_score)
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class TextDetectionOutput(TypedDict):
|
||||
image: npt.NDArray[np.float32]
|
||||
boxes: npt.NDArray[np.float32]
|
||||
scores: npt.NDArray[np.float32]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user