Skip to content

Commit b556bcd

Browse files
committed
feat: migrate deepseek batch split to nnx
1 parent 57e0ece commit b556bcd

File tree

3 files changed

+150
-130
lines changed

3 files changed

+150
-130
lines changed

src/MaxText/layers/decoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def get_decoder_layers(self):
395395
return [mixtral.MixtralDecoderLayerToLinen]
396396
case DecoderBlockType.DEEPSEEK:
397397
if self.config.use_batch_split_schedule:
398-
return [deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer]
398+
return [deepseek_batchsplit.DeepSeekDenseLayerToLinen, deepseek_batchsplit.DeepSeekMoELayerToLinen]
399399
else:
400400
return [deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer]
401401
case DecoderBlockType.GEMMA:

src/MaxText/layers/deepseek_batchsplit.py

Lines changed: 143 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -17,36 +17,95 @@
1717
"""Alternative DeepSeek model definition with batch-split schedule."""
1818

1919
from flax import linen as nn
20+
from flax import nnx
2021
import jax
2122
import jax.numpy as jnp
23+
from jax.sharding import Mesh
2224
from MaxText import common_types
25+
from MaxText import max_utils
26+
from MaxText.common_types import Config
2327
from MaxText.inference import page_manager
2428
from MaxText.layers import attention_mla
2529
from MaxText.layers import initializers
2630
from MaxText.layers import linears
2731
from MaxText.layers import moe
2832
from MaxText.layers import normalizations
33+
from MaxText.layers import nnx_wrappers
2934
from MaxText.layers import quantizations
3035

31-
32-
class DeepSeekGenericLayer(nn.Module):
36+
class DeepSeekBatchSplitGenericLayer(nnx.Module):
3337
"""Generic DeepSeek layer with Multi-Head Latent Attention.
3438
3539
This is to be used as a base class for DeepSeek layers with dense/sparse MLPs.
36-
3740
This class follows a pattern of separating module creation from execution.
38-
`*_layer()` methods (e.g., `attention_layer`) are factories for `nn.Module`s,
39-
called in `setup()` to initialize sub-layers. The module instances are stored
40-
in `*_op` attributes (e.g., `self.attention_op`). The corresponding methods
41-
(e.g., `attention`) are called during execution in `__call__` and wrap the
42-
`*_op` modules with logic like logical constraints. This keeps `__call__`
43-
clean and readable.
4441
"""
42+
def __init__(
43+
self,
44+
config: Config,
45+
model_mode: str,
46+
mesh: Mesh,
47+
rngs: nnx.Rngs,
48+
quant: quantizations.AqtQuantization|None = None,
49+
) -> None:
50+
51+
self.config = config
52+
self.model_mode = model_mode
53+
self.mesh = mesh
54+
self.quant = quant
55+
self.rngs = rngs
56+
57+
batch_size, sequence_length = max_utils.get_batch_seq_len_for_mode(self.config, model_mode)
58+
self.dummy_inputs_shape = (batch_size, sequence_length, self.config.emb_dim)
59+
60+
self.pre_attention_layer_norm = normalizations.RMSNorm(
61+
num_features=self.dummy_inputs_shape[-1],
62+
dtype=config.dtype,
63+
weight_dtype=config.weight_dtype,
64+
kernel_axes=("norm",),
65+
epsilon=config.normalization_layer_epsilon,
66+
rngs=self.rngs,
67+
)
4568

46-
config: common_types.Config
47-
mesh: jax.sharding.Mesh
48-
model_mode: str
49-
quant: None | quantizations.AqtQuantization = None
69+
self.post_attention_layer_norm = normalizations.RMSNorm(
70+
num_features=self.dummy_inputs_shape[-1],
71+
dtype=config.dtype,
72+
weight_dtype=config.weight_dtype,
73+
kernel_axes=("norm",),
74+
epsilon=config.normalization_layer_epsilon,
75+
rngs=self.rngs,
76+
)
77+
78+
self.self_attention = attention_mla.MLA(
79+
config=self.config,
80+
num_query_heads=self.config.num_query_heads,
81+
num_kv_heads=self.config.num_kv_heads,
82+
head_dim=self.config.head_dim,
83+
max_target_length=self.config.max_target_length,
84+
max_prefill_predict_length=self.config.max_prefill_predict_length,
85+
attention_kernel=self.config.attention,
86+
attention_type=self.config.attention_type,
87+
inputs_q_shape=self.dummy_inputs_shape,
88+
inputs_kv_shape=self.dummy_inputs_shape,
89+
mesh=self.mesh,
90+
dtype=self.config.dtype,
91+
weight_dtype=self.config.weight_dtype,
92+
dropout_rate=self.config.dropout_rate,
93+
quant=self.quant,
94+
kv_quant=quantizations.configure_kv_quant(self.config),
95+
q_lora_rank=self.config.q_lora_rank,
96+
kv_lora_rank=self.config.kv_lora_rank,
97+
qk_nope_head_dim=self.config.qk_nope_head_dim,
98+
qk_rope_head_dim=self.config.qk_rope_head_dim,
99+
v_head_dim=self.config.v_head_dim,
100+
max_position_embeddings=self.config.max_position_embeddings,
101+
original_max_position_embeddings=self.config.original_max_position_embeddings,
102+
mscale=self.config.mscale,
103+
rope_factor=self.config.rope_factor,
104+
model_mode=self.model_mode,
105+
rngs=self.rngs,
106+
)
107+
108+
self.dropout = linears.Dropout(rate=self.config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs)
50109

51110
def __call__(
52111
self,
@@ -62,8 +121,8 @@ def __call__(
62121
x = self.with_logical_constraint(inputs)
63122
x = jax.ad_checkpoint.checkpoint_name(x, "decoder_layer_input")
64123

65-
x += self.attention(
66-
self.pre_attention_norm(x),
124+
x += self.attention_op(
125+
self.pre_attention_norm_op(x),
67126
decoder_segment_ids,
68127
decoder_positions,
69128
deterministic,
@@ -72,19 +131,10 @@ def __call__(
72131
slot,
73132
)
74133

75-
x += self.mlp(self.post_attention_norm(x), deterministic)
76-
x = self.dropout(x, deterministic)
134+
x += self.mlp_op(self.post_attention_norm_op(x), deterministic)
135+
x = self.dropout_op(x, deterministic)
77136
return self.post_process(x)
78137

79-
def setup(self):
80-
self.pre_attention_norm_op = self.rms_norm_layer("pre_attention_layer_norm")
81-
self.post_attention_norm_op = self.rms_norm_layer(
82-
"post_attention_layer_norm"
83-
)
84-
self.attention_op = self.attention_layer()
85-
self.mlp_op = self.mlp_layer()
86-
self.dropout_op = self.dropout_layer()
87-
88138
@property
89139
def logical_axis_names(self):
90140
if self.model_mode == common_types.MODEL_MODE_PREFILL:
@@ -103,59 +153,13 @@ def logical_axis_names(self):
103153
def with_logical_constraint(self, x):
104154
return nn.with_logical_constraint(x, self.logical_axis_names)
105155

106-
def rms_norm_layer(self, name):
107-
return normalizations.rms_norm(
108-
num_features=self.config.base_emb_dim,
109-
dtype=self.config.dtype,
110-
weight_dtype=self.config.weight_dtype,
111-
name=name,
112-
kernel_axes=("norm",),
113-
epsilon=self.config.normalization_layer_epsilon,
114-
)
115-
116-
def pre_attention_norm(self, x):
117-
return self.with_logical_constraint(self.pre_attention_norm_op(x))
118-
119-
def post_attention_norm(self, x):
120-
return self.with_logical_constraint(self.post_attention_norm_op(x))
156+
def pre_attention_norm_op(self, x):
157+
return self.with_logical_constraint(self.pre_attention_layer_norm(x))
121158

122-
def attention_layer(self):
123-
inputs_shape = (
124-
self.config.per_device_batch_size,
125-
self.config.max_target_length,
126-
self.config.base_emb_dim,
127-
)
128-
return attention_mla.mla_as_linen(
129-
config=self.config,
130-
num_query_heads=self.config.num_query_heads,
131-
num_kv_heads=self.config.num_kv_heads,
132-
head_dim=self.config.head_dim,
133-
max_target_length=self.config.max_target_length,
134-
max_prefill_predict_length=self.config.max_prefill_predict_length,
135-
attention_kernel=self.config.attention,
136-
attention_type=self.config.attention_type,
137-
inputs_q_shape=inputs_shape,
138-
inputs_kv_shape=inputs_shape,
139-
mesh=self.mesh,
140-
dtype=self.config.dtype,
141-
weight_dtype=self.config.weight_dtype,
142-
dropout_rate=self.config.dropout_rate,
143-
name="self_attention",
144-
quant=self.quant,
145-
kv_quant=quantizations.configure_kv_quant(self.config),
146-
q_lora_rank=self.config.q_lora_rank,
147-
kv_lora_rank=self.config.kv_lora_rank,
148-
qk_nope_head_dim=self.config.qk_nope_head_dim,
149-
qk_rope_head_dim=self.config.qk_rope_head_dim,
150-
v_head_dim=self.config.v_head_dim,
151-
max_position_embeddings=self.config.max_position_embeddings,
152-
original_max_position_embeddings=self.config.original_max_position_embeddings,
153-
mscale=self.config.mscale,
154-
rope_factor=self.config.rope_factor,
155-
model_mode=self.model_mode,
156-
)
159+
def post_attention_norm_op(self, x):
160+
return self.with_logical_constraint(self.post_attention_layer_norm(x))
157161

158-
def attention(
162+
def attention_op(
159163
self,
160164
x,
161165
decoder_segment_ids,
@@ -167,7 +171,7 @@ def attention(
167171
):
168172
"""Executes the attention layer."""
169173
return self.with_logical_constraint(
170-
self.attention_op(
174+
self.self_attention(
171175
x,
172176
x,
173177
decoder_positions,
@@ -180,60 +184,87 @@ def attention(
180184
)
181185
)
182186

183-
def mlp_layer(self):
187+
def mlp_op(self, x, deterministic):
188+
"""Executes the MLP operation. To be implemented by subclasses."""
184189
raise NotImplementedError()
185190

186-
def mlp(self, x, deterministic):
187-
raise NotImplementedError()
188-
189-
def dropout_layer(self):
190-
return nn.Dropout(rate=self.config.dropout_rate, broadcast_dims=(-2,))
191-
192-
def dropout(self, x, deterministic):
191+
def dropout_op(self, x, deterministic):
193192
return self.with_logical_constraint(
194-
self.dropout_op(x, deterministic=deterministic)
193+
self.dropout(x, deterministic=deterministic)
195194
)
196195

197196
def post_process(self, x):
198197
"""Collect statistics about the output of the layer."""
199198
if self.config.record_internal_nn_metrics:
200-
self.sow("intermediates", "activation_mean", jnp.mean(x))
201-
self.sow("intermediates", "activation_stdev", jnp.std(x))
199+
self.sow(nnx.Intermediate, "activation_mean", jnp.mean(x))
200+
self.sow(nnx.Intermediate, "activation_stdev", jnp.std(x))
202201
self.sow(
203-
"intermediates",
202+
nnx.Intermediate,
204203
"activation_fraction_zero",
205204
jnp.sum(x == 0) / jnp.size(x),
206205
)
207206

208207
if self.config.scan_layers:
209208
return x, None
210-
else:
211-
return x
209+
return x
212210

213211

214-
class DeepSeekDenseLayer(DeepSeekGenericLayer):
212+
class DeepSeekDenseLayer(DeepSeekBatchSplitGenericLayer):
215213
"""DeepSeek layer with dense MLP."""
216214

217-
def mlp_layer(self):
218-
return linears.mlp_block(
219-
in_features=self.config.base_emb_dim,
215+
def __init__(self,
216+
config: Config,
217+
model_mode: str,
218+
mesh: Mesh,
219+
rngs: nnx.Rngs,
220+
quant: quantizations.AqtQuantization|None = None,):
221+
222+
super().__init__(config, model_mode, mesh, rngs, quant)
223+
224+
self.mlp = linears.MlpBlock(
225+
config=self.config,
226+
mesh=self.mesh,
227+
in_features=self.dummy_inputs_shape[-1],
220228
intermediate_dim=self.config.mlp_dim,
221229
activations=self.config.mlp_activations,
222230
intermediate_dropout_rate=self.config.dropout_rate,
223231
dtype=self.config.dtype,
224232
weight_dtype=self.config.weight_dtype,
225-
name="mlp",
226-
config=self.config,
227233
quant=self.quant,
228-
mesh=self.mesh,
234+
model_mode=model_mode,
235+
rngs=self.rngs,
229236
)
230237

231-
def mlp(self, x, deterministic):
232-
return self.with_logical_constraint(self.mlp_op(x, deterministic))
238+
def mlp_op(self, x, deterministic):
239+
return self.with_logical_constraint(self.mlp(x, deterministic))
240+
233241

242+
DeepSeekDenseLayerToLinen = nnx_wrappers.to_linen_class(
243+
DeepSeekDenseLayer,
244+
base_metadata_fn=initializers.variable_to_logically_partitioned,
245+
)
234246

235-
class DeepSeekMoELayer(DeepSeekGenericLayer):
247+
class DeepSeekMoELayer(DeepSeekBatchSplitGenericLayer):
236248
"""DeepSeek MoE layer that uses a batch-split schedule."""
249+
def __init__(self,
250+
config: Config,
251+
model_mode: str,
252+
mesh: Mesh,
253+
rngs: nnx.Rngs,
254+
quant: quantizations.AqtQuantization|None = None,):
255+
256+
super().__init__(config, model_mode, mesh, rngs, quant)
257+
258+
self.DeepSeekMoeBlock_0 = moe.RoutedAndSharedMoE(
259+
config=self.config,
260+
mesh=mesh,
261+
kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"),
262+
kernel_axes=("embed", None),
263+
dtype=self.config.dtype,
264+
weight_dtype=self.config.weight_dtype,
265+
quant=quant,
266+
rngs=self.rngs,
267+
)
237268

238269
def __call__(
239270
self,
@@ -261,8 +292,8 @@ def _merge(x):
261292
return jnp.concatenate(x, axis=0)
262293

263294
def _attn(x, decoder_segment_ids, decoder_positions):
264-
return self.attention(
265-
self.pre_attention_norm(x),
295+
return self.attention_op(
296+
self.pre_attention_norm_op(x),
266297
decoder_segment_ids,
267298
decoder_positions,
268299
deterministic,
@@ -272,7 +303,7 @@ def _attn(x, decoder_segment_ids, decoder_positions):
272303
)
273304

274305
def _moe(x):
275-
return self.mlp(self.post_attention_norm(x), deterministic)
306+
return self.mlp_op(self.post_attention_norm_op(x), deterministic)
276307

277308
# Split the inputs into micro-batches.
278309
x = _split(x)
@@ -288,29 +319,13 @@ def _moe(x):
288319
# Merge the micro-batches back into a single batch.
289320
x = _merge(x)
290321

291-
x = self.dropout(x, deterministic)
322+
x = self.dropout_op(x, deterministic)
292323
return self.post_process(x)
293324

294-
def init(self, *args, **kwargs):
295-
# Calls the parent init method for testing parity.
296-
return super().init(*args, **kwargs, method=super().__call__)
297-
298-
def mlp_layer(self):
299-
# NOTE: the naming mismatch here is to ensure reverse compatibility with
300-
# existing checkpoints. The `name` represents the weight name in
301-
# JAX/checkpoints and so the class name is just for readability.
302-
return moe.get_routed_and_shared_moe(
303-
name="DeepSeekMoeBlock_0",
304-
config=self.config,
305-
mesh=self.mesh,
306-
kernel_init=initializers.nd_dense_init(
307-
1.0, "fan_in", "truncated_normal"
308-
),
309-
kernel_axes=("embed", None),
310-
dtype=self.config.dtype,
311-
weight_dtype=self.config.weight_dtype,
312-
quant=self.quant,
313-
)
325+
def mlp_op(self, x, _):
326+
return self.with_logical_constraint(self.DeepSeekMoeBlock_0(x))
314327

315-
def mlp(self, x, _):
316-
return self.with_logical_constraint(self.mlp_op(x))
328+
DeepSeekMoELayerToLinen = nnx_wrappers.to_linen_class(
329+
DeepSeekMoELayer,
330+
base_metadata_fn=initializers.variable_to_logically_partitioned,
331+
)

0 commit comments

Comments
 (0)