-
Notifications
You must be signed in to change notification settings - Fork 364
Description
We are using torch2.8. Optimizer states are quantized to 8bit. Normal training jobs are fine, but jobs that resume from checkpoint fail at optimizer.step(). We use AdamW optimizer copied from some older version of torch/torchao, where computation is done at fp32 precision:
exp_avg_f32 = exp_avg.float().lerp(grad_f32, 1 - beta1)
This fails with error that indicates exp_avg.float() is somehow bf16.
torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_method lerp(*(DTensor(local_tensor=OptimState8bit(signed=True, block_size=256, shape=(1408, 2048), device=cuda:0, requires_grad=False), device_mesh=DeviceMesh('cuda', [0], mesh_dim_names=('fsdp_cp',)), placements=(Shard(dim=0),)), DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(1408, 2048)), device_mesh=DeviceMesh('cuda', [0], mesh_dim_names=('fsdp_cp',)), placements=(Shard(dim=0),)), 0.09999999999999998), **{}): got RuntimeError('expected dtype torch.bfloat16 for `end`, but got dtype torch.float32')
from user code:
File "/traindata/yunfan/lotus/lotus/components/optim/adamw.py", line 165, in single_param_adam
exp_avg_f32 = exp_avg_f32.lerp(grad_f32, 1 - beta1)
The casting in load_state_dict() is suspicious that it converts state values like exp_avg to bf16 to match model weights' precision. So I tried to make both DTensor wrapper and OptimState8bit local tensor converted to fp32 if they appear to be bf16 after checkpoint loading, and added assert statement before lerp() to make sure exp_avg.float()'s dtype is fp32. But these efforts don't help. It seems somewhere in DTensor operation bf16 is enforced without triggering the assert statement. Can I get help on understanding the behavior and making correct fix? Thanks in advance!
Below is more detailed stacktrace:
Traceback (most recent call last):
File "/traindata/yunfan/lotus/lotus/grpo.py", line 1051, in <module>
recipe_main()
File "/traindata/yunfan/lotus/lotus/utils/config.py", line 184, in wrapper
recipe_main(conf)
File "/traindata/yunfan/lotus/lotus/grpo.py", line 1046, in recipe_main
recipe.train()
File "/traindata/yunfan/lotus/lotus/grpo.py", line 813, in train
step_output = self.train_step(
^^^^^^^^^^^^^^^^
File "/traindata/yunfan/lotus/lotus/grpo.py", line 694, in train_step
self._optimizer.step()
File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/optim/lr_scheduler.py", line 133, in wrapper
return func.__get__(opt, opt.__class__)(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/optim/optimizer.py", line 516, in wrapper
out = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/traindata/yunfan/lotus/lotus/components/optim/adamw.py", line 166, in step
adamw8bit_step_helper(self, self.param_groups, self._new_buffer, self.bf16_stochastic_round, self.is_adamw)
File "/traindata/yunfan/lotus/lotus/components/optim/adamw.py", line 280, in adamw8bit_step_helper
single_param_adam(
File "/traindata/yunfan/lotus/lotus/components/optim/adamw.py", line 208, in single_param_adam
exp_avg_f32 = exp_avg_float.lerp(grad_f32, 1 - beta1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/_compile.py", line 53, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_api.py", line 350, in __torch_dispatch__
return DTensor._op_dispatcher.dispatch(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_dispatch.py", line 154, in dispatch
self.sharding_propagator.propagate(op_info)
File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_sharding_prop.py", line 266, in propagate
OutputSharding, self.propagate_op_sharding(op_info.schema)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_sharding_prop.py", line 45, in __call__
return self.cache(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_sharding_prop.py", line 279, in propagate_op_sharding_non_cached
out_tensor_meta = self._propagate_tensor_meta_non_cached(op_schema)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_sharding_prop.py", line 126, in _propagate_tensor_meta_non_cached
fake_out = op_schema.op(*fake_args, **fake_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/_ops.py", line 829, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/utils/_stats.py", line 28, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1352, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2058, in dispatch
return self._cached_dispatch_impl(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1487, in _cached_dispatch_impl
output = self._dispatch_impl(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2717, in _dispatch_impl
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/_ops.py", line 829, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/_prims_common/wrappers.py", line 309, in _fn
result = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/_meta_registrations.py", line 7886, in lerp
torch._check(
File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/__init__.py", line 1684, in _check
_check_with(RuntimeError, cond, message)
File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/__init__.py", line 1666, in _check_with
raise error_type(message_evaluated)
RuntimeError: expected dtype torch.bfloat16 for `end`, but got dtype torch.float32