Skip to content

Commit a725a58

Browse files
authored
Merge pull request #17 from MITLibraries/USE-131-embedding-input-transform-framework
USE 131 - Framework for embedding input strategies
2 parents bd9de2b + 421ba71 commit a725a58

File tree

8 files changed

+212
-46
lines changed

8 files changed

+212
-46
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,9 @@ Options:
114114
default = 0. [required]
115115
--record-limit INTEGER Limit number of records after --run-record-
116116
offset, default = None (unlimited). [required]
117-
--strategy TEXT Pre-embedding record transformation strategy to
118-
use. Repeatable. [required]
117+
--strategy [full_record] Pre-embedding record transformation strategy.
118+
Repeatable to apply multiple strategies.
119+
[required]
119120
--output-jsonl TEXT Optionally write embeddings to local JSONLines
120121
file (primarily for testing).
121122
--help Show this message and exit.

embeddings/cli.py

Lines changed: 13 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import json
23
import logging
34
import time
45
from collections.abc import Callable
@@ -12,6 +13,8 @@
1213

1314
from embeddings.config import configure_logger, configure_sentry
1415
from embeddings.models.registry import get_model_class
16+
from embeddings.strategies.processor import create_embedding_inputs
17+
from embeddings.strategies.registry import STRATEGY_REGISTRY
1518

1619
logger = logging.getLogger(__name__)
1720

@@ -181,10 +184,13 @@ def test_model_load(ctx: click.Context) -> None:
181184
)
182185
@click.option(
183186
"--strategy",
184-
type=str, # WIP: establish an enum of supported strategies
187+
type=click.Choice(list(STRATEGY_REGISTRY.keys())),
185188
required=True,
186189
multiple=True,
187-
help="Pre-embedding record transformation strategy to use. Repeatable.",
190+
help=(
191+
"Pre-embedding record transformation strategy. "
192+
"Repeatable to apply multiple strategies."
193+
),
188194
)
189195
@click.option(
190196
"--output-jsonl",
@@ -222,48 +228,11 @@ def create_embeddings(
222228
action="index",
223229
)
224230

225-
# create an iterator of InputTexts applying all requested strategies to all records
226-
# WIP NOTE: this will leverage some kind of pre-embedding transformer class(es) that
227-
# create texts based on the requested strategies (e.g. "full record"), which are
228-
# captured in --strategy CLI args
229-
# WIP NOTE: the following simulates that...
230-
# DEBUG ------------------------------------------------------------------------------
231-
import json # noqa: PLC0415
232-
233-
from embeddings.embedding import EmbeddingInput # noqa: PLC0415
234-
235-
input_records = (
236-
EmbeddingInput(
237-
timdex_record_id=timdex_record["timdex_record_id"],
238-
run_id=timdex_record["run_id"],
239-
run_record_offset=timdex_record["run_record_offset"],
240-
embedding_strategy=_strategy,
241-
text=json.dumps(timdex_record["transformed_record"].decode()),
242-
)
243-
for timdex_record in timdex_records
244-
for _strategy in strategy
245-
)
246-
# DEBUG ------------------------------------------------------------------------------
247-
248-
# create an iterator of Embeddings via the embedding model
249-
# WIP NOTE: this will use the embedding class .create_embeddings() bulk method
250-
# WIP NOTE: the following simulates that...
251-
# DEBUG ------------------------------------------------------------------------------
252-
from embeddings.embedding import Embedding # noqa: PLC0415
253-
254-
embeddings = (
255-
Embedding(
256-
timdex_record_id=input_record.timdex_record_id,
257-
run_id=input_record.run_id,
258-
run_record_offset=input_record.run_record_offset,
259-
embedding_strategy=input_record.embedding_strategy,
260-
model_uri=model.model_uri,
261-
embedding_vector=[0.1, 0.2, 0.3],
262-
embedding_token_weights={"coffee": 0.9, "seattle": 0.5},
263-
)
264-
for input_record in input_records
265-
)
266-
# DEBUG ------------------------------------------------------------------------------
231+
# create an iterator of EmbeddingInputs applying all requested strategies
232+
input_records = create_embedding_inputs(timdex_records, list(strategy))
233+
234+
# create embeddings via the embedding model
235+
embeddings = model.create_embeddings(input_records)
267236

268237
# if requested, write embeddings to a local JSONLines file
269238
if output_jsonl:

embeddings/strategies/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Strategies for transforming TIMDEX records into EmbeddingInputs."""

embeddings/strategies/base.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from abc import ABC, abstractmethod
2+
3+
4+
class BaseStrategy(ABC):
5+
"""Base class for embedding input strategies.
6+
7+
All child classes must set class level attribute STRATEGY_NAME.
8+
"""
9+
10+
STRATEGY_NAME: str # type hint to document the requirement
11+
12+
def __init_subclass__(cls, **kwargs: dict) -> None: # noqa: D105
13+
super().__init_subclass__(**kwargs)
14+
15+
# require class level STRATEGY_NAME to be set
16+
if not hasattr(cls, "STRATEGY_NAME"):
17+
raise TypeError(f"{cls.__name__} must define 'STRATEGY_NAME' class attribute")
18+
if not isinstance(cls.STRATEGY_NAME, str):
19+
raise TypeError(
20+
f"{cls.__name__} must override 'STRATEGY_NAME' with a valid string"
21+
)
22+
23+
@abstractmethod
24+
def extract_text(self, timdex_record: dict) -> str:
25+
"""Extract text to be embedded from transformed_record.
26+
27+
Args:
28+
timdex_record: TIMDEX JSON record ("transformed_record" in TIMDEX dataset)
29+
"""
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import json
2+
3+
from embeddings.strategies.base import BaseStrategy
4+
5+
6+
class FullRecordStrategy(BaseStrategy):
7+
"""Serialize entire TIMDEX record JSON as embedding input."""
8+
9+
STRATEGY_NAME = "full_record"
10+
11+
def extract_text(self, timdex_record: dict) -> str:
12+
"""Serialize the entire transformed_record as JSON."""
13+
return json.dumps(timdex_record)

embeddings/strategies/processor.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import json
2+
from collections.abc import Iterator
3+
4+
from embeddings.embedding import EmbeddingInput
5+
from embeddings.strategies.registry import get_strategy_class
6+
7+
8+
def create_embedding_inputs(
9+
timdex_dataset_records: Iterator[dict],
10+
strategies: list[str],
11+
) -> Iterator[EmbeddingInput]:
12+
"""Yield EmbeddingInput instances for all records x all strategies.
13+
14+
Creates a cartesian product: each record is processed by each strategy,
15+
yielding one EmbeddingInput per combination.
16+
17+
Args:
18+
timdex_dataset_records: Iterator of TIMDEX records.
19+
Expected keys: timdex_record_id, run_id, run_record_offset,
20+
transformed_record (bytes)
21+
strategies: List of strategy names to apply
22+
23+
Yields:
24+
EmbeddingInput instances ready for embedding model
25+
26+
Example:
27+
100 records x 3 strategies = 300 EmbeddingInput instances
28+
"""
29+
# instantiate strategy transformers
30+
transformers = [get_strategy_class(strategy)() for strategy in strategies]
31+
32+
# loop through records and apply all strategies, yielding an EmbeddingInput for each
33+
for timdex_dataset_record in timdex_dataset_records:
34+
35+
# decode and parse the TIMDEX JSON record once for all requested strategies
36+
timdex_record = json.loads(timdex_dataset_record["transformed_record"].decode())
37+
38+
for transformer in transformers:
39+
# prepare text for embedding from transformer strategy
40+
text = transformer.extract_text(timdex_record)
41+
42+
# emit an EmbeddingInput instance
43+
yield EmbeddingInput(
44+
timdex_record_id=timdex_dataset_record["timdex_record_id"],
45+
run_id=timdex_dataset_record["run_id"],
46+
run_record_offset=timdex_dataset_record["run_record_offset"],
47+
embedding_strategy=transformer.STRATEGY_NAME,
48+
text=text,
49+
)

embeddings/strategies/registry.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import logging
2+
3+
from embeddings.strategies.base import BaseStrategy
4+
from embeddings.strategies.full_record import FullRecordStrategy
5+
6+
logger = logging.getLogger(__name__)
7+
8+
STRATEGY_CLASSES = [
9+
FullRecordStrategy,
10+
]
11+
12+
STRATEGY_REGISTRY: dict[str, type[BaseStrategy]] = {
13+
strategy.STRATEGY_NAME: strategy for strategy in STRATEGY_CLASSES
14+
}
15+
16+
17+
def get_strategy_class(strategy_name: str) -> type[BaseStrategy]:
18+
"""Get strategy class by name.
19+
20+
Args:
21+
strategy_name: Name of the strategy to retrieve
22+
"""
23+
if strategy_name not in STRATEGY_REGISTRY:
24+
available = ", ".join(sorted(STRATEGY_REGISTRY.keys()))
25+
msg = f"Unknown strategy: {strategy_name}. Available: {available}"
26+
logger.error(msg)
27+
raise ValueError(msg)
28+
29+
return STRATEGY_REGISTRY[strategy_name]

tests/test_strategies.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import json
2+
3+
import pytest
4+
5+
from embeddings.strategies.base import BaseStrategy
6+
from embeddings.strategies.full_record import FullRecordStrategy
7+
from embeddings.strategies.processor import create_embedding_inputs
8+
from embeddings.strategies.registry import get_strategy_class
9+
10+
11+
def test_full_record_strategy_extracts_text():
12+
timdex_record = {"timdex_record_id": "test-123", "title": ["Test Title"]}
13+
strategy = FullRecordStrategy()
14+
15+
text = strategy.extract_text(timdex_record)
16+
17+
assert text == json.dumps(timdex_record)
18+
assert strategy.STRATEGY_NAME == "full_record"
19+
20+
21+
def test_create_embedding_inputs_yields_cartesian_product():
22+
# two records
23+
timdex_records = iter(
24+
[
25+
{
26+
"timdex_record_id": "id-1",
27+
"run_id": "run-1",
28+
"run_record_offset": 0,
29+
"transformed_record": b'{"title": ["Record 1"]}',
30+
},
31+
{
32+
"timdex_record_id": "id-2",
33+
"run_id": "run-1",
34+
"run_record_offset": 1,
35+
"transformed_record": b'{"title": ["Record 2"]}',
36+
},
37+
]
38+
)
39+
40+
# single strategy (for now)
41+
strategies = ["full_record"]
42+
43+
embedding_inputs = list(create_embedding_inputs(timdex_records, strategies))
44+
45+
assert len(embedding_inputs) == 2
46+
assert embedding_inputs[0].timdex_record_id == "id-1"
47+
assert embedding_inputs[0].embedding_strategy == "full_record"
48+
assert embedding_inputs[1].timdex_record_id == "id-2"
49+
assert embedding_inputs[1].embedding_strategy == "full_record"
50+
51+
52+
def test_get_strategy_class_returns_correct_class():
53+
strategy_class = get_strategy_class("full_record")
54+
assert strategy_class is FullRecordStrategy
55+
56+
57+
def test_get_strategy_class_raises_for_unknown_strategy():
58+
with pytest.raises(ValueError, match="Unknown strategy"):
59+
get_strategy_class("nonexistent_strategy")
60+
61+
62+
def test_subclass_without_strategy_name_raises_type_error():
63+
with pytest.raises(TypeError, match="must define 'STRATEGY_NAME' class attribute"):
64+
65+
class InvalidStrategy(BaseStrategy):
66+
pass
67+
68+
69+
def test_subclass_with_non_string_strategy_name_raises_type_error():
70+
with pytest.raises(
71+
TypeError, match="must override 'STRATEGY_NAME' with a valid string"
72+
):
73+
74+
class InvalidStrategy(BaseStrategy):
75+
STRATEGY_NAME = 123

0 commit comments

Comments
 (0)