@@ -308,7 +308,9 @@ def bulk_update(
308308 action = "index" ,
309309 )
310310 try :
311- index_results .update (tim_os .bulk_index (client , index , records_to_index ))
311+ index_results .update (
312+ tim_os .bulk_index (client , index , records_to_index , action = "index" )
313+ )
312314 except BulkIndexingError as exception :
313315 logger .info (f"Bulk indexing failed: { exception } " )
314316
@@ -325,6 +327,92 @@ def bulk_update(
325327 logger .info (f"Bulk update complete: { json .dumps (summary_results )} " )
326328
327329
330+ # Bulk update existing records with embeddings commands
331+
332+
333+ @main .command ()
334+ @click .option (
335+ "-i" ,
336+ "--index" ,
337+ help = "Name of the index where the bulk update to add embeddings is performed." ,
338+ )
339+ @click .option (
340+ "-s" ,
341+ "--source" ,
342+ type = click .Choice (VALID_SOURCES ),
343+ help = (
344+ "Source whose primary-aliased index will receive the bulk updated "
345+ "records with embeddings."
346+ ),
347+ )
348+ @click .option ("-d" , "--run-date" , help = "Run date, formatted as YYYY-MM-DD." )
349+ @click .option ("-rid" , "--run-id" , help = "Run ID." )
350+ @click .argument ("dataset_path" , type = click .Path ())
351+ @click .pass_context
352+ def bulk_update_embeddings (
353+ ctx : click .Context ,
354+ index : str ,
355+ source : str ,
356+ run_date : str ,
357+ run_id : str ,
358+ dataset_path : str ,
359+ ) -> None :
360+ client = ctx .obj ["CLIENT" ]
361+ index = helpers .validate_bulk_cli_options (index , source , client )
362+
363+ logger .info (
364+ f"Bulk updating records with embeddings from dataset '{ dataset_path } ' "
365+ f"into '{ index } '"
366+ )
367+
368+ update_results = {"updated" : 0 , "errors" : 0 , "total" : 0 }
369+
370+ td = TIMDEXDataset (location = dataset_path )
371+
372+ # TODO @ghukill: https://mitlibraries.atlassian.net/browse/USE-143 # noqa: FIX002
373+ # Remove temporary code and replace with TDA
374+ # method to read embeddings
375+ # ==== START TEMPORARY CODE ====
376+ # The code below reads transformed records from
377+ # the TIMDEX dataset. To simulate embeddings,
378+ # which are added to the record post-creation, a list
379+ # of dicts containing only the 'timdex_record_id' and
380+ # the new field (i.e., what would be the embedding fields)
381+ # is created. For simulation purposes, the 'alternate_titles'
382+ # field represents the new field as it is already added
383+ # to the OpenSearch mapping in config/opensearch_mappings.json.
384+ # When testing, the user is expected to pass in a source that
385+ # does not set this field (e.g., libguides).
386+ # Once TDA has been updated to read/write embeddings
387+ # from/to the TIMDEX dataset, this code should be replaced
388+ # with a simple call to read vector embeddings, which should
389+ # return an iter of dicts representing the embeddings.
390+ transformed_records = td .read_transformed_records_iter (
391+ run_date = run_date ,
392+ run_id = run_id ,
393+ action = "index" ,
394+ )
395+
396+ records_to_update = iter (
397+ [
398+ {
399+ "timdex_record_id" : record ["timdex_record_id" ],
400+ "alternate_titles" : [{"kind" : "Test" , "value" : "Test Alternate Title" }],
401+ }
402+ for record in transformed_records
403+ ]
404+ )
405+ # ==== END TEMPORARY CODE ====
406+ try :
407+ update_results .update (
408+ tim_os .bulk_index (client , index , records_to_update , action = "update" )
409+ )
410+ except BulkIndexingError as exception :
411+ logger .info (f"Bulk update with embeddings failed: { exception } " )
412+
413+ logger .info (f"Bulk update with embeddings complete: { json .dumps (update_results )} " )
414+
415+
328416@main .command ()
329417@click .option (
330418 "-s" ,
@@ -340,7 +428,12 @@ def bulk_update(
340428 help = "Alias to promote the index to in addition to the primary alias. May "
341429 "be repeated to promote the index to multiple aliases at once." ,
342430)
343- @click .argument ("dataset_path" , type = click .Path ())
431+ @click .argument (
432+ "dataset_path" ,
433+ type = click .Path (),
434+ help = "Location of TIMDEX parquet dataset from which transformed records are read."
435+ "This value can be a local filepath or an S3 URI." ,
436+ )
344437@click .pass_context
345438def reindex_source (
346439 ctx : click .Context ,
@@ -389,9 +482,15 @@ def reindex_source(
389482 action = "index" ,
390483 )
391484 try :
392- index_results .update (tim_os .bulk_index (client , index , records_to_index ))
485+ index_results .update (
486+ tim_os .bulk_index (client , index , records_to_index , action = "index" )
487+ )
393488 except BulkIndexingError as exception :
394489 logger .info (f"Bulk indexing failed: { exception } " )
395490
396491 summary_results = {"index" : index_results }
397492 logger .info (f"Reindex source complete: { json .dumps (summary_results )} " )
493+
494+
495+ if __name__ == "__main__" :
496+ main ()
0 commit comments