Skip to content

Commit 4dd085c

Browse files
authored
Enable callback logging only for full eval on GEPA (#9050)
* Limit callback metadata to trace capture path * use batch length for eval * refactor(gepa): rename full_eval_size to reflection_minibatch_size Updated the DspyAdapter and GEPA classes to replace the full_eval_size parameter with reflection_minibatch_size for improved clarity. Adjusted the evaluate method and corresponding tests to reflect this change, ensuring callback metadata is correctly generated based on the new parameter. Signed-off-by: TomuHirata <tomu.hirata@gmail.com> --------- Signed-off-by: TomuHirata <tomu.hirata@gmail.com>
1 parent e2b9ef8 commit 4dd085c

File tree

3 files changed

+22
-8
lines changed

3 files changed

+22
-8
lines changed

dspy/teleprompt/gepa/gepa.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,8 @@ def feedback_fn(
545545
rng=rng,
546546
reflection_lm=self.reflection_lm,
547547
custom_instruction_proposer=self.custom_instruction_proposer,
548-
warn_on_score_mismatch=self.warn_on_score_mismatch
548+
warn_on_score_mismatch=self.warn_on_score_mismatch,
549+
reflection_minibatch_size=self.reflection_minibatch_size,
549550
)
550551

551552
# Instantiate GEPA with the simpler adapter-based API

dspy/teleprompt/gepa/gepa_utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def __init__(
7676
rng: random.Random | None = None,
7777
reflection_lm=None,
7878
custom_instruction_proposer: "ProposalFn | None" = None,
79-
warn_on_score_mismatch: bool = True
79+
warn_on_score_mismatch: bool = True,
80+
reflection_minibatch_size: int | None = None,
8081
):
8182
self.student = student_module
8283
self.metric_fn = metric_fn
@@ -88,6 +89,7 @@ def __init__(
8889
self.reflection_lm = reflection_lm
8990
self.custom_instruction_proposer = custom_instruction_proposer
9091
self.warn_on_score_mismatch = warn_on_score_mismatch
92+
self.reflection_minibatch_size = reflection_minibatch_size
9193

9294
if self.custom_instruction_proposer is not None:
9395
# We are only overriding the propose_new_texts method when a custom
@@ -128,12 +130,12 @@ def build_program(self, candidate: dict[str, str]):
128130

129131
def evaluate(self, batch, candidate, capture_traces=False):
130132
program = self.build_program(candidate)
133+
callback_metadata = {"metric_key": "eval_full"} if self.reflection_minibatch_size is None or len(batch) > self.reflection_minibatch_size else {"disable_logging": True}
131134

132135
if capture_traces:
133136
# bootstrap_trace_data-like flow with trace capture
134137
from dspy.teleprompt import bootstrap_trace as bootstrap_trace_module
135138

136-
eval_callback_metadata = {"disable_logging": True}
137139
trajs = bootstrap_trace_module.bootstrap_trace_data(
138140
program=program,
139141
dataset=batch,
@@ -143,7 +145,7 @@ def evaluate(self, batch, candidate, capture_traces=False):
143145
capture_failed_parses=True,
144146
failure_score=self.failure_score,
145147
format_failure_score=self.failure_score,
146-
callback_metadata=eval_callback_metadata,
148+
callback_metadata=callback_metadata,
147149
)
148150
scores = []
149151
outputs = []
@@ -165,7 +167,8 @@ def evaluate(self, batch, candidate, capture_traces=False):
165167
return_all_scores=True,
166168
failure_score=self.failure_score,
167169
provide_traceback=True,
168-
max_errors=len(batch) * 100
170+
max_errors=len(batch) * 100,
171+
callback_metadata=callback_metadata,
169172
)
170173
res = evaluator(program)
171174
outputs = [r[1] for r in res.results]

tests/teleprompt/test_gepa.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,16 @@ def bad_metric(example, prediction):
4343
return 0.0
4444

4545

46-
def test_gepa_adapter_disables_logging_during_trace_capture(monkeypatch):
46+
@pytest.mark.parametrize("reflection_minibatch_size, batch, expected_callback_metadata", [
47+
(None, [], {"metric_key": "eval_full"}),
48+
(None, [Example(input="What is the color of the sky?", output="blue")], {"metric_key": "eval_full"}),
49+
(1, [], {"disable_logging": True}),
50+
(1, [
51+
Example(input="What is the color of the sky?", output="blue"),
52+
Example(input="What does the fox say?", output="Ring-ding-ding-ding-dingeringeding!"),
53+
], {"metric_key": "eval_full"}),
54+
])
55+
def test_gepa_adapter_disables_logging_on_minibatch_eval(monkeypatch, reflection_minibatch_size, batch, expected_callback_metadata):
4756
from dspy.teleprompt import bootstrap_trace as bootstrap_trace_module
4857
from dspy.teleprompt.gepa import gepa_utils
4958

@@ -57,6 +66,7 @@ def forward(self, **kwargs): # pragma: no cover - stub forward
5766
metric_fn=simple_metric,
5867
feedback_map={},
5968
failure_score=0.0,
69+
reflection_minibatch_size=reflection_minibatch_size,
6070
)
6171

6272
captured_kwargs: dict[str, Any] = {}
@@ -72,9 +82,9 @@ def dummy_bootstrap_trace_data(*args, **kwargs):
7282
lambda self, candidate: DummyModule(),
7383
)
7484

75-
adapter.evaluate(batch=[], candidate={}, capture_traces=True)
85+
adapter.evaluate(batch=batch, candidate={}, capture_traces=True)
7686

77-
assert captured_kwargs["callback_metadata"] == {"disable_logging": True}
87+
assert captured_kwargs["callback_metadata"] == expected_callback_metadata
7888

7989

8090
@pytest.fixture

0 commit comments

Comments
 (0)