1+ # torchrun --nnodes 1 --nproc-per-node 4 <fn>
12import os
23import sys
34import torch
1314
1415from log_utils import rank_log , get_logger , verify_min_gpu_count
1516
17+ from torch .distributed .tensor .debug import CommDebugMode
1618
1719# ---- GPU check ------------
1820_min_gpu_count = 2
@@ -63,9 +65,10 @@ def forward(self, x):
6365"""
6466logger = get_logger ()
6567
68+ device_type = torch .accelerator .current_accelerator ().type
6669# create a device mesh based on the given world_size.
6770device_mesh = init_device_mesh (
68- device_type = "cuda" , mesh_shape = (int (os .environ ["WORLD_SIZE" ]),)
71+ device_type = device_type , mesh_shape = (int (os .environ ["WORLD_SIZE" ]),)
6972)
7073
7174_rank = device_mesh .get_rank ()
@@ -75,7 +78,7 @@ def forward(self, x):
7578rank_log (_rank , logger , f"Device Mesh created: { device_mesh = } " )
7679
7780# create model and move it to GPU. Init_device_mesh has already assigned gpu ids...
78- model = ToyModel ().to ("cuda" )
81+ model = ToyModel ().to (device_type )
7982
8083# Custom parallelization plan for the model
8184sp_model = parallelize_module (
@@ -87,6 +90,8 @@ def forward(self, x):
8790 },
8891)
8992
93+ if torch .distributed .get_rank () == 0 :
94+ print (f"model { sp_model } " )
9095
9196# Create a optimizer for the parallelized module.
9297lr = 0.25
@@ -98,12 +103,19 @@ def forward(self, x):
98103num_iters = 10
99104rank_log (_rank , logger , "Sequence Parallel training starting..." )
100105
106+
101107for i in range (num_iters ):
102108 # For SP, input can be different across all ranks.
103- inp = torch .rand (20 , 10 , device = "cuda" )
104- output = sp_model (inp )
105- output .sum ().backward ()
106- optimizer .step ()
109+ #inp = torch.rand(20, 10, device=device_type)
110+ inp = torch .rand (1 , 10 , device = device_type )
111+ comm_mode = CommDebugMode ()
112+ with comm_mode :
113+ output = sp_model (inp )
114+ output .sum ().backward ()
115+ optimizer .step ()
107116 rank_log (_rank , logger , f"Sequence Parallel iter { i } completed" )
108117
118+ if i == 0 :
119+ print (f" rank{ torch .distributed .get_rank ()} { i } get_comm_counts { comm_mode .get_comm_counts ()} get_sharding_info() { comm_mode .get_sharding_info ()} generate_comm_debug_tracing_table { comm_mode .generate_comm_debug_tracing_table (noise_level = 1 )} " )
120+
109121rank_log (_rank , logger , "Sequence Parallel training completed!" )
0 commit comments