1717"""Alternative DeepSeek model definition with batch-split schedule."""
1818
1919from flax import linen as nn
20+ from flax import nnx
2021import jax
2122import jax .numpy as jnp
23+ from jax .sharding import Mesh
2224from MaxText import common_types
25+ from MaxText import max_utils
26+ from MaxText .common_types import Config
2327from MaxText .inference import page_manager
2428from MaxText .layers import attention_mla
2529from MaxText .layers import initializers
2630from MaxText .layers import linears
2731from MaxText .layers import moe
2832from MaxText .layers import normalizations
33+ from MaxText .layers import nnx_wrappers
2934from MaxText .layers import quantizations
3035
31-
32- class DeepSeekGenericLayer (nn .Module ):
36+ class DeepSeekBatchSplitGenericLayer (nnx .Module ):
3337 """Generic DeepSeek layer with Multi-Head Latent Attention.
3438
3539 This is to be used as a base class for DeepSeek layers with dense/sparse MLPs.
36-
3740 This class follows a pattern of separating module creation from execution.
38- `*_layer()` methods (e.g., `attention_layer`) are factories for `nn.Module`s,
39- called in `setup()` to initialize sub-layers. The module instances are stored
40- in `*_op` attributes (e.g., `self.attention_op`). The corresponding methods
41- (e.g., `attention`) are called during execution in `__call__` and wrap the
42- `*_op` modules with logic like logical constraints. This keeps `__call__`
43- clean and readable.
4441 """
42+ def __init__ (
43+ self ,
44+ config : Config ,
45+ model_mode : str ,
46+ mesh : Mesh ,
47+ rngs : nnx .Rngs ,
48+ quant : quantizations .AqtQuantization | None = None ,
49+ ) -> None :
50+
51+ self .config = config
52+ self .model_mode = model_mode
53+ self .mesh = mesh
54+ self .quant = quant
55+ self .rngs = rngs
56+
57+ batch_size , sequence_length = max_utils .get_batch_seq_len_for_mode (self .config , model_mode )
58+ self .dummy_inputs_shape = (batch_size , sequence_length , self .config .emb_dim )
59+
60+ self .pre_attention_layer_norm = normalizations .RMSNorm (
61+ num_features = self .dummy_inputs_shape [- 1 ],
62+ dtype = config .dtype ,
63+ weight_dtype = config .weight_dtype ,
64+ kernel_axes = ("norm" ,),
65+ epsilon = config .normalization_layer_epsilon ,
66+ rngs = self .rngs ,
67+ )
4568
46- config : common_types .Config
47- mesh : jax .sharding .Mesh
48- model_mode : str
49- quant : None | quantizations .AqtQuantization = None
69+ self .post_attention_layer_norm = normalizations .RMSNorm (
70+ num_features = self .dummy_inputs_shape [- 1 ],
71+ dtype = config .dtype ,
72+ weight_dtype = config .weight_dtype ,
73+ kernel_axes = ("norm" ,),
74+ epsilon = config .normalization_layer_epsilon ,
75+ rngs = self .rngs ,
76+ )
77+
78+ self .self_attention = attention_mla .MLA (
79+ config = self .config ,
80+ num_query_heads = self .config .num_query_heads ,
81+ num_kv_heads = self .config .num_kv_heads ,
82+ head_dim = self .config .head_dim ,
83+ max_target_length = self .config .max_target_length ,
84+ max_prefill_predict_length = self .config .max_prefill_predict_length ,
85+ attention_kernel = self .config .attention ,
86+ attention_type = self .config .attention_type ,
87+ inputs_q_shape = self .dummy_inputs_shape ,
88+ inputs_kv_shape = self .dummy_inputs_shape ,
89+ mesh = self .mesh ,
90+ dtype = self .config .dtype ,
91+ weight_dtype = self .config .weight_dtype ,
92+ dropout_rate = self .config .dropout_rate ,
93+ quant = self .quant ,
94+ kv_quant = quantizations .configure_kv_quant (self .config ),
95+ q_lora_rank = self .config .q_lora_rank ,
96+ kv_lora_rank = self .config .kv_lora_rank ,
97+ qk_nope_head_dim = self .config .qk_nope_head_dim ,
98+ qk_rope_head_dim = self .config .qk_rope_head_dim ,
99+ v_head_dim = self .config .v_head_dim ,
100+ max_position_embeddings = self .config .max_position_embeddings ,
101+ original_max_position_embeddings = self .config .original_max_position_embeddings ,
102+ mscale = self .config .mscale ,
103+ rope_factor = self .config .rope_factor ,
104+ model_mode = self .model_mode ,
105+ rngs = self .rngs ,
106+ )
107+
108+ self .dropout = linears .Dropout (rate = self .config .dropout_rate , broadcast_dims = (- 2 ,), rngs = self .rngs )
50109
51110 def __call__ (
52111 self ,
@@ -62,8 +121,8 @@ def __call__(
62121 x = self .with_logical_constraint (inputs )
63122 x = jax .ad_checkpoint .checkpoint_name (x , "decoder_layer_input" )
64123
65- x += self .attention (
66- self .pre_attention_norm (x ),
124+ x += self .attention_op (
125+ self .pre_attention_norm_op (x ),
67126 decoder_segment_ids ,
68127 decoder_positions ,
69128 deterministic ,
@@ -72,19 +131,10 @@ def __call__(
72131 slot ,
73132 )
74133
75- x += self .mlp (self .post_attention_norm (x ), deterministic )
76- x = self .dropout (x , deterministic )
134+ x += self .mlp_op (self .post_attention_norm_op (x ), deterministic )
135+ x = self .dropout_op (x , deterministic )
77136 return self .post_process (x )
78137
79- def setup (self ):
80- self .pre_attention_norm_op = self .rms_norm_layer ("pre_attention_layer_norm" )
81- self .post_attention_norm_op = self .rms_norm_layer (
82- "post_attention_layer_norm"
83- )
84- self .attention_op = self .attention_layer ()
85- self .mlp_op = self .mlp_layer ()
86- self .dropout_op = self .dropout_layer ()
87-
88138 @property
89139 def logical_axis_names (self ):
90140 if self .model_mode == common_types .MODEL_MODE_PREFILL :
@@ -103,59 +153,13 @@ def logical_axis_names(self):
103153 def with_logical_constraint (self , x ):
104154 return nn .with_logical_constraint (x , self .logical_axis_names )
105155
106- def rms_norm_layer (self , name ):
107- return normalizations .rms_norm (
108- num_features = self .config .base_emb_dim ,
109- dtype = self .config .dtype ,
110- weight_dtype = self .config .weight_dtype ,
111- name = name ,
112- kernel_axes = ("norm" ,),
113- epsilon = self .config .normalization_layer_epsilon ,
114- )
115-
116- def pre_attention_norm (self , x ):
117- return self .with_logical_constraint (self .pre_attention_norm_op (x ))
118-
119- def post_attention_norm (self , x ):
120- return self .with_logical_constraint (self .post_attention_norm_op (x ))
156+ def pre_attention_norm_op (self , x ):
157+ return self .with_logical_constraint (self .pre_attention_layer_norm (x ))
121158
122- def attention_layer (self ):
123- inputs_shape = (
124- self .config .per_device_batch_size ,
125- self .config .max_target_length ,
126- self .config .base_emb_dim ,
127- )
128- return attention_mla .mla_as_linen (
129- config = self .config ,
130- num_query_heads = self .config .num_query_heads ,
131- num_kv_heads = self .config .num_kv_heads ,
132- head_dim = self .config .head_dim ,
133- max_target_length = self .config .max_target_length ,
134- max_prefill_predict_length = self .config .max_prefill_predict_length ,
135- attention_kernel = self .config .attention ,
136- attention_type = self .config .attention_type ,
137- inputs_q_shape = inputs_shape ,
138- inputs_kv_shape = inputs_shape ,
139- mesh = self .mesh ,
140- dtype = self .config .dtype ,
141- weight_dtype = self .config .weight_dtype ,
142- dropout_rate = self .config .dropout_rate ,
143- name = "self_attention" ,
144- quant = self .quant ,
145- kv_quant = quantizations .configure_kv_quant (self .config ),
146- q_lora_rank = self .config .q_lora_rank ,
147- kv_lora_rank = self .config .kv_lora_rank ,
148- qk_nope_head_dim = self .config .qk_nope_head_dim ,
149- qk_rope_head_dim = self .config .qk_rope_head_dim ,
150- v_head_dim = self .config .v_head_dim ,
151- max_position_embeddings = self .config .max_position_embeddings ,
152- original_max_position_embeddings = self .config .original_max_position_embeddings ,
153- mscale = self .config .mscale ,
154- rope_factor = self .config .rope_factor ,
155- model_mode = self .model_mode ,
156- )
159+ def post_attention_norm_op (self , x ):
160+ return self .with_logical_constraint (self .post_attention_layer_norm (x ))
157161
158- def attention (
162+ def attention_op (
159163 self ,
160164 x ,
161165 decoder_segment_ids ,
@@ -167,7 +171,7 @@ def attention(
167171 ):
168172 """Executes the attention layer."""
169173 return self .with_logical_constraint (
170- self .attention_op (
174+ self .self_attention (
171175 x ,
172176 x ,
173177 decoder_positions ,
@@ -180,60 +184,87 @@ def attention(
180184 )
181185 )
182186
183- def mlp_layer (self ):
187+ def mlp_op (self , x , deterministic ):
188+ """Executes the MLP operation. To be implemented by subclasses."""
184189 raise NotImplementedError ()
185190
186- def mlp (self , x , deterministic ):
187- raise NotImplementedError ()
188-
189- def dropout_layer (self ):
190- return nn .Dropout (rate = self .config .dropout_rate , broadcast_dims = (- 2 ,))
191-
192- def dropout (self , x , deterministic ):
191+ def dropout_op (self , x , deterministic ):
193192 return self .with_logical_constraint (
194- self .dropout_op (x , deterministic = deterministic )
193+ self .dropout (x , deterministic = deterministic )
195194 )
196195
197196 def post_process (self , x ):
198197 """Collect statistics about the output of the layer."""
199198 if self .config .record_internal_nn_metrics :
200- self .sow ("intermediates" , "activation_mean" , jnp .mean (x ))
201- self .sow ("intermediates" , "activation_stdev" , jnp .std (x ))
199+ self .sow (nnx . Intermediate , "activation_mean" , jnp .mean (x ))
200+ self .sow (nnx . Intermediate , "activation_stdev" , jnp .std (x ))
202201 self .sow (
203- "intermediates" ,
202+ nnx . Intermediate ,
204203 "activation_fraction_zero" ,
205204 jnp .sum (x == 0 ) / jnp .size (x ),
206205 )
207206
208207 if self .config .scan_layers :
209208 return x , None
210- else :
211- return x
209+ return x
212210
213211
214- class DeepSeekDenseLayer (DeepSeekGenericLayer ):
212+ class DeepSeekDenseLayer (DeepSeekBatchSplitGenericLayer ):
215213 """DeepSeek layer with dense MLP."""
216214
217- def mlp_layer (self ):
218- return linears .mlp_block (
219- in_features = self .config .base_emb_dim ,
215+ def __init__ (self ,
216+ config : Config ,
217+ model_mode : str ,
218+ mesh : Mesh ,
219+ rngs : nnx .Rngs ,
220+ quant : quantizations .AqtQuantization | None = None ,):
221+
222+ super ().__init__ (config , model_mode , mesh , rngs , quant )
223+
224+ self .mlp = linears .MlpBlock (
225+ config = self .config ,
226+ mesh = self .mesh ,
227+ in_features = self .dummy_inputs_shape [- 1 ],
220228 intermediate_dim = self .config .mlp_dim ,
221229 activations = self .config .mlp_activations ,
222230 intermediate_dropout_rate = self .config .dropout_rate ,
223231 dtype = self .config .dtype ,
224232 weight_dtype = self .config .weight_dtype ,
225- name = "mlp" ,
226- config = self .config ,
227233 quant = self .quant ,
228- mesh = self .mesh ,
234+ model_mode = model_mode ,
235+ rngs = self .rngs ,
229236 )
230237
231- def mlp (self , x , deterministic ):
232- return self .with_logical_constraint (self .mlp_op (x , deterministic ))
238+ def mlp_op (self , x , deterministic ):
239+ return self .with_logical_constraint (self .mlp (x , deterministic ))
240+
233241
242+ DeepSeekDenseLayerToLinen = nnx_wrappers .to_linen_class (
243+ DeepSeekDenseLayer ,
244+ base_metadata_fn = initializers .variable_to_logically_partitioned ,
245+ )
234246
235- class DeepSeekMoELayer (DeepSeekGenericLayer ):
247+ class DeepSeekMoELayer (DeepSeekBatchSplitGenericLayer ):
236248 """DeepSeek MoE layer that uses a batch-split schedule."""
249+ def __init__ (self ,
250+ config : Config ,
251+ model_mode : str ,
252+ mesh : Mesh ,
253+ rngs : nnx .Rngs ,
254+ quant : quantizations .AqtQuantization | None = None ,):
255+
256+ super ().__init__ (config , model_mode , mesh , rngs , quant )
257+
258+ self .DeepSeekMoeBlock_0 = moe .RoutedAndSharedMoE (
259+ config = self .config ,
260+ mesh = mesh ,
261+ kernel_init = initializers .nd_dense_init (1.0 , "fan_in" , "truncated_normal" ),
262+ kernel_axes = ("embed" , None ),
263+ dtype = self .config .dtype ,
264+ weight_dtype = self .config .weight_dtype ,
265+ quant = quant ,
266+ rngs = self .rngs ,
267+ )
237268
238269 def __call__ (
239270 self ,
@@ -261,8 +292,8 @@ def _merge(x):
261292 return jnp .concatenate (x , axis = 0 )
262293
263294 def _attn (x , decoder_segment_ids , decoder_positions ):
264- return self .attention (
265- self .pre_attention_norm (x ),
295+ return self .attention_op (
296+ self .pre_attention_norm_op (x ),
266297 decoder_segment_ids ,
267298 decoder_positions ,
268299 deterministic ,
@@ -272,7 +303,7 @@ def _attn(x, decoder_segment_ids, decoder_positions):
272303 )
273304
274305 def _moe (x ):
275- return self .mlp (self .post_attention_norm (x ), deterministic )
306+ return self .mlp_op (self .post_attention_norm_op (x ), deterministic )
276307
277308 # Split the inputs into micro-batches.
278309 x = _split (x )
@@ -288,29 +319,13 @@ def _moe(x):
288319 # Merge the micro-batches back into a single batch.
289320 x = _merge (x )
290321
291- x = self .dropout (x , deterministic )
322+ x = self .dropout_op (x , deterministic )
292323 return self .post_process (x )
293324
294- def init (self , * args , ** kwargs ):
295- # Calls the parent init method for testing parity.
296- return super ().init (* args , ** kwargs , method = super ().__call__ )
297-
298- def mlp_layer (self ):
299- # NOTE: the naming mismatch here is to ensure reverse compatibility with
300- # existing checkpoints. The `name` represents the weight name in
301- # JAX/checkpoints and so the class name is just for readability.
302- return moe .get_routed_and_shared_moe (
303- name = "DeepSeekMoeBlock_0" ,
304- config = self .config ,
305- mesh = self .mesh ,
306- kernel_init = initializers .nd_dense_init (
307- 1.0 , "fan_in" , "truncated_normal"
308- ),
309- kernel_axes = ("embed" , None ),
310- dtype = self .config .dtype ,
311- weight_dtype = self .config .weight_dtype ,
312- quant = self .quant ,
313- )
325+ def mlp_op (self , x , _ ):
326+ return self .with_logical_constraint (self .DeepSeekMoeBlock_0 (x ))
314327
315- def mlp (self , x , _ ):
316- return self .with_logical_constraint (self .mlp_op (x ))
328+ DeepSeekMoELayerToLinen = nnx_wrappers .to_linen_class (
329+ DeepSeekMoELayer ,
330+ base_metadata_fn = initializers .variable_to_logically_partitioned ,
331+ )
0 commit comments