Skip to content

Commit 86bcb8b

Browse files
committed
Fix linter
Signed-off-by: Vladimir Suvorov <suvorovv@google.com>
1 parent 38fd035 commit 86bcb8b

File tree

1 file changed

+2
-17
lines changed

1 file changed

+2
-17
lines changed

src/MaxText/layerwise_quantization.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
import jax.numpy as jnp
3939
from absl import app
4040

41-
from flax import nnx
4241
from flax.linen import partitioning as nn_partitioning
4342

4443

@@ -94,29 +93,15 @@ def load_and_quantize(self, rng: None | PRNGKeyType = None) -> None:
9493

9594
self.quant.quant_mode = quantizations.get_quant_mode("convert")
9695

97-
if rng is None:
98-
rng = jax.random.PRNGKey(0)
99-
100-
dense_init_rng, moe_init_rng = jax.random.split(rng)
101-
dense_params_rng, dense_dropout_rng = jax.random.split(dense_init_rng)
102-
moe_params_rng, moe_dropout_rng = jax.random.split(moe_init_rng)
103-
104-
dense_layer_rngs = nnx.Rngs(params=dense_params_rng, dropout=dense_dropout_rng)
105-
moe_layer_rngs = nnx.Rngs(params=moe_params_rng, dropout=moe_dropout_rng)
106-
10796
layers = [
108-
deepseek.DeepSeekDenseLayer(
97+
deepseek.DeepSeekDenseLayer( # pylint: disable=no-value-for-parameter
10998
config,
110-
model_mode=common_types.MODEL_MODE_TRAIN,
11199
mesh=self._mesh,
112-
rngs=dense_layer_rngs,
113100
quant=self.quant,
114101
),
115-
deepseek.DeepSeekMoELayer(
102+
deepseek.DeepSeekMoELayer( # pylint: disable=no-value-for-parameter
116103
config,
117-
model_mode=common_types.MODEL_MODE_TRAIN,
118104
mesh=self._mesh,
119-
rngs=moe_layer_rngs,
120105
quant=self.quant,
121106
),
122107
]

0 commit comments

Comments
 (0)