From 629c786f2bc5307e20cf3cd405f3c074b4830253 Mon Sep 17 00:00:00 2001 From: jianbinc Date: Wed, 26 Nov 2025 11:07:08 +0800 Subject: [PATCH] FusedAdam: replace zeros(param.shape)/empty(param.shape) with zeros_like(param)/empty_like(param) to support DTensor Signed-off-by: jianbinc --- transformer_engine/pytorch/optimizers/fused_adam.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index b5c87b4815..935af8ee0e 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -373,9 +373,9 @@ def _initialize_state( """ dtype = self.name_to_dtype_map[state_name] if store_param_remainders: - data = torch.zeros(param.shape, dtype=torch.int16, device=param.device) + data = torch.zeros_like(param, dtype=torch.int16, device=param.device) else: - data = torch.empty(param.shape, dtype=dtype, device=param.device) + data = torch.empty_like(param, dtype=dtype, device=param.device) if zero_buffer: data.zero_()