Skip to content

Commit 8c37568

Browse files
committed
Move floatx_tensor_core_layout to prototype/dtypes (#3317)
1 parent 42fc6bd commit 8c37568

File tree

10 files changed

+723
-672
lines changed

10 files changed

+723
-672
lines changed

benchmarks/benchmark_fp6.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from tqdm import tqdm
1010

1111
from torchao.dtypes import to_affine_quantized_fpx
12-
from torchao.dtypes.floatx import FloatxTensorCoreLayout
12+
from torchao.prototype.dtypes.floatx import FloatxTensorCoreLayout
1313
from torchao.utils import benchmark_torch_function_in_microseconds
1414

1515

docs/source/api_ref_dtypes.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ Layouts and Tensor Subclasses
2020
TensorCoreTiledLayout
2121
Float8Layout
2222
FloatxTensor
23-
FloatxTensorCoreLayout
2423
MarlinSparseLayout
2524
Int4CPULayout
2625
CutlassSemiSparseLayout
@@ -52,6 +51,7 @@ Prototype
5251
Int8DynamicActInt4WeightCPULayout
5352
MarlinQQQTensor
5453
MarlinQQQLayout
54+
FloatxTensorCoreLayout
5555
UintxLayout
5656

5757
..

test/dtypes/test_floatx.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,20 @@
1414
run_tests,
1515
)
1616

17-
from torchao.dtypes.floatx import (
17+
from torchao.prototype.custom_fp_utils import (
18+
_f32_to_floatx_unpacked,
19+
_floatx_unpacked_to_f32,
20+
)
21+
from torchao.prototype.dtypes.floatx import (
1822
FloatxTensorCoreLayout,
1923
from_scaled_tc_floatx,
2024
to_scaled_tc_floatx,
2125
)
22-
from torchao.dtypes.floatx.floatx_tensor_core_layout import (
26+
from torchao.prototype.dtypes.floatx.floatx_tensor_core_layout import (
2327
FloatxTensorCoreAQTTensorImpl,
2428
_pack_tc_floatx,
2529
_pack_tc_fp6,
2630
)
27-
from torchao.prototype.custom_fp_utils import (
28-
_f32_to_floatx_unpacked,
29-
_floatx_unpacked_to_f32,
30-
)
3131
from torchao.quantization import (
3232
FPXWeightOnlyConfig,
3333
quantize_,

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
136136
if output_dtype is None:
137137
output_dtype = self.dtype
138138

139-
from torchao.dtypes.floatx import Float8Layout, FloatxTensorCoreLayout
139+
from torchao.dtypes.floatx import Float8Layout
140+
from torchao.prototype.dtypes.floatx import FloatxTensorCoreLayout
140141

141142
if isinstance(self._layout, FloatxTensorCoreLayout):
142143
int_data, scale = self.tensor_impl.get_plain()
@@ -539,7 +540,7 @@ def from_hp_to_fpx(
539540
_layout: Layout,
540541
):
541542
"""Create a floatx AffineQuantizedTensor from a high precision tensor. Floatx is represented as ebits and mbits, and supports the representation of float1-float7."""
542-
from torchao.dtypes.floatx import FloatxTensorCoreLayout
543+
from torchao.prototype.dtypes.floatx import FloatxTensorCoreLayout
543544

544545
assert isinstance(_layout, FloatxTensorCoreLayout), (
545546
f"Only FloatxTensorCoreLayout is supported for floatx, got {_layout}"

torchao/dtypes/affine_quantized_tensor_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,6 @@
2121
_linear_fp_act_fp8_weight_check,
2222
_linear_fp_act_fp8_weight_impl,
2323
)
24-
from torchao.dtypes.floatx.floatx_tensor_core_layout import (
25-
_linear_f16_bf16_act_floatx_weight_check,
26-
_linear_f16_bf16_act_floatx_weight_impl,
27-
)
2824
from torchao.dtypes.uintx.int4_cpu_layout import (
2925
_linear_fp_act_uint4_weight_cpu_check,
3026
_linear_fp_act_uint4_weight_cpu_impl,
@@ -72,6 +68,10 @@
7268
_linear_bf16_act_uint4_weight_check,
7369
_linear_bf16_act_uint4_weight_impl,
7470
)
71+
from torchao.prototype.dtypes.floatx.floatx_tensor_core_layout import (
72+
_linear_f16_bf16_act_floatx_weight_check,
73+
_linear_f16_bf16_act_floatx_weight_impl,
74+
)
7575
from torchao.prototype.dtypes.uintx.block_sparse_layout import (
7676
_linear_int8_act_int8_weight_block_sparse_check,
7777
_linear_int8_act_int8_weight_block_sparse_impl,

0 commit comments

Comments
 (0)