Skip to content

Commit 3095fab

Browse files
patch_attrs helper
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
1 parent 2763f81 commit 3095fab

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

src/compressed_tensors/utils/helpers.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
"pack_bitmasks",
4545
"unpack_bitmasks",
4646
"patch_attr",
47+
"patch_attrs",
4748
"ParameterizedDefaultDict",
4849
"get_num_attn_heads",
4950
"get_num_kv_heads",
@@ -368,6 +369,39 @@ def patch_attr(base: object, attr: str, value: Any):
368369
delattr(base, attr)
369370

370371

372+
@contextlib.contextmanager
373+
def patch_attrs(bases: list[object], attr: str, values: list[Any]):
374+
"""
375+
Patch attribute for a list of objects with list of values.
376+
Original values are restored upon exit
377+
378+
:param bases: objects which has the attribute to patch
379+
:param attr: name of the the attribute to patch
380+
:param values: used to replace original values. Must be same
381+
length as bases
382+
383+
Usage:
384+
>>> from types import SimpleNamespace
385+
>>> obj = SimpleNamespace()
386+
>>> with patch_attr(obj, "attribute", "value"):
387+
... assert obj.attribute == "value"
388+
>>> assert not hasattr(obj, "attribute")
389+
"""
390+
_sentinel = object()
391+
original_values = [getattr(base, attr, _sentinel) for base in bases]
392+
393+
for base, value in zip(bases, values):
394+
setattr(base, attr, value)
395+
try:
396+
yield
397+
finally:
398+
for base, original_value in zip(bases, original_values):
399+
if original_value is not _sentinel:
400+
setattr(base, attr, original_value)
401+
else:
402+
delattr(base, attr)
403+
404+
371405
class ParameterizedDefaultDict(dict):
372406
"""
373407
Similar to `collections.DefaultDict`, but upon fetching a key which is missing,

0 commit comments

Comments
 (0)