Skip to content

Commit faca712

Browse files
authored
Merge pull request #18 from MITLibraries/USE-136-implement-create-embeddings
USE 136 - implement create embeddings for OSNeuralSparseDocV3GTE
2 parents a725a58 + 4a0cb4d commit faca712

File tree

11 files changed

+599
-311
lines changed

11 files changed

+599
-311
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,6 @@ repos:
2424
types: ["python"]
2525
- id: pip-audit
2626
name: pip-audit
27-
entry: uv run pip-audit --ignore-vuln GHSA-4xh5-x5gv-qwph
27+
entry: uv run pip-audit
2828
language: system
2929
pass_filenames: false

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ ruff: # Run 'ruff' linter and print a preview of errors
5757
uv run ruff check .
5858

5959
safety: # Check for security vulnerabilities
60-
uv run pip-audit --ignore-vuln GHSA-4xh5-x5gv-qwph
60+
uv run pip-audit
6161

6262
lint-apply: black-apply ruff-apply # Apply changes with 'black' and resolve 'fixable errors' with 'ruff'
6363

embeddings/cli.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def create_embeddings(
210210
) -> None:
211211
"""Create embeddings for TIMDEX records."""
212212
model: BaseEmbeddingModel = ctx.obj["model"]
213+
model.load()
213214

214215
# init TIMDEXDataset
215216
timdex_dataset = TIMDEXDataset(dataset_location)
@@ -229,10 +230,10 @@ def create_embeddings(
229230
)
230231

231232
# create an iterator of EmbeddingInputs applying all requested strategies
232-
input_records = create_embedding_inputs(timdex_records, list(strategy))
233+
embedding_inputs = create_embedding_inputs(timdex_records, list(strategy))
233234

234235
# create embeddings via the embedding model
235-
embeddings = model.create_embeddings(input_records)
236+
embeddings = model.create_embeddings(embedding_inputs)
236237

237238
# if requested, write embeddings to a local JSONLines file
238239
if output_jsonl:

embeddings/embedding.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ class EmbeddingInput:
2323
embedding_strategy: str
2424
text: str
2525

26+
def __repr__(self) -> str: # noqa: D105
27+
return (
28+
f"<EmbeddingInput - record:'{self.timdex_record_id}', "
29+
f"strategy:'{self.embedding_strategy}', text length:{len(self.text)}>"
30+
)
31+
2632

2733
@dataclass
2834
class Embedding:
@@ -49,6 +55,12 @@ class Embedding:
4955
default_factory=lambda: datetime.datetime.now(datetime.UTC)
5056
)
5157

58+
def __repr__(self) -> str: # noqa: D105
59+
return (
60+
f"<Embedding - record:'{self.timdex_record_id}', "
61+
f"strategy:'{self.embedding_strategy}'>"
62+
)
63+
5264
def to_dict(self) -> dict:
5365
"""Marshal to dictionary."""
5466
return asdict(self)

embeddings/models/base.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,20 +51,20 @@ def load(self) -> None:
5151
"""Load model from self.model_path."""
5252

5353
@abstractmethod
54-
def create_embedding(self, input_record: EmbeddingInput) -> Embedding:
54+
def create_embedding(self, embedding_input: EmbeddingInput) -> Embedding:
5555
"""Create an Embedding for an EmbeddingInput.
5656
5757
Args:
58-
input_record: EmbeddingInput instance
58+
embedding_input: EmbeddingInput instance
5959
"""
6060

6161
def create_embeddings(
62-
self, input_records: Iterator[EmbeddingInput]
62+
self, embedding_inputs: Iterator[EmbeddingInput]
6363
) -> Iterator[Embedding]:
64-
"""Yield Embeddings for an iterator of InputRecords.
64+
"""Yield Embeddings for a batch of EmbeddingInputs.
6565
6666
Args:
67-
input_records: iterator of InputRecords
67+
embedding_inputs: iterator of EmbeddingInputs
6868
"""
69-
for input_text in input_records:
70-
yield self.create_embedding(input_text)
69+
for embedding_input in embedding_inputs:
70+
yield self.create_embedding(embedding_input)

embeddings/models/os_neural_sparse_doc_v3_gte.py

Lines changed: 195 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pathlib import Path
99
from typing import TYPE_CHECKING
1010

11+
import torch
1112
from huggingface_hub import snapshot_download
1213
from transformers import AutoModelForMaskedLM, AutoTokenizer
1314

@@ -26,6 +27,9 @@
2627
class OSNeuralSparseDocV3GTE(BaseEmbeddingModel):
2728
"""OpenSearch Neural Sparse Encoding Doc v3 GTE model.
2829
30+
This model generates sparse embeddings for documents by using a masked language
31+
model's logits to identify the most relevant tokens.
32+
2933
HuggingFace URI: opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte
3034
"""
3135

@@ -40,8 +44,8 @@ def __init__(self, model_path: str | Path) -> None:
4044
super().__init__(model_path)
4145
self._model: PreTrainedModel | None = None
4246
self._tokenizer: DistilBertTokenizerFast | None = None
43-
self._special_token_ids: list | None = None
44-
self._id_to_token: list | None = None
47+
self._special_token_ids: list[int] | None = None
48+
self._device: torch.device = torch.device("cpu")
4549

4650
def download(self) -> Path:
4751
"""Download and prepare model, saving to self.model_path.
@@ -139,29 +143,205 @@ def load(self) -> None:
139143
if not self.model_path.exists():
140144
raise FileNotFoundError(f"Model not found at path: {self.model_path}")
141145

142-
# load local model and tokenizer
143-
self._model = AutoModelForMaskedLM.from_pretrained(
146+
# setup device (use CUDA if available, otherwise CPU)
147+
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
148+
149+
# load tokenizer
150+
self._tokenizer = AutoTokenizer.from_pretrained( # type: ignore[no-untyped-call]
144151
self.model_path,
145-
trust_remote_code=True,
146152
local_files_only=True,
147153
)
148-
self._tokenizer = AutoTokenizer.from_pretrained( # type: ignore[no-untyped-call]
154+
155+
# load model as AutoModelForMaskedLM (required for sparse embeddings)
156+
self._model = AutoModelForMaskedLM.from_pretrained(
149157
self.model_path,
158+
trust_remote_code=True,
150159
local_files_only=True,
151160
)
161+
self._model.to(self._device) # type: ignore[arg-type]
162+
self._model.eval()
152163

153-
# setup special tokens
164+
# set special token IDs (following model card pattern)
165+
# these will be zeroed out in the sparse vectors
154166
self._special_token_ids = [
155-
self._tokenizer.vocab[str(token)]
167+
self._tokenizer.vocab[token] # type: ignore[index]
156168
for token in self._tokenizer.special_tokens_map.values()
157169
]
158170

159-
# setup id_to_token mapping
160-
self._id_to_token = ["" for _ in range(self._tokenizer.vocab_size)]
161-
for token, token_id in self._tokenizer.vocab.items():
162-
self._id_to_token[token_id] = token
171+
logger.info(
172+
f"Model loaded successfully on {self._device}, "
173+
f"{time.perf_counter() - start_time:.2f}s"
174+
)
175+
176+
def create_embedding(self, embedding_input: EmbeddingInput) -> Embedding:
177+
"""Create sparse vector and decoded token weight embeddings for an input text.
178+
179+
Args:
180+
embedding_input: EmbeddingInput object with a .text attribute
181+
"""
182+
# generate the sparse embeddings
183+
sparse_vector, decoded_tokens = self._encode_documents([embedding_input.text])[0]
184+
185+
# coerce sparse vector tensor into list[float]
186+
sparse_vector_list = sparse_vector.cpu().numpy().tolist()
187+
188+
return Embedding(
189+
timdex_record_id=embedding_input.timdex_record_id,
190+
run_id=embedding_input.run_id,
191+
run_record_offset=embedding_input.run_record_offset,
192+
model_uri=self.model_uri,
193+
embedding_strategy=embedding_input.embedding_strategy,
194+
embedding_vector=sparse_vector_list,
195+
embedding_token_weights=decoded_tokens,
196+
)
197+
198+
def _encode_documents(
199+
self,
200+
texts: list[str],
201+
) -> list[tuple[torch.Tensor, dict[str, float]]]:
202+
"""Encode documents into sparse vectors and decoded token weights.
203+
204+
This follows the pattern outlined on the HuggingFace model card for document
205+
encoding.
206+
207+
This method will accommodate MULTIPLE text inputs, and return a list of
208+
embeddings, but the calling context of create_embedding() is a SINGULAR input +
209+
output. This method keeps the ability to handle multiple inputs + outputs, in the
210+
event we want something like a create_multiple_embeddings() method in the future,
211+
but only returns a single result.
212+
213+
At a very high level, the following is performed:
214+
215+
1. We tokenize the input text into "features" using the model's tokenizer.
216+
217+
2. The features are fed to the model returning model output logits. These logits
218+
are "dense" in the sense there are few zeros, but they are not "dense vectors"
219+
(embeddings) in the sense that they meaningfully represent the input document in
220+
geometric space; two logit tensors cannot be compared with something like cosine
221+
similarity.
222+
223+
3. The logits are then converted into a sparse vector, which is a numeric
224+
array of floats with the same number of values as the model's vocabulary. Each
225+
value's position in the sparse array corresponds to the token id in the
226+
vocabulary, and the value itself is the "weight" of this token in the input text.
227+
228+
4. Lastly, we convert this sparse vector into a {token:weight} dictionary of the
229+
actual token strings and their numerical weight. This dictionary may contain
230+
tokens not present in the original text, but will be considerably shorter than
231+
the model vocabulary length given all zero and low scoring tokens are dropped.
232+
This is the final form that we will ultimately index into OpenSearch.
233+
234+
Args:
235+
texts: list of strings to create embeddings for
236+
"""
237+
if self._model is None or self._tokenizer is None:
238+
raise RuntimeError("Model not loaded. Call load() before create_embedding.")
239+
240+
# tokenize the input texts
241+
features = self._tokenizer(
242+
texts,
243+
padding=True,
244+
truncation=True,
245+
return_tensors="pt", # returns PyTorch tensors instead of Python lists
246+
return_token_type_ids=False,
247+
)
248+
249+
# move to CPU or GPU device, depending on what's available
250+
features = {k: v.to(self._device) for k, v in features.items()}
251+
252+
# pass features to the model and receive model output logits as a tensor
253+
with torch.no_grad():
254+
output = self._model(**features)[0]
255+
256+
# generate sparse vectors from model logits tensor
257+
sparse_vectors = self._get_sparse_vectors(features, output)
258+
259+
# decode sparse vectors to token-weight dictionaries
260+
decoded = self._decode_sparse_vectors(sparse_vectors)
261+
262+
# return list of tuple(vector, decoded token weights) embedding results
263+
return [(sparse_vectors[i], decoded[i]) for i in range(len(texts))]
264+
265+
def _get_sparse_vectors(
266+
self, features: dict[str, torch.Tensor], output: torch.Tensor
267+
) -> torch.Tensor:
268+
"""Convert model logits output to sparse vectors.
269+
270+
This follows the HuggingFace model card exactly: https://huggingface.co/
271+
opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte#usage-huggingface
272+
273+
This implements the get_sparse_vector function from the model card:
274+
1. Max pooling with attention mask
275+
2. log(1 + log(1 + relu())) transformation
276+
3. Zero out special tokens
277+
278+
The end result is a sparse vector with a length of the model vocabulary, with each
279+
position representing a token in the model vocabulary and each value representing
280+
that token's weight relative to the input text.
281+
282+
Args:
283+
features: Tokenizer output with attention_mask
284+
output: Model logits of shape (batch_size, seq_len, vocab_size)
285+
286+
Returns:
287+
Sparse vectors of shape (batch_size, vocab_size)
288+
"""
289+
# collapse sequence positions: take max logit for each vocab token across all
290+
# positions (also masks out padding tokens)
291+
values, _ = torch.max(output * features["attention_mask"].unsqueeze(-1), dim=1)
292+
293+
# compress values to create sparsity: ReLU removes negatives,
294+
# double-log shrinks large values
295+
values = torch.log(1 + torch.log(1 + torch.relu(values)))
296+
297+
# remove special tokens like [CLS], [SEP], [PAD]
298+
values[:, self._special_token_ids] = 0
299+
300+
return values
301+
302+
def _decode_sparse_vectors(
303+
self, sparse_vectors: torch.Tensor
304+
) -> list[dict[str, float]]:
305+
"""Convert sparse vectors to token-weight dictionaries.
306+
307+
Handles both single vectors and batches, returning a list of dictionaries mapping
308+
token strings to their weights.
309+
310+
Args:
311+
sparse_vectors: Tensor of shape (batch_size, vocab_size) or (vocab_size,)
312+
313+
Returns:
314+
List of dictionaries with token-weight pairs
315+
"""
316+
if sparse_vectors.dim() == 1:
317+
sparse_vectors = sparse_vectors.unsqueeze(0)
318+
319+
# move to CPU for processing
320+
sparse_vectors_cpu = sparse_vectors.cpu()
321+
322+
results: list[dict] = []
323+
for vector in sparse_vectors_cpu:
324+
325+
# find non-zero indices and values
326+
nonzero_indices = torch.nonzero(vector, as_tuple=False).squeeze(-1)
327+
328+
if nonzero_indices.numel() == 0:
329+
results.append({})
330+
continue
331+
332+
# get weights
333+
weights = vector[nonzero_indices].tolist()
334+
335+
# convert indices to token strings
336+
token_ids = nonzero_indices.tolist()
337+
tokens = self._tokenizer.convert_ids_to_tokens(token_ids) # type: ignore[union-attr]
163338

164-
logger.info(f"Model loaded successfully, {time.perf_counter()-start_time}s")
339+
# create token:weight dictionary
340+
token_dict = {
341+
token: weight
342+
for token, weight in zip(tokens, weights, strict=True)
343+
if token is not None
344+
}
345+
results.append(token_dict)
165346

166-
def create_embedding(self, input_record: EmbeddingInput) -> Embedding:
167-
raise NotImplementedError
347+
return results

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ ignore = [
7070
"D102",
7171
"D103",
7272
"D104",
73+
"EM101",
7374
"EM102",
7475
"G004",
7576
"PLR0912",

0 commit comments

Comments
 (0)