|
19 | 19 | def device_sync(device): |
20 | 20 | if "cuda" in device: |
21 | 21 | torch.cuda.synchronize(device) |
| 22 | + elif "xpu" in device: |
| 23 | + torch.xpu.synchronize(device) |
22 | 24 | elif ("cpu" in device) or ("mps" in device): |
23 | 25 | pass |
24 | 26 | else: |
@@ -288,7 +290,10 @@ def main( |
288 | 290 |
|
289 | 291 | for i in range(start, num_samples): |
290 | 292 | if i==0: |
291 | | - torch.cuda.reset_peak_memory_stats() |
| 293 | + if "cuda" in device: |
| 294 | + torch.cuda.reset_peak_memory_stats() |
| 295 | + elif "xpu" in device: |
| 296 | + torch.xpu.reset_peak_memory_stats() |
292 | 297 | device_sync(device=device) # MKG |
293 | 298 | if i >= 0 and interactive: |
294 | 299 | prompt = input("What is your prompt? ") |
@@ -318,8 +323,15 @@ def callback(x): |
318 | 323 | if (i != num_samples - 1 or not profile): |
319 | 324 | prof = contextlib.nullcontext() |
320 | 325 | else: |
321 | | - torch.profiler._utils._init_for_cuda_graphs() |
322 | | - prof = torch.profiler.profile() |
| 326 | + if "cuda" in device: |
| 327 | + torch.profiler._utils._init_for_cuda_graphs() |
| 328 | + prof = torch.profiler.profile() |
| 329 | + elif "xpu" in device: |
| 330 | + prof = torch.profiler.profile( |
| 331 | + activities=[ |
| 332 | + torch.profiler.ProfilerActivity.CPU, |
| 333 | + torch.profiler.ProfilerActivity.XPU], |
| 334 | + ) |
323 | 335 | with prof: |
324 | 336 | y = generate( |
325 | 337 | model, |
@@ -369,7 +381,8 @@ def callback(x): |
369 | 381 |
|
370 | 382 | tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item() |
371 | 383 | bandwidth = model_size * tokpersec |
372 | | - mem = torch.cuda.max_memory_reserved() /1e9 |
| 384 | + max_memory_reserved = torch.cuda.max_memory_reserved() if "cuda" in device else torch.xpu.max_memory_reserved() |
| 385 | + mem = max_memory_reserved / 1e9 |
373 | 386 | print(f"Average tokens/sec: {tokpersec:.2f}") |
374 | 387 | print(f"Average Bandwidth: {bandwidth:.02f} GB/s") |
375 | 388 | print(f"Peak Memory Usage: {mem:.02f} GB") |
@@ -431,6 +444,7 @@ def callback(x): |
431 | 444 | parser.add_argument('--write_result', type=Path, default=None, help='Path where to write the result') |
432 | 445 |
|
433 | 446 | args = parser.parse_args() |
| 447 | + |
434 | 448 | main( |
435 | 449 | args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, |
436 | 450 | args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result |
|
0 commit comments