Skip to content

Commit a50f256

Browse files
committed
OSNeuralSparseDocV3GTE model download and load
Why these changes are being introduced: Each embedding class needs a way to download the model assets (e.g. weights and related files) locally, such that it can be loaded without calls to the HuggingFace API. Some models may require work beyond just HF's `snapshot_download()` function, e.g. cloning dependency models or configurations. To test if a downloaded and configured correctly, you must then also load the model. Ideally performing a test embedding creation, but even just a load without errors is a good step. How this addresses that need: The base class is extended to include a `load()` method. Our first embedding class `OSNeuralSparseDocV3GTE` has a first pass at `downloadg()` and `load()` methods. The model we are using has some unusual dependency requirements, that most commonly relies on additional HuggingFace calls on load. To avoid this, we include some manual work to clone the model `Alibaba-NLP/new-impl` and copy required files into our local model clone. The `load()` function confirms that the model loads successfully, and without making any HuggingFace API calls. Side effects of this change: * None Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/USE-113
1 parent 462aef5 commit a50f256

File tree

5 files changed

+641
-6
lines changed

5 files changed

+641
-6
lines changed

embeddings/cli.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import functools
22
import logging
3+
import os
34
import time
45
from collections.abc import Callable
56
from datetime import timedelta
@@ -90,6 +91,25 @@ def download_model(model_uri: str, output: Path) -> None:
9091
click.echo(result_path)
9192

9293

94+
@main.command()
95+
def test_model_load() -> None:
96+
"""Test loading of embedding class and local model based on env vars.
97+
98+
In a deployed context, the following env vars are expected:
99+
- TE_MODEL_URI
100+
- TE_MODEL_DOWNLOAD_PATH
101+
102+
With these set, the embedding class should be registered successfully and initialized,
103+
and the model loaded from a local copy.
104+
"""
105+
# load embedding model class
106+
model_class = get_model_class(os.environ["TE_MODEL_URI"])
107+
model = model_class()
108+
109+
model.load(os.environ["TE_MODEL_DOWNLOAD_PATH"])
110+
click.echo("OK")
111+
112+
93113
@main.command()
94114
@model_required
95115
def create_embeddings(_model_uri: str) -> None:

embeddings/models/base.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,17 @@ def model_uri(self) -> str:
2828
return self.MODEL_URI
2929

3030
@abstractmethod
31-
def download(self, output_path: Path) -> Path:
31+
def download(self, output_path: str | Path) -> Path:
3232
"""Download and prepare model, saving to output_path.
3333
3434
Args:
3535
output_path: Path where the model zip should be saved.
3636
"""
37+
38+
@abstractmethod
39+
def load(self, model_path: str | Path) -> None:
40+
"""Load model from local, downloaded instance.
41+
42+
Args:
43+
model_path: Path of local model directory.
44+
"""
Lines changed: 141 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,24 @@
11
"""OpenSearch Neural Sparse Doc v3 GTE model."""
22

3+
import json
34
import logging
5+
import shutil
6+
import tempfile
7+
import time
48
from pathlib import Path
9+
from typing import TYPE_CHECKING
10+
11+
from huggingface_hub import snapshot_download
12+
from transformers import AutoModelForMaskedLM, AutoTokenizer
513

614
from embeddings.models.base import BaseEmbeddingModel
715

16+
if TYPE_CHECKING:
17+
from transformers import PreTrainedModel
18+
from transformers.models.distilbert.tokenization_distilbert_fast import (
19+
DistilBertTokenizerFast,
20+
)
21+
822
logger = logging.getLogger(__name__)
923

1024

@@ -16,11 +30,134 @@ class OSNeuralSparseDocV3GTE(BaseEmbeddingModel):
1630

1731
MODEL_URI = "opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte"
1832

19-
def download(self, output_path: Path) -> Path:
33+
def __init__(self) -> None:
34+
"""Initialize the model."""
35+
super().__init__()
36+
self._model: PreTrainedModel | None = None
37+
self._tokenizer: DistilBertTokenizerFast | None = None
38+
self._special_token_ids: list | None = None
39+
self._id_to_token: list | None = None
40+
41+
def download(self, output_path: str | Path) -> Path:
2042
"""Download and prepare model, saving to output_path.
2143
2244
Args:
23-
output_path: Path where the model zip should be saved.
45+
output_path: Path where the model should be saved.
2446
"""
25-
logger.info(f"Downloading model: { self.model_uri}, saving to: {output_path}.")
26-
raise NotImplementedError
47+
start_time = time.perf_counter()
48+
49+
output_path = Path(output_path)
50+
logger.info(f"Downloading model: {self.model_uri}, saving to: {output_path}.")
51+
52+
with tempfile.TemporaryDirectory() as temp_dir:
53+
temp_path = Path(temp_dir)
54+
55+
# download snapshot of HuggingFace model
56+
snapshot_download(repo_id=self.model_uri, local_dir=temp_path)
57+
logger.debug("Model download complete.")
58+
59+
# patch local model with files from dependency model "Alibaba-NLP/new-impl"
60+
self._patch_local_model_with_alibaba_new_impl(temp_path)
61+
62+
# compress model directory as a zip file
63+
if output_path.suffix.lower() == ".zip":
64+
logger.debug("Creating zip file of model contents.")
65+
shutil.make_archive(str(output_path.with_suffix("")), "zip", temp_path)
66+
67+
# copy to output directory without zipping
68+
else:
69+
logger.debug(f"Copying model contents to {output_path}")
70+
if output_path.exists():
71+
shutil.rmtree(output_path)
72+
shutil.copytree(temp_path, output_path)
73+
74+
logger.info(f"Model downloaded successfully, {time.perf_counter() - start_time}s")
75+
return output_path
76+
77+
def _patch_local_model_with_alibaba_new_impl(self, model_temp_path: Path) -> None:
78+
"""Patch downloaded model with required assets from Alibaba-NLP/new-impl.
79+
80+
Our main model, opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte,
81+
has configurations that attempt dynamic downloading of another model for files.
82+
This can be seen here: https://huggingface.co/opensearch-project/opensearch-
83+
neural-sparse-encoding-doc-v3-gte/blob/main/config.json#L6-L14.
84+
85+
To avoid our deployed CLI application making requests to the HuggingFace API to
86+
retrieve these required files, which is problematic during high concurrency, we
87+
manually download these files and patch the model during our local download and
88+
save.
89+
90+
This allows us to load the primary model without any HuggingFace API calls.
91+
"""
92+
logger.info("Downloading custom code from Alibaba-NLP/new-impl")
93+
with tempfile.TemporaryDirectory() as temp_dir:
94+
temp_path = Path(temp_dir)
95+
snapshot_download(
96+
repo_id="Alibaba-NLP/new-impl",
97+
local_dir=str(temp_path),
98+
)
99+
100+
logger.info("Copying Alibaba code and updating config.json")
101+
shutil.copy(temp_path / "modeling.py", model_temp_path / "modeling.py")
102+
shutil.copy(
103+
temp_path / "configuration.py",
104+
model_temp_path / "configuration.py",
105+
)
106+
107+
with open(model_temp_path / "config.json") as f:
108+
config_json = json.load(f)
109+
config_json["auto_map"] = {
110+
"AutoConfig": "configuration.NewConfig",
111+
"AutoModel": "modeling.NewModel",
112+
"AutoModelForMaskedLM": "modeling.NewForMaskedLM",
113+
"AutoModelForMultipleChoice": "modeling.NewForMultipleChoice",
114+
"AutoModelForQuestionAnswering": "modeling.NewForQuestionAnswering",
115+
"AutoModelForSequenceClassification": (
116+
"modeling.NewForSequenceClassification"
117+
),
118+
"AutoModelForTokenClassification": (
119+
"modeling.NewForTokenClassification"
120+
),
121+
}
122+
with open(model_temp_path / "config.json", "w") as f:
123+
f.write(json.dumps(config_json))
124+
125+
logger.debug("Dependency model Alibaba-NLP/new-impl downloaded and used.")
126+
127+
def load(self, model_path: str | Path) -> None:
128+
"""Load the model from the specified path.
129+
130+
Args:
131+
model_path: Path to the model directory.
132+
"""
133+
start_time = time.perf_counter()
134+
logger.info(f"Loading model from: {model_path}")
135+
model_path = Path(model_path)
136+
137+
# ensure model exists locally
138+
if not model_path.exists():
139+
raise FileNotFoundError(f"Model not found at path: {model_path}")
140+
141+
# load local model and tokenizer
142+
self._model = AutoModelForMaskedLM.from_pretrained(
143+
model_path,
144+
trust_remote_code=True,
145+
local_files_only=True,
146+
)
147+
self._tokenizer = AutoTokenizer.from_pretrained( # type: ignore[no-untyped-call]
148+
model_path,
149+
local_files_only=True,
150+
)
151+
152+
# setup special tokens
153+
self._special_token_ids = [
154+
self._tokenizer.vocab[str(token)]
155+
for token in self._tokenizer.special_tokens_map.values()
156+
]
157+
158+
# setup id_to_token mapping
159+
self._id_to_token = ["" for _ in range(self._tokenizer.vocab_size)]
160+
for token, token_id in self._tokenizer.vocab.items():
161+
self._id_to_token[token_id] = token
162+
163+
logger.info(f"Model loaded successfully, {time.perf_counter()-start_time}s")

pyproject.toml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ dependencies = [
1212
"huggingface-hub>=0.26.0",
1313
"sentry-sdk>=2.34.1",
1414
"timdex-dataset-api",
15+
"torch>=2.9.0",
16+
"transformers>=4.57.1",
1517
]
1618

1719
[dependency-groups]
@@ -32,7 +34,10 @@ line-length = 90
3234
[tool.mypy]
3335
disallow_untyped_calls = true
3436
disallow_untyped_defs = true
35-
exclude = ["tests/"]
37+
exclude = [
38+
"tests/",
39+
"output/"
40+
]
3641

3742
[tool.pytest.ini_options]
3843
log_level = "INFO"
@@ -101,3 +106,6 @@ embeddings = "embeddings.cli:main"
101106
[build-system]
102107
requires = ["setuptools>=61"]
103108
build-backend = "setuptools.build_meta"
109+
110+
[tool.setuptools]
111+
py-modules = ["embeddings"]

0 commit comments

Comments
 (0)