Skip to content

Commit b3dbbb8

Browse files
authored
Merge pull request #15 from MITLibraries/USE-112-refactor-CLI-model-load
USE 112 - refactor model load
2 parents 8db533a + 39ef93e commit b3dbbb8

File tree

9 files changed

+311
-130
lines changed

9 files changed

+311
-130
lines changed

Dockerfile

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@ COPY pyproject.toml uv.lock* ./
1515
# Copy CLI application
1616
COPY embeddings ./embeddings
1717

18-
# Install package into system python, includes "marimo-launcher" script
18+
# Install package into system python
1919
RUN uv pip install --system .
2020

2121
# Download the model and include in the Docker image
22-
# NOTE: The env vars "TE_MODEL_URI" and "TE_MODEL_DOWNLOAD_PATH" are set here to support
23-
# the downloading of the model into this image build, but persist in the container and
24-
# effectively also set this as the default model.
22+
# NOTE: The env vars "TE_MODEL_URI" and "TE_MODEL_PATH" are set here to support
23+
# the downloading of the model during image build, but also persist in the container and
24+
# effectively set the default model.
2525
ENV HF_HUB_DISABLE_PROGRESS_BARS=true
2626
ENV TE_MODEL_URI=opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte
27-
ENV TE_MODEL_DOWNLOAD_PATH=/model
27+
ENV TE_MODEL_PATH=/model
2828
RUN python -m embeddings.cli --verbose download-model
2929

3030
ENTRYPOINT ["python", "-m", "embeddings.cli"]

README.md

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ WORKSPACE=### Set to `dev` for local development, this will be set to `stage` an
2424

2525
```shell
2626
TE_MODEL_URI=# HuggingFace model URI
27-
TE_MODEL_DOWNLOAD_PATH=# Download location for model
27+
TE_MODEL_PATH=# Path where the model will be downloaded to and loaded from
2828
HF_HUB_DISABLE_PROGRESS_BARS=#boolean to use progress bars for HuggingFace model downloads; defaults to 'true' in deployed contexts
2929
```
3030

@@ -34,7 +34,7 @@ This CLI application is designed to create embeddings for input texts. To do th
3434

3535
To this end, there is a base embedding class `BaseEmbeddingModel` that is designed to be extended and customized for a particular embedding model.
3636

37-
Once an embedding class has been created, the preferred approach is to set env vars `TE_MODEL_URI` and `TE_MODEL_DOWNLOAD_PATH` directly in the `Dockerfile` to a) download a local snapshot of the model during image build, and b) set this model as the default for the CLI.
37+
Once an embedding class has been created, the preferred approach is to set env vars `TE_MODEL_URI` and `TE_MODEL_PATH` directly in the `Dockerfile` to a) download a local snapshot of the model during image build, and b) set this model as the default for the CLI.
3838

3939
This allows invoking the CLI without specifying a model URI or local location, allowing this model to serve as the default, e.g.:
4040

@@ -61,18 +61,38 @@ Usage: embeddings ping [OPTIONS]
6161
```text
6262
Usage: embeddings download-model [OPTIONS]
6363
64-
Download a model from HuggingFace and save as zip file.
64+
Download a model from HuggingFace and save locally.
6565
6666
Options:
67-
--model-uri TEXT HuggingFace model URI (e.g., 'org/model-name') [required]
68-
--output PATH Output path for zipped model (e.g., '/path/to/model.zip')
69-
[required]
70-
--help Show this message and exit.
67+
--model-uri TEXT HuggingFace model URI (e.g., 'org/model-name')
68+
[required]
69+
--model-path PATH Path where the model will be downloaded to and loaded
70+
from, e.g. '/path/to/model'. [required]
71+
--help Show this message and exit.
7172
```
7273

73-
### `create-embeddings`
74+
### `test-model-load`
7475
```text
75-
TODO...
76+
Usage: embeddings test-model-load [OPTIONS]
77+
78+
Test loading of embedding class and local model based on env vars.
79+
80+
In a deployed context, the following env vars are expected: -
81+
TE_MODEL_URI - TE_MODEL_PATH
82+
83+
With these set, the embedding class should be registered successfully and
84+
initialized, and the model loaded from a local copy.
85+
86+
This CLI command is NOT used during normal workflows. This is used primary
87+
during development and after model downloading/loading changes to ensure the
88+
model loads correctly.
89+
90+
Options:
91+
--model-uri TEXT HuggingFace model URI (e.g., 'org/model-name')
92+
[required]
93+
--model-path PATH Path where the model will be downloaded to and loaded
94+
from, e.g. '/path/to/model'. [required]
95+
--help Show this message and exit.
7696
```
7797

7898

embeddings/cli.py

Lines changed: 78 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import functools
22
import logging
3-
import os
43
import time
54
from collections.abc import Callable
65
from datetime import timedelta
76
from pathlib import Path
7+
from typing import TYPE_CHECKING
88

99
import click
1010

@@ -13,21 +13,8 @@
1313

1414
logger = logging.getLogger(__name__)
1515

16-
17-
def model_required(f: Callable) -> Callable:
18-
"""Decorator for commands that require a specific model."""
19-
20-
@click.option(
21-
"--model-uri",
22-
envvar="TE_MODEL_URI",
23-
required=True,
24-
help="HuggingFace model URI (e.g., 'org/model-name')",
25-
)
26-
@functools.wraps(f)
27-
def wrapper(*args: list, **kwargs: dict) -> Callable:
28-
return f(*args, **kwargs)
29-
30-
return wrapper
16+
if TYPE_CHECKING:
17+
from embeddings.models.base import BaseEmbeddingModel
3118

3219

3320
@click.group("embeddings")
@@ -60,6 +47,60 @@ def _log_command_elapsed_time() -> None:
6047
ctx.call_on_close(_log_command_elapsed_time)
6148

6249

50+
def model_required(f: Callable) -> Callable:
51+
"""Middleware decorator for commands that require an embedding model.
52+
53+
This decorator adds two CLI options:
54+
- "--model-uri": defaults to environment variable "TE_MODEL_URI"
55+
- "--model-path": defaults to environment variable "TE_MODEL_PATH"
56+
57+
The decorator intercepts these parameters, uses the model URI to identify and
58+
instantiate the appropriate embedding model class with the provided model path,
59+
and stores the model instance in the Click context at ctx.obj["model"].
60+
61+
Both model_uri and model_path parameters are consumed by the decorator and not
62+
passed to the decorated command function.
63+
"""
64+
65+
@click.option(
66+
"--model-uri",
67+
envvar="TE_MODEL_URI",
68+
required=True,
69+
help="HuggingFace model URI (e.g., 'org/model-name')",
70+
)
71+
@click.option(
72+
"--model-path",
73+
required=True,
74+
envvar="TE_MODEL_PATH",
75+
type=click.Path(path_type=Path),
76+
help=(
77+
"Path where the model will be downloaded to and loaded from, "
78+
"e.g. '/path/to/model'."
79+
),
80+
)
81+
@functools.wraps(f)
82+
def wrapper(*args: tuple, **kwargs: dict[str, str | Path]) -> Callable:
83+
# pop "model_uri" and "model_path" from CLI args
84+
model_uri: str = str(kwargs.pop("model_uri"))
85+
model_path: str | Path = str(kwargs.pop("model_path"))
86+
87+
# initialize embedding class
88+
model_class = get_model_class(str(model_uri))
89+
model: BaseEmbeddingModel = model_class(model_path)
90+
logger.info(
91+
f"Embedding class '{model.__class__.__name__}' "
92+
f"initialized from model URI '{model_uri}'."
93+
)
94+
95+
# save embedding class instance to Context
96+
ctx: click.Context = args[0] # type: ignore[assignment]
97+
ctx.obj["model"] = model
98+
99+
return f(*args, **kwargs)
100+
101+
return wrapper
102+
103+
63104
@main.command()
64105
def ping() -> None:
65106
"""Emit 'pong' to debug logs and stdout."""
@@ -68,53 +109,49 @@ def ping() -> None:
68109

69110

70111
@main.command()
112+
@click.pass_context
71113
@model_required
72-
@click.option(
73-
"--output",
74-
required=True,
75-
envvar="TE_MODEL_DOWNLOAD_PATH",
76-
type=click.Path(path_type=Path),
77-
help="Output path for zipped model (e.g., '/path/to/model.zip')",
78-
)
79-
def download_model(model_uri: str, output: Path) -> None:
80-
"""Download a model from HuggingFace and save as zip file."""
81-
# load embedding model class
82-
model_class = get_model_class(model_uri)
83-
model = model_class()
114+
def download_model(
115+
ctx: click.Context,
116+
) -> None:
117+
"""Download a model from HuggingFace and save locally."""
118+
model: BaseEmbeddingModel = ctx.obj["model"]
84119

85-
# download model assets
86-
logger.info(f"Downloading model: {model_uri}")
87-
result_path = model.download(output)
120+
logger.info(f"Downloading model: {model.model_uri}")
121+
result_path = model.download()
88122

89123
message = f"Model downloaded and saved to: {result_path}"
90124
logger.info(message)
91125
click.echo(result_path)
92126

93127

94128
@main.command()
95-
def test_model_load() -> None:
129+
@click.pass_context
130+
@model_required
131+
def test_model_load(ctx: click.Context) -> None:
96132
"""Test loading of embedding class and local model based on env vars.
97133
98134
In a deployed context, the following env vars are expected:
99135
- TE_MODEL_URI
100-
- TE_MODEL_DOWNLOAD_PATH
136+
- TE_MODEL_PATH
101137
102138
With these set, the embedding class should be registered successfully and initialized,
103139
and the model loaded from a local copy.
104-
"""
105-
# load embedding model class
106-
model_class = get_model_class(os.environ["TE_MODEL_URI"])
107-
model = model_class()
108140
109-
model.load(os.environ["TE_MODEL_DOWNLOAD_PATH"])
141+
This CLI command is NOT used during normal workflows. This is used primary
142+
during development and after model downloading/loading changes to ensure the model
143+
loads correctly.
144+
"""
145+
model: BaseEmbeddingModel = ctx.obj["model"]
146+
model.load()
110147
click.echo("OK")
111148

112149

113150
@main.command()
151+
@click.pass_context
114152
@model_required
115-
def create_embeddings(_model_uri: str) -> None:
116-
# TODO: docstring # noqa: FIX002
117-
raise NotImplementedError
153+
def create_embedding(ctx: click.Context) -> None:
154+
"""Create a single embedding for a single input text."""
118155

119156

120157
if __name__ == "__main__": # pragma: no cover

embeddings/models/base.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@ class BaseEmbeddingModel(ABC):
1212

1313
MODEL_URI: str # Type hint to document the requirement
1414

15+
def __init__(self, model_path: str | Path) -> None:
16+
"""Initialize the embedding model with a model path.
17+
18+
Args:
19+
model_path: Path where the model will be downloaded to and loaded from.
20+
"""
21+
self.model_path = Path(model_path)
22+
1523
def __init_subclass__(cls, **kwargs: dict) -> None: # noqa: D105
1624
super().__init_subclass__(**kwargs)
1725

@@ -28,17 +36,13 @@ def model_uri(self) -> str:
2836
return self.MODEL_URI
2937

3038
@abstractmethod
31-
def download(self, output_path: str | Path) -> Path:
32-
"""Download and prepare model, saving to output_path.
39+
def download(self) -> Path:
40+
"""Download and prepare model, saving to self.model_path.
3341
34-
Args:
35-
output_path: Path where the model zip should be saved.
42+
Returns:
43+
Path where the model was saved.
3644
"""
3745

3846
@abstractmethod
39-
def load(self, model_path: str | Path) -> None:
40-
"""Load model from local, downloaded instance.
41-
42-
Args:
43-
model_path: Path of local model directory.
44-
"""
47+
def load(self) -> None:
48+
"""Load model from self.model_path."""

embeddings/models/os_neural_sparse_doc_v3_gte.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,27 @@ class OSNeuralSparseDocV3GTE(BaseEmbeddingModel):
3030

3131
MODEL_URI = "opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte"
3232

33-
def __init__(self) -> None:
34-
"""Initialize the model."""
35-
super().__init__()
33+
def __init__(self, model_path: str | Path) -> None:
34+
"""Initialize the model.
35+
36+
Args:
37+
model_path: Path where the model will be downloaded to and loaded from.
38+
"""
39+
super().__init__(model_path)
3640
self._model: PreTrainedModel | None = None
3741
self._tokenizer: DistilBertTokenizerFast | None = None
3842
self._special_token_ids: list | None = None
3943
self._id_to_token: list | None = None
4044

41-
def download(self, output_path: str | Path) -> Path:
42-
"""Download and prepare model, saving to output_path.
45+
def download(self) -> Path:
46+
"""Download and prepare model, saving to self.model_path.
4347
44-
Args:
45-
output_path: Path where the model should be saved.
48+
Returns:
49+
Path where the model was saved.
4650
"""
4751
start_time = time.perf_counter()
4852

49-
output_path = Path(output_path)
50-
logger.info(f"Downloading model: {self.model_uri}, saving to: {output_path}.")
53+
logger.info(f"Downloading model: {self.model_uri}, saving to: {self.model_path}.")
5154

5255
with tempfile.TemporaryDirectory() as temp_dir:
5356
temp_path = Path(temp_dir)
@@ -60,19 +63,21 @@ def download(self, output_path: str | Path) -> Path:
6063
self._patch_local_model_with_alibaba_new_impl(temp_path)
6164

6265
# compress model directory as a zip file
63-
if output_path.suffix.lower() == ".zip":
66+
if self.model_path.suffix.lower() == ".zip":
6467
logger.debug("Creating zip file of model contents.")
65-
shutil.make_archive(str(output_path.with_suffix("")), "zip", temp_path)
68+
shutil.make_archive(
69+
str(self.model_path.with_suffix("")), "zip", temp_path
70+
)
6671

6772
# copy to output directory without zipping
6873
else:
69-
logger.debug(f"Copying model contents to {output_path}")
70-
if output_path.exists():
71-
shutil.rmtree(output_path)
72-
shutil.copytree(temp_path, output_path)
74+
logger.debug(f"Copying model contents to {self.model_path}")
75+
if self.model_path.exists():
76+
shutil.rmtree(self.model_path)
77+
shutil.copytree(temp_path, self.model_path)
7378

7479
logger.info(f"Model downloaded successfully, {time.perf_counter() - start_time}s")
75-
return output_path
80+
return self.model_path
7681

7782
def _patch_local_model_with_alibaba_new_impl(self, model_temp_path: Path) -> None:
7883
"""Patch downloaded model with required assets from Alibaba-NLP/new-impl.
@@ -124,28 +129,23 @@ def _patch_local_model_with_alibaba_new_impl(self, model_temp_path: Path) -> Non
124129

125130
logger.debug("Dependency model Alibaba-NLP/new-impl downloaded and used.")
126131

127-
def load(self, model_path: str | Path) -> None:
128-
"""Load the model from the specified path.
129-
130-
Args:
131-
model_path: Path to the model directory.
132-
"""
132+
def load(self) -> None:
133+
"""Load the model from self.model_path."""
133134
start_time = time.perf_counter()
134-
logger.info(f"Loading model from: {model_path}")
135-
model_path = Path(model_path)
135+
logger.info(f"Loading model from: {self.model_path}")
136136

137137
# ensure model exists locally
138-
if not model_path.exists():
139-
raise FileNotFoundError(f"Model not found at path: {model_path}")
138+
if not self.model_path.exists():
139+
raise FileNotFoundError(f"Model not found at path: {self.model_path}")
140140

141141
# load local model and tokenizer
142142
self._model = AutoModelForMaskedLM.from_pretrained(
143-
model_path,
143+
self.model_path,
144144
trust_remote_code=True,
145145
local_files_only=True,
146146
)
147147
self._tokenizer = AutoTokenizer.from_pretrained( # type: ignore[no-untyped-call]
148-
model_path,
148+
self.model_path,
149149
local_files_only=True,
150150
)
151151

0 commit comments

Comments
 (0)