From ccb64c3a47e47868eb1369414c3d8411cfc9a6bd Mon Sep 17 00:00:00 2001 From: Nuojin Cheng Date: Wed, 5 Nov 2025 22:27:45 +0000 Subject: [PATCH] update grad dtype --- src/MaxText/vocabulary_tiling.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/MaxText/vocabulary_tiling.py b/src/MaxText/vocabulary_tiling.py index ff0a90e04..395d0400a 100644 --- a/src/MaxText/vocabulary_tiling.py +++ b/src/MaxText/vocabulary_tiling.py @@ -198,15 +198,13 @@ def _bwd_scan_body(grad_params_acc, chunk_data): _bwd_scan_body, initial_grad_params_acc, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation) ) grad_reshaped_hidden_states = _maybe_shard_with_name(grad_reshaped_hidden_states, reshaped_hidden_spec) - # TODO (chengnuojin): we may want to convert grad_params to bf16 to save memory - # grad_params = jax.tree_util.tree_map(lambda x, y: y.astype(x.dtype), gathered_params, grad_params) # Chain-rule to accumulate gradients - grad_params = jax.tree_util.tree_map(lambda g: g * loss_cotangent, grad_params) + grad_params = jax.tree_util.tree_map(lambda g, p: (g * loss_cotangent).astype(p.dtype), grad_params, gathered_params) # Give back sharding constraint - grad_reshaped_hidden_states = _reshape(grad_reshaped_hidden_states, (batch_size, seq_len, emb_dim), hidden_spec) + grad_hidden_states = _reshape(grad_reshaped_hidden_states, (batch_size, seq_len, emb_dim), hidden_spec) return ( grad_params, # grad for params - grad_reshaped_hidden_states.astype(reshaped_hidden_states.dtype), + grad_hidden_states.astype(config.weight_dtype), # Enforce activation gradients dtype as None, # grad for reshaped_labels None, # grad for reshaped_segmentation )