Skip to content

Commit d0511c0

Browse files
Stub CLI command for updating existing docs with embeddings (#370)
Why these changes are being introduced: * TIM requires a CLI command to update existing docs with embeddings. Currently, TIM uses the OpenSearch 'index' action to create/update docs using *full* records, which performs a full replacement of existing docs. To update docs using *partial* records (e.g., only a subset of TIMDEX fields [i.e., embeddings]), the OpenSearch 'update' action must be used. How this addresses that need: * Add 'bulk_update_embeddings' CLI command * Create BulkOperationError exception * Update tim.helpers.generate_bulk_actions to format inputs for OpenSearch 'update' action Side effects of this change: * TIMDEX Pipeline Lambdas will require updates to generate the 'bulk-update-embeddings' CLI command. Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/USE-122
1 parent a7689c8 commit d0511c0

File tree

4 files changed

+185
-4
lines changed

4 files changed

+185
-4
lines changed

tim/cli.py

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
},
2727
{
2828
"name": "Bulk record processing commands",
29-
"commands": ["bulk-update", "reindex-source"],
29+
"commands": ["bulk-update", "bulk-update-embeddings", "reindex-source"],
3030
},
3131
]
3232
}
@@ -325,6 +325,101 @@ def bulk_update(
325325
logger.info(f"Bulk update complete: {json.dumps(summary_results)}")
326326

327327

328+
# Bulk update existing records with embeddings commands
329+
330+
331+
@main.command()
332+
@click.option(
333+
"-i",
334+
"--index",
335+
help="Name of the index where the bulk update to add embeddings is performed.",
336+
)
337+
@click.option(
338+
"-s",
339+
"--source",
340+
type=click.Choice(VALID_SOURCES),
341+
help=(
342+
"Source whose primary-aliased index will receive the bulk updated "
343+
"records with embeddings."
344+
),
345+
)
346+
@click.option("-d", "--run-date", help="Run date, formatted as YYYY-MM-DD.")
347+
@click.option("-rid", "--run-id", help="Run ID.")
348+
@click.argument("dataset_path", type=click.Path())
349+
@click.pass_context
350+
def bulk_update_embeddings(
351+
ctx: click.Context,
352+
index: str,
353+
source: str,
354+
run_date: str,
355+
run_id: str,
356+
dataset_path: str,
357+
) -> None:
358+
"""Bulk update existing records with vector embeddings for an index.
359+
360+
Must provide either the name of an existing index in the cluster or a valid source.
361+
If source is provided, it will update existing records for the primary-aliased
362+
index for the source. If the provided index doesn't exist in the cluster, the
363+
method will log an error and abort.
364+
365+
The method will read vector embeddings from a TIMDEXDataset
366+
located at dataset_path using the 'timdex-dataset-api' library. The dataset
367+
is filtered by run date and run ID.
368+
"""
369+
client = ctx.obj["CLIENT"]
370+
index = helpers.validate_bulk_cli_options(index, source, client)
371+
372+
logger.info(
373+
f"Bulk updating records with embeddings from dataset '{dataset_path}' "
374+
f"into '{index}'"
375+
)
376+
377+
update_results = {"updated": 0, "errors": 0, "total": 0}
378+
379+
td = TIMDEXDataset(location=dataset_path)
380+
381+
# TODO @ghukill: https://mitlibraries.atlassian.net/browse/USE-143 # noqa: FIX002
382+
# Remove temporary code and replace with TDA
383+
# method to read embeddings
384+
# ==== START TEMPORARY CODE ====
385+
# The code below reads transformed records from
386+
# the TIMDEX dataset. To simulate embeddings,
387+
# which are added to the record post-creation, a list
388+
# of dicts containing only the 'timdex_record_id' and
389+
# the new field (i.e., what would be the embedding fields)
390+
# is created. For simulation purposes, the 'alternate_titles'
391+
# field represents the new field as it is already added
392+
# to the OpenSearch mapping in config/opensearch_mappings.json.
393+
# When testing, the user is expected to pass in a source that
394+
# does not set this field (e.g., libguides).
395+
# Once TDA has been updated to read/write embeddings
396+
# from/to the TIMDEX dataset, this code should be replaced
397+
# with a simple call to read vector embeddings, which should
398+
# return an iter of dicts representing the embeddings.
399+
transformed_records = td.read_transformed_records_iter(
400+
run_date=run_date,
401+
run_id=run_id,
402+
action="index",
403+
)
404+
405+
records_to_update = iter(
406+
[
407+
{
408+
"timdex_record_id": record["timdex_record_id"],
409+
"alternate_titles": [{"kind": "Test", "value": "Test Alternate Title"}],
410+
}
411+
for record in transformed_records
412+
]
413+
)
414+
# ==== END TEMPORARY CODE ====
415+
try:
416+
update_results.update(tim_os.bulk_update(client, index, records_to_update))
417+
except BulkIndexingError as exception:
418+
logger.info(f"Bulk update with embeddings failed: {exception}")
419+
420+
logger.info(f"Bulk update with embeddings complete: {json.dumps(update_results)}")
421+
422+
328423
@main.command()
329424
@click.option(
330425
"-s",
@@ -340,7 +435,12 @@ def bulk_update(
340435
help="Alias to promote the index to in addition to the primary alias. May "
341436
"be repeated to promote the index to multiple aliases at once.",
342437
)
343-
@click.argument("dataset_path", type=click.Path())
438+
@click.argument(
439+
"dataset_path",
440+
type=click.Path(),
441+
help="Location of TIMDEX parquet dataset from which transformed records are read."
442+
"This value can be a local filepath or an S3 URI.",
443+
)
344444
@click.pass_context
345445
def reindex_source(
346446
ctx: click.Context,
@@ -395,3 +495,7 @@ def reindex_source(
395495

396496
summary_results = {"index": index_results}
397497
logger.info(f"Reindex source complete: {json.dumps(summary_results)}")
498+
499+
500+
if __name__ == "__main__":
501+
main()

tim/errors.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,22 @@ def __init__(self, record: str, index: str, error: str) -> None:
1919
super().__init__(self.message)
2020

2121

22+
class BulkOperationError(Exception):
23+
"""Exception raised when an unexpected error is returned during a bulk operation."""
24+
25+
def __init__(self, action: str, record: str, index: str, error: str) -> None:
26+
"""Initialize exception with provided index name and error for message."""
27+
if action == "index":
28+
verb = "indexing"
29+
elif action == "update":
30+
verb = "updating"
31+
32+
self.message = (
33+
f"Error {verb} record '{record}' into index '{index}'. Details: {error}"
34+
)
35+
super().__init__(self.message)
36+
37+
2238
class IndexExistsError(Exception):
2339
"""Exception raised when attempting to create an index that is already present."""
2440

tim/helpers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,13 @@ def generate_bulk_actions(
5050
"_index": index,
5151
"_id": record["timdex_record_id"],
5252
}
53-
if action != "delete":
54-
doc["_source"] = record
53+
54+
match action:
55+
case "update":
56+
doc["doc"] = record
57+
case _ if action != "delete":
58+
doc["_source"] = record
59+
5560
yield doc
5661

5762

tim/opensearch.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from tim.errors import (
1818
AliasNotFoundError,
1919
BulkIndexingError,
20+
BulkOperationError,
2021
IndexExistsError,
2122
IndexNotFoundError,
2223
)
@@ -370,6 +371,10 @@ def bulk_index(client: OpenSearch, index: str, records: Iterator[dict]) -> dict[
370371
If an error occurs during record indexing, it will be logged and bulk indexing will
371372
continue until all records have been processed.
372373
374+
NOTE: The update performed by the "index" action results in a full replacement of the
375+
document in OpenSearch. If a partial record is provided, this will result in a new
376+
document in OpenSearch containing only the fields provided in the partial record.
377+
373378
Returns total sums of: records created, records updated, errors, and total records
374379
processed.
375380
"""
@@ -413,3 +418,54 @@ def bulk_index(client: OpenSearch, index: str, records: Iterator[dict]) -> dict[
413418
)
414419
logger.debug(response)
415420
return result
421+
422+
423+
def bulk_update(
424+
client: OpenSearch, index: str, records: Iterator[dict]
425+
) -> dict[str, int]:
426+
"""Updates existing documents in the index using the streaming bulk helper.
427+
428+
This method uses the OpenSearch "update" action, which updates existing documents
429+
and returns an error if the document does not exist. The "update" action can accept
430+
a full or partial record and will only update the corresponding fields in the
431+
document.
432+
433+
Returns total sums of: records updated, errors, and total records
434+
processed.
435+
"""
436+
result = {"updated": 0, "errors": 0, "total": 0}
437+
actions = helpers.generate_bulk_actions(index, records, "update")
438+
responses = streaming_bulk(
439+
client,
440+
actions,
441+
max_chunk_bytes=REQUEST_CONFIG["OPENSEARCH_BULK_MAX_CHUNK_BYTES"],
442+
raise_on_error=False,
443+
)
444+
for response in responses:
445+
if response[0] is False:
446+
error = response[1]["update"]["error"]
447+
record = response[1]["update"]["_id"]
448+
if error["type"] == "mapper_parsing_exception":
449+
logger.error(
450+
"Error updating record '%s'. Details: %s",
451+
record,
452+
json.dumps(error),
453+
)
454+
result["errors"] += 1
455+
else:
456+
raise BulkOperationError(
457+
"update", record, index, json.dumps(error) # noqa: EM101
458+
)
459+
elif response[1]["update"].get("result") == "updated":
460+
result["updated"] += 1
461+
else:
462+
logger.error(
463+
"Something unexpected happened during update. Bulk update response: %s",
464+
json.dumps(response),
465+
)
466+
result["errors"] += 1
467+
result["total"] += 1
468+
if result["total"] % int(os.getenv("STATUS_UPDATE_INTERVAL", "1000")) == 0:
469+
logger.info("Status update: %s records updated so far!", result["total"])
470+
logger.debug(response)
471+
return result

0 commit comments

Comments
 (0)