Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,23 +63,22 @@ def permute(w, n_head):
.reshape(config.head_dim * n_head, dim)
)

merged_result = {}
for file in sorted(bin_files):
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
merged_result.update(state_dict)

final_result = {}
for key, value in merged_result.items():
if "layers" in key:
abstract_key = re.sub(r'(\d+)', '{}', key)
layer_num = re.search(r'\d+', key).group(0)
new_key = weight_map[abstract_key]
if new_key is None:
continue
new_key = new_key.format(layer_num)
else:
new_key = weight_map[key]

final_result[new_key] = value
layer_key_pattern = re.compile(r'.(\d+).') # Pre-compile regex pattern

for file in sorted(pt_files):
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
for key, value in state_dict.items():
if "layers" in key:
abstract_key = layer_key_pattern.sub('.{}.', key)
layer_num = layer_key_pattern.search(key).group(1)
if abstract_key in weight_map: # Check existence before formatting
new_key = weight_map[abstract_key].format(layer_num)
final_result[new_key] = value
else:
if key in weight_map: # Directly check and assign non-layer keys
final_result[weight_map[key]] = value

for key in tuple(final_result.keys()):
if "wq" in key:
Expand Down