Skip to content

Commit 88ae78a

Browse files
ZiemsLakshyAAAgrawaldilarasoyluNoah Ziemschenmoneygithub
authored
Add GRPO Optimizer to DSPy (#8171)
* D1 for GRPO * Improve type for arbor * Add temp test script for grpo * Add note about assumption of same inputs to all predictors * Disable LM cache in GRPO * Add support for valset * Add configurable variable module invocation handling strategy * Noahs dspy.LM changes and dspy.ArborProvider implementation * Add latest arbor changes * First working grpo version * Add modules * Add training args in initialize * Fix grpo * Add batches * Update finetuning infra * Revise server interface * Update example script * Move temporary interface to a separate file * Add LM level reinforce interface * Update testing script * Update api_base access for finetune * Style check * Style check all * Add Test script with MATH dataset * Ensure grpo trainer does not crash due to format issues, but temporary fix * Add error log * Fix termination * Delete temp files * Add diff * Add model update endpoint support * Remove experimental flag * Remove extra files * Add GRPO error resiliency to avoid parsing failures lead to crashes * Param Passthrough and Consistent Tutorial Script (#3) * Add param passthrough and default banking77 tutorial * Add more threads * Update banking tutorial --------- Co-authored-by: Noah Ziems <nziems2@nziems2@nd.edu> * Lower beta param for banking tutorial * Add warning on no training data * Add train logging to GRPIO * Add max_prompt_length and max_completion_length support * fix litellm retries * no jsonadapter * fix errors * fix tests * fix tests * add the retry strategy back * Add working implementation of format errors and negative rewards * Fix bugs in validation * Add validation logic to grpo * Add more supported args * Support max grad norm * Add Train Shuffling logic * Add lora support * Add soft format rewards * Disable proivide_traceback in all grpo invoked evaluates * Remove temporary tutorial script * Revert classification finetuning tutorial * Comment out json adapter test * Fix ruff errors * Add teacher (#8) * Modify teacher preparation logic * Re-add teachers to GRPO * Style fix * Update tutorial script * Housekeeping * Revert number of train steps * Address PR comments * Add wandb support for GRPO training runs * Add completion logging * Add logging steps support * update report_to to be default none * Add max_context_length * Fix num_samples_per_input computation * Checkpointing Endpoints (#10) * Fix typo * Fix checkpoint url * fix merge conflict leftover * shorten the warning message in json adapter * fix the error piping --------- Co-authored-by: Lakshya A Agrawal <lakshyaaagrawal@berkeley.edu> Co-authored-by: Dilara Soylu <21346670+dilarasoylu@users.noreply.github.com> Co-authored-by: Noah Ziems <nziems2@nziems2@nd.edu> Co-authored-by: chenmoneygithub <chen.qian@databricks.com>
1 parent 637d759 commit 88ae78a

File tree

11 files changed

+1338
-37
lines changed

11 files changed

+1338
-37
lines changed

docs/docs/tutorials/classification_finetuning/index.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@
211211
},
212212
{
213213
"cell_type": "code",
214-
"execution_count": 5,
214+
"execution_count": null,
215215
"metadata": {},
216216
"outputs": [],
217217
"source": [

dspy/adapters/json_adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def __call__(
6262
structured_output_model = _get_structured_outputs_response_format(signature)
6363
lm_kwargs["response_format"] = structured_output_model
6464
return super().__call__(lm, lm_kwargs, signature, demos, inputs)
65-
except Exception as e:
66-
logger.warning(f"Failed to use structured output format. Falling back to JSON mode. Error: {e}")
65+
except Exception:
66+
logger.warning("Failed to use structured output format, falling back to JSON mode.")
6767
try:
6868
lm_kwargs["response_format"] = {"type": "json_object"}
6969
return super().__call__(lm, lm_kwargs, signature, demos, inputs)

dspy/clients/lm.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import dspy
1212
from dspy.clients.cache import request_cache
1313
from dspy.clients.openai import OpenAIProvider
14-
from dspy.clients.provider import Provider, TrainingJob
14+
from dspy.clients.provider import Provider, TrainingJob, ReinforceJob
1515
from dspy.clients.utils_finetune import TrainDataFormat
1616
from dspy.dsp.utils.settings import settings
1717
from dspy.utils.callback import BaseCallback
@@ -188,10 +188,6 @@ def finetune(
188188
) -> TrainingJob:
189189
from dspy import settings as settings
190190

191-
err = "Fine-tuning is an experimental feature."
192-
err += " Set `dspy.settings.experimental` to `True` to use it."
193-
assert settings.experimental, err
194-
195191
err = f"Provider {self.provider} does not support fine-tuning."
196192
assert self.provider.finetunable, err
197193

@@ -212,6 +208,17 @@ def thread_function_wrapper():
212208

213209
return job
214210

211+
def reinforce(self, train_kwargs) -> ReinforceJob:
212+
# TODO(GRPO Team): Should we return an initialized job here?
213+
from dspy import settings as settings
214+
215+
err = f"Provider {self.provider} does not implement the reinforcement learning interface."
216+
assert self.provider.reinforceable, err
217+
218+
job = self.provider.ReinforceJob(lm=self, train_kwargs=train_kwargs)
219+
job.initialize()
220+
return job
221+
215222
def _run_finetune_job(self, job: TrainingJob):
216223
# TODO(enhance): We should listen for keyboard interrupts somewhere.
217224
# Requires TrainingJob.cancel() to be implemented for each provider.

0 commit comments

Comments
 (0)