Skip to content

Commit 231b120

Browse files
committed
Add support for logprob output
1 parent 657c086 commit 231b120

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

dsp/modules/gpt3.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,12 @@ def __call__(
185185
if only_completed and len(completed_choices):
186186
choices = completed_choices
187187

188-
completions = [self._get_choice_text(c) for c in choices]
188+
# if logprobs:
189+
if kwargs.get("logprobs", False):
190+
completions = [{'text': self._get_choice_text(c), 'logprobs': c["logprobs"]} for c in choices]
191+
else:
192+
completions = [self._get_choice_text(c) for c in choices]
193+
189194
if return_sorted and kwargs.get("n", 1) > 1:
190195
scored_completions = []
191196

@@ -200,10 +205,12 @@ def __call__(
200205
tokens, logprobs = tokens[:index], logprobs[:index]
201206

202207
avglog = sum(logprobs) / len(logprobs)
203-
scored_completions.append((avglog, self._get_choice_text(c)))
204-
208+
scored_completions.append((avglog, self._get_choice_text(c), logprobs))
205209
scored_completions = sorted(scored_completions, reverse=True)
206-
completions = [c for _, c in scored_completions]
210+
if logprobs:
211+
completions = [{'text': c, 'logprobs': lp} for _, c, lp in scored_completions]
212+
else:
213+
completions = [c for _, c in scored_completions]
207214

208215
return completions
209216

0 commit comments

Comments
 (0)