88from pathlib import Path
99from typing import TYPE_CHECKING
1010
11+ import torch
1112from huggingface_hub import snapshot_download
1213from transformers import AutoModelForMaskedLM , AutoTokenizer
1314
2627class OSNeuralSparseDocV3GTE (BaseEmbeddingModel ):
2728 """OpenSearch Neural Sparse Encoding Doc v3 GTE model.
2829
30+ This model generates sparse embeddings for documents by using a masked language
31+ model's logits to identify the most relevant tokens.
32+
2933 HuggingFace URI: opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte
3034 """
3135
@@ -40,8 +44,8 @@ def __init__(self, model_path: str | Path) -> None:
4044 super ().__init__ (model_path )
4145 self ._model : PreTrainedModel | None = None
4246 self ._tokenizer : DistilBertTokenizerFast | None = None
43- self ._special_token_ids : list | None = None
44- self ._id_to_token : list | None = None
47+ self ._special_token_ids : list [ int ] | None = None
48+ self ._device : torch . device = torch . device ( "cpu" )
4549
4650 def download (self ) -> Path :
4751 """Download and prepare model, saving to self.model_path.
@@ -139,29 +143,228 @@ def load(self) -> None:
139143 if not self .model_path .exists ():
140144 raise FileNotFoundError (f"Model not found at path: { self .model_path } " )
141145
142- # load local model and tokenizer
143- self ._model = AutoModelForMaskedLM .from_pretrained (
146+ # setup device (use CUDA if available, otherwise CPU)
147+ self ._device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
148+
149+ # load tokenizer
150+ self ._tokenizer = AutoTokenizer .from_pretrained ( # type: ignore[no-untyped-call]
144151 self .model_path ,
145- trust_remote_code = True ,
146152 local_files_only = True ,
147153 )
148- self ._tokenizer = AutoTokenizer .from_pretrained ( # type: ignore[no-untyped-call]
154+
155+ # load model as AutoModelForMaskedLM (required for sparse embeddings)
156+ self ._model = AutoModelForMaskedLM .from_pretrained (
149157 self .model_path ,
158+ trust_remote_code = True ,
150159 local_files_only = True ,
151160 )
161+ self ._model .to (self ._device ) # type: ignore[arg-type]
162+ self ._model .eval ()
152163
153- # setup special tokens
164+ # set special token IDs (following model card pattern)
165+ # these will be zeroed out in the sparse vectors
154166 self ._special_token_ids = [
155- self ._tokenizer .vocab [str ( token ) ]
167+ self ._tokenizer .vocab [token ] # type: ignore[index ]
156168 for token in self ._tokenizer .special_tokens_map .values ()
157169 ]
158170
159- # setup id_to_token mapping
160- self ._id_to_token = ["" for _ in range (self ._tokenizer .vocab_size )]
161- for token , token_id in self ._tokenizer .vocab .items ():
162- self ._id_to_token [token_id ] = token
163-
164- logger .info (f"Model loaded successfully, { time .perf_counter ()- start_time } s" )
171+ logger .info (
172+ f"Model loaded successfully on { self ._device } , "
173+ f"{ time .perf_counter () - start_time :.2f} s"
174+ )
165175
166176 def create_embedding (self , input_record : EmbeddingInput ) -> Embedding :
167- raise NotImplementedError
177+ """Create sparse embeddings for the input text (document encoding).
178+
179+ This method generates sparse document embeddings.
180+
181+ Process follows the model card exactly:
182+ 1. Tokenize the document
183+ 2. Pass through the masked language model to get logits
184+ 3. Convert logits to sparse vector
185+ 6. Return both raw sparse vector and decoded token-weight pairs
186+
187+ Args:
188+ input_record: The input containing text to embed
189+ """
190+ # generate the sparse embeddings
191+ sparse_vector , decoded_tokens = self ._encode_documents ([input_record .text ])[0 ]
192+
193+ # coerce sparse vector tensor into list[float]
194+ sparse_vector_list = sparse_vector .cpu ().numpy ().tolist ()
195+
196+ return Embedding (
197+ timdex_record_id = input_record .timdex_record_id ,
198+ run_id = input_record .run_id ,
199+ run_record_offset = input_record .run_record_offset ,
200+ model_uri = self .model_uri ,
201+ embedding_strategy = input_record .embedding_strategy ,
202+ embedding_vector = sparse_vector_list ,
203+ embedding_token_weights = decoded_tokens ,
204+ )
205+
206+ def _encode_documents (
207+ self ,
208+ texts : list [str ],
209+ ) -> list [tuple [torch .Tensor , dict [str , float ]]]:
210+ """Encode documents into sparse vectors and decoded token weights.
211+
212+ This follows the pattern outlined on the HuggingFace model card for document
213+ encoding.
214+
215+ This method will accommodate a list of text inputs, and return a list of
216+ embeddings, but the calling base method create_embeddings() is a singular input +
217+ output. This method keeps the ability to handle multiple inputs + outputs, in the
218+ event we want something like a create_multiple_embeddings() method in the future.
219+
220+ The following is a rough approximation of receiving logits back from the model
221+ and converting this to a sparse vector which can then be decoded to token:weights:
222+
223+ ----------------------------------------------------------------------------------
224+ Imagine your vocabulary is just 5 words: ["cat", "dog", "bird", "fish", "tree"]
225+ Vocabulary indices: [ 0, 1, 2, 3, 4]
226+
227+ 1. MODEL RETURNS LOGITS
228+ Let's say you input the text: "cat and dog"
229+ After tokenization, you have 3 tokens at 3 sequence positions
230+ The model outputs logits - a score for EVERY vocab word at EVERY position:
231+
232+ logits = [
233+ # Position 0 (word "cat"): scores for each vocab word at this position
234+ [9.2, 1.1, 0.3, 0.5, 0.2], # "cat" gets high score (9.2)
235+
236+ # Position 1 (word "and" - not in our toy vocab, but tokenized somehow):
237+ [2.1, 1.8, 0.4, 0.3, 0.9], # moderate scores everywhere
238+
239+ # Position 2 (word "dog"):
240+ [0.8, 8.7, 0.2, 0.4, 0.1], # "dog" gets high score (8.7)
241+ ]
242+ Shape: (3 positions, 5 vocab words)
243+
244+
245+ 2. PRODUCE SPARSE VECTORS FROM LOGITS
246+ We collapse the sequence positions by taking the MAX score for each vocab word:
247+
248+ sparse_vector = [
249+ max(9.2, 2.1, 0.8), # "cat": take max across all 3 positions = 9.2
250+ max(1.1, 1.8, 8.7), # "dog": take max = 8.7
251+ max(0.3, 0.4, 0.2), # "bird": take max = 0.4
252+ max(0.5, 0.3, 0.4), # "fish": take max = 0.5
253+ max(0.2, 0.9, 0.1), # "tree": take max = 0.9
254+ ]
255+
256+ Apply transformations (ReLU, double-log) to make it sparser:
257+ sparse_vector = [5.1, 4.8, 0.0, 0.0, 0.0] # smaller values become 0
258+
259+ Final result:
260+ {"cat": 5.1, "dog": 4.8} # Only the relevant words have non-zero weights
261+ ----------------------------------------------------------------------------------
262+
263+ Args:
264+ texts: list of strings to create embeddings for
265+ """
266+ if self ._model is None or self ._tokenizer is None :
267+ raise RuntimeError ("Model not loaded. Call load() before create_embedding." )
268+
269+ # tokenize the input texts
270+ features = self ._tokenizer (
271+ texts ,
272+ padding = True ,
273+ truncation = True ,
274+ return_tensors = "pt" , # returns PyTorch tensors instead of Python lists
275+ return_token_type_ids = False ,
276+ )
277+
278+ # move to CPU or GPU device, depending on what's available
279+ features = {k : v .to (self ._device ) for k , v in features .items ()}
280+
281+ # get model logits output
282+ with torch .no_grad ():
283+ output = self ._model (** features )[0 ]
284+
285+ # generate sparse vectors from model logits
286+ sparse_vectors = self ._get_sparse_vectors (features , output )
287+
288+ # decode to token-weight dictionaries
289+ decoded = self ._decode_sparse_vectors (sparse_vectors )
290+
291+ # return list of tuple(vector, decoded token weights) embedding results
292+ return [(sparse_vectors [i ], decoded [i ]) for i in range (len (texts ))]
293+
294+ def _get_sparse_vectors (
295+ self , features : dict [str , torch .Tensor ], output : torch .Tensor
296+ ) -> torch .Tensor :
297+ """Convert model logits output to sparse vectors.
298+
299+ This follows the HuggingFace model card exactly: https://huggingface.co/
300+ opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte#usage-huggingface
301+
302+ This implements the get_sparse_vector function from the model card:
303+ 1. Max pooling with attention mask
304+ 2. log(1 + log(1 + relu())) transformation
305+ 3. Zero out special tokens
306+
307+ Args:
308+ features: Tokenizer output with attention_mask
309+ output: Model logits of shape (batch_size, seq_len, vocab_size)
310+
311+ Returns:
312+ Sparse vectors of shape (batch_size, vocab_size)
313+ """
314+ # max pooling with attention mask
315+ values , _ = torch .max (output * features ["attention_mask" ].unsqueeze (- 1 ), dim = 1 )
316+
317+ # apply the v3 model activation
318+ values = torch .log (1 + torch .log (1 + torch .relu (values )))
319+
320+ # zero out special tokens
321+ values [:, self ._special_token_ids ] = 0
322+
323+ return values
324+
325+ def _decode_sparse_vectors (
326+ self , sparse_vectors : torch .Tensor
327+ ) -> list [dict [str , float ]]:
328+ """Convert sparse vectors to token-weight dictionaries.
329+
330+ Handles both single vectors and batches, returning a list of dictionaries mapping
331+ token strings to their weights.
332+
333+ Args:
334+ sparse_vectors: Tensor of shape (batch_size, vocab_size) or (vocab_size,)
335+
336+ Returns:
337+ List of dictionaries with token-weight pairs
338+ """
339+ if sparse_vectors .dim () == 1 :
340+ sparse_vectors = sparse_vectors .unsqueeze (0 )
341+
342+ # move to CPU for processing
343+ sparse_vectors_cpu = sparse_vectors .cpu ()
344+
345+ results : list [dict ] = []
346+ for vector in sparse_vectors_cpu :
347+
348+ # find non-zero indices and values
349+ nonzero_indices = torch .nonzero (vector , as_tuple = False ).squeeze (- 1 )
350+
351+ if nonzero_indices .numel () == 0 :
352+ results .append ({})
353+ continue
354+
355+ # get weights
356+ weights = vector [nonzero_indices ].tolist ()
357+
358+ # convert indices to token strings
359+ token_ids = nonzero_indices .tolist ()
360+ tokens = self ._tokenizer .convert_ids_to_tokens (token_ids ) # type: ignore[union-attr]
361+
362+ # create token:weight dictionary
363+ token_dict = {
364+ token : weight
365+ for token , weight in zip (tokens , weights , strict = True )
366+ if token is not None
367+ }
368+ results .append (token_dict )
369+
370+ return results
0 commit comments