Skip to content

Commit 1e04cbb

Browse files
Merge pull request #938 from sysid/main
fix(dsp): indicate other completions only if choices is list, not string
2 parents 7b76165 + 59241d5 commit 1e04cbb

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

dsp/modules/aws_models.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ def __init__(
4949
self._max_context_size: int = max_context_size
5050
self._max_new_tokens: int = max_new_tokens
5151

52+
# make it consistent with equivalent LM::max_token
53+
self.kwargs["max_tokens"] = max_new_tokens
54+
5255
self.kwargs = {
5356
**self.kwargs,
5457
**kwargs,
@@ -63,7 +66,7 @@ def _call_model(self, body: str) -> str | list[str]:
6366
"""Call model, get generated input without the formatted prompt."""
6467

6568
def _estimate_tokens(self, text: str) -> int:
66-
return len(text)/CHARS2TOKENS
69+
return len(text) / CHARS2TOKENS
6770

6871
def _extract_input_parameters(
6972
self,

dsp/modules/lm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def inspect_history(self, n: int = 1, skip: int = 0):
7676
text = choices
7777
elif provider == "openai" or provider == "ollama":
7878
text = ' ' + self._get_choice_text(choices[0]).strip()
79-
elif provider == "clarifai" or provider == "claude" :
80-
text=choices
79+
elif provider == "clarifai" or provider == "claude":
80+
text = choices
8181
elif provider == "groq":
8282
text = ' ' + choices
8383
elif provider == "google":
@@ -88,7 +88,7 @@ def inspect_history(self, n: int = 1, skip: int = 0):
8888
text = choices[0]["text"]
8989
printing_value += self.print_green(text, end="")
9090

91-
if len(choices) > 1:
91+
if len(choices) > 1 and isinstance(choices, list):
9292
printing_value += self.print_red(f" \t (and {len(choices)-1} other completions)", end="")
9393

9494
printing_value += "\n\n\n"

0 commit comments

Comments
 (0)