Skip to content

Commit 0eaf4b5

Browse files
committed
intel gpu : enable intel gpu
1 parent f479b07 commit 0eaf4b5

File tree

4 files changed

+67
-13
lines changed

4 files changed

+67
-13
lines changed

generate.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
def device_sync(device):
1717
if "cuda" in device:
1818
torch.cuda.synchronize(device)
19+
elif "xpu" in device:
20+
torch.xpu.synchronize(device)
1921
elif ("cpu" in device) or ("mps" in device):
2022
pass
2123
else:
@@ -271,7 +273,7 @@ def main(
271273

272274
global print
273275
from tp import maybe_init_dist
274-
rank = maybe_init_dist()
276+
rank = maybe_init_dist(device)
275277
use_tp = rank is not None
276278
if use_tp:
277279
if rank != 0:
@@ -302,7 +304,7 @@ def main(
302304
torch.manual_seed(1234)
303305
model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())])
304306
if compile:
305-
if is_speculative and use_tp: # and ("cuda" in device):
307+
if is_speculative and use_tp and ("cuda" in device):
306308
torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case
307309

308310
if is_speculative:
@@ -353,8 +355,15 @@ def callback(x):
353355
if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
354356
prof = contextlib.nullcontext()
355357
else:
356-
torch.profiler._utils._init_for_cuda_graphs()
357-
prof = torch.profiler.profile()
358+
if "cuda" in device:
359+
torch.profiler._utils._init_for_cuda_graphs()
360+
prof = torch.profiler.profile()
361+
elif "xpu" in device:
362+
prof = torch.profiler.profile(
363+
activities=[
364+
torch.profiler.ProfilerActivity.CPU,
365+
torch.profiler.ProfilerActivity.XPU],
366+
)
358367
with prof:
359368
y, metrics = generate(
360369
model,
@@ -418,6 +427,11 @@ def callback(x):
418427
parser.add_argument('--device', type=str, default=default_device, help='Device to use')
419428

420429
args = parser.parse_args()
430+
if "xpu" in args.device:
431+
try:
432+
import intel_extension_for_pytorch as ipex
433+
except:
434+
raise ModuleNotFoundError(f"Intel Extension for PyTorch (intel_extension_for_pytorch) is required to run PyTorch code on Intel GPU (XPU). Please check https://github.com/intel/intel-extension-for-pytorch for details.")
421435
main(
422436
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
423437
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path,

mixtral-moe/generate.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
def device_sync(device):
1717
if "cuda" in device:
1818
torch.cuda.synchronize(device)
19+
elif "xpu" in device:
20+
torch.xpu.synchronize(device)
1921
elif "cpu" in device:
2022
pass
2123
else:
@@ -178,7 +180,7 @@ def main(
178180
assert tokenizer_path.is_file(), tokenizer_path
179181

180182
global print
181-
rank = maybe_init_dist()
183+
rank = maybe_init_dist(device)
182184
use_tp = rank is not None
183185
if use_tp:
184186
if rank != 0:
@@ -248,8 +250,15 @@ def callback(x):
248250
if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
249251
prof = contextlib.nullcontext()
250252
else:
251-
torch.profiler._utils._init_for_cuda_graphs()
252-
prof = torch.profiler.profile()
253+
if "cuda" in device:
254+
torch.profiler._utils._init_for_cuda_graphs()
255+
prof = torch.profiler.profile()
256+
elif "xpu" in device:
257+
prof = torch.profiler.profile(
258+
activities=[
259+
torch.profiler.ProfilerActivity.CPU,
260+
torch.profiler.ProfilerActivity.XPU],
261+
)
253262
with prof:
254263
y = generate(
255264
model,
@@ -302,6 +311,11 @@ def callback(x):
302311
parser.add_argument('--device', type=str, default="cuda", help='device to use')
303312

304313
args = parser.parse_args()
314+
if "xpu" in args.device:
315+
try:
316+
import intel_extension_for_pytorch as ipex
317+
except:
318+
raise ModuleNotFoundError(f"Intel Extension for PyTorch (intel_extension_for_pytorch) is required to run PyTorch code on Intel GPU (XPU). Please check https://github.com/intel/intel-extension-for-pytorch for details.")
305319
main(
306320
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
307321
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.device

mixtral-moe/tp.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def local_break():
2828
def _get_world_size() -> int:
2929
return int(os.environ.get("LOCAL_WORLD_SIZE", "1"))
3030

31-
def maybe_init_dist() -> Optional[int]:
31+
def maybe_init_dist(device) -> Optional[int]:
3232
try:
3333
# provided by torchrun
3434
rank = _get_rank()
@@ -41,8 +41,21 @@ def maybe_init_dist() -> Optional[int]:
4141
# not run via torchrun, no-op
4242
return None
4343

44-
torch.cuda.set_device(rank)
45-
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
44+
if "cuda" in device:
45+
torch.cuda.set_device(rank)
46+
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
47+
elif "xpu" in device:
48+
try:
49+
import oneccl_bindings_for_pytorch
50+
except:
51+
raise ModuleNotFoundError(f"OneCCL bindings for PyTorch (oneccl_bindings_for_pytorch) is required to run tensor parallel on Intel GPU (XPU). Please check https://github.com/intel/torch-ccl for details.")
52+
53+
os.environ['CCL_PROCESS_LAUNCHER'] = 'none'
54+
os.environ['CCL_LOCAL_SIZE'] = str(world_size)
55+
os.environ['CCL_LOCAL_RANK'] = str(rank)
56+
57+
torch.xpu.set_device(rank)
58+
dist.init_process_group(backend="ccl", rank=rank, world_size=world_size)
4659
return rank
4760

4861
rank = _get_rank()

tp.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def local_break():
3333
def _get_world_size() -> int:
3434
return int(os.environ.get("LOCAL_WORLD_SIZE", "1"))
3535

36-
def maybe_init_dist() -> Optional[int]:
36+
def maybe_init_dist(device) -> Optional[int]:
3737
try:
3838
# provided by torchrun
3939
rank = _get_rank()
@@ -46,8 +46,21 @@ def maybe_init_dist() -> Optional[int]:
4646
# not run via torchrun, no-op
4747
return None
4848

49-
torch.cuda.set_device(rank)
50-
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
49+
if "cuda" in device:
50+
torch.cuda.set_device(rank)
51+
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
52+
elif "xpu" in device:
53+
try:
54+
import oneccl_bindings_for_pytorch
55+
except:
56+
raise ModuleNotFoundError(f"OneCCL bindings for PyTorch (oneccl_bindings_for_pytorch) is required to run tensor parallel on Intel GPU (XPU). Please check https://github.com/intel/torch-ccl for details.")
57+
58+
os.environ['CCL_PROCESS_LAUNCHER'] = 'none'
59+
os.environ['CCL_LOCAL_SIZE'] = str(world_size)
60+
os.environ['CCL_LOCAL_RANK'] = str(rank)
61+
62+
torch.xpu.set_device(rank)
63+
dist.init_process_group(backend="ccl", rank=rank, world_size=world_size)
5164
return rank
5265

5366

0 commit comments

Comments
 (0)