|
| 1 | +import json |
| 2 | + |
| 3 | +import pytest |
| 4 | + |
| 5 | +from embeddings.strategies.base import BaseStrategy |
| 6 | +from embeddings.strategies.full_record import FullRecordStrategy |
| 7 | +from embeddings.strategies.processor import create_embedding_inputs |
| 8 | +from embeddings.strategies.registry import get_strategy_class |
| 9 | + |
| 10 | + |
| 11 | +def test_full_record_strategy_creates_embedding_input(): |
| 12 | + transformed_record = {"timdex_record_id": "test-123", "title": ["Test Title"]} |
| 13 | + strategy = FullRecordStrategy( |
| 14 | + timdex_record_id="test-123", |
| 15 | + run_id="run-456", |
| 16 | + run_record_offset=42, |
| 17 | + transformed_record=transformed_record, |
| 18 | + ) |
| 19 | + |
| 20 | + embedding_input = strategy.to_embedding_input() |
| 21 | + |
| 22 | + assert embedding_input.timdex_record_id == "test-123" |
| 23 | + assert embedding_input.run_id == "run-456" |
| 24 | + assert embedding_input.run_record_offset == 42 |
| 25 | + assert embedding_input.embedding_strategy == "full_record" |
| 26 | + assert embedding_input.text == json.dumps(transformed_record) |
| 27 | + |
| 28 | + |
| 29 | +def test_create_embedding_inputs_yields_cartesian_product(): |
| 30 | + # two records |
| 31 | + timdex_records = iter( |
| 32 | + [ |
| 33 | + { |
| 34 | + "timdex_record_id": "id-1", |
| 35 | + "run_id": "run-1", |
| 36 | + "run_record_offset": 0, |
| 37 | + "transformed_record": b'{"title": ["Record 1"]}', |
| 38 | + }, |
| 39 | + { |
| 40 | + "timdex_record_id": "id-2", |
| 41 | + "run_id": "run-1", |
| 42 | + "run_record_offset": 1, |
| 43 | + "transformed_record": b'{"title": ["Record 2"]}', |
| 44 | + }, |
| 45 | + ] |
| 46 | + ) |
| 47 | + |
| 48 | + # single strategy (for now) |
| 49 | + strategies = ["full_record"] |
| 50 | + |
| 51 | + embedding_inputs = list(create_embedding_inputs(timdex_records, strategies)) |
| 52 | + |
| 53 | + assert len(embedding_inputs) == 2 |
| 54 | + assert embedding_inputs[0].timdex_record_id == "id-1" |
| 55 | + assert embedding_inputs[0].embedding_strategy == "full_record" |
| 56 | + assert embedding_inputs[1].timdex_record_id == "id-2" |
| 57 | + assert embedding_inputs[1].embedding_strategy == "full_record" |
| 58 | + |
| 59 | + |
| 60 | +def test_get_strategy_class_returns_correct_class(): |
| 61 | + strategy_class = get_strategy_class("full_record") |
| 62 | + assert strategy_class is FullRecordStrategy |
| 63 | + |
| 64 | + |
| 65 | +def test_get_strategy_class_raises_for_unknown_strategy(): |
| 66 | + with pytest.raises(ValueError, match="Unknown strategy"): |
| 67 | + get_strategy_class("nonexistent_strategy") |
| 68 | + |
| 69 | + |
| 70 | +def test_subclass_without_strategy_name_raises_type_error(): |
| 71 | + with pytest.raises(TypeError, match="must define 'STRATEGY_NAME' class attribute"): |
| 72 | + |
| 73 | + class InvalidStrategy(BaseStrategy): |
| 74 | + pass |
| 75 | + |
| 76 | + |
| 77 | +def test_subclass_with_non_string_strategy_name_raises_type_error(): |
| 78 | + with pytest.raises( |
| 79 | + TypeError, match="must override 'STRATEGY_NAME' with a valid string" |
| 80 | + ): |
| 81 | + |
| 82 | + class InvalidStrategy(BaseStrategy): |
| 83 | + STRATEGY_NAME = 123 |
0 commit comments