Skip to content

Commit 868f4ae

Browse files
committed
use real LLM for unit tests
1 parent 252191d commit 868f4ae

File tree

5 files changed

+45
-44
lines changed

5 files changed

+45
-44
lines changed

.github/workflows/run_tests.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ jobs:
8181
uses: chartboost/ruff-action@v1
8282
with:
8383
args: check --fix-only
84+
- name: Set LLM model
85+
run: |
86+
echo "LLM_MODEL=${{ secrets.LLM_MODEL }}" >> $GITHUB_ENV
87+
echo "DATABRICKS_API_BASE=${{ secrets.DATABRICKS_API_BASE }}" >> $GITHUB_ENV
88+
echo "DATABRICKS_API_KEY=${{ secrets.DATABRICKS_API_KEY }}" >> $GITHUB_ENV
8489
- name: Run tests with pytest
8590
run: uv run -p .venv pytest tests/
8691
- name: Install optional dependencies

tests/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import os
23

34
import pytest
45

@@ -49,3 +50,11 @@ def pytest_collection_modifyitems(config, items):
4950
for item in items:
5051
if flag in item.keywords:
5152
item.add_marker(skip_mark)
53+
54+
55+
@pytest.fixture
56+
def llm_model():
57+
model = os.environ.get("LLM_MODEL", None)
58+
if model is None:
59+
pytest.skip("LLM_MODEL is not set in the environment variables")
60+
return model

tests/primitives/test_base_module.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -230,30 +230,28 @@ def emit(self, record):
230230
logger.removeHandler(handler)
231231

232232

233-
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Skip the test if OPENAI_API_KEY is not set.")
234-
def test_single_module_call_with_usage_tracker():
235-
dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False), track_usage=True)
233+
def test_single_module_call_with_usage_tracker(llm_model):
234+
dspy.settings.configure(lm=dspy.LM(llm_model, cache=False), track_usage=True)
236235

237236
predict = dspy.ChainOfThought("question -> answer")
238237
output = predict(question="What is the capital of France?")
239238

240239
lm_usage = output.get_lm_usage()
241240
assert len(lm_usage) == 1
242-
assert lm_usage["openai/gpt-4o-mini"]["prompt_tokens"] > 0
243-
assert lm_usage["openai/gpt-4o-mini"]["completion_tokens"] > 0
244-
assert lm_usage["openai/gpt-4o-mini"]["total_tokens"] > 0
241+
assert lm_usage[llm_model]["prompt_tokens"] > 0
242+
assert lm_usage[llm_model]["completion_tokens"] > 0
243+
assert lm_usage[llm_model]["total_tokens"] > 0
245244

246245
# Test no usage being tracked when cache is enabled
247-
dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=True), track_usage=True)
246+
dspy.settings.configure(lm=dspy.LM(llm_model, cache=True), track_usage=True)
248247
for _ in range(2):
249248
output = predict(question="What is the capital of France?")
250249

251250
assert len(output.get_lm_usage()) == 0
252251

253252

254-
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Skip the test if OPENAI_API_KEY is not set.")
255-
def test_multi_module_call_with_usage_tracker():
256-
dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False), track_usage=True)
253+
def test_multi_module_call_with_usage_tracker(llm_model):
254+
dspy.settings.configure(lm=dspy.LM(llm_model, cache=False), track_usage=True)
257255

258256
class MyProgram(dspy.Module):
259257
def __init__(self):
@@ -270,12 +268,13 @@ def __call__(self, question: str) -> str:
270268

271269
lm_usage = output.get_lm_usage()
272270
assert len(lm_usage) == 1
273-
assert lm_usage["openai/gpt-4o-mini"]["prompt_tokens"] > 0
274-
assert lm_usage["openai/gpt-4o-mini"]["prompt_tokens"] > 0
275-
assert lm_usage["openai/gpt-4o-mini"]["completion_tokens"] > 0
276-
assert lm_usage["openai/gpt-4o-mini"]["total_tokens"] > 0
271+
assert lm_usage[llm_model]["prompt_tokens"] > 0
272+
assert lm_usage[llm_model]["prompt_tokens"] > 0
273+
assert lm_usage[llm_model]["completion_tokens"] > 0
274+
assert lm_usage[llm_model]["total_tokens"] > 0
277275

278276

277+
# TODO: prepare second model for testing this unit test in ci
279278
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Skip the test if OPENAI_API_KEY is not set.")
280279
def test_usage_tracker_in_parallel():
281280
class MyProgram(dspy.Module):

tests/streaming/test_streaming.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
import os
32
import time
43
from unittest import mock
54
from unittest.mock import AsyncMock
@@ -131,9 +130,8 @@ def module_start_status_message(self, instance, inputs):
131130
assert status_messages[2].message == "Predict starting!"
132131

133132

134-
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OpenAI API key not found in environment variables")
135133
@pytest.mark.anyio
136-
async def test_stream_listener_chat_adapter():
134+
async def test_stream_listener_chat_adapter(llm_model):
137135
class MyProgram(dspy.Module):
138136
def __init__(self):
139137
self.predict1 = dspy.Predict("question->answer")
@@ -154,7 +152,7 @@ def __call__(self, x: str, **kwargs):
154152
include_final_prediction_in_output_stream=False,
155153
)
156154
# Turn off the cache to ensure the stream is produced.
157-
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False)):
155+
with dspy.context(lm=dspy.LM(llm_model, cache=False)):
158156
output = program(x="why did a chicken cross the kitchen?")
159157
all_chunks = []
160158
async for value in output:
@@ -194,9 +192,8 @@ async def acall(self, x: str):
194192
assert status_messages[1].message == "Tool calling finished! Querying the LLM with tool calling results..."
195193

196194

197-
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OpenAI API key not found in environment variables")
198195
@pytest.mark.anyio
199-
async def test_stream_listener_json_adapter():
196+
async def test_stream_listener_json_adapter(llm_model):
200197
class MyProgram(dspy.Module):
201198
def __init__(self):
202199
self.predict1 = dspy.Predict("question->answer")
@@ -217,7 +214,7 @@ def __call__(self, x: str, **kwargs):
217214
include_final_prediction_in_output_stream=False,
218215
)
219216
# Turn off the cache to ensure the stream is produced.
220-
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter()):
217+
with dspy.context(lm=dspy.LM(llm_model, cache=False), adapter=dspy.JSONAdapter()):
221218
output = program(x="why did a chicken cross the kitchen?")
222219
all_chunks = []
223220
async for value in output:
@@ -232,22 +229,22 @@ def __call__(self, x: str, **kwargs):
232229

233230

234231
@pytest.mark.anyio
235-
async def test_streaming_handles_space_correctly():
232+
async def test_streaming_handles_space_correctly(llm_model):
236233
my_program = dspy.Predict("question->answer")
237234
program = dspy.streamify(
238235
my_program, stream_listeners=[dspy.streaming.StreamListener(signature_field_name="answer")]
239236
)
240237

241238
async def gpt_4o_mini_stream(*args, **kwargs):
242239
yield ModelResponseStream(
243-
model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="[[ ## answer ## ]]\n"))]
240+
model=llm_model, choices=[StreamingChoices(delta=Delta(content="[[ ## answer ## ]]\n"))]
244241
)
245-
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="How "))])
246-
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="are "))])
247-
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="you "))])
248-
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="doing?"))])
242+
yield ModelResponseStream(model=llm_model, choices=[StreamingChoices(delta=Delta(content="How "))])
243+
yield ModelResponseStream(model=llm_model, choices=[StreamingChoices(delta=Delta(content="are "))])
244+
yield ModelResponseStream(model=llm_model, choices=[StreamingChoices(delta=Delta(content="you "))])
245+
yield ModelResponseStream(model=llm_model, choices=[StreamingChoices(delta=Delta(content="doing?"))])
249246
yield ModelResponseStream(
250-
model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="\n\n[[ ## completed ## ]]"))]
247+
model=llm_model, choices=[StreamingChoices(delta=Delta(content="\n\n[[ ## completed ## ]]"))]
251248
)
252249

253250
with mock.patch("litellm.acompletion", side_effect=gpt_4o_mini_stream):
@@ -261,8 +258,7 @@ async def gpt_4o_mini_stream(*args, **kwargs):
261258
assert all_chunks[0].chunk == "How are you doing?"
262259

263260

264-
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OpenAI API key not found in environment variables")
265-
def test_sync_streaming():
261+
def test_sync_streaming(llm_model):
266262
class MyProgram(dspy.Module):
267263
def __init__(self):
268264
self.predict1 = dspy.Predict("question->answer")
@@ -284,7 +280,7 @@ def __call__(self, x: str, **kwargs):
284280
async_streaming=False,
285281
)
286282
# Turn off the cache to ensure the stream is produced.
287-
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False)):
283+
with dspy.context(lm=dspy.LM(llm_model, cache=False)):
288284
output = program(x="why did a chicken cross the kitchen?")
289285
all_chunks = []
290286
for value in output:

tests/utils/test_usage_tracker.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
import os
2-
3-
import pytest
4-
51
import dspy
62
from dspy.utils.usage_tracker import UsageTracker, track_usage
73

@@ -137,12 +133,8 @@ def test_track_usage_with_multiple_models():
137133
assert total_usage["gpt-3.5-turbo"]["total_tokens"] == 900
138134

139135

140-
@pytest.mark.skipif(
141-
not os.getenv("OPENAI_API_KEY"),
142-
reason="Skip the test if OPENAI_API_KEY is not set.",
143-
)
144-
def test_track_usage_context_manager():
145-
lm = dspy.LM("openai/gpt-4o-mini", cache=False)
136+
def test_track_usage_context_manager(llm_model):
137+
lm = dspy.LM(llm_model, cache=False)
146138
dspy.settings.configure(lm=lm)
147139

148140
predict = dspy.ChainOfThought("question -> answer")
@@ -151,12 +143,12 @@ def test_track_usage_context_manager():
151143
predict(question="What is the capital of Italy?")
152144

153145
assert len(tracker.usage_data) > 0
154-
assert len(tracker.usage_data["openai/gpt-4o-mini"]) == 2
146+
assert len(tracker.usage_data[llm_model]) == 2
155147

156148
total_usage = tracker.get_total_tokens()
157-
assert "openai/gpt-4o-mini" in total_usage
149+
assert llm_model in total_usage
158150
assert len(total_usage.keys()) == 1
159-
assert isinstance(total_usage["openai/gpt-4o-mini"], dict)
151+
assert isinstance(total_usage[llm_model], dict)
160152

161153

162154
def test_merge_usage_entries_with_new_keys():

0 commit comments

Comments
 (0)