Skip to content

Commit af44de3

Browse files
committed
update grad dtype
1 parent fca6e8f commit af44de3

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/MaxText/vocabulary_tiling.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
import jax
2222
import jax.numpy as jnp
2323
from MaxText import max_utils
24+
from MaxText import max_logging
2425
from MaxText.sharding import maybe_shard_with_name, all_gather_over_fsdp
2526
from MaxText.common_types import ShardMode
2627

28+
max_logging.log("Using new new vocab tiling file!!!")
2729

2830
def vocab_tiling_linen_loss(
2931
hidden_states,
@@ -198,15 +200,13 @@ def _bwd_scan_body(grad_params_acc, chunk_data):
198200
_bwd_scan_body, initial_grad_params_acc, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation)
199201
)
200202
grad_reshaped_hidden_states = _maybe_shard_with_name(grad_reshaped_hidden_states, reshaped_hidden_spec)
201-
# TODO (chengnuojin): we may want to convert grad_params to bf16 to save memory
202-
# grad_params = jax.tree_util.tree_map(lambda x, y: y.astype(x.dtype), gathered_params, grad_params)
203203
# Chain-rule to accumulate gradients
204-
grad_params = jax.tree_util.tree_map(lambda g: g * loss_cotangent, grad_params)
204+
grad_params = jax.tree_util.tree_map(lambda g, p: (g * loss_cotangent).astype(p.dtype), grad_params, gathered_params)
205205
# Give back sharding constraint
206-
grad_reshaped_hidden_states = _reshape(grad_reshaped_hidden_states, (batch_size, seq_len, emb_dim), hidden_spec)
206+
grad_hidden_states = _reshape(grad_reshaped_hidden_states, (batch_size, seq_len, emb_dim), hidden_spec)
207207
return (
208208
grad_params, # grad for params
209-
grad_reshaped_hidden_states.astype(reshaped_hidden_states.dtype),
209+
grad_hidden_states.astype(config.weight_dtype), # Enforce activation gradients dtype as
210210
None, # grad for reshaped_labels
211211
None, # grad for reshaped_segmentation
212212
)

0 commit comments

Comments
 (0)