Skip to content

Commit 564b558

Browse files
authored
fix(llama): other llama flavours (ollama#12308)
* fix(llama): rope scale * spm llama * skip moe models * cleanup
1 parent a417ac9 commit 564b558

File tree

10 files changed

+74
-66
lines changed

10 files changed

+74
-66
lines changed

model/models/gemma2/model.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func New(c fs.Config) (model.Model, error) {
6363
attnValLen: int(c.Uint("attention.value_length")),
6464
eps: c.Float("attention.layer_norm_rms_epsilon"),
6565
ropeBase: c.Float("rope.freq_base", 10000.0),
66-
ropeScale: c.Float("rope.freq_scale", 1.0),
66+
ropeScale: c.Float("rope.scaling.factor", 1.0),
6767
attnLogitSoftcap: c.Float("attn_logit_softcapping"),
6868
finalLogitSoftcap: c.Float("final_logit_softcapping"),
6969
},
@@ -88,7 +88,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
8888

8989
q := sa.Query.Forward(ctx, hiddenState)
9090
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
91-
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
91+
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
9292

9393
if opts.largeModelScaling {
9494
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
@@ -98,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
9898

9999
k := sa.Key.Forward(ctx, hiddenState)
100100
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
101-
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
101+
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
102102

103103
v := sa.Value.Forward(ctx, hiddenState)
104104
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)

model/models/gemma3/model_text.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func newTextModel(c fs.Config) *TextModel {
5353
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
5454
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
5555
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
56-
ropeScale: c.Float("rope.freq_scale", 1.0),
56+
ropeScale: c.Float("rope.scaling.factor", 1.0),
5757
},
5858
}
5959

@@ -84,7 +84,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
8484
q := sa.Query.Forward(ctx, hiddenState)
8585
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
8686
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
87-
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX())
87+
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
8888

8989
if opts.largeModelScaling {
9090
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
@@ -95,7 +95,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
9595
k := sa.Key.Forward(ctx, hiddenState)
9696
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
9797
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
98-
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX())
98+
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
9999

100100
v := sa.Value.Forward(ctx, hiddenState)
101101
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)

model/models/gemma3n/model_text.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
9595
ropeBase = m.ropeBaseLocal
9696
}
9797

98-
return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
98+
return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
9999
}
100100

101101
type TextScaledWordEmbedding struct {
@@ -256,14 +256,14 @@ func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Ten
256256
query := attn.Query.Forward(ctx, hiddenStates)
257257
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
258258
query = attn.QueryNorm.Forward(ctx, query, opts.eps)
259-
query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
259+
query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
260260

261261
var key, value ml.Tensor
262262
if !sharedKV {
263263
key = attn.Key.Forward(ctx, hiddenStates)
264264
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
265265
key = attn.KeyNorm.Forward(ctx, key, opts.eps)
266-
key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
266+
key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
267267

268268
value = attn.Value.Forward(ctx, hiddenStates)
269269
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
@@ -349,7 +349,7 @@ func newTextModel(c fs.Config) *TextModel {
349349
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
350350
ropeBase: c.Float("rope.freq_base", 1_000_000),
351351
ropeBaseLocal: c.Float("rope.freq_base_local", 10_000),
352-
ropeScale: c.Float("rope.freq_scale", 1.0),
352+
ropeScale: c.Float("rope.scaling.factor", 1.0),
353353

354354
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
355355
activationSparsityScale: c.Floats("activation_sparsity_scale"),

model/models/llama/model.go

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package llama
22

33
import (
44
"cmp"
5-
"fmt"
65
"math"
76

87
"github.com/ollama/ollama/fs"
@@ -23,51 +22,60 @@ type Options struct {
2322

2423
type Model struct {
2524
model.Base
26-
model.BytePairEncoding
25+
model.TextProcessor
2726

2827
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
2928
Layers []Layer `gguf:"blk"`
3029
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
3130
Output *nn.Linear `gguf:"output,alt:token_embd"`
3231

33-
*Options
32+
Options
3433
}
3534

3635
func New(c fs.Config) (model.Model, error) {
37-
// This model currently only supports the gpt2 tokenizer
38-
if c.String("tokenizer.ggml.model") == "llama" {
39-
return nil, fmt.Errorf("unsupported tokenizer: llama")
36+
if c.Uint("expert_count") > 0 {
37+
// TODO: support mixtures of experts
38+
return nil, model.ErrUnsupportedModel
4039
}
41-
// Best effort detection of library/deepseek-coder model(s) which are incompatible
42-
if c.String("general.name") == "deepseek-ai" {
43-
return nil, fmt.Errorf("unsupported model: %s", c.String("general.name"))
40+
41+
var processor model.TextProcessor
42+
vocabulary := model.Vocabulary{
43+
Values: c.Strings("tokenizer.ggml.tokens"),
44+
Scores: c.Floats("tokenizer.ggml.scores"),
45+
Types: c.Ints("tokenizer.ggml.token_type"),
46+
Merges: c.Strings("tokenizer.ggml.merges"),
47+
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
48+
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
49+
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
50+
EOS: append(
51+
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
52+
c.Ints("tokenizer.ggml.eos_token_ids")...,
53+
),
54+
}
55+
switch c.String("tokenizer.ggml.model") {
56+
case "gpt2":
57+
processor = model.NewBytePairEncoding(
58+
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
59+
&vocabulary,
60+
)
61+
case "llama":
62+
processor = model.NewSentencePiece(&vocabulary)
63+
default:
64+
return nil, model.ErrUnsupportedTokenizer
4465
}
66+
4567
m := Model{
46-
BytePairEncoding: model.NewBytePairEncoding(
47-
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
48-
&model.Vocabulary{
49-
Values: c.Strings("tokenizer.ggml.tokens"),
50-
Types: c.Ints("tokenizer.ggml.token_type"),
51-
Merges: c.Strings("tokenizer.ggml.merges"),
52-
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
53-
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
54-
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
55-
EOS: append(
56-
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
57-
c.Ints("tokenizer.ggml.eos_token_ids")...,
58-
),
59-
},
60-
),
61-
Layers: make([]Layer, c.Uint("block_count")),
62-
Options: &Options{
68+
TextProcessor: processor,
69+
Layers: make([]Layer, c.Uint("block_count")),
70+
Options: Options{
6371
hiddenSize: int(c.Uint("embedding_length")),
6472
numHeads: int(c.Uint("attention.head_count")),
6573
numKVHeads: int(c.Uint("attention.head_count_kv")),
6674
headDim: int(c.Uint("attention.key_length")),
6775
ropeDim: int(c.Uint("rope.dimension_count")),
6876
eps: c.Float("attention.layer_norm_rms_epsilon"),
69-
ropeBase: c.Float("rope.freq_base"),
70-
ropeScale: c.Float("rope.freq_scale", 1),
77+
ropeBase: c.Float("rope.freq_base", 1e5),
78+
ropeScale: c.Float("rope.scaling.factor", 1),
7179
},
7280
}
7381

@@ -98,8 +106,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
98106
value := sa.Value.Forward(ctx, hiddenState)
99107
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
100108

101-
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
102-
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
109+
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
110+
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
103111

104112
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
105113
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
@@ -108,7 +116,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
108116

109117
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
110118
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
111-
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil
119+
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil
112120
}
113121

114122
type MLP struct {
@@ -163,7 +171,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
163171
outputs = batch.Outputs
164172
}
165173

166-
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options)
174+
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, &m.Options)
167175
}
168176

169177
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)

model/models/llama4/model_text.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent
3333
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
3434

3535
if useRope {
36-
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
37-
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
36+
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
37+
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
3838
}
3939

4040
if opts.useQKNorm {
@@ -196,7 +196,7 @@ func newTextModel(c fs.Config) *TextModel {
196196
numExpertsUsed: int(c.Uint("expert_used_count")),
197197
ropeDim: int(c.Uint("rope.dimension_count")),
198198
ropeBase: c.Float("rope.freq_base"),
199-
ropeScale: c.Float("rope.freq_scale", 1),
199+
ropeScale: c.Float("rope.scaling.factor", 1),
200200
eps: c.Float("attention.layer_norm_rms_epsilon"),
201201
interleaveLayerStep: int(c.Uint("interleave_moe_layer_step", 1)),
202202
noRopeInterval: int(c.Uint("no_rope_interval", 4)),
@@ -248,5 +248,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
248248
}
249249

250250
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
251-
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil
251+
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil
252252
}

model/models/mistral3/model_text.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
4040

4141
q := sa.Query.Forward(ctx, hiddenState)
4242
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
43-
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale)
43+
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale)
4444

4545
k := sa.Key.Forward(ctx, hiddenState)
4646
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
47-
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale)
47+
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale)
4848

4949
v := sa.Value.Forward(ctx, hiddenState)
5050
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -55,7 +55,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
5555
}
5656

5757
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
58-
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale), nil
58+
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale), nil
5959
}
6060

6161
type MLP struct {
@@ -132,7 +132,7 @@ func newTextModel(c fs.Config) *TextModel {
132132
ropeDim: int(c.Uint("rope.dimension_count")),
133133
eps: c.Float("attention.layer_norm_rms_epsilon"),
134134
ropeBase: c.Float("rope.freq_base"),
135-
ropeScale: c.Float("rope.freq_scale", 1),
135+
ropeScale: c.Float("rope.scaling.factor", 1),
136136
},
137137
}
138138
}

model/models/mllama/model_text.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
2626

2727
query := sa.Query.Forward(ctx, hiddenState)
2828
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
29-
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
29+
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
3030

3131
key := sa.Key.Forward(ctx, hiddenState)
3232
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
33-
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
33+
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
3434

3535
value := sa.Value.Forward(ctx, hiddenState)
3636
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -45,7 +45,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
4545
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
4646
// This will only get called for layers in the cache, which are just the self attention layers
4747
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
48-
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil
48+
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil
4949
}
5050

5151
return key, nil
@@ -244,7 +244,7 @@ func newTextModel(c fs.Config) *TextModel {
244244
ropeDim: int(c.Uint("rope.dimension_count")),
245245
eps: c.Float("attention.layer_norm_rms_epsilon"),
246246
ropeBase: c.Float("rope.freq_base"),
247-
ropeScale: c.Float("rope.freq_scale", 1),
247+
ropeScale: c.Float("rope.scaling.factor", 1),
248248
crossAttentionLayers: c.Ints("attention.cross_attention_layers"),
249249
},
250250
}

model/models/qwen2/model.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
4343
value := attn.Value.Forward(ctx, hiddenStates)
4444
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
4545

46-
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
47-
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
46+
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
47+
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
4848

4949
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
5050
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
@@ -124,7 +124,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
124124

125125
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
126126
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
127-
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
127+
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
128128
}
129129

130130
func New(c fs.Config) (model.Model, error) {
@@ -160,7 +160,7 @@ func New(c fs.Config) (model.Model, error) {
160160
headDim: int(c.Uint("attention.key_length")),
161161
ropeDim: int(c.Uint("rope.dimension_count")),
162162
ropeBase: c.Float("rope.freq_base"),
163-
ropeScale: c.Float("rope.freq_scale", 1),
163+
ropeScale: c.Float("rope.scaling.factor", 1),
164164
eps: c.Float("attention.layer_norm_rms_epsilon"),
165165
},
166166
}

model/models/qwen25vl/model_text.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ func NewTextModel(c fs.Config) *TextModel {
3838
originalContextLength: int(c.Uint("context_length", 128000)),
3939
eps: c.Float("attention.layer_norm_rms_epsilon"),
4040
ropeBase: c.Float("rope.freq_base"),
41-
ropeScale: c.Float("rope.freq_scale", 1),
41+
ropeScale: c.Float("rope.scaling.factor", 1),
4242
},
4343
}
4444

@@ -60,11 +60,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
6060

6161
q := sa.Query.Forward(ctx, hiddenState)
6262
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
63-
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
63+
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
6464

6565
k := sa.Key.Forward(ctx, hiddenState)
6666
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
67-
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
67+
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
6868

6969
v := sa.Value.Forward(ctx, hiddenState)
7070
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -78,7 +78,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
7878

7979
// Shift applies rotary position embeddings to the key tensor for causal attention caching
8080
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
81-
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil
81+
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil
8282
}
8383

8484
// MLP implements the feed-forward network component with SwiGLU activation

0 commit comments

Comments
 (0)