mirror of
https://github.com/immich-app/immich.git
synced 2024-11-15 09:59:00 -07:00
feat(ml): improve test coverage (#7041)
* update e2e * tokenizer tests * more tests, remove unnecessary code * fix e2e setting * add tests for loading model * update workflow * fixed test
This commit is contained in:
parent
6e853e2a9d
commit
0c4df216d7
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@ -247,7 +247,7 @@ jobs:
|
||||
poetry run mypy --install-types --non-interactive --strict app/
|
||||
- name: Run tests and coverage
|
||||
run: |
|
||||
poetry run pytest --cov app
|
||||
poetry run pytest app --cov=app --cov-report term-missing
|
||||
|
||||
generated-api-up-to-date:
|
||||
name: OpenAPI Clients
|
||||
|
@ -119,16 +119,12 @@ async def load(model: InferenceModel) -> InferenceModel:
|
||||
if model.loaded:
|
||||
return model
|
||||
|
||||
def _load() -> None:
|
||||
def _load(model: InferenceModel) -> None:
|
||||
with lock:
|
||||
model.load()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
if thread_pool is None:
|
||||
model.load()
|
||||
else:
|
||||
await loop.run_in_executor(thread_pool, _load)
|
||||
await run(_load, model)
|
||||
return model
|
||||
except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
|
||||
log.warning(
|
||||
@ -138,10 +134,7 @@ async def load(model: InferenceModel) -> InferenceModel:
|
||||
)
|
||||
)
|
||||
model.clear_cache()
|
||||
if thread_pool is None:
|
||||
model.load()
|
||||
else:
|
||||
await loop.run_in_executor(thread_pool, _load)
|
||||
await run(_load, model)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -21,4 +21,4 @@ def from_model_type(model_type: ModelType, model_name: str, **model_kwargs: Any)
|
||||
case _:
|
||||
raise ValueError(f"Unknown model type {model_type}")
|
||||
|
||||
raise ValueError(f"Unknown ${model_type} model {model_name}")
|
||||
raise ValueError(f"Unknown {model_type} model {model_name}")
|
||||
|
@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
@ -11,7 +10,6 @@ import onnxruntime as ort
|
||||
from huggingface_hub import snapshot_download
|
||||
from onnx.shape_inference import infer_shapes
|
||||
from onnx.tools.update_model_dims import update_inputs_outputs_dims
|
||||
from typing_extensions import Buffer
|
||||
|
||||
import ann.ann
|
||||
from app.models.constants import STATIC_INPUT_PROVIDERS, SUPPORTED_PROVIDERS
|
||||
@ -200,7 +198,7 @@ class InferenceModel(ABC):
|
||||
|
||||
@providers.setter
|
||||
def providers(self, providers: list[str]) -> None:
|
||||
log.debug(
|
||||
log.info(
|
||||
(f"Setting '{self.model_name}' execution providers to {providers}, " "in descending order of preference"),
|
||||
)
|
||||
self._providers = providers
|
||||
@ -217,7 +215,7 @@ class InferenceModel(ABC):
|
||||
|
||||
@provider_options.setter
|
||||
def provider_options(self, provider_options: list[dict[str, Any]]) -> None:
|
||||
log.info(f"Setting execution provider options to {provider_options}")
|
||||
log.debug(f"Setting execution provider options to {provider_options}")
|
||||
self._provider_options = provider_options
|
||||
|
||||
@property
|
||||
@ -255,7 +253,7 @@ class InferenceModel(ABC):
|
||||
|
||||
@property
|
||||
def sess_options_default(self) -> ort.SessionOptions:
|
||||
sess_options = PicklableSessionOptions()
|
||||
sess_options = ort.SessionOptions()
|
||||
sess_options.enable_cpu_mem_arena = False
|
||||
|
||||
# avoid thread contention between models
|
||||
@ -287,15 +285,3 @@ class InferenceModel(ABC):
|
||||
@property
|
||||
def preferred_runtime_default(self) -> ModelRuntime:
|
||||
return ModelRuntime.ARMNN if ann.ann.is_available and settings.ann else ModelRuntime.ONNX
|
||||
|
||||
|
||||
# HF deep copies configs, so we need to make session options picklable
|
||||
class PicklableSessionOptions(ort.SessionOptions): # type: ignore[misc]
|
||||
def __getstate__(self) -> bytes:
|
||||
return pickle.dumps([(attr, getattr(self, attr)) for attr in dir(self) if not callable(getattr(self, attr))])
|
||||
|
||||
def __setstate__(self, state: Buffer) -> None:
|
||||
self.__init__() # type: ignore[misc]
|
||||
attrs: list[tuple[str, Any]] = pickle.loads(state)
|
||||
for attr, val in attrs:
|
||||
setattr(self, attr, val)
|
||||
|
@ -80,20 +80,3 @@ class RevalidationPlugin(BasePlugin): # type: ignore[misc]
|
||||
key = client.build_key(key, namespace)
|
||||
if key in client._handlers:
|
||||
await client.expire(key, client.ttl)
|
||||
|
||||
async def post_multi_get(
|
||||
self,
|
||||
client: SimpleMemoryCache,
|
||||
keys: list[str],
|
||||
ret: list[Any] | None = None,
|
||||
namespace: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if ret is None:
|
||||
return
|
||||
|
||||
for key, val in zip(keys, ret):
|
||||
if namespace is not None:
|
||||
key = client.build_key(key, namespace)
|
||||
if val is not None and key in client._handlers:
|
||||
await client.expire(key, client.ttl)
|
||||
|
@ -144,9 +144,7 @@ class OpenCLIPEncoder(BaseCLIPEncoder):
|
||||
|
||||
def _load(self) -> None:
|
||||
super()._load()
|
||||
text_cfg: dict[str, Any] = self.model_cfg["text_cfg"]
|
||||
context_length: int = text_cfg.get("context_length", 77)
|
||||
pad_token: int = self.tokenizer_cfg["pad_token"]
|
||||
self._load_tokenizer()
|
||||
|
||||
size: list[int] | int = self.preprocess_cfg["size"]
|
||||
self.size = size[0] if isinstance(size, list) else size
|
||||
@ -155,11 +153,19 @@ class OpenCLIPEncoder(BaseCLIPEncoder):
|
||||
self.mean = np.array(self.preprocess_cfg["mean"], dtype=np.float32)
|
||||
self.std = np.array(self.preprocess_cfg["std"], dtype=np.float32)
|
||||
|
||||
def _load_tokenizer(self) -> Tokenizer:
|
||||
log.debug(f"Loading tokenizer for CLIP model '{self.model_name}'")
|
||||
|
||||
text_cfg: dict[str, Any] = self.model_cfg["text_cfg"]
|
||||
context_length: int = text_cfg.get("context_length", 77)
|
||||
pad_token: str = self.tokenizer_cfg["pad_token"]
|
||||
|
||||
self.tokenizer: Tokenizer = Tokenizer.from_file(self.tokenizer_file_path.as_posix())
|
||||
|
||||
pad_id: int = self.tokenizer.token_to_id(pad_token)
|
||||
self.tokenizer.enable_padding(length=context_length, pad_token=pad_token, pad_id=pad_id)
|
||||
self.tokenizer.enable_truncation(max_length=context_length)
|
||||
|
||||
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
|
||||
|
||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
||||
|
@ -1,7 +1,8 @@
|
||||
import json
|
||||
import pickle
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from random import randint
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Callable
|
||||
from unittest import mock
|
||||
|
||||
@ -13,10 +14,12 @@ from fastapi.testclient import TestClient
|
||||
from PIL import Image
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from app.main import load
|
||||
|
||||
from .config import log, settings
|
||||
from .models.base import InferenceModel, PicklableSessionOptions
|
||||
from .models.base import InferenceModel
|
||||
from .models.cache import ModelCache
|
||||
from .models.clip import OpenCLIPEncoder
|
||||
from .models.clip import MCLIPEncoder, OpenCLIPEncoder
|
||||
from .models.facial_recognition import FaceRecognizer
|
||||
from .schemas import ModelRuntime, ModelType
|
||||
|
||||
@ -72,6 +75,17 @@ class TestBase:
|
||||
{"arena_extend_strategy": "kSameAsRequested"},
|
||||
]
|
||||
|
||||
def test_sets_openvino_device_id_if_possible(self, mocker: MockerFixture) -> None:
|
||||
mocked = mocker.patch("app.models.base.ort.capi._pybind_state")
|
||||
mocked.get_available_openvino_device_ids.return_value = ["GPU.0", "CPU"]
|
||||
|
||||
encoder = OpenCLIPEncoder("ViT-B-32__openai", providers=["OpenVINOExecutionProvider", "CPUExecutionProvider"])
|
||||
|
||||
assert encoder.provider_options == [
|
||||
{"device_id": "GPU.0"},
|
||||
{"arena_extend_strategy": "kSameAsRequested"},
|
||||
]
|
||||
|
||||
def test_sets_provider_options_kwarg(self) -> None:
|
||||
encoder = OpenCLIPEncoder(
|
||||
"ViT-B-32__openai",
|
||||
@ -119,7 +133,7 @@ class TestBase:
|
||||
def test_sets_default_cache_dir(self) -> None:
|
||||
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
||||
|
||||
assert encoder.cache_dir == Path("/cache/clip/ViT-B-32__openai")
|
||||
assert encoder.cache_dir == Path(settings.cache_folder) / "clip" / "ViT-B-32__openai"
|
||||
|
||||
def test_sets_cache_dir_kwarg(self) -> None:
|
||||
cache_dir = Path("/test_cache")
|
||||
@ -170,7 +184,7 @@ class TestBase:
|
||||
encoder.clear_cache()
|
||||
|
||||
mock_rmtree.assert_called_once_with(encoder.cache_dir)
|
||||
assert info.call_count == 2
|
||||
info.assert_called_with(f"Cleared cache directory for model '{encoder.model_name}'.")
|
||||
|
||||
def test_clear_cache_warns_if_path_does_not_exist(self, mocker: MockerFixture) -> None:
|
||||
mock_rmtree = mocker.patch("app.models.base.rmtree", autospec=True)
|
||||
@ -267,7 +281,7 @@ class TestBase:
|
||||
def test_download(self, mocker: MockerFixture) -> None:
|
||||
mock_snapshot_download = mocker.patch("app.models.base.snapshot_download")
|
||||
|
||||
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
||||
encoder = OpenCLIPEncoder("ViT-B-32__openai", cache_dir="/path/to/cache")
|
||||
encoder.download()
|
||||
|
||||
mock_snapshot_download.assert_called_once_with(
|
||||
@ -348,6 +362,60 @@ class TestCLIP:
|
||||
assert embedding.dtype == np.float32
|
||||
mocked.run.assert_called_once()
|
||||
|
||||
def test_openclip_tokenizer(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
clip_model_cfg: dict[str, Any],
|
||||
clip_preprocess_cfg: Callable[[Path], dict[str, Any]],
|
||||
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
|
||||
) -> None:
|
||||
mocker.patch.object(OpenCLIPEncoder, "download")
|
||||
mocker.patch.object(OpenCLIPEncoder, "model_cfg", clip_model_cfg)
|
||||
mocker.patch.object(OpenCLIPEncoder, "preprocess_cfg", clip_preprocess_cfg)
|
||||
mocker.patch.object(OpenCLIPEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
|
||||
mock_tokenizer = mocker.patch("app.models.clip.Tokenizer.from_file", autospec=True).return_value
|
||||
mock_ids = [randint(0, 50000) for _ in range(77)]
|
||||
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
|
||||
|
||||
clip_encoder = OpenCLIPEncoder("ViT-B-32__openai", cache_dir="test_cache", mode="text")
|
||||
clip_encoder._load_tokenizer()
|
||||
tokens = clip_encoder.tokenize("test search query")
|
||||
|
||||
assert "text" in tokens
|
||||
assert isinstance(tokens["text"], np.ndarray)
|
||||
assert tokens["text"].shape == (1, 77)
|
||||
assert tokens["text"].dtype == np.int32
|
||||
assert np.allclose(tokens["text"], np.array([mock_ids], dtype=np.int32), atol=0)
|
||||
|
||||
def test_mclip_tokenizer(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
clip_model_cfg: dict[str, Any],
|
||||
clip_preprocess_cfg: Callable[[Path], dict[str, Any]],
|
||||
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
|
||||
) -> None:
|
||||
mocker.patch.object(OpenCLIPEncoder, "download")
|
||||
mocker.patch.object(OpenCLIPEncoder, "model_cfg", clip_model_cfg)
|
||||
mocker.patch.object(OpenCLIPEncoder, "preprocess_cfg", clip_preprocess_cfg)
|
||||
mocker.patch.object(OpenCLIPEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
|
||||
mock_tokenizer = mocker.patch("app.models.clip.Tokenizer.from_file", autospec=True).return_value
|
||||
mock_ids = [randint(0, 50000) for _ in range(77)]
|
||||
mock_attention_mask = [randint(0, 1) for _ in range(77)]
|
||||
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids, attention_mask=mock_attention_mask)
|
||||
|
||||
clip_encoder = MCLIPEncoder("ViT-B-32__openai", cache_dir="test_cache", mode="text")
|
||||
clip_encoder._load_tokenizer()
|
||||
tokens = clip_encoder.tokenize("test search query")
|
||||
|
||||
assert "input_ids" in tokens
|
||||
assert "attention_mask" in tokens
|
||||
assert isinstance(tokens["input_ids"], np.ndarray)
|
||||
assert isinstance(tokens["attention_mask"], np.ndarray)
|
||||
assert tokens["input_ids"].shape == (1, 77)
|
||||
assert tokens["attention_mask"].shape == (1, 77)
|
||||
assert np.allclose(tokens["input_ids"], np.array([mock_ids], dtype=np.int32), atol=0)
|
||||
assert np.allclose(tokens["attention_mask"], np.array([mock_attention_mask], dtype=np.int32), atol=0)
|
||||
|
||||
|
||||
class TestFaceRecognition:
|
||||
def test_set_min_score(self, mocker: MockerFixture) -> None:
|
||||
@ -420,12 +488,75 @@ class TestCache:
|
||||
mock_lock_cls.return_value.__aenter__.return_value.cas.assert_called_with(mock.ANY, ttl=100)
|
||||
|
||||
@mock.patch("app.models.cache.SimpleMemoryCache.expire")
|
||||
async def test_revalidate(self, mock_cache_expire: mock.Mock, mock_get_model: mock.Mock) -> None:
|
||||
async def test_revalidate_get(self, mock_cache_expire: mock.Mock, mock_get_model: mock.Mock) -> None:
|
||||
model_cache = ModelCache(ttl=100, revalidate=True)
|
||||
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
|
||||
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
|
||||
mock_cache_expire.assert_called_once_with(mock.ANY, 100)
|
||||
|
||||
async def test_profiling(self, mock_get_model: mock.Mock) -> None:
|
||||
model_cache = ModelCache(ttl=100, profiling=True)
|
||||
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
|
||||
profiling = await model_cache.get_profiling()
|
||||
assert isinstance(profiling, dict)
|
||||
assert profiling == model_cache.cache.profiling
|
||||
|
||||
async def test_loads_mclip(self) -> None:
|
||||
model_cache = ModelCache()
|
||||
|
||||
model = await model_cache.get("XLM-Roberta-Large-Vit-B-32", ModelType.CLIP, mode="text")
|
||||
|
||||
assert isinstance(model, MCLIPEncoder)
|
||||
assert model.model_name == "XLM-Roberta-Large-Vit-B-32"
|
||||
|
||||
async def test_raises_exception_if_invalid_model_type(self) -> None:
|
||||
invalid: Any = SimpleNamespace(value="invalid")
|
||||
model_cache = ModelCache()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await model_cache.get("XLM-Roberta-Large-Vit-B-32", invalid, mode="text")
|
||||
|
||||
async def test_raises_exception_if_unknown_model_name(self) -> None:
|
||||
model_cache = ModelCache()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await model_cache.get("test_model_name", ModelType.CLIP, mode="text")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestLoad:
|
||||
async def test_load(self) -> None:
|
||||
mock_model = mock.Mock(spec=InferenceModel)
|
||||
mock_model.loaded = False
|
||||
|
||||
res = await load(mock_model)
|
||||
|
||||
assert res is mock_model
|
||||
mock_model.load.assert_called_once()
|
||||
mock_model.clear_cache.assert_not_called()
|
||||
|
||||
async def test_load_returns_model_if_loaded(self) -> None:
|
||||
mock_model = mock.Mock(spec=InferenceModel)
|
||||
mock_model.loaded = True
|
||||
|
||||
res = await load(mock_model)
|
||||
|
||||
assert res is mock_model
|
||||
mock_model.load.assert_not_called()
|
||||
|
||||
async def test_load_clears_cache_and_retries_if_os_error(self) -> None:
|
||||
mock_model = mock.Mock(spec=InferenceModel)
|
||||
mock_model.model_name = "test_model_name"
|
||||
mock_model.model_type = ModelType.CLIP
|
||||
mock_model.load.side_effect = [OSError, None]
|
||||
mock_model.loaded = False
|
||||
|
||||
res = await load(mock_model)
|
||||
|
||||
assert res is mock_model
|
||||
mock_model.clear_cache.assert_called_once()
|
||||
assert mock_model.load.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not settings.test_full,
|
||||
@ -437,15 +568,21 @@ class TestEndpoints:
|
||||
) -> None:
|
||||
byte_image = BytesIO()
|
||||
pil_image.save(byte_image, format="jpeg")
|
||||
expected = responses["clip"]["image"]
|
||||
|
||||
response = deployed_app.post(
|
||||
"http://localhost:3003/predict",
|
||||
data={"modelName": "ViT-B-32__openai", "modelType": "clip", "options": json.dumps({"mode": "vision"})},
|
||||
files={"image": byte_image.getvalue()},
|
||||
)
|
||||
|
||||
actual = response.json()
|
||||
assert response.status_code == 200
|
||||
assert response.json() == responses["clip"]["image"]
|
||||
assert np.allclose(expected, actual)
|
||||
|
||||
def test_clip_text_endpoint(self, responses: dict[str, Any], deployed_app: TestClient) -> None:
|
||||
expected = responses["clip"]["text"]
|
||||
|
||||
response = deployed_app.post(
|
||||
"http://localhost:3003/predict",
|
||||
data={
|
||||
@ -455,12 +592,15 @@ class TestEndpoints:
|
||||
"options": json.dumps({"mode": "text"}),
|
||||
},
|
||||
)
|
||||
|
||||
actual = response.json()
|
||||
assert response.status_code == 200
|
||||
assert response.json() == responses["clip"]["text"]
|
||||
assert np.allclose(expected, actual)
|
||||
|
||||
def test_face_endpoint(self, pil_image: Image.Image, responses: dict[str, Any], deployed_app: TestClient) -> None:
|
||||
byte_image = BytesIO()
|
||||
pil_image.save(byte_image, format="jpeg")
|
||||
expected = responses["facial-recognition"]
|
||||
|
||||
response = deployed_app.post(
|
||||
"http://localhost:3003/predict",
|
||||
@ -471,15 +611,13 @@ class TestEndpoints:
|
||||
},
|
||||
files={"image": byte_image.getvalue()},
|
||||
)
|
||||
|
||||
actual = response.json()
|
||||
assert response.status_code == 200
|
||||
assert response.json() == responses["facial-recognition"]
|
||||
|
||||
|
||||
def test_sess_options() -> None:
|
||||
sess_options = PicklableSessionOptions()
|
||||
sess_options.intra_op_num_threads = 1
|
||||
sess_options.inter_op_num_threads = 1
|
||||
pickled = pickle.dumps(sess_options)
|
||||
unpickled = pickle.loads(pickled)
|
||||
assert unpickled.intra_op_num_threads == 1
|
||||
assert unpickled.inter_op_num_threads == 1
|
||||
assert len(expected) == len(actual)
|
||||
for expected_face, actual_face in zip(expected, actual):
|
||||
assert expected_face["imageHeight"] == actual_face["imageHeight"]
|
||||
assert expected_face["imageWidth"] == actual_face["imageWidth"]
|
||||
assert expected_face["boundingBox"] == actual_face["boundingBox"]
|
||||
assert np.allclose(expected_face["embedding"], actual_face["embedding"])
|
||||
assert np.allclose(expected_face["score"], actual_face["score"])
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user