77from collections .abc import Sequence
88from contextlib import contextmanager
99from dataclasses import dataclass
10- from typing import List , Optional
1110
1211import torch
1312import torch .nn as nn
@@ -45,8 +44,8 @@ def disable_active_parametrization():
4544
4645@dataclass (frozen = True )
4746class MixedPrecisionPolicy :
48- param_dtype : Optional [ torch .dtype ] = None
49- reduce_dtype : Optional [ torch .dtype ] = None
47+ param_dtype : torch .dtype | None = None
48+ reduce_dtype : torch .dtype | None = None
5049
5150
5251class _ScaledPartial (Partial ):
@@ -161,8 +160,8 @@ def _distribute_dtensor(
161160
162161
163162def _register_parametrization (
164- module : nn .Module , param_names : List [str ], parametrization : nn .Module
165- ):
163+ module : nn .Module , param_names : list [str ], parametrization : nn .Module
164+ ) -> None :
166165 """
167166 It works with state_dict without incurring parametrization calls because
168167 state_dict accesses parameters directly from self._parameters, not from getters
@@ -230,16 +229,14 @@ def __init__(
230229 self .param_dtype = mp_policy .param_dtype
231230 self .reduce_dtype = mp_policy .reduce_dtype
232231
233- def replicate_compute (self , x ) :
232+ def replicate_compute (self , x : DTensor ) -> torch . Tensor :
234233 # data parallel runtime replicate parameters and do local compute
235234 # the gradients are partial tensors that needs to perform reduction
236235 # (i.e. DDP: allreduce, FSDP: reduce_scatter, HSDP: mix of both)
237236 # support FSDP/DDP/HSDP + EP + TP (assuming TP shards the inner-most dim)
238237 non_dp_mesh_dims = x ._spec .mesh .ndim - self .device_mesh .ndim
239238 assert non_dp_mesh_dims <= 2 , "Only DP + EP/TP/EP+TP is supported"
240239 if non_dp_mesh_dims > 0 :
241- # TODO: remove tp_mesh as an input arg to data_parallel API and use x._spec.mesh["tp"]
242- # after DeviceMesh supports slicing a non-root mesh
243240 dp_mesh = self .device_mesh
244241 # re-wrap 2D DTensor to 1D DTensor on dp_mesh for efficient FSDP all-gather
245242 sharded_local_tensor = x .to_local ()
@@ -283,7 +280,7 @@ def replicate_compute(self, x):
283280
284281 return output
285282
286- def forward (self , x ) :
283+ def forward (self , x : DTensor ) -> torch . Tensor :
287284 global _active_parametrization
288285 # This should never be set to true during forward, only outside for model
289286 # inspection / debugging / initialization
@@ -296,7 +293,10 @@ def forward(self, x):
296293 if self .regional_ac and self .mode in ("fully_shard" , "hybrid_shard" ):
297294 # apply checkpointing to implement reshard_after_forward
298295 output = checkpoint (
299- self .replicate_compute , x , use_reentrant = False , context_fn = fsdp_policy
296+ self .replicate_compute ,
297+ x ,
298+ use_reentrant = False ,
299+ context_fn = fsdp_policy ,
300300 )
301301 else :
302302 output = self .replicate_compute (x )
@@ -305,13 +305,13 @@ def forward(self, x):
305305
306306
307307def data_parallel (
308- model ,
309- device_mesh ,
310- mode = "replicate" ,
308+ model : nn . Module ,
309+ device_mesh : DeviceMesh ,
310+ mode : str = "replicate" ,
311311 ac_mode : str = "none" ,
312- mp_policy : Optional [ MixedPrecisionPolicy ] = None ,
312+ mp_policy : MixedPrecisionPolicy | None = None ,
313313 shard_dim : int = 0 ,
314- reduction_divide_factor : Optional [ float ] = None ,
314+ reduction_divide_factor : float | None = None ,
315315):
316316 if mode == "replicate" :
317317 param_sharding = (Replicate (),)
0 commit comments