@@ -884,43 +884,7 @@ def extract_bedrock_converse_attrs(kwargs, response, response_headers, model, sp
884884 return bedrock_attrs
885885
886886
887- class EventStreamWrapper (ObjectProxy ):
888- def __iter__ (self ):
889- g = GeneratorProxy (self .__wrapped__ .__iter__ ())
890- g ._nr_ft = getattr (self , "_nr_ft" , None )
891- g ._nr_bedrock_attrs = getattr (self , "_nr_bedrock_attrs" , {})
892- g ._nr_model_extractor = getattr (self , "_nr_model_extractor" , NULL_EXTRACTOR )
893- g ._nr_is_converse = getattr (self , "_nr_is_converse" , False )
894- return g
895-
896-
897- class GeneratorProxy (ObjectProxy ):
898- def __init__ (self , wrapped ):
899- super ().__init__ (wrapped )
900-
901- def __iter__ (self ):
902- return self
903-
904- def __next__ (self ):
905- transaction = current_transaction ()
906- if not transaction :
907- return self .__wrapped__ .__next__ ()
908-
909- return_val = None
910- try :
911- return_val = self .__wrapped__ .__next__ ()
912- self .record_stream_chunk (return_val , transaction )
913- except StopIteration :
914- self .record_events_on_stop_iteration (transaction )
915- raise
916- except Exception as exc :
917- self .record_error (transaction , exc )
918- raise
919- return return_val
920-
921- def close (self ):
922- return super ().close ()
923-
887+ class BedrockRecordEventMixin :
924888 def record_events_on_stop_iteration (self , transaction ):
925889 if hasattr (self , "_nr_ft" ):
926890 bedrock_attrs = getattr (self , "_nr_bedrock_attrs" , {})
@@ -1009,6 +973,44 @@ def converse_record_stream_chunk(self, event, transaction):
1009973 # self.record_events_on_stop_iteration(transaction)
1010974
1011975
976+ class EventStreamWrapper (ObjectProxy ):
977+ def __iter__ (self ):
978+ g = GeneratorProxy (self .__wrapped__ .__iter__ ())
979+ g ._nr_ft = getattr (self , "_nr_ft" , None )
980+ g ._nr_bedrock_attrs = getattr (self , "_nr_bedrock_attrs" , {})
981+ g ._nr_model_extractor = getattr (self , "_nr_model_extractor" , NULL_EXTRACTOR )
982+ g ._nr_is_converse = getattr (self , "_nr_is_converse" , False )
983+ return g
984+
985+
986+ class GeneratorProxy (BedrockRecordEventMixin , ObjectProxy ):
987+ def __init__ (self , wrapped ):
988+ super ().__init__ (wrapped )
989+
990+ def __iter__ (self ):
991+ return self
992+
993+ def __next__ (self ):
994+ transaction = current_transaction ()
995+ if not transaction :
996+ return self .__wrapped__ .__next__ ()
997+
998+ return_val = None
999+ try :
1000+ return_val = self .__wrapped__ .__next__ ()
1001+ self .record_stream_chunk (return_val , transaction )
1002+ except StopIteration :
1003+ self .record_events_on_stop_iteration (transaction )
1004+ raise
1005+ except Exception as exc :
1006+ self .record_error (transaction , exc )
1007+ raise
1008+ return return_val
1009+
1010+ def close (self ):
1011+ return super ().close ()
1012+
1013+
10121014class AsyncEventStreamWrapper (ObjectProxy ):
10131015 def __aiter__ (self ):
10141016 g = AsyncGeneratorProxy (self .__wrapped__ .__aiter__ ())
@@ -1019,13 +1021,7 @@ def __aiter__(self):
10191021 return g
10201022
10211023
1022- class AsyncGeneratorProxy (ObjectProxy ):
1023- # Import these methods from the synchronous GeneratorProxy
1024- # Avoid direct inheritance so we don't implement both __iter__ and __aiter__
1025- record_stream_chunk = GeneratorProxy .record_stream_chunk
1026- record_events_on_stop_iteration = GeneratorProxy .record_events_on_stop_iteration
1027- record_error = GeneratorProxy .record_error
1028-
1024+ class AsyncGeneratorProxy (BedrockRecordEventMixin , ObjectProxy ):
10291025 def __aiter__ (self ):
10301026 return self
10311027
0 commit comments