Skip to content

Commit 4541954

Browse files
committed
simplify
1 parent 05f287d commit 4541954

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

src/transformers/integrations/accelerate.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import copy
2020
import inspect
2121
import os
22+
import re
2223
from collections import OrderedDict, defaultdict
2324
from contextlib import contextmanager
2425
from typing import TYPE_CHECKING, Optional, Union
@@ -492,11 +493,15 @@ def expand_device_map(device_map, param_names):
492493
"""
493494
Expand a device map to return the correspondence parameter name to device.
494495
"""
496+
# Here, we first sort by number of submodules, then length of the full string, to make sure to match correctly
497+
device_map_regex = re.compile(
498+
"|".join(rf"({k})" for k in sorted(device_map.keys(), key=lambda x: (x.count("."), len(x)), reverse=True))
499+
)
495500
new_device_map = {}
496-
for module, device in device_map.items():
497-
new_device_map.update(
498-
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
499-
)
501+
for param in param_names:
502+
device_match = device_map_regex.match(param)
503+
new_device_map[param] = device_map[device_match.group()] if device_match else device_map.get("", "cpu")
504+
500505
return new_device_map
501506

502507

0 commit comments

Comments
 (0)