Skip to content

Commit 76052b2

Browse files
committed
Instrument converse streaming
1 parent 40b512f commit 76052b2

File tree

1 file changed

+121
-77
lines changed

1 file changed

+121
-77
lines changed

newrelic/hooks/external_botocore.py

Lines changed: 121 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,16 @@ def _wrap_bedrock_runtime_converse(wrapped, instance, args, kwargs):
819819
bedrock_attrs = extract_bedrock_converse_attrs(kwargs, response, response_headers, model, span_id, trace_id)
820820

821821
try:
822+
if response_streaming:
823+
# Wrap EventStream object here to intercept __iter__ method instead of instrumenting class.
824+
# This class is used in numerous other services in botocore, and would cause conflicts.
825+
response["stream"] = stream = EventStreamWrapper(response["stream"])
826+
stream._nr_ft = ft
827+
stream._nr_bedrock_attrs = bedrock_attrs
828+
stream._nr_model_extractor = stream_extractor
829+
stream._nr_is_converse = True
830+
return response
831+
822832
ft.__exit__(None, None, None)
823833
bedrock_attrs["duration"] = ft.duration * 1000
824834
run_bedrock_response_extractor(response_extractor, {}, bedrock_attrs, False, transaction)
@@ -833,6 +843,7 @@ def _wrap_bedrock_runtime_converse(wrapped, instance, args, kwargs):
833843

834844
def extract_bedrock_converse_attrs(kwargs, response, response_headers, model, span_id, trace_id):
835845
input_message_list = []
846+
output_message_list = None
836847
# If a system message is supplied, it is under its own key in kwargs rather than with the other input messages
837848
if "system" in kwargs.keys():
838849
input_message_list.extend({"role": "system", "content": result["text"]} for result in kwargs.get("system", []))
@@ -843,22 +854,26 @@ def extract_bedrock_converse_attrs(kwargs, response, response_headers, model, sp
843854
[{"role": "user", "content": result["text"]} for result in kwargs["messages"][-1].get("content", [])]
844855
)
845856

846-
output_message_list = [
847-
{"role": "assistant", "content": result["text"]}
848-
for result in response.get("output").get("message").get("content", [])
849-
]
857+
if "output" in response:
858+
output_message_list = [
859+
{"role": "assistant", "content": result["text"]}
860+
for result in response.get("output").get("message").get("content", [])
861+
]
850862

851863
bedrock_attrs = {
852864
"request_id": response_headers.get("x-amzn-requestid"),
853865
"model": model,
854866
"span_id": span_id,
855867
"trace_id": trace_id,
856868
"response.choices.finish_reason": response.get("stopReason"),
857-
"output_message_list": output_message_list,
858869
"request.max_tokens": kwargs.get("inferenceConfig", {}).get("maxTokens", None),
859870
"request.temperature": kwargs.get("inferenceConfig", {}).get("temperature", None),
860871
"input_message_list": input_message_list,
861872
}
873+
874+
if output_message_list is not None:
875+
bedrock_attrs["output_message_list"] = output_message_list
876+
862877
return bedrock_attrs
863878

864879

@@ -868,6 +883,7 @@ def __iter__(self):
868883
g._nr_ft = getattr(self, "_nr_ft", None)
869884
g._nr_bedrock_attrs = getattr(self, "_nr_bedrock_attrs", {})
870885
g._nr_model_extractor = getattr(self, "_nr_model_extractor", NULL_EXTRACTOR)
886+
g._nr_is_converse = getattr(self, "_nr_is_converse", False)
871887
return g
872888

873889

@@ -886,31 +902,122 @@ def __next__(self):
886902
return_val = None
887903
try:
888904
return_val = self.__wrapped__.__next__()
889-
record_stream_chunk(self, return_val, transaction)
905+
self.record_stream_chunk(return_val, transaction)
890906
except StopIteration:
891-
record_events_on_stop_iteration(self, transaction)
907+
self.record_events_on_stop_iteration(transaction)
892908
raise
893909
except Exception as exc:
894-
record_error(self, transaction, exc)
910+
self.record_error(transaction, exc)
895911
raise
896912
return return_val
897913

898914
def close(self):
899915
return super().close()
900916

917+
def record_events_on_stop_iteration(self, transaction):
918+
if hasattr(self, "_nr_ft"):
919+
bedrock_attrs = getattr(self, "_nr_bedrock_attrs", {})
920+
self._nr_ft.__exit__(None, None, None)
921+
922+
# If there are no bedrock attrs exit early as there's no data to record.
923+
if not bedrock_attrs:
924+
return
925+
926+
try:
927+
bedrock_attrs["duration"] = self._nr_ft.duration * 1000
928+
handle_chat_completion_event(transaction, bedrock_attrs)
929+
except Exception:
930+
_logger.warning(RESPONSE_PROCESSING_FAILURE_LOG_MESSAGE, exc_info=True)
931+
932+
# Clear cached data as this can be very large.
933+
self._nr_bedrock_attrs.clear()
934+
935+
def record_error(self, transaction, exc):
936+
if hasattr(self, "_nr_ft"):
937+
try:
938+
ft = self._nr_ft
939+
error_attributes = getattr(self, "_nr_bedrock_attrs", {})
940+
941+
# If there are no bedrock attrs exit early as there's no data to record.
942+
if not error_attributes:
943+
return
944+
945+
error_attributes = bedrock_error_attributes(exc, error_attributes)
946+
notice_error_attributes = {
947+
"http.statusCode": error_attributes.get("http.statusCode"),
948+
"error.message": error_attributes.get("error.message"),
949+
"error.code": error_attributes.get("error.code"),
950+
}
951+
notice_error_attributes.update({"completion_id": str(uuid.uuid4())})
952+
953+
ft.notice_error(attributes=notice_error_attributes)
954+
955+
ft.__exit__(*sys.exc_info())
956+
error_attributes["duration"] = ft.duration * 1000
957+
958+
handle_chat_completion_event(transaction, error_attributes)
959+
960+
# Clear cached data as this can be very large.
961+
error_attributes.clear()
962+
except Exception:
963+
_logger.warning(EXCEPTION_HANDLING_FAILURE_LOG_MESSAGE, exc_info=True)
964+
965+
def record_stream_chunk(self, event, transaction):
966+
if event:
967+
try:
968+
if getattr(self, "_nr_is_converse", False):
969+
return self.converse_record_stream_chunk(event, transaction)
970+
else:
971+
return self.invoke_record_stream_chunk(event, transaction)
972+
except Exception:
973+
_logger.warning(RESPONSE_EXTRACTOR_FAILURE_LOG_MESSAGE, exc_info=True)
974+
975+
def invoke_record_stream_chunk(self, event, transaction):
976+
bedrock_attrs = getattr(self, "_nr_bedrock_attrs", {})
977+
chunk = json.loads(event["chunk"]["bytes"].decode("utf-8"))
978+
self._nr_model_extractor(chunk, bedrock_attrs)
979+
# In Langchain, the bedrock iterator exits early if type is "content_block_stop".
980+
# So we need to call the record events here since stop iteration will not be raised.
981+
_type = chunk.get("type")
982+
if _type == "content_block_stop":
983+
self.record_events_on_stop_iteration(transaction)
984+
985+
def converse_record_stream_chunk(self, event, transaction):
986+
bedrock_attrs = getattr(self, "_nr_bedrock_attrs", {})
987+
if "contentBlockDelta" in event:
988+
if not bedrock_attrs:
989+
return
990+
991+
content = ((event.get("contentBlockDelta") or {}).get("delta") or {}).get("text", "")
992+
if "output_message_list" not in bedrock_attrs:
993+
bedrock_attrs["output_message_list"] = [{"role": "assistant", "content": ""}]
994+
bedrock_attrs["output_message_list"][0]["content"] += content
995+
996+
if "messageStop" in event:
997+
bedrock_attrs["response.choices.finish_reason"] = (event.get("messageStop") or {}).get("stopReason", "")
998+
999+
# TODO: Is this also subject to the content_block_stop behavior from Langchain?
1000+
# If so, that would preclude us from ever capturing the messageStop event with the stopReason.
1001+
# if "contentBlockStop" in event:
1002+
# self.record_events_on_stop_iteration(transaction)
1003+
9011004

9021005
class AsyncEventStreamWrapper(ObjectProxy):
9031006
def __aiter__(self):
9041007
g = AsyncGeneratorProxy(self.__wrapped__.__aiter__())
9051008
g._nr_ft = getattr(self, "_nr_ft", None)
9061009
g._nr_bedrock_attrs = getattr(self, "_nr_bedrock_attrs", {})
9071010
g._nr_model_extractor = getattr(self, "_nr_model_extractor", NULL_EXTRACTOR)
1011+
g._nr_is_converse = getattr(self, "_nr_is_converse", False)
9081012
return g
9091013

9101014

9111015
class AsyncGeneratorProxy(ObjectProxy):
912-
def __init__(self, wrapped):
913-
super().__init__(wrapped)
1016+
# Import these methods from the synchronous GeneratorProxy
1017+
# Avoid direct inheritance so we don't implement both __iter__ and __aiter__
1018+
record_stream_chunk = GeneratorProxy.record_stream_chunk
1019+
record_events_on_stop_iteration = GeneratorProxy.record_events_on_stop_iteration
1020+
record_error = GeneratorProxy.record_error
9141021

9151022
def __aiter__(self):
9161023
return self
@@ -922,83 +1029,19 @@ async def __anext__(self):
9221029
return_val = None
9231030
try:
9241031
return_val = await self.__wrapped__.__anext__()
925-
record_stream_chunk(self, return_val, transaction)
1032+
self.record_stream_chunk(return_val, transaction)
9261033
except StopAsyncIteration:
927-
record_events_on_stop_iteration(self, transaction)
1034+
self.record_events_on_stop_iteration(transaction)
9281035
raise
9291036
except Exception as exc:
930-
record_error(self, transaction, exc)
1037+
self.record_error(transaction, exc)
9311038
raise
9321039
return return_val
9331040

9341041
async def aclose(self):
9351042
return await super().aclose()
9361043

9371044

938-
def record_stream_chunk(self, return_val, transaction):
939-
if return_val:
940-
try:
941-
chunk = json.loads(return_val["chunk"]["bytes"].decode("utf-8"))
942-
self._nr_model_extractor(chunk, self._nr_bedrock_attrs)
943-
# In Langchain, the bedrock iterator exits early if type is "content_block_stop".
944-
# So we need to call the record events here since stop iteration will not be raised.
945-
_type = chunk.get("type")
946-
if _type == "content_block_stop":
947-
record_events_on_stop_iteration(self, transaction)
948-
except Exception:
949-
_logger.warning(RESPONSE_EXTRACTOR_FAILURE_LOG_MESSAGE, exc_info=True)
950-
951-
952-
def record_events_on_stop_iteration(self, transaction):
953-
if hasattr(self, "_nr_ft"):
954-
bedrock_attrs = getattr(self, "_nr_bedrock_attrs", {})
955-
self._nr_ft.__exit__(None, None, None)
956-
957-
# If there are no bedrock attrs exit early as there's no data to record.
958-
if not bedrock_attrs:
959-
return
960-
961-
try:
962-
bedrock_attrs["duration"] = self._nr_ft.duration * 1000
963-
handle_chat_completion_event(transaction, bedrock_attrs)
964-
except Exception:
965-
_logger.warning(RESPONSE_PROCESSING_FAILURE_LOG_MESSAGE, exc_info=True)
966-
967-
# Clear cached data as this can be very large.
968-
self._nr_bedrock_attrs.clear()
969-
970-
971-
def record_error(self, transaction, exc):
972-
if hasattr(self, "_nr_ft"):
973-
try:
974-
ft = self._nr_ft
975-
error_attributes = getattr(self, "_nr_bedrock_attrs", {})
976-
977-
# If there are no bedrock attrs exit early as there's no data to record.
978-
if not error_attributes:
979-
return
980-
981-
error_attributes = bedrock_error_attributes(exc, error_attributes)
982-
notice_error_attributes = {
983-
"http.statusCode": error_attributes.get("http.statusCode"),
984-
"error.message": error_attributes.get("error.message"),
985-
"error.code": error_attributes.get("error.code"),
986-
}
987-
notice_error_attributes.update({"completion_id": str(uuid.uuid4())})
988-
989-
ft.notice_error(attributes=notice_error_attributes)
990-
991-
ft.__exit__(*sys.exc_info())
992-
error_attributes["duration"] = ft.duration * 1000
993-
994-
handle_chat_completion_event(transaction, error_attributes)
995-
996-
# Clear cached data as this can be very large.
997-
error_attributes.clear()
998-
except Exception:
999-
_logger.warning(EXCEPTION_HANDLING_FAILURE_LOG_MESSAGE, exc_info=True)
1000-
1001-
10021045
def handle_embedding_event(transaction, bedrock_attrs):
10031046
embedding_id = str(uuid.uuid4())
10041047

@@ -1529,6 +1572,7 @@ def wrap_serialize_to_request(wrapped, instance, args, kwargs):
15291572
response_streaming=True
15301573
),
15311574
("bedrock-runtime", "converse"): wrap_bedrock_runtime_converse(response_streaming=False),
1575+
("bedrock-runtime", "converse_stream"): wrap_bedrock_runtime_converse(response_streaming=True),
15321576
}
15331577

15341578

0 commit comments

Comments
 (0)