Compare commits

...

8 Commits

Author SHA1 Message Date
mertalev
72269ab58c add cli 2024-07-12 16:50:48 -04:00
mertalev
3db69b94ed support resnet models, test failed models 2024-07-12 16:50:48 -04:00
mertalev
b5acb71b05 prevent multidimensional bias 2024-07-12 16:50:48 -04:00
mertalev
b39cca1b43 fixes 2024-07-12 16:50:47 -04:00
mertalev
3d62011ae3 handle gather at the end 2024-07-12 16:50:47 -04:00
mertalev
1ad348c407 gather -> slice 2024-07-12 16:50:47 -04:00
mertalev
5dae920ac6 onnx2tf, 4d transpose 2024-07-12 16:50:47 -04:00
mertalev
956480ab2c enhance armnn conversion 2024-07-12 16:50:46 -04:00
27 changed files with 2531 additions and 390 deletions

View File

@@ -1,3 +0,0 @@
#!/usr/bin/env sh
g++ -shared -O3 -o libann.so -fuse-ld=gold -std=c++17 -I"$ARMNN_PATH"/include -larmnn -larmnnDeserializer -larmnnTfLiteParser -larmnnOnnxParser -L"$ARMNN_PATH" ann.cpp

View File

@@ -1,4 +0,0 @@
#!/usr/bin/env sh
cd armnn-23.11/ || exit
g++ -o ../armnnconverter -O1 -DARMNN_ONNX_PARSER -DARMNN_SERIALIZER -DARMNN_TF_LITE_PARSER -fuse-ld=gold -std=c++17 -Iinclude -Isrc/armnnUtils -Ithird-party -larmnn -larmnnDeserializer -larmnnTfLiteParser -larmnnOnnxParser -larmnnSerializer -L../armnn src/armnnConverter/ArmnnConverter.cpp

View File

@@ -1,201 +0,0 @@
name: annexport
channels:
- pytorch
- nvidia
- conda-forge
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_kmp_llvm
- aiohttp=3.9.1=py310h2372a71_0
- aiosignal=1.3.1=pyhd8ed1ab_0
- arpack=3.8.0=nompi_h0baa96a_101
- async-timeout=4.0.3=pyhd8ed1ab_0
- attrs=23.1.0=pyh71513ae_1
- aws-c-auth=0.7.3=h28f7589_1
- aws-c-cal=0.6.1=hc309b26_1
- aws-c-common=0.9.0=hd590300_0
- aws-c-compression=0.2.17=h4d4d85c_2
- aws-c-event-stream=0.3.1=h2e3709c_4
- aws-c-http=0.7.11=h00aa349_4
- aws-c-io=0.13.32=he9a53bd_1
- aws-c-mqtt=0.9.3=hb447be9_1
- aws-c-s3=0.3.14=hf3aad02_1
- aws-c-sdkutils=0.1.12=h4d4d85c_1
- aws-checksums=0.1.17=h4d4d85c_1
- aws-crt-cpp=0.21.0=hb942446_5
- aws-sdk-cpp=1.10.57=h85b1a90_19
- blas=2.120=openblas
- blas-devel=3.9.0=20_linux64_openblas
- brotli-python=1.0.9=py310hd8f1fbe_9
- bzip2=1.0.8=hd590300_5
- c-ares=1.23.0=hd590300_0
- ca-certificates=2023.11.17=hbcca054_0
- certifi=2023.11.17=pyhd8ed1ab_0
- charset-normalizer=3.3.2=pyhd8ed1ab_0
- click=8.1.7=unix_pyh707e725_0
- colorama=0.4.6=pyhd8ed1ab_0
- coloredlogs=15.0.1=pyhd8ed1ab_3
- cuda-cudart=11.7.99=0
- cuda-cupti=11.7.101=0
- cuda-libraries=11.7.1=0
- cuda-nvrtc=11.7.99=0
- cuda-nvtx=11.7.91=0
- cuda-runtime=11.7.1=0
- dataclasses=0.8=pyhc8e2a94_3
- datasets=2.14.7=pyhd8ed1ab_0
- dill=0.3.7=pyhd8ed1ab_0
- filelock=3.13.1=pyhd8ed1ab_0
- flatbuffers=23.5.26=h59595ed_1
- freetype=2.12.1=h267a509_2
- frozenlist=1.4.0=py310h2372a71_1
- fsspec=2023.10.0=pyhca7485f_0
- ftfy=6.1.3=pyhd8ed1ab_0
- gflags=2.2.2=he1b5a44_1004
- glog=0.6.0=h6f12383_0
- glpk=5.0=h445213a_0
- gmp=6.3.0=h59595ed_0
- gmpy2=2.1.2=py310h3ec546c_1
- huggingface_hub=0.17.3=pyhd8ed1ab_0
- humanfriendly=10.0=pyhd8ed1ab_6
- icu=73.2=h59595ed_0
- idna=3.6=pyhd8ed1ab_0
- importlib-metadata=7.0.0=pyha770c72_0
- importlib_metadata=7.0.0=hd8ed1ab_0
- joblib=1.3.2=pyhd8ed1ab_0
- keyutils=1.6.1=h166bdaf_0
- krb5=1.21.2=h659d440_0
- lcms2=2.15=h7f713cb_2
- ld_impl_linux-64=2.40=h41732ed_0
- lerc=4.0.0=h27087fc_0
- libabseil=20230125.3=cxx17_h59595ed_0
- libarrow=12.0.1=hb87d912_8_cpu
- libblas=3.9.0=20_linux64_openblas
- libbrotlicommon=1.0.9=h166bdaf_9
- libbrotlidec=1.0.9=h166bdaf_9
- libbrotlienc=1.0.9=h166bdaf_9
- libcblas=3.9.0=20_linux64_openblas
- libcrc32c=1.1.2=h9c3ff4c_0
- libcublas=11.10.3.66=0
- libcufft=10.7.2.124=h4fbf590_0
- libcufile=1.8.1.2=0
- libcurand=10.3.4.101=0
- libcurl=8.5.0=hca28451_0
- libcusolver=11.4.0.1=0
- libcusparse=11.7.4.91=0
- libdeflate=1.19=hd590300_0
- libedit=3.1.20191231=he28a2e2_2
- libev=4.33=hd590300_2
- libevent=2.1.12=hf998b51_1
- libffi=3.4.2=h7f98852_5
- libgcc-ng=13.2.0=h807b86a_3
- libgfortran-ng=13.2.0=h69a702a_3
- libgfortran5=13.2.0=ha4646dd_3
- libgoogle-cloud=2.12.0=hac9eb74_1
- libgrpc=1.54.3=hb20ce57_0
- libhwloc=2.9.3=default_h554bfaf_1009
- libiconv=1.17=hd590300_1
- libjpeg-turbo=2.1.5.1=hd590300_1
- liblapack=3.9.0=20_linux64_openblas
- liblapacke=3.9.0=20_linux64_openblas
- libnghttp2=1.58.0=h47da74e_1
- libnpp=11.7.4.75=0
- libnsl=2.0.1=hd590300_0
- libnuma=2.0.16=h0b41bf4_1
- libnvjpeg=11.8.0.2=0
- libopenblas=0.3.25=pthreads_h413a1c8_0
- libpng=1.6.39=h753d276_0
- libprotobuf=3.21.12=hfc55251_2
- libsentencepiece=0.1.99=h180e1df_0
- libsqlite=3.44.2=h2797004_0
- libssh2=1.11.0=h0841786_0
- libstdcxx-ng=13.2.0=h7e041cc_3
- libthrift=0.18.1=h8fd135c_2
- libtiff=4.6.0=h29866fb_1
- libutf8proc=2.8.0=h166bdaf_0
- libuuid=2.38.1=h0b41bf4_0
- libwebp-base=1.3.2=hd590300_0
- libxcb=1.15=h0b41bf4_0
- libxml2=2.11.6=h232c23b_0
- libzlib=1.2.13=hd590300_5
- llvm-openmp=17.0.6=h4dfa4b3_0
- lz4-c=1.9.4=hcb278e6_0
- mkl=2022.2.1=h84fe81f_16997
- mkl-devel=2022.2.1=ha770c72_16998
- mkl-include=2022.2.1=h84fe81f_16997
- mpc=1.3.1=hfe3b2da_0
- mpfr=4.2.1=h9458935_0
- mpmath=1.3.0=pyhd8ed1ab_0
- multidict=6.0.4=py310h2372a71_1
- multiprocess=0.70.15=py310h2372a71_1
- ncurses=6.4=h59595ed_2
- numpy=1.26.2=py310hb13e2d6_0
- onnx=1.14.0=py310ha3deec4_1
- onnx2torch=1.5.13=pyhd8ed1ab_0
- onnxruntime=1.16.3=py310hd4b7fbc_1_cpu
- open-clip-torch=2.23.0=pyhd8ed1ab_1
- openblas=0.3.25=pthreads_h7a3da1a_0
- openjpeg=2.5.0=h488ebb8_3
- openssl=3.2.0=hd590300_1
- orc=1.9.0=h2f23424_1
- packaging=23.2=pyhd8ed1ab_0
- pandas=2.1.4=py310hcc13569_0
- pillow=10.0.1=py310h29da1c1_1
- pip=23.3.1=pyhd8ed1ab_0
- protobuf=4.21.12=py310heca2aa9_0
- pthread-stubs=0.4=h36c2ea0_1001
- pyarrow=12.0.1=py310h0576679_8_cpu
- pyarrow-hotfix=0.6=pyhd8ed1ab_0
- pysocks=1.7.1=pyha2e5f31_6
- python=3.10.13=hd12c33a_0_cpython
- python-dateutil=2.8.2=pyhd8ed1ab_0
- python-flatbuffers=23.5.26=pyhd8ed1ab_0
- python-tzdata=2023.3=pyhd8ed1ab_0
- python-xxhash=3.4.1=py310h2372a71_0
- python_abi=3.10=4_cp310
- pytorch=1.13.1=cpu_py310hd11e9c7_1
- pytorch-cuda=11.7=h778d358_5
- pytorch-mutex=1.0=cuda
- pytz=2023.3.post1=pyhd8ed1ab_0
- pyyaml=6.0.1=py310h2372a71_1
- rdma-core=28.9=h59595ed_1
- re2=2023.03.02=h8c504da_0
- readline=8.2=h8228510_1
- regex=2023.10.3=py310h2372a71_0
- requests=2.31.0=pyhd8ed1ab_0
- s2n=1.3.49=h06160fa_0
- sacremoses=0.0.53=pyhd8ed1ab_0
- safetensors=0.3.3=py310hcb5633a_1
- sentencepiece=0.1.99=hff52083_0
- sentencepiece-python=0.1.99=py310hebdb9f0_0
- sentencepiece-spm=0.1.99=h180e1df_0
- setuptools=68.2.2=pyhd8ed1ab_0
- six=1.16.0=pyh6c4a22f_0
- sleef=3.5.1=h9b69904_2
- snappy=1.1.10=h9fff704_0
- sympy=1.12=pypyh9d50eac_103
- tbb=2021.11.0=h00ab1b0_0
- texttable=1.7.0=pyhd8ed1ab_0
- timm=0.9.12=pyhd8ed1ab_0
- tk=8.6.13=noxft_h4845f30_101
- tokenizers=0.14.1=py310h320607d_2
- torchvision=0.14.1=cpu_py310hd3d2ac3_1
- tqdm=4.66.1=pyhd8ed1ab_0
- transformers=4.35.2=pyhd8ed1ab_0
- typing-extensions=4.9.0=hd8ed1ab_0
- typing_extensions=4.9.0=pyha770c72_0
- tzdata=2023c=h71feb2d_0
- ucx=1.14.1=h64cca9d_5
- urllib3=2.1.0=pyhd8ed1ab_0
- wcwidth=0.2.12=pyhd8ed1ab_0
- wheel=0.42.0=pyhd8ed1ab_0
- xorg-libxau=1.0.11=hd590300_0
- xorg-libxdmcp=1.1.3=h7f98852_0
- xxhash=0.8.2=hd590300_0
- xz=5.2.6=h166bdaf_0
- yaml=0.2.5=h7f98852_2
- yarl=1.9.3=py310h2372a71_0
- zipp=3.17.0=pyhd8ed1ab_0
- zlib=1.2.13=hd590300_5
- zstd=1.5.5=hfc55251_0
- pip:
- git+https://github.com/fyfrey/TinyNeuralNetwork.git

View File

@@ -1,157 +0,0 @@
import logging
import os
import platform
import subprocess
from abc import abstractmethod
import onnx
import open_clip
import torch
from onnx2torch import convert
from onnxruntime.tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed
from tinynn.converter import TFLiteConverter
class ExportBase(torch.nn.Module):
input_shape: tuple[int, ...]
def __init__(self, device: torch.device, name: str):
super().__init__()
self.device = device
self.name = name
self.optimize = 5
self.nchw_transpose = False
@abstractmethod
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor]:
pass
def dummy_input(self) -> torch.FloatTensor:
return torch.rand((1, 3, 224, 224), device=self.device)
class ArcFace(ExportBase):
input_shape = (1, 3, 112, 112)
def __init__(self, onnx_model_path: str, device: torch.device):
name, _ = os.path.splitext(os.path.basename(onnx_model_path))
super().__init__(device, name)
onnx_model = onnx.load_model(onnx_model_path)
make_input_shape_fixed(onnx_model.graph, onnx_model.graph.input[0].name, self.input_shape)
fix_output_shapes(onnx_model)
self.model = convert(onnx_model).to(device)
if self.device.type == "cuda":
self.model = self.model.half()
def forward(self, input_tensor: torch.Tensor) -> torch.FloatTensor:
embedding: torch.FloatTensor = self.model(
input_tensor.half() if self.device.type == "cuda" else input_tensor
).float()
assert isinstance(embedding, torch.FloatTensor)
return embedding
def dummy_input(self) -> torch.FloatTensor:
return torch.rand(self.input_shape, device=self.device)
class RetinaFace(ExportBase):
input_shape = (1, 3, 640, 640)
def __init__(self, onnx_model_path: str, device: torch.device):
name, _ = os.path.splitext(os.path.basename(onnx_model_path))
super().__init__(device, name)
self.optimize = 3
self.model = convert(onnx_model_path).eval().to(device)
if self.device.type == "cuda":
self.model = self.model.half()
def forward(self, input_tensor: torch.Tensor) -> tuple[torch.FloatTensor]:
out: torch.Tensor = self.model(input_tensor.half() if self.device.type == "cuda" else input_tensor)
return tuple(o.float() for o in out)
def dummy_input(self) -> torch.FloatTensor:
return torch.rand(self.input_shape, device=self.device)
class ClipVision(ExportBase):
input_shape = (1, 3, 224, 224)
def __init__(self, model_name: str, weights: str, device: torch.device):
super().__init__(device, model_name + "__" + weights)
self.model = open_clip.create_model(
model_name,
weights,
precision="fp16" if device.type == "cuda" else "fp32",
jit=False,
require_pretrained=True,
device=device,
)
def forward(self, input_tensor: torch.Tensor) -> torch.FloatTensor:
embedding: torch.Tensor = self.model.encode_image(
input_tensor.half() if self.device.type == "cuda" else input_tensor,
normalize=True,
).float()
return embedding
def export(model: ExportBase) -> None:
model.eval()
for param in model.parameters():
param.requires_grad = False
dummy_input = model.dummy_input()
model(dummy_input)
jit = torch.jit.trace(model, dummy_input) # type: ignore[no-untyped-call,attr-defined]
tflite_model_path = f"output/{model.name}.tflite"
os.makedirs("output", exist_ok=True)
converter = TFLiteConverter(
jit,
dummy_input,
tflite_model_path,
optimize=model.optimize,
nchw_transpose=model.nchw_transpose,
)
# segfaults on ARM, must run on x86_64 / AMD64
converter.convert()
armnn_model_path = f"output/{model.name}.armnn"
os.environ["LD_LIBRARY_PATH"] = "armnn"
subprocess.run(
[
"./armnnconverter",
"-f",
"tflite-binary",
"-m",
tflite_model_path,
"-i",
"input_tensor",
"-o",
"output_tensor",
"-p",
armnn_model_path,
]
)
def main() -> None:
if platform.machine() not in ("x86_64", "AMD64"):
raise RuntimeError(f"Can only run on x86_64 / AMD64, not {platform.machine()}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type != "cuda":
logging.warning(
"No CUDA available, cannot create fp16 model! proceeding to create a fp32 model (use only for testing)"
)
models = [
ClipVision("ViT-B-32", "openai", device),
ArcFace("buffalo_l_rec.onnx", device),
RetinaFace("buffalo_l_det.onnx", device),
]
for model in models:
export(model)
if __name__ == "__main__":
with torch.no_grad():
main()

View File

@@ -0,0 +1,35 @@
FROM mambaorg/micromamba:bookworm-slim@sha256:333f7598ff2c2400fb10bfe057709c68b7daab5d847143af85abcf224a07271a as builder
USER root
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cmake \
curl \
git
USER $MAMBA_USER
WORKDIR /home/mambauser
ENV ARMNN_PATH=armnn
COPY --chown=$MAMBA_USER:$MAMBA_USER scripts/* .
RUN ./download-armnn.sh && \
./build-converter.sh && \
./build.sh
COPY --chown=$MAMBA_USER:$MAMBA_USER conda-lock.yml .
RUN micromamba create -y -p /home/mambauser/venv -f conda-lock.yml && \
micromamba clean --all --yes
ENV PATH="/home/mambauser/venv/bin:${PATH}"
FROM gcr.io/distroless/base-debian12
# FROM mambaorg/micromamba:bookworm-slim@sha256:333f7598ff2c2400fb10bfe057709c68b7daab5d847143af85abcf224a07271a
WORKDIR /export/ann
ENV PYTHONDONTWRITEBYTECODE=1 \
LD_LIBRARY_PATH=/export/ann/armnn \
PATH="/opt/venv/bin:${PATH}"
COPY --from=builder /home/mambauser/armnnconverter /home/mambauser/armnn ./
COPY --from=builder /home/mambauser/venv /opt/venv
COPY --chown=$MAMBA_USER:$MAMBA_USER onnx2ann onnx2ann
ENTRYPOINT ["python", "-m", "onnx2ann"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,21 @@
name: onnx2ann
channels:
- conda-forge
dependencies:
- python>=3.11,<4.0
- onnx>=1.16.1
# - onnxruntime>=1.18.1 # conda only has gpu version
- psutil>=6.0.0
- flatbuffers>=24.3.25
- ml_dtypes>=0.3.1
- typer-slim>=0.12.3
- huggingface_hub>=0.23.4
- pip
- pip:
- onnxruntime>=1.18.1 # conda only has gpu version
- onnxsim>=0.4.36
- onnx2tf>=1.24.1
- onnx_graphsurgeon>=0.5.2
- simple_onnx_processing_tools>=1.1.32
- tf_keras>=2.16.0
- git+https://github.com/microsoft/onnxconverter-common.git

View File

@@ -0,0 +1,99 @@
import os
import platform
from typing import Annotated, Optional
import typer
from onnx2ann.export import Exporter, ModelType, Precision
app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False)
@app.command()
def export(
model_name: Annotated[
str, typer.Argument(..., help="The name of the model to be exported as it exists in Hugging Face.")
],
model_type: Annotated[ModelType, typer.Option(..., "--type", "-t", help="The type of model to be exported.")],
input_shapes: Annotated[
list[str],
typer.Option(
...,
"--input-shape",
"-s",
help="The shape of an input tensor to the model, each dimension separated by commas. "
"Multiple shapes can be provided for multiple inputs.",
),
],
precision: Annotated[
Precision,
typer.Option(
...,
"--precision",
"-p",
help="The precision of the exported model. `float16` requires a GPU.",
),
] = Precision.FLOAT32,
cache_dir: Annotated[
str,
typer.Option(
...,
"--cache-dir",
"-c",
help="Directory where pre-export models will be stored.",
envvar="CACHE_DIR",
show_envvar=True,
),
] = "~/.cache/huggingface",
output_dir: Annotated[
str,
typer.Option(
...,
"--output-dir",
"-o",
help="Directory where exported models will be stored.",
),
] = "output",
auth_token: Annotated[
Optional[str],
typer.Option(
...,
"--auth-token",
"-t",
help="If uploading models to Hugging Face, the auth token of the user or organisation.",
envvar="HF_AUTH_TOKEN",
show_envvar=True,
),
] = None,
force_export: Annotated[
bool,
typer.Option(
...,
"--force-export",
"-f",
help="Export the model even if an exported model already exists in the output directory.",
),
] = False,
) -> None:
if platform.machine() not in ("x86_64", "AMD64"):
msg = f"Can only run on x86_64 / AMD64, not {platform.machine()}"
raise RuntimeError(msg)
os.environ.setdefault("LD_LIBRARY_PATH", "armnn")
parsed_input_shapes = [tuple(map(int, shape.split(","))) for shape in input_shapes]
model = Exporter(
model_name, model_type, input_shapes=parsed_input_shapes, cache_dir=cache_dir, force_export=force_export
)
model_dir = os.path.join("output", model_name)
output_dir = os.path.join(model_dir, model_type)
armnn_model = model.to_armnn(output_dir, precision)
if not auth_token:
return
from huggingface_hub import upload_file
relative_path = os.path.relpath(armnn_model, start=model_dir)
upload_file(path_or_fileobj=armnn_model, path_in_repo=relative_path, repo_id=model.repo_name, token=auth_token)
app()

View File

@@ -0,0 +1,129 @@
import os
import subprocess
from enum import StrEnum
from onnx2ann.helpers import onnx_make_armnn_compatible, onnx_make_inputs_fixed
class ModelType(StrEnum):
VISUAL = "visual"
TEXTUAL = "textual"
RECOGNITION = "recognition"
DETECTION = "detection"
class Precision(StrEnum):
FLOAT16 = "float16"
FLOAT32 = "float32"
class Exporter:
def __init__(
self,
model_name: str,
model_type: str,
input_shapes: list[tuple[int, ...]],
optimization_level: int = 5,
cache_dir: str = os.environ.get("CACHE_DIR", "~/.cache/huggingface"),
force_export: bool = False,
):
self.model_name = model_name.split("/")[-1]
self.model_type = model_type
self.optimize = optimization_level
self.input_shapes = input_shapes
self.cache_dir = os.path.join(cache_dir, self.repo_name)
self.force_export = force_export
def download(self) -> str:
model_path = os.path.join(self.cache_dir, self.model_type, "model.onnx")
if os.path.isfile(model_path):
print(f"Model is already downloaded at {model_path}")
return model_path
from huggingface_hub import snapshot_download
snapshot_download(
self.repo_name, cache_dir=self.cache_dir, local_dir=self.cache_dir, local_dir_use_symlinks=False
)
return model_path
def to_onnx_static(self, precision: Precision) -> str:
import onnx
from onnxconverter_common import float16
onnx_path_original = self.download()
static_dir = os.path.join(self.cache_dir, self.model_type, "static")
static_path = os.path.join(static_dir, f"model.onnx")
if self.force_export and not os.path.isfile(static_path):
print(f"Making {self} static")
os.makedirs(static_dir, exist_ok=True)
onnx_make_inputs_fixed(onnx_path_original, static_path, self.input_shapes)
onnx_make_armnn_compatible(static_path)
print(f"Finished making {self} static")
model = onnx.load(static_path)
self.inputs = [input_.name for input_ in model.graph.input]
self.outputs = [output_.name for output_ in model.graph.output]
if precision == Precision.FLOAT16:
static_path = os.path.join(static_dir, f"model_{precision}.onnx")
print(f"Converting {self} to {precision} precision")
model = float16.convert_float_to_float16(model, keep_io_types=True, disable_shape_infer=True)
onnx.save(model, static_path)
print(f"Finished converting {self} to {precision} precision")
# self.inputs, self.outputs = onnx_get_inputs_outputs(static_path)
return static_path
def to_tflite(self, output_dir: str, precision: Precision) -> str:
onnx_model = self.to_onnx_static(precision)
tflite_dir = os.path.join(output_dir, precision)
tflite_model = os.path.join(tflite_dir, f"model_{precision}.tflite")
if self.force_export or not os.path.isfile(tflite_model):
import onnx2tf
print(f"Exporting {self} to TFLite with {precision} precision (this might take a few minutes)")
onnx2tf.convert(
input_onnx_file_path=onnx_model,
output_folder_path=tflite_dir,
keep_shape_absolutely_input_names=self.inputs,
# verbosity="warn",
copy_onnx_input_output_names_to_tflite=True,
output_signaturedefs=True,
not_use_onnxsim=True,
)
print(f"Finished exporting {self} to TFLite with {precision} precision")
return tflite_model
def to_armnn(self, output_dir: str, precision: Precision) -> tuple[str, str]:
armnn_model = os.path.join(output_dir, "model.armnn")
if not self.force_export and os.path.isfile(armnn_model):
return armnn_model
tflite_model_dir = os.path.join(output_dir, "tflite")
tflite_model = self.to_tflite(tflite_model_dir, precision)
args = ["./armnnconverter", "-f", "tflite-binary", "-m", tflite_model, "-p", armnn_model]
args.append("-i")
args.extend(self.inputs)
args.append("-o")
args.extend(self.outputs)
print(f"Exporting {self} to ARM NN with {precision} precision")
try:
if (stdout := subprocess.check_output(args, stderr=subprocess.STDOUT).decode()):
print(stdout)
print(f"Finished exporting {self} to ARM NN with {precision} precision")
except subprocess.CalledProcessError as e:
print(e.output.decode())
try:
from shutil import rmtree
rmtree(tflite_model_dir, ignore_errors=True)
finally:
raise e
@property
def repo_name(self) -> str:
return f"immich-app/{self.model_name}"
def __repr__(self) -> str:
return f"{self.model_name} ({self.model_type})"

View File

@@ -0,0 +1,260 @@
from typing import Any
def onnx_make_armnn_compatible(model_path: str) -> None:
"""
i can explain
armnn only supports up to 4d tranposes, but the model has a 5d transpose due to a redundant unsqueeze
this function folds the unsqueeze+transpose+squeeze into a single 4d transpose
it also switches from gather ops to slices since armnn has different dimension semantics for gathers
also fixes batch normalization being in training mode
"""
import numpy as np
import onnx
from onnx_graphsurgeon import Constant, Node, Variable, export_onnx, import_onnx
proto = onnx.load(model_path)
graph = import_onnx(proto)
gather_idx = 1
squeeze_idx = 1
for node in graph.nodes:
for link1 in node.outputs:
if "Unsqueeze" in link1.name:
for node1 in link1.outputs:
for link2 in node1.outputs:
if "Transpose" in link2.name:
for node2 in link2.outputs:
if node2.attrs.get("perm") == [3, 1, 2, 0, 4]:
node2.attrs["perm"] = [2, 0, 1, 3]
link2.shape = link1.shape
for link3 in node2.outputs:
if "Squeeze" in link3.name:
link3.shape = [link3.shape[x] for x in [0, 1, 2, 4]]
for node3 in link3.outputs:
for link4 in node3.outputs:
link4.shape = link3.shape
try:
idx = link2.inputs.index(node1)
link2.inputs[idx] = node
except ValueError:
pass
node.outputs = [link2]
if "Gather" in link4.name:
for node4 in link4.outputs:
axis = node1.attrs.get("axis", 0)
index = node4.inputs[1].values
slice_link = Variable(
f"onnx::Slice_123{gather_idx}",
dtype=link4.dtype,
shape=[1] + link3.shape[1:],
)
slice_node = Node(
op="Slice",
inputs=[
link3,
Constant(
f"SliceStart_123{gather_idx}",
np.array([index]),
),
Constant(
f"SliceEnd_123{gather_idx}",
np.array([index + 1]),
),
Constant(
f"SliceAxis_123{gather_idx}",
np.array([axis]),
),
],
outputs=[slice_link],
name=f"Slice_123{gather_idx}",
)
graph.nodes.append(slice_node)
gather_idx += 1
for link5 in node4.outputs:
for node5 in link5.outputs:
try:
idx = node5.inputs.index(link5)
node5.inputs[idx] = slice_link
except ValueError:
pass
elif node.op == "LayerNormalization":
for node1 in link1.outputs:
if node1.op == "Gather":
for link2 in node1.outputs:
for node2 in link2.outputs:
axis = node1.attrs.get("axis", 0)
index = node1.inputs[1].values
slice_link = Variable(
f"onnx::Slice_123{gather_idx}",
dtype=link2.dtype,
shape=[1, *link2.shape],
)
slice_node = Node(
op="Slice",
inputs=[
node1.inputs[0],
Constant(
f"SliceStart_123{gather_idx}",
np.array([index]),
),
Constant(
f"SliceEnd_123{gather_idx}",
np.array([index + 1]),
),
Constant(
f"SliceAxis_123{gather_idx}",
np.array([axis]),
),
],
outputs=[slice_link],
name=f"Slice_123{gather_idx}",
)
graph.nodes.append(slice_node)
gather_idx += 1
squeeze_link = Variable(
f"onnx::Squeeze_123{squeeze_idx}",
dtype=link2.dtype,
shape=link2.shape,
)
squeeze_node = Node(
op="Squeeze",
inputs=[
slice_link,
Constant(
f"SqueezeAxis_123{squeeze_idx}",
np.array([0]),
),
],
outputs=[squeeze_link],
name=f"Squeeze_123{squeeze_idx}",
)
graph.nodes.append(squeeze_node)
squeeze_idx += 1
try:
idx = node2.inputs.index(link2)
node2.inputs[idx] = squeeze_link
except ValueError:
pass
elif node.op == "Reshape":
for node1 in link1.outputs:
if node1.op == "Gather":
node2s = [n for link in node1.outputs for n in link.outputs]
if any(n.op == "Abs" for n in node2s):
axis = node1.attrs.get("axis", 0)
index = node1.inputs[1].values
slice_link = Variable(
f"onnx::Slice_123{gather_idx}",
dtype=node1.outputs[0].dtype,
shape=[1, *node1.outputs[0].shape],
)
slice_node = Node(
op="Slice",
inputs=[
node1.inputs[0],
Constant(
f"SliceStart_123{gather_idx}",
np.array([index]),
),
Constant(
f"SliceEnd_123{gather_idx}",
np.array([index + 1]),
),
Constant(
f"SliceAxis_123{gather_idx}",
np.array([axis]),
),
],
outputs=[slice_link],
name=f"Slice_123{gather_idx}",
)
graph.nodes.append(slice_node)
gather_idx += 1
squeeze_link = Variable(
f"onnx::Squeeze_123{squeeze_idx}",
dtype=node1.outputs[0].dtype,
shape=node1.outputs[0].shape,
)
squeeze_node = Node(
op="Squeeze",
inputs=[
slice_link,
Constant(
f"SqueezeAxis_123{squeeze_idx}",
np.array([0]),
),
],
outputs=[squeeze_link],
name=f"Squeeze_123{squeeze_idx}",
)
graph.nodes.append(squeeze_node)
squeeze_idx += 1
for node2 in node2s:
node2.inputs[0] = squeeze_link
elif node.op == "BatchNormalization" and node.attrs.get("training_mode") == 1:
node.attrs["training_mode"] = 0
node.outputs = node.outputs[:1]
graph.cleanup(remove_unused_node_outputs=True, recurse_subgraphs=True, recurse_functions=True)
graph.toposort()
graph.fold_constants()
updated = export_onnx(graph)
onnx_save(updated, model_path)
# for some reason, reloading the model is necessary to apply the correct shape
proto = onnx.load(model_path)
graph = import_onnx(proto)
for node in graph.nodes:
if node.op == "Slice":
for link in node.outputs:
if "Slice_123" in link.name and link.shape[0] == 3: # noqa: PLR2004
link.shape[0] = 1
graph.cleanup(remove_unused_node_outputs=True, recurse_subgraphs=True, recurse_functions=True)
graph.toposort()
graph.fold_constants()
updated = export_onnx(graph)
onnx_save(updated, model_path)
onnx.shape_inference.infer_shapes_path(model_path, check_type=True, strict_mode=True, data_prop=True)
def onnx_make_inputs_fixed(input_path: str, output_path: str, input_shapes: list[tuple[int, ...]]) -> None:
import onnx
import onnxsim
from onnxruntime.tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed
model, success = onnxsim.simplify(input_path)
if not success:
msg = f"Failed to simplify {input_path}"
raise RuntimeError(msg)
onnx_save(model, output_path)
onnx.shape_inference.infer_shapes_path(output_path, check_type=True, strict_mode=True, data_prop=True)
model = onnx.load_model(output_path)
for input_node, shape in zip(model.graph.input, input_shapes, strict=False):
make_input_shape_fixed(model.graph, input_node.name, shape)
fix_output_shapes(model)
onnx_save(model, output_path)
onnx.shape_inference.infer_shapes_path(output_path, check_type=True, strict_mode=True, data_prop=True)
def onnx_get_inputs_outputs(model_path: str) -> tuple[list[str], list[str]]:
import onnx
model = onnx.load(model_path)
inputs = [input_.name for input_ in model.graph.input]
outputs = [output_.name for output_ in model.graph.output]
return inputs, outputs
def onnx_save(model: Any, output_path: str) -> None:
import onnx
try:
onnx.save(model, output_path)
except:
onnx.save(model, output_path, save_as_external_data=True, all_tensors_to_one_file=False, size_threshold=1_000_000)

View File

@@ -0,0 +1,56 @@
[project]
name = "onnx2ann"
version = "1.107.2"
dependencies = [
"onnx>=1.16.1",
"psutil>=6.0.0",
"flatbuffers>=24.3.25",
"ml_dtypes>=0.3.1,<1.0.0",
"typer-slim>=0.12.3,<1.0.0",
"huggingface_hub>=0.23.4,<1.0.0",
"onnxruntime>=1.18.1",
"onnxsim>=0.4.36,<1.0.0",
"onnx2tf>=1.24.0",
"onnx_graphsurgeon>=0.5.2,<1.0.0",
"simple_onnx_processing_tools>=1.1.32",
"tf_keras>=2.16.0",
"onnxconverter-common @ git+https://github.com/microsoft/onnxconverter-common"
]
requires-python = ">=3.11"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.sdist]
only-include = ["onnx2ann"]
[tool.hatch.metadata]
allow-direct-references = true
[tool.mypy]
python_version = "3.12"
follow_imports = "silent"
warn_redundant_casts = true
disallow_any_generics = true
check_untyped_defs = true
disallow_untyped_defs = true
ignore_missing_imports = true
[tool.pydantic-mypy]
init_forbid_extra = true
init_typed = true
warn_required_dynamic_aliases = true
warn_untyped_fields = true
[tool.ruff]
line-length = 120
target-version = "py312"
[tool.ruff.lint]
extend-select = ["E", "F", "I"]
extend-ignore = ["FBT001", "FBT002"]
[tool.black]
line-length = 120
target-version = ['py312']

View File

@@ -0,0 +1,281 @@
#include <fstream>
#include <mutex>
#include <atomic>
#include "armnn/IRuntime.hpp"
#include "armnn/INetwork.hpp"
#include "armnn/Types.hpp"
#include "armnnDeserializer/IDeserializer.hpp"
#include "armnnTfLiteParser/ITfLiteParser.hpp"
#include "armnnOnnxParser/IOnnxParser.hpp"
using namespace armnn;
struct IOInfos
{
std::vector<BindingPointInfo> inputInfos;
std::vector<BindingPointInfo> outputInfos;
};
// from https://rigtorp.se/spinlock/
struct SpinLock
{
std::atomic<bool> lock_ = {false};
void lock()
{
for (;;)
{
if (!lock_.exchange(true, std::memory_order_acquire))
{
break;
}
while (lock_.load(std::memory_order_relaxed))
;
}
}
void unlock() { lock_.store(false, std::memory_order_release); }
};
class Ann
{
public:
int load(const char *modelPath,
bool fastMath,
bool fp16,
bool saveCachedNetwork,
const char *cachedNetworkPath)
{
INetworkPtr network = loadModel(modelPath);
IOptimizedNetworkPtr optNet = OptimizeNetwork(network.get(), fastMath, fp16, saveCachedNetwork, cachedNetworkPath);
const IOInfos infos = getIOInfos(optNet.get());
NetworkId netId;
mutex.lock();
Status status = runtime->LoadNetwork(netId, std::move(optNet));
mutex.unlock();
if (status != Status::Success)
{
return -1;
}
spinLock.lock();
ioInfos[netId] = infos;
mutexes.emplace(netId, std::make_unique<std::mutex>());
spinLock.unlock();
return netId;
}
void execute(NetworkId netId, const void **inputData, void **outputData)
{
spinLock.lock();
const IOInfos *infos = &ioInfos[netId];
auto m = mutexes[netId].get();
spinLock.unlock();
InputTensors inputTensors;
inputTensors.reserve(infos->inputInfos.size());
size_t i = 0;
for (const BindingPointInfo &info : infos->inputInfos)
inputTensors.emplace_back(info.first, ConstTensor(info.second, inputData[i++]));
OutputTensors outputTensors;
outputTensors.reserve(infos->outputInfos.size());
i = 0;
for (const BindingPointInfo &info : infos->outputInfos)
outputTensors.emplace_back(info.first, Tensor(info.second, outputData[i++]));
m->lock();
runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
m->unlock();
}
void unload(NetworkId netId)
{
mutex.lock();
runtime->UnloadNetwork(netId);
mutex.unlock();
}
int tensors(NetworkId netId, bool isInput = false)
{
spinLock.lock();
const IOInfos *infos = &ioInfos[netId];
spinLock.unlock();
return (int)(isInput ? infos->inputInfos.size() : infos->outputInfos.size());
}
unsigned long shape(NetworkId netId, bool isInput = false, int index = 0)
{
spinLock.lock();
const IOInfos *infos = &ioInfos[netId];
spinLock.unlock();
const TensorShape shape = (isInput ? infos->inputInfos : infos->outputInfos)[index].second.GetShape();
unsigned long s = 0;
for (unsigned int d = 0; d < shape.GetNumDimensions(); d++)
s |= ((unsigned long)shape[d]) << (d * 16); // stores up to 4 16-bit values in a 64-bit value
return s;
}
Ann(int tuningLevel, const char *tuningFile)
{
IRuntime::CreationOptions runtimeOptions;
BackendOptions backendOptions{"GpuAcc",
{
{"TuningLevel", tuningLevel},
{"MemoryOptimizerStrategy", "ConstantMemoryStrategy"}, // SingleAxisPriorityList or ConstantMemoryStrategy
}};
if (tuningFile)
backendOptions.AddOption({"TuningFile", tuningFile});
runtimeOptions.m_BackendOptions.emplace_back(backendOptions);
runtime = IRuntime::CreateRaw(runtimeOptions);
};
~Ann()
{
IRuntime::Destroy(runtime);
};
private:
INetworkPtr loadModel(const char *modelPath)
{
const auto path = std::string(modelPath);
if (path.rfind(".tflite") == path.length() - 7) // endsWith()
{
auto parser = armnnTfLiteParser::ITfLiteParser::CreateRaw();
return parser->CreateNetworkFromBinaryFile(modelPath);
}
else if (path.rfind(".onnx") == path.length() - 5) // endsWith()
{
auto parser = armnnOnnxParser::IOnnxParser::CreateRaw();
return parser->CreateNetworkFromBinaryFile(modelPath);
}
else
{
std::ifstream ifs(path, std::ifstream::in | std::ifstream::binary);
auto parser = armnnDeserializer::IDeserializer::CreateRaw();
return parser->CreateNetworkFromBinary(ifs);
}
}
static BindingPointInfo getInputTensorInfo(LayerBindingId inputBindingId, TensorInfo info)
{
const auto newInfo = TensorInfo{info.GetShape(), info.GetDataType(),
info.GetQuantizationScale(),
info.GetQuantizationOffset(),
true};
return {inputBindingId, newInfo};
}
IOptimizedNetworkPtr OptimizeNetwork(INetwork *network, bool fastMath, bool fp16, bool saveCachedNetwork, const char *cachedNetworkPath)
{
const bool allowExpandedDims = false;
const ShapeInferenceMethod shapeInferenceMethod = ShapeInferenceMethod::ValidateOnly;
OptimizerOptionsOpaque options;
options.SetReduceFp32ToFp16(fp16);
options.SetShapeInferenceMethod(shapeInferenceMethod);
options.SetAllowExpandedDims(allowExpandedDims);
BackendOptions gpuAcc("GpuAcc", {{"FastMathEnabled", fastMath}});
if (cachedNetworkPath)
{
gpuAcc.AddOption({"SaveCachedNetwork", saveCachedNetwork});
gpuAcc.AddOption({"CachedNetworkFilePath", cachedNetworkPath});
}
options.AddModelOption(gpuAcc);
// No point in using ARMNN for CPU, use ONNX (quantized) instead.
// BackendOptions cpuAcc("CpuAcc",
// {
// {"FastMathEnabled", fastMath},
// {"NumberOfThreads", 0},
// });
// options.AddModelOption(cpuAcc);
BackendOptions allowExDimOpt("AllowExpandedDims",
{{"AllowExpandedDims", allowExpandedDims}});
options.AddModelOption(allowExDimOpt);
BackendOptions shapeInferOpt("ShapeInferenceMethod",
{{"InferAndValidate", shapeInferenceMethod == ShapeInferenceMethod::InferAndValidate}});
options.AddModelOption(shapeInferOpt);
std::vector<BackendId> backends = {
BackendId("GpuAcc"),
// BackendId("CpuAcc"),
// BackendId("CpuRef"),
};
return Optimize(*network, backends, runtime->GetDeviceSpec(), options);
}
IOInfos getIOInfos(IOptimizedNetwork *optNet)
{
struct InfoStrategy : IStrategy
{
void ExecuteStrategy(const IConnectableLayer *layer,
const BaseDescriptor &descriptor,
const std::vector<ConstTensor> &constants,
const char *name,
const LayerBindingId id = 0) override
{
IgnoreUnused(descriptor, constants, id);
const LayerType lt = layer->GetType();
if (lt == LayerType::Input)
ioInfos.inputInfos.push_back(getInputTensorInfo(id, layer->GetOutputSlot(0).GetTensorInfo()));
else if (lt == LayerType::Output)
ioInfos.outputInfos.push_back({id, layer->GetInputSlot(0).GetTensorInfo()});
}
IOInfos ioInfos;
};
InfoStrategy infoStrategy;
optNet->ExecuteStrategy(infoStrategy);
return infoStrategy.ioInfos;
}
IRuntime *runtime;
std::map<NetworkId, IOInfos> ioInfos;
std::map<NetworkId, std::unique_ptr<std::mutex>> mutexes; // mutex per network to not execute the same the same network concurrently
std::mutex mutex; // global mutex for load/unload calls to the runtime
SpinLock spinLock; // fast spin lock to guard access to the ioInfos and mutexes maps
};
extern "C" void *init(int logLevel, int tuningLevel, const char *tuningFile)
{
LogSeverity level = static_cast<LogSeverity>(logLevel);
ConfigureLogging(true, true, level);
Ann *ann = new Ann(tuningLevel, tuningFile);
return ann;
}
extern "C" void destroy(void *ann)
{
delete ((Ann *)ann);
}
extern "C" int load(void *ann,
const char *path,
bool fastMath,
bool fp16,
bool saveCachedNetwork,
const char *cachedNetworkPath)
{
return ((Ann *)ann)->load(path, fastMath, fp16, saveCachedNetwork, cachedNetworkPath);
}
extern "C" void unload(void *ann, NetworkId netId)
{
((Ann *)ann)->unload(netId);
}
extern "C" void execute(void *ann, NetworkId netId, const void **inputData, void **outputData)
{
((Ann *)ann)->execute(netId, inputData, outputData);
}
extern "C" unsigned long shape(void *ann, NetworkId netId, bool isInput, int index)
{
return ((Ann *)ann)->shape(netId, isInput, index);
}
extern "C" int tensors(void *ann, NetworkId netId, bool isInput)
{
return ((Ann *)ann)->tensors(netId, isInput);
}

View File

@@ -0,0 +1,4 @@
#!/usr/bin/env sh
cd armnn-23.11/ || exit
g++ -o ../armnnconverter -fPIC -O1 -DARMNN_ONNX_PARSER -DARMNN_SERIALIZER -DARMNN_TF_LITE_PARSER -fuse-ld=gold -std=c++17 -Iinclude -Isrc/armnnUtils -Ithird-party -larmnn -larmnnDeserializer -larmnnTfLiteParser -larmnnOnnxParser -larmnnSerializer -L../armnn src/armnnConverter/ArmnnConverter.cpp

View File

@@ -0,0 +1,3 @@
#!/usr/bin/env sh
g++ -shared -O3 -fPIC -o libann.so -fuse-ld=gold -std=c++17 -I"$ARMNN_PATH"/include -larmnn -larmnnDeserializer -larmnnTfLiteParser -larmnnOnnxParser -L"$ARMNN_PATH" ann.cpp

View File

@@ -19,37 +19,44 @@ _MCLIP_TO_OPENCLIP = {
}
def forward(self: MultilingualCLIP, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
embs = self.transformer(input_ids, attention_mask)[0]
embs = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum(dim=1)[:, None]
embs = self.LinearTransformation(embs)
return torch.nn.functional.normalize(embs, dim=-1)
# unfortunately need to monkeypatch for tracing to work here
# otherwise it hits the 2GiB protobuf serialization limit
MultilingualCLIP.forward = forward
def to_torchscript(model_name: str) -> torch.jit.ScriptModule:
with tempfile.TemporaryDirectory() as tmpdir:
model = MultilingualCLIP.from_pretrained(model_name, cache_dir=tmpdir)
model.eval()
for param in model.parameters():
param.requires_grad_(False)
return model
def to_onnx(
model_name: str,
output_dir_visual: Path | str,
output_dir_textual: Path | str,
) -> None:
textual_path = get_model_path(output_dir_textual)
with tempfile.TemporaryDirectory() as tmpdir:
model = MultilingualCLIP.from_pretrained(model_name, cache_dir=tmpdir)
AutoTokenizer.from_pretrained(model_name).save_pretrained(output_dir_textual)
model = to_torchscript(model_name)
AutoTokenizer.from_pretrained(model_name).save_pretrained(output_dir_textual)
for param in model.parameters():
param.requires_grad_(False)
export_text_encoder(model, textual_path)
openclip_to_onnx(_MCLIP_TO_OPENCLIP[model_name], output_dir_visual)
optimize(textual_path)
_text_encoder_to_onnx(model, textual_path)
openclip_to_onnx(_MCLIP_TO_OPENCLIP[model_name], output_dir_visual)
optimize(textual_path)
def export_text_encoder(model: MultilingualCLIP, output_path: Path | str) -> None:
def _text_encoder_to_onnx(model: MultilingualCLIP, output_path: Path | str) -> None:
output_path = Path(output_path)
def forward(self: MultilingualCLIP, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
embs = self.transformer(input_ids, attention_mask)[0]
embs = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum(dim=1)[:, None]
embs = self.LinearTransformation(embs)
return torch.nn.functional.normalize(embs, dim=-1)
# unfortunately need to monkeypatch for tracing to work here
# otherwise it hits the 2GiB protobuf serialization limit
MultilingualCLIP.forward = forward
args = (torch.ones(1, 77, dtype=torch.int32), torch.ones(1, 77, dtype=torch.int32))
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)

View File

@@ -26,6 +26,17 @@ class OpenCLIPModelConfig:
self.sequence_length = open_clip_cfg["text_cfg"]["context_length"]
def to_torchscript(model_name: str) -> torch.jit.ScriptModule:
with tempfile.TemporaryDirectory() as tmpdir:
model = MultilingualCLIP.from_pretrained(model_name, cache_dir=tmpdir)
model.eval()
for param in model.parameters():
param.requires_grad_(False)
return model
def to_onnx(
model_cfg: OpenCLIPModelConfig,
output_dir_visual: Path | str | None = None,
@@ -51,7 +62,7 @@ def to_onnx(
save_config(open_clip.get_model_preprocess_cfg(model), output_dir_visual / "preprocess_cfg.json")
save_config(text_vision_cfg, output_dir_visual.parent / "config.json")
export_image_encoder(model, model_cfg, visual_path)
_image_encoder_to_onnx(model, model_cfg, visual_path)
optimize(visual_path)
@@ -61,11 +72,11 @@ def to_onnx(
tokenizer_name = text_vision_cfg["text_cfg"].get("hf_tokenizer_name", "openai/clip-vit-base-patch32")
AutoTokenizer.from_pretrained(tokenizer_name).save_pretrained(output_dir_textual)
export_text_encoder(model, model_cfg, textual_path)
_text_encoder_to_onnx(model, model_cfg, textual_path)
optimize(textual_path)
def export_image_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig, output_path: Path | str) -> None:
def _image_encoder_to_onnx(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig, output_path: Path | str) -> None:
output_path = Path(output_path)
def encode_image(image: torch.Tensor) -> torch.Tensor:
@@ -89,7 +100,7 @@ def export_image_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig,
)
def export_text_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig, output_path: Path | str) -> None:
def _text_encoder_to_onnx(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig, output_path: Path | str) -> None:
output_path = Path(output_path)
def encode_text(text: torch.Tensor) -> torch.Tensor: