Skip to content

Commit f076735

Browse files
committed
fix: esure Bedrock Converse observation nests properly in chat trace
1 parent 55ae73e commit f076735

File tree

1 file changed

+142
-137
lines changed

1 file changed

+142
-137
lines changed

src/api/models/bedrock.py

Lines changed: 142 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,6 @@ def validate(self, chat_request: ChatRequest):
248248
detail=error,
249249
)
250250

251-
@observe(as_type="generation", name="Bedrock Converse")
252251
async def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
253252
"""Common logic for invoke bedrock models"""
254253
if DEBUG:
@@ -259,29 +258,6 @@ async def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
259258
if DEBUG:
260259
logger.info("Bedrock request: " + json.dumps(str(args)))
261260

262-
# Extract model metadata for Langfuse
263-
args_clone = args.copy()
264-
messages = args_clone.get('messages', [])
265-
model_id = args_clone.get('modelId', 'unknown')
266-
model_parameters = {
267-
**args_clone.get('inferenceConfig', {}),
268-
**args_clone.get('additionalModelRequestFields', {})
269-
}
270-
271-
# Update Langfuse generation with input metadata
272-
langfuse_context.update_current_observation(
273-
input=messages,
274-
model=model_id,
275-
model_parameters=model_parameters,
276-
metadata={
277-
'system': args_clone.get('system', []),
278-
'toolConfig': args_clone.get('toolConfig', {}),
279-
'stream': stream
280-
}
281-
)
282-
if DEBUG:
283-
logger.info(f"Langfuse: Updated observation with input - model={model_id}, stream={stream}, messages_count={len(messages)}")
284-
285261
try:
286262
if stream:
287263
# Run the blocking boto3 call in a thread pool
@@ -291,93 +267,118 @@ async def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
291267
else:
292268
# Run the blocking boto3 call in a thread pool
293269
response = await run_in_threadpool(bedrock_runtime.converse, **args)
294-
295-
# For non-streaming, extract response metadata immediately
296-
if response and not stream:
297-
output_message = response.get("output", {}).get("message", {})
298-
usage = response.get("usage", {})
299-
300-
# Build metadata
301-
metadata = {
302-
"stopReason": response.get("stopReason"),
303-
"ResponseMetadata": response.get("ResponseMetadata", {})
304-
}
305-
306-
# Check for reasoning content in response
307-
has_reasoning = False
308-
reasoning_text = ""
309-
if output_message and "content" in output_message:
310-
for content_block in output_message.get("content", []):
311-
if "reasoningContent" in content_block:
312-
has_reasoning = True
313-
reasoning_text = content_block.get("reasoningContent", {}).get("reasoningText", {}).get("text", "")
314-
break
315-
316-
if has_reasoning and reasoning_text:
317-
metadata["has_extended_thinking"] = True
318-
metadata["reasoning_content"] = reasoning_text
319-
metadata["reasoning_tokens_estimate"] = len(reasoning_text) // 4
320-
321-
langfuse_context.update_current_observation(
322-
output=output_message,
323-
usage={
324-
"input": usage.get("inputTokens", 0),
325-
"output": usage.get("outputTokens", 0),
326-
"total": usage.get("totalTokens", 0)
327-
},
328-
metadata=metadata
329-
)
330-
if DEBUG:
331-
logger.info(f"Langfuse: Updated observation with output - "
332-
f"input_tokens={usage.get('inputTokens', 0)}, "
333-
f"output_tokens={usage.get('outputTokens', 0)}, "
334-
f"has_reasoning={has_reasoning}, "
335-
f"stop_reason={response.get('stopReason')}")
336270
except bedrock_runtime.exceptions.ValidationException as e:
337271
error_message = f"Bedrock validation error for model {chat_request.model}: {str(e)}"
338272
logger.error(error_message)
339-
langfuse_context.update_current_observation(level="ERROR", status_message=error_message)
340-
if DEBUG:
341-
logger.info("Langfuse: Updated observation with ValidationException error")
342273
raise HTTPException(status_code=400, detail=str(e))
343274
except bedrock_runtime.exceptions.ThrottlingException as e:
344275
error_message = f"Bedrock throttling for model {chat_request.model}: {str(e)}"
345276
logger.warning(error_message)
346-
langfuse_context.update_current_observation(level="WARNING", status_message=error_message)
347-
if DEBUG:
348-
logger.info("Langfuse: Updated observation with ThrottlingException warning")
349277
raise HTTPException(status_code=429, detail=str(e))
350278
except Exception as e:
351279
error_message = f"Bedrock invocation failed for model {chat_request.model}: {str(e)}"
352280
logger.error(error_message)
353-
langfuse_context.update_current_observation(level="ERROR", status_message=error_message)
354-
if DEBUG:
355-
logger.info("Langfuse: Updated observation with generic Exception error")
356281
raise HTTPException(status_code=500, detail=str(e))
357282
return response
358283

359284
async def chat(self, chat_request: ChatRequest) -> ChatResponse:
360285
"""Default implementation for Chat API."""
361-
286+
from langfuse import get_client
287+
288+
langfuse = get_client()
362289
message_id = self.generate_message_id()
363-
response = await self._invoke_bedrock(chat_request)
364-
365-
output_message = response["output"]["message"]
366-
input_tokens = response["usage"]["inputTokens"]
367-
output_tokens = response["usage"]["outputTokens"]
368-
finish_reason = response["stopReason"]
369-
370-
chat_response = self._create_response(
371-
model=chat_request.model,
372-
message_id=message_id,
373-
content=output_message["content"],
374-
finish_reason=finish_reason,
375-
input_tokens=input_tokens,
376-
output_tokens=output_tokens,
377-
)
378-
if DEBUG:
379-
logger.info("Proxy response :" + chat_response.model_dump_json())
380-
return chat_response
290+
291+
# Create Langfuse observation for Bedrock call
292+
with langfuse.start_as_current_observation(as_type="generation", name="Bedrock Converse") as obs:
293+
# Parse request for metadata
294+
args = self._parse_request(chat_request)
295+
messages = args.get('messages', [])
296+
model_id = args.get('modelId', 'unknown')
297+
model_parameters = {
298+
**args.get('inferenceConfig', {}),
299+
**args.get('additionalModelRequestFields', {})
300+
}
301+
302+
# Update observation with input
303+
langfuse_context.update_current_observation(
304+
input=messages,
305+
model=model_id,
306+
model_parameters=model_parameters,
307+
metadata={
308+
'system': args.get('system', []),
309+
'toolConfig': args.get('toolConfig', {}),
310+
'stream': False
311+
}
312+
)
313+
if DEBUG:
314+
logger.info(f"Langfuse: Updated observation with input - model={model_id}, stream=False, messages_count={len(messages)}")
315+
316+
try:
317+
response = await self._invoke_bedrock(chat_request)
318+
319+
output_message = response["output"]["message"]
320+
input_tokens = response["usage"]["inputTokens"]
321+
output_tokens = response["usage"]["outputTokens"]
322+
finish_reason = response["stopReason"]
323+
324+
# Update observation with output
325+
metadata = {
326+
"stopReason": finish_reason,
327+
"ResponseMetadata": response.get("ResponseMetadata", {})
328+
}
329+
330+
# Check for reasoning content in response
331+
has_reasoning = False
332+
reasoning_text = ""
333+
if output_message and "content" in output_message:
334+
for content_block in output_message.get("content", []):
335+
if "reasoningContent" in content_block:
336+
has_reasoning = True
337+
reasoning_text = content_block.get("reasoningContent", {}).get("reasoningText", {}).get("text", "")
338+
break
339+
340+
if has_reasoning and reasoning_text:
341+
metadata["has_extended_thinking"] = True
342+
metadata["reasoning_content"] = reasoning_text
343+
metadata["reasoning_tokens_estimate"] = len(reasoning_text) // 4
344+
345+
langfuse_context.update_current_observation(
346+
output=output_message,
347+
usage={
348+
"input": input_tokens,
349+
"output": output_tokens,
350+
"total": input_tokens + output_tokens
351+
},
352+
metadata=metadata
353+
)
354+
if DEBUG:
355+
logger.info(f"Langfuse: Updated observation with output - "
356+
f"input_tokens={input_tokens}, "
357+
f"output_tokens={output_tokens}, "
358+
f"has_reasoning={has_reasoning}, "
359+
f"stop_reason={finish_reason}")
360+
361+
chat_response = self._create_response(
362+
model=chat_request.model,
363+
message_id=message_id,
364+
content=output_message["content"],
365+
finish_reason=finish_reason,
366+
input_tokens=input_tokens,
367+
output_tokens=output_tokens,
368+
)
369+
if DEBUG:
370+
logger.info("Proxy response :" + chat_response.model_dump_json())
371+
return chat_response
372+
except HTTPException as e:
373+
langfuse_context.update_current_observation(level="ERROR", status_message=str(e.detail))
374+
if DEBUG:
375+
logger.info(f"Langfuse: Updated observation with error - {str(e.detail)[:100]}")
376+
raise
377+
except Exception as e:
378+
langfuse_context.update_current_observation(level="ERROR", status_message=str(e))
379+
if DEBUG:
380+
logger.info(f"Langfuse: Updated observation with error - {str(e)[:100]}")
381+
raise
381382

382383
async def _async_iterate(self, stream):
383384
"""Helper method to convert sync iterator to async iterator"""
@@ -386,10 +387,21 @@ async def _async_iterate(self, stream):
386387
yield chunk
387388

388389
async def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
389-
"""Default implementation for Chat Stream API"""
390+
"""Default implementation for Chat Stream API
391+
392+
Note: For streaming, we work within the parent trace context created by @observe
393+
decorator on chat_completions endpoint. We update that trace context with
394+
streaming data as it arrives.
395+
"""
390396
try:
391397
if DEBUG:
392398
logger.info(f"Langfuse: Starting streaming request for model={chat_request.model}")
399+
400+
# Parse request for metadata to log in parent trace
401+
args = self._parse_request(chat_request)
402+
messages = args.get('messages', [])
403+
model_id = args.get('modelId', 'unknown')
404+
393405
response = await self._invoke_bedrock(chat_request, stream=True)
394406
message_id = self.generate_message_id()
395407
stream = response.get("stream")
@@ -403,8 +415,8 @@ async def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
403415
has_reasoning = False
404416

405417
async for chunk in self._async_iterate(stream):
406-
args = {"model_id": chat_request.model, "message_id": message_id, "chunk": chunk}
407-
stream_response = self._create_response_stream(**args)
418+
args_chunk = {"model_id": chat_request.model, "message_id": message_id, "chunk": chunk}
419+
stream_response = self._create_response_stream(**args_chunk)
408420
if not stream_response:
409421
continue
410422

@@ -438,49 +450,46 @@ async def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
438450
# All other chunks will also include a usage field, but with a null value.
439451
yield self.stream_response_to_bytes(stream_response)
440452

441-
# Update Langfuse with final streaming metadata (both observation and trace)
453+
# Update Langfuse trace with final streaming output
454+
# This updates the parent trace from chat_completions
442455
if final_usage or accumulated_output:
443-
update_params = {}
444-
if accumulated_output:
445-
final_output = "".join(accumulated_output)
446-
update_params["output"] = final_output
447-
if final_usage:
448-
update_params["usage"] = {
449-
"input": final_usage.prompt_tokens,
450-
"output": final_usage.completion_tokens,
451-
"total": final_usage.total_tokens
452-
}
453-
# Build metadata
454-
metadata = {}
455-
if finish_reason:
456-
metadata["finish_reason"] = finish_reason
457-
if has_reasoning and accumulated_reasoning:
458-
reasoning_text = "".join(accumulated_reasoning)
459-
metadata["has_extended_thinking"] = True
460-
metadata["reasoning_content"] = reasoning_text
461-
# Estimate reasoning tokens (rough approximation: ~4 chars per token)
462-
metadata["reasoning_tokens_estimate"] = len(reasoning_text) // 4
463-
if metadata:
464-
update_params["metadata"] = metadata
465-
466-
# Update the child observation (Bedrock Converse)
467-
langfuse_context.update_current_observation(**update_params)
468-
469-
# Also update the parent trace (chat_completion) with final output
456+
final_output = "".join(accumulated_output) if accumulated_output else None
470457
trace_output = {
471458
"message": {
472459
"role": "assistant",
473-
"content": final_output if accumulated_output else None,
460+
"content": final_output,
474461
},
475462
"finish_reason": finish_reason,
476463
}
477-
langfuse_context.update_current_trace(output=trace_output)
464+
465+
# Build metadata including usage info
466+
trace_metadata = {
467+
"model": model_id,
468+
"stream": True,
469+
}
470+
if finish_reason:
471+
trace_metadata["finish_reason"] = finish_reason
472+
if final_usage:
473+
trace_metadata["usage"] = {
474+
"prompt_tokens": final_usage.prompt_tokens,
475+
"completion_tokens": final_usage.completion_tokens,
476+
"total_tokens": final_usage.total_tokens
477+
}
478+
if has_reasoning and accumulated_reasoning:
479+
reasoning_text = "".join(accumulated_reasoning)
480+
trace_metadata["has_extended_thinking"] = True
481+
trace_metadata["reasoning_tokens_estimate"] = len(reasoning_text) // 4
482+
483+
langfuse_context.update_current_trace(
484+
output=trace_output,
485+
metadata=trace_metadata
486+
)
478487

479488
if DEBUG:
480489
output_length = len(accumulated_output)
481-
logger.info(f"Langfuse: Updated observation and trace with streaming output - "
490+
logger.info(f"Langfuse: Updated trace with streaming output - "
482491
f"chunks_count={output_length}, "
483-
f"output_chars={len(final_output) if accumulated_output else 0}, "
492+
f"output_chars={len(final_output) if final_output else 0}, "
484493
f"input_tokens={final_usage.prompt_tokens if final_usage else 'N/A'}, "
485494
f"output_tokens={final_usage.completion_tokens if final_usage else 'N/A'}, "
486495
f"has_reasoning={has_reasoning}, "
@@ -490,21 +499,17 @@ async def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
490499
yield self.stream_response_to_bytes()
491500
self.think_emitted = False # Cleanup
492501
except HTTPException:
493-
# HTTPException already has Langfuse updated in _invoke_bedrock, re-raise it
502+
# Re-raise HTTPException as-is
494503
raise
495504
except Exception as e:
496505
logger.error("Stream error for model %s: %s", chat_request.model, str(e))
497-
# Update Langfuse with error (both observation and trace)
498-
langfuse_context.update_current_observation(
499-
level="ERROR",
500-
status_message=f"Stream error: {str(e)}"
501-
)
506+
# Update Langfuse with error
502507
langfuse_context.update_current_trace(
503508
output={"error": str(e)},
504-
metadata={"error": True}
509+
metadata={"error": True, "error_type": type(e).__name__}
505510
)
506511
if DEBUG:
507-
logger.info(f"Langfuse: Updated observation with streaming error - error={str(e)[:100]}")
512+
logger.info(f"Langfuse: Updated trace with streaming error - error={str(e)[:100]}")
508513
error_event = Error(error=ErrorMessage(message=str(e)))
509514
yield self.stream_response_to_bytes(error_event)
510515

0 commit comments

Comments
 (0)