Skip to content

Commit 70d8394

Browse files
committed
Strategies unit tests
1 parent cb03062 commit 70d8394

File tree

1 file changed

+83
-0
lines changed

1 file changed

+83
-0
lines changed

tests/test_strategies.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import json
2+
3+
import pytest
4+
5+
from embeddings.strategies.base import BaseStrategy
6+
from embeddings.strategies.full_record import FullRecordStrategy
7+
from embeddings.strategies.processor import create_embedding_inputs
8+
from embeddings.strategies.registry import get_strategy_class
9+
10+
11+
def test_full_record_strategy_creates_embedding_input():
12+
transformed_record = {"timdex_record_id": "test-123", "title": ["Test Title"]}
13+
strategy = FullRecordStrategy(
14+
timdex_record_id="test-123",
15+
run_id="run-456",
16+
run_record_offset=42,
17+
transformed_record=transformed_record,
18+
)
19+
20+
embedding_input = strategy.to_embedding_input()
21+
22+
assert embedding_input.timdex_record_id == "test-123"
23+
assert embedding_input.run_id == "run-456"
24+
assert embedding_input.run_record_offset == 42
25+
assert embedding_input.embedding_strategy == "full_record"
26+
assert embedding_input.text == json.dumps(transformed_record)
27+
28+
29+
def test_create_embedding_inputs_yields_cartesian_product():
30+
# two records
31+
timdex_records = iter(
32+
[
33+
{
34+
"timdex_record_id": "id-1",
35+
"run_id": "run-1",
36+
"run_record_offset": 0,
37+
"transformed_record": b'{"title": ["Record 1"]}',
38+
},
39+
{
40+
"timdex_record_id": "id-2",
41+
"run_id": "run-1",
42+
"run_record_offset": 1,
43+
"transformed_record": b'{"title": ["Record 2"]}',
44+
},
45+
]
46+
)
47+
48+
# single strategy (for now)
49+
strategies = ["full_record"]
50+
51+
embedding_inputs = list(create_embedding_inputs(timdex_records, strategies))
52+
53+
assert len(embedding_inputs) == 2
54+
assert embedding_inputs[0].timdex_record_id == "id-1"
55+
assert embedding_inputs[0].embedding_strategy == "full_record"
56+
assert embedding_inputs[1].timdex_record_id == "id-2"
57+
assert embedding_inputs[1].embedding_strategy == "full_record"
58+
59+
60+
def test_get_strategy_class_returns_correct_class():
61+
strategy_class = get_strategy_class("full_record")
62+
assert strategy_class is FullRecordStrategy
63+
64+
65+
def test_get_strategy_class_raises_for_unknown_strategy():
66+
with pytest.raises(ValueError, match="Unknown strategy"):
67+
get_strategy_class("nonexistent_strategy")
68+
69+
70+
def test_subclass_without_strategy_name_raises_type_error():
71+
with pytest.raises(TypeError, match="must define 'STRATEGY_NAME' class attribute"):
72+
73+
class InvalidStrategy(BaseStrategy):
74+
pass
75+
76+
77+
def test_subclass_with_non_string_strategy_name_raises_type_error():
78+
with pytest.raises(
79+
TypeError, match="must override 'STRATEGY_NAME' with a valid string"
80+
):
81+
82+
class InvalidStrategy(BaseStrategy):
83+
STRATEGY_NAME = 123

0 commit comments

Comments
 (0)