File tree Expand file tree Collapse file tree 1 file changed +5
-6
lines changed Expand file tree Collapse file tree 1 file changed +5
-6
lines changed Original file line number Diff line number Diff line change @@ -108,14 +108,16 @@ def init_distributed_device_so(
108108 world_size = 1
109109 global_rank = 0
110110 local_rank = 0
111+ device_type , * device_idx = device .split (':' , maxsplit = 1 )
112+
111113 if dist_backend is None :
112114 # FIXME: verify that ROCm transform nccl to rccl
113115 dist_backends = {
114116 "xpu" : "ccl" ,
115117 "hpu" : "hccl" ,
116118 "cuda" : "nccl" ,
117119 }
118- dist_backend = dist_backends .get (device , 'gloo' )
120+ dist_backend = dist_backends .get (device_type , 'gloo' )
119121 dist_url = dist_url or 'env://'
120122
121123 # TBD, support horovod?
@@ -155,18 +157,15 @@ def init_distributed_device_so(
155157 global_rank = torch .distributed .get_rank ()
156158 distributed = True
157159
158- if 'cuda' in device :
160+ if device_type == 'cuda' :
159161 assert torch .cuda .is_available (), f'CUDA is not available but { device } was specified.'
160162
161163 if distributed and device != 'cpu' :
162- device , * device_idx = device .split (':' , maxsplit = 1 )
163-
164164 # Ignore manually specified device index in distributed mode and
165165 # override with resolved local rank, fewer headaches in most setups.
166166 if device_idx :
167167 _logger .warning (f'device index { device_idx [0 ]} removed from specified ({ device } ).' )
168-
169- device = f'{ device } :{ local_rank } '
168+ device = f'{ device_type } :{ local_rank } '
170169
171170 if device .startswith ('cuda:' ):
172171 torch .cuda .set_device (device )
You can’t perform that action at this time.
0 commit comments