Skip to content

Commit 38deedf

Browse files
chr1sj0nesGoogle-ML-Automation
authored andcommitted
Internal change.
PiperOrigin-RevId: 831169654
1 parent bd2c072 commit 38deedf

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

src/MaxText/layers/moe.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,20 @@
2121
from typing import Iterable, Optional, Tuple, Union
2222

2323
from aqt.jax.v2 import aqt_tensor as aqt
24-
import flax.linen as nn
2524
from flax import nnx
25+
import flax.linen as nn
2626
import jax
2727
from jax import ad_checkpoint as adc
2828
from jax.experimental import xla_metadata
2929
import jax.numpy as jnp
30-
import numpy as np
31-
3230
from MaxText import common_types as ctypes
3331
from MaxText import max_logging
3432
from MaxText import max_utils
3533
from MaxText.kernels import megablox as mblx
36-
from MaxText.layers import attentions, linears, quantizations, nnx_wrappers
37-
from MaxText.layers.initializers import NdInitializer, nd_dense_init, default_bias_init, variable_to_logically_partitioned
38-
39-
from tokamax._src.ops.ragged_dot import api as tokamax_api
34+
from MaxText.layers import attentions, linears, nnx_wrappers, quantizations
35+
from MaxText.layers.initializers import NdInitializer, default_bias_init, nd_dense_init, variable_to_logically_partitioned
36+
import numpy as np
37+
import tokamax
4038

4139
set_xla_metadata = xla_metadata.set_xla_metadata
4240

@@ -812,7 +810,7 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments):
812810
min(tiling[2], n),
813811
)
814812
if self.config.use_tokamax_gmm:
815-
output = tokamax_api.ragged_dot(
813+
output = tokamax.ragged_dot(
816814
lhs=inputs,
817815
rhs=kernel,
818816
group_sizes=group_sizes,

0 commit comments

Comments
 (0)