diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index f26400b0..ff3d2344 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -31,17 +31,18 @@ "match_modules_set", "is_match", "is_narrow_match", + "FusedMapping", ] -FusedMappping = Mapping[str, Iterable[str]] +FusedMapping = Mapping[str, Iterable[str]] def match_named_modules( model: torch.nn.Module, targets: Optional[Iterable[str]], ignore: Optional[Iterable[str]] = None, - fused: Optional[FusedMappping] = None, + fused: Optional[FusedMapping] = None, warn_on_fail: bool = False, ) -> Generator[Tuple[str, torch.nn.Module]]: """ @@ -80,7 +81,7 @@ def match_named_parameters( model: torch.nn.Module, targets: Optional[Iterable[str]], ignore: Optional[Iterable[str]] = None, - fused: Optional[FusedMappping] = None, + fused: Optional[FusedMapping] = None, warn_on_fail: bool = False, ) -> Generator[Tuple[str, torch.nn.Module, torch.nn.Parameter]]: """ @@ -159,9 +160,10 @@ def match_targets( def match_modules_set( model: torch.nn.Module, - targets: Optional[Iterable[str]], + targets: Iterable[str] | None, ignore: Optional[Iterable[str]] = None, -) -> Generator[Iterable[torch.nn.Module]]: + return_unmatched: bool = False, +) -> Generator[Iterable[torch.nn.Module], None, None | dict[str, torch.nn.Module]]: """ Yields modules grouped with the same order and size as `targets`. Values are returned in order of `model.named_modules()` @@ -192,11 +194,21 @@ def match_modules_set( For example, matching layer norms to their subsequent linear layers ```python3 for norm, q, k, v in match_modules_set(model, (norm_tgt, q_tgt, k_tgt, v_tgt)): - fuse_norm_linears(norm, [q, k, v]) + apply_smoothing(norm, [q, k, v]) + ``` + + Or for matching fused modules in a model + ```python3 + for q, k, v in match_modules_set(model, (q_tgt, k_tgt, v_tgt)): + fuse_global_scales(q, k, v) + ``` :param model: model containing modules to match against :param targets: target strings, potentially containing "re:" prefixes :param ignore: targets to ignore, potentially containing "re:" prefixes + :param return_unmatched: if True, return any modules which are partial matches, + otherwise raise a ValueError. Default is False. + :return: if return_unmatched is True, any modules which are partial matches """ targets = targets or [] ignore = ignore or [] @@ -207,7 +219,10 @@ def match_modules_set( for target in targets: if is_match(name, module, target, ignore): if matches[target] is not None: - raise ValueError(f"Matched a {target} twice before completing set") + raise ValueError( + f"Matched a {target} twice before " + f"completing set ({matches[target]}, {name})" + ) matches[target] = module # once we have a full set, yield and reset @@ -215,10 +230,12 @@ def match_modules_set( yield [matches[target] for target in targets] # ensure correct ordering matches = dict.fromkeys(targets, None) - # check that none are left over - unmatched_keys = [match for match, value in matches.items() if value is not None] - if len(unmatched_keys): - raise ValueError(f"Unable to match targets into set: {unmatched_keys}") + # handle unmatched remainder, if any + if any(value is not None for value in matches.values()): + if not return_unmatched: + raise ValueError(f"Unable to match targets into set: {matches}") + else: + return matches def is_match( @@ -226,7 +243,7 @@ def is_match( module: torch.nn.Module, targets: Union[str, Iterable[str]], ignore: Union[str, Iterable[str]] = tuple(), - fused: Optional[FusedMappping] = None, + fused: Optional[FusedMapping] = None, ) -> bool: """ Returns true if either module name or module parent classes match against target @@ -289,7 +306,7 @@ def is_narrow_match( ) -def _match_name(name: str, target: str, fused: Optional[FusedMappping] = None) -> bool: +def _match_name(name: str, target: str, fused: Optional[FusedMapping] = None) -> bool: """ Returns true if target string begins with "re:" and regex matches or if target string exactly matches name. If the name refers to a fused module defined by vLLM, diff --git a/tests/test_utils/test_match.py b/tests/test_utils/test_match.py index 1129120c..a15d52d8 100644 --- a/tests/test_utils/test_match.py +++ b/tests/test_utils/test_match.py @@ -457,7 +457,21 @@ def test_incomplete_set_error(self): targets = ["layer1", "nonexistent_module"] with pytest.raises(ValueError, match="Unable to match targets into set"): - list(match_modules_set(model, targets)) + list(match_modules_set(model, targets, return_unmatched=False)) + + def test_incomplete_set_return(self): + """Test error when unable to complete a set""" + model = DummyModel() + targets = ["layer1", "nonexistent_module"] + + try: + while True: + next(match_modules_set(model, targets, return_unmatched=True)) + except StopIteration as exception: + assert exception.value == { + "layer1": model.layer1, + "nonexistent_module": None, + } def test_duplicate_match_error(self): """Test error when same target matches multiple times before set completion"""