Skip to content

Commit 8220c8f

Browse files
Add GRPO error resiliency to avoid parsing failures lead to crashes
1 parent 0ef5932 commit 8220c8f

File tree

2 files changed

+56
-36
lines changed

2 files changed

+56
-36
lines changed

dspy/teleprompt/bootstrap_finetune.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
# attached to LMs themselves -- an LM could know which adapter it should
5050
# be used with along with the train_kwargs. This will lead the only
5151
# required argument for LM.finetune() to be the train dataset.
52+
5253
super().__init__(train_kwargs=train_kwargs)
5354
self.metric = metric
5455
self.multitask = multitask
@@ -200,13 +201,12 @@ def build_call_data_from_trace(
200201
)
201202
return call_data
202203

203-
204204
def bootstrap_trace_data(
205205
program: Program,
206206
dataset: List[Example],
207207
metric: Optional[Callable] = None,
208208
num_threads: Optional[int] = None,
209-
raise_on_error=True,
209+
raise_on_error=True
210210
) -> List[Dict[str, Any]]:
211211
# Return a list of dicts with the following keys: example_ind, example, prediction, trace, and score
212212
# (if metric != None)
@@ -216,6 +216,7 @@ def bootstrap_trace_data(
216216
display_progress=True,
217217
return_outputs=True,
218218
provide_traceback=True, # TODO(check with team)
219+
max_errors=len(dataset)*10, # TODO(check with team)
219220
)
220221

221222
def wrapped_metric(example, prediction, trace=None):
@@ -274,11 +275,13 @@ def wrapped_program(**kwargs):
274275
# return data_dict
275276

276277

278+
# Note: Shared below are useful functions for preparing student/teacher programs
279+
# Similar methods are implemented separately and used by other DSPy
280+
# teleprompters. These can be moved to shared locations.
277281
def all_predictors_have_lms(program: Program) -> bool:
278282
"""Return True if all predictors in the program have an LM set."""
279283
return all(pred.lm for pred in program.predictors())
280284

281-
282285
def copy_program_with_lms(program: Program) -> Program:
283286
pred_lms = [pred.lm for pred in program.predictors()]
284287
program = program.deepcopy()
@@ -290,13 +293,19 @@ def copy_program_with_lms(program: Program) -> Program:
290293
def prepare_student(student: Program) -> Program:
291294
if getattr(student, "_compiled", False):
292295
raise ValueError("The student program should not be compiled.")
296+
297+
# TODO: Should we use reset_copy here? How would it affect the student
298+
# program's predictor LMs, if they are set?
299+
300+
# TODO: Should there be a deepcopy here?
301+
# student = student.deepcopy()
293302
return student
294303

295304

296305
def prepare_teacher(student: Program, teacher: Optional[Program] = None) -> Program:
297306
if teacher is None:
298307
return copy_program_with_lms(student)
299-
308+
300309
# We avoid modifying the original teacher program by making a copy
301310
teacher = copy_program_with_lms(teacher)
302311

dspy/teleprompt/grpo.py

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,25 @@ def __init__(
5858
assert variably_invoked_predictor_fill_strategy in ['randint', 'max'], "variably_invoked_predictor_fill_strategy must be either 'randint' or 'max'"
5959
self.variably_invoked_predictor_fill_strategy = variably_invoked_predictor_fill_strategy
6060

61+
def validate_trace_data_and_log_issues(
62+
self,
63+
trace_data: List[List[List[Dict[str, Any]]]],
64+
subsample_training_dataset: List[Example],
65+
num_teachers: int,
66+
):
67+
# At this point, trace_data: List[example_idx -> List[teacher_idx -> [num_samples_per_input * Dict(example, prediction, trace, example_ind, score)]]]
68+
# Shape of trace is: [dspy_module_invocation_idx -> Tuple[Predictor, PredictorInputs, Prediction]]
69+
assert len(trace_data) == len(subsample_training_dataset), f"Trace data length {len(trace_data)} does not match the number of examples {len(subsample_training_dataset)}"
70+
assert len(trace_data[0]) == num_teachers, f"Trace data length {len(trace_data[0])} does not match the number of teachers {num_teachers}"
71+
# TODO(GRPO Team): Ideally, once the dspy format issue is fixed, this change should be reverted back to being a normal assert.
72+
if len(trace_data[0][0]) == 0:
73+
logger.warning(f"Trace data for example {0} and teacher {0} is empty. This is likely due to all examples in the training set input, resulting in the model generating output not following the dspy response format.")
74+
elif len(trace_data[0][0]) != self.num_samples_per_input:
75+
logger.warning(f"Trace data length {len(trace_data[0][0])} does not match the expected number of samples per input {self.num_samples_per_input}")
76+
assert "trace" in trace_data[0][0][0], "Trace data does not contain the 'trace' key"
77+
assert len(trace_data[0][0][0]["trace"]) > 0, "Trace data is empty"
78+
assert len(trace_data[0][0][0]["trace"][0]) == 3, f"Trace tuple length {len(trace_data[0][0][0]['trace'][0])} does not match the expected length 3"
79+
6180
def compile(
6281
self,
6382
student: Program,
@@ -134,7 +153,6 @@ def compile(
134153
train_kwargs = self.train_kwargs[pred.lm]
135154
job = pred.lm.reinforce(train_kwargs=train_kwargs)
136155
grpo_training_jobs[job_key] = job
137-
138156
if valset is None and self.use_train_as_val:
139157
logger.info("Using the training set as the validation set.")
140158
valset = trainset
@@ -151,8 +169,8 @@ def compile(
151169
display_progress=True,
152170
return_outputs=False,
153171
provide_traceback=True, # TODO(check with team)
172+
max_errors=len(valset)*10, # TODO(check with team)
154173
)
155-
156174
logger.info("Evaluating the student program on the validation set before training loop...")
157175
valset_evaluation = valset_evaluator(student, metric=self.metric)
158176
logger.info(f"Student program validation set score before training loop: {valset_evaluation}")
@@ -166,32 +184,26 @@ def compile(
166184
logger.info("Bootstrapping data...")
167185
trace_data = [[[] for _ in range(len(teachers))] for _ in range(len(subsample_training_dataset))]
168186
for tind, teacher in enumerate(teachers):
169-
for _ in range(self.num_samples_per_input):
170-
# We rely on disabled caches to ensure that we get different
171-
# traces
172-
round_data = bootstrap_trace_data(
173-
program=teacher,
174-
dataset=subsample_training_dataset,
175-
metric=self.metric,
176-
num_threads=self.num_threads,
177-
raise_on_error=False, # TODO(GRPO Team): This should be True, once the dspy format issue is fixed
178-
)
179-
for data_dict in round_data:
180-
trace_data[data_dict['example_ind']][tind].append(data_dict)
187+
subsample_training_dataset_repeated = [example for _ in range(self.num_samples_per_input) for example in subsample_training_dataset]
188+
round_data = bootstrap_trace_data(
189+
program=teacher,
190+
dataset=subsample_training_dataset_repeated,
191+
metric=self.metric,
192+
num_threads=self.num_threads,
193+
raise_on_error=False, # TODO(GRPO Team): This should be True, once the dspy format issue is fixed
194+
)
195+
for data_dict in round_data:
196+
example_ind_in_subsample = data_dict['example_ind'] % len(subsample_training_dataset)
197+
data_dict["example_ind"] = example_ind_in_subsample
198+
trace_data[example_ind_in_subsample][tind].append(data_dict)
181199

182200
# At this point, trace_data: List[example_idx -> List[teacher_idx -> [num_samples_per_input * Dict(example, prediction, trace, example_ind, score)]]]
183201
# Shape of trace is: [dspy_module_invocation_idx -> Tuple[Predictor, PredictorInputs, Prediction]]
184-
assert len(trace_data) == len(subsample_training_dataset), f"Trace data length {len(trace_data)} does not match the number of examples {len(subsample_training_dataset)}"
185-
assert len(trace_data[0]) == len(teachers), f"Trace data length {len(trace_data[0])} does not match the number of teachers {len(teachers)}"
186-
187-
# TODO(GRPO Team): Ideally, once the dspy format issue is fixed, this change should be reverted back to being a normal assert.
188-
if len(trace_data[0][0]) == 0:
189-
logger.warning(f"Trace data for example {0} and teacher {0} is empty. This is likely due to all examples in the training set input, resulting in the model generating output not following the dspy response format.")
190-
elif len(trace_data[0][0]) != self.num_samples_per_input:
191-
logger.warning(f"Trace data length {len(trace_data[0][0])} does not match the expected number of samples per input {self.num_samples_per_input}")
192-
assert "trace" in trace_data[0][0][0], "Trace data does not contain the 'trace' key"
193-
assert len(trace_data[0][0][0]["trace"]) > 0, "Trace data is empty"
194-
assert len(trace_data[0][0][0]["trace"][0]) == 3, f"Trace tuple length {len(trace_data[0][0][0]['trace'][0])} does not match the expected length 3"
202+
self.validate_trace_data_and_log_issues(
203+
trace_data=trace_data,
204+
subsample_training_dataset=subsample_training_dataset,
205+
num_teachers=len(teachers),
206+
)
195207

196208
logger.info("Preparing the training data batch from bootstrapped examples for GRPO...")
197209
# Now, we need to prepare batches of data to be sent for training
@@ -215,8 +227,11 @@ def compile(
215227

216228
predictor_example_invocations.append(trace_instances_for_current_pred)
217229

218-
if len(predictor_example_invocations) != num_generations:
219-
logger.warning(f"Number of predictor example invocations {len(predictor_example_invocations)} does not match the expected batch size {num_generations}")
230+
if len(predictor_example_invocations) == 0:
231+
logger.warning(f"Skipping example {example_ind} for predictor {pred_id} as it has no invocations. This is likely due to all examples in the training set input, resulting in the model generating output not following the dspy response format.")
232+
continue
233+
elif len(predictor_example_invocations) != num_generations:
234+
logger.warning(f"Number of predictor example invocations {len(predictor_example_invocations)} does not match the expected batch size {num_generations}. This is likely due to all examples in the training set input, resulting in the model generating output not following the dspy response format.")
220235

221236
min_len = min([len(predictor_example_invocations[i]) for i in range(len(predictor_example_invocations))])
222237
max_len = max([len(predictor_example_invocations[i]) for i in range(len(predictor_example_invocations))])
@@ -237,7 +252,6 @@ def compile(
237252
]
238253
else:
239254
assert self.variably_invoked_predictor_grouping_mode == 'ragged', f"Unknown variably invoked predictor grouping mode {self.variably_invoked_predictor_grouping_mode}"
240-
241255
max_len = max([len(predictor_example_invocations[i]) for i in range(len(predictor_example_invocations))])
242256

243257
example_training_data: List[GRPOGroup] = [[] for _ in range(max_len)]
@@ -276,15 +290,13 @@ def compile(
276290
# example_training_data[group_idx]["input"]["messages"] = [{"role": msg['role'], 'content': msg['content']} for msg in inp_messages]
277291
# elif example_training_data[group_idx]["input"]["messages"] != inp_messages:
278292
# logger.info(f"Input messages {inp_messages} do not match the expected messages {example_training_data[group_idx]['input']['messages']}")
279-
280293
# response_msg = all_messages[-1]
281294
# assert 'role' in response_msg and 'content' in response_msg, f"Response message {response_msg} does not contain the expected keys 'role' and 'content'"
282295
# example_training_data[group_idx]["completions"].append({
283296
# "role": response_msg["role"],
284297
# "content": response_msg["content"],
285298
# "reward": score,
286299
# })
287-
288300
example_training_data[group_idx].append({
289301
"messages": inp_messages,
290302
"completion": {
@@ -333,7 +345,7 @@ def compile(
333345
assert len(group) == num_generations, f"Number of completions {len(group)} does not match the expected number num_samples_per_input*len(teachers)={num_generations}"
334346

335347
job.step(train_data=train_data, train_data_format=TrainDataFormat.GRPO_CHAT)
336-
348+
337349
for (lm, data_key), job in grpo_training_jobs.items():
338350
if (train_step_idx + 1) % self.train_kwargs[lm]["update_interval"] == 0 and train_step_idx != 0:
339351
logger.info(f"Current train step is {train_step_idx + 1}. Updating the model...")
@@ -344,7 +356,6 @@ def compile(
344356
logger.info(f"Evaluating the student program on the validation set after training step {train_step_idx + 1}/{self.num_train_steps}")
345357
valset_evaluation = valset_evaluator(student, metric=self.metric)
346358
logger.info(f"Student program validation set score after training step {train_step_idx + 1}/{self.num_train_steps}: {valset_evaluation}")
347-
348359
logger.info("Done with the iterations! Retrieving the final model(s)...")
349360
for (lm, data_key), job in grpo_training_jobs.items():
350361
job.terminate()

0 commit comments

Comments
 (0)