2020 LayoutType ,
2121 PlainLayoutType ,
2222)
23- from torchao .utils import TorchAOBaseTensor , _register_layout_cls , _get_layout_tensor_constructor
23+ from torchao .utils import TorchAOBaseTensor
2424
2525aten = torch .ops .aten
2626
@@ -191,12 +191,8 @@ def _apply_fn_to_data(self, fn):
191191# LayoutType and Layout Tensor Subclass Registration #
192192######################################################
193193
194- def register_layout_cls (layout_type_class : type (LayoutType )):
195- return _register_layout_cls (MyDTypeTensor , layout_type_class )
196-
197- def get_layout_tensor_constructor (layout_type_class : type (LayoutType )):
198- return _get_layout_tensor_constructor (MyDTypeTensor , layout_type_class )
199-
194+ register_layout_cls = MyDTypeTensor .register_layout_cls
195+ get_layout_tensor_constructor = MyDTypeTensor .get_layout_tensor_constructor
200196
201197@register_layout_cls (PlainLayoutType )
202198class PlainMyDTypeLayout (MyDTypeLayout ):
@@ -343,12 +339,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
343339
344340for _ in range (NUM_WARMUPS ):
345341 m (* example_inputs )
346- print ("before quantization:" , benchmark_model (m , NUM_RUNS , example_inputs [ 0 ] ))
342+ print ("before quantization:" , benchmark_model (m , NUM_RUNS , example_inputs ))
347343
348344compiled = torch .compile (m , mode = "max-autotune" )
349345for _ in range (NUM_WARMUPS ):
350346 compiled (* example_inputs )
351- print ("after compile:" , benchmark_model (compiled , NUM_RUNS , example_inputs [ 0 ] ))
347+ print ("after compile:" , benchmark_model (compiled , NUM_RUNS , example_inputs ))
352348
353349# convert weights to quantized weights
354350m .linear .weight = torch .nn .Parameter (
@@ -358,7 +354,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
358354for _ in range (NUM_WARMUPS ):
359355 m (* example_inputs )
360356
361- print ("after quantization:" , benchmark_model (m , NUM_RUNS , example_inputs [ 0 ] ))
357+ print ("after quantization:" , benchmark_model (m , NUM_RUNS , example_inputs ))
362358
363359m = torch .compile (m , mode = "max-autotune" )
364360
@@ -367,4 +363,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
367363
368364# NOTE: currently there is no speedup because we just dequantize the weight in the _quantized_linear op
369365# we plan to add custom op example in the future and that will help us to get speedup
370- print ("after quantization and compile:" , benchmark_model (m , NUM_RUNS , example_inputs [ 0 ] ))
366+ print ("after quantization and compile:" , benchmark_model (m , NUM_RUNS , example_inputs ))
0 commit comments