diff --git a/tim/cli.py b/tim/cli.py index fe35ba8..379add5 100644 --- a/tim/cli.py +++ b/tim/cli.py @@ -26,7 +26,7 @@ }, { "name": "Bulk record processing commands", - "commands": ["bulk-update", "reindex-source"], + "commands": ["bulk-update", "bulk-update-embeddings", "reindex-source"], }, ] } @@ -325,6 +325,101 @@ def bulk_update( logger.info(f"Bulk update complete: {json.dumps(summary_results)}") +# Bulk update existing records with embeddings commands + + +@main.command() +@click.option( + "-i", + "--index", + help="Name of the index where the bulk update to add embeddings is performed.", +) +@click.option( + "-s", + "--source", + type=click.Choice(VALID_SOURCES), + help=( + "Source whose primary-aliased index will receive the bulk updated " + "records with embeddings." + ), +) +@click.option("-d", "--run-date", help="Run date, formatted as YYYY-MM-DD.") +@click.option("-rid", "--run-id", help="Run ID.") +@click.argument("dataset_path", type=click.Path()) +@click.pass_context +def bulk_update_embeddings( + ctx: click.Context, + index: str, + source: str, + run_date: str, + run_id: str, + dataset_path: str, +) -> None: + """Bulk update existing records with vector embeddings for an index. + + Must provide either the name of an existing index in the cluster or a valid source. + If source is provided, it will update existing records for the primary-aliased + index for the source. If the provided index doesn't exist in the cluster, the + method will log an error and abort. + + The method will read vector embeddings from a TIMDEXDataset + located at dataset_path using the 'timdex-dataset-api' library. The dataset + is filtered by run date and run ID. + """ + client = ctx.obj["CLIENT"] + index = helpers.validate_bulk_cli_options(index, source, client) + + logger.info( + f"Bulk updating records with embeddings from dataset '{dataset_path}' " + f"into '{index}'" + ) + + update_results = {"updated": 0, "errors": 0, "total": 0} + + td = TIMDEXDataset(location=dataset_path) + + # TODO @ghukill: https://mitlibraries.atlassian.net/browse/USE-143 # noqa: FIX002 + # Remove temporary code and replace with TDA + # method to read embeddings + # ==== START TEMPORARY CODE ==== + # The code below reads transformed records from + # the TIMDEX dataset. To simulate embeddings, + # which are added to the record post-creation, a list + # of dicts containing only the 'timdex_record_id' and + # the new field (i.e., what would be the embedding fields) + # is created. For simulation purposes, the 'alternate_titles' + # field represents the new field as it is already added + # to the OpenSearch mapping in config/opensearch_mappings.json. + # When testing, the user is expected to pass in a source that + # does not set this field (e.g., libguides). + # Once TDA has been updated to read/write embeddings + # from/to the TIMDEX dataset, this code should be replaced + # with a simple call to read vector embeddings, which should + # return an iter of dicts representing the embeddings. + transformed_records = td.read_transformed_records_iter( + run_date=run_date, + run_id=run_id, + action="index", + ) + + records_to_update = iter( + [ + { + "timdex_record_id": record["timdex_record_id"], + "alternate_titles": [{"kind": "Test", "value": "Test Alternate Title"}], + } + for record in transformed_records + ] + ) + # ==== END TEMPORARY CODE ==== + try: + update_results.update(tim_os.bulk_update(client, index, records_to_update)) + except BulkIndexingError as exception: + logger.info(f"Bulk update with embeddings failed: {exception}") + + logger.info(f"Bulk update with embeddings complete: {json.dumps(update_results)}") + + @main.command() @click.option( "-s", @@ -340,7 +435,12 @@ def bulk_update( help="Alias to promote the index to in addition to the primary alias. May " "be repeated to promote the index to multiple aliases at once.", ) -@click.argument("dataset_path", type=click.Path()) +@click.argument( + "dataset_path", + type=click.Path(), + help="Location of TIMDEX parquet dataset from which transformed records are read." + "This value can be a local filepath or an S3 URI.", +) @click.pass_context def reindex_source( ctx: click.Context, @@ -395,3 +495,7 @@ def reindex_source( summary_results = {"index": index_results} logger.info(f"Reindex source complete: {json.dumps(summary_results)}") + + +if __name__ == "__main__": + main() diff --git a/tim/errors.py b/tim/errors.py index 2905c68..184ad68 100644 --- a/tim/errors.py +++ b/tim/errors.py @@ -19,6 +19,22 @@ def __init__(self, record: str, index: str, error: str) -> None: super().__init__(self.message) +class BulkOperationError(Exception): + """Exception raised when an unexpected error is returned during a bulk operation.""" + + def __init__(self, action: str, record: str, index: str, error: str) -> None: + """Initialize exception with provided index name and error for message.""" + if action == "index": + verb = "indexing" + elif action == "update": + verb = "updating" + + self.message = ( + f"Error {verb} record '{record}' into index '{index}'. Details: {error}" + ) + super().__init__(self.message) + + class IndexExistsError(Exception): """Exception raised when attempting to create an index that is already present.""" diff --git a/tim/helpers.py b/tim/helpers.py index 2c195a5..f3d1cc4 100644 --- a/tim/helpers.py +++ b/tim/helpers.py @@ -50,8 +50,13 @@ def generate_bulk_actions( "_index": index, "_id": record["timdex_record_id"], } - if action != "delete": - doc["_source"] = record + + match action: + case "update": + doc["doc"] = record + case _ if action != "delete": + doc["_source"] = record + yield doc diff --git a/tim/opensearch.py b/tim/opensearch.py index 5b29643..2bf4d0e 100644 --- a/tim/opensearch.py +++ b/tim/opensearch.py @@ -17,6 +17,7 @@ from tim.errors import ( AliasNotFoundError, BulkIndexingError, + BulkOperationError, IndexExistsError, IndexNotFoundError, ) @@ -370,6 +371,10 @@ def bulk_index(client: OpenSearch, index: str, records: Iterator[dict]) -> dict[ If an error occurs during record indexing, it will be logged and bulk indexing will continue until all records have been processed. + NOTE: The update performed by the "index" action results in a full replacement of the + document in OpenSearch. If a partial record is provided, this will result in a new + document in OpenSearch containing only the fields provided in the partial record. + Returns total sums of: records created, records updated, errors, and total records processed. """ @@ -413,3 +418,54 @@ def bulk_index(client: OpenSearch, index: str, records: Iterator[dict]) -> dict[ ) logger.debug(response) return result + + +def bulk_update( + client: OpenSearch, index: str, records: Iterator[dict] +) -> dict[str, int]: + """Updates existing documents in the index using the streaming bulk helper. + + This method uses the OpenSearch "update" action, which updates existing documents + and returns an error if the document does not exist. The "update" action can accept + a full or partial record and will only update the corresponding fields in the + document. + + Returns total sums of: records updated, errors, and total records + processed. + """ + result = {"updated": 0, "errors": 0, "total": 0} + actions = helpers.generate_bulk_actions(index, records, "update") + responses = streaming_bulk( + client, + actions, + max_chunk_bytes=REQUEST_CONFIG["OPENSEARCH_BULK_MAX_CHUNK_BYTES"], + raise_on_error=False, + ) + for response in responses: + if response[0] is False: + error = response[1]["update"]["error"] + record = response[1]["update"]["_id"] + if error["type"] == "mapper_parsing_exception": + logger.error( + "Error updating record '%s'. Details: %s", + record, + json.dumps(error), + ) + result["errors"] += 1 + else: + raise BulkOperationError( + "update", record, index, json.dumps(error) # noqa: EM101 + ) + elif response[1]["update"].get("result") == "updated": + result["updated"] += 1 + else: + logger.error( + "Something unexpected happened during update. Bulk update response: %s", + json.dumps(response), + ) + result["errors"] += 1 + result["total"] += 1 + if result["total"] % int(os.getenv("STATUS_UPDATE_INTERVAL", "1000")) == 0: + logger.info("Status update: %s records updated so far!", result["total"]) + logger.debug(response) + return result