|
21 | 21 | import jax |
22 | 22 | import jax.numpy as jnp |
23 | 23 | from MaxText import max_utils |
| 24 | +from MaxText import max_logging |
24 | 25 | from MaxText.sharding import maybe_shard_with_name, all_gather_over_fsdp |
25 | 26 | from MaxText.common_types import ShardMode |
26 | 27 |
|
| 28 | +max_logging.log("Using new new vocab tiling file!!!") |
27 | 29 |
|
28 | 30 | def vocab_tiling_linen_loss( |
29 | 31 | hidden_states, |
@@ -198,15 +200,13 @@ def _bwd_scan_body(grad_params_acc, chunk_data): |
198 | 200 | _bwd_scan_body, initial_grad_params_acc, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation) |
199 | 201 | ) |
200 | 202 | grad_reshaped_hidden_states = _maybe_shard_with_name(grad_reshaped_hidden_states, reshaped_hidden_spec) |
201 | | - # TODO (chengnuojin): we may want to convert grad_params to bf16 to save memory |
202 | | - # grad_params = jax.tree_util.tree_map(lambda x, y: y.astype(x.dtype), gathered_params, grad_params) |
203 | 203 | # Chain-rule to accumulate gradients |
204 | | - grad_params = jax.tree_util.tree_map(lambda g: g * loss_cotangent, grad_params) |
| 204 | + grad_params = jax.tree_util.tree_map(lambda g, p: (g * loss_cotangent).astype(p.dtype), grad_params, gathered_params) |
205 | 205 | # Give back sharding constraint |
206 | | - grad_reshaped_hidden_states = _reshape(grad_reshaped_hidden_states, (batch_size, seq_len, emb_dim), hidden_spec) |
| 206 | + grad_hidden_states = _reshape(grad_reshaped_hidden_states, (batch_size, seq_len, emb_dim), hidden_spec) |
207 | 207 | return ( |
208 | 208 | grad_params, # grad for params |
209 | | - grad_reshaped_hidden_states.astype(reshaped_hidden_states.dtype), |
| 209 | + grad_hidden_states.astype(config.weight_dtype), # Enforce activation gradients dtype as |
210 | 210 | None, # grad for reshaped_labels |
211 | 211 | None, # grad for reshaped_segmentation |
212 | 212 | ) |
|
0 commit comments