File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -59,7 +59,7 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
5959 return
6060
6161 # inf-norm is equivalent to max(abs(w))
62- max_weights = torch ._foreach_norm (weights , ord = math .inf , dtype = torch . float32 ) # Partial
62+ max_weights = torch ._foreach_norm (weights , ord = math .inf ) # Partial
6363 amax_tensor = torch .stack (max_weights ) # Partial
6464 # clamp is dispatched through DTensor
6565 # it will issue a single all-reduce
@@ -69,7 +69,7 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
6969 scale_tensor = torch .clamp (scale_tensor , max = torch .finfo (torch .float16 ).max )
7070 local_scale_tensor = scale_tensor .to_local ()
7171 for i , float8_linear in enumerate (float8_linears ):
72- float8_linear .weight ._local_tensor ._precomputed_scale = local_scale_tensor [i ]
72+ float8_linear .weight ._local_tensor ._precomputed_scale = local_scale_tensor [i ]. to ( torch . float32 )
7373
7474
7575# FSDP pads its local tensor on dim-0. The subclass should be preserved such
You can’t perform that action at this time.
0 commit comments