Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 106 additions & 2 deletions tim/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
},
{
"name": "Bulk record processing commands",
"commands": ["bulk-update", "reindex-source"],
"commands": ["bulk-update", "bulk-update-embeddings", "reindex-source"],
},
]
}
Expand Down Expand Up @@ -325,6 +325,101 @@ def bulk_update(
logger.info(f"Bulk update complete: {json.dumps(summary_results)}")


# Bulk update existing records with embeddings commands


@main.command()
@click.option(
"-i",
"--index",
help="Name of the index where the bulk update to add embeddings is performed.",
)
@click.option(
"-s",
"--source",
type=click.Choice(VALID_SOURCES),
help=(
"Source whose primary-aliased index will receive the bulk updated "
"records with embeddings."
),
)
@click.option("-d", "--run-date", help="Run date, formatted as YYYY-MM-DD.")
@click.option("-rid", "--run-id", help="Run ID.")
@click.argument("dataset_path", type=click.Path())
@click.pass_context
def bulk_update_embeddings(
ctx: click.Context,
index: str,
source: str,
run_date: str,
run_id: str,
dataset_path: str,
) -> None:
"""Bulk update existing records with vector embeddings for an index.

Must provide either the name of an existing index in the cluster or a valid source.
If source is provided, it will update existing records for the primary-aliased
index for the source. If the provided index doesn't exist in the cluster, the
method will log an error and abort.

The method will read vector embeddings from a TIMDEXDataset
located at dataset_path using the 'timdex-dataset-api' library. The dataset
is filtered by run date and run ID.
"""
client = ctx.obj["CLIENT"]
index = helpers.validate_bulk_cli_options(index, source, client)

logger.info(
f"Bulk updating records with embeddings from dataset '{dataset_path}' "
f"into '{index}'"
)

update_results = {"updated": 0, "errors": 0, "total": 0}

td = TIMDEXDataset(location=dataset_path)

# TODO @ghukill: https://mitlibraries.atlassian.net/browse/USE-143 # noqa: FIX002
# Remove temporary code and replace with TDA
# method to read embeddings
# ==== START TEMPORARY CODE ====
# The code below reads transformed records from
# the TIMDEX dataset. To simulate embeddings,
# which are added to the record post-creation, a list
# of dicts containing only the 'timdex_record_id' and
# the new field (i.e., what would be the embedding fields)
# is created. For simulation purposes, the 'alternate_titles'
# field represents the new field as it is already added
# to the OpenSearch mapping in config/opensearch_mappings.json.
# When testing, the user is expected to pass in a source that
# does not set this field (e.g., libguides).
# Once TDA has been updated to read/write embeddings
# from/to the TIMDEX dataset, this code should be replaced
# with a simple call to read vector embeddings, which should
# return an iter of dicts representing the embeddings.
transformed_records = td.read_transformed_records_iter(
run_date=run_date,
run_id=run_id,
action="index",
)

records_to_update = iter(
[
{
"timdex_record_id": record["timdex_record_id"],
"alternate_titles": [{"kind": "Test", "value": "Test Alternate Title"}],
}
for record in transformed_records
]
)
# ==== END TEMPORARY CODE ====
try:
update_results.update(tim_os.bulk_update(client, index, records_to_update))
except BulkIndexingError as exception:
logger.info(f"Bulk update with embeddings failed: {exception}")

logger.info(f"Bulk update with embeddings complete: {json.dumps(update_results)}")


@main.command()
@click.option(
"-s",
Expand All @@ -340,7 +435,12 @@ def bulk_update(
help="Alias to promote the index to in addition to the primary alias. May "
"be repeated to promote the index to multiple aliases at once.",
)
@click.argument("dataset_path", type=click.Path())
@click.argument(
"dataset_path",
type=click.Path(),
help="Location of TIMDEX parquet dataset from which transformed records are read."
"This value can be a local filepath or an S3 URI.",
)
@click.pass_context
def reindex_source(
ctx: click.Context,
Expand Down Expand Up @@ -395,3 +495,7 @@ def reindex_source(

summary_results = {"index": index_results}
logger.info(f"Reindex source complete: {json.dumps(summary_results)}")


if __name__ == "__main__":
main()
16 changes: 16 additions & 0 deletions tim/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,22 @@ def __init__(self, record: str, index: str, error: str) -> None:
super().__init__(self.message)


class BulkOperationError(Exception):
"""Exception raised when an unexpected error is returned during a bulk operation."""

def __init__(self, action: str, record: str, index: str, error: str) -> None:
"""Initialize exception with provided index name and error for message."""
if action == "index":
verb = "indexing"
elif action == "update":
verb = "updating"

self.message = (
f"Error {verb} record '{record}' into index '{index}'. Details: {error}"
)
super().__init__(self.message)


class IndexExistsError(Exception):
"""Exception raised when attempting to create an index that is already present."""

Expand Down
9 changes: 7 additions & 2 deletions tim/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,13 @@ def generate_bulk_actions(
"_index": index,
"_id": record["timdex_record_id"],
}
if action != "delete":
doc["_source"] = record

match action:
case "update":
doc["doc"] = record
case _ if action != "delete":
doc["_source"] = record

yield doc


Expand Down
56 changes: 56 additions & 0 deletions tim/opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from tim.errors import (
AliasNotFoundError,
BulkIndexingError,
BulkOperationError,
IndexExistsError,
IndexNotFoundError,
)
Expand Down Expand Up @@ -370,6 +371,10 @@ def bulk_index(client: OpenSearch, index: str, records: Iterator[dict]) -> dict[
If an error occurs during record indexing, it will be logged and bulk indexing will
continue until all records have been processed.

NOTE: The update performed by the "index" action results in a full replacement of the
document in OpenSearch. If a partial record is provided, this will result in a new
document in OpenSearch containing only the fields provided in the partial record.

Returns total sums of: records created, records updated, errors, and total records
processed.
"""
Expand Down Expand Up @@ -413,3 +418,54 @@ def bulk_index(client: OpenSearch, index: str, records: Iterator[dict]) -> dict[
)
logger.debug(response)
return result


def bulk_update(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commenting here broadly: I think the addition of this helper was a great idea. As discussed off-PR, I think it's the right call to introduce code duplication now and get things working, and then could optionally take a second pass and consider combining bulk_index, bulk_delete, and bulk_update into some kind of bulk_operation with something like action=index|delete|update.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ghukill Should I go add a ticket in the USE backlog or create a GH issue? 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try a GH issue! It doesn't feel at all required, just sort of optional code hygene and maintenance consideration. If we don't touch it, no problem, but can get a feel if GH issues are valuable + remembered.

client: OpenSearch, index: str, records: Iterator[dict]
) -> dict[str, int]:
"""Updates existing documents in the index using the streaming bulk helper.

This method uses the OpenSearch "update" action, which updates existing documents
and returns an error if the document does not exist. The "update" action can accept
a full or partial record and will only update the corresponding fields in the
document.

Returns total sums of: records updated, errors, and total records
processed.
"""
result = {"updated": 0, "errors": 0, "total": 0}
actions = helpers.generate_bulk_actions(index, records, "update")
responses = streaming_bulk(
client,
actions,
max_chunk_bytes=REQUEST_CONFIG["OPENSEARCH_BULK_MAX_CHUNK_BYTES"],
raise_on_error=False,
)
for response in responses:
if response[0] is False:
error = response[1]["update"]["error"]
record = response[1]["update"]["_id"]
if error["type"] == "mapper_parsing_exception":
logger.error(
"Error updating record '%s'. Details: %s",
record,
json.dumps(error),
)
result["errors"] += 1
else:
raise BulkOperationError(
"update", record, index, json.dumps(error) # noqa: EM101
)
elif response[1]["update"].get("result") == "updated":
result["updated"] += 1
else:
logger.error(
"Something unexpected happened during update. Bulk update response: %s",
json.dumps(response),
)
result["errors"] += 1
result["total"] += 1
if result["total"] % int(os.getenv("STATUS_UPDATE_INTERVAL", "1000")) == 0:
logger.info("Status update: %s records updated so far!", result["total"])
logger.debug(response)
return result