@@ -70,6 +70,12 @@ def __repr__(self):
7070# Tensor Subclass Definition #
7171##############################
7272
73+
74+ class QuantizedLinearNotImplementedError (NotImplementedError ):
75+ """ Thin wrapper around NotImplementedError to make it easier to catch this error in the dispatch table """
76+ pass
77+
78+
7379_QLINEAR_DISPATCH_TABLE = {}
7480def _register_quantized_linear_dispatch (dispatch_condition , impl ):
7581 _QLINEAR_DISPATCH_TABLE [dispatch_condition ] = impl
@@ -158,8 +164,7 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias):
158164 for dispatch_condition , impl in _QLINEAR_DISPATCH_TABLE .items ():
159165 if dispatch_condition (input_tensor , weight_tensor , bias ):
160166 return impl (input_tensor , weight_tensor , bias )
161-
162- raise NotImplementedError ("No specialized dispatch found for quantized linear op" )
167+ raise QuantizedLinearNotImplementedError ("No specialized dispatch found for quantized linear op" )
163168
164169 def __tensor_flatten__ (self ):
165170 return ["layout_tensor" ], [self .block_size , self .shape , self .quant_min , self .quant_max , self .zero_point_domain , self .dtype ]
@@ -887,7 +892,7 @@ def _(func, types, args, kwargs):
887892 # make the branches easier to understand in `_quantized_linear_op`
888893 try :
889894 return weight_tensor ._quantized_linear_op (input_tensor , weight_tensor , bias )
890- except :
895+ except QuantizedLinearNotImplementedError :
891896 if isinstance (input_tensor , AffineQuantizedTensor ):
892897 input_tensor = input_tensor .dequantize ()
893898 if isinstance (weight_tensor , AffineQuantizedTensor ):
@@ -910,7 +915,7 @@ def _(func, types, args, kwargs):
910915 try :
911916 weight_tensor = weight_tensor .t ()
912917 return weight_tensor ._quantized_linear_op (input_tensor , weight_tensor , bias )
913- except :
918+ except QuantizedLinearNotImplementedError :
914919 if isinstance (input_tensor , AffineQuantizedTensor ):
915920 input_tensor = input_tensor .dequantize ()
916921 if isinstance (weight_tensor , AffineQuantizedTensor ):
@@ -930,7 +935,7 @@ def _(func, types, args, kwargs):
930935 try :
931936 weight_tensor = weight_tensor .t ()
932937 return weight_tensor ._quantized_linear_op (input_tensor , weight_tensor , bias )
933- except :
938+ except QuantizedLinearNotImplementedError :
934939 if isinstance (input_tensor , AffineQuantizedTensor ):
935940 input_tensor = input_tensor .dequantize ()
936941 if isinstance (weight_tensor , AffineQuantizedTensor ):
0 commit comments