Skip to content

Commit aa229db

Browse files
author
Krista Opsahl-Ong
committed
setting prompt_model to default model when no prompt_model is specified
1 parent ddd3418 commit aa229db

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

dspy/teleprompt/signature_opt.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
1818
Note that this teleprompter takes in the following parameters:
1919
20-
* prompt_model: The model used for prompt generation.
20+
* prompt_model: The model used for prompt generation. When unspecified, defaults to the model set in settings (ie. dspy.settings.configure(lm=task_model)).
2121
* metric: The task metric used for optimization.
2222
* breadth: The number of new prompts to generate at each iteration. Default=10.
2323
* depth: The number of times we should ask our prompt model to genereate new prompts, with the history of the past prompts as input. Default=3.
@@ -47,7 +47,7 @@ class GenerateInstructionGivenAttempts(dspy.Signature):
4747
proposed_prefix_for_output_field = dspy.OutputField(desc="The string at the end of the prompt, which will help the model start solving the task")
4848

4949
class SignatureOptimizer(Teleprompter):
50-
def __init__(self, prompt_model, metric=None, breadth=10, depth=3, init_temperature=1.4, verbose=False, track_stats=False):
50+
def __init__(self, prompt_model=None, metric=None, breadth=10, depth=3, init_temperature=1.4, verbose=False, track_stats=False):
5151
self.metric = metric
5252
self.breadth = breadth
5353
self.depth = depth
@@ -109,15 +109,18 @@ def compile(self, student, *, devset, eval_kwargs):
109109
else:
110110
basic_instruction = predictor.extended_signature1.instructions
111111
basic_prefix = predictor.extended_signature1.fields[-1].name
112-
with dspy.settings.context(lm=self.prompt_model):
112+
if self.prompt_model:
113+
with dspy.settings.context(lm=self.prompt_model):
114+
instruct = dspy.Predict(BasicGenerateInstruction, n=self.breadth-1, temperature=self.init_temperature)(basic_instruction=basic_instruction)
115+
else:
113116
instruct = dspy.Predict(BasicGenerateInstruction, n=self.breadth-1, temperature=self.init_temperature)(basic_instruction=basic_instruction)
114117
# Add in our initial prompt as a candidate as well
115118
instruct.completions.proposed_instruction.append(basic_instruction)
116119
instruct.completions.proposed_prefix_for_output_field.append(basic_prefix)
117120
candidates[id(predictor)] = instruct.completions
118121
evaluated_candidates[id(predictor)] = {}
119122

120-
if self.verbose: print(f"{self.prompt_model.inspect_history(n=1)}")
123+
if self.verbose and self.prompt_model: print(f"{self.prompt_model.inspect_history(n=1)}")
121124

122125
latest_candidates = candidates
123126
all_candidates = candidates
@@ -164,7 +167,7 @@ def compile(self, student, *, devset, eval_kwargs):
164167
if self.verbose: print()
165168
if self.verbose: print(f"At Depth {d}/{self.depth}, Evaluating Prompt Candidate #{c_i}/{len(candidates_)} for Predictor {p_i} of {len(module.predictors())}.")
166169
score = evaluate(module_clone, devset=devset, **eval_kwargs)
167-
print(f"prompt_model.inspect_history(n=1) {self.prompt_model.inspect_history(n=1)}")
170+
if self.verbose and self.prompt_model: print(f"prompt_model.inspect_history(n=1) {self.prompt_model.inspect_history(n=1)}")
168171
total_calls += 1
169172
if self.verbose: print(f"----------------")
170173

@@ -249,15 +252,19 @@ def compile(self, student, *, devset, eval_kwargs):
249252
attempts.append(f'Resulting Score #{shortest_len-i}: {best_predictors[i]["score"]}')
250253

251254
# Generate next batch of potential prompts to optimize, with previous attempts as input
252-
with dspy.settings.context(lm=self.prompt_model):
255+
if self.prompt_model:
256+
with dspy.settings.context(lm=self.prompt_model):
257+
instr = dspy.Predict(GenerateInstructionGivenAttempts, n=self.breadth, temperature=self.init_temperature)(attempted_instructions=attempts)
258+
else:
253259
instr = dspy.Predict(GenerateInstructionGivenAttempts, n=self.breadth, temperature=self.init_temperature)(attempted_instructions=attempts)
254-
if self.verbose: print(f"{self.prompt_model.inspect_history(n=1)}")
260+
261+
if self.verbose and self.prompt_model: print(f"{self.prompt_model.inspect_history(n=1)}")
255262
# Get candidates for each predictor
256263
new_candidates[id(p_base)] = instr.completions
257264
all_candidates[id(p_base)].proposed_instruction.extend(instr.completions.proposed_instruction)
258265
all_candidates[id(p_base)].proposed_prefix_for_output_field.extend(instr.completions.proposed_prefix_for_output_field)
259266

260-
if self.verbose: print(f"{self.prompt_model.inspect_history(n=1)}")
267+
if self.verbose and self.prompt_model: print(f"{self.prompt_model.inspect_history(n=1)}")
261268
latest_candidates = new_candidates
262269

263270
candidates = []

0 commit comments

Comments
 (0)