Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions src/compressed_tensors/utils/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,18 @@
"match_modules_set",
"is_match",
"is_narrow_match",
"FusedMapping",
]


FusedMappping = Mapping[str, Iterable[str]]
FusedMapping = Mapping[str, Iterable[str]]


def match_named_modules(
model: torch.nn.Module,
targets: Optional[Iterable[str]],
ignore: Optional[Iterable[str]] = None,
fused: Optional[FusedMappping] = None,
fused: Optional[FusedMapping] = None,
warn_on_fail: bool = False,
) -> Generator[Tuple[str, torch.nn.Module]]:
"""
Expand Down Expand Up @@ -80,7 +81,7 @@ def match_named_parameters(
model: torch.nn.Module,
targets: Optional[Iterable[str]],
ignore: Optional[Iterable[str]] = None,
fused: Optional[FusedMappping] = None,
fused: Optional[FusedMapping] = None,
warn_on_fail: bool = False,
) -> Generator[Tuple[str, torch.nn.Module, torch.nn.Parameter]]:
"""
Expand Down Expand Up @@ -159,9 +160,10 @@ def match_targets(

def match_modules_set(
model: torch.nn.Module,
targets: Optional[Iterable[str]],
targets: Iterable[str] | None,
ignore: Optional[Iterable[str]] = None,
) -> Generator[Iterable[torch.nn.Module]]:
return_unmatched: bool = False,
) -> Generator[Iterable[torch.nn.Module], None, None | dict[str, torch.nn.Module]]:
"""
Yields modules grouped with the same order and size as `targets`.
Values are returned in order of `model.named_modules()`
Expand Down Expand Up @@ -192,11 +194,21 @@ def match_modules_set(
For example, matching layer norms to their subsequent linear layers
```python3
for norm, q, k, v in match_modules_set(model, (norm_tgt, q_tgt, k_tgt, v_tgt)):
fuse_norm_linears(norm, [q, k, v])
apply_smoothing(norm, [q, k, v])
```

Or for matching fused modules in a model
```python3
for q, k, v in match_modules_set(model, (q_tgt, k_tgt, v_tgt)):
fuse_global_scales(q, k, v)
```

:param model: model containing modules to match against
:param targets: target strings, potentially containing "re:" prefixes
:param ignore: targets to ignore, potentially containing "re:" prefixes
:param return_unmatched: if True, return any modules which are partial matches,
otherwise raise a ValueError. Default is False.
:return: if return_unmatched is True, any modules which are partial matches
"""
targets = targets or []
ignore = ignore or []
Expand All @@ -207,26 +219,31 @@ def match_modules_set(
for target in targets:
if is_match(name, module, target, ignore):
if matches[target] is not None:
raise ValueError(f"Matched a {target} twice before completing set")
raise ValueError(
f"Matched a {target} twice before "
f"completing set ({matches[target]}, {name})"
)
matches[target] = module

# once we have a full set, yield and reset
if targets and all((matches[target] is not None for target in targets)):
yield [matches[target] for target in targets] # ensure correct ordering
matches = dict.fromkeys(targets, None)

# check that none are left over
unmatched_keys = [match for match, value in matches.items() if value is not None]
if len(unmatched_keys):
raise ValueError(f"Unable to match targets into set: {unmatched_keys}")
# handle unmatched remainder, if any
if any(value is not None for value in matches.values()):
if not return_unmatched:
raise ValueError(f"Unable to match targets into set: {matches}")
else:
return matches


def is_match(
name: str,
module: torch.nn.Module,
targets: Union[str, Iterable[str]],
ignore: Union[str, Iterable[str]] = tuple(),
fused: Optional[FusedMappping] = None,
fused: Optional[FusedMapping] = None,
) -> bool:
"""
Returns true if either module name or module parent classes match against target
Expand Down Expand Up @@ -289,7 +306,7 @@ def is_narrow_match(
)


def _match_name(name: str, target: str, fused: Optional[FusedMappping] = None) -> bool:
def _match_name(name: str, target: str, fused: Optional[FusedMapping] = None) -> bool:
"""
Returns true if target string begins with "re:" and regex matches or if target
string exactly matches name. If the name refers to a fused module defined by vLLM,
Expand Down
16 changes: 15 additions & 1 deletion tests/test_utils/test_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,21 @@ def test_incomplete_set_error(self):
targets = ["layer1", "nonexistent_module"]

with pytest.raises(ValueError, match="Unable to match targets into set"):
list(match_modules_set(model, targets))
list(match_modules_set(model, targets, return_unmatched=False))

def test_incomplete_set_return(self):
"""Test error when unable to complete a set"""
model = DummyModel()
targets = ["layer1", "nonexistent_module"]

try:
while True:
next(match_modules_set(model, targets, return_unmatched=True))
except StopIteration as exception:
assert exception.value == {
"layer1": model.layer1,
"nonexistent_module": None,
}

def test_duplicate_match_error(self):
"""Test error when same target matches multiple times before set completion"""
Expand Down