Skip to content

Commit 6f479d5

Browse files
authored
extend test_beam_search_early_stop_heuristic case to other device (#42078)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
1 parent d012f34 commit 6f479d5

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/generation/test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2718,7 +2718,7 @@ def test_beam_search_early_stop_heuristic(self):
27182718
question = tokenizer.apply_chat_template(
27192719
question, tokenize=False, add_generation_prompt=True, return_tensors="pt"
27202720
)
2721-
inputs = tokenizer(question, return_tensors="pt", padding=True).to("cuda")
2721+
inputs = tokenizer(question, return_tensors="pt", padding=True).to(torch_device)
27222722
outputs = model.generate(**inputs, generation_config=generation_config)
27232723
responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)
27242724
self.assertEqual(responses[0], EXPECTED_OUTPUT)
@@ -2737,7 +2737,7 @@ def test_beam_search_early_stop_heuristic(self):
27372737
cot_question = tokenizer.apply_chat_template(
27382738
cot_question, tokenize=False, add_generation_prompt=True, return_tensors="pt"
27392739
)
2740-
inputs = tokenizer([question, cot_question], return_tensors="pt", padding=True).to("cuda")
2740+
inputs = tokenizer([question, cot_question], return_tensors="pt", padding=True).to(torch_device)
27412741

27422742
outputs = model.generate(**inputs, generation_config=generation_config)
27432743
responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)

0 commit comments

Comments
 (0)