5353
5454import jax
5555import os
56- from typing import Sequence , Any
56+ from typing import Sequence
5757import time
5858from tqdm import tqdm
59+ import numpy as np
5960
6061from transformers import AutoTokenizer , AutoProcessor
6162
7172)
7273from MaxText .utils .ckpt_conversion .utils .hf_shape import HF_SHAPE
7374from MaxText .utils .ckpt_conversion .utils .hf_model_configs import HF_MODEL_CONFIGS
74- from MaxText .utils .ckpt_conversion .utils .utils import ( process_leaf_param , save_model_files , HF_IDS )
75+ from MaxText .utils .ckpt_conversion .utils .utils import process_maxtext_param , save_model_files , HF_IDS
7576
7677
7778os .environ ["JAX_PLATFORMS" ] = "cpu"
7879os .environ ["XLA_FLAGS" ] = "--xla_force_host_platform_device_count=16"
7980
8081
81- def _get_model_mappings (model_name : str , scan_layers : bool , config_dict : dict ):
82+ def _get_model_mappings (
83+ model_name : str , scan_layers : bool , hf_config_dict : dict , maxtext_config : pyconfig .HyperParameters
84+ ):
8285 """Retrieves parameter, shape, and hook function mappings for the model.
8386
8487 Args:
8588 model_name: The name of the model (e.g., "gemma2-2b").
8689 scan_layers: Boolean indicating if the model was trained with scanned layers.
87- config_dict: The Hugging Face model configuration dictionary.
90+ hf_config_dict: The Hugging Face model configuration dictionary.
91+ maxtext_config: The maxtext model configuration.
8892
8993 Returns:
9094 A dictionary containing the parameter mapping, shape mapping, and hook
@@ -97,12 +101,65 @@ def _get_model_mappings(model_name: str, scan_layers: bool, config_dict: dict):
97101 raise ValueError (f"Mappings not found for model: { model_name } . Available PARAM_MAPPING keys: { PARAM_MAPPING .keys ()} " )
98102
99103 return {
100- "param_mapping" : PARAM_MAPPING [model_name ](config_dict , scan_layers ),
101- "shape_mapping" : HF_SHAPE [model_name ](config_dict ),
102- "hook_fn_mapping" : HOOK_FNS [model_name ](config_dict , scan_layers , saving_to_hf = True ),
104+ "param_mapping" : PARAM_MAPPING [model_name ](hf_config_dict , maxtext_config , scan_layers ),
105+ "shape_mapping" : HF_SHAPE [model_name ](hf_config_dict ),
106+ "hook_fn_mapping" : HOOK_FNS [model_name ](hf_config_dict , maxtext_config , scan_layers , saving_to_hf = True ),
103107 }
104108
105109
110+ def _check_param_map_keys (param_map_keys , maxtext_state_keys ):
111+ """Validates map coverage, handles N-to-1 mappings, and filters unused keys.
112+
113+ Ensures every MaxText checkpoint key (`maxtext_state_keys`) is covered by
114+ the flattened parameter map. Keys in the map that are not present in the
115+ checkpoint (common for multi-variant maps like gemma3, qwen3, deepseek) are skipped.
116+
117+ Tuple keys represent N-to-1 mappings (multiple MaxText keys combining into one
118+ target key) and are only returned if all constituent keys exist in the checkpoint.
119+
120+ Args:
121+ param_map_keys: Keys from the parameter mapping (strings or N-to-1 tuples).
122+ maxtext_state_keys: Set of parameter keys loaded from the MaxText checkpoint.
123+
124+ Returns:
125+ A list of 'filtered' mapping keys (strings or tuples) that are fully present
126+ and valid based on `maxtext_state_keys`.
127+
128+ Raises:
129+ ValueError: If `maxtext_state_keys` is NOT a subset of the flattened
130+ `param_map_keys`.
131+ """
132+ flattened_map_keys = set ()
133+ for key in param_map_keys :
134+ if isinstance (key , tuple ):
135+ flattened_map_keys .update (key )
136+ else :
137+ flattened_map_keys .add (key )
138+
139+ # every maxtext state key must be covered by param map
140+ missing_keys = maxtext_state_keys - flattened_map_keys
141+ if missing_keys :
142+ raise ValueError (
143+ "maxtext_state_dict must be a subset of flattened param_map"
144+ + f"\n param map\n { param_map_keys } "
145+ + f"\n maxtext:\n { maxtext_state_keys } "
146+ )
147+
148+ # param map may have extra keys
149+ extra_keys = flattened_map_keys - maxtext_state_keys
150+ if extra_keys :
151+ max_logging .log (f"Warning: extra keys in param_map are skipped: { extra_keys } " )
152+
153+ # skip extra keys in param map
154+ filtered_map_keys = []
155+ for key in param_map_keys :
156+ if (isinstance (key , str ) and key in maxtext_state_keys ) or (
157+ isinstance (key , tuple ) and all (k in maxtext_state_keys for k in key )
158+ ):
159+ filtered_map_keys .append (key )
160+ return filtered_map_keys
161+
162+
106163def main (argv : Sequence [str ]) -> None :
107164 """Main function to convert a MaxText checkpoint to HuggingFace format.
108165
@@ -156,29 +213,52 @@ def main(argv: Sequence[str]) -> None:
156213 processor = AutoProcessor .from_pretrained (hf_tokenizer_id , token = hf_token ) if config .use_multimodal else None
157214
158215 # 3. Get parameter mappings
159- mappings = _get_model_mappings (model_key , config .scan_layers , hf_config_obj .to_dict ())
216+ mappings = _get_model_mappings (model_key , config .scan_layers , hf_config_obj .to_dict (), config )
160217 param_map = mappings ["param_mapping" ]
161218 shape_map = mappings ["shape_mapping" ] # HF target shapes
162219 hook_fn_map = mappings ["hook_fn_mapping" ]
163220
164221 # 4. Transform Weights
165- transformed_hf_weights : dict [str , Any ] = {}
166-
167222 # MaxText `engine.load_params()` returns `state.params` (a FrozenDict).
168223 # The actual weights are typically under `state.params['params']`.
169224 actual_weights_dict = loaded_params_from_engine .get ("params" )
170225 if actual_weights_dict is None :
171226 raise ValueError ("Loaded parameters from engine do not contain a 'params' key. Structure might be unexpected." )
172-
173227 leaves_with_paths = jax .tree_util .tree_leaves_with_path (actual_weights_dict )
174228
175- # traverse leavse to build: mt_param_key:mt_weights
229+ # Construct maxtext_state_dict: {parameter name: parameter weight}
230+ maxtext_state_dict = {}
231+ for path_tuple , leaf_value in leaves_with_paths :
232+ # Construct maxtext_param_key from path_tuple
233+ maxtext_param_key = "params-" + "-" .join (k .key for k in path_tuple )
234+ # Check leaf value is an array
235+ if not isinstance (leaf_value , (jax .Array , np .ndarray )):
236+ raise ValueError (f"Leaf value for { maxtext_param_key } is not an array. Type: { type (leaf_value )} ." )
237+ maxtext_state_dict [maxtext_param_key ] = leaf_value
238+
239+ # The param_map may contain tuples as keys, which represent N-to-1 mappings from maxtext to huggingface
240+ # Check maxtext_state_dict is a subset of flattened param_map
241+ # Skip extra keys from param_map
242+ filtered_map_keys = _check_param_map_keys (param_map .keys (), maxtext_state_dict .keys ())
243+
244+ # Iterate through the parameter map to transform and collect weights.
245+ # This loop handles both simple 1-to-1 mappings and complex N-to-1 mappings
246+ # (where multiple MaxText weights are combined into a single HF weight).
176247 max_logging .log ("\n Proccessing weight..." )
177248 start = time .time ()
178249 processed_params_list = []
179- for path_tuple_iter , leaf_value_iter in tqdm (leaves_with_paths , total = len (leaves_with_paths )):
180- processed_params = process_leaf_param (path_tuple_iter , leaf_value_iter , param_map , shape_map , hook_fn_map , config )
250+
251+ for key in tqdm (filtered_map_keys , total = len (filtered_map_keys )):
252+ if isinstance (key , tuple ):
253+ # if key is tuple of param names, weight is list of param weights
254+ weight = [maxtext_state_dict [subkey ] for subkey in key ]
255+ else :
256+ # if key is single param name, weight is single param weight
257+ weight = maxtext_state_dict [key ]
258+
259+ processed_params = process_maxtext_param (key , weight , param_map , hook_fn_map , shape_map , config )
181260 processed_params_list .extend (processed_params )
261+
182262 transformed_hf_weights = dict (processed_params_list )
183263 max_logging .log (f"Elapse: { (time .time () - start ) / 60 :.2f} min" )
184264
0 commit comments