Skip to content

Commit 51647b3

Browse files
authored
Enable quantization for Gemma 7b (#77)
quantize
1 parent 3de239b commit 51647b3

File tree

4 files changed

+102
-119
lines changed

4 files changed

+102
-119
lines changed

convert_checkpoints.py

Lines changed: 49 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import os
2929
import time
3030

31-
from collections.abc import Sequence
3231
from absl import app
3332
from absl import flags
3433
from etils import epath
@@ -94,7 +93,7 @@
9493
"output.weight": "ColumnParallelLinear",
9594
}
9695

97-
_QUANTIZED_WEIGHTS_TO_SCALER_NAME = {
96+
_LLAMA_QUANTIZED_WEIGHTS_TO_SCALER_NAME = {
9897
"tok_embeddings.weight": "tok_embeddings.weight_scaler",
9998
"attention.wq.weight": "attention.wq.weight_scaler",
10099
"attention.wk.weight": "attention.wk.weight_scaler",
@@ -106,82 +105,29 @@
106105
"output.weight": "output.weight_scaler",
107106
}
108107

109-
110-
def _quantize_state_dict(state_dict):
111-
updated_weights = {}
112-
for key, val in state_dict.items():
113-
for qname, qscale_name in _QUANTIZED_WEIGHTS_TO_SCALER_NAME.items():
114-
if key.endswith(qname):
115-
new_weights, scaler = quantize.quantize_torch_int8(
116-
val, reduce_axis=(1,)
117-
)
118-
updated_weights[key] = new_weights
119-
scale_name = key[: -len(qname)] + qscale_name
120-
updated_weights[scale_name] = scaler
121-
state_dict.update(updated_weights)
122-
return state_dict
123-
124-
125-
_QUANTIZE_LINEAR_WEIGHTS = {
126-
"attention.wq.weight",
127-
"attention.wk.weight",
128-
"attention.wv.weight",
129-
"attention.wo.weight",
130-
"feed_forward.w1.weight",
131-
"feed_forward.w2.weight",
132-
"feed_forward.w3.weight",
133-
"output.weight",
108+
_GEMMA_QUANTIZED_WEIGHTS_TO_SCALER_NAME = {
109+
"self_attn.o_proj.weight": "self_attn.o_proj.weight_scaler",
110+
"self_attn.wq.weight": "self_attn.wq.weight_scaler",
111+
"self_attn.wk.weight": "self_attn.wk.weight_scaler",
112+
"self_attn.wv.weight": "self_attn.wv.weight_scaler",
113+
"mlp.gate_proj.weight": "mlp.gate_proj.weight_scaler",
114+
"mlp.up_proj.weight": "mlp.up_proj.weight_scaler",
115+
"mlp.down_proj.weight": "mlp.down_proj.weight_scaler",
116+
"embedder.weight": "embedder.weight_scaler",
134117
}
135118

136119

137-
def _quantize_state_dict(state_dict):
120+
def _quantize_state_dict(state_dict, weight_map, weight_axis):
138121
updated_weights = {}
139122
for key, val in state_dict.items():
140-
for qname in _QUANTIZE_LINEAR_WEIGHTS:
123+
for qname, qscale_name in weight_map.items():
141124
if key.endswith(qname):
142125
new_weights, scaler = quantize.quantize_torch_int8(
143-
val, reduce_axis=(1,)
126+
val, reduce_axis=(weight_axis(key),)
144127
)
145128
updated_weights[key] = new_weights
146-
scale_name = key + "_scaler"
129+
scale_name = key[: -len(qname)] + qscale_name
147130
updated_weights[scale_name] = scaler.squeeze()
148-
tok_weights, tok_scalers = quantize.quantize_torch_int8(
149-
state_dict["tok_embeddings.weight"], reduce_axis=(0,)
150-
)
151-
updated_weights["tok_embeddings.weight"] = tok_weights
152-
updated_weights["tok_embeddings.weight_scaler"] = tok_scalers.squeeze()
153-
state_dict.update(updated_weights)
154-
return state_dict
155-
156-
157-
_QUANTIZE_LINEAR_WEIGHTS = {
158-
"attention.wq.weight",
159-
"attention.wk.weight",
160-
"attention.wv.weight",
161-
"attention.wo.weight",
162-
"feed_forward.w1.weight",
163-
"feed_forward.w2.weight",
164-
"feed_forward.w3.weight",
165-
"output.weight",
166-
}
167-
168-
169-
def _quantize_state_dict(state_dict):
170-
updated_weights = {}
171-
for key, val in state_dict.items():
172-
for qname in _QUANTIZE_LINEAR_WEIGHTS:
173-
if key.endswith(qname):
174-
new_weights, scaler = quantize.quantize_torch_int8(
175-
val, reduce_axis=(1,)
176-
)
177-
updated_weights[key] = new_weights
178-
scale_name = key + "_scaler"
179-
updated_weights[scale_name] = scaler.squeeze()
180-
tok_weights, tok_scalers = quantize.quantize_torch_int8(
181-
state_dict["tok_embeddings.weight"], reduce_axis=(0,)
182-
)
183-
updated_weights["tok_embeddings.weight"] = tok_weights
184-
updated_weights["tok_embeddings.weight_scaler"] = tok_scalers.squeeze()
185131
state_dict.update(updated_weights)
186132
return state_dict
187133

@@ -222,7 +168,9 @@ def _tensors_have_same_shape(tensors):
222168

223169

224170
# pylint: disable-next=all
225-
def _merge_weights(checkpoints, minimize_memory_footprint, enable_float32):
171+
def _merge_llama_weights(
172+
checkpoints, minimize_memory_footprint, enable_float32
173+
):
226174
print("Starting to merge weights.")
227175
state_dict = {}
228176
tmp_dir: epath.Path = None
@@ -362,13 +310,7 @@ def _export_to_local(output_ckpt_dir: epath.Path, params, state_dict):
362310
checklist_file.write_text(_generate_md5_checklist(output_ckpt_dir))
363311

364312

365-
def merge_weights(
366-
input_ckpt_dir: epath.Path,
367-
output_ckpt_dir: epath.Path,
368-
minimize_memory_footprint: bool = True,
369-
enable_float32: bool = False,
370-
) -> None:
371-
"""merge weights"""
313+
def _get_llama_state_dict(input_ckpt_dir):
372314
start = time.perf_counter()
373315
if "gs://" in str(input_ckpt_dir):
374316
print(
@@ -382,35 +324,15 @@ def merge_weights(
382324
print(f"Loading checkpoints takes {end - start} seconds")
383325

384326
start = time.perf_counter()
385-
state_dict = _merge_weights(
386-
checkpoints, minimize_memory_footprint, enable_float32
327+
state_dict = _merge_llama_weights(
328+
checkpoints, _MINIMIZE_MEMORY_FOOTPRINT.value, _ENABLE_FLOAT32.value
387329
)
388330
end = time.perf_counter()
389331
print(f"Merging weights takes {end - start} seconds")
390-
391-
if _QUANTIZE.value:
392-
start = time.perf_counter()
393-
state_dict = _quantize_state_dict(state_dict)
394-
end = time.perf_counter()
395-
print(f"Quantizing weights takes {end - start} seconds")
396-
397-
print(f"Writing merged weights to dir {output_ckpt_dir}")
398-
start = time.perf_counter()
399-
if "gs://" in str(output_ckpt_dir):
400-
_export_to_gcs(output_ckpt_dir, params, state_dict)
401-
else:
402-
_export_to_local(output_ckpt_dir, params, state_dict)
403-
end = time.perf_counter()
404-
print(f"Export outputs takes {end - start} seconds")
332+
return state_dict, params
405333

406334

407-
def convert_hf_gemma_weights(
408-
input_ckpt_dir: epath.Path, output_ckpt_dir: epath.Path
409-
):
410-
"""Convert gemma weights from Huggingface to be compatible with JetStream
411-
1. Map attention weights to new names.
412-
2. Split qkv fusion.
413-
"""
335+
def _get_gemma_state_dict(input_ckpt_dir):
414336
ckpt_file = list(input_ckpt_dir.glob("*.ckpt"))
415337
assert len(ckpt_file) == 1, "only expect 1 ckpt file for Gemma model."
416338
ckpt_file = ckpt_file[0]
@@ -450,24 +372,37 @@ def convert_hf_gemma_weights(
450372

451373
if new_key != key:
452374
state_dict[new_key] = state_dict.pop(key)
453-
_export_to_local(output_ckpt_dir, model_config, state_dict)
375+
return state_dict, model_config
454376

455377

456-
def main(argv: Sequence[str]) -> None:
457-
"""convert checkpoint main function"""
458-
if len(argv) > 1:
459-
raise app.UsageError("Too many command-line arguments.")
460-
if "gemma" in _MODEL_TYPE.value:
461-
convert_hf_gemma_weights(
462-
_INPUT_CHECKPOINT_DIR.value, _OUTPUT_CHECKPOINT_DIR.value
463-
)
378+
def main(argv) -> None:
379+
"""merge weights"""
380+
381+
if _MODEL_TYPE.value == "gemma":
382+
state_dict, params = _get_gemma_state_dict(_INPUT_CHECKPOINT_DIR.value)
383+
quantize_weight_map = _GEMMA_QUANTIZED_WEIGHTS_TO_SCALER_NAME
384+
weight_axis = lambda x: 0 if x == "embedder.weight" else 1
464385
else:
465-
merge_weights(
466-
_INPUT_CHECKPOINT_DIR.value,
467-
_OUTPUT_CHECKPOINT_DIR.value,
468-
_MINIMIZE_MEMORY_FOOTPRINT.value,
469-
_ENABLE_FLOAT32.value,
386+
state_dict, params = _get_llama_state_dict(_INPUT_CHECKPOINT_DIR.value)
387+
quantize_weight_map = _LLAMA_QUANTIZED_WEIGHTS_TO_SCALER_NAME
388+
weight_axis = lambda x: 0 if x == "tok_embeddings.weight" else 1
389+
390+
if _QUANTIZE.value:
391+
start = time.perf_counter()
392+
state_dict = _quantize_state_dict(
393+
state_dict, quantize_weight_map, weight_axis
470394
)
395+
end = time.perf_counter()
396+
print(f"Quantizing weights takes {end - start} seconds")
397+
398+
print(f"Writing merged weights to dir {_OUTPUT_CHECKPOINT_DIR.value}")
399+
start = time.perf_counter()
400+
if "gs://" in str(_OUTPUT_CHECKPOINT_DIR.value):
401+
_export_to_gcs(_OUTPUT_CHECKPOINT_DIR.value, params, state_dict)
402+
else:
403+
_export_to_local(_OUTPUT_CHECKPOINT_DIR.value, params, state_dict)
404+
end = time.perf_counter()
405+
print(f"Export outputs takes {end - start} seconds")
471406

472407

473408
if __name__ == "__main__":

default_shardings/gemma.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,11 @@ layers.*.input_layernorm.weight : -1 # torch.float32 (2048,)
1818
layers.*.post_attention_layernorm.weight : -1 # torch.float32 (2048,)
1919
norm.weight : -1 # torch.float32 (2048,)
2020
embedder.weight : 1 # # 1, -1] # torch.float32 (256000, 2048)
21+
embedder.weight_scaler : 0
22+
layers.*.self_attn.o_proj.weight_scaler: 0
23+
layers.*.self_attn.wq.weight_scaler : 0
24+
layers.*.self_attn.wk.weight_scaler : 0
25+
layers.*.self_attn.wv.weight_scaler : 0
26+
layers.*.mlp.gate_proj.weight_scaler : 0
27+
layers.*.mlp.up_proj.weight_scaler : 0
28+
layers.*.mlp.down_proj.weight_scaler : 0

jetstream_pt/third_party/gemma/model.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -346,9 +346,7 @@ def forward(
346346
hidden_states = self.norm(hidden_states)
347347

348348
embedder_weight = self.embedder.weight
349-
if self.config.quant:
350-
embedder_weight = embedder_weight * self.embedder.weight_scaler.unsqueeze(
351-
-1
352-
)
349+
if self.env.enable_weight_quantization:
350+
embedder_weight = embedder_weight * self.embedder.weight_scaler
353351
logits = torch.matmul(hidden_states, embedder_weight.t())
354352
return logits

tests/test_quantization.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
import unittest
1616
import jax
1717
import jax.numpy as jnp
18+
from tests import helpers
1819
import torch
1920
import torch_xla2
2021

21-
from jetstream_pt import cache_manager
22+
from jetstream_pt import cache_manager, layers, quantize
2223

2324

2425
class QuantizationTest(unittest.TestCase):
@@ -49,6 +50,47 @@ def test_kv_cache(self):
4950
jnp.allclose(v._elem, new_v._elem[:, :, 57:58, :], atol=0.1)
5051
)
5152

53+
def test_kv_kernel(self):
54+
"""test kv cache quantization"""
55+
cache_shape = (3, 2, 100, 2) # bs, num heads, seqlen, dim
56+
with jax.default_device(jax.devices("cpu")[0]):
57+
env, _ = helpers.make_env_tiny(False)
58+
key = jax.random.PRNGKey(123)
59+
key2 = jax.random.PRNGKey(456)
60+
cache_k_jax = jax.random.normal(key, cache_shape)
61+
cache_v_jax = jax.random.normal(key2, cache_shape)
62+
63+
cache_k, cache_v = torch_xla2.tensor.wrap((cache_k_jax, cache_v_jax))
64+
65+
cache = cache_manager.KVCacheGenerate(cache_k, cache_v, [0], None)
66+
67+
# 1 is seqlen
68+
xq = jax.random.normal(key, (3, 2, 1, 2))
69+
xk = jax.random.normal(key, (3, 2, 1, 2))
70+
xv = jax.random.normal(key, (3, 2, 1, 2))
71+
72+
xq, xk, xv = torch_xla2.tensor.wrap((xq, xk, xv))
73+
74+
attention_float = layers.AttentionKernel(env)
75+
float_res = attention_float(xq, xk, xv, None, cache)
76+
77+
# ==
78+
79+
cache_k, cache_v = torch_xla2.tensor.wrap((cache_k_jax, cache_v_jax))
80+
cache_k_int, cache_k_scaler = quantize.quantize_torch_int8(
81+
cache_k, (1, 3)
82+
)
83+
cache_v_int, cache_v_scaler = quantize.quantize_torch_int8(
84+
cache_v, (1, 3)
85+
)
86+
cache_int = cache_manager.Int8KVCacheGenerate(
87+
cache_k_int, cache_v_int, cache_k_scaler, cache_v_scaler, [0], None
88+
)
89+
attention_quant = layers.Int8KVAttentionKernel(env)
90+
int_res = attention_quant(xq, xk, xv, None, cache_int)
91+
92+
self.assertTrue(jnp.allclose(float_res.jax(), int_res.jax(), atol=0.01))
93+
5294

5395
if __name__ == "__main__":
5496
unittest.main()

0 commit comments

Comments
 (0)