Skip to content

Commit 4dbb8f9

Browse files
committed
fix vision test
1 parent e3c2913 commit 4dbb8f9

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

tests/test_llava.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import multiprocessing
2-
import ctypes
3-
2+
import os
43
from huggingface_hub import hf_hub_download
5-
64
import pytest
7-
85
import llama_cpp
6+
from llama_cpp.llama_chat_format import Llava15ChatHandler
7+
8+
# Enable HF transfer
9+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
910

1011
@pytest.fixture
1112
def mmproj_model_path():
@@ -23,8 +24,9 @@ def llava_cpp_model_path():
2324

2425
def test_real_llava(llava_cpp_model_path, mmproj_model_path):
2526
print("initializing model")
26-
model = llama_cpp.Llama(
27-
llava_cpp_model_path,
27+
model = llama_cpp.Llama.from_pretrained(
28+
repo_id="second-state/Llava-v1.5-7B-GGUF",
29+
filename="llava-v1.5-7b-Q8_0.gguf",
2830
n_ctx=2048,
2931
n_batch=512,
3032
n_threads=multiprocessing.cpu_count(),
@@ -34,9 +36,9 @@ def test_real_llava(llava_cpp_model_path, mmproj_model_path):
3436
)
3537

3638
# Initialize the LLaVA chat handler
37-
from llama_cpp.llama_chat_format import Llava15ChatHandler
3839
print("initializing chat handler")
39-
chat_handler = Llava15ChatHandler(clip_model_path=mmproj_model_path, llama_model=model)
40+
chat_handler = Llava15ChatHandler(clip_model_path=mmproj_model_path)
41+
model.chat_handler = chat_handler
4042

4143
# Create a chat message with the image
4244
print("creating chat message")
@@ -58,13 +60,13 @@ def test_real_llava(llava_cpp_model_path, mmproj_model_path):
5860

5961
# Generate response
6062
print("generating response")
61-
response = chat_handler(
62-
llama=model,
63+
response = model.create_chat_completion(
6364
messages=messages,
6465
max_tokens=200,
6566
temperature=0.2,
6667
top_p=0.95,
67-
stream=False
68+
stream=False,
69+
stop=['<end_of_turn>', '<eos>']
6870
)
6971

7072
print("response", response)
@@ -77,4 +79,4 @@ def test_real_llava(llava_cpp_model_path, mmproj_model_path):
7779

7880
# The response should mention Leonardo da Vinci
7981
content = response["choices"][0]["message"]["content"].lower()
80-
assert "leonardo" in content and "vinci" in content # Artist name should be in response
82+
assert "leonardo" in content and "vinci" in content # Artist name should be in response

0 commit comments

Comments
 (0)