Skip to content

Commit 37276d6

Browse files
authored
Make developer experience better for extending AQT (#749)
Make developer experience better
1 parent c2f4460 commit 37276d6

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ def __repr__(self):
7070
# Tensor Subclass Definition #
7171
##############################
7272

73+
74+
class QuantizedLinearNotImplementedError(NotImplementedError):
75+
""" Thin wrapper around NotImplementedError to make it easier to catch this error in the dispatch table """
76+
pass
77+
78+
7379
_QLINEAR_DISPATCH_TABLE = {}
7480
def _register_quantized_linear_dispatch(dispatch_condition, impl):
7581
_QLINEAR_DISPATCH_TABLE[dispatch_condition] = impl
@@ -158,8 +164,7 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias):
158164
for dispatch_condition, impl in _QLINEAR_DISPATCH_TABLE.items():
159165
if dispatch_condition(input_tensor, weight_tensor, bias):
160166
return impl(input_tensor, weight_tensor, bias)
161-
162-
raise NotImplementedError("No specialized dispatch found for quantized linear op")
167+
raise QuantizedLinearNotImplementedError("No specialized dispatch found for quantized linear op")
163168

164169
def __tensor_flatten__(self):
165170
return ["layout_tensor"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype]
@@ -887,7 +892,7 @@ def _(func, types, args, kwargs):
887892
# make the branches easier to understand in `_quantized_linear_op`
888893
try:
889894
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
890-
except:
895+
except QuantizedLinearNotImplementedError:
891896
if isinstance(input_tensor, AffineQuantizedTensor):
892897
input_tensor = input_tensor.dequantize()
893898
if isinstance(weight_tensor, AffineQuantizedTensor):
@@ -910,7 +915,7 @@ def _(func, types, args, kwargs):
910915
try:
911916
weight_tensor = weight_tensor.t()
912917
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
913-
except:
918+
except QuantizedLinearNotImplementedError:
914919
if isinstance(input_tensor, AffineQuantizedTensor):
915920
input_tensor = input_tensor.dequantize()
916921
if isinstance(weight_tensor, AffineQuantizedTensor):
@@ -930,7 +935,7 @@ def _(func, types, args, kwargs):
930935
try:
931936
weight_tensor = weight_tensor.t()
932937
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
933-
except:
938+
except QuantizedLinearNotImplementedError:
934939
if isinstance(input_tensor, AffineQuantizedTensor):
935940
input_tensor = input_tensor.dequantize()
936941
if isinstance(weight_tensor, AffineQuantizedTensor):

torchao/quantization/autoquant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,7 @@ class AQWeightOnlyQuantizedLinearWeight3(AQWeightOnlyQuantizedLinearWeight, AQMi
413413
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that
414414
uses a different kernel
415415
"""
416+
@staticmethod
416417
def _quantized_linear_op(act_mat, w_qtensor, bias):
417418
orig_shape = act_mat.shape
418419
y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.layout_tensor.int_data.t()*w_qtensor.layout_tensor.scale)

0 commit comments

Comments
 (0)