|
19 | 19 | def device_sync(device): |
20 | 20 | if "cuda" in device: |
21 | 21 | torch.cuda.synchronize(device) |
| 22 | + elif "xpu" in device: |
| 23 | + torch.xpu.synchronize(device) |
22 | 24 | elif ("cpu" in device) or ("mps" in device): |
23 | 25 | pass |
24 | 26 | else: |
@@ -261,7 +263,10 @@ def main( |
261 | 263 |
|
262 | 264 | for i in range(start, num_samples): |
263 | 265 | if i==0: |
264 | | - torch.cuda.reset_peak_memory_stats() |
| 266 | + if device == "cuda": |
| 267 | + torch.cuda.reset_peak_memory_stats() |
| 268 | + elif device == "xpu": |
| 269 | + torch.xpu.reset_peak_memory_stats() |
265 | 270 | device_sync(device=device) # MKG |
266 | 271 | if i >= 0 and interactive: |
267 | 272 | prompt = input("What is your prompt? ") |
@@ -291,8 +296,15 @@ def callback(x): |
291 | 296 | if (i != num_samples - 1 or not profile): |
292 | 297 | prof = contextlib.nullcontext() |
293 | 298 | else: |
294 | | - torch.profiler._utils._init_for_cuda_graphs() |
295 | | - prof = torch.profiler.profile() |
| 299 | + if device == "cuda": |
| 300 | + torch.profiler._utils._init_for_cuda_graphs() |
| 301 | + prof = torch.profiler.profile() |
| 302 | + elif "xpu" in device: |
| 303 | + prof = torch.profiler.profile( |
| 304 | + activities=[ |
| 305 | + torch.profiler.ProfilerActivity.CPU, |
| 306 | + torch.profiler.ProfilerActivity.XPU], |
| 307 | + ) |
296 | 308 | with prof: |
297 | 309 | y = generate( |
298 | 310 | model, |
@@ -328,7 +340,8 @@ def callback(x): |
328 | 340 |
|
329 | 341 | tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item() |
330 | 342 | bandwidth = model_size * tokpersec |
331 | | - mem = torch.cuda.max_memory_reserved() /1e9 |
| 343 | + max_memory_reserved = torch.cuda.max_memory_reserved() if device == "cuda" else torch.xpu.max_memory_reserved() |
| 344 | + mem = max_memory_reserved / 1e9 |
332 | 345 | print(f"Average tokens/sec: {tokpersec:.2f}") |
333 | 346 | print(f"Average Bandwidth: {bandwidth:.02f} GB/s") |
334 | 347 | print(f"Peak Memory Usage: {mem:.02f} GB") |
@@ -378,6 +391,13 @@ def callback(x): |
378 | 391 | parser.add_argument('--write_result', type=Path, default=None, help='Path where to write the result') |
379 | 392 |
|
380 | 393 | args = parser.parse_args() |
| 394 | + |
| 395 | + if "xpu" in args.device: |
| 396 | + try: |
| 397 | + import intel_extension_for_pytorch as ipex |
| 398 | + except: |
| 399 | + raise ModuleNotFoundError(f"Intel Extension for PyTorch (intel_extension_for_pytorch) is required to run PyTorch code on Intel GPU (XPU). Please check https://github.com/intel/intel-extension-for-pytorch for details.") |
| 400 | + |
381 | 401 | main( |
382 | 402 | args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, |
383 | 403 | args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.save, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result |
|
0 commit comments