File tree Expand file tree Collapse file tree 6 files changed +21
-17
lines changed Expand file tree Collapse file tree 6 files changed +21
-17
lines changed Original file line number Diff line number Diff line change @@ -230,10 +230,10 @@ def create_embeddings(
230230 # DEBUG ------------------------------------------------------------------------------
231231 import json # noqa: PLC0415
232232
233- from embeddings .embedding import RecordText # noqa: PLC0415
233+ from embeddings .embedding import EmbeddingInput # noqa: PLC0415
234234
235235 input_records = (
236- RecordText (
236+ EmbeddingInput (
237237 timdex_record_id = timdex_record ["timdex_record_id" ],
238238 run_id = timdex_record ["run_id" ],
239239 run_record_offset = timdex_record ["run_record_offset" ],
Original file line number Diff line number Diff line change 44
55
66@dataclass
7- class RecordText :
8- """Input record for creating an embedding for.
7+ class EmbeddingInput :
8+ """Encapsulates the inputs for an embedding.
9+
10+ When creating an embedding, we need to note what TIMDEX record the embedding is
11+ associated with and what strategy was used to prepare the embedding input text from
12+ the record itself.
913
1014 Args:
1115 (timdex_record_id, run_id, run_record_offset): composite key for TIMDEX record
Original file line number Diff line number Diff line change 44from collections .abc import Iterator
55from pathlib import Path
66
7- from embeddings .embedding import Embedding , RecordText
7+ from embeddings .embedding import Embedding , EmbeddingInput
88
99
1010class BaseEmbeddingModel (ABC ):
@@ -51,15 +51,15 @@ def load(self) -> None:
5151 """Load model from self.model_path."""
5252
5353 @abstractmethod
54- def create_embedding (self , input_record : RecordText ) -> Embedding :
55- """Create an Embedding for an RecordText .
54+ def create_embedding (self , input_record : EmbeddingInput ) -> Embedding :
55+ """Create an Embedding for an EmbeddingInput .
5656
5757 Args:
58- input_record: RecordText instance
58+ input_record: EmbeddingInput instance
5959 """
6060
6161 def create_embeddings (
62- self , input_records : Iterator [RecordText ]
62+ self , input_records : Iterator [EmbeddingInput ]
6363 ) -> Iterator [Embedding ]:
6464 """Yield Embeddings for an iterator of InputRecords.
6565
Original file line number Diff line number Diff line change 1111from huggingface_hub import snapshot_download
1212from transformers import AutoModelForMaskedLM , AutoTokenizer
1313
14- from embeddings .embedding import Embedding , RecordText
14+ from embeddings .embedding import Embedding , EmbeddingInput
1515from embeddings .models .base import BaseEmbeddingModel
1616
1717if TYPE_CHECKING :
@@ -163,5 +163,5 @@ def load(self) -> None:
163163
164164 logger .info (f"Model loaded successfully, { time .perf_counter ()- start_time } s" )
165165
166- def create_embedding (self , input_record : RecordText ) -> Embedding :
166+ def create_embedding (self , input_record : EmbeddingInput ) -> Embedding :
167167 raise NotImplementedError
Original file line number Diff line number Diff line change 66import pytest
77from click .testing import CliRunner
88
9- from embeddings .embedding import Embedding , RecordText
9+ from embeddings .embedding import Embedding , EmbeddingInput
1010from embeddings .models import registry
1111from embeddings .models .base import BaseEmbeddingModel
1212
@@ -45,7 +45,7 @@ def download(self) -> Path:
4545 def load (self ) -> None :
4646 logger .info ("Model loaded successfully, 1.5s" )
4747
48- def create_embedding (self , input_record : RecordText ) -> Embedding :
48+ def create_embedding (self , input_record : EmbeddingInput ) -> Embedding :
4949 return Embedding (
5050 timdex_record_id = input_record .timdex_record_id ,
5151 run_id = input_record .run_id ,
Original file line number Diff line number Diff line change 22
33import pytest
44
5- from embeddings .embedding import RecordText
5+ from embeddings .embedding import EmbeddingInput
66from embeddings .models .base import BaseEmbeddingModel
77from embeddings .models .registry import MODEL_REGISTRY , get_model_class
88
@@ -35,7 +35,7 @@ def test_mock_model_load(caplog, mock_model):
3535
3636
3737def test_mock_model_create_embedding (mock_model ):
38- input_record = RecordText (
38+ input_record = EmbeddingInput (
3939 timdex_record_id = "test-id" ,
4040 run_id = "test-run" ,
4141 run_record_offset = 42 ,
@@ -87,14 +87,14 @@ class InvalidModel(BaseEmbeddingModel):
8787
8888def test_base_model_create_embeddings_calls_create_embedding (mock_model ):
8989 input_records = [
90- RecordText (
90+ EmbeddingInput (
9191 timdex_record_id = "id-1" ,
9292 run_id = "run-1" ,
9393 run_record_offset = 0 ,
9494 embedding_strategy = "full_record" ,
9595 text = "text 1" ,
9696 ),
97- RecordText (
97+ EmbeddingInput (
9898 timdex_record_id = "id-2" ,
9999 run_id = "run-1" ,
100100 run_record_offset = 1 ,
You can’t perform that action at this time.
0 commit comments