diff --git a/dspy/teleprompt/gepa/gepa.py b/dspy/teleprompt/gepa/gepa.py index f32483bd6c..c35e916691 100644 --- a/dspy/teleprompt/gepa/gepa.py +++ b/dspy/teleprompt/gepa/gepa.py @@ -545,7 +545,8 @@ def feedback_fn( rng=rng, reflection_lm=self.reflection_lm, custom_instruction_proposer=self.custom_instruction_proposer, - warn_on_score_mismatch=self.warn_on_score_mismatch + warn_on_score_mismatch=self.warn_on_score_mismatch, + reflection_minibatch_size=self.reflection_minibatch_size, ) # Instantiate GEPA with the simpler adapter-based API diff --git a/dspy/teleprompt/gepa/gepa_utils.py b/dspy/teleprompt/gepa/gepa_utils.py index 844afe8b00..d2e6772cef 100644 --- a/dspy/teleprompt/gepa/gepa_utils.py +++ b/dspy/teleprompt/gepa/gepa_utils.py @@ -76,7 +76,8 @@ def __init__( rng: random.Random | None = None, reflection_lm=None, custom_instruction_proposer: "ProposalFn | None" = None, - warn_on_score_mismatch: bool = True + warn_on_score_mismatch: bool = True, + reflection_minibatch_size: int | None = None, ): self.student = student_module self.metric_fn = metric_fn @@ -88,6 +89,7 @@ def __init__( self.reflection_lm = reflection_lm self.custom_instruction_proposer = custom_instruction_proposer self.warn_on_score_mismatch = warn_on_score_mismatch + self.reflection_minibatch_size = reflection_minibatch_size if self.custom_instruction_proposer is not None: # We are only overriding the propose_new_texts method when a custom @@ -128,12 +130,12 @@ def build_program(self, candidate: dict[str, str]): def evaluate(self, batch, candidate, capture_traces=False): program = self.build_program(candidate) + callback_metadata = {"metric_key": "eval_full"} if self.reflection_minibatch_size is None or len(batch) > self.reflection_minibatch_size else {"disable_logging": True} if capture_traces: # bootstrap_trace_data-like flow with trace capture from dspy.teleprompt import bootstrap_trace as bootstrap_trace_module - eval_callback_metadata = {"disable_logging": True} trajs = bootstrap_trace_module.bootstrap_trace_data( program=program, dataset=batch, @@ -143,7 +145,7 @@ def evaluate(self, batch, candidate, capture_traces=False): capture_failed_parses=True, failure_score=self.failure_score, format_failure_score=self.failure_score, - callback_metadata=eval_callback_metadata, + callback_metadata=callback_metadata, ) scores = [] outputs = [] @@ -165,7 +167,8 @@ def evaluate(self, batch, candidate, capture_traces=False): return_all_scores=True, failure_score=self.failure_score, provide_traceback=True, - max_errors=len(batch) * 100 + max_errors=len(batch) * 100, + callback_metadata=callback_metadata, ) res = evaluator(program) outputs = [r[1] for r in res.results] diff --git a/tests/teleprompt/test_gepa.py b/tests/teleprompt/test_gepa.py index da9a74da82..afe40d082a 100644 --- a/tests/teleprompt/test_gepa.py +++ b/tests/teleprompt/test_gepa.py @@ -43,7 +43,16 @@ def bad_metric(example, prediction): return 0.0 -def test_gepa_adapter_disables_logging_during_trace_capture(monkeypatch): +@pytest.mark.parametrize("reflection_minibatch_size, batch, expected_callback_metadata", [ + (None, [], {"metric_key": "eval_full"}), + (None, [Example(input="What is the color of the sky?", output="blue")], {"metric_key": "eval_full"}), + (1, [], {"disable_logging": True}), + (1, [ + Example(input="What is the color of the sky?", output="blue"), + Example(input="What does the fox say?", output="Ring-ding-ding-ding-dingeringeding!"), + ], {"metric_key": "eval_full"}), +]) +def test_gepa_adapter_disables_logging_on_minibatch_eval(monkeypatch, reflection_minibatch_size, batch, expected_callback_metadata): from dspy.teleprompt import bootstrap_trace as bootstrap_trace_module from dspy.teleprompt.gepa import gepa_utils @@ -57,6 +66,7 @@ def forward(self, **kwargs): # pragma: no cover - stub forward metric_fn=simple_metric, feedback_map={}, failure_score=0.0, + reflection_minibatch_size=reflection_minibatch_size, ) captured_kwargs: dict[str, Any] = {} @@ -72,9 +82,9 @@ def dummy_bootstrap_trace_data(*args, **kwargs): lambda self, candidate: DummyModule(), ) - adapter.evaluate(batch=[], candidate={}, capture_traces=True) + adapter.evaluate(batch=batch, candidate={}, capture_traces=True) - assert captured_kwargs["callback_metadata"] == {"disable_logging": True} + assert captured_kwargs["callback_metadata"] == expected_callback_metadata @pytest.fixture