From 55aef6f0d29528fb6ba1a6f5e338808d57403e6a Mon Sep 17 00:00:00 2001 From: Ty Todd Date: Mon, 3 Nov 2025 10:54:20 -0800 Subject: [PATCH 1/3] Added a replaced teleprompt.compile's signature with a TypeVar that tracks the type exact type of the student module passed in --- dspy/teleprompt/teleprompt.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/dspy/teleprompt/teleprompt.py b/dspy/teleprompt/teleprompt.py index ad1b8b865d..5cedc3067d 100644 --- a/dspy/teleprompt/teleprompt.py +++ b/dspy/teleprompt/teleprompt.py @@ -1,13 +1,23 @@ -from typing import Any +from typing import Any, TypeVar from dspy.primitives import Example, Module +M = TypeVar("M", bound=Module) + class Teleprompter: def __init__(self): pass - def compile(self, student: Module, *, trainset: list[Example], teacher: Module | None = None, valset: list[Example] | None = None, **kwargs) -> Module: + def compile( + self, + student: M, + *, + trainset: list[Example], + teacher: Module | None = None, + valset: list[Example] | None = None, + **kwargs, + ) -> M: """ Optimize the student program. From 315840fb82002ac929564925c346857ad4e6f5f2 Mon Sep 17 00:00:00 2001 From: Ty Todd Date: Mon, 3 Nov 2025 11:09:52 -0800 Subject: [PATCH 2/3] added typevar to all other compilers --- dspy/teleprompt/bettertogether.py | 25 +-- dspy/teleprompt/bootstrap_finetune.py | 8 +- dspy/teleprompt/grpo.py | 266 ++++++++++++++++++-------- dspy/teleprompt/mipro_optimizer_v2.py | 46 ++--- dspy/teleprompt/simba.py | 27 ++- scratchpad.py | 19 ++ 6 files changed, 252 insertions(+), 139 deletions(-) create mode 100644 scratchpad.py diff --git a/dspy/teleprompt/bettertogether.py b/dspy/teleprompt/bettertogether.py index d1154f9ae4..b772b29bad 100644 --- a/dspy/teleprompt/bettertogether.py +++ b/dspy/teleprompt/bettertogether.py @@ -1,6 +1,6 @@ import logging import random -from typing import Callable +from typing import Callable, TypeVar import dspy from dspy.primitives.example import Example @@ -14,20 +14,22 @@ ) from dspy.teleprompt.random_search import BootstrapFewShotWithRandomSearch from dspy.teleprompt.teleprompt import Teleprompter +from dspy.primitives import Module +M = TypeVar("M", bound=Module) logger = logging.getLogger(__name__) class BetterTogether(Teleprompter): - STRAT_SEP = " -> " - def __init__(self, + def __init__( + self, metric: Callable, prompt_optimizer: Teleprompter | None = None, weight_optimizer: Teleprompter | None = None, seed: int | None = None, - ): + ): if not dspy.settings.experimental: raise ValueError("This is an experimental optimizer. Set `dspy.settings.experimental` to `True` to use it.") @@ -37,7 +39,9 @@ def __init__(self, # a BootstrapFinetune without a metric, say, if there aren't labels # available for the training data. Should this be noted somewhere? # TODO: We should re-consider if the metric should be required. - self.prompt_optimizer = prompt_optimizer if prompt_optimizer else BootstrapFewShotWithRandomSearch(metric=metric) + self.prompt_optimizer = ( + prompt_optimizer if prompt_optimizer else BootstrapFewShotWithRandomSearch(metric=metric) + ) self.weight_optimizer = weight_optimizer if weight_optimizer else BootstrapFinetune(metric=metric) is_supported_prompt = isinstance(self.prompt_optimizer, BootstrapFewShotWithRandomSearch) @@ -52,11 +56,11 @@ def __init__(self, def compile( self, - student: Module, + student: M, trainset: list[Example], strategy: str = "p -> w -> p", - valset_ratio = 0.1, - ) -> Module: + valset_ratio=0.1, + ) -> M: # TODO: We could record acc on a different valset to pick the best # strategy within the provided strategy logger.info("Validating the strategy") @@ -91,10 +95,9 @@ def _run_strategies(self, parsed_strategy, student, trainset, valset_ratio) -> M launched_flag = False for ind, step_code in enumerate(parsed_strategy): - current_strategy = self.STRAT_SEP.join(parsed_strategy[:ind + 1]) + current_strategy = self.STRAT_SEP.join(parsed_strategy[: ind + 1]) logger.info( - f"\n########## Step {ind + 1} of {len(parsed_strategy)} - Strategy " - f"'{current_strategy}' ##########" + f"\n########## Step {ind + 1} of {len(parsed_strategy)} - Strategy " f"'{current_strategy}' ##########" ) logger.info("Shuffling the trainset...") diff --git a/dspy/teleprompt/bootstrap_finetune.py b/dspy/teleprompt/bootstrap_finetune.py index 6a431e355e..e7a4710826 100644 --- a/dspy/teleprompt/bootstrap_finetune.py +++ b/dspy/teleprompt/bootstrap_finetune.py @@ -1,6 +1,6 @@ import logging from collections import defaultdict -from typing import Any, Callable +from typing import Any, Callable, TypeVar import dspy from dspy.adapters.base import Adapter @@ -16,6 +16,8 @@ logger = logging.getLogger(__name__) +M = TypeVar("M", bound=Module) + class FinetuneTeleprompter(Teleprompter): def __init__( @@ -57,9 +59,7 @@ def __init__( self.exclude_demos = exclude_demos self.num_threads = num_threads - def compile( - self, student: Module, trainset: list[Example], teacher: Module | list[Module] | None = None - ) -> Module: + def compile(self, student: M, trainset: list[Example], teacher: Module | list[Module] | None = None) -> M: # TODO: Print statements can be converted to logger.info if we ensure # that the default DSPy logger logs info level messages in notebook # environments. diff --git a/dspy/teleprompt/grpo.py b/dspy/teleprompt/grpo.py index 7998d2608d..6e5fa8abc9 100644 --- a/dspy/teleprompt/grpo.py +++ b/dspy/teleprompt/grpo.py @@ -2,7 +2,7 @@ import random import time from collections import Counter, deque -from typing import Any, Callable, Literal +from typing import Any, Callable, Literal, TypeVar from dspy.adapters.base import Adapter from dspy.adapters.chat_adapter import ChatAdapter @@ -22,6 +22,8 @@ logger = logging.getLogger(__name__) +M = TypeVar("M", bound=Module) + class GRPO(FinetuneTeleprompter): def __init__( @@ -41,7 +43,9 @@ def __init__( report_train_scores: bool = False, failure_score: float = 0, format_failure_score: float = -1, - variably_invoked_predictor_grouping_mode: Literal["truncate"] | Literal["fill"] | Literal["ragged"] = "truncate", + variably_invoked_predictor_grouping_mode: Literal["truncate"] + | Literal["fill"] + | Literal["ragged"] = "truncate", variably_invoked_predictor_fill_strategy: Literal["randint"] | Literal["max"] | None = None, ): super().__init__(train_kwargs=train_kwargs) @@ -60,7 +64,9 @@ def __init__( self.failure_score = failure_score self.format_failure_score = format_failure_score - assert failure_score > format_failure_score, "failure_score must be greater than format_failure_score since the range [format_failure_score, failure_score] is used to provide dspy formatting rewards" + assert ( + failure_score > format_failure_score + ), "failure_score must be greater than format_failure_score since the range [format_failure_score, failure_score] is used to provide dspy formatting rewards" if self.use_train_as_val: assert report_train_scores, "If use_train_as_val is True, report_train_scores must be True." @@ -72,8 +78,13 @@ def __init__( # If multitask is False, the backend will be called with a batch of (num_dspy_examples_per_grpo_step * num_rollouts_per_grpo_step) per training job self.variably_invoked_predictor_grouping_mode = variably_invoked_predictor_grouping_mode if variably_invoked_predictor_grouping_mode == "fill": - assert variably_invoked_predictor_fill_strategy is not None, "variably_invoked_predictor_fill_strategy must be set when variably_invoked_predictor_grouping_mode is 'fill'" - assert variably_invoked_predictor_fill_strategy in ["randint", "max"], "variably_invoked_predictor_fill_strategy must be either 'randint' or 'max'" + assert ( + variably_invoked_predictor_fill_strategy is not None + ), "variably_invoked_predictor_fill_strategy must be set when variably_invoked_predictor_grouping_mode is 'fill'" + assert variably_invoked_predictor_fill_strategy in [ + "randint", + "max", + ], "variably_invoked_predictor_fill_strategy must be either 'randint' or 'max'" self.variably_invoked_predictor_fill_strategy = variably_invoked_predictor_fill_strategy self.shuffled_trainset_ids = [] @@ -92,16 +103,26 @@ def validate_trace_data_and_log_issues( ): # At this point, trace_data: list[example_idx -> list[teacher_idx -> [num_samples_per_input * Dict(example, prediction, trace, example_ind, score)]]] # Shape of trace is: [dspy_module_invocation_idx -> Tuple[Predictor, PredictorInputs, Prediction]] - assert len(trace_data) == len(subsample_training_dataset), f"Trace data length {len(trace_data)} does not match the number of examples {len(subsample_training_dataset)}" - assert len(trace_data[0]) == num_teachers, f"Trace data length {len(trace_data[0])} does not match the number of teachers {num_teachers}" + assert ( + len(trace_data) == len(subsample_training_dataset) + ), f"Trace data length {len(trace_data)} does not match the number of examples {len(subsample_training_dataset)}" + assert ( + len(trace_data[0]) == num_teachers + ), f"Trace data length {len(trace_data[0])} does not match the number of teachers {num_teachers}" # TODO(GRPO Team): Ideally, once the dspy format issue is fixed, this change should be reverted back to being a normal assert. if len(trace_data[0][0]) == 0: - logger.warning(f"Trace data for example {0} and teacher {0} is empty. This is likely due to all examples in the training set input, resulting in the model generating output not following the dspy response format.") + logger.warning( + f"Trace data for example {0} and teacher {0} is empty. This is likely due to all examples in the training set input, resulting in the model generating output not following the dspy response format." + ) elif len(trace_data[0][0]) != num_samples_per_input: - logger.warning(f"Trace data length {len(trace_data[0][0])} does not match the expected number of samples per input {num_samples_per_input}") + logger.warning( + f"Trace data length {len(trace_data[0][0])} does not match the expected number of samples per input {num_samples_per_input}" + ) assert "trace" in trace_data[0][0][0], "Trace data does not contain the 'trace' key" assert len(trace_data[0][0][0]["trace"]) > 0, "Trace data is empty" - assert len(trace_data[0][0][0]["trace"][0]) == 3, f"Trace tuple length {len(trace_data[0][0][0]['trace'][0])} does not match the expected length 3" + assert ( + len(trace_data[0][0][0]["trace"][0]) == 3 + ), f"Trace tuple length {len(trace_data[0][0][0]['trace'][0])} does not match the expected length 3" for example_data in trace_data: for teacher_data in example_data: @@ -118,33 +139,43 @@ def report_validation_metrics(self, student, trainset, valset, logger, step_idx= if valset is not None: # Validation set provided by user assert not self.use_train_as_val, "If valset is provided, use_train_as_val must be False." - assert isinstance(self.num_steps_for_val, int) and self.num_steps_for_val > 0, "num_steps_for_val must be a positive integer." + assert ( + isinstance(self.num_steps_for_val, int) and self.num_steps_for_val > 0 + ), "num_steps_for_val must be a positive integer." if self.report_train_scores: if step_idx == -1: - logger.info("Using user provided validation set and reporting train scores for every validation step in addition.") + logger.info( + "Using user provided validation set and reporting train scores for every validation step in addition." + ) valset_evaluator = Evaluate( devset=valset + trainset, num_threads=self.num_threads, display_progress=True, provide_traceback=False, # TODO(check with team) - max_errors=len(valset)*10, # TODO(check with team) - failure_score=self.failure_score + max_errors=len(valset) * 10, # TODO(check with team) + failure_score=self.failure_score, ) if step_idx == -1: logger.info("Evaluating the student program on the train+validation set before training loop...") else: - logger.info(f"Evaluating the student program on the validation set after training step {step_idx + 1}/{self.num_train_steps}") + logger.info( + f"Evaluating the student program on the validation set after training step {step_idx + 1}/{self.num_train_steps}" + ) valset_evaluation = valset_evaluator(student, metric=self.metric) - trainset_scores = [r[-1] for r in valset_evaluation.results[len(valset):]] - valset_scores = [r[-1] for r in valset_evaluation.results[:len(valset)]] + trainset_scores = [r[-1] for r in valset_evaluation.results[len(valset) :]] + valset_scores = [r[-1] for r in valset_evaluation.results[: len(valset)]] trainset_agg = sum(trainset_scores) / len(trainset_scores) valset_agg = sum(valset_scores) / len(valset_scores) if step_idx == -1: logger.info(f"Student program training set score before training loop: {trainset_agg}") logger.info(f"Student program validation set score before training loop: {valset_agg}") else: - logger.info(f"Student program training set score after training step {step_idx + 1}/{self.num_train_steps}: {trainset_agg}") - logger.info(f"Student program validation set score after training step {step_idx + 1}/{self.num_train_steps}: {valset_agg}") + logger.info( + f"Student program training set score after training step {step_idx + 1}/{self.num_train_steps}: {trainset_agg}" + ) + logger.info( + f"Student program validation set score after training step {step_idx + 1}/{self.num_train_steps}: {valset_agg}" + ) else: if step_idx == -1: logger.info("Using user provided validation set and not reporting train scores.") @@ -153,23 +184,29 @@ def report_validation_metrics(self, student, trainset, valset, logger, step_idx= num_threads=self.num_threads, display_progress=True, provide_traceback=False, # TODO(check with team) - max_errors=len(valset)*10, # TODO(check with team) - failure_score=self.failure_score + max_errors=len(valset) * 10, # TODO(check with team) + failure_score=self.failure_score, ) if step_idx == -1: logger.info("Evaluating the student program on the validation set before training loop...") else: - logger.info(f"Evaluating the student program on the validation set after training step {step_idx + 1}/{self.num_train_steps}") + logger.info( + f"Evaluating the student program on the validation set after training step {step_idx + 1}/{self.num_train_steps}" + ) valset_evaluation = valset_evaluator(student, metric=self.metric) if step_idx == -1: logger.info(f"Student program validation set score before training loop: {valset_evaluation.score}") else: - logger.info(f"Student program validation set score after training step {step_idx + 1}/{self.num_train_steps}: {valset_evaluation.score}") + logger.info( + f"Student program validation set score after training step {step_idx + 1}/{self.num_train_steps}: {valset_evaluation.score}" + ) else: # No validation set provided by user if self.report_train_scores: assert self.use_train_as_val, "If report_train_scores is True, use_train_as_val must be True when valset is not provided explicitly." - assert isinstance(self.num_steps_for_val, int) and self.num_steps_for_val > 0, "num_steps_for_val must be a positive integer." + assert ( + isinstance(self.num_steps_for_val, int) and self.num_steps_for_val > 0 + ), "num_steps_for_val must be a positive integer." if step_idx == -1: logger.info("Using trainset as validation set.") valset_evaluator = Evaluate( @@ -177,18 +214,22 @@ def report_validation_metrics(self, student, trainset, valset, logger, step_idx= num_threads=self.num_threads, display_progress=True, provide_traceback=False, # TODO(check with team) - max_errors=len(trainset)*10, # TODO(check with team) - failure_score=self.failure_score + max_errors=len(trainset) * 10, # TODO(check with team) + failure_score=self.failure_score, ) if step_idx == -1: logger.info("Evaluating the student program on the validation set before training loop...") else: - logger.info(f"Evaluating the student program on the validation set after training step {step_idx + 1}/{self.num_train_steps}") + logger.info( + f"Evaluating the student program on the validation set after training step {step_idx + 1}/{self.num_train_steps}" + ) valset_evaluation = valset_evaluator(student, metric=self.metric) if step_idx == -1: logger.info(f"Student program training set score before training loop: {valset_evaluation.score}") else: - logger.info(f"Student program training set score after training step {step_idx + 1}/{self.num_train_steps}: {valset_evaluation.score}") + logger.info( + f"Student program training set score after training step {step_idx + 1}/{self.num_train_steps}: {valset_evaluation.score}" + ) else: # No valset provided, and not using train as val assert not self.use_train_as_val, "If report_train_scores is False, use_train_as_val must be False." @@ -201,7 +242,9 @@ def update_shuffled_trainset(self, original_trainset): for id in self.shuffled_trainset_ids: self.id_freqs[id] += 1 - num_to_pad = self.num_dspy_examples_per_grpo_step - (len(original_trainset) % self.num_dspy_examples_per_grpo_step) + num_to_pad = self.num_dspy_examples_per_grpo_step - ( + len(original_trainset) % self.num_dspy_examples_per_grpo_step + ) if num_to_pad > 0: # Select ids based on least frequent ids for _ in range(num_to_pad): @@ -224,32 +267,40 @@ def select_training_sample_and_update_shuffled_trainset( self.epoch = curr_epoch self.update_shuffled_trainset(original_trainset) - assert len(self.shuffled_trainset_ids) >= self.num_dspy_examples_per_grpo_step, f"Shuffled trainset length {len(self.shuffled_trainset_ids)} is less than num_dspy_examples_per_grpo_step {self.num_dspy_examples_per_grpo_step}" - assert len(self.shuffled_trainset_ids) % self.num_dspy_examples_per_grpo_step == 0, f"Shuffled trainset length {len(self.shuffled_trainset_ids)} is not divisible by num_dspy_examples_per_grpo_step {self.num_dspy_examples_per_grpo_step}" + assert ( + len(self.shuffled_trainset_ids) >= self.num_dspy_examples_per_grpo_step + ), f"Shuffled trainset length {len(self.shuffled_trainset_ids)} is less than num_dspy_examples_per_grpo_step {self.num_dspy_examples_per_grpo_step}" + assert ( + len(self.shuffled_trainset_ids) % self.num_dspy_examples_per_grpo_step == 0 + ), f"Shuffled trainset length {len(self.shuffled_trainset_ids)} is not divisible by num_dspy_examples_per_grpo_step {self.num_dspy_examples_per_grpo_step}" base_idx = base_idx % len(self.shuffled_trainset_ids) end_idx = base_idx + self.num_dspy_examples_per_grpo_step - assert end_idx <= len(self.shuffled_trainset_ids), f"End index {end_idx} is out of bounds for shuffled trainset length {len(self.shuffled_trainset_ids)}" + assert end_idx <= len( + self.shuffled_trainset_ids + ), f"End index {end_idx} is out of bounds for shuffled trainset length {len(self.shuffled_trainset_ids)}" selected_ids = self.shuffled_trainset_ids[base_idx:end_idx] selected_trainset = [original_trainset[i] for i in selected_ids] return selected_trainset def compile( self, - student: Module, + student: M, trainset: list[Example], teacher: Module | list[Module] | None = None, valset: list[Example] | None = None, **kwargs, - ) -> Module: - logger.info("Starting the GRPO compilation process... The LM(s) for the student program will be updated in place at the end of the training.") + ) -> M: + logger.info( + "Starting the GRPO compilation process... The LM(s) for the student program will be updated in place at the end of the training." + ) logger.info("Validating the inputs...") assert len(trainset) > 0, "Training set is empty. Please provide a non-empty training set." if len(trainset) < self.num_dspy_examples_per_grpo_step: logger.warning( - f"Number of training examples {len(trainset)} is less than the number of examples per GRPO step {self.num_dspy_examples_per_grpo_step}. " + f"Number of training examples {len(trainset)} is less than the number of examples per GRPO step {self.num_dspy_examples_per_grpo_step}. " "Repeating the training set to fill the GRPO step. This could lead to overfitting and training instability." ) multiplier = (self.num_dspy_examples_per_grpo_step + len(trainset) - 1) // len(trainset) @@ -294,7 +345,9 @@ def compile( pred_signature_hash_to_ind = {hash(pred.signature): ind for ind, pred in enumerate(student.predictors())} num_student_predictors = len(student.predictors()) - logging.info("Preparing the teacher program(s)... We will ensure that the provided programs have the same program structure as the student program.") + logging.info( + "Preparing the teacher program(s)... We will ensure that the provided programs have the same program structure as the student program." + ) if (isinstance(teacher, list) and len(teacher) == 0) or teacher is None: teacher = student teachers = teacher if isinstance(teacher, list) else [teacher] @@ -303,7 +356,9 @@ def compile( all_predictors_have_lms(t) # Ensure that the teachers list contain the student program - assert student in teachers, f"Student program {student} is not in the list of teachers {teachers}. Please provide the student program as one of the teachers. Alternatively, you can leave the teacher argument as None, and the student program will be used as the teacher program." + assert ( + student in teachers + ), f"Student program {student} is not in the list of teachers {teachers}. Please provide the student program as one of the teachers. Alternatively, you can leave the teacher argument as None, and the student program will be used as the teacher program." assert self.num_rollouts_per_grpo_step % len(teachers) == 0, ( f"The GRPO group size (num_rollouts_per_grpo_step) {self.num_rollouts_per_grpo_step} is not divisible by the number of teachers {len(teachers)}. " "This is required to ensure that each teacher gets the same number of examples." @@ -357,6 +412,7 @@ def compile( original_trainset=trainset, train_step_idx=train_step_idx, ) + def _any_available_for_step(): for _, job in grpo_training_jobs.items(): grpo_status: GRPOStatus = job.get_status() @@ -372,13 +428,15 @@ def _any_available_for_step(): logger.info("Bootstrapping data...") trace_data = [[[] for _ in range(len(teachers))] for _ in range(len(subsample_training_dataset))] for tind, teacher in enumerate(teachers): - subsample_training_dataset_repeated = [example for _ in range(num_samples_per_input) for example in subsample_training_dataset] + subsample_training_dataset_repeated = [ + example for _ in range(num_samples_per_input) for example in subsample_training_dataset + ] round_data = bootstrap_trace_data( program=teacher, dataset=subsample_training_dataset_repeated, metric=self.metric, num_threads=self.num_threads, - raise_on_error=False, # TODO(GRPO Team): This should be True, once the dspy format issue is fixed + raise_on_error=False, # TODO(GRPO Team): This should be True, once the dspy format issue is fixed capture_failed_parses=True, failure_score=self.failure_score, format_failure_score=self.format_failure_score, @@ -418,38 +476,60 @@ def _any_available_for_step(): for sample in teacher_data: # Each sample is a Dict(example, prediction, trace, example_ind, score) # sample['prediction'] is module_level prediction - assert sample["example_ind"] == example_ind, f"Example index {sample['example_ind']} does not match the expected index {example_ind}" + assert ( + sample["example_ind"] == example_ind + ), f"Example index {sample['example_ind']} does not match the expected index {example_ind}" - trace_instances_for_current_pred = [(*t, sample["score"]) for t in sample["trace"] if hash(t[0].signature) == hash(student.predictors()[pred_id].signature)] + trace_instances_for_current_pred = [ + (*t, sample["score"]) + for t in sample["trace"] + if hash(t[0].signature) == hash(student.predictors()[pred_id].signature) + ] predictor_example_invocations.append(trace_instances_for_current_pred) if len(predictor_example_invocations) == 0: - logger.warning(f"Skipping example {example_ind} for predictor {pred_id} as it has no invocations. This is likely due to all examples in the training set input, resulting in the model generating output not following the dspy response format.") + logger.warning( + f"Skipping example {example_ind} for predictor {pred_id} as it has no invocations. This is likely due to all examples in the training set input, resulting in the model generating output not following the dspy response format." + ) continue elif len(predictor_example_invocations) != self.num_rollouts_per_grpo_step: - logger.warning(f"Number of predictor example invocations {len(predictor_example_invocations)} does not match the expected batch size {self.num_rollouts_per_grpo_step}. This is likely due to all examples in the training set input, resulting in the model generating output not following the dspy response format.") - - min_len = min([len(predictor_example_invocations[i]) for i in range(len(predictor_example_invocations))]) - max_len = max([len(predictor_example_invocations[i]) for i in range(len(predictor_example_invocations))]) + logger.warning( + f"Number of predictor example invocations {len(predictor_example_invocations)} does not match the expected batch size {self.num_rollouts_per_grpo_step}. This is likely due to all examples in the training set input, resulting in the model generating output not following the dspy response format." + ) + + min_len = min( + [len(predictor_example_invocations[i]) for i in range(len(predictor_example_invocations))] + ) + max_len = max( + [len(predictor_example_invocations[i]) for i in range(len(predictor_example_invocations))] + ) if min_len == 0: - logger.warning(f"Skipping example {example_ind} for predictor {pred_id} as it has no invocations.") + logger.warning( + f"Skipping example {example_ind} for predictor {pred_id} as it has no invocations." + ) continue if self.variably_invoked_predictor_grouping_mode == "truncate": - predictor_example_invocations = [invocation[:min_len] for invocation in predictor_example_invocations] + predictor_example_invocations = [ + invocation[:min_len] for invocation in predictor_example_invocations + ] elif self.variably_invoked_predictor_grouping_mode == "fill": if self.variably_invoked_predictor_fill_strategy == "randint": - selector = lambda l: self.rng.choice(l) # noqa: E731, E741 + selector = lambda l: self.rng.choice(l) # noqa: E731, E741 else: - selector = lambda l: l[-1] # noqa: E731, E741 + selector = lambda l: l[-1] # noqa: E731, E741 predictor_example_invocations = [ invocation + [selector(invocation) for _ in range(max_len - len(invocation))] for invocation in predictor_example_invocations ] else: - assert self.variably_invoked_predictor_grouping_mode == "ragged", f"Unknown variably invoked predictor grouping mode {self.variably_invoked_predictor_grouping_mode}" - max_len = max([len(predictor_example_invocations[i]) for i in range(len(predictor_example_invocations))]) + assert ( + self.variably_invoked_predictor_grouping_mode == "ragged" + ), f"Unknown variably invoked predictor grouping mode {self.variably_invoked_predictor_grouping_mode}" + max_len = max( + [len(predictor_example_invocations[i]) for i in range(len(predictor_example_invocations))] + ) example_training_data: list[GRPOGroup] = [[] for _ in range(max_len)] @@ -465,60 +545,78 @@ def _any_available_for_step(): predictor = trace_instance[0] pred_lm = predictor.lm adapter = self.adapter[pred_lm] or settings.adapter or XMLAdapter() - assert isinstance(adapter, ChatAdapter), f"Adapter {adapter} is not a ChatAdapter. GRPO training is not supported for this adapter." + assert isinstance( + adapter, ChatAdapter + ), f"Adapter {adapter} is not a ChatAdapter. GRPO training is not supported for this adapter." # TODO(Lakshya): Currently we exclude demos from the training data # TODO(GRPO Team): Use build_call_data_from_trace (from bootstrap_finetune) instead of # dealing with the message formatting ourselves. inp_messages = adapter.format( signature=trace_instance[0].signature, inputs=trace_instance[1], - demos=[] # TODO: Add support for demos + demos=[], # TODO: Add support for demos ) if isinstance(trace_instance[2], FailedPrediction): score = trace_instance[2].format_reward or self.format_failure_score - example_training_data[group_idx].append({ - "messages": inp_messages, - "completion": { - "role": "assistant", - "content": trace_instance[2].completion_text, - }, - "reward": float(score), - }) - logger.warning(f"Adding a format failure example to the training data for predictor {pred_id} and example {example_ind}.") + example_training_data[group_idx].append( + { + "messages": inp_messages, + "completion": { + "role": "assistant", + "content": trace_instance[2].completion_text, + }, + "reward": float(score), + } + ) + logger.warning( + f"Adding a format failure example to the training data for predictor {pred_id} and example {example_ind}." + ) else: all_messages = adapter.format_finetune_data( signature=trace_instance[0].signature, inputs=trace_instance[1], outputs=trace_instance[2], - demos=[] # TODO: Add support for demos + demos=[], # TODO: Add support for demos )["messages"] - assert all_messages[:-1] == inp_messages, f"Input messages {inp_messages} do not match the expected messages {all_messages[:-1]}" - - example_training_data[group_idx].append({ - "messages": inp_messages, - "completion": { - "role": all_messages[-1]["role"], - "content": all_messages[-1]["content"], - }, - "reward": float(score), - }) + assert ( + all_messages[:-1] == inp_messages + ), f"Input messages {inp_messages} do not match the expected messages {all_messages[:-1]}" + + example_training_data[group_idx].append( + { + "messages": inp_messages, + "completion": { + "role": all_messages[-1]["role"], + "content": all_messages[-1]["content"], + }, + "reward": float(score), + } + ) train_batch_per_predictor[pred_id].extend(example_training_data) if not any(train_batch_per_predictor): - logger.warning("No training data found for this training step. This means that the model did not generate valid formatted responses for any of the examples in the training set. This is a critical error. Please check the model and the training set.") + logger.warning( + "No training data found for this training step. This means that the model did not generate valid formatted responses for any of the examples in the training set. This is a critical error. Please check the model and the training set." + ) continue for predictor_train_batch in train_batch_per_predictor: for grpo_train_group in predictor_train_batch: if len(grpo_train_group) != self.num_rollouts_per_grpo_step: - logger.warning(f"Number of completions {len(grpo_train_group)} does not match the expected number num_rollouts_per_grpo_step={self.num_rollouts_per_grpo_step}") - assert len(grpo_train_group) <= self.num_rollouts_per_grpo_step, f"Number of completions {len(grpo_train_group)} is greater than the expected number num_rollouts_per_grpo_step={self.num_rollouts_per_grpo_step}" + logger.warning( + f"Number of completions {len(grpo_train_group)} does not match the expected number num_rollouts_per_grpo_step={self.num_rollouts_per_grpo_step}" + ) + assert ( + len(grpo_train_group) <= self.num_rollouts_per_grpo_step + ), f"Number of completions {len(grpo_train_group)} is greater than the expected number num_rollouts_per_grpo_step={self.num_rollouts_per_grpo_step}" if len(set(map(repr, grpo_train_group))) < 2: # TODO(GRPO Team): How can we avoid this warning? - logger.warning(f"GRPOGroup has no diversity. This could be due to low temperature, or low number of rollouts, or the cache could be enabled inadvertently. The GRPOGroup is {grpo_train_group}.") + logger.warning( + f"GRPOGroup has no diversity. This could be due to low temperature, or low number of rollouts, or the cache could be enabled inadvertently. The GRPOGroup is {grpo_train_group}." + ) # We now run the GRPO step. Notes: # * The job here has a reference to a particular M that's attached @@ -536,15 +634,19 @@ def _any_available_for_step(): # LM. logger.info("Invoking GRPO training step...") for (lm_for_job, data_key), job in grpo_training_jobs.items(): - train_data: list[GRPOGroup] = sum(train_batch_per_predictor, []) if data_key is None else train_batch_per_predictor[data_key] #noqa: RUF017 + train_data: list[GRPOGroup] = ( + sum(train_batch_per_predictor, []) if data_key is None else train_batch_per_predictor[data_key] + ) # noqa: RUF017 for group in train_data: if len(group) != self.num_rollouts_per_grpo_step: # TODO(GRPO Team): This is very undesirable. This occurs only because in some of the generations, the model does not follow the correct dspy format. # The ideal solution is to identify the full response string in that predictor's group, and then assign a high-negative (user-configurable) reward to that group. # Pad the group to the expected number of generations by repeating the whole group, might require multiple iterations while len(group) < self.num_rollouts_per_grpo_step: - group.extend(group[:min(self.num_rollouts_per_grpo_step - len(group), len(group))]) - assert len(group) == self.num_rollouts_per_grpo_step, f"Number of completions {len(group)} does not match the expected number self.num_rollouts_per_grpo_step={self.num_rollouts_per_grpo_step}" + group.extend(group[: min(self.num_rollouts_per_grpo_step - len(group), len(group))]) + assert ( + len(group) == self.num_rollouts_per_grpo_step + ), f"Number of completions {len(group)} does not match the expected number self.num_rollouts_per_grpo_step={self.num_rollouts_per_grpo_step}" # Determine available batch IDs for this specific job grpo_status: GRPOStatus = job.get_status() diff --git a/dspy/teleprompt/mipro_optimizer_v2.py b/dspy/teleprompt/mipro_optimizer_v2.py index c96d71d3f5..3a0f585ef6 100644 --- a/dspy/teleprompt/mipro_optimizer_v2.py +++ b/dspy/teleprompt/mipro_optimizer_v2.py @@ -1,7 +1,7 @@ import logging import random from collections import defaultdict -from typing import TYPE_CHECKING, Any, Callable, Literal +from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar import numpy as np @@ -19,6 +19,7 @@ save_candidate_program, set_signature, ) +from dspy.primitives import Module if TYPE_CHECKING: import optuna @@ -43,6 +44,8 @@ BOLD = "\033[1m" ENDC = "\033[0m" # Resets the color to default +M = TypeVar("M", bound=Module) + class MIPROv2(Teleprompter): def __init__( @@ -91,11 +94,13 @@ def __init__( self.rng = None if not self.prompt_model or not self.task_model: - raise ValueError("Either provide both prompt_model and task_model or set a default LM through dspy.configure(lm=...)") + raise ValueError( + "Either provide both prompt_model and task_model or set a default LM through dspy.configure(lm=...)" + ) def compile( self, - student: Any, + student: M, *, trainset: list, teacher: Any = None, @@ -112,28 +117,22 @@ def compile( view_data_batch_size: int = 10, tip_aware_proposer: bool = True, fewshot_aware_proposer: bool = True, - requires_permission_to_run: bool | None = None, # deprecated + requires_permission_to_run: bool | None = None, # deprecated provide_traceback: bool | None = None, - ) -> Any: + ) -> M: if requires_permission_to_run == False: - logger.warning( - "'requires_permission_to_run' is deprecated and will be removed in a future version." - ) + logger.warning("'requires_permission_to_run' is deprecated and will be removed in a future version.") elif requires_permission_to_run == True: - raise ValueError("User confirmation is removed from MIPROv2. Please remove the 'requires_permission_to_run' argument.") + raise ValueError( + "User confirmation is removed from MIPROv2. Please remove the 'requires_permission_to_run' argument." + ) - effective_max_errors = ( - self.max_errors - if self.max_errors is not None - else dspy.settings.max_errors - ) + effective_max_errors = self.max_errors if self.max_errors is not None else dspy.settings.max_errors effective_max_bootstrapped_demos = ( max_bootstrapped_demos if max_bootstrapped_demos is not None else self.max_bootstrapped_demos ) - effective_max_labeled_demos = ( - max_labeled_demos if max_labeled_demos is not None else self.max_labeled_demos - ) + effective_max_labeled_demos = max_labeled_demos if max_labeled_demos is not None else self.max_labeled_demos zeroshot_opt = (effective_max_bootstrapped_demos == 0) and (effective_max_labeled_demos == 0) @@ -157,19 +156,14 @@ def compile( seed = seed or self.seed self._set_random_seeds(seed) - # Set training & validation sets trainset, valset = self._set_and_validate_datasets(trainset, valset) num_instruct_candidates = ( - self.num_instruct_candidates - if self.num_instruct_candidates is not None - else self.num_candidates + self.num_instruct_candidates if self.num_instruct_candidates is not None else self.num_candidates ) num_fewshot_candidates = ( - self.num_fewshot_candidates - if self.num_fewshot_candidates is not None - else self.num_candidates + self.num_fewshot_candidates if self.num_fewshot_candidates is not None else self.num_candidates ) # Set hyperparameters based on run mode (if set) @@ -419,9 +413,7 @@ def _bootstrap_fewshot_examples( num_candidate_sets=num_fewshot_candidates, trainset=trainset, max_labeled_demos=(LABELED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else max_labeled_demos), - max_bootstrapped_demos=( - BOOTSTRAPPED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else max_bootstrapped_demos - ), + max_bootstrapped_demos=(BOOTSTRAPPED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else max_bootstrapped_demos), metric=self.metric, max_errors=max_errors, teacher=teacher, diff --git a/dspy/teleprompt/simba.py b/dspy/teleprompt/simba.py index a604f5ee17..a79888f524 100644 --- a/dspy/teleprompt/simba.py +++ b/dspy/teleprompt/simba.py @@ -2,26 +2,29 @@ import logging import random -from typing import Any, Callable +from typing import Any, Callable, TypeVar import numpy as np import dspy +from dspy.primitives import Module from dspy.teleprompt.simba_utils import append_a_demo, append_a_rule, prepare_models_for_resampling, wrap_program from dspy.teleprompt.teleprompt import Teleprompter logger = logging.getLogger(__name__) +M = TypeVar("M", bound=Module) + class SIMBA(Teleprompter): """ SIMBA (Stochastic Introspective Mini-Batch Ascent) optimizer for DSPy. - - SIMBA is a DSPy optimizer that uses the LLM to analyze its own performance and - generate improvement rules. It samples mini-batches, identifies challenging examples - with high output variability, then either creates self-reflective rules or adds + + SIMBA is a DSPy optimizer that uses the LLM to analyze its own performance and + generate improvement rules. It samples mini-batches, identifies challenging examples + with high output variability, then either creates self-reflective rules or adds successful examples as demonstrations. - + For more details, see: https://dspy.ai/api/optimizers/SIMBA/ """ @@ -82,21 +85,15 @@ def __init__( else: self.strategies = [append_a_rule] - def compile( - self, - student: dspy.Module, - *, - trainset: list[dspy.Example], - seed: int = 0 - ) -> dspy.Module: + def compile(self, student: M, *, trainset: list[dspy.Example], seed: int = 0) -> M: """ Compile and optimize the student module using SIMBA. - + Args: student: The module to optimize trainset: Training examples for optimization seed: Random seed for reproducibility - + Returns: The optimized module with candidate_programs and trial_logs attached """ diff --git a/scratchpad.py b/scratchpad.py new file mode 100644 index 0000000000..b55bf5e1c8 --- /dev/null +++ b/scratchpad.py @@ -0,0 +1,19 @@ +from dspy.teleprompt.teleprompt import Teleprompter +from dspy.primitives import Module +import dspy +import random +from dspy.teleprompt.mipro_optimizer_v2 import MIPROv2 + + +class MyAgent(Module): + def __init__(self): + super().__init__() + self.predict = dspy.ReAct("question -> answer", tools=[self.random_function]) + + def random_function(self): + return random.random() + + +optimizer = MIPROv2(metric=dspy.evaluate.answer_exact_match, auto="light") +compiled_agent = optimizer.compile(student=MyAgent(), trainset=trainset, teacher=teacher, valset=valset) +print(compiled_agent.random_function()) From e45875eeb8d0fb1aada2f6378313305e97b019bc Mon Sep 17 00:00:00 2001 From: Ty Todd Date: Mon, 3 Nov 2025 12:02:19 -0800 Subject: [PATCH 3/3] cleanup --- scratchpad.py | 19 ------------------- 1 file changed, 19 deletions(-) delete mode 100644 scratchpad.py diff --git a/scratchpad.py b/scratchpad.py deleted file mode 100644 index b55bf5e1c8..0000000000 --- a/scratchpad.py +++ /dev/null @@ -1,19 +0,0 @@ -from dspy.teleprompt.teleprompt import Teleprompter -from dspy.primitives import Module -import dspy -import random -from dspy.teleprompt.mipro_optimizer_v2 import MIPROv2 - - -class MyAgent(Module): - def __init__(self): - super().__init__() - self.predict = dspy.ReAct("question -> answer", tools=[self.random_function]) - - def random_function(self): - return random.random() - - -optimizer = MIPROv2(metric=dspy.evaluate.answer_exact_match, auto="light") -compiled_agent = optimizer.compile(student=MyAgent(), trainset=trainset, teacher=teacher, valset=valset) -print(compiled_agent.random_function())