diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index b92114c4..7fcb8074 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -63,23 +63,22 @@ def permute(w, n_head): .reshape(config.head_dim * n_head, dim) ) - merged_result = {} - for file in sorted(bin_files): - state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) - merged_result.update(state_dict) + final_result = {} - for key, value in merged_result.items(): - if "layers" in key: - abstract_key = re.sub(r'(\d+)', '{}', key) - layer_num = re.search(r'\d+', key).group(0) - new_key = weight_map[abstract_key] - if new_key is None: - continue - new_key = new_key.format(layer_num) - else: - new_key = weight_map[key] - - final_result[new_key] = value + layer_key_pattern = re.compile(r'.(\d+).') # Pre-compile regex pattern + + for file in sorted(pt_files): + state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) + for key, value in state_dict.items(): + if "layers" in key: + abstract_key = layer_key_pattern.sub('.{}.', key) + layer_num = layer_key_pattern.search(key).group(1) + if abstract_key in weight_map: # Check existence before formatting + new_key = weight_map[abstract_key].format(layer_num) + final_result[new_key] = value + else: + if key in weight_map: # Directly check and assign non-layer keys + final_result[weight_map[key]] = value for key in tuple(final_result.keys()): if "wq" in key: