Skip to content

Commit a76ae95

Browse files
committed
Update quantization kernels
1 parent 778b61c commit a76ae95

File tree

7 files changed

+69
-76
lines changed

7 files changed

+69
-76
lines changed

flake.lock

Lines changed: 4 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

flake.nix

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
inputs.nixpkgs.follows = "hf-nix/nixpkgs";
66
};
77
nix-filter.url = "github:numtide/nix-filter";
8-
hf-nix.url = "github:huggingface/hf-nix";
8+
hf-nix.url = "github:huggingface/hf-nix/quantization-0.1.0";
99
nixpkgs.follows = "hf-nix/nixpkgs";
1010
flake-utils.url = "github:numtide/flake-utils";
1111
rust-overlay = {
@@ -33,7 +33,7 @@
3333
};
3434
pkgs = import nixpkgs {
3535
inherit system;
36-
inherit (hf-nix.lib) config;
36+
config = hf-nix.lib.config system;
3737
overlays = [
3838
rust-overlay.overlays.default
3939
hf-nix.overlays.default

server/kernels.lock

Lines changed: 14 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -223,82 +223,58 @@
223223
},
224224
{
225225
"repo_id": "kernels-community/quantization",
226-
"sha": "6470f9b005797e00279eb9103463dfe0f8b7da00",
226+
"sha": "229f047e826202eb49dc0321bb38aed5d3ab96e3",
227227
"variants": {
228-
"torch25-cxx11-cu118-x86_64-linux": {
229-
"hash": "sha256-f52c9b1a7cd98fb389c6d2a0b22a293cb36eb96af3a624f5aec761735861c96d",
230-
"hash_type": "git_lfs_concat"
231-
},
232-
"torch25-cxx11-cu121-x86_64-linux": {
233-
"hash": "sha256-e5f0da343363a562ce52f147a9534cd54a3efa90e70671f606cc2516f02a3876",
234-
"hash_type": "git_lfs_concat"
235-
},
236-
"torch25-cxx11-cu124-x86_64-linux": {
237-
"hash": "sha256-caad9300c155faf79c26426f10951ba75f931a05e741a5b39a24b064daabc040",
238-
"hash_type": "git_lfs_concat"
239-
},
240-
"torch25-cxx98-cu118-x86_64-linux": {
241-
"hash": "sha256-4fc87893de14a29ba4b55f5026ea05ec5901c0b52abd5ebae681ea0b791e858c",
242-
"hash_type": "git_lfs_concat"
243-
},
244-
"torch25-cxx98-cu121-x86_64-linux": {
245-
"hash": "sha256-72c975ea63fc524a38fcee5b2dbdb566eff0a0ea546ee5756441d04908e4e896",
246-
"hash_type": "git_lfs_concat"
247-
},
248-
"torch25-cxx98-cu124-x86_64-linux": {
249-
"hash": "sha256-28c5510e3b07eae2b3846b880f6111da65df024e1f24f81077d187a97c015364",
250-
"hash_type": "git_lfs_concat"
251-
},
252228
"torch26-cxx11-cu118-x86_64-linux": {
253-
"hash": "sha256-8444cf77686578a6b0f7e2fd29bf2783ba120ebf7df41573f61d2521fd0acc10",
229+
"hash": "sha256-354e86a4a1fc38bfaddb3bf98c083ccd8a00de721d6769e0f3c594b719c9dbd2",
254230
"hash_type": "git_lfs_concat"
255231
},
256232
"torch26-cxx11-cu124-x86_64-linux": {
257-
"hash": "sha256-6ea8e00625b5fe799fbe407e7de0fc08228cac26f9bbed2d70a6500026fe3bab",
233+
"hash": "sha256-99523c409552d6a0a514987bd31b427c273695abaa1085be85f9f243f6ff8184",
258234
"hash_type": "git_lfs_concat"
259235
},
260236
"torch26-cxx11-cu126-aarch64-linux": {
261-
"hash": "sha256-0b8b8afbdaf9aa533895cb9e884e3ad3e9a34d483f05a1bbde1b8902f9dbeb0f",
237+
"hash": "sha256-c7ed22cb6bb3cf23b3b36e157a3f902b2d22f2236a30e2e72110033aff4485c1",
262238
"hash_type": "git_lfs_concat"
263239
},
264240
"torch26-cxx11-cu126-x86_64-linux": {
265-
"hash": "sha256-e115e855d7ca4b97787f04c88e128432256c6b43d4823fb8889ab9985dc4cf36",
241+
"hash": "sha256-91498f3a73741f2e9b63467f0992fa28daabb3c0d9d06aec2fb650285fa7df92",
266242
"hash_type": "git_lfs_concat"
267243
},
268244
"torch26-cxx98-cu118-x86_64-linux": {
269-
"hash": "sha256-509f08c48a05584cc85c058607277fcbe3193e6cc61846dd2416d39e27c1d68e",
245+
"hash": "sha256-fcf32cbeb606021b80f3d1c86ca977a13a680fb4a7c15738487b35bc8f9edc04",
270246
"hash_type": "git_lfs_concat"
271247
},
272248
"torch26-cxx98-cu124-x86_64-linux": {
273-
"hash": "sha256-a10236bffd435296c736ae2762ab0836da2421297e46b377368a17b39d70c27b",
249+
"hash": "sha256-eeff3d5134795a25bb484b95a11f72658ef096766d13a126530cc379cb74850b",
274250
"hash_type": "git_lfs_concat"
275251
},
276252
"torch26-cxx98-cu126-aarch64-linux": {
277-
"hash": "sha256-ca2cb56f3eea4c399a61e21ba9b577d718b250aa60a13f42f01019ddd5cd8b0c",
253+
"hash": "sha256-8aaaae2f066c2b041828703d09882f80e9c058527385b0cfe256349972d12929",
278254
"hash_type": "git_lfs_concat"
279255
},
280256
"torch26-cxx98-cu126-x86_64-linux": {
281-
"hash": "sha256-8fcd62d8243a30b63a03751cc0c15d24f6e00e43eae79f7281627f24e078bf9a",
257+
"hash": "sha256-6556ddcd229b4532572294a1313f394a0f9f15be8d1cab1007dbc0ba712a1a94",
282258
"hash_type": "git_lfs_concat"
283259
},
284260
"torch27-cxx11-cu118-x86_64-linux": {
285-
"hash": "sha256-60f5807ee3da937c57c1b6080c30632305aa4875ed5a52bf4e81968770b61b13",
261+
"hash": "sha256-c5fe51f7830adc47a642151256b023fde606611e641fb12acccf9e5cc2d319e3",
286262
"hash_type": "git_lfs_concat"
287263
},
288264
"torch27-cxx11-cu126-aarch64-linux": {
289-
"hash": "sha256-64298b1713dc1d950915dc6569a06e2f541de3ed80aa5b32084246c1fdc7a958",
265+
"hash": "sha256-aec9c8d1e3653c700da624cedc7619af4eea77a1ba5b0f1093f7ea22d811f335",
290266
"hash_type": "git_lfs_concat"
291267
},
292268
"torch27-cxx11-cu126-x86_64-linux": {
293-
"hash": "sha256-d9e219890dc28e8582ef21d6f81f2ebc361de218a86b742be63bc4714f102e5e",
269+
"hash": "sha256-82623a36b6921357373ec767114438a4818a86087f56944d25fe21946b217420",
294270
"hash_type": "git_lfs_concat"
295271
},
296272
"torch27-cxx11-cu128-aarch64-linux": {
297-
"hash": "sha256-d72549f51aefcf020bc74262bbbccb78094638c5ab9adc8667873d247c1cce86",
273+
"hash": "sha256-d110173a26cb02d80c5462434491f30e41f66e21c3a9723f9e4edc4cf3a9bd9f",
298274
"hash_type": "git_lfs_concat"
299275
},
300276
"torch27-cxx11-cu128-x86_64-linux": {
301-
"hash": "sha256-d31ac5f87d7c7f62c63c72946479193aed467c9417c0acead5137e0e1fa968f8",
277+
"hash": "sha256-936c75f188ffcd8debbaffed37c73edbffaaa05462bd4d2fc78f767fc4678755",
302278
"hash_type": "git_lfs_concat"
303279
}
304280
}

server/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ build-backend = "setuptools.build_meta"
5959
"kernels-community/paged-attention" = ">=0.0.2"
6060
"kernels-community/moe" = ">=0.1.1"
6161
"kernels-community/punica-sgmv" = ">=0.0.1"
62-
"kernels-community/quantization" = ">=0.0.3"
62+
"kernels-community/quantization" = ">=0.1.1"
6363
"kernels-community/quantization-eetq" = ">=0.0.1"
6464
"kernels-community/rotary" = ">=0.0.1"
6565

server/text_generation_server/layers/marlin/fp8.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,21 @@ def forward(self, A: torch.Tensor) -> torch.Tensor:
7676
assert quantization is not None
7777

7878
A_flat = A.view(-1, A.shape[-1])
79-
C = quantization.fp8_marlin_gemm(
80-
A_flat,
81-
self.qweight,
82-
self.scales,
83-
self.workspace,
84-
8,
85-
A_flat.shape[0],
86-
self.scales.shape[1],
87-
A_flat.shape[1],
79+
C = quantization.gptq_marlin_gemm(
80+
a=A_flat,
81+
c=None,
82+
b_q_weight=self.qweight,
83+
b_scales=self.scales,
84+
global_scale=None,
85+
b_zeros=None,
86+
g_idx=None,
87+
perm=None,
88+
workspace=self.workspace,
89+
b_q_type=quantization.scalar_type.scalar_types.float8_e4m3fn,
90+
size_m=A_flat.shape[0],
91+
size_n=self.scales.shape[1],
92+
size_k=A_flat.shape[1],
93+
use_fp32_reduce=True,
8894
)
8995
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
9096

@@ -143,5 +149,6 @@ def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor):
143149
)
144150

145151
scales = permute_scales(scales)
152+
scales = quantization.marlin_utils_fp8.fp8_fused_exponent_bias_into_scales(scales)
146153

147154
return repacked, scales

server/text_generation_server/layers/marlin/gptq.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ class GPTQMarlinWeight(Weight):
256256
"""
257257

258258
qweight: torch.Tensor
259-
qzeros: torch.Tensor
259+
qzeros: Optional[torch.Tensor]
260260
scales: torch.Tensor
261261
g_idx: torch.Tensor
262262
perm: torch.Tensor
@@ -268,6 +268,7 @@ def __post_init__(self):
268268
assert self.scales.dtype in (torch.float16, torch.bfloat16)
269269
assert self.g_idx.dtype == torch.int32
270270
assert self.perm.dtype == torch.int32
271+
assert self.qzeros is None or self.qzeros.numel() > 0
271272

272273
def get_linear(self, bias: torch.Tensor):
273274
return GPTQMarlinLinear(
@@ -350,9 +351,6 @@ def repack_gptq_for_marlin(
350351
qweight, perm, in_features, out_features, bits
351352
)
352353

353-
if qzeros is None:
354-
qzeros = torch.empty(0, dtype=torch.int, device=qweight.device)
355-
356354
scales = permute_scales(scales)
357355

358356
is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures)
@@ -392,7 +390,7 @@ def __init__(
392390
if weight.bits not in (4, 8):
393391
raise ValueError("GPTQMarlinLinear only supports 4 and 8-bit quantization")
394392

395-
if weight.qzeros.numel() > 0:
393+
if weight.qzeros is not None:
396394
if weight.bits == 4:
397395
self.quant_type = quantization.scalar_types.uint4
398396
else:
@@ -424,20 +422,21 @@ def forward(self, A: torch.Tensor) -> torch.Tensor:
424422

425423
A_flat = A.view(-1, A.shape[-1])
426424
C = quantization.gptq_marlin_gemm(
427-
A_flat,
428-
self.qweight,
429-
self.scales,
430-
self.qzeros,
431-
self.g_idx,
432-
self.perm,
433-
self.workspace,
434-
self.quant_type,
435-
A_flat.shape[0],
436-
self.scales.shape[1],
437-
A_flat.shape[1],
438-
self.is_full_k,
439-
self.qzeros.numel() > 0,
440-
True,
425+
a=A_flat,
426+
c=None,
427+
b_q_weight=self.qweight,
428+
b_scales=self.scales,
429+
global_scale=None,
430+
b_zeros=self.qzeros,
431+
g_idx=self.g_idx,
432+
perm=self.perm,
433+
workspace=self.workspace,
434+
b_q_type=self.quant_type,
435+
size_m=A_flat.shape[0],
436+
size_n=self.scales.shape[1],
437+
size_k=A_flat.shape[1],
438+
is_k_full=self.is_full_k,
439+
use_fp32_reduce=True,
441440
)
442441
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
443442

server/text_generation_server/layers/moe/gptq_marlin.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,13 @@ def _pack_weight(
202202
device=weight.qweight.device,
203203
)
204204
qzeros = torch.empty(
205-
(n_experts,) + weight.qzeros.shape,
206-
dtype=weight.qzeros.dtype,
207-
device=weight.qzeros.device,
205+
(n_experts,) + ((0,) if weight.qzeros is None else weight.qzeros.shape),
206+
dtype=(
207+
weight.qweight.dtype if weight.qzeros is None else weight.qzeros.dtype
208+
),
209+
device=(
210+
weight.qweight.device if weight.qzeros is None else weight.qzeros.device
211+
),
208212
)
209213
scales = torch.empty(
210214
(n_experts,) + weight.scales.shape,
@@ -232,7 +236,13 @@ def _pack_weight(
232236
)
233237

234238
moe_weight.qweight[expert] = weight.qweight
235-
moe_weight.qzeros[expert] = weight.qzeros
239+
moe_weight.qzeros[expert] = (
240+
torch.zeros(
241+
(0,), device=moe_weight.qzeros.device, dtype=moe_weight.qzeros.dtype
242+
)
243+
if weight.qzeros is None
244+
else weight.qzeros
245+
)
236246
moe_weight.scales[expert] = weight.scales
237247
moe_weight.g_idx[expert] = weight.g_idx
238248
moe_weight.perm[expert] = weight.perm

0 commit comments

Comments
 (0)