Skip to content

Commit ccb64c3

Browse files
committed
update grad dtype
1 parent fca6e8f commit ccb64c3

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

src/MaxText/vocabulary_tiling.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -198,15 +198,13 @@ def _bwd_scan_body(grad_params_acc, chunk_data):
198198
_bwd_scan_body, initial_grad_params_acc, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation)
199199
)
200200
grad_reshaped_hidden_states = _maybe_shard_with_name(grad_reshaped_hidden_states, reshaped_hidden_spec)
201-
# TODO (chengnuojin): we may want to convert grad_params to bf16 to save memory
202-
# grad_params = jax.tree_util.tree_map(lambda x, y: y.astype(x.dtype), gathered_params, grad_params)
203201
# Chain-rule to accumulate gradients
204-
grad_params = jax.tree_util.tree_map(lambda g: g * loss_cotangent, grad_params)
202+
grad_params = jax.tree_util.tree_map(lambda g, p: (g * loss_cotangent).astype(p.dtype), grad_params, gathered_params)
205203
# Give back sharding constraint
206-
grad_reshaped_hidden_states = _reshape(grad_reshaped_hidden_states, (batch_size, seq_len, emb_dim), hidden_spec)
204+
grad_hidden_states = _reshape(grad_reshaped_hidden_states, (batch_size, seq_len, emb_dim), hidden_spec)
207205
return (
208206
grad_params, # grad for params
209-
grad_reshaped_hidden_states.astype(reshaped_hidden_states.dtype),
207+
grad_hidden_states.astype(config.weight_dtype), # Enforce activation gradients dtype as
210208
None, # grad for reshaped_labels
211209
None, # grad for reshaped_segmentation
212210
)

0 commit comments

Comments
 (0)