Skip to content

Commit a357269

Browse files
committed
Instrument converse streaming
1 parent debafbc commit a357269

File tree

1 file changed

+121
-77
lines changed

1 file changed

+121
-77
lines changed

newrelic/hooks/external_botocore.py

Lines changed: 121 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,16 @@ def _wrap_bedrock_runtime_converse(wrapped, instance, args, kwargs):
826826
bedrock_attrs = extract_bedrock_converse_attrs(kwargs, response, response_headers, model, span_id, trace_id)
827827

828828
try:
829+
if response_streaming:
830+
# Wrap EventStream object here to intercept __iter__ method instead of instrumenting class.
831+
# This class is used in numerous other services in botocore, and would cause conflicts.
832+
response["stream"] = stream = EventStreamWrapper(response["stream"])
833+
stream._nr_ft = ft
834+
stream._nr_bedrock_attrs = bedrock_attrs
835+
stream._nr_model_extractor = stream_extractor
836+
stream._nr_is_converse = True
837+
return response
838+
829839
ft.__exit__(None, None, None)
830840
bedrock_attrs["duration"] = ft.duration * 1000
831841
run_bedrock_response_extractor(response_extractor, {}, bedrock_attrs, False, transaction)
@@ -840,6 +850,7 @@ def _wrap_bedrock_runtime_converse(wrapped, instance, args, kwargs):
840850

841851
def extract_bedrock_converse_attrs(kwargs, response, response_headers, model, span_id, trace_id):
842852
input_message_list = []
853+
output_message_list = None
843854
# If a system message is supplied, it is under its own key in kwargs rather than with the other input messages
844855
if "system" in kwargs.keys():
845856
input_message_list.extend({"role": "system", "content": result["text"]} for result in kwargs.get("system", []))
@@ -850,22 +861,26 @@ def extract_bedrock_converse_attrs(kwargs, response, response_headers, model, sp
850861
[{"role": "user", "content": result["text"]} for result in kwargs["messages"][-1].get("content", [])]
851862
)
852863

853-
output_message_list = [
854-
{"role": "assistant", "content": result["text"]}
855-
for result in response.get("output").get("message").get("content", [])
856-
]
864+
if "output" in response:
865+
output_message_list = [
866+
{"role": "assistant", "content": result["text"]}
867+
for result in response.get("output").get("message").get("content", [])
868+
]
857869

858870
bedrock_attrs = {
859871
"request_id": response_headers.get("x-amzn-requestid"),
860872
"model": model,
861873
"span_id": span_id,
862874
"trace_id": trace_id,
863875
"response.choices.finish_reason": response.get("stopReason"),
864-
"output_message_list": output_message_list,
865876
"request.max_tokens": kwargs.get("inferenceConfig", {}).get("maxTokens", None),
866877
"request.temperature": kwargs.get("inferenceConfig", {}).get("temperature", None),
867878
"input_message_list": input_message_list,
868879
}
880+
881+
if output_message_list is not None:
882+
bedrock_attrs["output_message_list"] = output_message_list
883+
869884
return bedrock_attrs
870885

871886

@@ -875,6 +890,7 @@ def __iter__(self):
875890
g._nr_ft = getattr(self, "_nr_ft", None)
876891
g._nr_bedrock_attrs = getattr(self, "_nr_bedrock_attrs", {})
877892
g._nr_model_extractor = getattr(self, "_nr_model_extractor", NULL_EXTRACTOR)
893+
g._nr_is_converse = getattr(self, "_nr_is_converse", False)
878894
return g
879895

880896

@@ -893,31 +909,122 @@ def __next__(self):
893909
return_val = None
894910
try:
895911
return_val = self.__wrapped__.__next__()
896-
record_stream_chunk(self, return_val, transaction)
912+
self.record_stream_chunk(return_val, transaction)
897913
except StopIteration:
898-
record_events_on_stop_iteration(self, transaction)
914+
self.record_events_on_stop_iteration(transaction)
899915
raise
900916
except Exception as exc:
901-
record_error(self, transaction, exc)
917+
self.record_error(transaction, exc)
902918
raise
903919
return return_val
904920

905921
def close(self):
906922
return super().close()
907923

924+
def record_events_on_stop_iteration(self, transaction):
925+
if hasattr(self, "_nr_ft"):
926+
bedrock_attrs = getattr(self, "_nr_bedrock_attrs", {})
927+
self._nr_ft.__exit__(None, None, None)
928+
929+
# If there are no bedrock attrs exit early as there's no data to record.
930+
if not bedrock_attrs:
931+
return
932+
933+
try:
934+
bedrock_attrs["duration"] = self._nr_ft.duration * 1000
935+
handle_chat_completion_event(transaction, bedrock_attrs)
936+
except Exception:
937+
_logger.warning(RESPONSE_PROCESSING_FAILURE_LOG_MESSAGE, exc_info=True)
938+
939+
# Clear cached data as this can be very large.
940+
self._nr_bedrock_attrs.clear()
941+
942+
def record_error(self, transaction, exc):
943+
if hasattr(self, "_nr_ft"):
944+
try:
945+
ft = self._nr_ft
946+
error_attributes = getattr(self, "_nr_bedrock_attrs", {})
947+
948+
# If there are no bedrock attrs exit early as there's no data to record.
949+
if not error_attributes:
950+
return
951+
952+
error_attributes = bedrock_error_attributes(exc, error_attributes)
953+
notice_error_attributes = {
954+
"http.statusCode": error_attributes.get("http.statusCode"),
955+
"error.message": error_attributes.get("error.message"),
956+
"error.code": error_attributes.get("error.code"),
957+
}
958+
notice_error_attributes.update({"completion_id": str(uuid.uuid4())})
959+
960+
ft.notice_error(attributes=notice_error_attributes)
961+
962+
ft.__exit__(*sys.exc_info())
963+
error_attributes["duration"] = ft.duration * 1000
964+
965+
handle_chat_completion_event(transaction, error_attributes)
966+
967+
# Clear cached data as this can be very large.
968+
error_attributes.clear()
969+
except Exception:
970+
_logger.warning(EXCEPTION_HANDLING_FAILURE_LOG_MESSAGE, exc_info=True)
971+
972+
def record_stream_chunk(self, event, transaction):
973+
if event:
974+
try:
975+
if getattr(self, "_nr_is_converse", False):
976+
return self.converse_record_stream_chunk(event, transaction)
977+
else:
978+
return self.invoke_record_stream_chunk(event, transaction)
979+
except Exception:
980+
_logger.warning(RESPONSE_EXTRACTOR_FAILURE_LOG_MESSAGE, exc_info=True)
981+
982+
def invoke_record_stream_chunk(self, event, transaction):
983+
bedrock_attrs = getattr(self, "_nr_bedrock_attrs", {})
984+
chunk = json.loads(event["chunk"]["bytes"].decode("utf-8"))
985+
self._nr_model_extractor(chunk, bedrock_attrs)
986+
# In Langchain, the bedrock iterator exits early if type is "content_block_stop".
987+
# So we need to call the record events here since stop iteration will not be raised.
988+
_type = chunk.get("type")
989+
if _type == "content_block_stop":
990+
self.record_events_on_stop_iteration(transaction)
991+
992+
def converse_record_stream_chunk(self, event, transaction):
993+
bedrock_attrs = getattr(self, "_nr_bedrock_attrs", {})
994+
if "contentBlockDelta" in event:
995+
if not bedrock_attrs:
996+
return
997+
998+
content = ((event.get("contentBlockDelta") or {}).get("delta") or {}).get("text", "")
999+
if "output_message_list" not in bedrock_attrs:
1000+
bedrock_attrs["output_message_list"] = [{"role": "assistant", "content": ""}]
1001+
bedrock_attrs["output_message_list"][0]["content"] += content
1002+
1003+
if "messageStop" in event:
1004+
bedrock_attrs["response.choices.finish_reason"] = (event.get("messageStop") or {}).get("stopReason", "")
1005+
1006+
# TODO: Is this also subject to the content_block_stop behavior from Langchain?
1007+
# If so, that would preclude us from ever capturing the messageStop event with the stopReason.
1008+
# if "contentBlockStop" in event:
1009+
# self.record_events_on_stop_iteration(transaction)
1010+
9081011

9091012
class AsyncEventStreamWrapper(ObjectProxy):
9101013
def __aiter__(self):
9111014
g = AsyncGeneratorProxy(self.__wrapped__.__aiter__())
9121015
g._nr_ft = getattr(self, "_nr_ft", None)
9131016
g._nr_bedrock_attrs = getattr(self, "_nr_bedrock_attrs", {})
9141017
g._nr_model_extractor = getattr(self, "_nr_model_extractor", NULL_EXTRACTOR)
1018+
g._nr_is_converse = getattr(self, "_nr_is_converse", False)
9151019
return g
9161020

9171021

9181022
class AsyncGeneratorProxy(ObjectProxy):
919-
def __init__(self, wrapped):
920-
super().__init__(wrapped)
1023+
# Import these methods from the synchronous GeneratorProxy
1024+
# Avoid direct inheritance so we don't implement both __iter__ and __aiter__
1025+
record_stream_chunk = GeneratorProxy.record_stream_chunk
1026+
record_events_on_stop_iteration = GeneratorProxy.record_events_on_stop_iteration
1027+
record_error = GeneratorProxy.record_error
9211028

9221029
def __aiter__(self):
9231030
return self
@@ -929,83 +1036,19 @@ async def __anext__(self):
9291036
return_val = None
9301037
try:
9311038
return_val = await self.__wrapped__.__anext__()
932-
record_stream_chunk(self, return_val, transaction)
1039+
self.record_stream_chunk(return_val, transaction)
9331040
except StopAsyncIteration:
934-
record_events_on_stop_iteration(self, transaction)
1041+
self.record_events_on_stop_iteration(transaction)
9351042
raise
9361043
except Exception as exc:
937-
record_error(self, transaction, exc)
1044+
self.record_error(transaction, exc)
9381045
raise
9391046
return return_val
9401047

9411048
async def aclose(self):
9421049
return await super().aclose()
9431050

9441051

945-
def record_stream_chunk(self, return_val, transaction):
946-
if return_val:
947-
try:
948-
chunk = json.loads(return_val["chunk"]["bytes"].decode("utf-8"))
949-
self._nr_model_extractor(chunk, self._nr_bedrock_attrs)
950-
# In Langchain, the bedrock iterator exits early if type is "content_block_stop".
951-
# So we need to call the record events here since stop iteration will not be raised.
952-
_type = chunk.get("type")
953-
if _type == "content_block_stop":
954-
record_events_on_stop_iteration(self, transaction)
955-
except Exception:
956-
_logger.warning(RESPONSE_EXTRACTOR_FAILURE_LOG_MESSAGE, exc_info=True)
957-
958-
959-
def record_events_on_stop_iteration(self, transaction):
960-
if hasattr(self, "_nr_ft"):
961-
bedrock_attrs = getattr(self, "_nr_bedrock_attrs", {})
962-
self._nr_ft.__exit__(None, None, None)
963-
964-
# If there are no bedrock attrs exit early as there's no data to record.
965-
if not bedrock_attrs:
966-
return
967-
968-
try:
969-
bedrock_attrs["duration"] = self._nr_ft.duration * 1000
970-
handle_chat_completion_event(transaction, bedrock_attrs)
971-
except Exception:
972-
_logger.warning(RESPONSE_PROCESSING_FAILURE_LOG_MESSAGE, exc_info=True)
973-
974-
# Clear cached data as this can be very large.
975-
self._nr_bedrock_attrs.clear()
976-
977-
978-
def record_error(self, transaction, exc):
979-
if hasattr(self, "_nr_ft"):
980-
try:
981-
ft = self._nr_ft
982-
error_attributes = getattr(self, "_nr_bedrock_attrs", {})
983-
984-
# If there are no bedrock attrs exit early as there's no data to record.
985-
if not error_attributes:
986-
return
987-
988-
error_attributes = bedrock_error_attributes(exc, error_attributes)
989-
notice_error_attributes = {
990-
"http.statusCode": error_attributes.get("http.statusCode"),
991-
"error.message": error_attributes.get("error.message"),
992-
"error.code": error_attributes.get("error.code"),
993-
}
994-
notice_error_attributes.update({"completion_id": str(uuid.uuid4())})
995-
996-
ft.notice_error(attributes=notice_error_attributes)
997-
998-
ft.__exit__(*sys.exc_info())
999-
error_attributes["duration"] = ft.duration * 1000
1000-
1001-
handle_chat_completion_event(transaction, error_attributes)
1002-
1003-
# Clear cached data as this can be very large.
1004-
error_attributes.clear()
1005-
except Exception:
1006-
_logger.warning(EXCEPTION_HANDLING_FAILURE_LOG_MESSAGE, exc_info=True)
1007-
1008-
10091052
def handle_embedding_event(transaction, bedrock_attrs):
10101053
embedding_id = str(uuid.uuid4())
10111054

@@ -1551,6 +1594,7 @@ def wrap_serialize_to_request(wrapped, instance, args, kwargs):
15511594
response_streaming=True
15521595
),
15531596
("bedrock-runtime", "converse"): wrap_bedrock_runtime_converse(response_streaming=False),
1597+
("bedrock-runtime", "converse_stream"): wrap_bedrock_runtime_converse(response_streaming=True),
15541598
}
15551599

15561600

0 commit comments

Comments
 (0)