mirror of
https://github.com/immich-app/immich.git
synced 2024-11-15 18:08:48 -07:00
feat: preloading of machine learning models (#7540)
This commit is contained in:
parent
762c4684f8
commit
e8b001f62f
@ -124,16 +124,18 @@ Redis (Sentinel) URL example JSON before encoding:
|
||||
|
||||
## Machine Learning
|
||||
|
||||
| Variable | Description | Default | Services |
|
||||
| :----------------------------------------------- | :----------------------------------------------------------------- | :-----------------: | :--------------- |
|
||||
| `MACHINE_LEARNING_MODEL_TTL` | Inactivity time (s) before a model is unloaded (disabled if \<= 0) | `300` | machine learning |
|
||||
| `MACHINE_LEARNING_MODEL_TTL_POLL_S` | Interval (s) between checks for the model TTL (disabled if \<= 0) | `10` | machine learning |
|
||||
| `MACHINE_LEARNING_CACHE_FOLDER` | Directory where models are downloaded | `/cache` | machine learning |
|
||||
| `MACHINE_LEARNING_REQUEST_THREADS`<sup>\*1</sup> | Thread count of the request thread pool (disabled if \<= 0) | number of CPU cores | machine learning |
|
||||
| `MACHINE_LEARNING_MODEL_INTER_OP_THREADS` | Number of parallel model operations | `1` | machine learning |
|
||||
| `MACHINE_LEARNING_MODEL_INTRA_OP_THREADS` | Number of threads for each model operation | `2` | machine learning |
|
||||
| `MACHINE_LEARNING_WORKERS`<sup>\*2</sup> | Number of worker processes to spawn | `1` | machine learning |
|
||||
| `MACHINE_LEARNING_WORKER_TIMEOUT` | Maximum time (s) of unresponsiveness before a worker is killed | `120` | machine learning |
|
||||
| Variable | Description | Default | Services |
|
||||
| :----------------------------------------------- | :------------------------------------------------------------------- | :-----------------: | :--------------- |
|
||||
| `MACHINE_LEARNING_MODEL_TTL` | Inactivity time (s) before a model is unloaded (disabled if \<= 0) | `300` | machine learning |
|
||||
| `MACHINE_LEARNING_MODEL_TTL_POLL_S` | Interval (s) between checks for the model TTL (disabled if \<= 0) | `10` | machine learning |
|
||||
| `MACHINE_LEARNING_CACHE_FOLDER` | Directory where models are downloaded | `/cache` | machine learning |
|
||||
| `MACHINE_LEARNING_REQUEST_THREADS`<sup>\*1</sup> | Thread count of the request thread pool (disabled if \<= 0) | number of CPU cores | machine learning |
|
||||
| `MACHINE_LEARNING_MODEL_INTER_OP_THREADS` | Number of parallel model operations | `1` | machine learning |
|
||||
| `MACHINE_LEARNING_MODEL_INTRA_OP_THREADS` | Number of threads for each model operation | `2` | machine learning |
|
||||
| `MACHINE_LEARNING_WORKERS`<sup>\*2</sup> | Number of worker processes to spawn | `1` | machine learning |
|
||||
| `MACHINE_LEARNING_WORKER_TIMEOUT` | Maximum time (s) of unresponsiveness before a worker is killed | `120` | machine learning |
|
||||
| `MACHINE_LEARNING_PRELOAD__CLIP` | Name of a CLIP model to be preloaded and kept in cache | | machine learning |
|
||||
| `MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION` | Name of a facial recognition model to be preloaded and kept in cache | | machine learning |
|
||||
|
||||
\*1: It is recommended to begin with this parameter when changing the concurrency levels of the machine learning service and then tune the other ones.
|
||||
|
||||
|
2
machine-learning/.gitignore
vendored
2
machine-learning/.gitignore
vendored
@ -167,6 +167,8 @@ cython_debug/
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
|
||||
# VS Code
|
||||
.vscode
|
||||
|
||||
*.onnx
|
||||
*.zip
|
@ -6,7 +6,7 @@ from pathlib import Path
|
||||
from socket import socket
|
||||
|
||||
from gunicorn.arbiter import Arbiter
|
||||
from pydantic import BaseSettings
|
||||
from pydantic import BaseModel, BaseSettings
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
from uvicorn import Server
|
||||
@ -15,6 +15,11 @@ from uvicorn.workers import UvicornWorker
|
||||
from .schemas import ModelType
|
||||
|
||||
|
||||
class PreloadModelData(BaseModel):
|
||||
clip: str | None
|
||||
facial_recognition: str | None
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
cache_folder: str = "/cache"
|
||||
model_ttl: int = 300
|
||||
@ -27,10 +32,12 @@ class Settings(BaseSettings):
|
||||
model_inter_op_threads: int = 0
|
||||
model_intra_op_threads: int = 0
|
||||
ann: bool = True
|
||||
preload: PreloadModelData | None = None
|
||||
|
||||
class Config:
|
||||
env_prefix = "MACHINE_LEARNING_"
|
||||
case_sensitive = False
|
||||
env_nested_delimiter = "__"
|
||||
|
||||
|
||||
class LogSettings(BaseSettings):
|
||||
|
@ -17,7 +17,7 @@ from starlette.formparsers import MultiPartParser
|
||||
|
||||
from app.models.base import InferenceModel
|
||||
|
||||
from .config import log, settings
|
||||
from .config import PreloadModelData, log, settings
|
||||
from .models.cache import ModelCache
|
||||
from .schemas import (
|
||||
MessageResponse,
|
||||
@ -27,7 +27,7 @@ from .schemas import (
|
||||
|
||||
MultiPartParser.max_file_size = 2**26 # spools to disk if payload is 64 MiB or larger
|
||||
|
||||
model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
|
||||
model_cache = ModelCache(revalidate=settings.model_ttl > 0)
|
||||
thread_pool: ThreadPoolExecutor | None = None
|
||||
lock = threading.Lock()
|
||||
active_requests = 0
|
||||
@ -51,6 +51,8 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
|
||||
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
|
||||
if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0:
|
||||
asyncio.ensure_future(idle_shutdown_task())
|
||||
if settings.preload is not None:
|
||||
await preload_models(settings.preload)
|
||||
yield
|
||||
finally:
|
||||
log.handlers.clear()
|
||||
@ -61,6 +63,14 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
|
||||
gc.collect()
|
||||
|
||||
|
||||
async def preload_models(preload_models: PreloadModelData) -> None:
|
||||
log.info(f"Preloading models: {preload_models}")
|
||||
if preload_models.clip is not None:
|
||||
await load(await model_cache.get(preload_models.clip, ModelType.CLIP))
|
||||
if preload_models.facial_recognition is not None:
|
||||
await load(await model_cache.get(preload_models.facial_recognition, ModelType.FACIAL_RECOGNITION))
|
||||
|
||||
|
||||
def update_state() -> Iterator[None]:
|
||||
global active_requests, last_called
|
||||
active_requests += 1
|
||||
@ -103,7 +113,7 @@ async def predict(
|
||||
except orjson.JSONDecodeError:
|
||||
raise HTTPException(400, f"Invalid options JSON: {options}")
|
||||
|
||||
model = await load(await model_cache.get(model_name, model_type, **kwargs))
|
||||
model = await load(await model_cache.get(model_name, model_type, ttl=settings.model_ttl, **kwargs))
|
||||
model.configure(**kwargs)
|
||||
outputs = await run(model.predict, inputs)
|
||||
return ORJSONResponse(outputs)
|
||||
|
@ -2,7 +2,7 @@ from typing import Any
|
||||
|
||||
from aiocache.backends.memory import SimpleMemoryCache
|
||||
from aiocache.lock import OptimisticLock
|
||||
from aiocache.plugins import BasePlugin, TimingPlugin
|
||||
from aiocache.plugins import TimingPlugin
|
||||
|
||||
from app.models import from_model_type
|
||||
|
||||
@ -15,28 +15,25 @@ class ModelCache:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ttl: float | None = None,
|
||||
revalidate: bool = False,
|
||||
timeout: int | None = None,
|
||||
profiling: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
ttl: Unloads model after this duration. Disabled if None. Defaults to None.
|
||||
revalidate: Resets TTL on cache hit. Useful to keep models in memory while active. Defaults to False.
|
||||
timeout: Maximum allowed time for model to load. Disabled if None. Defaults to None.
|
||||
profiling: Collects metrics for cache operations, adding slight overhead. Defaults to False.
|
||||
"""
|
||||
|
||||
self.ttl = ttl
|
||||
plugins = []
|
||||
|
||||
if revalidate:
|
||||
plugins.append(RevalidationPlugin())
|
||||
if profiling:
|
||||
plugins.append(TimingPlugin())
|
||||
|
||||
self.cache = SimpleMemoryCache(ttl=ttl, timeout=timeout, plugins=plugins, namespace=None)
|
||||
self.revalidate_enable = revalidate
|
||||
|
||||
self.cache = SimpleMemoryCache(timeout=timeout, plugins=plugins, namespace=None)
|
||||
|
||||
async def get(self, model_name: str, model_type: ModelType, **model_kwargs: Any) -> InferenceModel:
|
||||
"""
|
||||
@ -49,11 +46,14 @@ class ModelCache:
|
||||
"""
|
||||
|
||||
key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}"
|
||||
|
||||
async with OptimisticLock(self.cache, key) as lock:
|
||||
model: InferenceModel | None = await self.cache.get(key)
|
||||
if model is None:
|
||||
model = from_model_type(model_type, model_name, **model_kwargs)
|
||||
await lock.cas(model, ttl=self.ttl)
|
||||
await lock.cas(model, ttl=model_kwargs.get("ttl", None))
|
||||
elif self.revalidate_enable:
|
||||
await self.revalidate(key, model_kwargs.get("ttl", None))
|
||||
return model
|
||||
|
||||
async def get_profiling(self) -> dict[str, float] | None:
|
||||
@ -62,21 +62,6 @@ class ModelCache:
|
||||
|
||||
return self.cache.profiling
|
||||
|
||||
|
||||
class RevalidationPlugin(BasePlugin): # type: ignore[misc]
|
||||
"""Revalidates cache item's TTL after cache hit."""
|
||||
|
||||
async def post_get(
|
||||
self,
|
||||
client: SimpleMemoryCache,
|
||||
key: str,
|
||||
ret: Any | None = None,
|
||||
namespace: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if ret is None:
|
||||
return
|
||||
if namespace is not None:
|
||||
key = client.build_key(key, namespace)
|
||||
if key in client._handlers:
|
||||
await client.expire(key, client.ttl)
|
||||
async def revalidate(self, key: str, ttl: int | None) -> None:
|
||||
if ttl is not None and key in self.cache._handlers:
|
||||
await self.cache.expire(key, ttl)
|
||||
|
@ -13,11 +13,12 @@ import onnxruntime as ort
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from PIL import Image
|
||||
from pytest import MonkeyPatch
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from app.main import load
|
||||
from app.main import load, preload_models
|
||||
|
||||
from .config import log, settings
|
||||
from .config import Settings, log, settings
|
||||
from .models.base import InferenceModel
|
||||
from .models.cache import ModelCache
|
||||
from .models.clip import MCLIPEncoder, OpenCLIPEncoder
|
||||
@ -509,20 +510,20 @@ class TestCache:
|
||||
|
||||
@mock.patch("app.models.cache.OptimisticLock", autospec=True)
|
||||
async def test_model_ttl(self, mock_lock_cls: mock.Mock, mock_get_model: mock.Mock) -> None:
|
||||
model_cache = ModelCache(ttl=100)
|
||||
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
|
||||
model_cache = ModelCache()
|
||||
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
|
||||
mock_lock_cls.return_value.__aenter__.return_value.cas.assert_called_with(mock.ANY, ttl=100)
|
||||
|
||||
@mock.patch("app.models.cache.SimpleMemoryCache.expire")
|
||||
async def test_revalidate_get(self, mock_cache_expire: mock.Mock, mock_get_model: mock.Mock) -> None:
|
||||
model_cache = ModelCache(ttl=100, revalidate=True)
|
||||
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
|
||||
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
|
||||
model_cache = ModelCache(revalidate=True)
|
||||
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
|
||||
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
|
||||
mock_cache_expire.assert_called_once_with(mock.ANY, 100)
|
||||
|
||||
async def test_profiling(self, mock_get_model: mock.Mock) -> None:
|
||||
model_cache = ModelCache(ttl=100, profiling=True)
|
||||
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
|
||||
model_cache = ModelCache(profiling=True)
|
||||
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
|
||||
profiling = await model_cache.get_profiling()
|
||||
assert isinstance(profiling, dict)
|
||||
assert profiling == model_cache.cache.profiling
|
||||
@ -548,6 +549,25 @@ class TestCache:
|
||||
with pytest.raises(ValueError):
|
||||
await model_cache.get("test_model_name", ModelType.CLIP, mode="text")
|
||||
|
||||
async def test_preloads_models(self, monkeypatch: MonkeyPatch, mock_get_model: mock.Mock) -> None:
|
||||
os.environ["MACHINE_LEARNING_PRELOAD__CLIP"] = "ViT-B-32__openai"
|
||||
os.environ["MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION"] = "buffalo_s"
|
||||
|
||||
settings = Settings()
|
||||
assert settings.preload is not None
|
||||
assert settings.preload.clip == "ViT-B-32__openai"
|
||||
assert settings.preload.facial_recognition == "buffalo_s"
|
||||
|
||||
model_cache = ModelCache()
|
||||
monkeypatch.setattr("app.main.model_cache", model_cache)
|
||||
|
||||
await preload_models(settings.preload)
|
||||
assert len(model_cache.cache._cache) == 2
|
||||
assert mock_get_model.call_count == 2
|
||||
await model_cache.get("ViT-B-32__openai", ModelType.CLIP, ttl=100)
|
||||
await model_cache.get("buffalo_s", ModelType.FACIAL_RECOGNITION, ttl=100)
|
||||
assert mock_get_model.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestLoad:
|
||||
|
Loading…
Reference in New Issue
Block a user