Skip to content

Commit d6f5dc8

Browse files
committed
tokamax megablox kernel
1 parent a15fc00 commit d6f5dc8

File tree

5 files changed

+35
-10
lines changed

5 files changed

+35
-10
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ tensorflow-datasets
3737
tensorflow-text
3838
tensorflow
3939
tiktoken
40+
tokamax>=0.0.3
4041
transformers
4142
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
4243
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip

requirements_with_jax_ai_image.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,5 @@ sentencepiece>=0.2.0
2323
tensorflow-datasets
2424
tensorflow-text>=2.17.0
2525
tiktoken
26+
tokamax>=0.0.3
2627
transformers

src/MaxText/configs/base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -876,3 +876,6 @@ gdn_num_value_heads: 32
876876
gdn_chunk_size: 64
877877
# Whether to apply L2 normalization to query and key tensors inside the Gated Delta Rule kernel.
878878
use_qk_norm_in_gdn: True
879+
880+
# Use tokamax library for gmm kernel implementation
881+
use_tokamax_gmm: false

src/MaxText/layers/moe.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
from MaxText.layers import attentions, linears, quantizations, nnx_wrappers
3737
from MaxText.layers.initializers import NdInitializer, nd_dense_init, default_bias_init, variable_to_logically_partitioned
3838

39+
if jax.__version__ >= "0.8.0":
40+
from tokamax._src.ops.ragged_dot import api as tokamax_api
41+
3942
set_xla_metadata = xla_metadata.set_xla_metadata
4043

4144

@@ -807,16 +810,26 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments):
807810
min(tiling[2], n),
808811
)
809812
if self.config.megablox:
810-
output = mblx.gmm(
811-
lhs=inputs,
812-
rhs=kernel,
813-
group_sizes=group_sizes,
814-
preferred_element_type=self.dtype,
815-
tiling=tiling,
816-
lhs_quantize_dtype=lhs_quantize_dtype,
817-
rhs_quantize_dtype=rhs_quantize_dtype,
818-
use_qwix_quantization=self.config.use_qwix_quantization,
819-
)
813+
if self.config.use_tokamax_gmm:
814+
output = tokamax_api.ragged_dot( # pylint: disable=possibly-used-before-assignment
815+
lhs=inputs,
816+
rhs=kernel,
817+
group_sizes=group_sizes,
818+
precision=jax.lax.Precision.DEFAULT,
819+
preferred_element_type=self.dtype,
820+
implementation="mosaic",
821+
)
822+
else:
823+
output = mblx.gmm(
824+
lhs=inputs,
825+
rhs=kernel,
826+
group_sizes=group_sizes,
827+
preferred_element_type=self.dtype,
828+
tiling=tiling,
829+
lhs_quantize_dtype=lhs_quantize_dtype,
830+
rhs_quantize_dtype=rhs_quantize_dtype,
831+
use_qwix_quantization=self.config.use_qwix_quantization,
832+
)
820833
else:
821834
rhs_inputs = kernel
822835
if isinstance(kernel, aqt.QTensor):

src/MaxText/pyconfig.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,12 @@ def validate_quantization_methods(keys):
287287
raise ValueError(f"Invalid quantization method {keys['quantization']}. Valid options are {valid_quant_methods}")
288288

289289

290+
def validate_tokamax_usage(keys):
291+
"""Validate tokamax usage for gmm kernel"""
292+
if keys["use_tokamax_gmm"] and keys["hardware"] != "tpu":
293+
raise ValueError(f"Invalid tokamax's megablox kernel usage for hardware {keys['hardware']}. Only TPU is supported.")
294+
295+
290296
def validate_data_input(keys):
291297
"""validate provided parameters for data input"""
292298
if not keys["hf_access_token"]:
@@ -734,6 +740,7 @@ def user_init(raw_keys):
734740
validate_data_input(raw_keys)
735741
validate_constant_bound(raw_keys)
736742
validate_quantization_methods(raw_keys)
743+
validate_tokamax_usage(raw_keys)
737744

738745
raw_keys["decoder_block"] = DecoderBlockType(raw_keys["decoder_block"])
739746

0 commit comments

Comments
 (0)