@@ -191,46 +191,57 @@ def _linear_fp8_act_fp8_weight_sparse_cutlass_check(input_tensor, weight_tensor,
191191 from torchao .dtypes .floatx import Float8Layout
192192
193193 base_check = (
194- isinstance (input_tensor , AffineQuantizedTensor ) and
195- isinstance (input_tensor ._layout , Float8Layout ) and
196- input_tensor .dtype in (torch .float16 , torch .bfloat16 ) and
197- len (input_tensor .shape ) >= 2 and
198- input_tensor .tensor_impl .scale .dtype == torch .float32 and
199- isinstance (weight_tensor , AffineQuantizedTensor ) and
200- isinstance (weight_tensor ._layout , CutlassSemiSparseLayout ) and
201- weight_tensor .dtype == input_tensor .dtype and
202- len (weight_tensor .shape ) == 2 and
203- weight_tensor .tensor_impl .scale .dtype == torch .float32 and
204- (bias is None or bias .dtype == input_tensor .dtype ) and
205- (bias is None or len (bias .shape ) == 1 )
194+ isinstance (input_tensor , AffineQuantizedTensor )
195+ and isinstance (input_tensor ._layout , Float8Layout )
196+ and input_tensor .dtype in (torch .float16 , torch .bfloat16 )
197+ and len (input_tensor .shape ) >= 2
198+ and input_tensor .tensor_impl .scale .dtype == torch .float32
199+ and isinstance (weight_tensor , AffineQuantizedTensor )
200+ and isinstance (weight_tensor ._layout , CutlassSemiSparseLayout )
201+ and weight_tensor .dtype == input_tensor .dtype
202+ and len (weight_tensor .shape ) == 2
203+ and weight_tensor .tensor_impl .scale .dtype == torch .float32
204+ and (bias is None or bias .dtype == input_tensor .dtype )
205+ and (bias is None or len (bias .shape ) == 1 )
206206 )
207207
208208 if base_check :
209-
210209 # do extra check and reshape if needed
211210 input_tensor_squeezed = False
212- if len (input_tensor .tensor_impl .scale .shape ) == len (input_tensor .shape ) and \
213- len (input_tensor .tensor_impl .scale .shape ) > 1 and \
214- input_tensor .tensor_impl .scale .shape [- 1 ] == 1 :
215- input_tensor .tensor_impl .scale = torch .squeeze (input_tensor .tensor_impl .scale , dim = - 1 )
211+ if (
212+ len (input_tensor .tensor_impl .scale .shape ) == len (input_tensor .shape )
213+ and len (input_tensor .tensor_impl .scale .shape ) > 1
214+ and input_tensor .tensor_impl .scale .shape [- 1 ] == 1
215+ ):
216+ input_tensor .tensor_impl .scale = torch .squeeze (
217+ input_tensor .tensor_impl .scale , dim = - 1
218+ )
216219 input_tensor_squeezed = True
217-
220+
218221 weight_tensor_squeezed = False
219- if len (weight_tensor .tensor_impl .scale .shape ) == 2 and \
220- weight_tensor .tensor_impl .scale .shape [- 1 ] == 1 :
221- weight_tensor .tensor_impl .scale = torch .squeeze (weight_tensor .tensor_impl .scale , dim = - 1 )
222+ if (
223+ len (weight_tensor .tensor_impl .scale .shape ) == 2
224+ and weight_tensor .tensor_impl .scale .shape [- 1 ] == 1
225+ ):
226+ weight_tensor .tensor_impl .scale = torch .squeeze (
227+ weight_tensor .tensor_impl .scale , dim = - 1
228+ )
222229 weight_tensor_squeezed = True
223230
224231 extra_check = (
225232 len (input_tensor .tensor_impl .scale .shape ) == len (input_tensor .shape ) - 1
226233 and len (weight_tensor .tensor_impl .scale .shape ) == 1
227234 )
228235
229- if not extra_check : # revert if extra check failed
236+ if not extra_check : # revert if extra check failed
230237 if input_tensor_squeezed :
231- input_tensor .tensor_impl .scale = torch .unsqueeze (input_tensor .tensor_impl .scale , dim = - 1 )
238+ input_tensor .tensor_impl .scale = torch .unsqueeze (
239+ input_tensor .tensor_impl .scale , dim = - 1
240+ )
232241 if weight_tensor_squeezed :
233- weight_tensor .tensor_impl .scale = torch .unsqueeze (weight_tensor .tensor_impl .scale , dim = - 1 )
242+ weight_tensor .tensor_impl .scale = torch .unsqueeze (
243+ weight_tensor .tensor_impl .scale , dim = - 1
244+ )
234245
235246 return extra_check
236247
0 commit comments