Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions dspy/teleprompt/bettertogether.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import random
from typing import Callable
from typing import Callable, TypeVar

import dspy
from dspy.primitives.example import Example
Expand All @@ -14,20 +14,22 @@
)
from dspy.teleprompt.random_search import BootstrapFewShotWithRandomSearch
from dspy.teleprompt.teleprompt import Teleprompter
from dspy.primitives import Module

M = TypeVar("M", bound=Module)
logger = logging.getLogger(__name__)


class BetterTogether(Teleprompter):

STRAT_SEP = " -> "

def __init__(self,
def __init__(
self,
metric: Callable,
prompt_optimizer: Teleprompter | None = None,
weight_optimizer: Teleprompter | None = None,
seed: int | None = None,
):
):
if not dspy.settings.experimental:
raise ValueError("This is an experimental optimizer. Set `dspy.settings.experimental` to `True` to use it.")

Expand All @@ -37,7 +39,9 @@ def __init__(self,
# a BootstrapFinetune without a metric, say, if there aren't labels
# available for the training data. Should this be noted somewhere?
# TODO: We should re-consider if the metric should be required.
self.prompt_optimizer = prompt_optimizer if prompt_optimizer else BootstrapFewShotWithRandomSearch(metric=metric)
self.prompt_optimizer = (
prompt_optimizer if prompt_optimizer else BootstrapFewShotWithRandomSearch(metric=metric)
)
self.weight_optimizer = weight_optimizer if weight_optimizer else BootstrapFinetune(metric=metric)

is_supported_prompt = isinstance(self.prompt_optimizer, BootstrapFewShotWithRandomSearch)
Expand All @@ -52,11 +56,11 @@ def __init__(self,

def compile(
self,
student: Module,
student: M,
trainset: list[Example],
strategy: str = "p -> w -> p",
valset_ratio = 0.1,
) -> Module:
valset_ratio=0.1,
) -> M:
Comment on lines 57 to +63
Copy link

Copilot AI Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method requires at least 3 positional arguments, whereas overridden Teleprompter.compile requires 2.

Copilot uses AI. Check for mistakes.
# TODO: We could record acc on a different valset to pick the best
# strategy within the provided strategy
logger.info("Validating the strategy")
Expand Down Expand Up @@ -91,10 +95,9 @@ def _run_strategies(self, parsed_strategy, student, trainset, valset_ratio) -> M
launched_flag = False

for ind, step_code in enumerate(parsed_strategy):
current_strategy = self.STRAT_SEP.join(parsed_strategy[:ind + 1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not include unrelated changes

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Must be my formatter. Let me see if I can undo the other formatting changes.

current_strategy = self.STRAT_SEP.join(parsed_strategy[: ind + 1])
logger.info(
f"\n########## Step {ind + 1} of {len(parsed_strategy)} - Strategy "
f"'{current_strategy}' ##########"
f"\n########## Step {ind + 1} of {len(parsed_strategy)} - Strategy " f"'{current_strategy}' ##########"
)

logger.info("Shuffling the trainset...")
Expand Down
8 changes: 4 additions & 4 deletions dspy/teleprompt/bootstrap_finetune.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from collections import defaultdict
from typing import Any, Callable
from typing import Any, Callable, TypeVar

import dspy
from dspy.adapters.base import Adapter
Expand All @@ -16,6 +16,8 @@

logger = logging.getLogger(__name__)

M = TypeVar("M", bound=Module)


class FinetuneTeleprompter(Teleprompter):
def __init__(
Expand Down Expand Up @@ -57,9 +59,7 @@ def __init__(
self.exclude_demos = exclude_demos
self.num_threads = num_threads

def compile(
self, student: Module, trainset: list[Example], teacher: Module | list[Module] | None = None
) -> Module:
def compile(self, student: M, trainset: list[Example], teacher: Module | list[Module] | None = None) -> M:
# TODO: Print statements can be converted to logger.info if we ensure
# that the default DSPy logger logs info level messages in notebook
# environments.
Comment on lines +62 to 65
Copy link

Copilot AI Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method requires at least 3 positional arguments, whereas overridden Teleprompter.compile requires 2.

Suggested change
def compile(self, student: M, trainset: list[Example], teacher: Module | list[Module] | None = None) -> M:
# TODO: Print statements can be converted to logger.info if we ensure
# that the default DSPy logger logs info level messages in notebook
# environments.
def compile(self, student: M, trainset: list[Example], **kwargs) -> M:
# TODO: Print statements can be converted to logger.info if we ensure
# that the default DSPy logger logs info level messages in notebook
# environments.
teacher = kwargs.get('teacher', None)

Copilot uses AI. Check for mistakes.
Expand Down
Loading
Loading