|
4 | 4 | from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 |
5 | 5 |
|
6 | 6 |
|
| 7 | +lib = torch.library.Library("torchao", "FRAGMENT") |
| 8 | +lib.define("quant_llm_linear(int EXPONENT, int MANTISSA, Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor") |
| 9 | +lib.define("unpack_tensor_core_tiled_layout(Tensor packed_w, int inner_k_tiles) -> Tensor") |
| 10 | +lib.define("dequantize_tensor_core_tiled_layout(Tensor packed_w, Tensor scales_and_zeros, int group_size, int inner_k_tiles) -> Tensor") |
| 11 | +lib.define("marlin_24_gemm(Tensor x, Tensor weight_marlin, Tensor meta, Tensor s, Tensor workspace, int bits, int size_m, int size_n, int size_k) -> Tensor") |
| 12 | + |
| 13 | + |
7 | 14 | def register_custom_op(name): |
8 | 15 | def decorator(func): |
9 | 16 | if TORCH_VERSION_AT_LEAST_2_4: |
@@ -39,7 +46,14 @@ def quant_llm_linear( |
39 | 46 |
|
40 | 47 |
|
41 | 48 | @register_custom_op("torchao::quant_llm_linear") |
42 | | -def _(EXPONENT, MANTISSA, _in_feats, _weights, _scales, splitK = 1): |
| 49 | +def _( |
| 50 | + EXPONENT: int, |
| 51 | + MANTISSA: int, |
| 52 | + _in_feats: Tensor, |
| 53 | + _weights: Tensor, |
| 54 | + _scales: Tensor, |
| 55 | + splitK: int = 1, |
| 56 | +) -> Tensor: |
43 | 57 | torch._check(_in_feats.dim() == 2, lambda: f"input should be a 2d tensor, got {_in_feats.dim()}D") |
44 | 58 | torch._check(_in_feats.dtype is torch.float16, lambda: f"weight must be FP16, got {_in_feats.dtype}") |
45 | 59 | torch._check(_weights.dim() == 2, lambda: f"weight should be a 2d tensor, got {_weights.dim()}D") |
@@ -76,7 +90,7 @@ def unpack_tensor_core_tiled_layout(packed_w: Tensor, inner_k_tiles: int) -> Ten |
76 | 90 | ) |
77 | 91 |
|
78 | 92 |
|
79 | | -@register_custom_op(f"torchao::unpack_tensor_core_tiled_layout") |
| 93 | +@register_custom_op("torchao::unpack_tensor_core_tiled_layout") |
80 | 94 | def _(packed_w: Tensor, inner_k_tiles: int) -> Tensor: |
81 | 95 | torch._check( |
82 | 96 | packed_w.dim() == 4, |
@@ -127,7 +141,7 @@ def dequantize_tensor_core_tiled_layout(packed_w: Tensor, scales_and_zeros: Tens |
127 | 141 | ) |
128 | 142 |
|
129 | 143 |
|
130 | | -@register_custom_op(f"torchao::dequantize_tensor_core_tiled_layout") |
| 144 | +@register_custom_op("torchao::dequantize_tensor_core_tiled_layout") |
131 | 145 | def _(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles: int) -> Tensor: |
132 | 146 | # packed_w preconditions |
133 | 147 | torch._check( |
@@ -192,7 +206,7 @@ def marlin_24_gemm( |
192 | 206 | ) |
193 | 207 |
|
194 | 208 |
|
195 | | -@register_custom_op(f"torchao::marlin_24_gemm") |
| 209 | +@register_custom_op("torchao::marlin_24_gemm") |
196 | 210 | def _( |
197 | 211 | x: Tensor, |
198 | 212 | weight_marlin: Tensor, |
|
0 commit comments