Skip to content

Commit 93c8f8d

Browse files
authored
Gemma sharding and test (#70)
* Gemma sharding and test * rm oringal
1 parent 9353640 commit 93c8f8d

File tree

2 files changed

+11
-18
lines changed

2 files changed

+11
-18
lines changed

default_shardings/gemma.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# "replicated" to signify "replicated".
44
# Integer signify axis to shard: 0 <= shard axis < rank
55

6-
freqs_cis : null # torch.complex64 (16384, 128)
6+
freqs_cis : -1 # torch.complex64 (16384, 128)
77
layers.*.self_attn.wo.weight : 1 # 1, -1] # torch.float32 (2048, 2048)
88
layers.*.self_attn.wq.weight : 0 # -1, 1] # torch.float32 (2048, 2048)
99
layers.*.self_attn.wk.weight : 0 # -1, 1] # torch.float32 (256, 2048)
@@ -13,8 +13,8 @@ layers.*.mlp.gate_proj.bias : 0 # -1] # torch.float32 (16384,)
1313
layers.*.mlp.up_proj.weight : 0 # -1, 1] # torch.float32 (16384, 2048)
1414
layers.*.mlp.up_proj.bias : 0 # -1] # torch.float32 (16384,)
1515
layers.*.mlp.down_proj.weight : 1 # 1, -1] # torch.float32 (2048, 16384)
16-
layers.*.mlp.down_proj.bias : null # torch.float32 (2048,)
17-
layers.*.input_layernorm.weight : null # torch.float32 (2048,)
18-
layers.*.post_attention_layernorm.weight : null # torch.float32 (2048,)
19-
norm.weight : null # torch.float32 (2048,)
16+
layers.*.mlp.down_proj.bias : -1 # torch.float32 (2048,)
17+
layers.*.input_layernorm.weight : -1 # torch.float32 (2048,)
18+
layers.*.post_attention_layernorm.weight : -1 # torch.float32 (2048,)
19+
norm.weight : -1 # torch.float32 (2048,)
2020
embedder.weight : 1 # # 1, -1] # torch.float32 (256000, 2048)

tests/test_llama_e2e.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222
import torch
2323
import torch_xla2
2424
from torch.utils import _pytree as pytree
25-
from . import helpers
2625

2726

2827
from jetstream_pt.engine import PyTorchEngine
2928
from jetstream_pt.third_party.llama import model_exportable, model_args
3029
from jetstream_pt.third_party.llama.generation_original import LlamaOriginal
3130
from jetstream_pt import environment
31+
from tests import helpers
3232

3333

3434
class LlamaE2ETest(unittest.TestCase):
@@ -93,9 +93,8 @@ def test_jetstream_llama2_seed(self):
9393
jax.config.update("jax_platform_name", "cpu")
9494
print(f"---------> {jax.devices()}")
9595

96-
torch.set_default_dtype(torch.bfloat16)
9796
# pylint: disable-next=all
98-
env, model_arg = helpers.make_env_tiny()
97+
env, model_arg = helpers.make_env_tiny(bf16_enable=True)
9998
# pylint: disable-next=all
10099
tokens = np.arange(10, dtype=np.int32)
101100
true_length = tokens.shape[-1]
@@ -221,7 +220,6 @@ def test_llama_e2e_float32(self):
221220
print(f"---------> {jax.devices()}")
222221

223222
env, model_arg = helpers.make_env_tiny(bf16_enable=False)
224-
torch.set_default_dtype(torch.float32)
225223
out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg)
226224
self.assertEqual(out_tokens, expected_output_tokens)
227225

@@ -232,7 +230,6 @@ def test_llama_e2e_bfloat16(self):
232230
print(f"---------> {jax.devices()}")
233231

234232
env, model_arg = helpers.make_env_tiny(bf16_enable=True)
235-
torch.set_default_dtype(torch.bfloat16)
236233
out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg)
237234
self.assertNotEqual(out_tokens, expected_output_tokens)
238235

@@ -242,9 +239,8 @@ def test_llama_e2e_two_addtional_tokens(self):
242239
jax.config.update("jax_platform_name", "cpu")
243240
print(f"---------> {jax.devices()}")
244241

245-
torch.set_default_dtype(torch.bfloat16)
246242
# pylint: disable-next=all
247-
env, model_arg = helpers.make_env_tiny()
243+
env, model_arg = helpers.make_env_tiny(bf16_enable=True)
248244
# pylint: disable-next=all
249245
tokens = np.arange(10, dtype=np.int32)
250246
tokens = np.append(tokens, [15050, 3503], axis=-1)
@@ -315,9 +311,8 @@ def test_llama_e2e_four_addtional_tokens(self):
315311
jax.config.update("jax_platform_name", "cpu")
316312
print(f"---------> {jax.devices()}")
317313

318-
torch.set_default_dtype(torch.bfloat16)
319314
# pylint: disable-next=all
320-
env, model_arg = helpers.make_env_tiny()
315+
env, model_arg = helpers.make_env_tiny(bf16_enable=True)
321316
# pylint: disable-next=all
322317
tokens = np.arange(10, dtype=np.int32)
323318
tokens = np.append(tokens, [15050, 3503, 11833, 28551], axis=-1)
@@ -387,7 +382,6 @@ def test_llama_with_original_prefill_decode_32(self):
387382
print(f"---------> {jax.devices()}")
388383

389384
env, model_arg = helpers.make_env_tiny(bf16_enable=False)
390-
torch.set_default_dtype(torch.float32)
391385
# pylint: disable-next=all
392386
tokens = np.arange(10, dtype=np.int32)
393387
true_length = tokens.shape[-1]
@@ -458,12 +452,11 @@ def test_llama_with_original_prefill_decode_32(self):
458452

459453
# pylint: disable-next=all
460454
def test_llama_with_original_prefill_decode(self):
461-
"""test jetstream llama by comparing original prefill and decode steps with float32"""
455+
"""test jetstream llama by comparing original prefill and decode steps with bf16"""
462456
jax.config.update("jax_platform_name", "cpu")
463457
print(f"---------> {jax.devices()}")
464458

465-
torch.set_default_dtype(torch.bfloat16)
466-
env, model_arg = helpers.make_env_tiny()
459+
env, model_arg = helpers.make_env_tiny(bf16_enable=True)
467460
# pylint: disable-next=all
468461
tokens = np.arange(10, dtype=np.int32)
469462
true_length = tokens.shape[-1]

0 commit comments

Comments
 (0)