diff --git a/src/MaxText/vocabulary_tiling.py b/src/MaxText/vocabulary_tiling.py index ff0a90e04..395d0400a 100644 --- a/src/MaxText/vocabulary_tiling.py +++ b/src/MaxText/vocabulary_tiling.py @@ -198,15 +198,13 @@ def _bwd_scan_body(grad_params_acc, chunk_data): _bwd_scan_body, initial_grad_params_acc, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation) ) grad_reshaped_hidden_states = _maybe_shard_with_name(grad_reshaped_hidden_states, reshaped_hidden_spec) - # TODO (chengnuojin): we may want to convert grad_params to bf16 to save memory - # grad_params = jax.tree_util.tree_map(lambda x, y: y.astype(x.dtype), gathered_params, grad_params) # Chain-rule to accumulate gradients - grad_params = jax.tree_util.tree_map(lambda g: g * loss_cotangent, grad_params) + grad_params = jax.tree_util.tree_map(lambda g, p: (g * loss_cotangent).astype(p.dtype), grad_params, gathered_params) # Give back sharding constraint - grad_reshaped_hidden_states = _reshape(grad_reshaped_hidden_states, (batch_size, seq_len, emb_dim), hidden_spec) + grad_hidden_states = _reshape(grad_reshaped_hidden_states, (batch_size, seq_len, emb_dim), hidden_spec) return ( grad_params, # grad for params - grad_reshaped_hidden_states.astype(reshaped_hidden_states.dtype), + grad_hidden_states.astype(config.weight_dtype), # Enforce activation gradients dtype as None, # grad for reshaped_labels None, # grad for reshaped_segmentation )