|
44 | 44 | "pack_bitmasks", |
45 | 45 | "unpack_bitmasks", |
46 | 46 | "patch_attr", |
| 47 | + "patch_attrs", |
47 | 48 | "ParameterizedDefaultDict", |
48 | 49 | "get_num_attn_heads", |
49 | 50 | "get_num_kv_heads", |
@@ -368,6 +369,39 @@ def patch_attr(base: object, attr: str, value: Any): |
368 | 369 | delattr(base, attr) |
369 | 370 |
|
370 | 371 |
|
| 372 | +@contextlib.contextmanager |
| 373 | +def patch_attrs(bases: list[object], attr: str, values: list[Any]): |
| 374 | + """ |
| 375 | + Patch attribute for a list of objects with list of values. |
| 376 | + Original values are restored upon exit |
| 377 | +
|
| 378 | + :param bases: objects which has the attribute to patch |
| 379 | + :param attr: name of the the attribute to patch |
| 380 | + :param values: used to replace original values. Must be same |
| 381 | + length as bases |
| 382 | +
|
| 383 | + Usage: |
| 384 | + >>> from types import SimpleNamespace |
| 385 | + >>> obj = SimpleNamespace() |
| 386 | + >>> with patch_attr(obj, "attribute", "value"): |
| 387 | + ... assert obj.attribute == "value" |
| 388 | + >>> assert not hasattr(obj, "attribute") |
| 389 | + """ |
| 390 | + _sentinel = object() |
| 391 | + original_values = [getattr(base, attr, _sentinel) for base in bases] |
| 392 | + |
| 393 | + for base, value in zip(bases, values): |
| 394 | + setattr(base, attr, value) |
| 395 | + try: |
| 396 | + yield |
| 397 | + finally: |
| 398 | + for base, original_value in zip(bases, original_values): |
| 399 | + if original_value is not _sentinel: |
| 400 | + setattr(base, attr, original_value) |
| 401 | + else: |
| 402 | + delattr(base, attr) |
| 403 | + |
| 404 | + |
371 | 405 | class ParameterizedDefaultDict(dict): |
372 | 406 | """ |
373 | 407 | Similar to `collections.DefaultDict`, but upon fetching a key which is missing, |
|
0 commit comments