Skip to content

Commit 8878f30

Browse files
authored
Add parq utility to create an optimizer (#3165)
* Add parq util * up
1 parent 41a6c11 commit 8878f30

File tree

1 file changed

+145
-0
lines changed

1 file changed

+145
-0
lines changed

torchao/prototype/parq/api.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
from dataclasses import dataclass
2+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
3+
4+
import torch
5+
6+
from torchao.prototype.parq.optim import QuantOptimizer
7+
from torchao.prototype.parq.quant import (
8+
Quantizer,
9+
StretchedUnifTorchaoQuantizer,
10+
UnifTorchaoQuantizer,
11+
)
12+
13+
14+
@dataclass(frozen=True, slots=True)
15+
class QuantConfig:
16+
bitwidth: int
17+
group_size: Optional[int] = None
18+
quantizer: Optional[Quantizer] = None
19+
20+
def __post_init__(self):
21+
if self.bitwidth < 2:
22+
raise ValueError("bitwidth must be >= 2")
23+
if self.group_size is not None and self.group_size <= 0:
24+
raise ValueError("group_size must be positive")
25+
26+
if self.quantizer is None:
27+
if self.bitwidth in [2, 3]:
28+
q = StretchedUnifTorchaoQuantizer(b=self.bitwidth)
29+
else:
30+
q = UnifTorchaoQuantizer()
31+
object.__setattr__(self, "quantizer", q)
32+
33+
34+
def create_param_groups_and_group_quantizer_map(
35+
model: torch.nn.Module,
36+
quant_configs_and_filter_fns: List[
37+
Tuple[QuantConfig, Callable[[torch.nn.Module, str], bool]]
38+
],
39+
):
40+
param_groups = []
41+
group_quantizer_map = {}
42+
for idx, (config, _) in enumerate(quant_configs_and_filter_fns):
43+
params_quant = []
44+
param_group = {
45+
"params": params_quant,
46+
"quant_bits": config.bitwidth,
47+
}
48+
if config.group_size is not None:
49+
param_group["quant_block_size"] = config.group_size
50+
param_group["_quantizer"] = config.quantizer
51+
param_groups.append(param_group)
52+
53+
# Non-quantized group at end so that index in param_groups
54+
# is the index in the subset of quantized param groups, which is
55+
# used in defining group_quantizer_map
56+
params_no_quant = []
57+
param_groups.append({"params": params_no_quant, "weight_decay": 0.0})
58+
59+
seen_data_ptrs = {}
60+
for param_name, param in model.named_parameters():
61+
module_name, _, param_basename = param_name.rpartition(".")
62+
owning_module = model.get_submodule(module_name) if module_name else model
63+
64+
data_ptr = param.data_ptr()
65+
if data_ptr in seen_data_ptrs:
66+
print(
67+
f"Not considering {param} because it shares a data_ptr with {seen_data_ptrs[data_ptr]}, which was previously considered"
68+
)
69+
continue
70+
seen_data_ptrs[data_ptr] = param_name
71+
72+
print(
73+
"param_name",
74+
param_name,
75+
"module_type",
76+
type(owning_module),
77+
"matching_config:",
78+
end="",
79+
)
80+
matching_config = None
81+
for idx, (config, filter_fn) in enumerate(quant_configs_and_filter_fns):
82+
if filter_fn(owning_module, param_name):
83+
param_groups[idx]["params"].append(param)
84+
if matching_config is None:
85+
matching_config = config
86+
print(f"{config.bitwidth},{config.group_size}")
87+
else:
88+
raise ValueError(
89+
f"Found multiple matching configs for {param_name}. Previous match={matching_config}, new match={config}."
90+
)
91+
92+
# If no match, add to no-quant group at last idx
93+
if matching_config is None:
94+
print("NONE")
95+
param_groups[-1]["params"].append(param)
96+
97+
# Filter out empty param groups
98+
param_groups = [pg for pg in param_groups if len(pg["params"]) > 0]
99+
100+
# After filter define group_quantizer_map
101+
# The index in group_quantizer_map must correspond to index in
102+
# quantized params
103+
group_quantizer_map = {}
104+
for idx, param_group in enumerate(param_groups):
105+
if "_quantizer" in param_group:
106+
group_quantizer_map[idx] = param_group.pop("_quantizer")
107+
108+
expected_n_params = sum(1 for p in model.parameters())
109+
n_found_params = sum(len(pg["params"]) for pg in param_groups)
110+
assert n_found_params == expected_n_params, (
111+
f"{n_found_params} != {expected_n_params=}"
112+
)
113+
114+
return param_groups, group_quantizer_map
115+
116+
117+
from torchao.prototype.parq import ProxHardQuant
118+
119+
120+
def create_optimizer(
121+
model: torch.nn.Module,
122+
quant_configs_and_filter_fns: List[
123+
Tuple[QuantConfig, Callable[[torch.nn.Module, str], bool]]
124+
],
125+
base_optimizer_cls: Type[torch.optim.Optimizer],
126+
base_optimizer_kwargs: Dict[str, Any],
127+
*,
128+
warmup_steps: int = 0,
129+
quant_period: int = 1,
130+
quant_per_channel: bool = True,
131+
):
132+
param_groups, group_quantizer_map = create_param_groups_and_group_quantizer_map(
133+
model, quant_configs_and_filter_fns
134+
)
135+
base_optimizer = base_optimizer_cls(param_groups, **base_optimizer_kwargs)
136+
optimizer = QuantOptimizer(
137+
base_optimizer,
138+
quantizer=UnifTorchaoQuantizer(),
139+
prox_map=ProxHardQuant(),
140+
warmup_steps=warmup_steps,
141+
quant_period=quant_period,
142+
quant_per_channel=quant_per_channel,
143+
group_quantizer_map=group_quantizer_map,
144+
)
145+
return optimizer

0 commit comments

Comments
 (0)