@@ -2718,7 +2718,7 @@ def test_beam_search_early_stop_heuristic(self):
27182718 question = tokenizer .apply_chat_template (
27192719 question , tokenize = False , add_generation_prompt = True , return_tensors = "pt"
27202720 )
2721- inputs = tokenizer (question , return_tensors = "pt" , padding = True ).to ("cuda" )
2721+ inputs = tokenizer (question , return_tensors = "pt" , padding = True ).to (torch_device )
27222722 outputs = model .generate (** inputs , generation_config = generation_config )
27232723 responses = tokenizer .batch_decode (outputs , skip_special_tokens = True )
27242724 self .assertEqual (responses [0 ], EXPECTED_OUTPUT )
@@ -2737,7 +2737,7 @@ def test_beam_search_early_stop_heuristic(self):
27372737 cot_question = tokenizer .apply_chat_template (
27382738 cot_question , tokenize = False , add_generation_prompt = True , return_tensors = "pt"
27392739 )
2740- inputs = tokenizer ([question , cot_question ], return_tensors = "pt" , padding = True ).to ("cuda" )
2740+ inputs = tokenizer ([question , cot_question ], return_tensors = "pt" , padding = True ).to (torch_device )
27412741
27422742 outputs = model .generate (** inputs , generation_config = generation_config )
27432743 responses = tokenizer .batch_decode (outputs , skip_special_tokens = True )
0 commit comments