Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f14dd1c
feat: customize model list using environment
computerdane Mar 18, 2025
69977bf
Merge commit 'f14dd1cb4e8fedf0dd67bb3cb13ec9ccad1893da' into merge-al…
ofawx May 20, 2025
796670b
Fixed blank CUSTOM_MODEL_LIST handling
2underscores May 28, 2025
d77567e
Merge branch 'aws-samples:main' into main
ofawx Jun 16, 2025
d133e2f
Configurable AWS retries via environment variable
ofawx Jun 16, 2025
ff70559
Merge pull request #1 from Constantinople-AI/GGGG-29-aws-retry
ofawx Jun 17, 2025
ff528a8
Merge remote-tracking branch 'upstream/main' into fix-sonnet-4
jake-marsden Aug 7, 2025
f48f37e
fix: remove custom models bug
jake-marsden Aug 7, 2025
9c89baf
Add pagination to list_inference_profiles calls
Aug 7, 2025
2d785ab
Merge pull request #2
ofawx Aug 11, 2025
8c7bd77
impl model list caching
ofawx Sep 23, 2025
4934ec0
impl uvicorn worker count configurability in containerised deployments
ofawx Sep 23, 2025
f2bda42
Merge pull request #3 from Constantinople-AI/GGGG-370-workers-model-c…
ofawx Sep 23, 2025
ecc8578
Merge remote-tracking branch 'upstream/main'
Adam-Schildkraut Oct 14, 2025
53c9cf8
Merge pull request #4 from Constantinople-AI/GGGG-417-sonnet-4-5-infe…
Adam-Schildkraut Oct 14, 2025
221e935
feat: add langfuse observability to bedrock gateway
jake-marsden Nov 13, 2025
56db50e
chore: add debugging statements
jake-marsden Nov 13, 2025
7d7a69d
fix: add missing think tag
jake-marsden Nov 13, 2025
4c98a98
fix: add HTTPException
jake-marsden Nov 13, 2025
e79317a
fix: resolve incorrect import
jake-marsden Nov 14, 2025
abedb3f
fix: correct langfuse import path and use client API
jake-marsden Nov 14, 2025
f057556
fix: downgrade to Langfuse 2.x for API compatibility
jake-marsden Nov 14, 2025
2afb11f
refactor: explicitly initialize langfuse client
jake-marsden Nov 14, 2025
c78cee3
fix: read env variables
jake-marsden Nov 14, 2025
55ae73e
fix: add trace-level metadata for all chat messages
jake-marsden Nov 19, 2025
1bf4859
fix: esure Bedrock Converse observation nests properly in chat trace
jake-marsden Nov 20, 2025
ff97409
fix: pass user email for all chats
jake-marsden Nov 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Dockerfile_ecs
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ ENV PORT=8080
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8080/health').read()"

CMD ["sh", "-c", "uvicorn api.app:app --host 0.0.0.0 --port ${PORT}"]
CMD ["sh", "-c", "uvicorn api.app:app --host 0.0.0.0 --port ${PORT} --workers ${WORKERS}"]
260 changes: 230 additions & 30 deletions src/api/models/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import json
import logging
import os
import re
import time
from abc import ABC
Expand All @@ -13,6 +14,8 @@
import tiktoken
from botocore.config import Config
from fastapi import HTTPException
from langfuse import Langfuse
from langfuse.decorators import langfuse_context, observe
from starlette.concurrency import run_in_threadpool

from api.models.base import BaseChatModel, BaseEmbeddingsModel
Expand Down Expand Up @@ -46,15 +49,32 @@
DEFAULT_MODEL,
ENABLE_CROSS_REGION_INFERENCE,
ENABLE_APPLICATION_INFERENCE_PROFILES,
MAX_RETRIES_AWS,
MODEL_CACHE_TTL,
)

logger = logging.getLogger(__name__)

# Explicitly initialize Langfuse client for @observe decorator
# This ensures the consumer and auth check happen at module load time
try:
_langfuse = Langfuse(
public_key=os.environ.get("LANGFUSE_PUBLIC_KEY"),
secret_key=os.environ.get("LANGFUSE_SECRET_KEY"),
host=os.environ.get("LANGFUSE_HOST"),
debug=DEBUG
)
if DEBUG:
logger.info("Langfuse client initialized successfully")
except Exception as e:
logger.warning(f"Failed to initialize Langfuse client: {e}")
_langfuse = None

config = Config(
connect_timeout=60, # Connection timeout: 60 seconds
read_timeout=900, # Read timeout: 15 minutes (suitable for long streaming responses)
retries={
'max_attempts': 8, # Maximum retry attempts
'max_attempts': MAX_RETRIES_AWS, # Maximum retry attempts
'mode': 'adaptive' # Adaptive retry mode
},
max_pool_connections=50 # Maximum connection pool size
Expand Down Expand Up @@ -178,15 +198,40 @@ def list_bedrock_models() -> dict:
return model_list


# In-memory cache
_model_cache = {
"data": None,
"timestamp": 0
}

def _get_cached_models():
"""Get models from in-memory cache if still valid."""
global _model_cache

current_time = time.time()
cache_age = current_time - _model_cache["timestamp"]

if _model_cache["data"] is None or cache_age > MODEL_CACHE_TTL:
fresh_models = list_bedrock_models()
if fresh_models:
_model_cache["data"] = fresh_models
_model_cache["timestamp"] = current_time
return fresh_models
else:
# Cache hit
return _model_cache["data"]

# Initialize the model list.
bedrock_model_list = list_bedrock_models()
bedrock_model_list = _get_cached_models()


class BedrockModel(BaseChatModel):
def list_models(self) -> list[str]:
"""Always refresh the latest model list"""
"""Get model list using in-memory cache with TTL"""
global bedrock_model_list
bedrock_model_list = list_bedrock_models()
cached_models = _get_cached_models()
if cached_models:
bedrock_model_list = cached_models
return list(bedrock_model_list.keys())

def validate(self, chat_request: ChatRequest):
Expand Down Expand Up @@ -223,38 +268,98 @@ async def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
# Run the blocking boto3 call in a thread pool
response = await run_in_threadpool(bedrock_runtime.converse, **args)
except bedrock_runtime.exceptions.ValidationException as e:
logger.error("Bedrock validation error for model %s: %s", chat_request.model, str(e))
error_message = f"Bedrock validation error for model {chat_request.model}: {str(e)}"
logger.error(error_message)
raise HTTPException(status_code=400, detail=str(e))
except bedrock_runtime.exceptions.ThrottlingException as e:
logger.warning("Bedrock throttling for model %s: %s", chat_request.model, str(e))
error_message = f"Bedrock throttling for model {chat_request.model}: {str(e)}"
logger.warning(error_message)
raise HTTPException(status_code=429, detail=str(e))
except Exception as e:
logger.error("Bedrock invocation failed for model %s: %s", chat_request.model, str(e))
error_message = f"Bedrock invocation failed for model {chat_request.model}: {str(e)}"
logger.error(error_message)
raise HTTPException(status_code=500, detail=str(e))
return response

async def chat(self, chat_request: ChatRequest) -> ChatResponse:
"""Default implementation for Chat API."""

"""Default implementation for Chat API.

Note: Works within the parent trace context created by @observe
decorator on chat_completions endpoint. Updates that trace context
with the response data.
"""
message_id = self.generate_message_id()
response = await self._invoke_bedrock(chat_request)

output_message = response["output"]["message"]
input_tokens = response["usage"]["inputTokens"]
output_tokens = response["usage"]["outputTokens"]
finish_reason = response["stopReason"]

chat_response = self._create_response(
model=chat_request.model,
message_id=message_id,
content=output_message["content"],
finish_reason=finish_reason,
input_tokens=input_tokens,
output_tokens=output_tokens,
)
if DEBUG:
logger.info("Proxy response :" + chat_response.model_dump_json())
return chat_response

try:
if DEBUG:
logger.info(f"Langfuse: Starting non-streaming request for model={chat_request.model}")

response = await self._invoke_bedrock(chat_request)

output_message = response["output"]["message"]
input_tokens = response["usage"]["inputTokens"]
output_tokens = response["usage"]["outputTokens"]
finish_reason = response["stopReason"]

# Build metadata including usage info
trace_metadata = {
"model": chat_request.model,
"stream": False,
"stopReason": finish_reason,
"usage": {
"prompt_tokens": input_tokens,
"completion_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens
},
"ResponseMetadata": response.get("ResponseMetadata", {})
}

# Check for reasoning content in response
has_reasoning = False
reasoning_text = ""
if output_message and "content" in output_message:
for content_block in output_message.get("content", []):
if "reasoningContent" in content_block:
has_reasoning = True
reasoning_text = content_block.get("reasoningContent", {}).get("reasoningText", {}).get("text", "")
break

if has_reasoning and reasoning_text:
trace_metadata["has_extended_thinking"] = True
trace_metadata["reasoning_content"] = reasoning_text
trace_metadata["reasoning_tokens_estimate"] = len(reasoning_text) // 4

# Update trace with metadata
langfuse_context.update_current_trace(
metadata=trace_metadata
)

if DEBUG:
logger.info(f"Langfuse: Non-streaming response - "
f"input_tokens={input_tokens}, "
f"output_tokens={output_tokens}, "
f"has_reasoning={has_reasoning}, "
f"stop_reason={finish_reason}")

chat_response = self._create_response(
model=chat_request.model,
message_id=message_id,
content=output_message["content"],
finish_reason=finish_reason,
input_tokens=input_tokens,
output_tokens=output_tokens,
)
if DEBUG:
logger.info("Proxy response :" + chat_response.model_dump_json())
return chat_response
except HTTPException:
# Re-raise HTTPException as-is
raise
except Exception as e:
logger.error("Chat error for model %s: %s", chat_request.model, str(e))
if DEBUG:
logger.info(f"Langfuse: Error in non-streaming - error={str(e)[:100]}")
raise

async def _async_iterate(self, stream):
"""Helper method to convert sync iterator to async iterator"""
Expand All @@ -263,17 +368,56 @@ async def _async_iterate(self, stream):
yield chunk

async def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
"""Default implementation for Chat Stream API"""
"""Default implementation for Chat Stream API

Note: For streaming, we work within the parent trace context created by @observe
decorator on chat_completions endpoint. We update that trace context with
streaming data as it arrives.
"""
try:
if DEBUG:
logger.info(f"Langfuse: Starting streaming request for model={chat_request.model}")

# Parse request for metadata to log in parent trace
args = self._parse_request(chat_request)
messages = args.get('messages', [])
model_id = args.get('modelId', 'unknown')

response = await self._invoke_bedrock(chat_request, stream=True)
message_id = self.generate_message_id()
stream = response.get("stream")
self.think_emitted = False

# Track streaming output and usage for Langfuse
accumulated_output = []
accumulated_reasoning = []
final_usage = None
finish_reason = None
has_reasoning = False

async for chunk in self._async_iterate(stream):
args = {"model_id": chat_request.model, "message_id": message_id, "chunk": chunk}
stream_response = self._create_response_stream(**args)
args_chunk = {"model_id": chat_request.model, "message_id": message_id, "chunk": chunk}
stream_response = self._create_response_stream(**args_chunk)
if not stream_response:
continue

# Accumulate output content for Langfuse tracking
if stream_response.choices:
for choice in stream_response.choices:
if choice.delta and choice.delta.content:
content = choice.delta.content
# Check if this is reasoning content (wrapped in <think> tags)
if "<think>" in content or self.think_emitted:
accumulated_reasoning.append(content)
has_reasoning = True
accumulated_output.append(content)
if choice.finish_reason:
finish_reason = choice.finish_reason

# Capture final usage metrics for Langfuse tracking
if stream_response.usage:
final_usage = stream_response.usage

if DEBUG:
logger.info("Proxy response :" + stream_response.model_dump_json())
if stream_response.choices:
Expand All @@ -287,11 +431,66 @@ async def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
# All other chunks will also include a usage field, but with a null value.
yield self.stream_response_to_bytes(stream_response)

# Update Langfuse trace with final streaming output
# This updates the parent trace from chat_completions
if final_usage or accumulated_output:
final_output = "".join(accumulated_output) if accumulated_output else None
trace_output = {
"message": {
"role": "assistant",
"content": final_output,
},
"finish_reason": finish_reason,
}

# Build metadata including usage info
trace_metadata = {
"model": model_id,
"stream": True,
}
if finish_reason:
trace_metadata["finish_reason"] = finish_reason
if final_usage:
trace_metadata["usage"] = {
"prompt_tokens": final_usage.prompt_tokens,
"completion_tokens": final_usage.completion_tokens,
"total_tokens": final_usage.total_tokens
}
if has_reasoning and accumulated_reasoning:
reasoning_text = "".join(accumulated_reasoning)
trace_metadata["has_extended_thinking"] = True
trace_metadata["reasoning_tokens_estimate"] = len(reasoning_text) // 4

langfuse_context.update_current_trace(
output=trace_output,
metadata=trace_metadata
)

if DEBUG:
output_length = len(accumulated_output)
logger.info(f"Langfuse: Updated trace with streaming output - "
f"chunks_count={output_length}, "
f"output_chars={len(final_output) if final_output else 0}, "
f"input_tokens={final_usage.prompt_tokens if final_usage else 'N/A'}, "
f"output_tokens={final_usage.completion_tokens if final_usage else 'N/A'}, "
f"has_reasoning={has_reasoning}, "
f"finish_reason={finish_reason}")

# return an [DONE] message at the end.
yield self.stream_response_to_bytes()
self.think_emitted = False # Cleanup
except HTTPException:
# Re-raise HTTPException as-is
raise
except Exception as e:
logger.error("Stream error for model %s: %s", chat_request.model, str(e))
# Update Langfuse with error
langfuse_context.update_current_trace(
output={"error": str(e)},
metadata={"error": True, "error_type": type(e).__name__}
)
if DEBUG:
logger.info(f"Langfuse: Updated trace with streaming error - error={str(e)[:100]}")
error_event = Error(error=ErrorMessage(message=str(e)))
yield self.stream_response_to_bytes(error_event)

Expand Down Expand Up @@ -720,6 +919,7 @@ def _create_response_stream(
# Port of "signature_delta"
if self.think_emitted:
message = ChatResponseMessage(content="\n </think> \n\n")
self.think_emitted = False # Reset flag after closing </think>
else:
return None # Ignore signature if no <think> started
else:
Expand Down
Loading