Skip to content

Commit 137eb47

Browse files
authored
Support llama3 (#64)
* Support llama3 * Sync with main branch * Fix CI * fix linting * Fix pyink issues * fix run_offline script * Fix pyink * Fix after merging main * Update jetstream version in install_everything.sh * Fix unit tests * Fix test
1 parent 9606a1f commit 137eb47

18 files changed

+101
-67
lines changed

benchmarks/run_offline.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import jax
2222
import jax.numpy as jnp
2323

24-
from jetstream.engine import token_utils
2524
from jetstream_pt import engine as je
2625
# pylint: disable-next=all
2726
from benchmarks import analyze_sharegpt
@@ -97,11 +96,11 @@ def create_engine():
9796
def run_prefill_time(engine, params, decode_state, seqlen):
9897
"""Run prefill and measure time."""
9998
metadata = engine.get_tokenizer()
100-
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)
99+
tokenizer = engine.build_tokenizer(metadata)
101100

102101
text = "This is a beautiful day"
103-
tokens, true_length = token_utils.tokenize_and_pad(
104-
text, vocab, is_bos=True, prefill_lengths=[seqlen]
102+
tokens, true_length = tokenizer.encode(
103+
text, is_bos=True, prefill_lengths=[seqlen]
105104
)
106105

107106
for _ in range(3):

install_everything.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
TORCHXLA_TAG=jetstream-pytorch
16-
JETSTREAM_TAG=v0.2.0
16+
JETSTREAM_TAG=v0.2.1
1717

1818
# Uninstall existing jax
1919
pip3 show jax && pip3 uninstall -y jax

jetstream_pt/engine.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@
2626
import torch
2727
import numpy as np
2828

29-
from jetstream.engine import engine_api, tokenizer_pb2, token_utils
29+
from jetstream.engine import engine_api, tokenizer_api, tokenizer_pb2, token_utils
3030
import torch_xla2
3131
from torch.utils import _pytree as pytree
3232

3333
from jetstream_pt import cache_manager
3434
from jetstream_pt import quantize
3535
from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData
36-
from jetstream_pt.third_party.llama2 import model_exportable, model_args
36+
from jetstream_pt.third_party.llama import model_exportable, model_args
3737

3838

3939
Mesh = jax.sharding.Mesh
@@ -526,6 +526,14 @@ def get_tokenizer(self) -> tokenizer_pb2.TokenizerParameters:
526526
# pylint: disable-next=all
527527
return tokenizer_pb2.TokenizerParameters(path=self.env.tokenizer_path)
528528

529+
def build_tokenizer(
530+
self, metadata: tokenizer_pb2.TokenizerParameters # pylint: disable=all
531+
) -> tokenizer_api.Tokenizer:
532+
if "llama-3" in self.env.model_type:
533+
return token_utils.TikToken(metadata)
534+
535+
return token_utils.SentencePieceTokenizer(metadata)
536+
529537
def join_prefixes(
530538
self,
531539
prefix1: engine_api.Prefix,
@@ -652,13 +660,18 @@ def create_pytorch_engine(
652660
context_length: int = 1024,
653661
batch_size: int = 1,
654662
max_decode_length: int = 4096,
655-
model_name="llama",
663+
model_name="llama-2",
656664
quantize_weights=False,
657665
quantize_kv=False,
658666
max_cache_length=1024,
659667
) -> PyTorchEngine:
660668
"""Returns: The pytorch engine."""
661669

670+
supported_models = ["llama-2", "llama-3"]
671+
if model_name not in supported_models:
672+
raise NotImplementedError(
673+
f"Model name should be one of{','.join(supported_models)}"
674+
)
662675
# See issue b/309529778 if it's turned on.
663676
jax.config.update("jax_dynamic_shapes", False)
664677
# Pytorch exports has int64 constants.
@@ -696,11 +709,7 @@ def create_pytorch_engine(
696709
if model_name.startswith("llama"):
697710

698711
args = model_args.get_model_args(
699-
param_size,
700-
context_length,
701-
batch_size,
702-
tokenizer.vocab_size,
703-
bf16_enable,
712+
model_name + "-" + param_size, context_length, batch_size, bf16_enable
704713
)
705714
args.device = "meta"
706715
args.quantize = quantize_weights

jetstream_pt/ray_worker.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@
3232
from torch.utils import _pytree as pytree
3333
import torch_xla2
3434

35-
from jetstream.engine import engine_api, tokenizer_pb2, token_utils
35+
from jetstream.engine import engine_api, tokenizer_pb2
3636

37-
from jetstream_pt.third_party.llama2 import model_exportable, model_args
37+
from jetstream_pt.third_party.llama import model_exportable, model_args
3838

3939
from jetstream_pt import cache_manager
4040
from jetstream_pt import quantize
@@ -99,7 +99,7 @@ def __init__(
9999
context_length: int = 1024,
100100
batch_size: int = 1,
101101
max_decode_length: int = 4096,
102-
model_name="llama",
102+
model_name="llama-2",
103103
quantize_weights=False,
104104
quantize_kv=False,
105105
max_cache_length=1024,
@@ -159,14 +159,12 @@ def __init__(
159159
)
160160
env = JetEngineEnvironment(env_data)
161161

162-
tokenizer = token_utils.load_vocab(tokenizer_path)
163162
pt_model = None
164-
if model_name == "llama":
163+
if "llama" in model_name:
165164
args = model_args.get_model_args(
166-
param_size,
165+
model_name + "-" + param_size,
167166
context_length,
168167
batch_size,
169-
tokenizer.vocab_size,
170168
bf16_enable,
171169
)
172170
args.device = "meta"
File renamed without changes.

jetstream_pt/third_party/llama2/generation_original.py renamed to jetstream_pt/third_party/llama/generation_original.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from typing import List, Literal, Optional, Tuple, TypedDict
66

77
import torch
8-
from jetstream_pt.third_party.llama2 import model_original
8+
from jetstream_pt.third_party.llama import model_original
99
from flax import struct
10-
from jetstream_pt.third_party.llama2.tokenizer import Tokenizer
10+
from jetstream_pt.third_party.llama.tokenizer import Tokenizer
1111

1212
Role = Literal["system", "user", "assistant"]
1313

jetstream_pt/third_party/llama2/model_args.py renamed to jetstream_pt/third_party/llama/model_args.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,68 +34,81 @@ class ModelArgs:
3434
device = "cpu"
3535
quantize = False
3636

37+
rope_theta: float = 10000.0
38+
3739

3840
def get_arg(
39-
param_size: str,
41+
model_name: str,
4042
seqlen,
4143
batch_size,
42-
vocab_size: int,
4344
bf16_enable: bool = False,
4445
) -> ModelArgs:
4546
"""Gets model args."""
4647

4748
data = {}
48-
if param_size == "tiny":
49+
if model_name == "llama-2-tiny":
4950
data = {
5051
"dim": 128,
52+
"vocab_size": 32000,
5153
"multiple_of": 32,
5254
"n_heads": 8,
5355
"n_layers": 3,
5456
"norm_eps": 1e-05,
5557
}
56-
elif param_size == "7b":
58+
elif model_name == "llama-2-7b":
5759
data = {
5860
"dim": 4096,
61+
"vocab_size": 32000,
5962
"multiple_of": 256,
6063
"n_heads": 32,
6164
"n_layers": 32,
6265
"norm_eps": 1e-05,
6366
}
64-
elif param_size == "13b":
67+
elif model_name == "llama-2-13b":
6568
data = {
6669
"dim": 5120,
70+
"vocab_size": 32000,
6771
"multiple_of": 256,
6872
"n_heads": 40,
6973
"n_layers": 40,
7074
"norm_eps": 1e-05,
7175
}
72-
elif param_size == "70b":
76+
elif model_name == "llama-2-70b":
7377
data = {
7478
"dim": 8192,
79+
"vocab_size": 32000,
7580
"multiple_of": 4096,
7681
"ffn_dim_multiplier": 1.3,
7782
"n_heads": 64,
7883
"n_kv_heads": 8,
7984
"n_layers": 80,
8085
"norm_eps": 1e-05,
8186
}
87+
elif model_name == "llama-3-8b":
88+
data = {
89+
"dim": 4096,
90+
"vocab_size": 128256,
91+
"multiple_of": 1024,
92+
"ffn_dim_multiplier": 1.3,
93+
"n_layers": 32,
94+
"n_heads": 32,
95+
"n_kv_heads": 8,
96+
"norm_eps": 1e-05,
97+
"rope_theta": 500000.0,
98+
}
8299
return ModelArgs(
83100
max_seq_len=seqlen,
84101
max_batch_size=batch_size,
85-
vocab_size=vocab_size,
86102
bf16_enable=bf16_enable,
87103
**data,
88104
)
89105

90106

91-
def get_model_args(
92-
param_size, context_length, batch_size, vocab_size, bf16_enable
93-
):
107+
def get_model_args(model_name, context_length, batch_size, bf16_enable):
94108
model_args = get_arg(
95-
param_size=param_size,
109+
model_name=model_name,
96110
seqlen=context_length,
97111
batch_size=batch_size,
98-
vocab_size=vocab_size,
99112
bf16_enable=bf16_enable,
100113
)
101114
model_args.n_kv_heads = (

jetstream_pt/third_party/llama2/model_exportable.py renamed to jetstream_pt/third_party/llama/model_exportable.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,9 @@ def __init__(
157157
)
158158
# TODO what to do with this
159159
freqs_cis = precompute_freqs_cis(
160-
self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
160+
self.params.dim // self.params.n_heads,
161+
self.params.max_seq_len * 2,
162+
theta=self.params.rope_theta,
161163
)
162164

163165
self.register_buffer("freqs_cis", freqs_cis)

0 commit comments

Comments
 (0)