Skip to content

Commit 6fc261d

Browse files
committed
feat(xpu): enable XPU for Llama
Signed-off-by: dbyoung18 <yang5.yang@intel.com>
1 parent cfabc13 commit 6fc261d

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

torchao/_models/llama/generate.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
def device_sync(device):
2020
if "cuda" in device:
2121
torch.cuda.synchronize(device)
22+
elif "xpu" in device:
23+
torch.xpu.synchronize(device)
2224
elif ("cpu" in device) or ("mps" in device):
2325
pass
2426
else:
@@ -261,7 +263,10 @@ def main(
261263

262264
for i in range(start, num_samples):
263265
if i==0:
264-
torch.cuda.reset_peak_memory_stats()
266+
if "cuda" in device:
267+
torch.cuda.reset_peak_memory_stats()
268+
elif "xpu" in device:
269+
torch.xpu.reset_peak_memory_stats()
265270
device_sync(device=device) # MKG
266271
if i >= 0 and interactive:
267272
prompt = input("What is your prompt? ")
@@ -291,8 +296,15 @@ def callback(x):
291296
if (i != num_samples - 1 or not profile):
292297
prof = contextlib.nullcontext()
293298
else:
294-
torch.profiler._utils._init_for_cuda_graphs()
295-
prof = torch.profiler.profile()
299+
if "cuda" in device:
300+
torch.profiler._utils._init_for_cuda_graphs()
301+
prof = torch.profiler.profile()
302+
elif "xpu" in device:
303+
prof = torch.profiler.profile(
304+
activities=[
305+
torch.profiler.ProfilerActivity.CPU,
306+
torch.profiler.ProfilerActivity.XPU],
307+
)
296308
with prof:
297309
y = generate(
298310
model,
@@ -342,7 +354,8 @@ def callback(x):
342354

343355
tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item()
344356
bandwidth = model_size * tokpersec
345-
mem = torch.cuda.max_memory_reserved() /1e9
357+
max_memory_reserved = torch.cuda.max_memory_reserved() if "cuda" in device else torch.xpu.max_memory_reserved()
358+
mem = max_memory_reserved / 1e9
346359
print(f"Average tokens/sec: {tokpersec:.2f}")
347360
print(f"Average Bandwidth: {bandwidth:.02f} GB/s")
348361
print(f"Peak Memory Usage: {mem:.02f} GB")
@@ -399,6 +412,7 @@ def callback(x):
399412
parser.add_argument('--write_result', type=Path, default=None, help='Path where to write the result')
400413

401414
args = parser.parse_args()
415+
402416
main(
403417
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
404418
args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result

0 commit comments

Comments
 (0)