@@ -1119,24 +1119,24 @@ def _aqt_is_int8(aqt):
11191119 """Check if an AffineQuantizedTensor is int8 quantized Tensor"""
11201120 return (
11211121 aqt .layout_tensor .dtype == torch .int8 and
1122- aqt .quant_min is None or aqt .quant_min == - 128 and
1123- aqt .quant_max is None or aqt .quant_max == 127
1122+ ( aqt .quant_min is None or aqt .quant_min == - 128 ) and
1123+ ( aqt .quant_max is None or aqt .quant_max == 127 )
11241124 )
11251125
11261126def _aqt_is_int8_reduced_range (aqt ):
11271127 return (
11281128 aqt .layout_tensor .dtype == torch .int8 and
11291129 aqt .quant_min == - 127 and
1130- aqt .quant_max is None or aqt .quant_max == 127
1130+ ( aqt .quant_max is None or aqt .quant_max == 127 )
11311131 )
11321132
1133- def _aqt_is_uint4 (aqt ):
1133+ def _aqt_is_tensor_core_tile_uint4 (aqt ):
11341134 """Check if an AffineQuantizedTensor is uint4 quantized Tensor"""
11351135 # TODO: use torch.uint4
11361136 return (
11371137 aqt .layout_tensor .dtype == torch .int32 and
1138- aqt .quant_min is None or aqt . quant_min == 0 and
1139- aqt .quant_max is None or aqt . quant_max == 15
1138+ aqt .quant_min == 0 and
1139+ aqt .quant_max == 15
11401140 )
11411141
11421142
@@ -1228,7 +1228,7 @@ def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias):
12281228 input_tensor .dtype == torch .bfloat16 and
12291229 # weight is uint4, group quantized tensor_core_tiled layout affine quantized tensor
12301230 isinstance (weight_tensor , AffineQuantizedTensor ) and
1231- _aqt_is_uint4 (weight_tensor ) and
1231+ _aqt_is_tensor_core_tile_uint4 (weight_tensor ) and
12321232 weight_tensor .dtype == torch .bfloat16 and
12331233 len (weight_tensor .shape ) == 2 and
12341234 weight_tensor .zero_point_domain == ZeroPointDomain .FLOAT and
@@ -1429,7 +1429,7 @@ def _linear_fp_act_fp8_weight_impl(
14291429def _linear_fp_act_int4_weight_sparse_marlin_check (input_tensor , weight_tensor , bias ):
14301430 return (
14311431 isinstance (weight_tensor , AffineQuantizedTensor ) and
1432- _aqt_is_uint4 (weight_tensor ) and
1432+ _aqt_is_tensor_core_tile_uint4 (weight_tensor ) and
14331433 input_tensor .dtype == torch .float16 and
14341434 len (weight_tensor .shape ) == 2 and
14351435 weight_tensor .zero_point_domain == ZeroPointDomain .INT and
0 commit comments