From 62adb6c1fb1794eae5a78a1776830fb920040e99 Mon Sep 17 00:00:00 2001 From: hvaria Date: Mon, 11 Mar 2024 21:13:37 -0700 Subject: [PATCH] Optimized the process of loading PyTorch state dictionaries, merging them and remapping their keys Integrating loading, merging, and remapping into one step reduces the overall processing time by minimizing redundant operations. Pre-compiling the regular expression used for identifying and transforming layer-related keys cuts down the processing time required for these operations. By directly updating the final results without an intermediate merged dictionary, the memory footprint is reduced. --- scripts/convert_hf_checkpoint.py | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index b92114c4..7fcb8074 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -63,23 +63,22 @@ def permute(w, n_head): .reshape(config.head_dim * n_head, dim) ) - merged_result = {} - for file in sorted(bin_files): - state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) - merged_result.update(state_dict) + final_result = {} - for key, value in merged_result.items(): - if "layers" in key: - abstract_key = re.sub(r'(\d+)', '{}', key) - layer_num = re.search(r'\d+', key).group(0) - new_key = weight_map[abstract_key] - if new_key is None: - continue - new_key = new_key.format(layer_num) - else: - new_key = weight_map[key] - - final_result[new_key] = value + layer_key_pattern = re.compile(r'.(\d+).') # Pre-compile regex pattern + + for file in sorted(pt_files): + state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) + for key, value in state_dict.items(): + if "layers" in key: + abstract_key = layer_key_pattern.sub('.{}.', key) + layer_num = layer_key_pattern.search(key).group(1) + if abstract_key in weight_map: # Check existence before formatting + new_key = weight_map[abstract_key].format(layer_num) + final_result[new_key] = value + else: + if key in weight_map: # Directly check and assign non-layer keys + final_result[weight_map[key]] = value for key in tuple(final_result.keys()): if "wq" in key: