Skip to content

Commit ae30526

Browse files
unit test
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
1 parent 3095fab commit ae30526

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

src/compressed_tensors/utils/helpers.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ def patch_attr(base: object, attr: str, value: Any):
372372
@contextlib.contextmanager
373373
def patch_attrs(bases: list[object], attr: str, values: list[Any]):
374374
"""
375+
Same as `patch_attr` but for a list of objects to patch
375376
Patch attribute for a list of objects with list of values.
376377
Original values are restored upon exit
377378
@@ -383,9 +384,11 @@ def patch_attrs(bases: list[object], attr: str, values: list[Any]):
383384
Usage:
384385
>>> from types import SimpleNamespace
385386
>>> obj = SimpleNamespace()
386-
>>> with patch_attr(obj, "attribute", "value"):
387-
... assert obj.attribute == "value"
388-
>>> assert not hasattr(obj, "attribute")
387+
>>> with patch_attr([obj1, obj2], "attribute", ["value1", "value2"]):
388+
... assert obj1.attribute == "value1"
389+
... assert obj2.attribute == "value2"
390+
>>> assert not hasattr(obj1, "attribute")
391+
>>> assert not hasattr(obj2, "attribute")
389392
"""
390393
_sentinel = object()
391394
original_values = [getattr(base, attr, _sentinel) for base in bases]

tests/test_utils/test_helpers.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
ParameterizedDefaultDict,
2222
load_compressed,
2323
patch_attr,
24+
patch_attrs,
2425
save_compressed,
2526
save_compressed_model,
2627
)
@@ -176,6 +177,23 @@ def test_patch_attr():
176177
assert not hasattr(obj, "attribute")
177178

178179

180+
def test_patch_attrs():
181+
num_objs = 4
182+
objs = [SimpleNamespace() for _ in range(num_objs)]
183+
for idx, obj in enumerate(objs):
184+
if idx % 2 == 0:
185+
obj.attribute = f"original_{idx}"
186+
with patch_attrs(objs, "attribute", [f"patched_{idx}" for idx in range(num_objs)]):
187+
for idx, obj in enumerate(objs):
188+
assert obj.attribute == f"patched_{idx}"
189+
obj.attribute = "modified"
190+
for idx, obj in enumerate(objs):
191+
if idx % 2 == 0:
192+
assert obj.attribute == f"original_{idx}"
193+
else:
194+
assert not hasattr(obj, "attribute")
195+
196+
179197
def test_parameterized_default_dict():
180198
def add_one(value):
181199
return value + 1

0 commit comments

Comments
 (0)