Skip to content

Commit 637d759

Browse files
authored
fix style errors of dspy/streaming (#8187)
1 parent f971600 commit 637d759

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

dspy/streaming/streamify.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ async def streaming_response(streamer: AsyncGenerator) -> AsyncGenerator:
267267
"""
268268
async for value in streamer:
269269
if isinstance(value, Prediction):
270-
data = {"prediction": {k: v for k, v in value.items(include_dspy=False)}}
270+
data = {"prediction": dict(value.items(include_dspy=False))}
271271
yield f"data: {ujson.dumps(data)}\n\n"
272272
elif isinstance(value, litellm.ModelResponseStream):
273273
data = {"chunk": value.json()}

dspy/streaming/streaming_listener.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import re
22
from collections import defaultdict
33
from queue import Queue
4-
from typing import TYPE_CHECKING, Any, List
4+
from typing import TYPE_CHECKING, Any, List, Optional
55

66
from litellm import ModelResponseStream
77

@@ -17,7 +17,7 @@
1717
class StreamListener:
1818
"""Class that listens to the stream to capture the streeaming of a specific output field of a predictor."""
1919

20-
def __init__(self, signature_field_name: str, predict: Any = None, predict_name: str = None):
20+
def __init__(self, signature_field_name: str, predict: Any = None, predict_name: Optional[str] = None):
2121
"""
2222
Args:
2323
signature_field_name: The name of the field to listen to.
@@ -36,7 +36,7 @@ def __init__(self, signature_field_name: str, predict: Any = None, predict_name:
3636
self.stream_end = False
3737
self.cache_hit = False
3838

39-
self.json_adapter_start_identifier = f'{{"{self.signature_field_name}":"' # noqa: Q000
39+
self.json_adapter_start_identifier = f'{{"{self.signature_field_name}":"'
4040
self.json_adapter_end_identifier = re.compile(r"\w*\",\w*")
4141

4242
self.chat_adapter_start_identifier = f"[[ ## {self.signature_field_name} ## ]]"
@@ -130,7 +130,7 @@ def flush(self) -> str:
130130
last_tokens = "".join(self.field_end_queue.queue)
131131
self.field_end_queue = Queue()
132132
if isinstance(settings.adapter, JSONAdapter):
133-
boundary_index = last_tokens.find('",') # noqa: Q000
133+
boundary_index = last_tokens.find('",')
134134
return last_tokens[:boundary_index]
135135
elif isinstance(settings.adapter, ChatAdapter) or settings.adapter is None:
136136
boundary_index = last_tokens.find("[[")

0 commit comments

Comments
 (0)