Skip to content

Commit 17867e6

Browse files
authored
Move marlin_qqq_tensor to prototype/dtypes (#3307)
1 parent 2f903f8 commit 17867e6

File tree

11 files changed

+397
-354
lines changed

11 files changed

+397
-354
lines changed

benchmarks/microbenchmarks/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def string_to_config(
218218
)
219219
if "marlin" in quantization:
220220
if "qqq" in quantization:
221-
from torchao.dtypes import MarlinQQQLayout
221+
from torchao.prototype.dtypes import MarlinQQQLayout
222222

223223
return Int8DynamicActivationInt4WeightConfig(
224224
group_size=128,

docs/source/api_ref_dtypes.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ Layouts and Tensor Subclasses
2323
FloatxTensorCoreLayout
2424
MarlinSparseLayout
2525
UintxLayout
26-
MarlinQQQTensor
27-
MarlinQQQLayout
2826
Int4CPULayout
2927
CutlassSemiSparseLayout
3028

@@ -53,6 +51,8 @@ Prototype
5351
BlockSparseLayout
5452
CutlassInt4PackedLayout
5553
Int8DynamicActInt4WeightCPULayout
54+
MarlinQQQTensor
55+
MarlinQQQLayout
5656

5757
..
5858
_NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring

test/dtypes/test_uintx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def test_uintx_api_deprecation():
182182
),
183183
("CutlassInt4PackedLayout", "torchao.dtypes.uintx.cutlass_int4_packed_layout"),
184184
("BlockSparseLayout", "torchao.dtypes.uintx.block_sparse_layout"),
185+
("MarlinQQQLayout", "torchao.dtypes.uintx.marlin_qqq_tensor"),
185186
]
186187

187188
for api_name, module_path in deprecated_apis:

test/quantization/test_marlin_qqq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch import nn
1111
from torch.testing._internal.common_utils import TestCase, run_tests
1212

13-
from torchao.dtypes import MarlinQQQLayout
13+
from torchao.prototype.dtypes import MarlinQQQLayout
1414
from torchao.quantization.marlin_qqq import (
1515
pack_to_marlin_qqq,
1616
unpack_from_marlin_qqq,

torchao/_models/llama/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ def ffn_or_attn_only(mod, fqn):
460460
)
461461
if "marlin" in quantization:
462462
if "qqq" in quantization:
463-
from torchao.dtypes import MarlinQQQLayout
463+
from torchao.prototype.dtypes import MarlinQQQLayout
464464

465465
quantize_(
466466
model,

torchao/dtypes/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,21 @@
1616
from .uintx import (
1717
Int4CPULayout,
1818
Int4XPULayout,
19-
MarlinQQQLayout,
20-
MarlinQQQTensor,
2119
MarlinSparseLayout,
2220
PackedLinearInt8DynamicActivationIntxWeightLayout,
2321
QDQLayout,
2422
SemiSparseLayout,
2523
TensorCoreTiledLayout,
2624
UintxLayout,
27-
to_marlinqqq_quantized_intx,
2825
)
2926
from .uintx.block_sparse_layout import BlockSparseLayout
3027
from .uintx.cutlass_int4_packed_layout import CutlassInt4PackedLayout
3128
from .uintx.dyn_int8_act_int4_wei_cpu_layout import Int8DynamicActInt4WeightCPULayout
29+
from .uintx.marlin_qqq_tensor import (
30+
MarlinQQQLayout,
31+
MarlinQQQTensor,
32+
to_marlinqqq_quantized_intx,
33+
)
3234
from .utils import (
3335
Layout,
3436
PlainLayout,

torchao/dtypes/affine_quantized_tensor_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,6 @@
3939
_linear_fp_act_uint4_weight_int8_zero_check,
4040
_linear_fp_act_uint4_weight_int8_zero_impl,
4141
)
42-
from torchao.dtypes.uintx.marlin_qqq_tensor import (
43-
_linear_int8_act_int4_weight_marlin_qqq_check,
44-
_linear_int8_act_int4_weight_marlin_qqq_impl,
45-
)
4642
from torchao.dtypes.uintx.marlin_sparse_layout import (
4743
_linear_fp_act_int4_weight_sparse_marlin_check,
4844
_linear_fp_act_int4_weight_sparse_marlin_impl,
@@ -94,6 +90,10 @@
9490
_linear_int8_act_int4_weight_cpu_check,
9591
_linear_int8_act_int4_weight_cpu_impl,
9692
)
93+
from torchao.prototype.dtypes.uintx.marlin_qqq_tensor import (
94+
_linear_int8_act_int4_weight_marlin_qqq_check,
95+
_linear_int8_act_int4_weight_marlin_qqq_impl,
96+
)
9797
from torchao.quantization.quant_primitives import (
9898
ZeroPointDomain,
9999
_dequantize_affine_no_zero_point,

0 commit comments

Comments
 (0)