Skip to content

Commit 2ef5194

Browse files
committed
add tests
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 4033f31 commit 2ef5194

File tree

2 files changed

+45
-35
lines changed

2 files changed

+45
-35
lines changed

tests/mock_observer.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Tuple
15+
from typing import Optional, Tuple
1616
from weakref import ref
1717

1818
import torch
@@ -42,7 +42,7 @@ def get_min_max(self, observed: torch.Tensor):
4242
return min_vals, max_vals
4343

4444
def forward(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
45-
observed = flatten_for_quantization(observed, self.base_name, self.args)
45+
observed = flatten_for_calibration(observed, self.base_name, self.args)
4646

4747
self.min_vals, self.max_vals = self.get_min_max(observed)
4848

@@ -57,26 +57,31 @@ def forward(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
5757

5858
def get_global_scale(self, observed: torch.Tensor):
5959
observed = observed.reshape((1, 1, -1)) # per tensor reshape
60-
min_vals, max_vals = self.get_min_max(observed)
61-
global_scale = generate_gparam(min_vals, max_vals)
60+
self.min_vals, self.max_vals = self.get_min_max(observed)
61+
global_scale = generate_gparam(self.min_vals, self.max_vals)
6262

6363
return global_scale
6464

6565

66-
def flatten_for_quantization(
67-
value: torch.Tensor, base_name: str, args: QuantizationArgs
66+
def flatten_for_calibration(
67+
value: torch.Tensor,
68+
base_name: str,
69+
args: QuantizationArgs,
70+
g_idx: Optional[torch.Tensor] = None,
6871
) -> torch.Tensor:
6972
if base_name == "weight":
70-
return flatten_weight_for_quantization(value, args)
73+
return _flatten_weight(value, args, g_idx)
7174
elif base_name in ("input", "output"):
72-
return flatten_activation_for_quantization(value, args)
75+
return _flatten_activation(value, args)
7376
elif base_name in ("q", "k", "v"):
74-
return flatten_attention_for_quantization(value, args)
77+
return _flatten_attention(value, args)
7578
else:
7679
raise ValueError(f"Unknown quantization base name: {base_name}")
7780

7881

79-
def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs):
82+
def _flatten_weight(
83+
value: torch.Tensor, args: QuantizationArgs, g_idx: Optional[torch.Tensor] = None
84+
):
8085
# value.shape = (num_rows, num_cols)
8186

8287
if args.strategy == QuantizationStrategy.TENSOR:
@@ -91,34 +96,32 @@ def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs)
9196
return value.unsqueeze(-2).unsqueeze(0)
9297

9398
if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
99+
if g_idx is not None:
100+
value = value.index_select(dim=1, index=torch.argsort(g_idx))
101+
94102
# (1, num_rows, num_groups, group_size)
95103
return value.unflatten(-1, (-1, args.group_size)).unsqueeze(0)
96104

97105
if args.strategy == QuantizationStrategy.BLOCK:
98106
# (1, num_block_rows, num_block_cols, block_width * block_height)
99107
block_height, block_width = args.block_structure
100-
num_rows, num_cols = value.shape
101-
num_block_rows = strategy_cdiv(num_rows, block_height, args.strategy)
102-
num_block_cols = strategy_cdiv(num_cols, block_width, args.strategy)
108+
rows, cols = value.shape
109+
block_rows = strategy_cdiv(rows, block_height, args.strategy, strict=True)
110+
block_cols = strategy_cdiv(cols, block_width, args.strategy, strict=True)
103111
return (
104-
value.reshape(
105-
num_block_rows,
106-
block_height,
107-
num_block_cols,
108-
block_width,
109-
)
112+
value.reshape(block_rows, block_height, block_cols, block_width)
110113
.transpose(1, 2)
111114
.flatten(-2, -1)
112115
.unsqueeze(0)
113116
)
114117

115118
if args.strategy == QuantizationStrategy.ATTN_HEAD:
116-
raise ValueError("attention head quantization cannot be applied to weights")
119+
raise ValueError("Attention head quantization cannot be applied to weights")
117120

118121
assert False, f"Unknown strategy {args.strategy}"
119122

120123

121-
def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationArgs):
124+
def _flatten_activation(value: torch.Tensor, args: QuantizationArgs):
122125
# value.shape = (batch_size, seq_len, hidden_dim)
123126

124127
if args.strategy == QuantizationStrategy.TENSOR:
@@ -128,7 +131,7 @@ def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationA
128131
if args.strategy == QuantizationStrategy.TOKEN:
129132
# (batch_size, seq_len, hidden_dim)
130133
# warning: token quantization uses `compute_dynamic_scales_and_zp`
131-
return value.flatten(2, -1)
134+
return value
132135

133136
if args.strategy == QuantizationStrategy.CHANNEL:
134137
raise ValueError("Channel quantization cannot be applied to activations")
@@ -142,12 +145,12 @@ def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationA
142145
raise ValueError("Block quantization cannot be applied to activations")
143146

144147
if args.strategy == QuantizationStrategy.ATTN_HEAD:
145-
raise ValueError("attention head quantization cannot be applied to linear acts")
148+
raise ValueError("Attention head quantization cannot be applied to activations")
146149

147150
assert False, f"Unknown strategy {args.strategy}"
148151

149152

150-
def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationArgs):
153+
def _flatten_attention(value: torch.Tensor, args: QuantizationArgs):
151154
# value.shape = (batch_size, num_heads, seq_len, head_dim)
152155

153156
if args.strategy == QuantizationStrategy.TENSOR:
@@ -161,7 +164,8 @@ def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationAr
161164
raise ValueError("Channel quantization cannot be applied to attention")
162165

163166
if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
164-
raise ValueError("Group quantization cannot be applied to attention")
167+
# batch_size * num_heads * seq_len, num_groups, group_size)
168+
return value.flatten(0, 2).unflatten(-1, (-1, args.group_size))
165169

166170
if args.strategy == QuantizationStrategy.BLOCK:
167171
raise ValueError("Block quantization cannot be applied to attention")

tests/test_quantization/lifecycle/test_static_lifecycle.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -314,24 +314,26 @@ class MockAttention(torch.nn.Module):
314314
(
315315
QuantizationArgs(
316316
num_bits=4,
317-
type="int",
317+
type="float", # must be fp4
318318
symmetric=True,
319-
strategy="attn_head",
319+
strategy="tensor_group",
320+
dynamic="local",
321+
group_size=2,
320322
),
321-
torch.tensor([[[0.0]], [[12.0]]]),
322-
torch.tensor([[[11.0]], [[23.0]]]),
323+
torch.tensor([0.0]),
324+
torch.tensor([23.0]),
323325
torch.tensor(
324326
[
325327
[
326328
[
327-
[0.0000, 1.4688, 1.4688, 2.9375],
328-
[4.4062, 4.4062, 5.8750, 7.3438],
329-
[7.3438, 8.8125, 10.2500, 10.2500],
329+
[0.0000, 1.0234, 2.0469, 3.0781],
330+
[3.2812, 4.9375, 4.9375, 7.3750],
331+
[9.0000, 9.0000, 10.6875, 10.6875],
330332
],
331333
[
332-
[12.2500, 12.2500, 15.3125, 15.3125],
333-
[15.3125, 18.3750, 18.3750, 18.3750],
334-
[21.5000, 21.5000, 21.5000, 21.5000],
334+
[13.1250, 13.1250, 14.7500, 14.7500],
335+
[16.3750, 16.3750, 19.7500, 19.7500],
336+
[21.3750, 21.3750, 23.0000, 23.0000],
335337
],
336338
]
337339
]
@@ -369,6 +371,10 @@ def test_static_attention_quantization(
369371
attention.k_observer = MockMinMaxObserver("k", args, attention)
370372

371373
# calibrate quantization parameters
374+
if hasattr(attention, "k_global_scale"):
375+
global_scale = attention.k_observer.get_global_scale(input)
376+
attention.k_global_scale.data = global_scale
377+
372378
if scheme.input_activations.dynamic is False:
373379
scale, zero_point = attention.k_observer(input)
374380
attention.k_scale.data = scale

0 commit comments

Comments
 (0)