Skip to content

Commit d507086

Browse files
authored
Refactor so that environment and engine (#65)
* Refactor so that environment and engine so that they dont depend on llama specific stuff such as ModelArgs * Fix lints
1 parent a58051d commit d507086

File tree

8 files changed

+101
-105
lines changed

8 files changed

+101
-105
lines changed

install_everything.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
TORCHXLA_TAG=jetstream-pytorch
16+
JETSTREAM_TAG=v0.2.0
1617

1718
# Uninstall existing jax
1819
pip3 show jax && pip3 uninstall -y jax
@@ -34,6 +35,7 @@ git checkout $TORCHXLA_TAG
3435
pip install .
3536
popd # now at the folder deps
3637
pushd JetStream
38+
git checkout $JETSTREAM_TAG
3739
pip install .
3840
popd # now at the folder deps
3941
popd # now at the folder current file

jetstream_pt/engine.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@
2929
from jetstream.engine import engine_api, tokenizer_pb2, token_utils
3030
import torch_xla2
3131
from torch.utils import _pytree as pytree
32-
from jetstream_pt.third_party.llama2 import model_exportable, model_args
3332

3433
from jetstream_pt import cache_manager
3534
from jetstream_pt import quantize
3635
from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData
36+
from jetstream_pt.third_party.llama2 import model_exportable, model_args
3737

3838

3939
Mesh = jax.sharding.Mesh
@@ -81,9 +81,6 @@ def __init__(
8181
self.env = env
8282
self.default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32
8383

84-
# NOTE: this is llama2 specific now.
85-
self.param = pt_model.params
86-
8784
self.y_sharding = env.sharding_by_axis(1)
8885
self.x_sharding = env.sharding_by_axis(0)
8986
self.replicated = env.sharding_by_axis(-1) # replicated
@@ -486,7 +483,7 @@ def generate(
486483
mask,
487484
decode_state.input_pos,
488485
)
489-
next_token = self._sampling(logits, self.param.max_batch_size)
486+
next_token = self._sampling(logits, self.env.batch_size)
490487
lens = decode_state.lens + 1
491488
data = jnp.concatenate(
492489
[
@@ -621,7 +618,7 @@ def get_prefix_sequence_ddim(self) -> Any:
621618

622619
@property
623620
def max_concurrent_decodes(self) -> int:
624-
return self.param.max_batch_size
621+
return self.env.batch_size
625622

626623
@property
627624
def samples_per_slot(self) -> int:
@@ -630,7 +627,7 @@ def samples_per_slot(self) -> int:
630627

631628
@property
632629
def max_prefill_length(self) -> int:
633-
return self.param.max_seq_len
630+
return self.env.max_input_sequence_length
634631

635632
@property
636633
def max_decode_length(self) -> int:
@@ -693,24 +690,11 @@ def create_pytorch_engine(
693690
checkpoint_format = "safetensors"
694691
checkpoint_path = paths[0]
695692

696-
env_data = JetEngineEnvironmentData(
697-
tokenizer_path=tokenizer_path,
698-
checkpoint_path=checkpoint_path,
699-
checkpoint_format=checkpoint_format,
700-
model_type="llama-2-" + param_size,
701-
batch_size=batch_size,
702-
max_decode_length=max_decode_length,
703-
max_input_sequence_length=context_length,
704-
enable_weight_quantization=quantize_weights,
705-
enable_kv_quantization=quantize_kv,
706-
cache_sequence_length=max_cache_length,
707-
bf16_enable=bf16_enable,
708-
)
709-
env = JetEngineEnvironment(env_data)
710-
711693
tokenizer = token_utils.load_vocab(tokenizer_path)
712694
pt_model = None
713-
if model_name == "llama":
695+
696+
if model_name.startswith("llama"):
697+
714698
args = model_args.get_model_args(
715699
param_size,
716700
context_length,
@@ -720,13 +704,34 @@ def create_pytorch_engine(
720704
)
721705
args.device = "meta"
722706
args.quantize = quantize_weights
707+
env_data = JetEngineEnvironmentData(
708+
tokenizer_path=tokenizer_path,
709+
checkpoint_path=checkpoint_path,
710+
checkpoint_format=checkpoint_format,
711+
model_type="llama-2-" + param_size,
712+
batch_size=batch_size,
713+
max_decode_length=max_decode_length,
714+
max_input_sequence_length=context_length,
715+
enable_weight_quantization=quantize_weights,
716+
enable_kv_quantization=quantize_kv,
717+
cache_sequence_length=max_cache_length,
718+
bf16_enable=bf16_enable,
719+
num_layers=args.n_layers,
720+
cache_shape=(
721+
batch_size,
722+
args.n_kv_heads,
723+
max_cache_length,
724+
args.dim // args.n_heads,
725+
),
726+
)
727+
env = JetEngineEnvironment(env_data)
723728
pt_model = model_exportable.Transformer(args, env)
724729

725730
num_params_size = 0
726731
num_params = 0
727732
for _, v in pt_model.state_dict().items():
728733
num_params += 1
729-
num_params_size += np.prod(v.shape) * (1 if v.dtype == jnp.int8 else 2)
734+
num_params_size += np.prod(v.shape) * (1 if v.dtype == torch.int8 else 2)
730735
print("Number of param Gbytes:", num_params_size / (1 << 30))
731736
print("Number of param: ", num_params)
732737

jetstream_pt/environment.py

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import torch_xla2
2222

2323

24-
from jetstream_pt.third_party.llama2 import model_args
2524
from jetstream_pt import cache_manager
2625

2726

@@ -52,6 +51,11 @@ class JetEngineEnvironmentData:
5251
"head_dim",
5352
)
5453

54+
# Shape of cache len(cache_shape) == len(attention_kv_axis_names)
55+
cache_shape: Tuple[int, ...] = ()
56+
57+
num_layers: int = 0
58+
5559
# This is the axis to shard among the number of available devices
5660
# This string must be one of the values of attention_kv_axis_names above
5761
kv_cache_shard_axis: str = "num_attn_heads"
@@ -73,23 +77,8 @@ class JetEngineEnvironment:
7377

7478
def __init__(self, data: JetEngineEnvironmentData):
7579
self._data = data
76-
# Get 13b
77-
self._model_arg = model_args.get_model_args(
78-
data.model_type.replace("llama-2-", ""),
79-
context_length=data.max_input_sequence_length,
80-
batch_size=data.batch_size,
81-
vocab_size=32000, # ?
82-
bf16_enable=data.bf16_enable,
83-
)
8480

85-
self.batch_size = self._data.batch_size
8681
self.seq_len = self._data.max_input_sequence_length
87-
self.num_layers = self._model_arg.n_layers
88-
self.num_kv_heads = self._model_arg.n_kv_heads
89-
self.num_heads = self._model_arg.n_heads
90-
self.head_dim = self._model_arg.dim // self._model_arg.n_heads
91-
self.cache_sequence_length = self._data.cache_sequence_length
92-
self.bf16_enable = self._data.bf16_enable
9382

9483
P = jax.sharding.PartitionSpec
9584

@@ -115,11 +104,6 @@ def __init__(self, data: JetEngineEnvironmentData):
115104
def __getattr__(self, name):
116105
return getattr(self._data, name)
117106

118-
@property
119-
def tokenizer_path(self):
120-
"""Return tokenizer path"""
121-
return self._data.tokenizer_path
122-
123107
# This is used by model to add activation sharding.
124108
def apply_sharding(self, tensor, *, axis: int | None):
125109
"""Apply sharding for tensor"""
@@ -150,12 +134,8 @@ def make_caches_prefill(self):
150134
def make_caches_generate(self):
151135
"""Create kv caches for inference generation"""
152136
caches = []
153-
shape = (
154-
self.batch_size,
155-
self.num_kv_heads,
156-
self._data.cache_sequence_length,
157-
self.head_dim,
158-
)
137+
shape = self._data.cache_shape
138+
159139
for _ in range(self.num_layers):
160140
if self.enable_kv_quantization:
161141
caches.append(

jetstream_pt/ray_worker.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,6 @@ def __init__(
187187
self.env = env
188188
self.default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32
189189

190-
# NOTE: this is llama2 specific now.
191-
self.param = pt_model.params
192-
193190
self.y_sharding = env.sharding_by_axis(1)
194191
self.x_sharding = env.sharding_by_axis(0)
195192
self.replicated = env.sharding_by_axis(-1) # replicated
@@ -682,7 +679,7 @@ def generate(
682679
)
683680

684681
logits = multihost_utils.process_allgather(logits, tiled=True)
685-
next_token = self._sampling(logits, self.param.max_batch_size)
682+
next_token = self._sampling(logits, self.env.batch_size)
686683

687684
data = np.concatenate(
688685
[
@@ -837,7 +834,7 @@ def get_prefix_sequence_ddim(self) -> Any:
837834
@property
838835
def max_concurrent_decodes(self) -> int:
839836
"""Max batch size for decodes"""
840-
return self.param.max_batch_size
837+
return self.env.batch_size
841838

842839
@property
843840
def samples_per_slot(self) -> int:
@@ -847,7 +844,7 @@ def samples_per_slot(self) -> int:
847844
@property
848845
def max_prefill_length(self) -> int:
849846
"""Maximum prefill length"""
850-
return self.param.max_seq_len
847+
return self.env.max_input_sequence_length
851848

852849
@property
853850
def max_decode_length(self) -> int:

tests/helpers.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import torch
2+
import jax
3+
from jetstream_pt.third_party.llama2 import model_args
4+
from jetstream_pt import environment
5+
6+
7+
def make_env_tiny(bf16_enable=True):
8+
torch_dtype = torch.bfloat16 if bf16_enable else torch.float32
9+
torch.set_default_dtype(torch_dtype)
10+
jax.config.update("jax_dynamic_shapes", False)
11+
jax.config.update("jax_traceback_filtering", "off")
12+
config = model_args.get_model_args("tiny", 128, 1, 32000, True)
13+
environment_data = environment.JetEngineEnvironmentData()
14+
environment_data.max_input_sequence_length = 128
15+
environment_data.max_input_sequence_length = 128
16+
environment_data.cache_sequence_length = 128
17+
environment_data.bf16_enable = bf16_enable
18+
environment_data.model_type = "llama-2-tiny"
19+
environment_data.batch_size = 1
20+
environment_data.num_layers = config.n_layers
21+
environment_data.cache_shape = (
22+
1,
23+
config.n_kv_heads,
24+
environment_data.cache_sequence_length,
25+
config.dim // config.n_heads,
26+
)
27+
env = environment.JetEngineEnvironment(environment_data)
28+
env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu
29+
return env, config
File renamed without changes.

0 commit comments

Comments
 (0)