Skip to content

Commit 284d6a0

Browse files
text/style fixes
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
1 parent feb3b49 commit 284d6a0

File tree

3 files changed

+22
-10
lines changed

3 files changed

+22
-10
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -338,10 +338,10 @@ def __init__(
338338

339339
self.quantization_compressor = {}
340340
for format in self.compression_formats:
341-
self.quantization_compressor[
342-
format
343-
] = BaseCompressor.load_from_registry(
344-
format, config=quantization_config
341+
self.quantization_compressor[format] = (
342+
BaseCompressor.load_from_registry(
343+
format, config=quantization_config
344+
)
345345
)
346346

347347
def get_missing_module_keys(self, model: Module) -> List[str]:

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,11 @@ def calculate_qparams(
115115
# 4. Update any 0s with small values to
116116
# prevent div by 0
117117
eps = _get_dtype_eps(
118-
dtype=quantization_args.scale_dtype
119-
if quantization_args.scale_dtype is not None
120-
else scales.dtype
118+
dtype=(
119+
quantization_args.scale_dtype
120+
if quantization_args.scale_dtype is not None
121+
else scales.dtype
122+
)
121123
)
122124
scales = torch.where(
123125
scales == 0,

src/compressed_tensors/utils/helpers.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,17 @@
1616
import warnings
1717
from functools import wraps
1818
from types import MappingProxyType
19-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, TypeVar
19+
from typing import (
20+
TYPE_CHECKING,
21+
Any,
22+
Callable,
23+
Dict,
24+
Iterable,
25+
List,
26+
Mapping,
27+
Optional,
28+
TypeVar,
29+
)
2030

2131
import numpy
2232
import torch
@@ -391,9 +401,9 @@ def patch_attrs(bases: Iterable[Any], attr: str, values: Iterable[Any]):
391401
>>> assert not hasattr(obj1, "attribute")
392402
>>> assert not hasattr(obj2, "attribute")
393403
"""
394-
with contextlib.exitstack() as stack:
404+
with contextlib.ExitStack() as stack:
395405
for base, value in zip(bases, values):
396-
stack.add(patch_attr(base, attr, value))
406+
stack.enter_context(patch_attr(base, attr, value))
397407
yield
398408

399409

0 commit comments

Comments
 (0)