|
18 | 18 |
|
19 | 19 | import copy |
20 | 20 | import inspect |
| 21 | +import itertools |
21 | 22 | import os |
22 | 23 | import re |
23 | 24 | from collections import OrderedDict, defaultdict |
24 | 25 | from contextlib import contextmanager |
25 | 26 | from typing import TYPE_CHECKING, Optional, Union |
26 | 27 |
|
| 28 | +from safetensors import safe_open |
27 | 29 | from safetensors.torch import save_file |
28 | 30 |
|
29 | 31 | from ..utils import ( |
@@ -505,36 +507,70 @@ def expand_device_map(device_map, param_names): |
505 | 507 | return new_device_map |
506 | 508 |
|
507 | 509 |
|
| 510 | +def update_param_name(param_name: str, weight_pattern_alt, weight_pattern_by_group_name, source_to_target): |
| 511 | + from ..core_model_loading import match_glob |
| 512 | + |
| 513 | + if weight_pattern_alt is None: |
| 514 | + return param_name |
| 515 | + |
| 516 | + matched_pattern = match_glob(param_name, weight_pattern_alt, weight_pattern_by_group_name) |
| 517 | + if matched_pattern is not None: |
| 518 | + converter = source_to_target[matched_pattern] |
| 519 | + # Only change name if it's a simple renaming, i.e. no custom Ops |
| 520 | + if len(converter.source_keys) == 1 and len(converter.target_keys) == 1: |
| 521 | + source_pattern = converter.source_keys[0] |
| 522 | + target_pattern = converter.target_keys[0] |
| 523 | + return re.sub(source_pattern, target_pattern, param_name) |
| 524 | + return param_name |
| 525 | + |
| 526 | + |
508 | 527 | def accelerate_disk_offload( |
509 | 528 | disk_offload_folder: str | None, |
510 | 529 | checkpoint_files: list[str] | None, |
511 | 530 | device_map: dict, |
512 | 531 | expected_keys: list[str], |
513 | 532 | sharded_metadata: dict | None, |
514 | 533 | dtype: torch.dtype | None, |
| 534 | + weight_mapping=None, |
515 | 535 | ): |
| 536 | + from ..core_model_loading import build_glob_alt |
| 537 | + |
516 | 538 | if disk_offload_folder is not None: |
517 | 539 | os.makedirs(disk_offload_folder, exist_ok=True) |
518 | 540 | is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors") |
519 | 541 |
|
| 542 | + _patterns, weight_pattern_alt, weight_pattern_by_group_name = None, None, None |
| 543 | + if weight_mapping is not None: |
| 544 | + _patterns = list(itertools.chain.from_iterable([k.source_keys for k in weight_mapping])) |
| 545 | + source_to_target = {sk: k for k in weight_mapping for sk in k.source_keys} |
| 546 | + weight_pattern_alt, weight_pattern_by_group_name = build_glob_alt(_patterns) |
| 547 | + |
520 | 548 | # In this case, the offload index is simply the existing safetensors (except if using custom weight loading |
521 | 549 | # Operation, e.g. the MoE models, where we need to resave the weights that were changed at loading time) |
522 | 550 | if is_offloaded_safetensors: |
523 | 551 | param_device_map = expand_device_map(device_map, expected_keys) |
524 | 552 | str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32" |
525 | 553 | if sharded_metadata is None: |
526 | | - weight_map = dict.fromkeys(expected_keys, checkpoint_files[0]) |
| 554 | + weight_map = dict.fromkeys(safe_open(checkpoint_files[0], framework="pt").keys(), checkpoint_files[0]) |
527 | 555 | else: |
528 | 556 | folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1]) |
529 | 557 | weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()} |
| 558 | + |
| 559 | + # Update the weight names according to the `weight_mapping` |
| 560 | + weight_renaming_map = { |
| 561 | + update_param_name(k, weight_pattern_alt, weight_pattern_by_group_name, source_to_target): k |
| 562 | + for k in weight_map |
| 563 | + } |
| 564 | + |
| 565 | + # Prepare the index using existing safetensors files |
530 | 566 | disk_offload_index = { |
531 | | - name: { |
532 | | - "safetensors_file": file, |
533 | | - "weight_name": name, |
| 567 | + target_name: { |
| 568 | + "safetensors_file": weight_map[source_name], |
| 569 | + "weight_name": source_name, |
534 | 570 | "dtype": str_dtype, |
535 | 571 | } |
536 | | - for name, file in weight_map.items() |
537 | | - if param_device_map[name] == "disk" |
| 572 | + for target_name, source_name in weight_renaming_map.items() |
| 573 | + if param_device_map[target_name] == "disk" |
538 | 574 | } |
539 | 575 | # In this case we will resave every offloaded weight |
540 | 576 | else: |
|
0 commit comments