|
16 | 16 | def device_sync(device): |
17 | 17 | if "cuda" in device: |
18 | 18 | torch.cuda.synchronize(device) |
| 19 | + elif "xpu" in device: |
| 20 | + torch.xpu.synchronize(device) |
19 | 21 | elif ("cpu" in device) or ("mps" in device): |
20 | 22 | pass |
21 | 23 | else: |
@@ -271,7 +273,7 @@ def main( |
271 | 273 |
|
272 | 274 | global print |
273 | 275 | from tp import maybe_init_dist |
274 | | - rank = maybe_init_dist() |
| 276 | + rank = maybe_init_dist(device) |
275 | 277 | use_tp = rank is not None |
276 | 278 | if use_tp: |
277 | 279 | if rank != 0: |
@@ -302,7 +304,7 @@ def main( |
302 | 304 | torch.manual_seed(1234) |
303 | 305 | model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())]) |
304 | 306 | if compile: |
305 | | - if is_speculative and use_tp: # and ("cuda" in device): |
| 307 | + if is_speculative and use_tp and ("cuda" in device): |
306 | 308 | torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case |
307 | 309 |
|
308 | 310 | if is_speculative: |
@@ -353,8 +355,15 @@ def callback(x): |
353 | 355 | if (i != num_samples - 1 or not profile) or (use_tp and rank != 0): |
354 | 356 | prof = contextlib.nullcontext() |
355 | 357 | else: |
356 | | - torch.profiler._utils._init_for_cuda_graphs() |
357 | | - prof = torch.profiler.profile() |
| 358 | + if "cuda" in device: |
| 359 | + torch.profiler._utils._init_for_cuda_graphs() |
| 360 | + prof = torch.profiler.profile() |
| 361 | + elif "xpu" in device: |
| 362 | + prof = torch.profiler.profile( |
| 363 | + activities=[ |
| 364 | + torch.profiler.ProfilerActivity.CPU, |
| 365 | + torch.profiler.ProfilerActivity.XPU], |
| 366 | + ) |
358 | 367 | with prof: |
359 | 368 | y, metrics = generate( |
360 | 369 | model, |
@@ -418,6 +427,11 @@ def callback(x): |
418 | 427 | parser.add_argument('--device', type=str, default=default_device, help='Device to use') |
419 | 428 |
|
420 | 429 | args = parser.parse_args() |
| 430 | + if "xpu" in args.device: |
| 431 | + try: |
| 432 | + import intel_extension_for_pytorch as ipex |
| 433 | + except: |
| 434 | + raise ModuleNotFoundError(f"Intel Extension for PyTorch (intel_extension_for_pytorch) is required to run PyTorch code on Intel GPU (XPU). Please check https://github.com/intel/intel-extension-for-pytorch for details.") |
421 | 435 | main( |
422 | 436 | args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, |
423 | 437 | args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path, |
|
0 commit comments