Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dspy/teleprompt/gepa/gepa.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,8 @@ def feedback_fn(
rng=rng,
reflection_lm=self.reflection_lm,
custom_instruction_proposer=self.custom_instruction_proposer,
warn_on_score_mismatch=self.warn_on_score_mismatch
warn_on_score_mismatch=self.warn_on_score_mismatch,
reflection_minibatch_size=self.reflection_minibatch_size,
)

# Instantiate GEPA with the simpler adapter-based API
Expand Down
11 changes: 7 additions & 4 deletions dspy/teleprompt/gepa/gepa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def __init__(
rng: random.Random | None = None,
reflection_lm=None,
custom_instruction_proposer: "ProposalFn | None" = None,
warn_on_score_mismatch: bool = True
warn_on_score_mismatch: bool = True,
reflection_minibatch_size: int | None = None,
):
self.student = student_module
self.metric_fn = metric_fn
Expand All @@ -88,6 +89,7 @@ def __init__(
self.reflection_lm = reflection_lm
self.custom_instruction_proposer = custom_instruction_proposer
self.warn_on_score_mismatch = warn_on_score_mismatch
self.reflection_minibatch_size = reflection_minibatch_size

if self.custom_instruction_proposer is not None:
# We are only overriding the propose_new_texts method when a custom
Expand Down Expand Up @@ -128,12 +130,12 @@ def build_program(self, candidate: dict[str, str]):

def evaluate(self, batch, candidate, capture_traces=False):
program = self.build_program(candidate)
callback_metadata = {"metric_key": "eval_full"} if self.reflection_minibatch_size is None or len(batch) > self.reflection_minibatch_size else {"disable_logging": True}

if capture_traces:
# bootstrap_trace_data-like flow with trace capture
from dspy.teleprompt import bootstrap_trace as bootstrap_trace_module

eval_callback_metadata = {"disable_logging": True}
trajs = bootstrap_trace_module.bootstrap_trace_data(
program=program,
dataset=batch,
Expand All @@ -143,7 +145,7 @@ def evaluate(self, batch, candidate, capture_traces=False):
capture_failed_parses=True,
failure_score=self.failure_score,
format_failure_score=self.failure_score,
callback_metadata=eval_callback_metadata,
callback_metadata=callback_metadata,
)
scores = []
outputs = []
Expand All @@ -165,7 +167,8 @@ def evaluate(self, batch, candidate, capture_traces=False):
return_all_scores=True,
failure_score=self.failure_score,
provide_traceback=True,
max_errors=len(batch) * 100
max_errors=len(batch) * 100,
callback_metadata=callback_metadata,
)
res = evaluator(program)
outputs = [r[1] for r in res.results]
Expand Down
16 changes: 13 additions & 3 deletions tests/teleprompt/test_gepa.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,16 @@ def bad_metric(example, prediction):
return 0.0


def test_gepa_adapter_disables_logging_during_trace_capture(monkeypatch):
@pytest.mark.parametrize("reflection_minibatch_size, batch, expected_callback_metadata", [
(None, [], {"metric_key": "eval_full"}),
(None, [Example(input="What is the color of the sky?", output="blue")], {"metric_key": "eval_full"}),
(1, [], {"disable_logging": True}),
(1, [
Example(input="What is the color of the sky?", output="blue"),
Example(input="What does the fox say?", output="Ring-ding-ding-ding-dingeringeding!"),
], {"metric_key": "eval_full"}),
])
def test_gepa_adapter_disables_logging_on_minibatch_eval(monkeypatch, reflection_minibatch_size, batch, expected_callback_metadata):
from dspy.teleprompt import bootstrap_trace as bootstrap_trace_module
from dspy.teleprompt.gepa import gepa_utils

Expand All @@ -57,6 +66,7 @@ def forward(self, **kwargs): # pragma: no cover - stub forward
metric_fn=simple_metric,
feedback_map={},
failure_score=0.0,
reflection_minibatch_size=reflection_minibatch_size,
)

captured_kwargs: dict[str, Any] = {}
Expand All @@ -72,9 +82,9 @@ def dummy_bootstrap_trace_data(*args, **kwargs):
lambda self, candidate: DummyModule(),
)

adapter.evaluate(batch=[], candidate={}, capture_traces=True)
adapter.evaluate(batch=batch, candidate={}, capture_traces=True)

assert captured_kwargs["callback_metadata"] == {"disable_logging": True}
assert captured_kwargs["callback_metadata"] == expected_callback_metadata


@pytest.fixture
Expand Down