Skip to content

Commit 790d30c

Browse files
authored
Use weights_only for load (#5)
1 parent d2c5d82 commit 790d30c

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def _load_model(checkpoint_path, device, precision, use_tp):
219219
simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
220220
model = simple_quantizer.convert_for_runtime()
221221

222-
checkpoint = torch.load(str(checkpoint_path), mmap=True)
222+
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
223223
model.load_state_dict(checkpoint, assign=True)
224224

225225
if use_tp:

quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ def quantize(
540540
with torch.device('meta'):
541541
model = Transformer.from_name(checkpoint_path.parent.name)
542542

543-
checkpoint = torch.load(str(checkpoint_path), mmap=True)
543+
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
544544
model.load_state_dict(checkpoint, assign=True)
545545
model = model.to(dtype=precision, device=device)
546546

scripts/convert_hf_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def permute(w, n_head):
6464

6565
merged_result = {}
6666
for file in sorted(bin_files):
67-
state_dict = torch.load(str(file), map_location="cpu", mmap=True)
67+
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
6868
merged_result.update(state_dict)
6969
final_result = {}
7070
for key, value in merged_result.items():

0 commit comments

Comments
 (0)