Skip to content

Commit 2717dce

Browse files
authored
convert: convert bf16 vision weights to fp16 (ollama#12324)
This change moves back to converting bf16 vision weights to fp16, specifically if they start with the name "v." (such as v.blk.0.attn_k.weight). This fixes a bug where converted images are failing because they are trying to call `im2col` which doesn't have a bf16 kernel in ggml.
1 parent 9b8187b commit 2717dce

File tree

2 files changed

+63
-1
lines changed

2 files changed

+63
-1
lines changed

convert/reader_safetensors.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ type safetensor struct {
9696

9797
func (st safetensor) Kind() uint32 {
9898
kind := st.tensorBase.Kind()
99-
if st.dtype == "BF16" && kind != tensorKindFP32 {
99+
if !strings.HasPrefix(st.name, "v.") && st.dtype == "BF16" && kind != tensorKindFP32 {
100100
kind = tensorKindBF16
101101
}
102102

convert/reader_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,65 @@ func TestSafetensors(t *testing.T) {
230230
})
231231
}
232232
}
233+
234+
func TestSafetensorKind(t *testing.T) {
235+
tests := []struct {
236+
name string
237+
st safetensor
238+
expected uint32
239+
}{
240+
{
241+
name: "BF16 dtype with non-v. prefix and non-FP32 base kind should return BF16",
242+
st: safetensor{
243+
tensorBase: &tensorBase{
244+
name: "weight.matrix",
245+
shape: []uint64{10, 10}, // will default to FP16
246+
},
247+
dtype: "BF16",
248+
},
249+
expected: tensorKindBF16,
250+
},
251+
{
252+
name: "BF16 dtype with v. prefix should return base kind",
253+
st: safetensor{
254+
tensorBase: &tensorBase{
255+
name: "v.weight.matrix",
256+
shape: []uint64{10, 10}, // will default to FP16
257+
},
258+
dtype: "BF16",
259+
},
260+
expected: tensorKindFP16,
261+
},
262+
{
263+
name: "BF16 dtype with FP32 base kind should return FP32",
264+
st: safetensor{
265+
tensorBase: &tensorBase{
266+
name: "weight.matrix",
267+
shape: []uint64{10}, // will default to FP32
268+
},
269+
dtype: "BF16",
270+
},
271+
expected: tensorKindFP32,
272+
},
273+
{
274+
name: "Non-BF16 dtype should return base kind",
275+
st: safetensor{
276+
tensorBase: &tensorBase{
277+
name: "weight.matrix",
278+
shape: []uint64{10, 10}, // will default to FP16
279+
},
280+
dtype: "FP16",
281+
},
282+
expected: tensorKindFP16,
283+
},
284+
}
285+
286+
for _, tt := range tests {
287+
t.Run(tt.name, func(t *testing.T) {
288+
result := tt.st.Kind()
289+
if result != tt.expected {
290+
t.Errorf("Kind() = %d, expected %d", result, tt.expected)
291+
}
292+
})
293+
}
294+
}

0 commit comments

Comments
 (0)