From 1be87398c9b460e86c1f167e37b0076d8eb1e821 Mon Sep 17 00:00:00 2001 From: "Ma, Jing" Date: Fri, 14 Jun 2024 00:46:08 -0700 Subject: [PATCH] intel gpu : enable intel gpu --- generate.py | 28 ++++++++++++++++++------ mixtral-moe/generate.py | 26 ++++++++++++++++++----- mixtral-moe/tp.py | 19 ++++++++++++++--- quantize.py | 6 +++++- tp.py | 47 ++++++++++++++++++++++++++++++++++++----- 5 files changed, 106 insertions(+), 20 deletions(-) diff --git a/generate.py b/generate.py index 24ba553d..6b29d5b8 100644 --- a/generate.py +++ b/generate.py @@ -16,6 +16,8 @@ def device_sync(device): if "cuda" in device: torch.cuda.synchronize(device) + elif "xpu" in device: + torch.xpu.synchronize(device) elif ("cpu" in device) or ("mps" in device): pass else: @@ -24,7 +26,8 @@ def device_sync(device): torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.triton.unique_kernel_names = True -torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future +if hasattr(torch._inductor.config, "fx_graph_cache"): + torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future default_device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -238,8 +241,9 @@ def _load_model(checkpoint_path, device, precision, use_tp): model.load_state_dict(checkpoint, assign=True) if use_tp: - from tp import apply_tp + from tp import apply_tp, global_device print("Applying tensor parallel to model ...") + global_device(device) apply_tp(model) model = model.to(device=device, dtype=precision) @@ -271,7 +275,7 @@ def main( global print from tp import maybe_init_dist - rank = maybe_init_dist() + rank = maybe_init_dist(device) use_tp = rank is not None if use_tp: if rank != 0: @@ -303,7 +307,7 @@ def main( torch.manual_seed(1234) model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())]) if compile: - if is_speculative and use_tp: # and ("cuda" in device): + if is_speculative and use_tp and ("cuda" in device): torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case if is_speculative: @@ -354,8 +358,15 @@ def callback(x): if (i != num_samples - 1 or not profile) or (use_tp and rank != 0): prof = contextlib.nullcontext() else: - torch.profiler._utils._init_for_cuda_graphs() - prof = torch.profiler.profile() + if "cuda" in device: + torch.profiler._utils._init_for_cuda_graphs() + prof = torch.profiler.profile() + elif "xpu" in device: + prof = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.XPU], + ) with prof: y, metrics = generate( model, @@ -419,6 +430,11 @@ def callback(x): parser.add_argument('--device', type=str, default=default_device, help='Device to use') args = parser.parse_args() + if "xpu" in args.device: + try: + import intel_extension_for_pytorch as ipex + except: + raise ModuleNotFoundError(f"Intel Extension for PyTorch (intel_extension_for_pytorch) is required to run PyTorch code on Intel GPU (XPU). Please check https://github.com/intel/intel-extension-for-pytorch for details.") main( args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path, diff --git a/mixtral-moe/generate.py b/mixtral-moe/generate.py index 9aa076b6..ea1a5ff5 100644 --- a/mixtral-moe/generate.py +++ b/mixtral-moe/generate.py @@ -16,6 +16,8 @@ def device_sync(device): if "cuda" in device: torch.cuda.synchronize(device) + elif "xpu" in device: + torch.xpu.synchronize(device) elif "cpu" in device: pass else: @@ -24,7 +26,8 @@ def device_sync(device): torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.triton.unique_kernel_names = True -torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future +if hasattr(torch._inductor.config, "fx_graph_cache"): + torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future # support running without installing as a package @@ -178,7 +181,7 @@ def main( assert tokenizer_path.is_file(), str(tokenizer_path) global print - rank = maybe_init_dist() + rank = maybe_init_dist(device) use_tp = rank is not None if use_tp: if rank != 0: @@ -203,7 +206,8 @@ def main( torch.manual_seed(1234) model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())]) if compile: - torch._inductor.config.assert_indirect_indexing = False + if hasattr(torch._inductor.config, "assert_indirect_indexing"): + torch._inductor.config.assert_indirect_indexing = False global decode_one_token, prefill decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) @@ -248,8 +252,15 @@ def callback(x): if (i != num_samples - 1 or not profile) or (use_tp and rank != 0): prof = contextlib.nullcontext() else: - torch.profiler._utils._init_for_cuda_graphs() - prof = torch.profiler.profile() + if "cuda" in device: + torch.profiler._utils._init_for_cuda_graphs() + prof = torch.profiler.profile() + elif "xpu" in device: + prof = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.XPU], + ) with prof: y = generate( model, @@ -302,6 +313,11 @@ def callback(x): parser.add_argument('--device', type=str, default="cuda", help='device to use') args = parser.parse_args() + if "xpu" in args.device: + try: + import intel_extension_for_pytorch as ipex + except: + raise ModuleNotFoundError(f"Intel Extension for PyTorch (intel_extension_for_pytorch) is required to run PyTorch code on Intel GPU (XPU). Please check https://github.com/intel/intel-extension-for-pytorch for details.") main( args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.device diff --git a/mixtral-moe/tp.py b/mixtral-moe/tp.py index 75336b58..8ebf638a 100644 --- a/mixtral-moe/tp.py +++ b/mixtral-moe/tp.py @@ -28,7 +28,7 @@ def local_break(): def _get_world_size() -> int: return int(os.environ.get("LOCAL_WORLD_SIZE", "1")) -def maybe_init_dist() -> Optional[int]: +def maybe_init_dist(device) -> Optional[int]: try: # provided by torchrun rank = _get_rank() @@ -41,8 +41,21 @@ def maybe_init_dist() -> Optional[int]: # not run via torchrun, no-op return None - torch.cuda.set_device(rank) - dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + if "cuda" in device: + torch.cuda.set_device(rank) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + elif "xpu" in device: + try: + import oneccl_bindings_for_pytorch + except: + raise ModuleNotFoundError(f"OneCCL bindings for PyTorch (oneccl_bindings_for_pytorch) is required to run tensor parallel on Intel GPU (XPU). Please check https://github.com/intel/torch-ccl for details.") + + os.environ['CCL_PROCESS_LAUNCHER'] = 'none' + os.environ['CCL_LOCAL_SIZE'] = str(world_size) + os.environ['CCL_LOCAL_RANK'] = str(rank) + + torch.xpu.set_device(rank) + dist.init_process_group(backend="ccl", rank=rank, world_size=world_size) return rank rank = _get_rank() diff --git a/quantize.py b/quantize.py index 4ebbe5f5..e6c95487 100644 --- a/quantize.py +++ b/quantize.py @@ -539,7 +539,6 @@ def quantize( device: str = default_device, ) -> None: assert checkpoint_path.is_file(), checkpoint_path - device = 'cpu' precision = torch.bfloat16 print("Loading model ...") @@ -621,4 +620,9 @@ def quantize( parser.add_argument('--device', type=str, default=default_device, help='device to use') args = parser.parse_args() + if "xpu" in args.device: + try: + import intel_extension_for_pytorch as ipex + except: + raise ModuleNotFoundError(f"Intel Extension for PyTorch (intel_extension_for_pytorch) is required to run PyTorch code on Intel GPU (XPU). Please check https://github.com/intel/intel-extension-for-pytorch for details.") quantize(args.checkpoint_path, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label, args.device) diff --git a/tp.py b/tp.py index a151a875..6da35095 100644 --- a/tp.py +++ b/tp.py @@ -18,6 +18,12 @@ from model import Attention, FeedForward, Transformer from quantize import WeightOnlyInt4Linear +Int4Device = "cpu" + +def global_device(device: str = "cpu"): + global Int4Device + Int4Device = device + def _get_rank() -> int: return int(os.environ.get("LOCAL_RANK", "0")) @@ -33,7 +39,7 @@ def local_break(): def _get_world_size() -> int: return int(os.environ.get("LOCAL_WORLD_SIZE", "1")) -def maybe_init_dist() -> Optional[int]: +def maybe_init_dist(device) -> Optional[int]: try: # provided by torchrun rank = _get_rank() @@ -46,8 +52,21 @@ def maybe_init_dist() -> Optional[int]: # not run via torchrun, no-op return None - torch.cuda.set_device(rank) - dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + if "cuda" in device: + torch.cuda.set_device(rank) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + elif "xpu" in device: + try: + import oneccl_bindings_for_pytorch + except: + raise ModuleNotFoundError(f"OneCCL bindings for PyTorch (oneccl_bindings_for_pytorch) is required to run tensor parallel on Intel GPU (XPU). Please check https://github.com/intel/torch-ccl for details.") + + os.environ['CCL_PROCESS_LAUNCHER'] = 'none' + os.environ['CCL_LOCAL_SIZE'] = str(world_size) + os.environ['CCL_LOCAL_RANK'] = str(rank) + + torch.xpu.set_device(rank) + dist.init_process_group(backend="ccl", rank=rank, world_size=world_size) return rank @@ -83,14 +102,32 @@ def shard_qkv(qkv, dim, weight_splits): assert len(weight_splits) == 3 if isinstance(linear, WeightOnlyInt4Linear): - sharded_weight = shard_qkv(linear.weight, shard_dim, [i//8 for i in weight_splits]) + if ("xpu" in Int4Device): + in_features = linear.in_features + out_features = linear.out_features//8 + sharded_weight_size = list(linear.weight.size()) + sharded_weight_size[shard_dim] = -1 + weight = linear.weight.reshape((in_features, out_features)) + sharded_weight = shard_qkv(weight, 1 - shard_dim, [i//8 for i in weight_splits]) + sharded_weight = sharded_weight.reshape(sharded_weight_size) + else: + sharded_weight = shard_qkv(linear.weight, shard_dim, [i//8 for i in weight_splits]) linear.scales_and_zeros = shard_qkv(linear.scales_and_zeros, 1 - shard_dim, weight_splits) else: sharded_weight = shard_qkv(linear.weight, shard_dim, weight_splits) if hasattr(linear, "scales") and style == "colwise": linear.scales = shard_qkv(linear.scales, 0, weight_splits) else: - sharded_weight = shard(linear.weight, shard_dim) + if isinstance(linear, WeightOnlyInt4Linear) and ("xpu" in Int4Device): + in_features = linear.in_features + out_features = linear.out_features//8 + sharded_weight_size = list(linear.weight.size()) + sharded_weight_size[shard_dim] = -1 + weight = linear.weight.reshape((in_features, out_features)) + sharded_weight = shard(weight, 1 - shard_dim) + sharded_weight = sharded_weight.reshape(sharded_weight_size) + else: + sharded_weight = shard(linear.weight, shard_dim) if isinstance(linear, WeightOnlyInt4Linear): linear.scales_and_zeros = shard(linear.scales_and_zeros, 1 - shard_dim) if style == "rowwise":