Skip to content

Commit 48a8a22

Browse files
authored
Enable Gemma 2B (#75)
* Enable Gemma 2B * encoding * lint * formatt
1 parent 9c0d2ac commit 48a8a22

File tree

6 files changed

+38
-72
lines changed

6 files changed

+38
-72
lines changed

benchmarks/analyze_sharegpt.py

Lines changed: 6 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -17,61 +17,11 @@
1717
CUTOFF_INPUT = 1024
1818
CUTOFF_OUTPUT = 1024
1919

20-
# batch size 60, ful cache, bfloat
21-
prefill_bucket_size_to_s = {
22-
64: 0.007696230011060834,
23-
128: 0.011508351005613805,
24-
256: 0.01721684739459306,
25-
512: 0.03257157760672271,
26-
1024: 0.08185497261583805,
27-
}
28-
29-
# batch size 96, ful cache, quantized
30-
prefill_bucket_size_to_s = {
31-
64: 0.006911616190336645,
32-
128: 0.011646182998083532,
33-
256: 0.01875854718964547,
34-
512: 0.0334438294172287,
35-
1024: 0.0643601292045787,
36-
}
37-
38-
# batch size 96, rolling, bfloat
39-
prefill_bucket_size_to_s = {
40-
64: 0.007730783987790346,
41-
128: 0.011515899002552033,
42-
256: 0.01780580161139369,
43-
512: 0.03115477201063186,
44-
1024: 0.07443338260054588,
45-
}
46-
47-
# batch size 160, rolling, quantized
48-
prefill_bucket_size_to_s = {
49-
64: 0.006821704190224409,
50-
128: 0.01175499300006777,
51-
256: 0.018776051187887787,
52-
512: 0.03392685519065708,
53-
1024: 0.06476318498607725,
54-
}
55-
56-
prefill_bucket_size_to_ms = {
57-
k: p * 1000 for k, p in prefill_bucket_size_to_s.items()
58-
}
59-
60-
# batch size 60, ful cache, bfloat
61-
SYSTEM_TIME_PER_DECODE_TOKEN_MS = 26.55 / 60
62-
63-
# batch size 96, ful cache, quantized
64-
SYSTEM_TIME_PER_DECODE_TOKEN_MS = 26.0 / 96
65-
66-
# batch size 96, rolling, bfloat
67-
SYSTEM_TIME_PER_DECODE_TOKEN_MS = 28.18 / 96
68-
69-
# batch size 160, rolling, quantized
70-
SYSTEM_TIME_PER_DECODE_TOKEN_MS = 30 / 160
71-
7220

7321
# pylint: disable-next=all
74-
def do_simulation(prefill_bucket_size_to_ms, system_time_per_decode_token_ms):
22+
def do_simulation(
23+
sharegpt_path, prefill_bucket_size_to_ms, system_time_per_decode_token_ms
24+
):
7525
def next_power_of_2(x):
7626
return 1 if x == 0 else 2 ** (x - 1).bit_length()
7727

@@ -82,10 +32,9 @@ def tokens_in_input_str(s):
8232

8333
convo_numbers = []
8434
# Please update with your own data file path
85-
loaded_share_gpt = json.load(
86-
# pylint: disable-next=all
87-
open("~/data/ShareGPT_V3_unfiltered_cleaned_split.json", "r")
88-
)
35+
36+
with open(sharegpt_path, "r", encoding="utf-8") as f:
37+
loaded_share_gpt = json.load(f)
8938
for example in loaded_share_gpt:
9039
if len(example["conversations"]) < 2:
9140
continue

benchmarks/run_offline.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,13 @@
6767
_MAX_CACHE_LENGTH = flags.DEFINE_integer(
6868
"max_cache_length", 1024, "kv_cache_quantize"
6969
)
70+
_MODEL_NAME = flags.DEFINE_string("model_name", "", "model_name")
71+
_SHARDING_CONFIG = flags.DEFINE_string(
72+
"sharding_config", "", "path to sharding config"
73+
)
74+
_SHAREGPT_PATH = flags.DEFINE_string(
75+
"sharegpt_path", "", "path to sharegpt json file"
76+
)
7077

7178

7279
def create_engine():
@@ -87,6 +94,8 @@ def create_engine():
8794
quantize_weights=_QUANTIZE_WEIGHTS.value,
8895
quantize_kv=_QUANTIZE_KV_CACHE.value,
8996
max_cache_length=_MAX_CACHE_LENGTH.value,
97+
model_name=_MODEL_NAME.value,
98+
sharding_config=_SHARDING_CONFIG.value,
9099
)
91100

92101
print("Initialize engine", time.perf_counter() - start)
@@ -185,7 +194,10 @@ def main(argv):
185194
prefill_times_ms = {k: v * 1000 for k, v in prefill_times.items()}
186195
decode_time_ms = sum(dec_times) * 1000 / 10 / _BATCH_SIZE.value
187196

188-
analyze_sharegpt.do_simulation(prefill_times_ms, decode_time_ms)
197+
if _SHAREGPT_PATH.value:
198+
analyze_sharegpt.do_simulation(
199+
_SHAREGPT_PATH.value, prefill_times_ms, decode_time_ms
200+
)
189201

190202

191203
if __name__ == "__main__":

jetstream_pt/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def __init__(
8686
self.y_sharding = env.sharding_by_axis(1)
8787
self.x_sharding = env.sharding_by_axis(0)
8888
self.replicated = env.sharding_by_axis(-1) # replicated
89-
self.cache_sharding = self.y_sharding
89+
self.cache_sharding = self.env.cache_sharding
9090

9191
self.prefill = jax.jit(
9292
self.prefill, out_shardings=self.get_prefix_destination_sharding()

jetstream_pt/environment.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,16 @@ def __init__(self, data: JetEngineEnvironmentData):
9797
self.x_sharding = jsharding.NamedSharding(self._mesh, P("x"))
9898
self.replicated = jsharding.NamedSharding(self._mesh, P())
9999

100-
cache_sharding = (
101-
"x" if axis == self._data.kv_cache_shard_axis else None
102-
for axis in self._data.attention_kv_axis_names
103-
)
104-
self.cache_sharding = jsharding.NamedSharding(
105-
self._mesh, P(*cache_sharding)
100+
cache_sharding_axis = self.attention_kv_axis_names.index(
101+
self.kv_cache_shard_axis
106102
)
103+
104+
if self.cache_shape[cache_sharding_axis] == 1:
105+
# cannot shard on an axis that is 1
106+
# default to last
107+
cache_sharding_axis = len(self.cache_shape) - 1
108+
109+
self.cache_sharding = self.sharding_by_axis(cache_sharding_axis)
107110
self._load_sharding_config()
108111

109112
def _load_sharding_config(self):

jetstream_pt/layers.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,6 @@ def __call__(self, xq, xk, xv, mask, cache):
151151

152152
with jax.named_scope("attn_insert_cache"):
153153
keys, values = cache.update(xk, xv)
154-
self.env.apply_sharding(keys, axis=1)
155-
self.env.apply_sharding(values, axis=1)
156154
keys = repeat_kv(keys, n_rep)
157155
values = repeat_kv(values, n_rep)
158156
with jax.named_scope("attn_mat1"):
@@ -206,8 +204,6 @@ def __call__(self, xq, xk, xv, mask, cache):
206204

207205
with jax.named_scope("attn_insert_cache"):
208206
keys, values, k_scaler, v_scaler = cache.update(xk, xv)
209-
self.env.apply_sharding(keys, axis=1)
210-
self.env.apply_sharding(values, axis=1)
211207
keys = repeat_kv(keys, n_rep)
212208
values = repeat_kv(values, n_rep)
213209
with jax.named_scope("attn_mat1"):

jetstream_pt/third_party/gemma/model.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,15 @@ def forward(
148148
xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
149149
xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)
150150

151-
self.env.apply_sharding(xq, axis=2)
152-
self.env.apply_sharding(xk, axis=2)
153-
self.env.apply_sharding(xv, axis=2)
151+
if self.num_kv_heads > 1:
152+
self.env.apply_sharding(xq, axis=2)
153+
self.env.apply_sharding(xk, axis=2)
154+
self.env.apply_sharding(xv, axis=2)
155+
else:
156+
# Gemma 2B
157+
self.env.apply_sharding(xq, axis=3)
158+
self.env.apply_sharding(xk, axis=3)
159+
self.env.apply_sharding(xv, axis=3)
154160

155161
# Positional embedding.
156162
xq = apply_rotary_emb(xq, freqs_cis=freqs_cis)

0 commit comments

Comments
 (0)