Skip to content

Commit 6259e98

Browse files
authored
Update Float8Tensor for GRPO training in unsloth (#3158)
**Summary:** Support a few extra ops called during GRPO loop in unsloth/vllm for Float8Tensor. **Test Plan:** ``` python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_fp8_matmul_lora_variants python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_to_dtype_layout python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_has_compatible_shallow_copy_type python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_transpose ``` ghstack-source-id: d806897 Pull Request resolved: #3291
1 parent 1fbc364 commit 6259e98

File tree

4 files changed

+267
-37
lines changed

4 files changed

+267
-37
lines changed

test/integration/test_integration.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1877,7 +1877,9 @@ def forward(self, x):
18771877
config = Float8DynamicActivationFloat8WeightConfig()
18781878
quantize_(model, config)
18791879

1880-
ep = torch.export.export(model, (inp,))
1880+
# Need to export with strict=True
1881+
# https://github.com/pytorch/pytorch/issues/167007
1882+
ep = torch.export.export(model, (inp,), strict=True)
18811883
print(ep)
18821884
FileCheck().check_count(
18831885
"torch.ops.torchao.choose_scale_float8.default", 1, exactly=True

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 159 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torchao.quantization import (
1919
Float8DynamicActivationFloat8WeightConfig,
2020
Float8WeightOnlyConfig,
21+
Granularity,
2122
PerBlock,
2223
PerRow,
2324
PerTensor,
@@ -42,6 +43,8 @@
4243
class ToyLinearModel(torch.nn.Module):
4344
def __init__(self, in_features, out_features, bias):
4445
super().__init__()
46+
self.in_features = in_features
47+
self.out_features = out_features
4548
self.linear1 = torch.nn.Linear(in_features, out_features, bias=bias)
4649
self.linear2 = torch.nn.Linear(out_features, in_features, bias=bias)
4750

@@ -50,6 +53,21 @@ def forward(self, x):
5053
x = self.linear2(x)
5154
return x
5255

56+
def check_weight_scaling(self, granularity: Granularity):
57+
qs1 = self.linear1.weight.scale
58+
qs2 = self.linear2.weight.scale
59+
N, K = (self.out_features, self.in_features)
60+
if granularity == PerTensor():
61+
assert qs1.shape == (1, 1)
62+
assert qs2.shape == (1, 1)
63+
elif granularity == PerRow():
64+
assert qs1.shape == (N, 1)
65+
assert qs2.shape == (K, 1)
66+
else:
67+
assert granularity == (PerBlock([1, 128]), PerBlock([128, 128]))
68+
assert qs1.shape == (N // 128, K // 128)
69+
assert qs2.shape == (K // 128, N // 128)
70+
5371

5472
class ToyConvModel(torch.nn.Module):
5573
def __init__(
@@ -73,6 +91,47 @@ def forward(self, x):
7391
return self.conv(x)
7492

7593

94+
class ToyLoRAModel(torch.nn.Module):
95+
def __init__(
96+
self,
97+
in_features: int,
98+
out_features: int,
99+
lora_rank: int,
100+
device: torch.device,
101+
):
102+
super().__init__()
103+
self.in_features = in_features
104+
self.out_features = out_features
105+
self.linear = torch.nn.Linear(
106+
in_features,
107+
out_features,
108+
bias=False,
109+
device=device,
110+
)
111+
self.lora_A = torch.nn.Parameter(
112+
torch.randn(in_features, lora_rank, device=device),
113+
)
114+
self.lora_B = torch.nn.Parameter(
115+
torch.randn(lora_rank, out_features, device=device),
116+
)
117+
118+
def forward(self, x):
119+
matmul_out = torch.matmul(x, self.linear.weight.t())
120+
lora_out = x @ self.lora_A @ self.lora_B
121+
return matmul_out + lora_out
122+
123+
def check_weight_scaling(self, granularity: Granularity):
124+
qs = self.linear.weight.scale
125+
N, K = (self.out_features, self.in_features)
126+
if granularity == PerTensor():
127+
assert qs.shape == (1, 1)
128+
elif granularity == PerRow():
129+
assert qs.shape == (N, 1)
130+
else:
131+
assert granularity == (PerBlock((1, 128)), PerBlock((128, 128)))
132+
assert qs.shape == (N // 128, K // 128)
133+
134+
76135
# TODO: move tests in test_affine_quantized_float.py here after we migrated all implementations
77136
@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
78137
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@@ -112,10 +171,75 @@ def test_fp8_linear_variants(
112171
dtype: torch.dtype,
113172
mode: str,
114173
compile: bool,
115-
granularity,
174+
granularity: Granularity,
116175
kernel_preference: KernelPreference,
117176
sizes: Tuple,
118177
bias: bool,
178+
):
179+
_, N, K = sizes
180+
self._test_fp8_matmul_model(
181+
dtype,
182+
mode,
183+
compile,
184+
granularity,
185+
kernel_preference,
186+
sizes,
187+
bias,
188+
ToyLinearModel(K, N, bias),
189+
)
190+
191+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
192+
@unittest.skipIf(
193+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
194+
)
195+
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
196+
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
197+
@common_utils.parametrize("compile", [True, False])
198+
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
199+
@common_utils.parametrize(
200+
"kernel_preference",
201+
[KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM],
202+
)
203+
# Inputs are (M,..), K, N
204+
@common_utils.parametrize(
205+
"sizes",
206+
[
207+
((128,), 256, 128),
208+
((32, 128), 64, 256),
209+
],
210+
)
211+
def test_fp8_matmul_lora_variants(
212+
self,
213+
dtype: torch.dtype,
214+
mode: str,
215+
compile: bool,
216+
granularity: Granularity,
217+
kernel_preference: KernelPreference,
218+
sizes: Tuple,
219+
):
220+
_, N, K = sizes
221+
model = ToyLoRAModel(K, N, lora_rank=8, device=torch.device("cpu"))
222+
self._test_fp8_matmul_model(
223+
dtype,
224+
mode,
225+
compile,
226+
granularity,
227+
kernel_preference,
228+
sizes,
229+
bias=False,
230+
model=model.to("cuda"),
231+
)
232+
233+
def _test_fp8_matmul_model(
234+
self,
235+
dtype: torch.dtype,
236+
mode: str,
237+
compile: bool,
238+
granularity: Granularity,
239+
kernel_preference: KernelPreference,
240+
sizes: Tuple,
241+
bias: bool,
242+
model: torch.nn.Module,
119243
):
120244
if isinstance(granularity, PerTensor):
121245
if kernel_preference is KernelPreference.FBGEMM:
@@ -172,9 +296,7 @@ def test_fp8_linear_variants(
172296
with error_context:
173297
M, N, K = sizes
174298
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
175-
176-
# Create a linear layer with bfloat16 dtype
177-
model = ToyLinearModel(K, N, bias).eval().to(dtype).to("cuda")
299+
model = model.eval().to(dtype).to("cuda")
178300

179301
quantized_model = copy.deepcopy(model)
180302

@@ -190,18 +312,7 @@ def test_fp8_linear_variants(
190312
quantize_(quantized_model, config)
191313

192314
# ensure weight scaling is what we expect
193-
qs1 = quantized_model.linear1.weight.scale
194-
qs2 = quantized_model.linear2.weight.scale
195-
if granularity == PerTensor():
196-
assert qs1.shape == (1, 1)
197-
assert qs2.shape == (1, 1)
198-
elif granularity == PerRow():
199-
assert qs1.shape == (N, 1)
200-
assert qs2.shape == (K, 1)
201-
else:
202-
assert granularity == (PerBlock([1, 128]), PerBlock([128, 128]))
203-
assert qs1.shape == (N // 128, K // 128)
204-
assert qs2.shape == (K // 128, N // 128)
315+
quantized_model.check_weight_scaling(granularity)
205316

206317
if compile:
207318
quantized_model = torch.compile(quantized_model, fullgraph=True)
@@ -807,6 +918,38 @@ def test_slice_3d_operation(self, granularity, slice_dim, tensor_shape):
807918

808919
self.assertEqual(sliced_dequantized, sliced_original)
809920

921+
def test_to_dtype_layout(self):
922+
x = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16)
923+
x_fp8 = Float8Tensor.from_hp(x)
924+
y_fp8 = torch.ops.aten.to.dtype_layout(
925+
x_fp8, dtype=x_fp8.dtype, layout=x_fp8.layout, device="cpu"
926+
)
927+
self.assertEqual(y_fp8.dtype, x_fp8.dtype)
928+
self.assertEqual(y_fp8.layout, x_fp8.layout)
929+
self.assertEqual(y_fp8.device, torch.device("cpu"))
930+
931+
def test_has_compatible_shallow_copy_type(self):
932+
x1 = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16)
933+
x2 = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16)
934+
x3 = torch.randn(128, 256, device="cuda", dtype=torch.bfloat16)
935+
x1_fp8 = Float8Tensor.from_hp(x1)
936+
x2_fp8 = Float8Tensor.from_hp(x2)
937+
x3_fp8 = Float8Tensor.from_hp(x3)
938+
self.assertFalse(torch._has_compatible_shallow_copy_type(x1, x2_fp8))
939+
self.assertFalse(torch._has_compatible_shallow_copy_type(x1_fp8, x2))
940+
self.assertTrue(torch._has_compatible_shallow_copy_type(x1_fp8, x2_fp8))
941+
# Wrong shape
942+
self.assertFalse(torch._has_compatible_shallow_copy_type(x1_fp8, x3_fp8))
943+
944+
def test_transpose(self):
945+
x = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16)
946+
x_fp8 = Float8Tensor.from_hp(x)
947+
x_fp8_t = x_fp8.t()
948+
torch.testing.assert_close(x_fp8_t.qdata, x_fp8.qdata.t(), atol=0, rtol=0)
949+
torch.testing.assert_close(x_fp8_t.scale, x_fp8.scale.t(), atol=0, rtol=0)
950+
self.assertEqual(x_fp8.block_size, (1, 512), atol=0, rtol=0)
951+
self.assertEqual(x_fp8_t.block_size, (512, 1), atol=0, rtol=0)
952+
810953

811954
common_utils.instantiate_parametrized_tests(TestFloat8Tensor)
812955

0 commit comments

Comments
 (0)