Skip to content

Commit 6b1ef88

Browse files
Merge pull request #2474 from AI-Hypercomputer:mohit/tokamax-gmm
PiperOrigin-RevId: 823078777
2 parents 804ff4e + d6f5dc8 commit 6b1ef88

File tree

5 files changed

+35
-10
lines changed

5 files changed

+35
-10
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ tensorflow-datasets
3737
tensorflow-text
3838
tensorflow
3939
tiktoken
40+
tokamax>=0.0.3
4041
transformers
4142
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
4243
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip

requirements_with_jax_ai_image.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,5 @@ sentencepiece>=0.2.0
2323
tensorflow-datasets
2424
tensorflow-text>=2.17.0
2525
tiktoken
26+
tokamax>=0.0.3
2627
transformers

src/MaxText/configs/base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,3 +878,6 @@ gdn_num_value_heads: 32
878878
gdn_chunk_size: 64
879879
# Whether to apply L2 normalization to query and key tensors inside the Gated Delta Rule kernel.
880880
use_qk_norm_in_gdn: True
881+
882+
# Use tokamax library for gmm kernel implementation
883+
use_tokamax_gmm: false

src/MaxText/layers/moe.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
from MaxText.layers import attentions, linears, quantizations, nnx_wrappers
3737
from MaxText.layers.initializers import NdInitializer, nd_dense_init, default_bias_init, variable_to_logically_partitioned
3838

39+
if jax.__version__ >= "0.8.0":
40+
from tokamax._src.ops.ragged_dot import api as tokamax_api
41+
3942
set_xla_metadata = xla_metadata.set_xla_metadata
4043

4144

@@ -807,16 +810,26 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments):
807810
min(tiling[2], n),
808811
)
809812
if self.config.megablox:
810-
output = mblx.gmm(
811-
lhs=inputs,
812-
rhs=kernel,
813-
group_sizes=group_sizes,
814-
preferred_element_type=self.dtype,
815-
tiling=tiling,
816-
lhs_quantize_dtype=lhs_quantize_dtype,
817-
rhs_quantize_dtype=rhs_quantize_dtype,
818-
use_qwix_quantization=self.config.use_qwix_quantization,
819-
)
813+
if self.config.use_tokamax_gmm:
814+
output = tokamax_api.ragged_dot( # pylint: disable=possibly-used-before-assignment
815+
lhs=inputs,
816+
rhs=kernel,
817+
group_sizes=group_sizes,
818+
precision=jax.lax.Precision.DEFAULT,
819+
preferred_element_type=self.dtype,
820+
implementation="mosaic",
821+
)
822+
else:
823+
output = mblx.gmm(
824+
lhs=inputs,
825+
rhs=kernel,
826+
group_sizes=group_sizes,
827+
preferred_element_type=self.dtype,
828+
tiling=tiling,
829+
lhs_quantize_dtype=lhs_quantize_dtype,
830+
rhs_quantize_dtype=rhs_quantize_dtype,
831+
use_qwix_quantization=self.config.use_qwix_quantization,
832+
)
820833
else:
821834
rhs_inputs = kernel
822835
if isinstance(kernel, aqt.QTensor):

src/MaxText/pyconfig.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,12 @@ def validate_quantization_methods(keys):
290290
raise ValueError(f"Invalid quantization method {keys['quantization']}. Valid options are {valid_quant_methods}")
291291

292292

293+
def validate_tokamax_usage(keys):
294+
"""Validate tokamax usage for gmm kernel"""
295+
if keys["use_tokamax_gmm"] and keys["hardware"] != "tpu":
296+
raise ValueError(f"Invalid tokamax's megablox kernel usage for hardware {keys['hardware']}. Only TPU is supported.")
297+
298+
293299
def validate_data_input(keys):
294300
"""validate provided parameters for data input"""
295301
if not keys["hf_access_token"]:
@@ -737,6 +743,7 @@ def user_init(raw_keys):
737743
validate_data_input(raw_keys)
738744
validate_constant_bound(raw_keys)
739745
validate_quantization_methods(raw_keys)
746+
validate_tokamax_usage(raw_keys)
740747

741748
raw_keys["decoder_block"] = DecoderBlockType(raw_keys["decoder_block"])
742749

0 commit comments

Comments
 (0)