11import sys
2+ import os
23import torch
34import torch .distributed as dist
45import torch .nn as nn
56import torch .nn .functional as F
67
7- from torch .distributed .fsdp import FullyShardedDataParallel as FSDP
8- from torch .distributed .tensor .parallel import (
9- parallelize_module ,
10- ColwiseParallel ,
11- RowwiseParallel ,
12- )
13-
14- import os
158from log_utils import rank_log , get_logger , verify_min_gpu_count
169
17-
1810# ---- GPU check ------------
1911_min_gpu_count = 4
2012
2315 sys .exit ()
2416# ---------------------------
2517
26- from torch .distributed ._tensor .device_mesh import init_device_mesh
18+ from llama2_model import Transformer , ModelArgs
19+
20+ from torch .distributed .device_mesh import init_device_mesh
21+ from torch .distributed .fsdp import FullyShardedDataParallel as FSDP
22+ from torch .distributed ._tensor import Shard , Replicate
23+ from torch .distributed .tensor .parallel import (
24+ parallelize_module ,
25+ ColwiseParallel ,
26+ RowwiseParallel ,
27+ PrepareModuleInput ,
28+ SequenceParallel
29+ )
2730
2831
2932"""
3033This is the script to test 2D Parallel which combines Tensor/Sequence
31- parallel with Fully Sharded Data Parallel (TP/SP + FSDP) on a toy model
32- in the SPMD style . We show an E2E working flow from forward, backward
34+ parallel with Fully Sharded Data Parallel (TP/SP + FSDP) on a example
35+ Llama2 model . We show an E2E working flow from forward, backward
3336and optimization.
3437
3538We enabled Fully Sharded Data Parallel + Tensor Parallel in
5356[0, 8, ..., 8N-8], [1, 9, ..., 8N-7], ..., [7, 15, ..., 8N-1]
5457======================================================================
5558
56- More details can be seen in the slide :
57- https://docs.google.com/presentation/d/17g6WqrO00rP3MsxbRENsPpjrlSkwiA_QB4r93_eB5is/
59+ More details can be seen in the PyTorch tutorials :
60+ https://pytorch.org/tutorials/intermediate/TP_tutorial.html
5861"""
5962
60-
61- def find_multiple (n : int , k : int ) -> int :
62- """function to find resizing multiple for SwiGLU MLP"""
63- if n % k == 0 :
64- return n
65- return n + k - (n % k )
66-
67-
68- class MLP_swiglu (nn .Module ):
69- """SwiGLU to showcase a Llama style MLP model"""
70-
71- def __init__ (self , mlp_dim : int = 1024 ) -> None :
72- super ().__init__ ()
73- hidden_dim = 4 * mlp_dim
74- scaled_hidden = int (2 * hidden_dim / 3 )
75- rounded_hidden = find_multiple (scaled_hidden , 256 )
76-
77- self .in_proj = nn .Linear (mlp_dim , rounded_hidden , bias = False )
78- self .gate_proj = nn .Linear (mlp_dim , rounded_hidden , bias = False )
79- self .out_proj = nn .Linear (rounded_hidden , mlp_dim , bias = False )
80-
81- def forward (self , x : torch .Tensor ) -> torch .Tensor :
82- x = F .silu (self .in_proj (x )) * self .gate_proj (x )
83- x = self .out_proj (x )
84- return x
85-
86-
87- """
88- Main body of the demo of a basic version of tensor parallel by using
89- PyTorch native APIs.
90- """
9163tp_size = 2
9264logger = get_logger ()
9365
@@ -120,26 +92,72 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
12092# to mimic the behavior of the dataloader.
12193dp_rank = dp_mesh .get_local_rank ()
12294
123- # create model and move it to GPU with id rank
124- _mlp_dim = 1024
125- base_model_swiglu = MLP_swiglu (mlp_dim = _mlp_dim ).to ("cuda" )
126-
127-
128- # Custom parallelization plan for the swiglu MLP model
129- custom_tp_model = parallelize_module (
130- module = base_model_swiglu ,
131- device_mesh = tp_mesh ,
132- parallelize_plan = {
133- "in_proj" : ColwiseParallel (),
134- "gate_proj" : ColwiseParallel (),
135- "out_proj" : RowwiseParallel (),
136- },
95+ # create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids.
96+ simple_llama2_config = ModelArgs (dim = 256 , n_layers = 2 , n_heads = 16 , vocab_size = 32000 )
97+
98+ model = Transformer .from_model_args (simple_llama2_config ).to ("cuda" )
99+
100+ # init model weights
101+ model .init_weights ()
102+
103+ # parallelize the first embedding and the last linear out projection
104+ model = parallelize_module (
105+ model ,
106+ tp_mesh ,
107+ {
108+ "tok_embeddings" : RowwiseParallel (
109+ input_layouts = Replicate (),
110+ ),
111+ "output" : ColwiseParallel (
112+ input_layouts = Shard (1 ),
113+ output_layouts = Replicate ()
114+ ),
115+ "norm" : SequenceParallel (),
116+ "layers.0" : PrepareModuleInput (
117+ input_layouts = (Replicate (), None ),
118+ desired_input_layouts = (Shard (1 ), None ),
119+ use_local_output = True ,
120+ ),
121+ }
137122)
138123
139- rank_log (_rank , logger , f"Model after parallelization { custom_tp_model = } \n " )
124+ for layer_id , transformer_block in enumerate (model .layers ):
125+ layer_tp_plan = {
126+ "attention" : PrepareModuleInput (
127+ input_layouts = (Shard (1 ), None ),
128+ desired_input_layouts = (Replicate (), None ),
129+ ),
130+ "attention.wq" : ColwiseParallel (),
131+ "attention.wk" : ColwiseParallel (),
132+ "attention.wv" : ColwiseParallel (),
133+ "attention.wo" : RowwiseParallel (output_layouts = Shard (1 )),
134+ "attention_norm" : SequenceParallel (),
135+ "feed_forward" : PrepareModuleInput (
136+ input_layouts = (Shard (1 ),),
137+ desired_input_layouts = (Replicate (),),
138+ ),
139+ "feed_forward.w1" : ColwiseParallel (),
140+ "feed_forward.w2" : RowwiseParallel (output_layouts = Shard (1 )),
141+ "feed_forward.w3" : ColwiseParallel (),
142+ "ffn_norm" : SequenceParallel (),
143+ }
144+
145+ # Adjust attention module to use the local number of heads
146+ attn_layer = transformer_block .attention
147+ attn_layer .n_heads = attn_layer .n_heads // tp_mesh .size ()
148+ attn_layer .n_kv_heads = attn_layer .n_kv_heads // tp_mesh .size ()
149+
150+ # Custom parallelization plan for the model
151+ parallelize_module (
152+ module = transformer_block ,
153+ device_mesh = tp_mesh ,
154+ parallelize_plan = layer_tp_plan
155+ )
140156
141157# Init FSDP using the dp device mesh
142- sharded_model = FSDP (custom_tp_model , device_mesh = dp_mesh , use_orig_params = True )
158+ sharded_model = FSDP (model , device_mesh = dp_mesh , use_orig_params = True )
159+
160+ rank_log (_rank , logger , f"Model after parallelization { sharded_model = } \n " )
143161
144162# Create an optimizer for the parallelized and sharded model.
145163lr = 3e-3
@@ -156,7 +174,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
156174for i in range (num_iterations ):
157175 # seeding with dp_rank to ensure identical inputs for TP groups
158176 torch .manual_seed (i + dp_rank )
159- inp = torch .rand ( batch_size , _mlp_dim , device = "cuda" )
177+ inp = torch .randint ( 32000 , ( 8 , 256 ) , device = "cuda" )
160178
161179 output = sharded_model (inp )
162180 output .sum ().backward ()
0 commit comments