Skip to content

Commit 8db533a

Browse files
authored
Merge pull request #14 from MITLibraries/USE-113-download-model
USE 113 - OSNeuralSparseDocV3GTE download and load
2 parents 462aef5 + 8899d51 commit 8db533a

File tree

9 files changed

+1006
-7
lines changed

9 files changed

+1006
-7
lines changed

Dockerfile

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,13 @@ COPY embeddings ./embeddings
1818
# Install package into system python, includes "marimo-launcher" script
1919
RUN uv pip install --system .
2020

21-
ENTRYPOINT ["embeddings"]
21+
# Download the model and include in the Docker image
22+
# NOTE: The env vars "TE_MODEL_URI" and "TE_MODEL_DOWNLOAD_PATH" are set here to support
23+
# the downloading of the model into this image build, but persist in the container and
24+
# effectively also set this as the default model.
25+
ENV HF_HUB_DISABLE_PROGRESS_BARS=true
26+
ENV TE_MODEL_URI=opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte
27+
ENV TE_MODEL_DOWNLOAD_PATH=/model
28+
RUN python -m embeddings.cli --verbose download-model
29+
30+
ENTRYPOINT ["python", "-m", "embeddings.cli"]

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,20 @@ TE_MODEL_DOWNLOAD_PATH=# Download location for model
2828
HF_HUB_DISABLE_PROGRESS_BARS=#boolean to use progress bars for HuggingFace model downloads; defaults to 'true' in deployed contexts
2929
```
3030

31+
## Configuring an Embedding Model
32+
33+
This CLI application is designed to create embeddings for input texts. To do this, a pre-trained model must be identified and configured for use.
34+
35+
To this end, there is a base embedding class `BaseEmbeddingModel` that is designed to be extended and customized for a particular embedding model.
36+
37+
Once an embedding class has been created, the preferred approach is to set env vars `TE_MODEL_URI` and `TE_MODEL_DOWNLOAD_PATH` directly in the `Dockerfile` to a) download a local snapshot of the model during image build, and b) set this model as the default for the CLI.
38+
39+
This allows invoking the CLI without specifying a model URI or local location, allowing this model to serve as the default, e.g.:
40+
41+
```shell
42+
uv run --env-file .env embeddings test-model-load
43+
```
44+
3145
## CLI Commands
3246

3347
For local development, all CLI commands should be invoked with the following format to pickup environment variables from `.env`:

embeddings/cli.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import functools
22
import logging
3+
import os
34
import time
45
from collections.abc import Callable
56
from datetime import timedelta
@@ -90,6 +91,25 @@ def download_model(model_uri: str, output: Path) -> None:
9091
click.echo(result_path)
9192

9293

94+
@main.command()
95+
def test_model_load() -> None:
96+
"""Test loading of embedding class and local model based on env vars.
97+
98+
In a deployed context, the following env vars are expected:
99+
- TE_MODEL_URI
100+
- TE_MODEL_DOWNLOAD_PATH
101+
102+
With these set, the embedding class should be registered successfully and initialized,
103+
and the model loaded from a local copy.
104+
"""
105+
# load embedding model class
106+
model_class = get_model_class(os.environ["TE_MODEL_URI"])
107+
model = model_class()
108+
109+
model.load(os.environ["TE_MODEL_DOWNLOAD_PATH"])
110+
click.echo("OK")
111+
112+
93113
@main.command()
94114
@model_required
95115
def create_embeddings(_model_uri: str) -> None:

embeddings/models/base.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,17 @@ def model_uri(self) -> str:
2828
return self.MODEL_URI
2929

3030
@abstractmethod
31-
def download(self, output_path: Path) -> Path:
31+
def download(self, output_path: str | Path) -> Path:
3232
"""Download and prepare model, saving to output_path.
3333
3434
Args:
3535
output_path: Path where the model zip should be saved.
3636
"""
37+
38+
@abstractmethod
39+
def load(self, model_path: str | Path) -> None:
40+
"""Load model from local, downloaded instance.
41+
42+
Args:
43+
model_path: Path of local model directory.
44+
"""
Lines changed: 141 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,24 @@
11
"""OpenSearch Neural Sparse Doc v3 GTE model."""
22

3+
import json
34
import logging
5+
import shutil
6+
import tempfile
7+
import time
48
from pathlib import Path
9+
from typing import TYPE_CHECKING
10+
11+
from huggingface_hub import snapshot_download
12+
from transformers import AutoModelForMaskedLM, AutoTokenizer
513

614
from embeddings.models.base import BaseEmbeddingModel
715

16+
if TYPE_CHECKING:
17+
from transformers import PreTrainedModel
18+
from transformers.models.distilbert.tokenization_distilbert_fast import (
19+
DistilBertTokenizerFast,
20+
)
21+
822
logger = logging.getLogger(__name__)
923

1024

@@ -16,11 +30,134 @@ class OSNeuralSparseDocV3GTE(BaseEmbeddingModel):
1630

1731
MODEL_URI = "opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte"
1832

19-
def download(self, output_path: Path) -> Path:
33+
def __init__(self) -> None:
34+
"""Initialize the model."""
35+
super().__init__()
36+
self._model: PreTrainedModel | None = None
37+
self._tokenizer: DistilBertTokenizerFast | None = None
38+
self._special_token_ids: list | None = None
39+
self._id_to_token: list | None = None
40+
41+
def download(self, output_path: str | Path) -> Path:
2042
"""Download and prepare model, saving to output_path.
2143
2244
Args:
23-
output_path: Path where the model zip should be saved.
45+
output_path: Path where the model should be saved.
2446
"""
25-
logger.info(f"Downloading model: { self.model_uri}, saving to: {output_path}.")
26-
raise NotImplementedError
47+
start_time = time.perf_counter()
48+
49+
output_path = Path(output_path)
50+
logger.info(f"Downloading model: {self.model_uri}, saving to: {output_path}.")
51+
52+
with tempfile.TemporaryDirectory() as temp_dir:
53+
temp_path = Path(temp_dir)
54+
55+
# download snapshot of HuggingFace model
56+
snapshot_download(repo_id=self.model_uri, local_dir=temp_path)
57+
logger.debug("Model download complete.")
58+
59+
# patch local model with files from dependency model "Alibaba-NLP/new-impl"
60+
self._patch_local_model_with_alibaba_new_impl(temp_path)
61+
62+
# compress model directory as a zip file
63+
if output_path.suffix.lower() == ".zip":
64+
logger.debug("Creating zip file of model contents.")
65+
shutil.make_archive(str(output_path.with_suffix("")), "zip", temp_path)
66+
67+
# copy to output directory without zipping
68+
else:
69+
logger.debug(f"Copying model contents to {output_path}")
70+
if output_path.exists():
71+
shutil.rmtree(output_path)
72+
shutil.copytree(temp_path, output_path)
73+
74+
logger.info(f"Model downloaded successfully, {time.perf_counter() - start_time}s")
75+
return output_path
76+
77+
def _patch_local_model_with_alibaba_new_impl(self, model_temp_path: Path) -> None:
78+
"""Patch downloaded model with required assets from Alibaba-NLP/new-impl.
79+
80+
Our main model, opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte,
81+
has configurations that attempt dynamic downloading of another model for files.
82+
This can be seen here: https://huggingface.co/opensearch-project/opensearch-
83+
neural-sparse-encoding-doc-v3-gte/blob/main/config.json#L6-L14.
84+
85+
To avoid our deployed CLI application making requests to the HuggingFace API to
86+
retrieve these required files, which is problematic during high concurrency, we
87+
manually download these files and patch the model during our local download and
88+
save.
89+
90+
This allows us to load the primary model without any HuggingFace API calls.
91+
"""
92+
logger.info("Downloading custom code from Alibaba-NLP/new-impl")
93+
with tempfile.TemporaryDirectory() as temp_dir:
94+
temp_path = Path(temp_dir)
95+
snapshot_download(
96+
repo_id="Alibaba-NLP/new-impl",
97+
local_dir=str(temp_path),
98+
)
99+
100+
logger.info("Copying Alibaba code and updating config.json")
101+
shutil.copy(temp_path / "modeling.py", model_temp_path / "modeling.py")
102+
shutil.copy(
103+
temp_path / "configuration.py",
104+
model_temp_path / "configuration.py",
105+
)
106+
107+
with open(model_temp_path / "config.json") as f:
108+
config_json = json.load(f)
109+
config_json["auto_map"] = {
110+
"AutoConfig": "configuration.NewConfig",
111+
"AutoModel": "modeling.NewModel",
112+
"AutoModelForMaskedLM": "modeling.NewForMaskedLM",
113+
"AutoModelForMultipleChoice": "modeling.NewForMultipleChoice",
114+
"AutoModelForQuestionAnswering": "modeling.NewForQuestionAnswering",
115+
"AutoModelForSequenceClassification": (
116+
"modeling.NewForSequenceClassification"
117+
),
118+
"AutoModelForTokenClassification": (
119+
"modeling.NewForTokenClassification"
120+
),
121+
}
122+
with open(model_temp_path / "config.json", "w") as f:
123+
f.write(json.dumps(config_json))
124+
125+
logger.debug("Dependency model Alibaba-NLP/new-impl downloaded and used.")
126+
127+
def load(self, model_path: str | Path) -> None:
128+
"""Load the model from the specified path.
129+
130+
Args:
131+
model_path: Path to the model directory.
132+
"""
133+
start_time = time.perf_counter()
134+
logger.info(f"Loading model from: {model_path}")
135+
model_path = Path(model_path)
136+
137+
# ensure model exists locally
138+
if not model_path.exists():
139+
raise FileNotFoundError(f"Model not found at path: {model_path}")
140+
141+
# load local model and tokenizer
142+
self._model = AutoModelForMaskedLM.from_pretrained(
143+
model_path,
144+
trust_remote_code=True,
145+
local_files_only=True,
146+
)
147+
self._tokenizer = AutoTokenizer.from_pretrained( # type: ignore[no-untyped-call]
148+
model_path,
149+
local_files_only=True,
150+
)
151+
152+
# setup special tokens
153+
self._special_token_ids = [
154+
self._tokenizer.vocab[str(token)]
155+
for token in self._tokenizer.special_tokens_map.values()
156+
]
157+
158+
# setup id_to_token mapping
159+
self._id_to_token = ["" for _ in range(self._tokenizer.vocab_size)]
160+
for token, token_id in self._tokenizer.vocab.items():
161+
self._id_to_token[token_id] = token
162+
163+
logger.info(f"Model loaded successfully, {time.perf_counter()-start_time}s")

pyproject.toml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ dependencies = [
1212
"huggingface-hub>=0.26.0",
1313
"sentry-sdk>=2.34.1",
1414
"timdex-dataset-api",
15+
"torch>=2.9.0",
16+
"transformers>=4.57.1",
1517
]
1618

1719
[dependency-groups]
@@ -32,7 +34,10 @@ line-length = 90
3234
[tool.mypy]
3335
disallow_untyped_calls = true
3436
disallow_untyped_defs = true
35-
exclude = ["tests/"]
37+
exclude = [
38+
"tests/",
39+
"output/"
40+
]
3641

3742
[tool.pytest.ini_options]
3843
log_level = "INFO"
@@ -101,3 +106,6 @@ embeddings = "embeddings.cli:main"
101106
[build-system]
102107
requires = ["setuptools>=61"]
103108
build-backend = "setuptools.build_meta"
109+
110+
[tool.setuptools]
111+
py-modules = ["embeddings"]

0 commit comments

Comments
 (0)