Skip to content

Commit 3af7b93

Browse files
committed
intel gpu : enable intel gpu
1 parent f479b07 commit 3af7b93

File tree

4 files changed

+73
-16
lines changed

4 files changed

+73
-16
lines changed

generate.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
def device_sync(device):
1717
if "cuda" in device:
1818
torch.cuda.synchronize(device)
19+
elif "xpu" in device:
20+
torch.xpu.synchronize(device)
1921
elif ("cpu" in device) or ("mps" in device):
2022
pass
2123
else:
@@ -24,7 +26,8 @@ def device_sync(device):
2426

2527
torch._inductor.config.coordinate_descent_tuning = True
2628
torch._inductor.config.triton.unique_kernel_names = True
27-
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
29+
if hasattr(torch._inductor.config, "fx_graph_cache"):
30+
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
2831

2932
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
3033

@@ -271,7 +274,7 @@ def main(
271274

272275
global print
273276
from tp import maybe_init_dist
274-
rank = maybe_init_dist()
277+
rank = maybe_init_dist(device)
275278
use_tp = rank is not None
276279
if use_tp:
277280
if rank != 0:
@@ -302,7 +305,7 @@ def main(
302305
torch.manual_seed(1234)
303306
model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())])
304307
if compile:
305-
if is_speculative and use_tp: # and ("cuda" in device):
308+
if is_speculative and use_tp and ("cuda" in device):
306309
torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case
307310

308311
if is_speculative:
@@ -353,8 +356,15 @@ def callback(x):
353356
if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
354357
prof = contextlib.nullcontext()
355358
else:
356-
torch.profiler._utils._init_for_cuda_graphs()
357-
prof = torch.profiler.profile()
359+
if "cuda" in device:
360+
torch.profiler._utils._init_for_cuda_graphs()
361+
prof = torch.profiler.profile()
362+
elif "xpu" in device:
363+
prof = torch.profiler.profile(
364+
activities=[
365+
torch.profiler.ProfilerActivity.CPU,
366+
torch.profiler.ProfilerActivity.XPU],
367+
)
358368
with prof:
359369
y, metrics = generate(
360370
model,
@@ -418,6 +428,11 @@ def callback(x):
418428
parser.add_argument('--device', type=str, default=default_device, help='Device to use')
419429

420430
args = parser.parse_args()
431+
if "xpu" in args.device:
432+
try:
433+
import intel_extension_for_pytorch as ipex
434+
except:
435+
raise ModuleNotFoundError(f"Intel Extension for PyTorch (intel_extension_for_pytorch) is required to run PyTorch code on Intel GPU (XPU). Please check https://github.com/intel/intel-extension-for-pytorch for details.")
421436
main(
422437
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
423438
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path,

mixtral-moe/generate.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
def device_sync(device):
1717
if "cuda" in device:
1818
torch.cuda.synchronize(device)
19+
elif "xpu" in device:
20+
torch.xpu.synchronize(device)
1921
elif "cpu" in device:
2022
pass
2123
else:
@@ -24,7 +26,8 @@ def device_sync(device):
2426

2527
torch._inductor.config.coordinate_descent_tuning = True
2628
torch._inductor.config.triton.unique_kernel_names = True
27-
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
29+
if hasattr(torch._inductor.config, "fx_graph_cache"):
30+
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
2831

2932

3033
# support running without installing as a package
@@ -178,7 +181,7 @@ def main(
178181
assert tokenizer_path.is_file(), tokenizer_path
179182

180183
global print
181-
rank = maybe_init_dist()
184+
rank = maybe_init_dist(device)
182185
use_tp = rank is not None
183186
if use_tp:
184187
if rank != 0:
@@ -203,7 +206,8 @@ def main(
203206
torch.manual_seed(1234)
204207
model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())])
205208
if compile:
206-
torch._inductor.config.assert_indirect_indexing = False
209+
if hasattr(torch._inductor.config, "assert_indirect_indexing"):
210+
torch._inductor.config.assert_indirect_indexing = False
207211

208212
global decode_one_token, prefill
209213
decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)
@@ -248,8 +252,15 @@ def callback(x):
248252
if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
249253
prof = contextlib.nullcontext()
250254
else:
251-
torch.profiler._utils._init_for_cuda_graphs()
252-
prof = torch.profiler.profile()
255+
if "cuda" in device:
256+
torch.profiler._utils._init_for_cuda_graphs()
257+
prof = torch.profiler.profile()
258+
elif "xpu" in device:
259+
prof = torch.profiler.profile(
260+
activities=[
261+
torch.profiler.ProfilerActivity.CPU,
262+
torch.profiler.ProfilerActivity.XPU],
263+
)
253264
with prof:
254265
y = generate(
255266
model,
@@ -302,6 +313,11 @@ def callback(x):
302313
parser.add_argument('--device', type=str, default="cuda", help='device to use')
303314

304315
args = parser.parse_args()
316+
if "xpu" in args.device:
317+
try:
318+
import intel_extension_for_pytorch as ipex
319+
except:
320+
raise ModuleNotFoundError(f"Intel Extension for PyTorch (intel_extension_for_pytorch) is required to run PyTorch code on Intel GPU (XPU). Please check https://github.com/intel/intel-extension-for-pytorch for details.")
305321
main(
306322
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
307323
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.device

mixtral-moe/tp.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def local_break():
2828
def _get_world_size() -> int:
2929
return int(os.environ.get("LOCAL_WORLD_SIZE", "1"))
3030

31-
def maybe_init_dist() -> Optional[int]:
31+
def maybe_init_dist(device) -> Optional[int]:
3232
try:
3333
# provided by torchrun
3434
rank = _get_rank()
@@ -41,8 +41,21 @@ def maybe_init_dist() -> Optional[int]:
4141
# not run via torchrun, no-op
4242
return None
4343

44-
torch.cuda.set_device(rank)
45-
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
44+
if "cuda" in device:
45+
torch.cuda.set_device(rank)
46+
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
47+
elif "xpu" in device:
48+
try:
49+
import oneccl_bindings_for_pytorch
50+
except:
51+
raise ModuleNotFoundError(f"OneCCL bindings for PyTorch (oneccl_bindings_for_pytorch) is required to run tensor parallel on Intel GPU (XPU). Please check https://github.com/intel/torch-ccl for details.")
52+
53+
os.environ['CCL_PROCESS_LAUNCHER'] = 'none'
54+
os.environ['CCL_LOCAL_SIZE'] = str(world_size)
55+
os.environ['CCL_LOCAL_RANK'] = str(rank)
56+
57+
torch.xpu.set_device(rank)
58+
dist.init_process_group(backend="ccl", rank=rank, world_size=world_size)
4659
return rank
4760

4861
rank = _get_rank()

tp.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def local_break():
3333
def _get_world_size() -> int:
3434
return int(os.environ.get("LOCAL_WORLD_SIZE", "1"))
3535

36-
def maybe_init_dist() -> Optional[int]:
36+
def maybe_init_dist(device) -> Optional[int]:
3737
try:
3838
# provided by torchrun
3939
rank = _get_rank()
@@ -46,8 +46,21 @@ def maybe_init_dist() -> Optional[int]:
4646
# not run via torchrun, no-op
4747
return None
4848

49-
torch.cuda.set_device(rank)
50-
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
49+
if "cuda" in device:
50+
torch.cuda.set_device(rank)
51+
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
52+
elif "xpu" in device:
53+
try:
54+
import oneccl_bindings_for_pytorch
55+
except:
56+
raise ModuleNotFoundError(f"OneCCL bindings for PyTorch (oneccl_bindings_for_pytorch) is required to run tensor parallel on Intel GPU (XPU). Please check https://github.com/intel/torch-ccl for details.")
57+
58+
os.environ['CCL_PROCESS_LAUNCHER'] = 'none'
59+
os.environ['CCL_LOCAL_SIZE'] = str(world_size)
60+
os.environ['CCL_LOCAL_RANK'] = str(rank)
61+
62+
torch.xpu.set_device(rank)
63+
dist.init_process_group(backend="ccl", rank=rank, world_size=world_size)
5164
return rank
5265

5366

0 commit comments

Comments
 (0)