diff --git a/machine-learning/app/models/base.py b/machine-learning/app/models/base.py index 08a2f5f6b9..e5ebb828f3 100644 --- a/machine-learning/app/models/base.py +++ b/machine-learning/app/models/base.py @@ -14,7 +14,7 @@ import ann.ann from app.models.constants import SUPPORTED_PROVIDERS from ..config import get_cache_dir, get_hf_model_name, log, settings -from ..schemas import ModelType +from ..schemas import ModelRuntime, ModelType from .ann import AnnSession @@ -28,6 +28,7 @@ class InferenceModel(ABC): providers: list[str] | None = None, provider_options: list[dict[str, Any]] | None = None, sess_options: ort.SessionOptions | None = None, + preferred_runtime: ModelRuntime | None = None, **model_kwargs: Any, ) -> None: self.loaded = False @@ -36,6 +37,7 @@ class InferenceModel(ABC): self.providers = providers if providers is not None else self.providers_default self.provider_options = provider_options if provider_options is not None else self.provider_options_default self.sess_options = sess_options if sess_options is not None else self.sess_options_default + self.preferred_runtime = preferred_runtime if preferred_runtime is not None else self.preferred_runtime_default def download(self) -> None: if not self.cached: @@ -66,11 +68,13 @@ class InferenceModel(ABC): pass def _download(self) -> None: + ignore_patterns = [] if self.preferred_runtime == ModelRuntime.ARMNN else ["*.armnn"] snapshot_download( get_hf_model_name(self.model_name), cache_dir=self.cache_dir, local_dir=self.cache_dir, local_dir_use_symlinks=False, + ignore_patterns=ignore_patterns, ) @abstractmethod @@ -100,18 +104,28 @@ class InferenceModel(ABC): self.cache_dir.mkdir(parents=True, exist_ok=True) def _make_session(self, model_path: Path) -> AnnSession | ort.InferenceSession: - armnn_path = model_path.with_suffix(".armnn") - if settings.ann and ann.ann.is_available and armnn_path.is_file(): - session = AnnSession(armnn_path) - elif model_path.is_file(): - session = ort.InferenceSession( - model_path.as_posix(), - sess_options=self.sess_options, - providers=self.providers, - provider_options=self.provider_options, + if not model_path.is_file(): + onnx_path = model_path.with_suffix(".onnx") + if not onnx_path.is_file(): + raise ValueError(f"Model path '{model_path}' does not exist") + + log.warning( + f"Could not find model path '{model_path}'. " f"Falling back to ONNX model path '{onnx_path}' instead.", ) - else: - raise ValueError(f"the file model_path='{model_path}' does not exist") + model_path = onnx_path + + match model_path.suffix: + case ".armnn": + session = AnnSession(model_path) + case ".onnx": + session = ort.InferenceSession( + model_path.as_posix(), + sess_options=self.sess_options, + providers=self.providers, + provider_options=self.provider_options, + ) + case _: + raise ValueError(f"Unsupported model file type: {model_path.suffix}") return session @property @@ -132,7 +146,7 @@ class InferenceModel(ABC): @property def cached(self) -> bool: - return self.cache_dir.exists() and any(self.cache_dir.iterdir()) + return self.cache_dir.is_dir() and any(self.cache_dir.iterdir()) @property def providers(self) -> list[str]: @@ -215,6 +229,19 @@ class InferenceModel(ABC): return sess_options + @property + def preferred_runtime(self) -> ModelRuntime: + return self._preferred_runtime + + @preferred_runtime.setter + def preferred_runtime(self, preferred_runtime: ModelRuntime) -> None: + log.debug(f"Setting preferred runtime to {preferred_runtime}") + self._preferred_runtime = preferred_runtime + + @property + def preferred_runtime_default(self) -> ModelRuntime: + return ModelRuntime.ARMNN if ann.ann.is_available and settings.ann else ModelRuntime.ONNX + # HF deep copies configs, so we need to make session options picklable class PicklableSessionOptions(ort.SessionOptions): # type: ignore[misc] diff --git a/machine-learning/app/models/clip.py b/machine-learning/app/models/clip.py index 895786cde1..469789155e 100644 --- a/machine-learning/app/models/clip.py +++ b/machine-learning/app/models/clip.py @@ -81,11 +81,11 @@ class BaseCLIPEncoder(InferenceModel): @property def textual_path(self) -> Path: - return self.textual_dir / "model.onnx" + return self.textual_dir / f"model.{self.preferred_runtime}" @property def visual_path(self) -> Path: - return self.visual_dir / "model.onnx" + return self.visual_dir / f"model.{self.preferred_runtime}" @property def tokenizer_file_path(self) -> Path: diff --git a/machine-learning/app/models/facial_recognition.py b/machine-learning/app/models/facial_recognition.py index 6b66e57c47..072fc807f9 100644 --- a/machine-learning/app/models/facial_recognition.py +++ b/machine-learning/app/models/facial_recognition.py @@ -77,11 +77,11 @@ class FaceRecognizer(InferenceModel): @property def det_file(self) -> Path: - return self.cache_dir / "detection" / "model.onnx" + return self.cache_dir / "detection" / f"model.{self.preferred_runtime}" @property def rec_file(self) -> Path: - return self.cache_dir / "recognition" / "model.onnx" + return self.cache_dir / "recognition" / f"model.{self.preferred_runtime}" def configure(self, **model_kwargs: Any) -> None: self.det_model.det_thresh = model_kwargs.pop("minScore", self.det_model.det_thresh) diff --git a/machine-learning/app/schemas.py b/machine-learning/app/schemas.py index 0be2bb8a3d..f9e64bd259 100644 --- a/machine-learning/app/schemas.py +++ b/machine-learning/app/schemas.py @@ -26,6 +26,11 @@ class ModelType(str, Enum): FACIAL_RECOGNITION = "facial-recognition" +class ModelRuntime(str, Enum): + ONNX = "onnx" + ARMNN = "armnn" + + class HasProfiling(Protocol): profiling: dict[str, float] diff --git a/machine-learning/app/test_main.py b/machine-learning/app/test_main.py index b45f4bf554..adcb5f43c0 100644 --- a/machine-learning/app/test_main.py +++ b/machine-learning/app/test_main.py @@ -18,7 +18,7 @@ from .models.base import InferenceModel, PicklableSessionOptions from .models.cache import ModelCache from .models.clip import OpenCLIPEncoder from .models.facial_recognition import FaceRecognizer -from .schemas import ModelType +from .schemas import ModelRuntime, ModelType class TestBase: @@ -127,6 +127,30 @@ class TestBase: assert encoder.cache_dir == cache_dir + def test_sets_default_preferred_runtime(self, mocker: MockerFixture) -> None: + mocker.patch.object(settings, "ann", True) + mocker.patch("ann.ann.is_available", False) + + encoder = OpenCLIPEncoder("ViT-B-32__openai") + + assert encoder.preferred_runtime == ModelRuntime.ONNX + + def test_sets_default_preferred_runtime_to_armnn_if_available(self, mocker: MockerFixture) -> None: + mocker.patch.object(settings, "ann", True) + mocker.patch("ann.ann.is_available", True) + + encoder = OpenCLIPEncoder("ViT-B-32__openai") + + assert encoder.preferred_runtime == ModelRuntime.ARMNN + + def test_sets_preferred_runtime_kwarg(self, mocker: MockerFixture) -> None: + mocker.patch.object(settings, "ann", False) + mocker.patch("ann.ann.is_available", False) + + encoder = OpenCLIPEncoder("ViT-B-32__openai", preferred_runtime=ModelRuntime.ARMNN) + + assert encoder.preferred_runtime == ModelRuntime.ARMNN + def test_casts_cache_dir_string_to_path(self) -> None: cache_dir = "/test_cache" encoder = OpenCLIPEncoder("ViT-B-32__openai", cache_dir=cache_dir) @@ -195,46 +219,79 @@ class TestBase: warning.assert_called_once() def test_make_session_return_ann_if_available(self, mocker: MockerFixture) -> None: - mock_cache_dir = mocker.Mock() - mock_cache_dir.is_file.return_value = True - mock_cache_dir.with_suffix.return_value = mock_cache_dir - mocker.patch.object(settings, "ann", True) - mocker.patch("ann.ann.is_available", True) + mock_model_path = mocker.Mock() + mock_model_path.is_file.return_value = True + mock_model_path.suffix = ".armnn" + mock_model_path.with_suffix.return_value = mock_model_path mock_session = mocker.patch("app.models.base.AnnSession") encoder = OpenCLIPEncoder("ViT-B-32__openai") - encoder._make_session(mock_cache_dir) + encoder._make_session(mock_model_path) mock_session.assert_called_once() def test_make_session_return_ort_if_available_and_ann_is_not(self, mocker: MockerFixture) -> None: - mock_cache_dir = mocker.Mock() - mock_cache_dir.is_file.return_value = True - mock_cache_dir.with_suffix.return_value = mock_cache_dir - mocker.patch.object(settings, "ann", False) - mocker.patch("ann.ann.is_available", False) - mock_session = mocker.patch("app.models.base.ort.InferenceSession") + mock_armnn_path = mocker.Mock() + mock_armnn_path.is_file.return_value = False + mock_armnn_path.suffix = ".armnn" + + mock_onnx_path = mocker.Mock() + mock_onnx_path.is_file.return_value = True + mock_onnx_path.suffix = ".onnx" + mock_armnn_path.with_suffix.return_value = mock_onnx_path + + mock_ann = mocker.patch("app.models.base.AnnSession") + mock_ort = mocker.patch("app.models.base.ort.InferenceSession") encoder = OpenCLIPEncoder("ViT-B-32__openai") - encoder._make_session(mock_cache_dir) + encoder._make_session(mock_armnn_path) - mock_session.assert_called_once() + mock_ort.assert_called_once() + mock_ann.assert_not_called() def test_make_session_raises_exception_if_path_does_not_exist(self, mocker: MockerFixture) -> None: - mock_cache_dir = mocker.Mock() - mock_cache_dir.is_file.return_value = False - mock_cache_dir.with_suffix.return_value = mock_cache_dir - mocker.patch("ann.ann.is_available", False) - mock_ann = mocker.patch("app.models.base.ort.InferenceSession") + mock_model_path = mocker.Mock() + mock_model_path.is_file.return_value = False + mock_model_path.suffix = ".onnx" + mock_model_path.with_suffix.return_value = mock_model_path + mock_ann = mocker.patch("app.models.base.AnnSession") mock_ort = mocker.patch("app.models.base.ort.InferenceSession") encoder = OpenCLIPEncoder("ViT-B-32__openai") with pytest.raises(ValueError): - encoder._make_session(mock_cache_dir) + encoder._make_session(mock_model_path) mock_ann.assert_not_called() mock_ort.assert_not_called() + def test_download(self, mocker: MockerFixture) -> None: + mock_snapshot_download = mocker.patch("app.models.base.snapshot_download") + + encoder = OpenCLIPEncoder("ViT-B-32__openai") + encoder.download() + + mock_snapshot_download.assert_called_once_with( + "immich-app/ViT-B-32__openai", + cache_dir=encoder.cache_dir, + local_dir=encoder.cache_dir, + local_dir_use_symlinks=False, + ignore_patterns=["*.armnn"], + ) + + def test_download_downloads_armnn_if_preferred_runtime(self, mocker: MockerFixture) -> None: + mock_snapshot_download = mocker.patch("app.models.base.snapshot_download") + + encoder = OpenCLIPEncoder("ViT-B-32__openai", preferred_runtime=ModelRuntime.ARMNN) + encoder.download() + + mock_snapshot_download.assert_called_once_with( + "immich-app/ViT-B-32__openai", + cache_dir=encoder.cache_dir, + local_dir=encoder.cache_dir, + local_dir_use_symlinks=False, + ignore_patterns=[], + ) + class TestCLIP: embedding = np.random.rand(512).astype(np.float32)