|
96 | 96 | to_weight_tensor_with_linear_activation_quantization_metadata, |
97 | 97 | ) |
98 | 98 | from torchao.utils import ( |
| 99 | + _ConfigDeprecationWrapper, |
99 | 100 | is_MI300, |
100 | 101 | is_sm_at_least_89, |
101 | 102 | is_sm_at_least_90, |
|
144 | 145 | "autoquant", |
145 | 146 | "_get_subclass_inserter", |
146 | 147 | "quantize_", |
| 148 | + "int8_dynamic_activation_int4_weight", |
| 149 | + "int8_dynamic_activation_int8_weight", |
| 150 | + "int8_dynamic_activation_int8_semi_sparse_weight", |
| 151 | + "int4_weight_only", |
| 152 | + "int8_weight_only", |
147 | 153 | "intx_quantization_aware_training", |
| 154 | + "float8_weight_only", |
| 155 | + "uintx_weight_only", |
| 156 | + "fpx_weight_only", |
| 157 | + "gemlite_uintx_weight_only", |
| 158 | + "float8_dynamic_activation_float8_weight", |
| 159 | + "float8_static_activation_float8_weight", |
148 | 160 | "Int8DynActInt4WeightQuantizer", |
149 | 161 | "Float8DynamicActivationFloat8SemiSparseWeightConfig", |
150 | 162 | "ModuleFqnToConfig", |
@@ -491,7 +503,7 @@ def quantize_( |
491 | 503 | # Int8DynamicActivationInt8WeightConfig (optimized with int8 mm op and torch.compile) |
492 | 504 | # Int4WeightOnlyConfig (optimized with int4 tinygemm kernel and torch.compile) |
493 | 505 | # Int8WeightOnlyConfig (optimized with int8 mm op and torch.compile |
494 | | - from torchao.quantization.quant_api import Int4WeightOnlyConfig |
| 506 | + from torchao.quantization.quant_api import int4_weight_only |
495 | 507 |
|
496 | 508 | m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) |
497 | 509 | quantize_(m, Int4WeightOnlyConfig(group_size=32, version=1)) |
@@ -613,6 +625,12 @@ def __post_init__(self): |
613 | 625 | ) |
614 | 626 |
|
615 | 627 |
|
| 628 | +# for BC |
| 629 | +int8_dynamic_activation_int4_weight = _ConfigDeprecationWrapper( |
| 630 | + "int8_dynamic_activation_int4_weight", Int8DynamicActivationInt4WeightConfig |
| 631 | +) |
| 632 | + |
| 633 | + |
616 | 634 | @register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig) |
617 | 635 | def _int8_dynamic_activation_int4_weight_transform( |
618 | 636 | module: torch.nn.Module, |
@@ -978,6 +996,12 @@ def __post_init__(self): |
978 | 996 | ) |
979 | 997 |
|
980 | 998 |
|
| 999 | +# for bc |
| 1000 | +int4_dynamic_activation_int4_weight = _ConfigDeprecationWrapper( |
| 1001 | + "int4_dynamic_activation_int4_weight", Int4DynamicActivationInt4WeightConfig |
| 1002 | +) |
| 1003 | + |
| 1004 | + |
981 | 1005 | @register_quantize_module_handler(Int4DynamicActivationInt4WeightConfig) |
982 | 1006 | def _int4_dynamic_activation_int4_weight_transform( |
983 | 1007 | module: torch.nn.Module, config: Int4DynamicActivationInt4WeightConfig |
@@ -1035,6 +1059,12 @@ def __post_init__(self): |
1035 | 1059 | ) |
1036 | 1060 |
|
1037 | 1061 |
|
| 1062 | +# for BC |
| 1063 | +gemlite_uintx_weight_only = _ConfigDeprecationWrapper( |
| 1064 | + "gemlite_uintx_weight_only", GemliteUIntXWeightOnlyConfig |
| 1065 | +) |
| 1066 | + |
| 1067 | + |
1038 | 1068 | @register_quantize_module_handler(GemliteUIntXWeightOnlyConfig) |
1039 | 1069 | def _gemlite_uintx_weight_only_transform( |
1040 | 1070 | module: torch.nn.Module, config: GemliteUIntXWeightOnlyConfig |
@@ -1112,6 +1142,11 @@ def __post_init__(self): |
1112 | 1142 | torch._C._log_api_usage_once("torchao.quantization.Int4WeightOnlyConfig") |
1113 | 1143 |
|
1114 | 1144 |
|
| 1145 | +# for BC |
| 1146 | +# TODO maybe change other callsites |
| 1147 | +int4_weight_only = _ConfigDeprecationWrapper("int4_weight_only", Int4WeightOnlyConfig) |
| 1148 | + |
| 1149 | + |
1115 | 1150 | def _int4_weight_only_quantize_tensor(weight, config): |
1116 | 1151 | # TODO(future PR): perhaps move this logic to a different file, to keep the API |
1117 | 1152 | # file clean of implementation details |
@@ -1323,6 +1358,10 @@ def __post_init__(self): |
1323 | 1358 | torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig") |
1324 | 1359 |
|
1325 | 1360 |
|
| 1361 | +# for BC |
| 1362 | +int8_weight_only = _ConfigDeprecationWrapper("int8_weight_only", Int8WeightOnlyConfig) |
| 1363 | + |
| 1364 | + |
1326 | 1365 | def _int8_weight_only_quantize_tensor(weight, config): |
1327 | 1366 | mapping_type = MappingType.SYMMETRIC |
1328 | 1367 | target_dtype = torch.int8 |
@@ -1480,6 +1519,12 @@ def __post_init__(self): |
1480 | 1519 | ) |
1481 | 1520 |
|
1482 | 1521 |
|
| 1522 | +# for BC |
| 1523 | +int8_dynamic_activation_int8_weight = _ConfigDeprecationWrapper( |
| 1524 | + "int8_dynamic_activation_int8_weight", Int8DynamicActivationInt8WeightConfig |
| 1525 | +) |
| 1526 | + |
| 1527 | + |
1483 | 1528 | def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): |
1484 | 1529 | layout = config.layout |
1485 | 1530 | act_mapping_type = config.act_mapping_type |
@@ -1585,6 +1630,12 @@ def __post_init__(self): |
1585 | 1630 | torch._C._log_api_usage_once("torchao.quantization.Float8WeightOnlyConfig") |
1586 | 1631 |
|
1587 | 1632 |
|
| 1633 | +# for BC |
| 1634 | +float8_weight_only = _ConfigDeprecationWrapper( |
| 1635 | + "float8_weight_only", Float8WeightOnlyConfig |
| 1636 | +) |
| 1637 | + |
| 1638 | + |
1588 | 1639 | def _float8_weight_only_quant_tensor(weight, config): |
1589 | 1640 | if config.version == 1: |
1590 | 1641 | warnings.warn( |
@@ -1743,6 +1794,12 @@ def __post_init__(self): |
1743 | 1794 | self.granularity = [activation_granularity, weight_granularity] |
1744 | 1795 |
|
1745 | 1796 |
|
| 1797 | +# for bc |
| 1798 | +float8_dynamic_activation_float8_weight = _ConfigDeprecationWrapper( |
| 1799 | + "float8_dynamic_activation_float8_weight", Float8DynamicActivationFloat8WeightConfig |
| 1800 | +) |
| 1801 | + |
| 1802 | + |
1746 | 1803 | def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): |
1747 | 1804 | activation_dtype = config.activation_dtype |
1748 | 1805 | weight_dtype = config.weight_dtype |
@@ -1918,6 +1975,12 @@ def __post_init__(self): |
1918 | 1975 | ) |
1919 | 1976 |
|
1920 | 1977 |
|
| 1978 | +# for bc |
| 1979 | +float8_static_activation_float8_weight = _ConfigDeprecationWrapper( |
| 1980 | + "float8_static_activation_float8_weight", Float8StaticActivationFloat8WeightConfig |
| 1981 | +) |
| 1982 | + |
| 1983 | + |
1921 | 1984 | @register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig) |
1922 | 1985 | def _float8_static_activation_float8_weight_transform( |
1923 | 1986 | module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig |
@@ -2000,6 +2063,12 @@ def __post_init__(self): |
2000 | 2063 | torch._C._log_api_usage_once("torchao.quantization.UIntXWeightOnlyConfig") |
2001 | 2064 |
|
2002 | 2065 |
|
| 2066 | +# for BC |
| 2067 | +uintx_weight_only = _ConfigDeprecationWrapper( |
| 2068 | + "uintx_weight_only", UIntXWeightOnlyConfig |
| 2069 | +) |
| 2070 | + |
| 2071 | + |
2003 | 2072 | @register_quantize_module_handler(UIntXWeightOnlyConfig) |
2004 | 2073 | def _uintx_weight_only_transform( |
2005 | 2074 | module: torch.nn.Module, config: UIntXWeightOnlyConfig |
@@ -2278,6 +2347,10 @@ def __post_init__(self): |
2278 | 2347 | torch._C._log_api_usage_once("torchao.quantization.FPXWeightOnlyConfig") |
2279 | 2348 |
|
2280 | 2349 |
|
| 2350 | +# for BC |
| 2351 | +fpx_weight_only = _ConfigDeprecationWrapper("fpx_weight_only", FPXWeightOnlyConfig) |
| 2352 | + |
| 2353 | + |
2281 | 2354 | @register_quantize_module_handler(FPXWeightOnlyConfig) |
2282 | 2355 | def _fpx_weight_only_transform( |
2283 | 2356 | module: torch.nn.Module, config: FPXWeightOnlyConfig |
|
0 commit comments