File tree Expand file tree Collapse file tree 2 files changed +10
-0
lines changed
aiu_fms_testing_utils/utils Expand file tree Collapse file tree 2 files changed +10
-0
lines changed Original file line number Diff line number Diff line change @@ -45,6 +45,14 @@ def wrap_encoder(model: nn.Module) -> HFModelArchitecture:
4545 model .config .linear_config .pop ("linear_type" , None )
4646 return to_hf_api (model , task_specific_params = None )
4747
48+ def move_to_device (batch : dict , device : torch .device ) -> dict :
49+ """Move batch to selected device."""
50+
51+ batch_on_device = {}
52+ for k , v in batch .items ():
53+ batch_on_device [k ] = v .to (device )
54+ return batch_on_device
55+
4856
4957class EncoderQAInfer ():
5058 """Run QuestionAnswering task with encoder models."""
@@ -587,6 +595,7 @@ def run_evaluation(self) -> None:
587595 with torch .no_grad ():
588596 dprint (f"Step { step + 1 } / { len (eval_dataloader )} " )
589597 batch = self .convert_batch_to_fms_style (batch )
598+ batch = move_to_device (batch , args .device )
590599 start_logits , end_logits = self .model (** batch )
591600 all_start_logits .append (start_logits .cpu ().numpy ())
592601 all_end_logits .append (end_logits .cpu ().numpy ())
Original file line number Diff line number Diff line change 3636
3737# Main model setup
3838default_dtype , device , dist_strat = setup_model (args )
39+ args .device = device
3940
4041# Retrieve linear configuration (quantized or not) to instantiate FMS model
4142linear_config = get_linear_config (args )
You can’t perform that action at this time.
0 commit comments