Skip to content

Commit 620a3bf

Browse files
committed
feat(xpu): enable XPU for Llama BF16
1 parent 37276d6 commit 620a3bf

File tree

1 file changed

+24
-4
lines changed

1 file changed

+24
-4
lines changed

torchao/_models/llama/generate.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
def device_sync(device):
2020
if "cuda" in device:
2121
torch.cuda.synchronize(device)
22+
elif "xpu" in device:
23+
torch.xpu.synchronize(device)
2224
elif ("cpu" in device) or ("mps" in device):
2325
pass
2426
else:
@@ -261,7 +263,10 @@ def main(
261263

262264
for i in range(start, num_samples):
263265
if i==0:
264-
torch.cuda.reset_peak_memory_stats()
266+
if device == "cuda":
267+
torch.cuda.reset_peak_memory_stats()
268+
elif device == "xpu":
269+
torch.xpu.reset_peak_memory_stats()
265270
device_sync(device=device) # MKG
266271
if i >= 0 and interactive:
267272
prompt = input("What is your prompt? ")
@@ -291,8 +296,15 @@ def callback(x):
291296
if (i != num_samples - 1 or not profile):
292297
prof = contextlib.nullcontext()
293298
else:
294-
torch.profiler._utils._init_for_cuda_graphs()
295-
prof = torch.profiler.profile()
299+
if device == "cuda":
300+
torch.profiler._utils._init_for_cuda_graphs()
301+
prof = torch.profiler.profile()
302+
elif "xpu" in device:
303+
prof = torch.profiler.profile(
304+
activities=[
305+
torch.profiler.ProfilerActivity.CPU,
306+
torch.profiler.ProfilerActivity.XPU],
307+
)
296308
with prof:
297309
y = generate(
298310
model,
@@ -328,7 +340,8 @@ def callback(x):
328340

329341
tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item()
330342
bandwidth = model_size * tokpersec
331-
mem = torch.cuda.max_memory_reserved() /1e9
343+
max_memory_reserved = torch.cuda.max_memory_reserved() if device == "cuda" else torch.xpu.max_memory_reserved()
344+
mem = max_memory_reserved / 1e9
332345
print(f"Average tokens/sec: {tokpersec:.2f}")
333346
print(f"Average Bandwidth: {bandwidth:.02f} GB/s")
334347
print(f"Peak Memory Usage: {mem:.02f} GB")
@@ -378,6 +391,13 @@ def callback(x):
378391
parser.add_argument('--write_result', type=Path, default=None, help='Path where to write the result')
379392

380393
args = parser.parse_args()
394+
395+
if "xpu" in args.device:
396+
try:
397+
import intel_extension_for_pytorch as ipex
398+
except:
399+
raise ModuleNotFoundError(f"Intel Extension for PyTorch (intel_extension_for_pytorch) is required to run PyTorch code on Intel GPU (XPU). Please check https://github.com/intel/intel-extension-for-pytorch for details.")
400+
381401
main(
382402
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
383403
args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.save, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result

0 commit comments

Comments
 (0)