diff --git a/src/MaxText/utils/ckpt_scripts/convert_deepseek_family_ckpt.py b/src/MaxText/utils/ckpt_scripts/convert_deepseek_family_ckpt.py index c33cc7c5a..40f641c26 100644 --- a/src/MaxText/utils/ckpt_scripts/convert_deepseek_family_ckpt.py +++ b/src/MaxText/utils/ckpt_scripts/convert_deepseek_family_ckpt.py @@ -41,7 +41,7 @@ from MaxText import max_logging from MaxText.inference_utils import str2bool -from MaxText import llama_or_mistral_ckpt +from MaxText.utils.ckpt_scripts import llama_or_mistral_ckpt MODEL_PARAMS_DICT = { diff --git a/src/MaxText/utils/ckpt_scripts/convert_deepseek_family_unscanned_ckpt.py b/src/MaxText/utils/ckpt_scripts/convert_deepseek_family_unscanned_ckpt.py index ab0706bc7..6008c61d4 100644 --- a/src/MaxText/utils/ckpt_scripts/convert_deepseek_family_unscanned_ckpt.py +++ b/src/MaxText/utils/ckpt_scripts/convert_deepseek_family_unscanned_ckpt.py @@ -36,7 +36,7 @@ from tqdm import tqdm from MaxText.utils.ckpt_scripts import convert_deepseek_family_ckpt as ds_ckpt -from MaxText import llama_or_mistral_ckpt +from MaxText.utils.ckpt_scripts import llama_or_mistral_ckpt from MaxText import max_logging from MaxText.inference_utils import str2bool from safetensors import safe_open diff --git a/src/MaxText/utils/ckpt_scripts/convert_qwen3_moe.py b/src/MaxText/utils/ckpt_scripts/convert_qwen3_moe.py index f10e516f7..e7d089c7b 100644 --- a/src/MaxText/utils/ckpt_scripts/convert_qwen3_moe.py +++ b/src/MaxText/utils/ckpt_scripts/convert_qwen3_moe.py @@ -32,8 +32,9 @@ from safetensors import safe_open from tqdm import tqdm -from MaxText import llama_or_mistral_ckpt, max_logging +from MaxText import max_logging from MaxText.inference_utils import str2bool +from MaxText.utils.ckpt_scripts import llama_or_mistral_ckpt # Static model parameters dictionary MODEL_PARAMS_DICT = { @@ -45,7 +46,25 @@ "head_dim": 128, "num_experts": 128, "moe_intermediate_size": 1536, - } + }, + "qwen3-30b-a3b": { + "num_hidden_layers": 48, + "num_attention_heads": 32, + "num_key_value_heads": 4, + "hidden_size": 2048, + "head_dim": 128, + "num_experts": 128, + "moe_intermediate_size": 768, + }, + "qwen3-480b-a35b": { + "num_hidden_layers": 62, + "num_attention_heads": 96, + "num_key_value_heads": 8, + "hidden_size": 6144, + "head_dim": 128, + "num_experts": 160, + "moe_intermediate_size": 2560, + }, } diff --git a/src/MaxText/utils/ckpt_scripts/llama_mistral_mixtral_orbax_to_hf.py b/src/MaxText/utils/ckpt_scripts/llama_mistral_mixtral_orbax_to_hf.py index 8a9afe34d..bd303baca 100644 --- a/src/MaxText/utils/ckpt_scripts/llama_mistral_mixtral_orbax_to_hf.py +++ b/src/MaxText/utils/ckpt_scripts/llama_mistral_mixtral_orbax_to_hf.py @@ -48,7 +48,7 @@ from transformers import LlamaForCausalLM, MistralForCausalLM, AutoModelForCausalLM, AutoConfig from MaxText import checkpointing -from MaxText import llama_or_mistral_ckpt +from MaxText.utils.ckpt_scripts import llama_or_mistral_ckpt from MaxText import max_logging from MaxText import maxtext_utils from MaxText import pyconfig