Skip to content

Commit ed4ef14

Browse files
committed
use real LLM for unit tests
1 parent 252191d commit ed4ef14

File tree

5 files changed

+45
-44
lines changed

5 files changed

+45
-44
lines changed

.github/workflows/run_tests.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ jobs:
8181
uses: chartboost/ruff-action@v1
8282
with:
8383
args: check --fix-only
84+
- name: Set LLM model
85+
run: |
86+
echo "LLM_MODEL=${{ secrets.LLM_MODEL }}" >> $GITHUB_ENV
87+
echo "DATABRICKS_API_BASE=${{ secrets.DATABRICKS_API_BASE }}" >> $GITHUB_ENV
88+
echo "DATABRICKS_API_KEY=${{ secrets.DATABRICKS_API_KEY }}" >> $GITHUB_ENV
8489
- name: Run tests with pytest
8590
run: uv run -p .venv pytest tests/
8691
- name: Install optional dependencies

tests/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import os
23

34
import pytest
45

@@ -49,3 +50,11 @@ def pytest_collection_modifyitems(config, items):
4950
for item in items:
5051
if flag in item.keywords:
5152
item.add_marker(skip_mark)
53+
54+
55+
@pytest.fixture
56+
def llm_model():
57+
model = os.environ.get("LLM_MODEL", None)
58+
if model is None:
59+
pytest.skip("LLM_MODEL is not set in the environment variables")
60+
return model

tests/primitives/test_base_module.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -230,30 +230,28 @@ def emit(self, record):
230230
logger.removeHandler(handler)
231231

232232

233-
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Skip the test if OPENAI_API_KEY is not set.")
234-
def test_single_module_call_with_usage_tracker():
235-
dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False), track_usage=True)
233+
def test_single_module_call_with_usage_tracker(llm_model):
234+
dspy.settings.configure(lm=dspy.LM(llm_model, cache=False), track_usage=True)
236235

237236
predict = dspy.ChainOfThought("question -> answer")
238237
output = predict(question="What is the capital of France?")
239238

240239
lm_usage = output.get_lm_usage()
241240
assert len(lm_usage) == 1
242-
assert lm_usage["openai/gpt-4o-mini"]["prompt_tokens"] > 0
243-
assert lm_usage["openai/gpt-4o-mini"]["completion_tokens"] > 0
244-
assert lm_usage["openai/gpt-4o-mini"]["total_tokens"] > 0
241+
assert lm_usage[llm_model]["prompt_tokens"] > 0
242+
assert lm_usage[llm_model]["completion_tokens"] > 0
243+
assert lm_usage[llm_model]["total_tokens"] > 0
245244

246245
# Test no usage being tracked when cache is enabled
247-
dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=True), track_usage=True)
246+
dspy.settings.configure(lm=dspy.LM(llm_model, cache=True), track_usage=True)
248247
for _ in range(2):
249248
output = predict(question="What is the capital of France?")
250249

251250
assert len(output.get_lm_usage()) == 0
252251

253252

254-
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Skip the test if OPENAI_API_KEY is not set.")
255-
def test_multi_module_call_with_usage_tracker():
256-
dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False), track_usage=True)
253+
def test_multi_module_call_with_usage_tracker(llm_model):
254+
dspy.settings.configure(lm=dspy.LM(llm_model, cache=False), track_usage=True)
257255

258256
class MyProgram(dspy.Module):
259257
def __init__(self):
@@ -270,12 +268,13 @@ def __call__(self, question: str) -> str:
270268

271269
lm_usage = output.get_lm_usage()
272270
assert len(lm_usage) == 1
273-
assert lm_usage["openai/gpt-4o-mini"]["prompt_tokens"] > 0
274-
assert lm_usage["openai/gpt-4o-mini"]["prompt_tokens"] > 0
275-
assert lm_usage["openai/gpt-4o-mini"]["completion_tokens"] > 0
276-
assert lm_usage["openai/gpt-4o-mini"]["total_tokens"] > 0
271+
assert lm_usage[llm_model]["prompt_tokens"] > 0
272+
assert lm_usage[llm_model]["prompt_tokens"] > 0
273+
assert lm_usage[llm_model]["completion_tokens"] > 0
274+
assert lm_usage[llm_model]["total_tokens"] > 0
277275

278276

277+
# TODO: prepare second model for testing this unit test in ci
279278
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Skip the test if OPENAI_API_KEY is not set.")
280279
def test_usage_tracker_in_parallel():
281280
class MyProgram(dspy.Module):

tests/streaming/test_streaming.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,8 @@ def module_start_status_message(self, instance, inputs):
131131
assert status_messages[2].message == "Predict starting!"
132132

133133

134-
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OpenAI API key not found in environment variables")
135134
@pytest.mark.anyio
136-
async def test_stream_listener_chat_adapter():
135+
async def test_stream_listener_chat_adapter(llm_model):
137136
class MyProgram(dspy.Module):
138137
def __init__(self):
139138
self.predict1 = dspy.Predict("question->answer")
@@ -154,7 +153,7 @@ def __call__(self, x: str, **kwargs):
154153
include_final_prediction_in_output_stream=False,
155154
)
156155
# Turn off the cache to ensure the stream is produced.
157-
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False)):
156+
with dspy.context(lm=dspy.LM(llm_model, cache=False)):
158157
output = program(x="why did a chicken cross the kitchen?")
159158
all_chunks = []
160159
async for value in output:
@@ -194,9 +193,8 @@ async def acall(self, x: str):
194193
assert status_messages[1].message == "Tool calling finished! Querying the LLM with tool calling results..."
195194

196195

197-
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OpenAI API key not found in environment variables")
198196
@pytest.mark.anyio
199-
async def test_stream_listener_json_adapter():
197+
async def test_stream_listener_json_adapter(llm_model):
200198
class MyProgram(dspy.Module):
201199
def __init__(self):
202200
self.predict1 = dspy.Predict("question->answer")
@@ -217,7 +215,7 @@ def __call__(self, x: str, **kwargs):
217215
include_final_prediction_in_output_stream=False,
218216
)
219217
# Turn off the cache to ensure the stream is produced.
220-
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter()):
218+
with dspy.context(lm=dspy.LM(llm_model, cache=False), adapter=dspy.JSONAdapter()):
221219
output = program(x="why did a chicken cross the kitchen?")
222220
all_chunks = []
223221
async for value in output:
@@ -232,22 +230,22 @@ def __call__(self, x: str, **kwargs):
232230

233231

234232
@pytest.mark.anyio
235-
async def test_streaming_handles_space_correctly():
233+
async def test_streaming_handles_space_correctly(llm_model):
236234
my_program = dspy.Predict("question->answer")
237235
program = dspy.streamify(
238236
my_program, stream_listeners=[dspy.streaming.StreamListener(signature_field_name="answer")]
239237
)
240238

241239
async def gpt_4o_mini_stream(*args, **kwargs):
242240
yield ModelResponseStream(
243-
model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="[[ ## answer ## ]]\n"))]
241+
model=llm_model, choices=[StreamingChoices(delta=Delta(content="[[ ## answer ## ]]\n"))]
244242
)
245-
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="How "))])
246-
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="are "))])
247-
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="you "))])
248-
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="doing?"))])
243+
yield ModelResponseStream(model=llm_model, choices=[StreamingChoices(delta=Delta(content="How "))])
244+
yield ModelResponseStream(model=llm_model, choices=[StreamingChoices(delta=Delta(content="are "))])
245+
yield ModelResponseStream(model=llm_model, choices=[StreamingChoices(delta=Delta(content="you "))])
246+
yield ModelResponseStream(model=llm_model, choices=[StreamingChoices(delta=Delta(content="doing?"))])
249247
yield ModelResponseStream(
250-
model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="\n\n[[ ## completed ## ]]"))]
248+
model=llm_model, choices=[StreamingChoices(delta=Delta(content="\n\n[[ ## completed ## ]]"))]
251249
)
252250

253251
with mock.patch("litellm.acompletion", side_effect=gpt_4o_mini_stream):
@@ -261,8 +259,7 @@ async def gpt_4o_mini_stream(*args, **kwargs):
261259
assert all_chunks[0].chunk == "How are you doing?"
262260

263261

264-
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OpenAI API key not found in environment variables")
265-
def test_sync_streaming():
262+
def test_sync_streaming(llm_model):
266263
class MyProgram(dspy.Module):
267264
def __init__(self):
268265
self.predict1 = dspy.Predict("question->answer")
@@ -284,7 +281,7 @@ def __call__(self, x: str, **kwargs):
284281
async_streaming=False,
285282
)
286283
# Turn off the cache to ensure the stream is produced.
287-
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False)):
284+
with dspy.context(lm=dspy.LM(llm_model, cache=False)):
288285
output = program(x="why did a chicken cross the kitchen?")
289286
all_chunks = []
290287
for value in output:

tests/utils/test_usage_tracker.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
1-
import os
2-
3-
import pytest
4-
51
import dspy
62
from dspy.utils.usage_tracker import UsageTracker, track_usage
73

8-
94
def test_add_usage_entry():
105
"""Test adding usage entries to the tracker."""
116
tracker = UsageTracker()
@@ -137,12 +132,8 @@ def test_track_usage_with_multiple_models():
137132
assert total_usage["gpt-3.5-turbo"]["total_tokens"] == 900
138133

139134

140-
@pytest.mark.skipif(
141-
not os.getenv("OPENAI_API_KEY"),
142-
reason="Skip the test if OPENAI_API_KEY is not set.",
143-
)
144-
def test_track_usage_context_manager():
145-
lm = dspy.LM("openai/gpt-4o-mini", cache=False)
135+
def test_track_usage_context_manager(llm_model):
136+
lm = dspy.LM(llm_model, cache=False)
146137
dspy.settings.configure(lm=lm)
147138

148139
predict = dspy.ChainOfThought("question -> answer")
@@ -151,12 +142,12 @@ def test_track_usage_context_manager():
151142
predict(question="What is the capital of Italy?")
152143

153144
assert len(tracker.usage_data) > 0
154-
assert len(tracker.usage_data["openai/gpt-4o-mini"]) == 2
145+
assert len(tracker.usage_data[llm_model]) == 2
155146

156147
total_usage = tracker.get_total_tokens()
157-
assert "openai/gpt-4o-mini" in total_usage
148+
assert llm_model in total_usage
158149
assert len(total_usage.keys()) == 1
159-
assert isinstance(total_usage["openai/gpt-4o-mini"], dict)
150+
assert isinstance(total_usage[llm_model], dict)
160151

161152

162153
def test_merge_usage_entries_with_new_keys():

0 commit comments

Comments
 (0)