Skip to content

Commit 128be85

Browse files
[wip]
1 parent 142f055 commit 128be85

File tree

4 files changed

+91
-25
lines changed

4 files changed

+91
-25
lines changed

tests/test_opensearch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def test_bulk_index_creates_records(
459459
test_opensearch_client, five_valid_index_libguides_records
460460
):
461461
assert tim_os.bulk_index(
462-
test_opensearch_client, "test-index", five_valid_index_libguides_records
462+
test_opensearch_client, "test-index", five_valid_index_libguides_records, "index"
463463
) == {
464464
"created": 5,
465465
"updated": 0,
@@ -474,22 +474,22 @@ def test_bulk_index_updates_records(
474474
):
475475
monkeypatch.setenv("STATUS_UPDATE_INTERVAL", "5")
476476
assert tim_os.bulk_index(
477-
test_opensearch_client, "test-index", five_valid_index_libguides_records
477+
test_opensearch_client, "test-index", five_valid_index_libguides_records, "index"
478478
) == {
479479
"created": 0,
480480
"updated": 5,
481481
"errors": 0,
482482
"total": 5,
483483
}
484-
assert "Status update: 5 records indexed so far!" in caplog.text
484+
assert "Status update: 5 records processed so far!" in caplog.text
485485

486486

487487
@my_vcr.use_cassette("opensearch/bulk_index_record_mapper_parsing_error.yaml")
488488
def test_bulk_index_logs_mapper_parsing_errors(
489489
caplog, test_opensearch_client, one_invalid_index_libguides_records
490490
):
491491
assert tim_os.bulk_index(
492-
test_opensearch_client, "test-index", one_invalid_index_libguides_records
492+
test_opensearch_client, "test-index", one_invalid_index_libguides_records, "index"
493493
) == {
494494
"created": 0,
495495
"updated": 0,

tim/cli.py

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,9 @@ def bulk_update(
308308
action="index",
309309
)
310310
try:
311-
index_results.update(tim_os.bulk_index(client, index, records_to_index))
311+
index_results.update(
312+
tim_os.bulk_index(client, index, records_to_index, action="index")
313+
)
312314
except BulkIndexingError as exception:
313315
logger.info(f"Bulk indexing failed: {exception}")
314316

@@ -343,29 +345,72 @@ def bulk_update(
343345
"records with embeddings."
344346
),
345347
)
348+
@click.option("-d", "--run-date", help="Run date, formatted as YYYY-MM-DD.")
346349
@click.option("-rid", "--run-id", help="Run ID.")
347350
@click.argument("dataset_path", type=click.Path())
348351
@click.pass_context
349352
def bulk_update_embeddings(
350-
ctx: click.Context, index: str, source: str, run_id: str, dataset_path: str
351-
):
353+
ctx: click.Context,
354+
index: str,
355+
source: str,
356+
run_date: str,
357+
run_id: str,
358+
dataset_path: str,
359+
) -> None:
352360
client = ctx.obj["CLIENT"]
353361
index = helpers.validate_bulk_cli_options(index, source, client)
354362

355363
logger.info(
356-
f"Bulk updating records with embeddings from dataset '{dataset_path}' into '{index}'"
364+
f"Bulk updating records with embeddings from dataset '{dataset_path}' "
365+
f"into '{index}'"
357366
)
358367

359368
update_results = {"updated": 0, "errors": 0, "total": 0}
360369

361370
td = TIMDEXDataset(location=dataset_path)
362371

363-
# TODO: update TDA to read embeddings
372+
# TODO @ghukill: https://mitlibraries.atlassian.net/browse/USE-143 # noqa: FIX002
373+
# Remove temporary code and replace with TDA
374+
# method to read embeddings
375+
# ==== START TEMPORARY CODE ====
376+
# The code below reads transformed records from
377+
# the TIMDEX dataset. To simulate embeddings,
378+
# which are added to the record post-creation, a list
379+
# of dicts containing only the 'timdex_record_id' and
380+
# the new field (i.e., what would be the embedding fields)
381+
# is created. For simulation purposes, the 'alternate_titles'
382+
# field represents the new field as it is already added
383+
# to the OpenSearch mapping in config/opensearch_mappings.json.
384+
# When testing, the user is expected to pass in a source that
385+
# does not set this field (e.g., libguides).
386+
# Once TDA has been updated to read/write embeddings
387+
# from/to the TIMDEX dataset, this code should be replaced
388+
# with a simple call to read vector embeddings, which should
389+
# return an iter of dicts representing the embeddings.
390+
transformed_records = td.read_transformed_records_iter(
391+
run_date=run_date,
392+
run_id=run_id,
393+
action="index",
394+
)
364395

396+
records_to_update = iter(
397+
[
398+
{
399+
"timdex_record_id": record["timdex_record_id"],
400+
"alternate_titles": [{"kind": "Test", "value": "Test Alternate Title"}],
401+
}
402+
for record in transformed_records
403+
]
404+
)
405+
# ==== END TEMPORARY CODE ====
365406
try:
366-
update_results.update(tim_os.bulk_index(client, index, records_to_index))
407+
update_results.update(
408+
tim_os.bulk_index(client, index, records_to_update, action="update")
409+
)
367410
except BulkIndexingError as exception:
368-
logger.info(f"Bulk indexing failed: {exception}")
411+
logger.info(f"Bulk update with embeddings failed: {exception}")
412+
413+
logger.info(f"Bulk update with embeddings complete: {json.dumps(update_results)}")
369414

370415

371416
@main.command()
@@ -383,7 +428,12 @@ def bulk_update_embeddings(
383428
help="Alias to promote the index to in addition to the primary alias. May "
384429
"be repeated to promote the index to multiple aliases at once.",
385430
)
386-
@click.argument("dataset_path", type=click.Path())
431+
@click.argument(
432+
"dataset_path",
433+
type=click.Path(),
434+
help="Location of TIMDEX parquet dataset from which transformed records are read."
435+
"This value can be a local filepath or an S3 URI.",
436+
)
387437
@click.pass_context
388438
def reindex_source(
389439
ctx: click.Context,
@@ -432,9 +482,15 @@ def reindex_source(
432482
action="index",
433483
)
434484
try:
435-
index_results.update(tim_os.bulk_index(client, index, records_to_index))
485+
index_results.update(
486+
tim_os.bulk_index(client, index, records_to_index, action="index")
487+
)
436488
except BulkIndexingError as exception:
437489
logger.info(f"Bulk indexing failed: {exception}")
438490

439491
summary_results = {"index": index_results}
440492
logger.info(f"Reindex source complete: {json.dumps(summary_results)}")
493+
494+
495+
if __name__ == "__main__":
496+
main()

tim/helpers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,13 @@ def generate_bulk_actions(
5050
"_index": index,
5151
"_id": record["timdex_record_id"],
5252
}
53-
if action != "delete":
54-
doc["_source"] = record
53+
54+
match action:
55+
case "update":
56+
doc["doc"] = record
57+
case _ if action != "delete":
58+
doc["_source"] = record
59+
5560
yield doc
5661

5762

tim/opensearch.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -359,22 +359,26 @@ def bulk_delete(
359359
return result
360360

361361

362-
def bulk_index(client: OpenSearch, index: str, records: Iterator[dict]) -> dict[str, int]:
362+
def bulk_index(
363+
client: OpenSearch, index: str, records: Iterator[dict], action: str
364+
) -> dict[str, int]:
363365
"""Indexes records into an existing index using the streaming bulk helper.
364366
365-
This action function uses the OpenSearch "index" action, which is a
366-
combination of create and update: if a record with the same _id exists in the
367-
index, it will be updated. If it does not exist, the record will be indexed as a
368-
new document.
367+
This method uses the OpenSearch "index" and "update" operations.
368+
- Setting `action` to "index" will either create or update a record.
369+
If a record with the same _id exists in the index, it will be updated;
370+
if it does not exist, the record will be added as a new document.
371+
- Setting `action` to "update" will update a document only if it exists
372+
in the index. Otherwise, an error is raised.
369373
370-
If an error occurs during record indexing, it will be logged and bulk indexing will
371-
continue until all records have been processed.
374+
If an error occurs during the operation, it will be logged, and the bulk
375+
operation will continue until all records have been processed.
372376
373377
Returns total sums of: records created, records updated, errors, and total records
374378
processed.
375379
"""
376380
result = {"created": 0, "updated": 0, "errors": 0, "total": 0}
377-
actions = helpers.generate_bulk_actions(index, records, "index")
381+
actions = helpers.generate_bulk_actions(index, records, action)
378382
responses = streaming_bulk(
379383
client,
380384
actions,
@@ -400,13 +404,14 @@ def bulk_index(client: OpenSearch, index: str, records: Iterator[dict]) -> dict[
400404
result["updated"] += 1
401405
else:
402406
logger.error(
403-
"Something unexpected happened during ingest. Bulk index response: %s",
407+
"Something unexpected happened during ingest. "
408+
f"Bulk {action} response: %s",
404409
json.dumps(response),
405410
)
406411
result["errors"] += 1
407412
result["total"] += 1
408413
if result["total"] % int(os.getenv("STATUS_UPDATE_INTERVAL", "1000")) == 0:
409-
logger.info("Status update: %s records indexed so far!", result["total"])
414+
logger.info("Status update: %s records processed so far!", result["total"])
410415
logger.info("All records ingested, refreshing index.")
411416
response = client.indices.refresh(
412417
index=index,

0 commit comments

Comments
 (0)