Skip to content

Commit 421ba71

Browse files
committed
Init transformer strategy once for all records
Why these changes are being introduced: Formerly, a transformer strategy class was instantiated in a per-record fashion, where things like the timdex_record_id and other record-level values were passed. This ultimately felt awkward, when we could just as easily instantiate it once in a more generic fashion, then build EmbeddingInput instances with the *result* of the strategy extracting text from the TIMDEX JSON record. How this addresses that need: All record-level details are removed as arguments for initializing a transformer strategy. Instead, the helper function create_embedding_inputs() is responsible for passing the TIMDEX JSON record to the transformer strategies, and then building an EmbeddingInput object before yielding. This keeps the init of those strategies much simpler, and preventing properties in the class they don't really need. Side effects of this change: * None Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/USE-131 * https://mitlibraries.atlassian.net/browse/USE-132
1 parent 145cd81 commit 421ba71

File tree

4 files changed

+36
-65
lines changed

4 files changed

+36
-65
lines changed

embeddings/strategies/base.py

Lines changed: 6 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from abc import ABC, abstractmethod
22

3-
from embeddings.embedding import EmbeddingInput
4-
53

64
class BaseStrategy(ABC):
75
"""Base class for embedding input strategies.
@@ -11,26 +9,6 @@ class BaseStrategy(ABC):
119

1210
STRATEGY_NAME: str # type hint to document the requirement
1311

14-
def __init__(
15-
self,
16-
timdex_record_id: str,
17-
run_id: str,
18-
run_record_offset: int,
19-
transformed_record: dict,
20-
) -> None:
21-
"""Initialize strategy with TIMDEX record metadata.
22-
23-
Args:
24-
timdex_record_id: TIMDEX record ID
25-
run_id: TIMDEX ETL run ID
26-
run_record_offset: record offset within the run
27-
transformed_record: parsed TIMDEX record JSON
28-
"""
29-
self.timdex_record_id = timdex_record_id
30-
self.run_id = run_id
31-
self.run_record_offset = run_record_offset
32-
self.transformed_record = transformed_record
33-
3412
def __init_subclass__(cls, **kwargs: dict) -> None: # noqa: D105
3513
super().__init_subclass__(**kwargs)
3614

@@ -43,15 +21,9 @@ def __init_subclass__(cls, **kwargs: dict) -> None: # noqa: D105
4321
)
4422

4523
@abstractmethod
46-
def extract_text(self) -> str:
47-
"""Extract text to be embedded from transformed_record."""
48-
49-
def to_embedding_input(self) -> EmbeddingInput:
50-
"""Create EmbeddingInput instance with strategy-specific extracted text."""
51-
return EmbeddingInput(
52-
timdex_record_id=self.timdex_record_id,
53-
run_id=self.run_id,
54-
run_record_offset=self.run_record_offset,
55-
embedding_strategy=self.STRATEGY_NAME,
56-
text=self.extract_text(),
57-
)
24+
def extract_text(self, timdex_record: dict) -> str:
25+
"""Extract text to be embedded from transformed_record.
26+
27+
Args:
28+
timdex_record: TIMDEX JSON record ("transformed_record" in TIMDEX dataset)
29+
"""

embeddings/strategies/full_record.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ class FullRecordStrategy(BaseStrategy):
88

99
STRATEGY_NAME = "full_record"
1010

11-
def extract_text(self) -> str:
11+
def extract_text(self, timdex_record: dict) -> str:
1212
"""Serialize the entire transformed_record as JSON."""
13-
return json.dumps(self.transformed_record)
13+
return json.dumps(timdex_record)

embeddings/strategies/processor.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
def create_embedding_inputs(
9-
timdex_records: Iterator[dict],
9+
timdex_dataset_records: Iterator[dict],
1010
strategies: list[str],
1111
) -> Iterator[EmbeddingInput]:
1212
"""Yield EmbeddingInput instances for all records x all strategies.
@@ -15,7 +15,7 @@ def create_embedding_inputs(
1515
yielding one EmbeddingInput per combination.
1616
1717
Args:
18-
timdex_records: Iterator of TIMDEX records.
18+
timdex_dataset_records: Iterator of TIMDEX records.
1919
Expected keys: timdex_record_id, run_id, run_record_offset,
2020
transformed_record (bytes)
2121
strategies: List of strategy names to apply
@@ -26,17 +26,24 @@ def create_embedding_inputs(
2626
Example:
2727
100 records x 3 strategies = 300 EmbeddingInput instances
2828
"""
29-
for timdex_record in timdex_records:
30-
# decode and parse the TIMDEX JSON record
31-
transformed_record = json.loads(timdex_record["transformed_record"].decode())
32-
33-
# apply all strategies to the record and yield
34-
for strategy_name in strategies:
35-
strategy_class = get_strategy_class(strategy_name)
36-
strategy_instance = strategy_class(
37-
timdex_record_id=timdex_record["timdex_record_id"],
38-
run_id=timdex_record["run_id"],
39-
run_record_offset=timdex_record["run_record_offset"],
40-
transformed_record=transformed_record,
29+
# instantiate strategy transformers
30+
transformers = [get_strategy_class(strategy)() for strategy in strategies]
31+
32+
# loop through records and apply all strategies, yielding an EmbeddingInput for each
33+
for timdex_dataset_record in timdex_dataset_records:
34+
35+
# decode and parse the TIMDEX JSON record once for all requested strategies
36+
timdex_record = json.loads(timdex_dataset_record["transformed_record"].decode())
37+
38+
for transformer in transformers:
39+
# prepare text for embedding from transformer strategy
40+
text = transformer.extract_text(timdex_record)
41+
42+
# emit an EmbeddingInput instance
43+
yield EmbeddingInput(
44+
timdex_record_id=timdex_dataset_record["timdex_record_id"],
45+
run_id=timdex_dataset_record["run_id"],
46+
run_record_offset=timdex_dataset_record["run_record_offset"],
47+
embedding_strategy=transformer.STRATEGY_NAME,
48+
text=text,
4149
)
42-
yield strategy_instance.to_embedding_input()

tests/test_strategies.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,14 @@
88
from embeddings.strategies.registry import get_strategy_class
99

1010

11-
def test_full_record_strategy_creates_embedding_input():
12-
transformed_record = {"timdex_record_id": "test-123", "title": ["Test Title"]}
13-
strategy = FullRecordStrategy(
14-
timdex_record_id="test-123",
15-
run_id="run-456",
16-
run_record_offset=42,
17-
transformed_record=transformed_record,
18-
)
11+
def test_full_record_strategy_extracts_text():
12+
timdex_record = {"timdex_record_id": "test-123", "title": ["Test Title"]}
13+
strategy = FullRecordStrategy()
1914

20-
embedding_input = strategy.to_embedding_input()
15+
text = strategy.extract_text(timdex_record)
2116

22-
assert embedding_input.timdex_record_id == "test-123"
23-
assert embedding_input.run_id == "run-456"
24-
assert embedding_input.run_record_offset == 42
25-
assert embedding_input.embedding_strategy == "full_record"
26-
assert embedding_input.text == json.dumps(transformed_record)
17+
assert text == json.dumps(timdex_record)
18+
assert strategy.STRATEGY_NAME == "full_record"
2719

2820

2921
def test_create_embedding_inputs_yields_cartesian_product():

0 commit comments

Comments
 (0)