Skip to content

Commit f092dbe

Browse files
Stub CLI command for updating existing docs with embeddings
Why these changes are being introduced: * TIM requires a CLI command to update existing docs with embeddings. Currently, TIM uses the OpenSearch 'index' action to create/update docs using *full* records, which performs a full replacement of existing docs. To update docs using *partial* records (e.g., only a subset of TIMDEX fields [i.e., embeddings]), the OpenSearch 'update' action must be used. How this addresses that need: * Add 'bulk_update_embeddings' CLI command * Create BulkOperationError exception * Update tim.helpers.generate_bulk_actions to format inputs for OpenSearch 'update' action Side effects of this change: * TIMDEX Pipeline Lambdas will require updates to generate the 'bulk-update-embeddings' CLI command. Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/USE-122
1 parent a7689c8 commit f092dbe

File tree

4 files changed

+173
-3
lines changed

4 files changed

+173
-3
lines changed

tim/cli.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,90 @@ def bulk_update(
325325
logger.info(f"Bulk update complete: {json.dumps(summary_results)}")
326326

327327

328+
# Bulk update existing records with embeddings commands
329+
330+
331+
@main.command()
332+
@click.option(
333+
"-i",
334+
"--index",
335+
help="Name of the index where the bulk update to add embeddings is performed.",
336+
)
337+
@click.option(
338+
"-s",
339+
"--source",
340+
type=click.Choice(VALID_SOURCES),
341+
help=(
342+
"Source whose primary-aliased index will receive the bulk updated "
343+
"records with embeddings."
344+
),
345+
)
346+
@click.option("-d", "--run-date", help="Run date, formatted as YYYY-MM-DD.")
347+
@click.option("-rid", "--run-id", help="Run ID.")
348+
@click.argument("dataset_path", type=click.Path())
349+
@click.pass_context
350+
def bulk_update_embeddings(
351+
ctx: click.Context,
352+
index: str,
353+
source: str,
354+
run_date: str,
355+
run_id: str,
356+
dataset_path: str,
357+
) -> None:
358+
client = ctx.obj["CLIENT"]
359+
index = helpers.validate_bulk_cli_options(index, source, client)
360+
361+
logger.info(
362+
f"Bulk updating records with embeddings from dataset '{dataset_path}' "
363+
f"into '{index}'"
364+
)
365+
366+
update_results = {"updated": 0, "errors": 0, "total": 0}
367+
368+
td = TIMDEXDataset(location=dataset_path)
369+
370+
# TODO @ghukill: https://mitlibraries.atlassian.net/browse/USE-143 # noqa: FIX002
371+
# Remove temporary code and replace with TDA
372+
# method to read embeddings
373+
# ==== START TEMPORARY CODE ====
374+
# The code below reads transformed records from
375+
# the TIMDEX dataset. To simulate embeddings,
376+
# which are added to the record post-creation, a list
377+
# of dicts containing only the 'timdex_record_id' and
378+
# the new field (i.e., what would be the embedding fields)
379+
# is created. For simulation purposes, the 'alternate_titles'
380+
# field represents the new field as it is already added
381+
# to the OpenSearch mapping in config/opensearch_mappings.json.
382+
# When testing, the user is expected to pass in a source that
383+
# does not set this field (e.g., libguides).
384+
# Once TDA has been updated to read/write embeddings
385+
# from/to the TIMDEX dataset, this code should be replaced
386+
# with a simple call to read vector embeddings, which should
387+
# return an iter of dicts representing the embeddings.
388+
transformed_records = td.read_transformed_records_iter(
389+
run_date=run_date,
390+
run_id=run_id,
391+
action="index",
392+
)
393+
394+
records_to_update = iter(
395+
[
396+
{
397+
"timdex_record_id": record["timdex_record_id"],
398+
"alternate_titles": [{"kind": "Test", "value": "Test Alternate Title"}],
399+
}
400+
for record in transformed_records
401+
]
402+
)
403+
# ==== END TEMPORARY CODE ====
404+
try:
405+
update_results.update(tim_os.bulk_update(client, index, records_to_update))
406+
except BulkIndexingError as exception:
407+
logger.info(f"Bulk update with embeddings failed: {exception}")
408+
409+
logger.info(f"Bulk update with embeddings complete: {json.dumps(update_results)}")
410+
411+
328412
@main.command()
329413
@click.option(
330414
"-s",
@@ -340,7 +424,12 @@ def bulk_update(
340424
help="Alias to promote the index to in addition to the primary alias. May "
341425
"be repeated to promote the index to multiple aliases at once.",
342426
)
343-
@click.argument("dataset_path", type=click.Path())
427+
@click.argument(
428+
"dataset_path",
429+
type=click.Path(),
430+
help="Location of TIMDEX parquet dataset from which transformed records are read."
431+
"This value can be a local filepath or an S3 URI.",
432+
)
344433
@click.pass_context
345434
def reindex_source(
346435
ctx: click.Context,
@@ -395,3 +484,7 @@ def reindex_source(
395484

396485
summary_results = {"index": index_results}
397486
logger.info(f"Reindex source complete: {json.dumps(summary_results)}")
487+
488+
489+
if __name__ == "__main__":
490+
main()

tim/errors.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,22 @@ def __init__(self, record: str, index: str, error: str) -> None:
1919
super().__init__(self.message)
2020

2121

22+
class BulkOperationError(Exception):
23+
"""Exception raised when an unexpected error is returned during a bulk operation."""
24+
25+
def __init__(self, action: str, record: str, index: str, error: str) -> None:
26+
"""Initialize exception with provided index name and error for message."""
27+
if action == "index":
28+
verb = "indexing"
29+
elif action == "update":
30+
verb = "updating"
31+
32+
self.message = (
33+
f"Error {verb} record '{record}' into index '{index}'. Details: {error}"
34+
)
35+
super().__init__(self.message)
36+
37+
2238
class IndexExistsError(Exception):
2339
"""Exception raised when attempting to create an index that is already present."""
2440

tim/helpers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,13 @@ def generate_bulk_actions(
5050
"_index": index,
5151
"_id": record["timdex_record_id"],
5252
}
53-
if action != "delete":
54-
doc["_source"] = record
53+
54+
match action:
55+
case "update":
56+
doc["doc"] = record
57+
case _ if action != "delete":
58+
doc["_source"] = record
59+
5560
yield doc
5661

5762

tim/opensearch.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from tim.errors import (
1818
AliasNotFoundError,
1919
BulkIndexingError,
20+
BulkOperationError,
2021
IndexExistsError,
2122
IndexNotFoundError,
2223
)
@@ -370,6 +371,10 @@ def bulk_index(client: OpenSearch, index: str, records: Iterator[dict]) -> dict[
370371
If an error occurs during record indexing, it will be logged and bulk indexing will
371372
continue until all records have been processed.
372373
374+
NOTE: The update performed by the "index" action results in a full replacement of the
375+
document in OpenSearch. If a partial record is provided, this will result in a new
376+
document in OpenSearch containing only the fields provided in the partial record.
377+
373378
Returns total sums of: records created, records updated, errors, and total records
374379
processed.
375380
"""
@@ -413,3 +418,54 @@ def bulk_index(client: OpenSearch, index: str, records: Iterator[dict]) -> dict[
413418
)
414419
logger.debug(response)
415420
return result
421+
422+
423+
def bulk_update(
424+
client: OpenSearch, index: str, records: Iterator[dict]
425+
) -> dict[str, int]:
426+
"""Updates existing documents in the index using the streaming bulk helper.
427+
428+
This method uses the OpenSearch "update" action, which updates existing documents
429+
and returns an error if the document does not exist. The "update" action can accept
430+
a full or partial record and will only update the corresponding fields in the
431+
document.
432+
433+
Returns total sums of: records updated, errors, and total records
434+
processed.
435+
"""
436+
result = {"updated": 0, "errors": 0, "total": 0}
437+
actions = helpers.generate_bulk_actions(index, records, "update")
438+
responses = streaming_bulk(
439+
client,
440+
actions,
441+
max_chunk_bytes=REQUEST_CONFIG["OPENSEARCH_BULK_MAX_CHUNK_BYTES"],
442+
raise_on_error=False,
443+
)
444+
for response in responses:
445+
if response[0] is False:
446+
error = response[1]["update"]["error"]
447+
record = response[1]["update"]["_id"]
448+
if error["type"] == "mapper_parsing_exception":
449+
logger.error(
450+
"Error updating record '%s'. Details: %s",
451+
record,
452+
json.dumps(error),
453+
)
454+
result["errors"] += 1
455+
else:
456+
raise BulkOperationError(
457+
"update", record, index, json.dumps(error) # noqa: EM101
458+
)
459+
elif response[1]["update"].get("result") == "updated":
460+
result["updated"] += 1
461+
else:
462+
logger.error(
463+
"Something unexpected happened during update. Bulk update response: %s",
464+
json.dumps(response),
465+
)
466+
result["errors"] += 1
467+
result["total"] += 1
468+
if result["total"] % int(os.getenv("STATUS_UPDATE_INTERVAL", "1000")) == 0:
469+
logger.info("Status update: %s records updated so far!", result["total"])
470+
logger.debug(response)
471+
return result

0 commit comments

Comments
 (0)