Skip to content

Loading 8bit optimizer state from checkpoint causes dtype mismatch #3314

@yz-ppl

Description

@yz-ppl

We are using torch2.8. Optimizer states are quantized to 8bit. Normal training jobs are fine, but jobs that resume from checkpoint fail at optimizer.step(). We use AdamW optimizer copied from some older version of torch/torchao, where computation is done at fp32 precision:

exp_avg_f32 = exp_avg.float().lerp(grad_f32, 1 - beta1)

This fails with error that indicates exp_avg.float() is somehow bf16.

torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_method lerp(*(DTensor(local_tensor=OptimState8bit(signed=True, block_size=256, shape=(1408, 2048), device=cuda:0, requires_grad=False), device_mesh=DeviceMesh('cuda', [0], mesh_dim_names=('fsdp_cp',)), placements=(Shard(dim=0),)), DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(1408, 2048)), device_mesh=DeviceMesh('cuda', [0], mesh_dim_names=('fsdp_cp',)), placements=(Shard(dim=0),)), 0.09999999999999998), **{}): got RuntimeError('expected dtype torch.bfloat16 for `end`, but got dtype torch.float32')

from user code:
   File "/traindata/yunfan/lotus/lotus/components/optim/adamw.py", line 165, in single_param_adam
    exp_avg_f32 = exp_avg_f32.lerp(grad_f32, 1 - beta1)

The casting in load_state_dict() is suspicious that it converts state values like exp_avg to bf16 to match model weights' precision. So I tried to make both DTensor wrapper and OptimState8bit local tensor converted to fp32 if they appear to be bf16 after checkpoint loading, and added assert statement before lerp() to make sure exp_avg.float()'s dtype is fp32. But these efforts don't help. It seems somewhere in DTensor operation bf16 is enforced without triggering the assert statement. Can I get help on understanding the behavior and making correct fix? Thanks in advance!

Below is more detailed stacktrace:

Traceback (most recent call last):
  File "/traindata/yunfan/lotus/lotus/grpo.py", line 1051, in <module>
    recipe_main()
  File "/traindata/yunfan/lotus/lotus/utils/config.py", line 184, in wrapper
    recipe_main(conf)
  File "/traindata/yunfan/lotus/lotus/grpo.py", line 1046, in recipe_main
    recipe.train()
  File "/traindata/yunfan/lotus/lotus/grpo.py", line 813, in train
    step_output = self.train_step(
                  ^^^^^^^^^^^^^^^^
  File "/traindata/yunfan/lotus/lotus/grpo.py", line 694, in train_step
    self._optimizer.step()
  File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/optim/lr_scheduler.py", line 133, in wrapper
    return func.__get__(opt, opt.__class__)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/optim/optimizer.py", line 516, in wrapper
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/traindata/yunfan/lotus/lotus/components/optim/adamw.py", line 166, in step
    adamw8bit_step_helper(self, self.param_groups, self._new_buffer, self.bf16_stochastic_round, self.is_adamw)
  File "/traindata/yunfan/lotus/lotus/components/optim/adamw.py", line 280, in adamw8bit_step_helper
    single_param_adam(
  File "/traindata/yunfan/lotus/lotus/components/optim/adamw.py", line 208, in single_param_adam
    exp_avg_f32 = exp_avg_float.lerp(grad_f32, 1 - beta1)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/_compile.py", line 53, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_api.py", line 350, in __torch_dispatch__
    return DTensor._op_dispatcher.dispatch(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_dispatch.py", line 154, in dispatch
    self.sharding_propagator.propagate(op_info)
  File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_sharding_prop.py", line 266, in propagate
    OutputSharding, self.propagate_op_sharding(op_info.schema)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_sharding_prop.py", line 45, in __call__
    return self.cache(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_sharding_prop.py", line 279, in propagate_op_sharding_non_cached
    out_tensor_meta = self._propagate_tensor_meta_non_cached(op_schema)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_sharding_prop.py", line 126, in _propagate_tensor_meta_non_cached
    fake_out = op_schema.op(*fake_args, **fake_kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/_ops.py", line 829, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/utils/_stats.py", line 28, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1352, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2058, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1487, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2717, in _dispatch_impl
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/_ops.py", line 829, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/_prims_common/wrappers.py", line 309, in _fn
    result = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/_meta_registrations.py", line 7886, in lerp
    torch._check(
  File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/__init__.py", line 1684, in _check
    _check_with(RuntimeError, cond, message)
  File "/traindata/yunfan/lotus/.venv/lib/python3.12/site-packages/torch/__init__.py", line 1666, in _check_with
    raise error_type(message_evaluated)
RuntimeError: expected dtype torch.bfloat16 for `end`, but got dtype torch.float32

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions