Skip to content

Commit bb174b6

Browse files
authored
Llama 3.1 RoPE scaling (#205)
* feat(llama): import RoPE scaling code This is imported from the original Llama reference implementation: https://github.com/meta-llama/llama-models/blob/7890266c5a3ccd29e739d53a71ea968bcf4ca400/models/llama3/reference_impl/model.py#L45 Note that the function does not have any effect on the original model code as long as the use_scaled parameter is false (the default). * feat(llama): add RopeScalingArgs These are aligned with HF ones, so it will be easier to implement rope scaling as it is done in Llama3.1. * feat(llama): support rope scaling arguments to improve flexibility * chore: relax safetensors pattern on download * feat: untie weights when needed (i.e.: Llama3.2-1B) * feat: add support for Llama3.1 - 3.2 and 3.3 models
1 parent 08e4977 commit bb174b6

File tree

5 files changed

+126
-5
lines changed

5 files changed

+126
-5
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ meta-llama/Meta-Llama-3-8B
6767
meta-llama/Meta-Llama-3-8B-Instruct
6868
meta-llama/Meta-Llama-3-70B
6969
meta-llama/Meta-Llama-3-70B-Instruct
70+
meta-llama/Llama-3.1-8B
71+
meta-llama/Llama-3.1-8B-Instruct
72+
meta-llama/Llama-3.2-1B
73+
meta-llama/Llama-3.2-1B-Instruct
74+
meta-llama/Llama-3.3-70B
75+
meta-llama/Llama-3.3-70B-Instruct
7076
google/gemma-2b
7177
google/gemma-2b-it
7278
google/gemma-7b

jetstream_pt/cli.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@
4444
def shard_weights(env, weights, weight_shardings):
4545
"""Shard weights according to weight_shardings"""
4646
sharded = {}
47+
# Some output and embeddings weights might be tied: in this case untie them
48+
if weights["output.weight"].device.type == "meta":
49+
weights["output.weight"] = weights["tok_embeddings.weight"].clone()
4750
for key, val in weights.items():
4851
sharding = env.sharding_by_axis(weight_shardings.get(key, -1))
4952
with jax.default_device(jax.devices("cpu")[0]):

jetstream_pt/fetch_models.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ class ModelInfo:
6060
_llama2_70 = ModelInfo(llama_model.Transformer, 80, 8, 128, 8)
6161
_llama3_8 = ModelInfo(llama_model.Transformer, 32, 8, 128, 4)
6262
_llama3_70 = _llama2_70
63+
_llama3_1_8b = _llama3_8
64+
_llama3_2_1b = ModelInfo(llama_model.Transformer, 16, 8, 64, 4)
65+
_llama3_3_70b = _llama2_70
6366

6467
_mixtral_87 = ModelInfo(mixtral_model.Transformer, 32, 8, 128, 4)
6568

@@ -78,6 +81,12 @@ class ModelInfo:
7881
"meta-llama/Meta-Llama-3-8B-Instruct": _llama3_8,
7982
"meta-llama/Meta-Llama-3-70B": _llama3_70,
8083
"meta-llama/Meta-Llama-3-70B-Instruct": _llama3_70,
84+
"meta-llama/Llama-3.1-8B": _llama3_1_8b,
85+
"meta-llama/Llama-3.1-8B-Instruct": _llama3_1_8b,
86+
"meta-llama/Llama-3.2-1B": _llama3_2_1b,
87+
"meta-llama/Llama-3.2-1B-Instruct": _llama3_2_1b,
88+
"meta-llama/Llama-3.3-70B": _llama3_3_70b,
89+
"meta-llama/Llama-3.3-70B-Instruct": _llama3_3_70b,
8190
"google/gemma-2b": _gemma_2b,
8291
"google/gemma-2b-it": _gemma_2b,
8392
"google/gemma-7b": _gemma_7b,
@@ -215,7 +224,7 @@ def _hf_download(
215224
local_dir_use_symlinks=False,
216225
token=hf_token,
217226
allow_patterns=[
218-
"model-?????-of-?????.safetensors",
227+
"model*.safetensors",
219228
"*.json",
220229
"*.model",
221230
],

jetstream_pt/third_party/llama/model_args.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,16 @@
55
from typing import Optional
66

77

8+
@dataclasses.dataclass
9+
class RopeScalingArgs:
10+
"""Rope scaling configuration parameters."""
11+
12+
factor: float = 8.0
13+
low_freq_factor: float = 1.0
14+
high_freq_factor: float = 4.0
15+
original_max_position_embeddings: int = 8192
16+
17+
818
@dataclasses.dataclass
919
class ModelArgs:
1020
"""Model configuration parameters."""
@@ -29,6 +39,7 @@ class ModelArgs:
2939
device = "cpu"
3040

3141
rope_theta: float = 10000.0
42+
rope_scaling_args: RopeScalingArgs = None
3243

3344

3445
def get_arg(
@@ -103,6 +114,60 @@ def get_arg(
103114
"vocab_size": 128256,
104115
"rope_theta": 500000.0,
105116
}
117+
elif model_name == "llama-3.1-8b":
118+
data = {
119+
"dim": 4096,
120+
"vocab_size": 128256,
121+
"multiple_of": 1024,
122+
"ffn_dim_multiplier": 1.3,
123+
"n_layers": 32,
124+
"n_heads": 32,
125+
"n_kv_heads": 8,
126+
"norm_eps": 1e-05,
127+
"rope_theta": 500000.0,
128+
"rope_scaling_args": RopeScalingArgs(
129+
factor=8.0,
130+
low_freq_factor=1.0,
131+
high_freq_factor=4.0,
132+
original_max_position_embeddings=8192,
133+
),
134+
}
135+
elif model_name == "llama-3.2-1b":
136+
data = {
137+
"dim": 2048,
138+
"vocab_size": 128256,
139+
"multiple_of": 1024,
140+
"ffn_dim_multiplier": 1.5,
141+
"n_layers": 16,
142+
"n_heads": 32,
143+
"n_kv_heads": 8,
144+
"norm_eps": 1e-05,
145+
"rope_theta": 500000.0,
146+
"rope_scaling_args": RopeScalingArgs(
147+
factor=32.0,
148+
low_freq_factor=1.0,
149+
high_freq_factor=4.0,
150+
original_max_position_embeddings=8192,
151+
),
152+
}
153+
elif model_name == "llama-3.3-70b":
154+
data = {
155+
"dim": 8192,
156+
"vocab_size": 128256,
157+
"multiple_of": 1024,
158+
"ffn_dim_multiplier": 1.3,
159+
"n_layers": 80,
160+
"n_heads": 64,
161+
"n_kv_heads": 8,
162+
"norm_eps": 1e-05,
163+
"rope_theta": 500000.0,
164+
"rope_scaling_args": RopeScalingArgs(
165+
factor=8.0,
166+
low_freq_factor=1.0,
167+
high_freq_factor=4.0,
168+
original_max_position_embeddings=8192,
169+
),
170+
}
106171

107172
return ModelArgs(
108173
max_seq_len=seqlen,

jetstream_pt/third_party/llama/model_exportable.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, List, Optional
55
import copy
66
import jax
7+
import math
78
import torch
89
import torch.nn.functional as F
910
import functools
@@ -170,12 +171,42 @@ def forward(
170171
return out
171172

172173

174+
def apply_scaling(freqs: torch.Tensor, config: model_args.RopeScalingArgs):
175+
# Values obtained from grid search
176+
scale_factor = config.factor
177+
low_freq_factor = config.low_freq_factor
178+
high_freq_factor = config.high_freq_factor
179+
old_context_len = config.original_max_position_embeddings
180+
181+
low_freq_wavelen = old_context_len / low_freq_factor
182+
high_freq_wavelen = old_context_len / high_freq_factor
183+
new_freqs = []
184+
for freq in freqs:
185+
wavelen = 2 * math.pi / freq
186+
if wavelen < high_freq_wavelen:
187+
new_freqs.append(freq)
188+
elif wavelen > low_freq_wavelen:
189+
new_freqs.append(freq / scale_factor)
190+
else:
191+
assert low_freq_wavelen != high_freq_wavelen
192+
smooth = (old_context_len / wavelen - low_freq_factor) / (
193+
high_freq_factor - low_freq_factor
194+
)
195+
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
196+
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
197+
198+
173199
def precompute_freqs_cis(
174-
dim: int, end: int, theta: float = 10000.0
175-
) -> torch.Tensor:
200+
dim: int,
201+
end: int,
202+
theta: float = 10000.0,
203+
rope_scaling_config: model_args.RopeScalingArgs = None,
204+
):
176205
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
177-
t = torch.arange(end, device=freqs.device) # type: ignore
178-
freqs = torch.outer(t, freqs).float() # type: ignore
206+
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
207+
if rope_scaling_config is not None:
208+
freqs = apply_scaling(freqs, rope_scaling_config)
209+
freqs = torch.outer(t, freqs)
179210
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
180211
return freqs_cis
181212

@@ -223,6 +254,7 @@ def __init__(
223254
self.params.dim // self.params.n_heads,
224255
self.params.max_seq_len * 2,
225256
theta=self.params.rope_theta,
257+
rope_scaling_config=self.params.rope_scaling_args,
226258
)
227259

228260
self.register_buffer("freqs_cis", freqs_cis)
@@ -306,6 +338,12 @@ def from_hf_model_id(cls, model_id, env, is_tiny=False):
306338
"meta-llama/Meta-Llama-3-8B-Instruct": "llama-3-8b",
307339
"meta-llama/Meta-Llama-3-70B": "llama-3-70b",
308340
"meta-llama/Meta-Llama-3-70B-Instruct": "llama-3-70b",
341+
"meta-llama/Llama-3.1-8B": "llama-3.1-8b",
342+
"meta-llama/Llama-3.1-8B-Instruct": "llama-3.1-8b",
343+
"meta-llama/Llama-3.2-1B": "llama-3.2-1b",
344+
"meta-llama/Llama-3.2-1B-Instruct": "llama-3.2-1b",
345+
"meta-llama/Llama-3.3-70B": "llama-3.3-70b",
346+
"meta-llama/Llama-3.3-70B-Instruct": "llama-3.3-70b",
309347
}.get(model_id)
310348
assert name
311349
args = model_args.get_model_args(

0 commit comments

Comments
 (0)