Skip to content

Commit 1be8739

Browse files
committed
intel gpu : enable intel gpu
1 parent 30d69b3 commit 1be8739

File tree

5 files changed

+106
-20
lines changed

5 files changed

+106
-20
lines changed

generate.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
def device_sync(device):
1717
if "cuda" in device:
1818
torch.cuda.synchronize(device)
19+
elif "xpu" in device:
20+
torch.xpu.synchronize(device)
1921
elif ("cpu" in device) or ("mps" in device):
2022
pass
2123
else:
@@ -24,7 +26,8 @@ def device_sync(device):
2426

2527
torch._inductor.config.coordinate_descent_tuning = True
2628
torch._inductor.config.triton.unique_kernel_names = True
27-
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
29+
if hasattr(torch._inductor.config, "fx_graph_cache"):
30+
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
2831

2932
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
3033

@@ -238,8 +241,9 @@ def _load_model(checkpoint_path, device, precision, use_tp):
238241
model.load_state_dict(checkpoint, assign=True)
239242

240243
if use_tp:
241-
from tp import apply_tp
244+
from tp import apply_tp, global_device
242245
print("Applying tensor parallel to model ...")
246+
global_device(device)
243247
apply_tp(model)
244248

245249
model = model.to(device=device, dtype=precision)
@@ -271,7 +275,7 @@ def main(
271275

272276
global print
273277
from tp import maybe_init_dist
274-
rank = maybe_init_dist()
278+
rank = maybe_init_dist(device)
275279
use_tp = rank is not None
276280
if use_tp:
277281
if rank != 0:
@@ -303,7 +307,7 @@ def main(
303307
torch.manual_seed(1234)
304308
model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())])
305309
if compile:
306-
if is_speculative and use_tp: # and ("cuda" in device):
310+
if is_speculative and use_tp and ("cuda" in device):
307311
torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case
308312

309313
if is_speculative:
@@ -354,8 +358,15 @@ def callback(x):
354358
if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
355359
prof = contextlib.nullcontext()
356360
else:
357-
torch.profiler._utils._init_for_cuda_graphs()
358-
prof = torch.profiler.profile()
361+
if "cuda" in device:
362+
torch.profiler._utils._init_for_cuda_graphs()
363+
prof = torch.profiler.profile()
364+
elif "xpu" in device:
365+
prof = torch.profiler.profile(
366+
activities=[
367+
torch.profiler.ProfilerActivity.CPU,
368+
torch.profiler.ProfilerActivity.XPU],
369+
)
359370
with prof:
360371
y, metrics = generate(
361372
model,
@@ -419,6 +430,11 @@ def callback(x):
419430
parser.add_argument('--device', type=str, default=default_device, help='Device to use')
420431

421432
args = parser.parse_args()
433+
if "xpu" in args.device:
434+
try:
435+
import intel_extension_for_pytorch as ipex
436+
except:
437+
raise ModuleNotFoundError(f"Intel Extension for PyTorch (intel_extension_for_pytorch) is required to run PyTorch code on Intel GPU (XPU). Please check https://github.com/intel/intel-extension-for-pytorch for details.")
422438
main(
423439
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
424440
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path,

mixtral-moe/generate.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
def device_sync(device):
1717
if "cuda" in device:
1818
torch.cuda.synchronize(device)
19+
elif "xpu" in device:
20+
torch.xpu.synchronize(device)
1921
elif "cpu" in device:
2022
pass
2123
else:
@@ -24,7 +26,8 @@ def device_sync(device):
2426

2527
torch._inductor.config.coordinate_descent_tuning = True
2628
torch._inductor.config.triton.unique_kernel_names = True
27-
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
29+
if hasattr(torch._inductor.config, "fx_graph_cache"):
30+
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
2831

2932

3033
# support running without installing as a package
@@ -178,7 +181,7 @@ def main(
178181
assert tokenizer_path.is_file(), str(tokenizer_path)
179182

180183
global print
181-
rank = maybe_init_dist()
184+
rank = maybe_init_dist(device)
182185
use_tp = rank is not None
183186
if use_tp:
184187
if rank != 0:
@@ -203,7 +206,8 @@ def main(
203206
torch.manual_seed(1234)
204207
model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())])
205208
if compile:
206-
torch._inductor.config.assert_indirect_indexing = False
209+
if hasattr(torch._inductor.config, "assert_indirect_indexing"):
210+
torch._inductor.config.assert_indirect_indexing = False
207211

208212
global decode_one_token, prefill
209213
decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)
@@ -248,8 +252,15 @@ def callback(x):
248252
if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
249253
prof = contextlib.nullcontext()
250254
else:
251-
torch.profiler._utils._init_for_cuda_graphs()
252-
prof = torch.profiler.profile()
255+
if "cuda" in device:
256+
torch.profiler._utils._init_for_cuda_graphs()
257+
prof = torch.profiler.profile()
258+
elif "xpu" in device:
259+
prof = torch.profiler.profile(
260+
activities=[
261+
torch.profiler.ProfilerActivity.CPU,
262+
torch.profiler.ProfilerActivity.XPU],
263+
)
253264
with prof:
254265
y = generate(
255266
model,
@@ -302,6 +313,11 @@ def callback(x):
302313
parser.add_argument('--device', type=str, default="cuda", help='device to use')
303314

304315
args = parser.parse_args()
316+
if "xpu" in args.device:
317+
try:
318+
import intel_extension_for_pytorch as ipex
319+
except:
320+
raise ModuleNotFoundError(f"Intel Extension for PyTorch (intel_extension_for_pytorch) is required to run PyTorch code on Intel GPU (XPU). Please check https://github.com/intel/intel-extension-for-pytorch for details.")
305321
main(
306322
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
307323
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.device

mixtral-moe/tp.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def local_break():
2828
def _get_world_size() -> int:
2929
return int(os.environ.get("LOCAL_WORLD_SIZE", "1"))
3030

31-
def maybe_init_dist() -> Optional[int]:
31+
def maybe_init_dist(device) -> Optional[int]:
3232
try:
3333
# provided by torchrun
3434
rank = _get_rank()
@@ -41,8 +41,21 @@ def maybe_init_dist() -> Optional[int]:
4141
# not run via torchrun, no-op
4242
return None
4343

44-
torch.cuda.set_device(rank)
45-
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
44+
if "cuda" in device:
45+
torch.cuda.set_device(rank)
46+
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
47+
elif "xpu" in device:
48+
try:
49+
import oneccl_bindings_for_pytorch
50+
except:
51+
raise ModuleNotFoundError(f"OneCCL bindings for PyTorch (oneccl_bindings_for_pytorch) is required to run tensor parallel on Intel GPU (XPU). Please check https://github.com/intel/torch-ccl for details.")
52+
53+
os.environ['CCL_PROCESS_LAUNCHER'] = 'none'
54+
os.environ['CCL_LOCAL_SIZE'] = str(world_size)
55+
os.environ['CCL_LOCAL_RANK'] = str(rank)
56+
57+
torch.xpu.set_device(rank)
58+
dist.init_process_group(backend="ccl", rank=rank, world_size=world_size)
4659
return rank
4760

4861
rank = _get_rank()

quantize.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,6 @@ def quantize(
539539
device: str = default_device,
540540
) -> None:
541541
assert checkpoint_path.is_file(), checkpoint_path
542-
device = 'cpu'
543542
precision = torch.bfloat16
544543

545544
print("Loading model ...")
@@ -621,4 +620,9 @@ def quantize(
621620
parser.add_argument('--device', type=str, default=default_device, help='device to use')
622621

623622
args = parser.parse_args()
623+
if "xpu" in args.device:
624+
try:
625+
import intel_extension_for_pytorch as ipex
626+
except:
627+
raise ModuleNotFoundError(f"Intel Extension for PyTorch (intel_extension_for_pytorch) is required to run PyTorch code on Intel GPU (XPU). Please check https://github.com/intel/intel-extension-for-pytorch for details.")
624628
quantize(args.checkpoint_path, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label, args.device)

tp.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818
from model import Attention, FeedForward, Transformer
1919
from quantize import WeightOnlyInt4Linear
2020

21+
Int4Device = "cpu"
22+
23+
def global_device(device: str = "cpu"):
24+
global Int4Device
25+
Int4Device = device
26+
2127

2228
def _get_rank() -> int:
2329
return int(os.environ.get("LOCAL_RANK", "0"))
@@ -33,7 +39,7 @@ def local_break():
3339
def _get_world_size() -> int:
3440
return int(os.environ.get("LOCAL_WORLD_SIZE", "1"))
3541

36-
def maybe_init_dist() -> Optional[int]:
42+
def maybe_init_dist(device) -> Optional[int]:
3743
try:
3844
# provided by torchrun
3945
rank = _get_rank()
@@ -46,8 +52,21 @@ def maybe_init_dist() -> Optional[int]:
4652
# not run via torchrun, no-op
4753
return None
4854

49-
torch.cuda.set_device(rank)
50-
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
55+
if "cuda" in device:
56+
torch.cuda.set_device(rank)
57+
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
58+
elif "xpu" in device:
59+
try:
60+
import oneccl_bindings_for_pytorch
61+
except:
62+
raise ModuleNotFoundError(f"OneCCL bindings for PyTorch (oneccl_bindings_for_pytorch) is required to run tensor parallel on Intel GPU (XPU). Please check https://github.com/intel/torch-ccl for details.")
63+
64+
os.environ['CCL_PROCESS_LAUNCHER'] = 'none'
65+
os.environ['CCL_LOCAL_SIZE'] = str(world_size)
66+
os.environ['CCL_LOCAL_RANK'] = str(rank)
67+
68+
torch.xpu.set_device(rank)
69+
dist.init_process_group(backend="ccl", rank=rank, world_size=world_size)
5170
return rank
5271

5372

@@ -83,14 +102,32 @@ def shard_qkv(qkv, dim, weight_splits):
83102
assert len(weight_splits) == 3
84103

85104
if isinstance(linear, WeightOnlyInt4Linear):
86-
sharded_weight = shard_qkv(linear.weight, shard_dim, [i//8 for i in weight_splits])
105+
if ("xpu" in Int4Device):
106+
in_features = linear.in_features
107+
out_features = linear.out_features//8
108+
sharded_weight_size = list(linear.weight.size())
109+
sharded_weight_size[shard_dim] = -1
110+
weight = linear.weight.reshape((in_features, out_features))
111+
sharded_weight = shard_qkv(weight, 1 - shard_dim, [i//8 for i in weight_splits])
112+
sharded_weight = sharded_weight.reshape(sharded_weight_size)
113+
else:
114+
sharded_weight = shard_qkv(linear.weight, shard_dim, [i//8 for i in weight_splits])
87115
linear.scales_and_zeros = shard_qkv(linear.scales_and_zeros, 1 - shard_dim, weight_splits)
88116
else:
89117
sharded_weight = shard_qkv(linear.weight, shard_dim, weight_splits)
90118
if hasattr(linear, "scales") and style == "colwise":
91119
linear.scales = shard_qkv(linear.scales, 0, weight_splits)
92120
else:
93-
sharded_weight = shard(linear.weight, shard_dim)
121+
if isinstance(linear, WeightOnlyInt4Linear) and ("xpu" in Int4Device):
122+
in_features = linear.in_features
123+
out_features = linear.out_features//8
124+
sharded_weight_size = list(linear.weight.size())
125+
sharded_weight_size[shard_dim] = -1
126+
weight = linear.weight.reshape((in_features, out_features))
127+
sharded_weight = shard(weight, 1 - shard_dim)
128+
sharded_weight = sharded_weight.reshape(sharded_weight_size)
129+
else:
130+
sharded_weight = shard(linear.weight, shard_dim)
94131
if isinstance(linear, WeightOnlyInt4Linear):
95132
linear.scales_and_zeros = shard(linear.scales_and_zeros, 1 - shard_dim)
96133
if style == "rowwise":

0 commit comments

Comments
 (0)