Skip to content

Commit 9b467b5

Browse files
authored
Add Disable Fallback Option in ChatAdapter (#8984)
1 parent bf022c7 commit 9b467b5

File tree

2 files changed

+45
-6
lines changed

2 files changed

+45
-6
lines changed

dspy/adapters/chat_adapter.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,20 @@ class FieldInfoWithName(NamedTuple):
2626

2727

2828
class ChatAdapter(Adapter):
29+
def __init__(
30+
self,
31+
callbacks=None,
32+
use_native_function_calling: bool = False,
33+
native_response_types=None,
34+
use_json_adapter_fallback: bool = True,
35+
):
36+
super().__init__(
37+
callbacks=callbacks,
38+
use_native_function_calling=use_native_function_calling,
39+
native_response_types=native_response_types,
40+
)
41+
self.use_json_adapter_fallback = use_json_adapter_fallback
42+
2943
def __call__(
3044
self,
3145
lm: LM,
@@ -40,9 +54,13 @@ def __call__(
4054
# fallback to JSONAdapter
4155
from dspy.adapters.json_adapter import JSONAdapter
4256

43-
if isinstance(e, ContextWindowExceededError) or isinstance(self, JSONAdapter):
44-
# On context window exceeded error or already using JSONAdapter, we don't want to retry with a different
45-
# adapter.
57+
if (
58+
isinstance(e, ContextWindowExceededError)
59+
or isinstance(self, JSONAdapter)
60+
or not self.use_json_adapter_fallback
61+
):
62+
# On context window exceeded error, already using JSONAdapter, or use_json_adapter_fallback is False
63+
# we don't want to retry with a different adapter. Raise the original error instead of the fallback error.
4664
raise e
4765
return JSONAdapter()(lm, lm_kwargs, signature, demos, inputs)
4866

@@ -60,9 +78,13 @@ async def acall(
6078
# fallback to JSONAdapter
6179
from dspy.adapters.json_adapter import JSONAdapter
6280

63-
if isinstance(e, ContextWindowExceededError) or isinstance(self, JSONAdapter):
64-
# On context window exceeded error or already using JSONAdapter, we don't want to retry with a different
65-
# adapter.
81+
if (
82+
isinstance(e, ContextWindowExceededError)
83+
or isinstance(self, JSONAdapter)
84+
or not self.use_json_adapter_fallback
85+
):
86+
# On context window exceeded error, already using JSONAdapter, or use_json_adapter_fallback is False
87+
# we don't want to retry with a different adapter. Raise the original error instead of the fallback error.
6688
raise e
6789
return await JSONAdapter().acall(lm, lm_kwargs, signature, demos, inputs)
6890

tests/adapters/test_chat_adapter.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,23 @@ def test_chat_adapter_fallback_to_json_adapter_on_exception():
442442
result = adapter(lm, {}, signature, [], {"question": "What is the capital of France?"})
443443
assert result == [{"answer": "Paris"}]
444444

445+
def test_chat_adapter_respects_use_json_adapter_fallback_flag():
446+
signature = dspy.make_signature("question->answer")
447+
adapter = dspy.ChatAdapter(use_json_adapter_fallback=False)
448+
449+
with mock.patch("litellm.completion") as mock_completion:
450+
mock_completion.return_value = ModelResponse(
451+
choices=[Choices(message=Message(content="nonsense"))],
452+
model="openai/gpt-4o-mini",
453+
)
454+
455+
lm = dspy.LM("openai/gpt-4o-mini", cache=False)
456+
457+
with mock.patch("dspy.adapters.json_adapter.JSONAdapter.__call__") as mock_json_adapter_call:
458+
with pytest.raises(dspy.utils.exceptions.AdapterParseError):
459+
adapter(lm, {}, signature, [], {"question": "What is the capital of France?"})
460+
mock_json_adapter_call.assert_not_called()
461+
445462

446463
@pytest.mark.asyncio
447464
async def test_chat_adapter_fallback_to_json_adapter_on_exception_async():

0 commit comments

Comments
 (0)