11"""Module with methods used to trace AWS Bedrock LLMs."""
22
3+ import io
34import json
45import logging
56import time
67from functools import wraps
7- from typing import Any , Dict , Iterator , Optional , Union , TYPE_CHECKING
8+ from typing import TYPE_CHECKING , Any , Dict , Iterator , Optional , Union
9+
10+ from botocore .response import StreamingBody
11+
812
913try :
1014 import boto3
@@ -89,20 +93,7 @@ def handle_non_streaming_invoke(
8993 inference_id : Optional [str ] = None ,
9094 ** kwargs ,
9195) -> Dict [str , Any ]:
92- """Handles the invoke_model method for non-streaming requests.
93-
94- Parameters
95- ----------
96- invoke_func : callable
97- The invoke_model method to handle.
98- inference_id : Optional[str], optional
99- A user-generated inference id, by default None
100-
101- Returns
102- -------
103- Dict[str, Any]
104- The model invocation response.
105- """
96+ """Handles the invoke_model method for non-streaming requests."""
10697 start_time = time .time ()
10798 response = invoke_func (* args , ** kwargs )
10899 end_time = time .time ()
@@ -115,21 +106,27 @@ def handle_non_streaming_invoke(
115106 body_str = body_str .decode ("utf-8" )
116107 body_data = json .loads (body_str ) if isinstance (body_str , str ) else body_str
117108
118- # Parse the response body
119- response_body = response ["body" ].read ()
120- if isinstance (response_body , bytes ):
121- response_body = response_body .decode ("utf-8" )
122- response_data = json .loads (response_body )
109+ # Read the response body ONCE and preserve it
110+ original_body = response ["body" ]
111+ response_body_bytes = original_body .read ()
112+
113+ # Parse the response data for tracing
114+ if isinstance (response_body_bytes , bytes ):
115+ response_body_str = response_body_bytes .decode ("utf-8" )
116+ else :
117+ response_body_str = response_body_bytes
118+ response_data = json .loads (response_body_str )
123119
124- # Extract input and output data
120+ # Create a NEW StreamingBody with the same data and type
121+ # This preserves the exact botocore.response.StreamingBody type
122+ new_stream = io .BytesIO (response_body_bytes )
123+ response ["body" ] = StreamingBody (new_stream , len (response_body_bytes ))
124+
125+ # Extract data for tracing
125126 inputs = extract_inputs_from_body (body_data )
126127 output_data = extract_output_data (response_data )
127-
128- # Extract tokens and model info
129128 tokens_info = extract_tokens_info (response_data )
130129 model_id = kwargs .get ("modelId" , "unknown" )
131-
132- # Extract metadata including stop information
133130 metadata = extract_metadata (response_data )
134131
135132 trace_args = create_trace_args (
@@ -149,19 +146,12 @@ def handle_non_streaming_invoke(
149146
150147 add_to_trace (** trace_args )
151148
152- # pylint: disable=broad-except
153149 except Exception as e :
154150 logger .error (
155151 "Failed to trace the Bedrock model invocation with Openlayer. %s" , e
156152 )
157153
158- # Reset response body for return (since we read it)
159- response_bytes = json .dumps (response_data ).encode ("utf-8" )
160- response ["body" ] = type (
161- "MockBody" ,
162- (),
163- {"read" : lambda size = - 1 : response_bytes [:size ] if size > 0 else response_bytes },
164- )()
154+ # Return the response with the properly restored body
165155 return response
166156
167157
0 commit comments