99
1010import click
1111import jsonlines
12+ import smart_open
1213from timdex_dataset_api import TIMDEXDataset
1314
1415from embeddings .config import configure_logger , configure_sentry
@@ -156,32 +157,41 @@ def test_model_load(ctx: click.Context) -> None:
156157@click .pass_context
157158@model_required
158159@click .option (
159- "-d" ,
160160 "--dataset-location" ,
161- required = True ,
161+ required = False ,
162162 type = click .Path (),
163163 help = "TIMDEX dataset location, e.g. 's3://timdex/dataset', to read records from." ,
164164)
165165@click .option (
166166 "--run-id" ,
167- required = True ,
167+ required = False ,
168168 type = str ,
169169 help = "TIMDEX ETL run id." ,
170170)
171171@click .option (
172172 "--run-record-offset" ,
173- required = True ,
173+ required = False ,
174174 type = int ,
175175 default = 0 ,
176176 help = "TIMDEX ETL run record offset to start from, default = 0." ,
177177)
178178@click .option (
179179 "--record-limit" ,
180- required = True ,
180+ required = False ,
181181 type = int ,
182182 default = None ,
183183 help = "Limit number of records after --run-record-offset, default = None (unlimited)." ,
184184)
185+ @click .option (
186+ "--input-jsonl" ,
187+ required = False ,
188+ type = str ,
189+ default = None ,
190+ help = (
191+ "Optional filepath to JSONLines file containing "
192+ "TIMDEX records to create embeddings from."
193+ ),
194+ )
185195@click .option (
186196 "--strategy" ,
187197 type = click .Choice (list (STRATEGY_REGISTRY .keys ())),
@@ -205,50 +215,63 @@ def create_embeddings(
205215 run_id : str ,
206216 run_record_offset : int ,
207217 record_limit : int ,
218+ input_jsonl : str ,
208219 strategy : list [str ],
209220 output_jsonl : str ,
210221) -> None :
211222 """Create embeddings for TIMDEX records."""
212223 model : BaseEmbeddingModel = ctx .obj ["model" ]
213224 model .load ()
214225
215- # init TIMDEXDataset
216- timdex_dataset = TIMDEXDataset (dataset_location )
217-
218- # query TIMDEX dataset for an iterator of records
219- timdex_records = timdex_dataset .read_dicts_iter (
220- columns = [
221- "timdex_record_id" ,
222- "run_id" ,
223- "run_record_offset" ,
224- "transformed_record" ,
225- ],
226- run_id = run_id ,
227- where = f"""run_record_offset >= { run_record_offset } """ ,
228- limit = record_limit ,
229- action = "index" ,
230- )
226+ # read input records from TIMDEX dataset (default) or a JSONLines file
227+ if input_jsonl :
228+ with (
229+ smart_open .open (input_jsonl , "r" ) as file_obj , # type: ignore[no-untyped-call]
230+ jsonlines .Reader (file_obj ) as reader ,
231+ ):
232+ timdex_records = iter (list (reader ))
233+
234+ else :
235+ if not dataset_location or not run_id :
236+ raise click .UsageError (
237+ "Both '--dataset-location' and '--run-id' are required arguments "
238+ "when reading input records from the TIMDEX dataset."
239+ )
240+
241+ # init TIMDEXDataset
242+ timdex_dataset = TIMDEXDataset (dataset_location )
243+
244+ # query TIMDEX dataset for an iterator of records
245+ timdex_records = timdex_dataset .read_dicts_iter (
246+ columns = [
247+ "timdex_record_id" ,
248+ "run_id" ,
249+ "run_record_offset" ,
250+ "transformed_record" ,
251+ ],
252+ run_id = run_id ,
253+ where = f"""run_record_offset >= { run_record_offset } """ ,
254+ limit = record_limit ,
255+ action = "index" ,
256+ )
231257
232258 # create an iterator of EmbeddingInputs applying all requested strategies
233259 embedding_inputs = create_embedding_inputs (timdex_records , list (strategy ))
234260
235261 # create embeddings via the embedding model
236262 embeddings = model .create_embeddings (embedding_inputs )
237263
238- # if requested, write embeddings to a local JSONLines file
264+ # write embeddings to TIMDEX dataset (default) or to a JSONLines file
239265 if output_jsonl :
240- with jsonlines .open (
241- output_jsonl ,
242- mode = "w" ,
243- dumps = lambda obj : json .dumps (
244- obj ,
245- default = str ,
246- ),
247- ) as writer :
266+ with (
267+ smart_open .open (output_jsonl , "w" ) as s3_file , # type: ignore[no-untyped-call]
268+ jsonlines .Writer (
269+ s3_file ,
270+ dumps = lambda obj : json .dumps (obj , default = str ),
271+ ) as writer ,
272+ ):
248273 for embedding in embeddings :
249274 writer .write (embedding .to_dict ())
250-
251- # else, default writing embeddings back to TIMDEX dataset
252275 else :
253276 # WIP NOTE: write via anticipated timdex_dataset.embeddings.write(...)
254277 # NOTE: will likely use an imported TIMDEXEmbedding class from TDA, which the
0 commit comments