|
1 | 1 | import functools |
| 2 | +import json |
2 | 3 | import logging |
3 | 4 | import time |
4 | 5 | from collections.abc import Callable |
|
12 | 13 |
|
13 | 14 | from embeddings.config import configure_logger, configure_sentry |
14 | 15 | from embeddings.models.registry import get_model_class |
| 16 | +from embeddings.strategies.processor import create_embedding_inputs |
| 17 | +from embeddings.strategies.registry import STRATEGY_REGISTRY |
15 | 18 |
|
16 | 19 | logger = logging.getLogger(__name__) |
17 | 20 |
|
@@ -181,10 +184,13 @@ def test_model_load(ctx: click.Context) -> None: |
181 | 184 | ) |
182 | 185 | @click.option( |
183 | 186 | "--strategy", |
184 | | - type=str, # WIP: establish an enum of supported strategies |
| 187 | + type=click.Choice(list(STRATEGY_REGISTRY.keys())), |
185 | 188 | required=True, |
186 | 189 | multiple=True, |
187 | | - help="Pre-embedding record transformation strategy to use. Repeatable.", |
| 190 | + help=( |
| 191 | + "Pre-embedding record transformation strategy. " |
| 192 | + "Repeatable to apply multiple strategies." |
| 193 | + ), |
188 | 194 | ) |
189 | 195 | @click.option( |
190 | 196 | "--output-jsonl", |
@@ -222,48 +228,11 @@ def create_embeddings( |
222 | 228 | action="index", |
223 | 229 | ) |
224 | 230 |
|
225 | | - # create an iterator of InputTexts applying all requested strategies to all records |
226 | | - # WIP NOTE: this will leverage some kind of pre-embedding transformer class(es) that |
227 | | - # create texts based on the requested strategies (e.g. "full record"), which are |
228 | | - # captured in --strategy CLI args |
229 | | - # WIP NOTE: the following simulates that... |
230 | | - # DEBUG ------------------------------------------------------------------------------ |
231 | | - import json # noqa: PLC0415 |
232 | | - |
233 | | - from embeddings.embedding import EmbeddingInput # noqa: PLC0415 |
234 | | - |
235 | | - input_records = ( |
236 | | - EmbeddingInput( |
237 | | - timdex_record_id=timdex_record["timdex_record_id"], |
238 | | - run_id=timdex_record["run_id"], |
239 | | - run_record_offset=timdex_record["run_record_offset"], |
240 | | - embedding_strategy=_strategy, |
241 | | - text=json.dumps(timdex_record["transformed_record"].decode()), |
242 | | - ) |
243 | | - for timdex_record in timdex_records |
244 | | - for _strategy in strategy |
245 | | - ) |
246 | | - # DEBUG ------------------------------------------------------------------------------ |
247 | | - |
248 | | - # create an iterator of Embeddings via the embedding model |
249 | | - # WIP NOTE: this will use the embedding class .create_embeddings() bulk method |
250 | | - # WIP NOTE: the following simulates that... |
251 | | - # DEBUG ------------------------------------------------------------------------------ |
252 | | - from embeddings.embedding import Embedding # noqa: PLC0415 |
253 | | - |
254 | | - embeddings = ( |
255 | | - Embedding( |
256 | | - timdex_record_id=input_record.timdex_record_id, |
257 | | - run_id=input_record.run_id, |
258 | | - run_record_offset=input_record.run_record_offset, |
259 | | - embedding_strategy=input_record.embedding_strategy, |
260 | | - model_uri=model.model_uri, |
261 | | - embedding_vector=[0.1, 0.2, 0.3], |
262 | | - embedding_token_weights={"coffee": 0.9, "seattle": 0.5}, |
263 | | - ) |
264 | | - for input_record in input_records |
265 | | - ) |
266 | | - # DEBUG ------------------------------------------------------------------------------ |
| 231 | + # create an iterator of EmbeddingInputs applying all requested strategies |
| 232 | + input_records = create_embedding_inputs(timdex_records, list(strategy)) |
| 233 | + |
| 234 | + # create embeddings via the embedding model |
| 235 | + embeddings = model.create_embeddings(input_records) |
267 | 236 |
|
268 | 237 | # if requested, write embeddings to a local JSONLines file |
269 | 238 | if output_jsonl: |
|
0 commit comments