Skip to content

Commit 9353640

Browse files
authored
Add gemma support (#69)
add gemma 1 shard config for llamA gemma 3 gemma4 formatter
1 parent 137eb47 commit 9353640

21 files changed

+1290
-112
lines changed

default_shardings/gemma.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
2+
# Sharding config for gemma
3+
# "replicated" to signify "replicated".
4+
# Integer signify axis to shard: 0 <= shard axis < rank
5+
6+
freqs_cis : null # torch.complex64 (16384, 128)
7+
layers.*.self_attn.wo.weight : 1 # 1, -1] # torch.float32 (2048, 2048)
8+
layers.*.self_attn.wq.weight : 0 # -1, 1] # torch.float32 (2048, 2048)
9+
layers.*.self_attn.wk.weight : 0 # -1, 1] # torch.float32 (256, 2048)
10+
layers.*.self_attn.wv.weight : 0 # -1, 1] # torch.float32 (256, 2048)
11+
layers.*.mlp.gate_proj.weight : 0 # -1, 1] # torch.float32 (16384, 2048)
12+
layers.*.mlp.gate_proj.bias : 0 # -1] # torch.float32 (16384,)
13+
layers.*.mlp.up_proj.weight : 0 # -1, 1] # torch.float32 (16384, 2048)
14+
layers.*.mlp.up_proj.bias : 0 # -1] # torch.float32 (16384,)
15+
layers.*.mlp.down_proj.weight : 1 # 1, -1] # torch.float32 (2048, 16384)
16+
layers.*.mlp.down_proj.bias : null # torch.float32 (2048,)
17+
layers.*.input_layernorm.weight : null # torch.float32 (2048,)
18+
layers.*.post_attention_layernorm.weight : null # torch.float32 (2048,)
19+
norm.weight : null # torch.float32 (2048,)
20+
embedder.weight : 1 # # 1, -1] # torch.float32 (256000, 2048)

default_shardings/llama-2.yaml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
2+
# Sharding config for llama-2
3+
# Sharding should either be an int between 0 and rank - 1
4+
# signifying the axis to shard or -1 / null signifying replicated
5+
6+
7+
freqs_cis : -1 # torch.complex64 (2048, 64)
8+
tok_embeddings.weight : 1 # torch.float32 (32000, 4096)
9+
layers.*.attention.wo.weight : 1 # torch.int8 (4096, 4096)
10+
layers.*.attention.wo.weight_scaler : 0 # torch.bfloat16 (4096,)
11+
layers.*.attention.wq.weight : 0 # torch.int8 (4096, 4096)
12+
layers.*.attention.wq.weight_scaler : 0 # torch.bfloat16 (4096,)
13+
layers.*.attention.wk.weight : 0 # torch.int8 (4096, 4096)
14+
layers.*.attention.wk.weight_scaler : 0 # torch.bfloat16 (4096,)
15+
layers.*.attention.wv.weight : 0 # torch.int8 (4096, 4096)
16+
layers.*.attention.wv.weight_scaler : 0 # torch.bfloat16 (4096,)
17+
layers.*.feed_forward.w1.weight : 0 # torch.float32 (11008, 4096)
18+
layers.*.feed_forward.w2.weight : 1 # torch.float32 (4096, 11008)
19+
layers.*.feed_forward.w3.weight : 0 # torch.float32 (11008, 4096)
20+
layers.*.attention_norm.weight : -1 # torch.float32 (4096,)
21+
layers.*.ffn_norm.weight : -1 # torch.float32 (4096,)
22+
norm.weight : -1 # torch.float32 (4096,)
23+
output.weight : 0 # torch.float32 (32000, 4096)

jetstream_pt/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@
1313
# limitations under the License.
1414

1515
from jetstream_pt.engine import create_pytorch_engine
16+
17+
__all__ = ["create_pytorch_engine"]

jetstream_pt/engine.py

Lines changed: 47 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from jetstream_pt import quantize
3535
from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData
3636
from jetstream_pt.third_party.llama import model_exportable, model_args
37+
from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model
3738

3839

3940
Mesh = jax.sharding.Mesh
@@ -108,32 +109,6 @@ def __init__(
108109
# out_shardings=self.get_decode_state_sharding())
109110
self._lock = threading.RLock()
110111

111-
# pylint: disable-next=all
112-
def sharding_by_name(self, name):
113-
114-
# This allows easier way to edit shardings
115-
"""
116-
for key, val in self.env._data.experimental_sharding_axis_override.items():
117-
if name.endswith(key):
118-
return self.env.sharding_by_axis(val)
119-
"""
120-
121-
if "weight_scaler" in name:
122-
return self.x_sharding
123-
if "tok_embeddings." in name:
124-
return self.y_sharding
125-
if "attention." in name:
126-
if "wo" in name:
127-
return self.y_sharding
128-
return self.x_sharding
129-
if "feed_forward." in name:
130-
if "w2" in name:
131-
return self.y_sharding
132-
return self.x_sharding
133-
if "output" in name:
134-
return self.x_sharding
135-
return self.replicated
136-
137112
# pylint: disable-next=all
138113
def init_decode_state(
139114
self,
@@ -561,7 +536,7 @@ def _load_from_safetensors(self, path):
561536
for key, model_weights in self.pt_model.state_dict().items():
562537
if key == "freqs_cis":
563538
continue
564-
arr = jax.device_put(f.get_tensor(key), self.sharding_by_name(key))
539+
arr = jax.device_put(f.get_tensor(key), self.env.sharding_by_name(key))
565540
assert tuple(model_weights.shape) == tuple(
566541
arr.shape
567542
), f"key: {key} error: {model_weights.shape} != {arr.shape}"
@@ -587,7 +562,7 @@ def load_params(self) -> Params:
587562
else:
588563
jax_weights = self._make_state_dict_jax(self.pt_model.state_dict())
589564
jax_weights = {
590-
key: jax.device_put(value, self.sharding_by_name(key))
565+
key: jax.device_put(value, self.env.sharding_by_name(key))
591566
for key, value in jax_weights.items()
592567
}
593568
for k, v in jax_weights.items():
@@ -664,6 +639,7 @@ def create_pytorch_engine(
664639
quantize_weights=False,
665640
quantize_kv=False,
666641
max_cache_length=1024,
642+
sharding_config=None,
667643
) -> PyTorchEngine:
668644
"""Returns: The pytorch engine."""
669645

@@ -706,42 +682,58 @@ def create_pytorch_engine(
706682
tokenizer = token_utils.load_vocab(tokenizer_path)
707683
pt_model = None
708684

685+
env_data = JetEngineEnvironmentData(
686+
tokenizer_path=tokenizer_path,
687+
checkpoint_path=checkpoint_path,
688+
checkpoint_format=checkpoint_format,
689+
batch_size=batch_size,
690+
max_decode_length=max_decode_length,
691+
max_input_sequence_length=context_length,
692+
enable_weight_quantization=quantize_weights,
693+
enable_kv_quantization=quantize_kv,
694+
cache_sequence_length=max_cache_length,
695+
bf16_enable=bf16_enable,
696+
sharding_config_path=sharding_config,
697+
)
698+
709699
if model_name.startswith("llama"):
710700

711701
args = model_args.get_model_args(
712702
model_name + "-" + param_size, context_length, batch_size, bf16_enable
713703
)
714704
args.device = "meta"
715705
args.quantize = quantize_weights
716-
env_data = JetEngineEnvironmentData(
717-
tokenizer_path=tokenizer_path,
718-
checkpoint_path=checkpoint_path,
719-
checkpoint_format=checkpoint_format,
720-
model_type="llama-2-" + param_size,
721-
batch_size=batch_size,
722-
max_decode_length=max_decode_length,
723-
max_input_sequence_length=context_length,
724-
enable_weight_quantization=quantize_weights,
725-
enable_kv_quantization=quantize_kv,
726-
cache_sequence_length=max_cache_length,
727-
bf16_enable=bf16_enable,
728-
num_layers=args.n_layers,
729-
cache_shape=(
730-
batch_size,
731-
args.n_kv_heads,
732-
max_cache_length,
733-
args.dim // args.n_heads,
734-
),
706+
env_data.cache_shape = (
707+
batch_size,
708+
args.n_kv_heads,
709+
max_cache_length,
710+
args.dim // args.n_heads,
735711
)
712+
env_data.model_type = "llama-2-" + param_size
713+
env_data.num_layers = args.n_layers
736714
env = JetEngineEnvironment(env_data)
737715
pt_model = model_exportable.Transformer(args, env)
738-
739-
num_params_size = 0
740-
num_params = 0
741-
for _, v in pt_model.state_dict().items():
742-
num_params += 1
743-
num_params_size += np.prod(v.shape) * (1 if v.dtype == torch.int8 else 2)
744-
print("Number of param Gbytes:", num_params_size / (1 << 30))
745-
print("Number of param: ", num_params)
716+
elif model_name == "gemma":
717+
args = gemma_config.get_model_config(param_size)
718+
env_data.cache_shape = (
719+
batch_size,
720+
args.num_key_value_heads,
721+
max_cache_length,
722+
args.head_dim,
723+
)
724+
env_data.model_type = "gemma-" + param_size
725+
env_data.num_layers = args.num_hidden_layers
726+
env = JetEngineEnvironment(env_data)
727+
pt_model = gemma_model.GemmaModel(args, env)
728+
else:
729+
raise RuntimeError(f"Model with name {model_name} not found")
730+
731+
num_params_size = 0
732+
num_params = 0
733+
for _, v in pt_model.state_dict().items():
734+
num_params += 1
735+
num_params_size += np.prod(v.shape) * (1 if v.dtype == torch.int8 else 2)
736+
print("Number of param Gbytes:", num_params_size / (1 << 30))
737+
print("Number of param: ", num_params)
746738

747739
return PyTorchEngine(pt_model=pt_model, env=env)

jetstream_pt/environment.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from typing import Tuple, Dict
1616

1717
import dataclasses
18+
import yaml
19+
1820
import jax
1921
import jax.sharding as jsharding
2022
from jax.experimental import mesh_utils
@@ -71,6 +73,8 @@ class JetEngineEnvironmentData:
7173
# If Ture, use bfloat16 as dtype. If False, use float32 as dtype
7274
bf16_enable: bool = True
7375

76+
sharding_config_path: str = ""
77+
7478

7579
# pylint: disable-next=all
7680
class JetEngineEnvironment:
@@ -100,6 +104,15 @@ def __init__(self, data: JetEngineEnvironmentData):
100104
self.cache_sharding = jsharding.NamedSharding(
101105
self._mesh, P(*cache_sharding)
102106
)
107+
self._load_sharding_config()
108+
109+
def _load_sharding_config(self):
110+
"""Load sharding config"""
111+
if self._data.sharding_config_path:
112+
with open(self._data.sharding_config_path, encoding="utf-8") as f:
113+
self._sharding_config = yaml.safe_load(f)
114+
else:
115+
self._sharding_config = {}
103116

104117
def __getattr__(self, name):
105118
return getattr(self._data, name)
@@ -150,3 +163,35 @@ def make_caches_generate(self):
150163
)
151164
)
152165
return caches
166+
167+
def sharding_by_name(self, name):
168+
"""Create sharding specified in the config."""
169+
if name in self._sharding_config:
170+
return self.sharding_by_axis(self._sharding_config[name])
171+
172+
name = process_sharding_name(name)
173+
if name in self._sharding_config:
174+
return self.sharding_by_axis(self._sharding_config[name])
175+
176+
raise RuntimeError("Sharding for name: ", name, " not specified")
177+
178+
179+
def process_sharding_name(name):
180+
"""Replace integers in param name with *.
181+
182+
Presumably all layers should have the same sharding.
183+
"""
184+
185+
def is_integer(t):
186+
try:
187+
int(t)
188+
return True
189+
# pylint: disable-next=all
190+
except: # noqa: E722
191+
return False
192+
193+
tokens = name.split(".")
194+
for i, t in enumerate(tokens):
195+
if is_integer(t):
196+
tokens[i] = "*"
197+
return ".".join(tokens)

jetstream_pt/layers.py

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -132,57 +132,53 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
132132
class Attention(nn.Module):
133133
"""Attention module."""
134134

135-
def __init__(self, args, env):
135+
def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env):
136136
super().__init__()
137-
138-
self.n_kv_heads = (
139-
args.n_heads if args.n_kv_heads is None else args.n_kv_heads
140-
)
141-
self.n_local_heads = args.n_heads
142-
self.n_local_kv_heads = self.n_kv_heads
143-
self.n_rep = self.n_local_heads // self.n_local_kv_heads
144-
self.head_dim = args.dim // args.n_heads
145-
self.max_seq_len = args.max_seq_len
146-
self.n_heads = args.n_heads
147-
137+
self.n_heads = n_heads
138+
self.n_kv_heads = n_kv_heads
139+
self.head_dim = head_dim
140+
self.n_rep = self.n_heads // self.n_kv_heads
148141
self.env = env
142+
self.hidden_size = hidden_size
149143

150-
LinearLayer = WeightOnlyInt8Linear if args.quantize else nn.Linear
144+
LinearLayer = (
145+
WeightOnlyInt8Linear if env.enable_weight_quantization else nn.Linear
146+
)
151147

152148
self.wo = LinearLayer(
153-
args.n_heads * self.head_dim,
154-
args.dim,
149+
n_heads * self.head_dim,
150+
hidden_size,
155151
bias=False,
156-
device=args.device,
152+
device=device,
157153
)
158-
self.q_size = args.n_heads * self.head_dim
154+
self.q_size = n_heads * self.head_dim
159155
self.kv_size = self.n_kv_heads * self.head_dim
160156
if self.env.qkv_fusion:
161157
self._register_load_state_dict_pre_hook(self.load_hook)
162158
self.wqkv = LinearLayer(
163-
args.dim,
164-
(args.n_heads + 2 * self.n_kv_heads) * self.head_dim,
159+
hidden_size,
160+
(n_heads + 2 * self.n_kv_heads) * self.head_dim,
165161
bias=False,
166-
device=args.device,
162+
device=device,
167163
)
168164
else:
169165
self.wq = LinearLayer(
170-
args.dim,
171-
args.n_heads * self.head_dim,
166+
hidden_size,
167+
n_heads * self.head_dim,
172168
bias=False,
173-
device=args.device,
169+
device=device,
174170
)
175171
self.wk = LinearLayer(
176-
args.dim,
172+
hidden_size,
177173
self.n_kv_heads * self.head_dim,
178174
bias=False,
179-
device=args.device,
175+
device=device,
180176
)
181177
self.wv = LinearLayer(
182-
args.dim,
178+
hidden_size,
183179
self.n_kv_heads * self.head_dim,
184180
bias=False,
185-
device=args.device,
181+
device=device,
186182
)
187183

188184
def load_hook(self, state_dict, prefix, *args):
@@ -210,9 +206,9 @@ def forward(
210206
)
211207
else:
212208
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
213-
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
214-
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
215-
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
209+
xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
210+
xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
211+
xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
216212

217213
self.env.apply_sharding(xq, axis=2)
218214
self.env.apply_sharding(xk, axis=2)
@@ -262,7 +258,8 @@ def forward(
262258
self.env.apply_sharding(output, axis=1)
263259
output = output.transpose(-3, -2).contiguous().view(bsz, seqlen, -1)
264260
self.env.apply_sharding(output, axis=2)
265-
return self.wo(output)
261+
output = self.wo(output)
262+
return output
266263
else:
267264
with jax.named_scope("attn_insert_cache"):
268265
keys, values, k_scaler, v_scaler = cache.update(xk, xv)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

0 commit comments

Comments
 (0)