Skip to content

Commit 7dabab8

Browse files
Allow stream listener to work on any type (#8833)
* support streaming on any type * add decent testing * add comments * clean up * fix end condition * robust finish handling * comments * comments * add todo for clean code
1 parent c7119b5 commit 7dabab8

File tree

2 files changed

+456
-54
lines changed

2 files changed

+456
-54
lines changed

dspy/streaming/streaming_listener.py

Lines changed: 99 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from queue import Queue
44
from typing import TYPE_CHECKING, Any
55

6+
import jiter
67
from litellm import ModelResponseStream
78

89
from dspy.adapters.chat_adapter import ChatAdapter
@@ -49,6 +50,8 @@ def __init__(
4950
self.cache_hit = False
5051
self.allow_reuse = allow_reuse
5152

53+
self.json_adapter_state = {"field_accumulated_messages": ""}
54+
5255
self.adapter_identifiers = {
5356
"ChatAdapter": {
5457
"start_identifier": f"[[ ## {self.signature_field_name} ## ]]",
@@ -62,7 +65,7 @@ def __init__(
6265
"end_identifier": re.compile(r"\w*\"(,|\s*})"),
6366
"start_indicator": '"',
6467
"end_pattern_prefixes": ['"', '",', '" ', '"}'],
65-
"end_pattern_contains": None,
68+
"end_pattern_contains": "}",
6669
},
6770
"XMLAdapter": {
6871
"start_identifier": f"<{self.signature_field_name}>",
@@ -126,6 +129,7 @@ def receive(self, chunk: ModelResponseStream):
126129
self.cache_hit = False
127130
self.field_start_queue = []
128131
self.field_end_queue = Queue()
132+
self.json_adapter_state["field_accumulated_messages"] = ""
129133
self.stream_start = False
130134
else:
131135
return
@@ -147,7 +151,7 @@ def receive(self, chunk: ModelResponseStream):
147151
is_last_chunk=self.stream_end,
148152
)
149153

150-
if chunk_message and start_identifier in chunk_message:
154+
if chunk_message and start_identifier in chunk_message and not isinstance(settings.adapter, JSONAdapter):
151155
# If the cache is hit, the chunk_message could be the full response. When it happens we can
152156
# directly end the stream listening. In some models like gemini, each stream chunk can be multiple
153157
# tokens, so it's possible that response only has one chunk, we also fall back to this logic.
@@ -180,10 +184,13 @@ def receive(self, chunk: ModelResponseStream):
180184
# Keep the part after the start_identifier from the concat_message, we need to write it to the buffer.
181185
value_start_index = concat_message.find(start_identifier) + len(start_identifier)
182186
chunk_message = concat_message[value_start_index:].lstrip()
183-
if isinstance(settings.adapter, JSONAdapter) and chunk_message.startswith('"'):
184-
# For JSONAdapter, we need to remove the leading ". We cannot do this with the start_identifier
185-
# because there could be a few splitters between ':' and '"', e.g., '"name": "value"'.
186-
chunk_message = chunk_message[1:]
187+
188+
if isinstance(settings.adapter, JSONAdapter):
189+
# For JSONAdapter, we rely on partial json parsing to detect the end of the field we are listening
190+
# to, so we need to maintain a few extra states to help us with that.
191+
# We add an extra "{" to the beginning of the field_accumulated_messages, so we can detect the
192+
# appearance of the next key.
193+
self.json_adapter_state["field_accumulated_messages"] += "{" + start_identifier
187194

188195
elif self._buffered_message_end_with_start_identifier(concat_message.strip(), start_identifier):
189196
# If the buffered message ends with part of the start_identifier, we keep looking for the
@@ -196,30 +203,101 @@ def receive(self, chunk: ModelResponseStream):
196203

197204
if self.stream_start and chunk_message:
198205
# The stream is started, we keep returning the token until we see the start of the next field.
199-
token = None
200206
self.field_end_queue.put(chunk_message)
201207

208+
token = None
202209
concat_message = "".join(self.field_end_queue.queue).strip()
203-
if re.search(end_identifier, concat_message):
204-
# The next field is identified, we can end the stream and flush out all tokens in the buffer.
205-
self.stream_end = True
206-
token = self.flush()
207-
token = token.rstrip() # Remove the trailing \n\n
208-
elif not self._could_form_end_identifier(concat_message, adapter_name):
210+
211+
if not self._could_form_end_identifier(concat_message, adapter_name):
209212
# Buffer cannot form end identifier, safe to flush out the tokens in the buffer.
210213
token = self.flush()
211214
elif self.field_end_queue.qsize() > 10:
212-
# Buffer could form end identifier, but we've exceeded max buffer size
213-
# Yield the oldest token to prevent unbounded buffering
215+
# We keep the last 10 tokens in the buffer if they can potentially form the end_identifier to avoid
216+
# sending the DSPy boilerplate tokens to users. 10 is a heuristic number that is sufficient to capture
217+
# the end_identifier for all LMs.
214218
token = self.field_end_queue.get()
215219

216-
if token:
220+
# TODO: Put adapter streaming handling into individial classes, e.g., `JSONAdapterStreamListener`,
221+
# `ChatAdapterStreamListener`, `XMLAdapterStreamListener` instead of having many adhoc code in the
222+
# `StreamListener` class.
223+
if isinstance(settings.adapter, JSONAdapter):
224+
# JSONAdapter uses partial json parsing to detect the end of the field we are listening to, instead of
225+
# relying on the end_identifier.
226+
return self._json_adapter_handle_stream_chunk(token, chunk_message)
227+
else:
228+
# Other adapters rely on the end_identifier to detect the end of the field we are listening to.
229+
return self._default_handle_stream_chunk(token, end_identifier)
230+
231+
def _json_adapter_handle_stream_chunk(self, token: str, chunk_message: str) -> StreamResponse | None:
232+
self.json_adapter_state["field_accumulated_messages"] += chunk_message
233+
if self.json_adapter_state["field_accumulated_messages"].rstrip().endswith("}"):
234+
# When the accumulated tokens end with a curly bracket, that means the streaming for the `dspy.Predict` we
235+
# are listening to is probably finished, we need to run a check and decide whether to end the stream.
236+
try:
237+
# If the parse doesn't raise an error, that means the accumulated tokens is a valid json object. Because
238+
# we add an extra "{" to the beginning of the field_accumulated_messages, so we know the streaming is
239+
# finished.
240+
jiter.from_json(self.json_adapter_state["field_accumulated_messages"].encode("utf-8"))
241+
self.stream_end = True
242+
last_token = self.flush()
243+
right_curly_bracket_index = last_token.rfind("}")
244+
token = (
245+
token + last_token[:right_curly_bracket_index] if token else last_token[:right_curly_bracket_index]
246+
)
217247
return StreamResponse(
218-
self.predict_name,
219-
self.signature_field_name,
220-
token,
221-
is_last_chunk=self.stream_end,
248+
self.predict_name, self.signature_field_name, token, is_last_chunk=self.stream_end
222249
)
250+
except ValueError:
251+
pass
252+
253+
try:
254+
parsed = jiter.from_json(
255+
self.json_adapter_state["field_accumulated_messages"].encode("utf-8"),
256+
partial_mode="trailing-strings",
257+
)
258+
if len(parsed) > 1:
259+
# If partial json parsing finds a second key, that means the streaming for the field we are listening to
260+
# is finished.
261+
self.stream_end = True
262+
last_token = self.flush()
263+
264+
keys = list(parsed.keys())
265+
next_field_name = None
266+
for key in keys:
267+
if key != self.signature_field_name:
268+
next_field_name = key
269+
break
270+
271+
last_token_index = last_token.find(next_field_name)
272+
token = token + last_token[:last_token_index] if token else last_token[:last_token_index]
273+
except ValueError:
274+
pass
275+
276+
if token:
277+
return StreamResponse(
278+
self.predict_name,
279+
self.signature_field_name,
280+
token,
281+
is_last_chunk=self.stream_end,
282+
)
283+
284+
def _default_handle_stream_chunk(self, token: str, end_identifier: str) -> StreamResponse | None:
285+
concat_message = "".join(self.field_end_queue.queue).strip()
286+
287+
if re.search(end_identifier, concat_message):
288+
# The next field is identified, we can end the stream and flush out all tokens in the buffer.
289+
self.stream_end = True
290+
last_token = self.flush()
291+
token = token + last_token if token else last_token
292+
token = token.rstrip() # Remove the trailing \n\n
293+
294+
if token:
295+
return StreamResponse(
296+
self.predict_name,
297+
self.signature_field_name,
298+
token,
299+
is_last_chunk=self.stream_end,
300+
)
223301

224302
def flush(self) -> str:
225303
"""Flush all tokens in the field end queue.
@@ -231,12 +309,7 @@ def flush(self) -> str:
231309
last_tokens = "".join(self.field_end_queue.queue)
232310
self.field_end_queue = Queue()
233311
if isinstance(settings.adapter, JSONAdapter):
234-
match = re.search(r'",|"\s*}', last_tokens)
235-
if match:
236-
boundary_index = match.start()
237-
else:
238-
boundary_index = len(last_tokens)
239-
return last_tokens[:boundary_index]
312+
return last_tokens
240313
elif isinstance(settings.adapter, XMLAdapter):
241314
boundary_index = last_tokens.find(f"</{self.signature_field_name}>")
242315
if boundary_index == -1:
@@ -314,13 +387,6 @@ def find_predictor_for_stream_listeners(
314387
f"Signature field {field_name} is not unique in the program, cannot automatically determine which "
315388
"predictor to use for streaming. Please specify the predictor to listen to."
316389
)
317-
318-
if not _is_streamable(field_info.annotation):
319-
raise ValueError(
320-
f"Stream listener can only be applied to string or subclass of `dspy.Type` that has `is_streamable() == True`, "
321-
f"but your field {field_name} is of type {field_info.annotation}."
322-
)
323-
324390
field_name_to_named_predictor[field_name] = (name, predictor)
325391

326392
predict_id_to_listener = defaultdict(list)
@@ -337,13 +403,3 @@ def find_predictor_for_stream_listeners(
337403
listener.predict_name, listener.predict = field_name_to_named_predictor[listener.signature_field_name]
338404
predict_id_to_listener[id(listener.predict)].append(listener)
339405
return predict_id_to_listener
340-
341-
342-
def _is_streamable(field_type: type | None) -> bool:
343-
if field_type is None:
344-
return False
345-
if field_type is str:
346-
return True
347-
if issubclass(field_type, Type):
348-
return field_type.is_streamable()
349-
return False

0 commit comments

Comments
 (0)