11"""OpenSearch Neural Sparse Doc v3 GTE model."""
22
3+ import json
34import logging
5+ import shutil
6+ import tempfile
7+ import time
48from pathlib import Path
9+ from typing import TYPE_CHECKING
10+
11+ from huggingface_hub import snapshot_download
12+ from transformers import AutoModelForMaskedLM , AutoTokenizer
513
614from embeddings .models .base import BaseEmbeddingModel
715
16+ if TYPE_CHECKING :
17+ from transformers import PreTrainedModel
18+ from transformers .models .distilbert .tokenization_distilbert_fast import (
19+ DistilBertTokenizerFast ,
20+ )
21+
822logger = logging .getLogger (__name__ )
923
1024
@@ -16,11 +30,134 @@ class OSNeuralSparseDocV3GTE(BaseEmbeddingModel):
1630
1731 MODEL_URI = "opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte"
1832
19- def download (self , output_path : Path ) -> Path :
33+ def __init__ (self ) -> None :
34+ """Initialize the model."""
35+ super ().__init__ ()
36+ self ._model : PreTrainedModel | None = None
37+ self ._tokenizer : DistilBertTokenizerFast | None = None
38+ self ._special_token_ids : list | None = None
39+ self ._id_to_token : list | None = None
40+
41+ def download (self , output_path : str | Path ) -> Path :
2042 """Download and prepare model, saving to output_path.
2143
2244 Args:
23- output_path: Path where the model zip should be saved.
45+ output_path: Path where the model should be saved.
2446 """
25- logger .info (f"Downloading model: { self .model_uri } , saving to: { output_path } ." )
26- raise NotImplementedError
47+ start_time = time .perf_counter ()
48+
49+ output_path = Path (output_path )
50+ logger .info (f"Downloading model: { self .model_uri } , saving to: { output_path } ." )
51+
52+ with tempfile .TemporaryDirectory () as temp_dir :
53+ temp_path = Path (temp_dir )
54+
55+ # download snapshot of HuggingFace model
56+ snapshot_download (repo_id = self .model_uri , local_dir = temp_path )
57+ logger .debug ("Model download complete." )
58+
59+ # patch local model with files from dependency model "Alibaba-NLP/new-impl"
60+ self ._patch_local_model_with_alibaba_new_impl (temp_path )
61+
62+ # compress model directory as a zip file
63+ if output_path .suffix .lower () == ".zip" :
64+ logger .debug ("Creating zip file of model contents." )
65+ shutil .make_archive (str (output_path .with_suffix ("" )), "zip" , temp_path )
66+
67+ # copy to output directory without zipping
68+ else :
69+ logger .debug (f"Copying model contents to { output_path } " )
70+ if output_path .exists ():
71+ shutil .rmtree (output_path )
72+ shutil .copytree (temp_path , output_path )
73+
74+ logger .info (f"Model downloaded successfully, { time .perf_counter () - start_time } s" )
75+ return output_path
76+
77+ def _patch_local_model_with_alibaba_new_impl (self , model_temp_path : Path ) -> None :
78+ """Patch downloaded model with required assets from Alibaba-NLP/new-impl.
79+
80+ Our main model, opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte,
81+ has configurations that attempt dynamic downloading of another model for files.
82+ This can be seen here: https://huggingface.co/opensearch-project/opensearch-
83+ neural-sparse-encoding-doc-v3-gte/blob/main/config.json#L6-L14.
84+
85+ To avoid our deployed CLI application making requests to the HuggingFace API to
86+ retrieve these required files, which is problematic during high concurrency, we
87+ manually download these files and patch the model during our local download and
88+ save.
89+
90+ This allows us to load the primary model without any HuggingFace API calls.
91+ """
92+ logger .info ("Downloading custom code from Alibaba-NLP/new-impl" )
93+ with tempfile .TemporaryDirectory () as temp_dir :
94+ temp_path = Path (temp_dir )
95+ snapshot_download (
96+ repo_id = "Alibaba-NLP/new-impl" ,
97+ local_dir = str (temp_path ),
98+ )
99+
100+ logger .info ("Copying Alibaba code and updating config.json" )
101+ shutil .copy (temp_path / "modeling.py" , model_temp_path / "modeling.py" )
102+ shutil .copy (
103+ temp_path / "configuration.py" ,
104+ model_temp_path / "configuration.py" ,
105+ )
106+
107+ with open (model_temp_path / "config.json" ) as f :
108+ config_json = json .load (f )
109+ config_json ["auto_map" ] = {
110+ "AutoConfig" : "configuration.NewConfig" ,
111+ "AutoModel" : "modeling.NewModel" ,
112+ "AutoModelForMaskedLM" : "modeling.NewForMaskedLM" ,
113+ "AutoModelForMultipleChoice" : "modeling.NewForMultipleChoice" ,
114+ "AutoModelForQuestionAnswering" : "modeling.NewForQuestionAnswering" ,
115+ "AutoModelForSequenceClassification" : (
116+ "modeling.NewForSequenceClassification"
117+ ),
118+ "AutoModelForTokenClassification" : (
119+ "modeling.NewForTokenClassification"
120+ ),
121+ }
122+ with open (model_temp_path / "config.json" , "w" ) as f :
123+ f .write (json .dumps (config_json ))
124+
125+ logger .debug ("Dependency model Alibaba-NLP/new-impl downloaded and used." )
126+
127+ def load (self , model_path : str | Path ) -> None :
128+ """Load the model from the specified path.
129+
130+ Args:
131+ model_path: Path to the model directory.
132+ """
133+ start_time = time .perf_counter ()
134+ logger .info (f"Loading model from: { model_path } " )
135+ model_path = Path (model_path )
136+
137+ # ensure model exists locally
138+ if not model_path .exists ():
139+ raise FileNotFoundError (f"Model not found at path: { model_path } " )
140+
141+ # load local model and tokenizer
142+ self ._model = AutoModelForMaskedLM .from_pretrained (
143+ model_path ,
144+ trust_remote_code = True ,
145+ local_files_only = True ,
146+ )
147+ self ._tokenizer = AutoTokenizer .from_pretrained ( # type: ignore[no-untyped-call]
148+ model_path ,
149+ local_files_only = True ,
150+ )
151+
152+ # setup special tokens
153+ self ._special_token_ids = [
154+ self ._tokenizer .vocab [str (token )]
155+ for token in self ._tokenizer .special_tokens_map .values ()
156+ ]
157+
158+ # setup id_to_token mapping
159+ self ._id_to_token = ["" for _ in range (self ._tokenizer .vocab_size )]
160+ for token , token_id in self ._tokenizer .vocab .items ():
161+ self ._id_to_token [token_id ] = token
162+
163+ logger .info (f"Model loaded successfully, { time .perf_counter ()- start_time } s" )
0 commit comments