@@ -198,15 +198,13 @@ def _bwd_scan_body(grad_params_acc, chunk_data):
198198 _bwd_scan_body , initial_grad_params_acc , (reshaped_hidden_states , reshaped_labels , reshaped_segmentation )
199199 )
200200 grad_reshaped_hidden_states = _maybe_shard_with_name (grad_reshaped_hidden_states , reshaped_hidden_spec )
201- # TODO (chengnuojin): we may want to convert grad_params to bf16 to save memory
202- # grad_params = jax.tree_util.tree_map(lambda x, y: y.astype(x.dtype), gathered_params, grad_params)
203201 # Chain-rule to accumulate gradients
204- grad_params = jax .tree_util .tree_map (lambda g : g * loss_cotangent , grad_params )
202+ grad_params = jax .tree_util .tree_map (lambda g , p : ( g * loss_cotangent ). astype ( p . dtype ) , grad_params , gathered_params )
205203 # Give back sharding constraint
206- grad_reshaped_hidden_states = _reshape (grad_reshaped_hidden_states , (batch_size , seq_len , emb_dim ), hidden_spec )
204+ grad_hidden_states = _reshape (grad_reshaped_hidden_states , (batch_size , seq_len , emb_dim ), hidden_spec )
207205 return (
208206 grad_params , # grad for params
209- grad_reshaped_hidden_states .astype (reshaped_hidden_states . dtype ),
207+ grad_hidden_states .astype (config . weight_dtype ), # Enforce activation gradients dtype as
210208 None , # grad for reshaped_labels
211209 None , # grad for reshaped_segmentation
212210 )
0 commit comments