Skip to content

Commit d433f17

Browse files
committed
Rename RecordText to EmbeddingInput
Why these changes are being introduced: Code review suggested that 'RecordText' was a confusing name for the object that we prepare to then create an embedding from. How this addresses that need: Renamign to 'EmbeddingInput' makes it crystal clear that we are preparing an object that will be used to create an embedding. Side effects of this change: * None Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/USE-112
1 parent 0024b8f commit d433f17

File tree

6 files changed

+21
-17
lines changed

6 files changed

+21
-17
lines changed

embeddings/cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,10 +230,10 @@ def create_embeddings(
230230
# DEBUG ------------------------------------------------------------------------------
231231
import json # noqa: PLC0415
232232

233-
from embeddings.embedding import RecordText # noqa: PLC0415
233+
from embeddings.embedding import EmbeddingInput # noqa: PLC0415
234234

235235
input_records = (
236-
RecordText(
236+
EmbeddingInput(
237237
timdex_record_id=timdex_record["timdex_record_id"],
238238
run_id=timdex_record["run_id"],
239239
run_record_offset=timdex_record["run_record_offset"],

embeddings/embedding.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@
44

55

66
@dataclass
7-
class RecordText:
8-
"""Input record for creating an embedding for.
7+
class EmbeddingInput:
8+
"""Encapsulates the inputs for an embedding.
9+
10+
When creating an embedding, we need to note what TIMDEX record the embedding is
11+
associated with and what strategy was used to prepare the embedding input text from
12+
the record itself.
913
1014
Args:
1115
(timdex_record_id, run_id, run_record_offset): composite key for TIMDEX record

embeddings/models/base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import Iterator
55
from pathlib import Path
66

7-
from embeddings.embedding import Embedding, RecordText
7+
from embeddings.embedding import Embedding, EmbeddingInput
88

99

1010
class BaseEmbeddingModel(ABC):
@@ -51,15 +51,15 @@ def load(self) -> None:
5151
"""Load model from self.model_path."""
5252

5353
@abstractmethod
54-
def create_embedding(self, input_record: RecordText) -> Embedding:
55-
"""Create an Embedding for an RecordText.
54+
def create_embedding(self, input_record: EmbeddingInput) -> Embedding:
55+
"""Create an Embedding for an EmbeddingInput.
5656
5757
Args:
58-
input_record: RecordText instance
58+
input_record: EmbeddingInput instance
5959
"""
6060

6161
def create_embeddings(
62-
self, input_records: Iterator[RecordText]
62+
self, input_records: Iterator[EmbeddingInput]
6363
) -> Iterator[Embedding]:
6464
"""Yield Embeddings for an iterator of InputRecords.
6565

embeddings/models/os_neural_sparse_doc_v3_gte.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from huggingface_hub import snapshot_download
1212
from transformers import AutoModelForMaskedLM, AutoTokenizer
1313

14-
from embeddings.embedding import Embedding, RecordText
14+
from embeddings.embedding import Embedding, EmbeddingInput
1515
from embeddings.models.base import BaseEmbeddingModel
1616

1717
if TYPE_CHECKING:
@@ -163,5 +163,5 @@ def load(self) -> None:
163163

164164
logger.info(f"Model loaded successfully, {time.perf_counter()-start_time}s")
165165

166-
def create_embedding(self, input_record: RecordText) -> Embedding:
166+
def create_embedding(self, input_record: EmbeddingInput) -> Embedding:
167167
raise NotImplementedError

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77
from click.testing import CliRunner
88

9-
from embeddings.embedding import Embedding, RecordText
9+
from embeddings.embedding import Embedding, EmbeddingInput
1010
from embeddings.models import registry
1111
from embeddings.models.base import BaseEmbeddingModel
1212

@@ -45,7 +45,7 @@ def download(self) -> Path:
4545
def load(self) -> None:
4646
logger.info("Model loaded successfully, 1.5s")
4747

48-
def create_embedding(self, input_record: RecordText) -> Embedding:
48+
def create_embedding(self, input_record: EmbeddingInput) -> Embedding:
4949
return Embedding(
5050
timdex_record_id=input_record.timdex_record_id,
5151
run_id=input_record.run_id,

tests/test_models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from embeddings.embedding import RecordText
5+
from embeddings.embedding import EmbeddingInput
66
from embeddings.models.base import BaseEmbeddingModel
77
from embeddings.models.registry import MODEL_REGISTRY, get_model_class
88

@@ -35,7 +35,7 @@ def test_mock_model_load(caplog, mock_model):
3535

3636

3737
def test_mock_model_create_embedding(mock_model):
38-
input_record = RecordText(
38+
input_record = EmbeddingInput(
3939
timdex_record_id="test-id",
4040
run_id="test-run",
4141
run_record_offset=42,
@@ -87,14 +87,14 @@ class InvalidModel(BaseEmbeddingModel):
8787

8888
def test_base_model_create_embeddings_calls_create_embedding(mock_model):
8989
input_records = [
90-
RecordText(
90+
EmbeddingInput(
9191
timdex_record_id="id-1",
9292
run_id="run-1",
9393
run_record_offset=0,
9494
embedding_strategy="full_record",
9595
text="text 1",
9696
),
97-
RecordText(
97+
EmbeddingInput(
9898
timdex_record_id="id-2",
9999
run_id="run-1",
100100
run_record_offset=1,

0 commit comments

Comments
 (0)