|
19 | 19 | def device_sync(device): |
20 | 20 | if "cuda" in device: |
21 | 21 | torch.cuda.synchronize(device) |
| 22 | + elif "xpu" in device: |
| 23 | + torch.xpu.synchronize(device) |
22 | 24 | elif ("cpu" in device) or ("mps" in device): |
23 | 25 | pass |
24 | 26 | else: |
@@ -261,7 +263,10 @@ def main( |
261 | 263 |
|
262 | 264 | for i in range(start, num_samples): |
263 | 265 | if i==0: |
264 | | - torch.cuda.reset_peak_memory_stats() |
| 266 | + if "cuda" in device: |
| 267 | + torch.cuda.reset_peak_memory_stats() |
| 268 | + elif "xpu" in device: |
| 269 | + torch.xpu.reset_peak_memory_stats() |
265 | 270 | device_sync(device=device) # MKG |
266 | 271 | if i >= 0 and interactive: |
267 | 272 | prompt = input("What is your prompt? ") |
@@ -291,8 +296,15 @@ def callback(x): |
291 | 296 | if (i != num_samples - 1 or not profile): |
292 | 297 | prof = contextlib.nullcontext() |
293 | 298 | else: |
294 | | - torch.profiler._utils._init_for_cuda_graphs() |
295 | | - prof = torch.profiler.profile() |
| 299 | + if "cuda" in device: |
| 300 | + torch.profiler._utils._init_for_cuda_graphs() |
| 301 | + prof = torch.profiler.profile() |
| 302 | + elif "xpu" in device: |
| 303 | + prof = torch.profiler.profile( |
| 304 | + activities=[ |
| 305 | + torch.profiler.ProfilerActivity.CPU, |
| 306 | + torch.profiler.ProfilerActivity.XPU], |
| 307 | + ) |
296 | 308 | with prof: |
297 | 309 | y = generate( |
298 | 310 | model, |
@@ -342,7 +354,8 @@ def callback(x): |
342 | 354 |
|
343 | 355 | tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item() |
344 | 356 | bandwidth = model_size * tokpersec |
345 | | - mem = torch.cuda.max_memory_reserved() /1e9 |
| 357 | + max_memory_reserved = torch.cuda.max_memory_reserved() if "cuda" in device else torch.xpu.max_memory_reserved() |
| 358 | + mem = max_memory_reserved / 1e9 |
346 | 359 | print(f"Average tokens/sec: {tokpersec:.2f}") |
347 | 360 | print(f"Average Bandwidth: {bandwidth:.02f} GB/s") |
348 | 361 | print(f"Peak Memory Usage: {mem:.02f} GB") |
@@ -399,6 +412,7 @@ def callback(x): |
399 | 412 | parser.add_argument('--write_result', type=Path, default=None, help='Path where to write the result') |
400 | 413 |
|
401 | 414 | args = parser.parse_args() |
| 415 | + |
402 | 416 | main( |
403 | 417 | args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, |
404 | 418 | args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result |
|
0 commit comments