11import re
22from collections import defaultdict
33from queue import Queue
4- from typing import TYPE_CHECKING , Any , List
4+ from typing import TYPE_CHECKING , Any , List , Optional
55
66from litellm import ModelResponseStream
77
1717class StreamListener :
1818 """Class that listens to the stream to capture the streeaming of a specific output field of a predictor."""
1919
20- def __init__ (self , signature_field_name : str , predict : Any = None , predict_name : str = None ):
20+ def __init__ (self , signature_field_name : str , predict : Any = None , predict_name : Optional [ str ] = None ):
2121 """
2222 Args:
2323 signature_field_name: The name of the field to listen to.
@@ -36,7 +36,7 @@ def __init__(self, signature_field_name: str, predict: Any = None, predict_name:
3636 self .stream_end = False
3737 self .cache_hit = False
3838
39- self .json_adapter_start_identifier = f'{{"{ self .signature_field_name } ":"' # noqa: Q000
39+ self .json_adapter_start_identifier = f'{{"{ self .signature_field_name } ":"'
4040 self .json_adapter_end_identifier = re .compile (r"\w*\",\w*" )
4141
4242 self .chat_adapter_start_identifier = f"[[ ## { self .signature_field_name } ## ]]"
@@ -130,7 +130,7 @@ def flush(self) -> str:
130130 last_tokens = "" .join (self .field_end_queue .queue )
131131 self .field_end_queue = Queue ()
132132 if isinstance (settings .adapter , JSONAdapter ):
133- boundary_index = last_tokens .find ('",' ) # noqa: Q000
133+ boundary_index = last_tokens .find ('",' )
134134 return last_tokens [:boundary_index ]
135135 elif isinstance (settings .adapter , ChatAdapter ) or settings .adapter is None :
136136 boundary_index = last_tokens .find ("[[" )
0 commit comments