Skip to content

Commit 5bc17d8

Browse files
Merge pull request stanfordnlp#999 from tom-doerr/add_logprob_support
Add logprob support for OpenAI models
2 parents 32449e1 + 4b48281 commit 5bc17d8

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

dsp/modules/gpt3.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,11 @@ def __call__(
185185
if only_completed and len(completed_choices):
186186
choices = completed_choices
187187

188-
completions = [self._get_choice_text(c) for c in choices]
188+
if kwargs.get("logprobs", False):
189+
completions = [{'text': self._get_choice_text(c), 'logprobs': c["logprobs"]} for c in choices]
190+
else:
191+
completions = [self._get_choice_text(c) for c in choices]
192+
189193
if return_sorted and kwargs.get("n", 1) > 1:
190194
scored_completions = []
191195

@@ -200,10 +204,12 @@ def __call__(
200204
tokens, logprobs = tokens[:index], logprobs[:index]
201205

202206
avglog = sum(logprobs) / len(logprobs)
203-
scored_completions.append((avglog, self._get_choice_text(c)))
204-
207+
scored_completions.append((avglog, self._get_choice_text(c), logprobs))
205208
scored_completions = sorted(scored_completions, reverse=True)
206-
completions = [c for _, c in scored_completions]
209+
if logprobs:
210+
completions = [{'text': c, 'logprobs': lp} for _, c, lp in scored_completions]
211+
else:
212+
completions = [c for _, c in scored_completions]
207213

208214
return completions
209215

0 commit comments

Comments
 (0)