77from torch .distributed .fsdp import fully_shard , MixedPrecisionPolicy
88from utils import inspect_mixed_precision , inspect_model
99
10+ def verify_min_gpu_count (min_gpus : int = 2 ) -> bool :
11+ """ verification that we have at least 2 gpus to run dist examples """
12+ has_gpu = torch .accelerator .is_available ()
13+ gpu_count = torch .accelerator .device_count ()
14+ return has_gpu and gpu_count >= min_gpus
1015
1116def set_modules_to_forward_prefetch (model , num_to_forward_prefetch ):
1217 for i , layer in enumerate (model .layers ):
@@ -29,10 +34,23 @@ def set_modules_to_backward_prefetch(model, num_to_backward_prefetch):
2934
3035
3136def main (args ):
37+ _min_gpu_count = 2
38+ if not verify_min_gpu_count (min_gpus = _min_gpu_count ):
39+ print (f"Unable to locate sufficient { _min_gpu_count } gpus to run this example. Exiting." )
40+ exit ()
3241 rank = int (os .environ ["LOCAL_RANK" ])
33- device = torch .device (f"cuda:{ rank } " )
34- torch .cuda .set_device (device )
35- torch .distributed .init_process_group (backend = "nccl" , device_id = device )
42+ if torch .accelerator .is_available ():
43+ device_type = torch .accelerator .current_accelerator ()
44+ device = torch .device (f"{ device_type } :{ rank } " )
45+ torch .accelerator .device_index (rank )
46+ print (f"Running on rank { rank } on device { device } " )
47+ else :
48+ device = torch .device ("cpu" )
49+ print (f"Running on device { device } " )
50+
51+ backend = torch .distributed .get_default_backend_for_device (device )
52+ torch .distributed .init_process_group (backend = backend , device_id = device )
53+
3654 torch .manual_seed (0 )
3755 vocab_size = 1024
3856 batch_size = 32
@@ -64,7 +82,7 @@ def main(args):
6482
6583 checkpointer = Checkpointer ("checkpoints" , dcp_api = args .dcp_api )
6684 if checkpointer .last_training_time is None :
67- model .to_empty (device = "cuda" )
85+ model .to_empty (device = device )
6886 model .reset_parameters ()
6987 else :
7088 checkpointer .load_model (model )
@@ -96,4 +114,5 @@ def main(args):
96114 parser .add_argument ("--mixed-precision" , action = "store_true" , default = False )
97115 parser .add_argument ("--dcp-api" , action = "store_true" , default = False )
98116 args = parser .parse_args ()
117+
99118 main (args )
0 commit comments