Skip to content

Commit cb03062

Browse files
committed
Introduce strategies framework for preparing EmbeddingInputs
Why these changes are being introduced: A core requirement of this application is the ability to take a TIMDEX JSON record and "transform" all or parts of it into a single string for which an embedding can be created. We are calling these "embedding strategies" in the context of this app. While our first strategy will likely be a very simple, full record approach, we want to support multiple strategies in the application, and even multiple strategies for a single record in a single invocation. How this addresses that need: * A new 'strategies' module is created * A base 'BaseStrategy' class, with a required 'extract_text()' method for implementations * Our first strategy represented in class 'FullRecordStrategy', which JSON dumps the entire TIMDEX JSON record. * A registry of strategies, similar to our models, that allow CLI level validation. Side effects of this change: * None really, but further solidifies that this application is contains the opinionation about how text is prepared for the embedding process. Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/USE-131 * https://mitlibraries.atlassian.net/browse/USE-132
1 parent bd9de2b commit cb03062

File tree

7 files changed

+158
-46
lines changed

7 files changed

+158
-46
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,9 @@ Options:
114114
default = 0. [required]
115115
--record-limit INTEGER Limit number of records after --run-record-
116116
offset, default = None (unlimited). [required]
117-
--strategy TEXT Pre-embedding record transformation strategy to
118-
use. Repeatable. [required]
117+
--strategy [full_record] Pre-embedding record transformation strategy.
118+
Repeatable to apply multiple strategies.
119+
[required]
119120
--output-jsonl TEXT Optionally write embeddings to local JSONLines
120121
file (primarily for testing).
121122
--help Show this message and exit.

embeddings/cli.py

Lines changed: 13 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import json
23
import logging
34
import time
45
from collections.abc import Callable
@@ -12,6 +13,8 @@
1213

1314
from embeddings.config import configure_logger, configure_sentry
1415
from embeddings.models.registry import get_model_class
16+
from embeddings.strategies.processor import create_embedding_inputs
17+
from embeddings.strategies.registry import STRATEGY_REGISTRY
1518

1619
logger = logging.getLogger(__name__)
1720

@@ -181,10 +184,13 @@ def test_model_load(ctx: click.Context) -> None:
181184
)
182185
@click.option(
183186
"--strategy",
184-
type=str, # WIP: establish an enum of supported strategies
187+
type=click.Choice(list(STRATEGY_REGISTRY.keys())),
185188
required=True,
186189
multiple=True,
187-
help="Pre-embedding record transformation strategy to use. Repeatable.",
190+
help=(
191+
"Pre-embedding record transformation strategy. "
192+
"Repeatable to apply multiple strategies."
193+
),
188194
)
189195
@click.option(
190196
"--output-jsonl",
@@ -222,48 +228,11 @@ def create_embeddings(
222228
action="index",
223229
)
224230

225-
# create an iterator of InputTexts applying all requested strategies to all records
226-
# WIP NOTE: this will leverage some kind of pre-embedding transformer class(es) that
227-
# create texts based on the requested strategies (e.g. "full record"), which are
228-
# captured in --strategy CLI args
229-
# WIP NOTE: the following simulates that...
230-
# DEBUG ------------------------------------------------------------------------------
231-
import json # noqa: PLC0415
232-
233-
from embeddings.embedding import EmbeddingInput # noqa: PLC0415
234-
235-
input_records = (
236-
EmbeddingInput(
237-
timdex_record_id=timdex_record["timdex_record_id"],
238-
run_id=timdex_record["run_id"],
239-
run_record_offset=timdex_record["run_record_offset"],
240-
embedding_strategy=_strategy,
241-
text=json.dumps(timdex_record["transformed_record"].decode()),
242-
)
243-
for timdex_record in timdex_records
244-
for _strategy in strategy
245-
)
246-
# DEBUG ------------------------------------------------------------------------------
247-
248-
# create an iterator of Embeddings via the embedding model
249-
# WIP NOTE: this will use the embedding class .create_embeddings() bulk method
250-
# WIP NOTE: the following simulates that...
251-
# DEBUG ------------------------------------------------------------------------------
252-
from embeddings.embedding import Embedding # noqa: PLC0415
253-
254-
embeddings = (
255-
Embedding(
256-
timdex_record_id=input_record.timdex_record_id,
257-
run_id=input_record.run_id,
258-
run_record_offset=input_record.run_record_offset,
259-
embedding_strategy=input_record.embedding_strategy,
260-
model_uri=model.model_uri,
261-
embedding_vector=[0.1, 0.2, 0.3],
262-
embedding_token_weights={"coffee": 0.9, "seattle": 0.5},
263-
)
264-
for input_record in input_records
265-
)
266-
# DEBUG ------------------------------------------------------------------------------
231+
# create an iterator of EmbeddingInputs applying all requested strategies
232+
input_records = create_embedding_inputs(timdex_records, list(strategy))
233+
234+
# create embeddings via the embedding model
235+
embeddings = model.create_embeddings(input_records)
267236

268237
# if requested, write embeddings to a local JSONLines file
269238
if output_jsonl:

embeddings/strategies/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Strategies for transforming TIMDEX records into EmbeddingInputs."""

embeddings/strategies/base.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from abc import ABC, abstractmethod
2+
3+
from embeddings.embedding import EmbeddingInput
4+
5+
6+
class BaseStrategy(ABC):
7+
"""Base class for embedding input strategies.
8+
9+
All child classes must set class level attribute STRATEGY_NAME.
10+
"""
11+
12+
STRATEGY_NAME: str # type hint to document the requirement
13+
14+
def __init__(
15+
self,
16+
timdex_record_id: str,
17+
run_id: str,
18+
run_record_offset: int,
19+
transformed_record: dict,
20+
) -> None:
21+
"""Initialize strategy with TIMDEX record metadata.
22+
23+
Args:
24+
timdex_record_id: TIMDEX record ID
25+
run_id: TIMDEX ETL run ID
26+
run_record_offset: record offset within the run
27+
transformed_record: parsed TIMDEX record JSON
28+
"""
29+
self.timdex_record_id = timdex_record_id
30+
self.run_id = run_id
31+
self.run_record_offset = run_record_offset
32+
self.transformed_record = transformed_record
33+
34+
def __init_subclass__(cls, **kwargs: dict) -> None: # noqa: D105
35+
super().__init_subclass__(**kwargs)
36+
37+
# require class level STRATEGY_NAME to be set
38+
if not hasattr(cls, "STRATEGY_NAME"):
39+
msg = f"{cls.__name__} must define 'STRATEGY_NAME' class attribute"
40+
raise TypeError(msg)
41+
if not isinstance(cls.STRATEGY_NAME, str):
42+
msg = f"{cls.__name__} must override 'STRATEGY_NAME' with a valid string"
43+
raise TypeError(msg)
44+
45+
@abstractmethod
46+
def extract_text(self) -> str:
47+
"""Extract text to be embedded from transformed_record."""
48+
49+
def to_embedding_input(self) -> EmbeddingInput:
50+
"""Create EmbeddingInput instance with strategy-specific extracted text."""
51+
return EmbeddingInput(
52+
timdex_record_id=self.timdex_record_id,
53+
run_id=self.run_id,
54+
run_record_offset=self.run_record_offset,
55+
embedding_strategy=self.STRATEGY_NAME,
56+
text=self.extract_text(),
57+
)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import json
2+
3+
from embeddings.strategies.base import BaseStrategy
4+
5+
6+
class FullRecordStrategy(BaseStrategy):
7+
"""Serialize entire TIMDEX record JSON as embedding input."""
8+
9+
STRATEGY_NAME = "full_record"
10+
11+
def extract_text(self) -> str:
12+
"""Serialize the entire transformed_record as JSON."""
13+
return json.dumps(self.transformed_record)

embeddings/strategies/processor.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import json
2+
from collections.abc import Iterator
3+
4+
from embeddings.embedding import EmbeddingInput
5+
from embeddings.strategies.registry import get_strategy_class
6+
7+
8+
def create_embedding_inputs(
9+
timdex_records: Iterator[dict],
10+
strategies: list[str],
11+
) -> Iterator[EmbeddingInput]:
12+
"""Yield EmbeddingInput instances for all records x all strategies.
13+
14+
Creates a cartesian product: each record is processed by each strategy,
15+
yielding one EmbeddingInput per combination.
16+
17+
Args:
18+
timdex_records: Iterator of TIMDEX records.
19+
Expected keys: timdex_record_id, run_id, run_record_offset,
20+
transformed_record (bytes)
21+
strategies: List of strategy names to apply
22+
23+
Yields:
24+
EmbeddingInput instances ready for embedding model
25+
26+
Example:
27+
100 records x 3 strategies = 300 EmbeddingInput instances
28+
"""
29+
for timdex_record in timdex_records:
30+
# decode and parse the TIMDEX JSON record
31+
transformed_record = json.loads(timdex_record["transformed_record"].decode())
32+
33+
# apply all strategies to the record and yield
34+
for strategy_name in strategies:
35+
strategy_class = get_strategy_class(strategy_name)
36+
strategy_instance = strategy_class(
37+
timdex_record_id=timdex_record["timdex_record_id"],
38+
run_id=timdex_record["run_id"],
39+
run_record_offset=timdex_record["run_record_offset"],
40+
transformed_record=transformed_record,
41+
)
42+
yield strategy_instance.to_embedding_input()

embeddings/strategies/registry.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import logging
2+
3+
from embeddings.strategies.base import BaseStrategy
4+
from embeddings.strategies.full_record import FullRecordStrategy
5+
6+
logger = logging.getLogger(__name__)
7+
8+
STRATEGY_CLASSES = [
9+
FullRecordStrategy,
10+
]
11+
12+
STRATEGY_REGISTRY: dict[str, type[BaseStrategy]] = {
13+
strategy.STRATEGY_NAME: strategy for strategy in STRATEGY_CLASSES
14+
}
15+
16+
17+
def get_strategy_class(strategy_name: str) -> type[BaseStrategy]:
18+
"""Get strategy class by name.
19+
20+
Args:
21+
strategy_name: Name of the strategy to retrieve
22+
"""
23+
if strategy_name not in STRATEGY_REGISTRY:
24+
available = ", ".join(sorted(STRATEGY_REGISTRY.keys()))
25+
msg = f"Unknown strategy: {strategy_name}. Available: {available}"
26+
logger.error(msg)
27+
raise ValueError(msg)
28+
29+
return STRATEGY_REGISTRY[strategy_name]

0 commit comments

Comments
 (0)