Skip to content

Commit a3e170c

Browse files
authored
Improve SimpleFSDP typing and remove a finished TODO (#1960)
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * __->__ #1960 * #1959 As title, no logic change. **Squash and Merge button won't work for this PR. I'll merge by myself.**
1 parent bc3021e commit a3e170c

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

torchtitan/experiments/simple_fsdp/simple_fsdp.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from collections.abc import Sequence
88
from contextlib import contextmanager
99
from dataclasses import dataclass
10-
from typing import List, Optional
1110

1211
import torch
1312
import torch.nn as nn
@@ -45,8 +44,8 @@ def disable_active_parametrization():
4544

4645
@dataclass(frozen=True)
4746
class MixedPrecisionPolicy:
48-
param_dtype: Optional[torch.dtype] = None
49-
reduce_dtype: Optional[torch.dtype] = None
47+
param_dtype: torch.dtype | None = None
48+
reduce_dtype: torch.dtype | None = None
5049

5150

5251
class _ScaledPartial(Partial):
@@ -161,8 +160,8 @@ def _distribute_dtensor(
161160

162161

163162
def _register_parametrization(
164-
module: nn.Module, param_names: List[str], parametrization: nn.Module
165-
):
163+
module: nn.Module, param_names: list[str], parametrization: nn.Module
164+
) -> None:
166165
"""
167166
It works with state_dict without incurring parametrization calls because
168167
state_dict accesses parameters directly from self._parameters, not from getters
@@ -230,16 +229,14 @@ def __init__(
230229
self.param_dtype = mp_policy.param_dtype
231230
self.reduce_dtype = mp_policy.reduce_dtype
232231

233-
def replicate_compute(self, x):
232+
def replicate_compute(self, x: DTensor) -> torch.Tensor:
234233
# data parallel runtime replicate parameters and do local compute
235234
# the gradients are partial tensors that needs to perform reduction
236235
# (i.e. DDP: allreduce, FSDP: reduce_scatter, HSDP: mix of both)
237236
# support FSDP/DDP/HSDP + EP + TP (assuming TP shards the inner-most dim)
238237
non_dp_mesh_dims = x._spec.mesh.ndim - self.device_mesh.ndim
239238
assert non_dp_mesh_dims <= 2, "Only DP + EP/TP/EP+TP is supported"
240239
if non_dp_mesh_dims > 0:
241-
# TODO: remove tp_mesh as an input arg to data_parallel API and use x._spec.mesh["tp"]
242-
# after DeviceMesh supports slicing a non-root mesh
243240
dp_mesh = self.device_mesh
244241
# re-wrap 2D DTensor to 1D DTensor on dp_mesh for efficient FSDP all-gather
245242
sharded_local_tensor = x.to_local()
@@ -283,7 +280,7 @@ def replicate_compute(self, x):
283280

284281
return output
285282

286-
def forward(self, x):
283+
def forward(self, x: DTensor) -> torch.Tensor:
287284
global _active_parametrization
288285
# This should never be set to true during forward, only outside for model
289286
# inspection / debugging / initialization
@@ -296,7 +293,10 @@ def forward(self, x):
296293
if self.regional_ac and self.mode in ("fully_shard", "hybrid_shard"):
297294
# apply checkpointing to implement reshard_after_forward
298295
output = checkpoint(
299-
self.replicate_compute, x, use_reentrant=False, context_fn=fsdp_policy
296+
self.replicate_compute,
297+
x,
298+
use_reentrant=False,
299+
context_fn=fsdp_policy,
300300
)
301301
else:
302302
output = self.replicate_compute(x)
@@ -305,13 +305,13 @@ def forward(self, x):
305305

306306

307307
def data_parallel(
308-
model,
309-
device_mesh,
310-
mode="replicate",
308+
model: nn.Module,
309+
device_mesh: DeviceMesh,
310+
mode: str = "replicate",
311311
ac_mode: str = "none",
312-
mp_policy: Optional[MixedPrecisionPolicy] = None,
312+
mp_policy: MixedPrecisionPolicy | None = None,
313313
shard_dim: int = 0,
314-
reduction_divide_factor: Optional[float] = None,
314+
reduction_divide_factor: float | None = None,
315315
):
316316
if mode == "replicate":
317317
param_sharding = (Replicate(),)

0 commit comments

Comments
 (0)