diff --git a/src/Dockerfile_ecs b/src/Dockerfile_ecs index e71acbfd..8294ebf3 100644 --- a/src/Dockerfile_ecs +++ b/src/Dockerfile_ecs @@ -19,4 +19,4 @@ ENV PORT=8080 HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8080/health').read()" -CMD ["sh", "-c", "uvicorn api.app:app --host 0.0.0.0 --port ${PORT}"] +CMD ["sh", "-c", "uvicorn api.app:app --host 0.0.0.0 --port ${PORT} --workers ${WORKERS}"] diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index fba048bb..0f98a987 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -1,6 +1,7 @@ import base64 import json import logging +import os import re import time from abc import ABC @@ -13,6 +14,8 @@ import tiktoken from botocore.config import Config from fastapi import HTTPException +from langfuse import Langfuse +from langfuse.decorators import langfuse_context, observe from starlette.concurrency import run_in_threadpool from api.models.base import BaseChatModel, BaseEmbeddingsModel @@ -46,15 +49,32 @@ DEFAULT_MODEL, ENABLE_CROSS_REGION_INFERENCE, ENABLE_APPLICATION_INFERENCE_PROFILES, + MAX_RETRIES_AWS, + MODEL_CACHE_TTL, ) logger = logging.getLogger(__name__) +# Explicitly initialize Langfuse client for @observe decorator +# This ensures the consumer and auth check happen at module load time +try: + _langfuse = Langfuse( + public_key=os.environ.get("LANGFUSE_PUBLIC_KEY"), + secret_key=os.environ.get("LANGFUSE_SECRET_KEY"), + host=os.environ.get("LANGFUSE_HOST"), + debug=DEBUG + ) + if DEBUG: + logger.info("Langfuse client initialized successfully") +except Exception as e: + logger.warning(f"Failed to initialize Langfuse client: {e}") + _langfuse = None + config = Config( connect_timeout=60, # Connection timeout: 60 seconds read_timeout=900, # Read timeout: 15 minutes (suitable for long streaming responses) retries={ - 'max_attempts': 8, # Maximum retry attempts + 'max_attempts': MAX_RETRIES_AWS, # Maximum retry attempts 'mode': 'adaptive' # Adaptive retry mode }, max_pool_connections=50 # Maximum connection pool size @@ -178,15 +198,40 @@ def list_bedrock_models() -> dict: return model_list +# In-memory cache +_model_cache = { + "data": None, + "timestamp": 0 +} + +def _get_cached_models(): + """Get models from in-memory cache if still valid.""" + global _model_cache + + current_time = time.time() + cache_age = current_time - _model_cache["timestamp"] + + if _model_cache["data"] is None or cache_age > MODEL_CACHE_TTL: + fresh_models = list_bedrock_models() + if fresh_models: + _model_cache["data"] = fresh_models + _model_cache["timestamp"] = current_time + return fresh_models + else: + # Cache hit + return _model_cache["data"] + # Initialize the model list. -bedrock_model_list = list_bedrock_models() +bedrock_model_list = _get_cached_models() class BedrockModel(BaseChatModel): def list_models(self) -> list[str]: - """Always refresh the latest model list""" + """Get model list using in-memory cache with TTL""" global bedrock_model_list - bedrock_model_list = list_bedrock_models() + cached_models = _get_cached_models() + if cached_models: + bedrock_model_list = cached_models return list(bedrock_model_list.keys()) def validate(self, chat_request: ChatRequest): @@ -223,38 +268,98 @@ async def _invoke_bedrock(self, chat_request: ChatRequest, stream=False): # Run the blocking boto3 call in a thread pool response = await run_in_threadpool(bedrock_runtime.converse, **args) except bedrock_runtime.exceptions.ValidationException as e: - logger.error("Bedrock validation error for model %s: %s", chat_request.model, str(e)) + error_message = f"Bedrock validation error for model {chat_request.model}: {str(e)}" + logger.error(error_message) raise HTTPException(status_code=400, detail=str(e)) except bedrock_runtime.exceptions.ThrottlingException as e: - logger.warning("Bedrock throttling for model %s: %s", chat_request.model, str(e)) + error_message = f"Bedrock throttling for model {chat_request.model}: {str(e)}" + logger.warning(error_message) raise HTTPException(status_code=429, detail=str(e)) except Exception as e: - logger.error("Bedrock invocation failed for model %s: %s", chat_request.model, str(e)) + error_message = f"Bedrock invocation failed for model {chat_request.model}: {str(e)}" + logger.error(error_message) raise HTTPException(status_code=500, detail=str(e)) return response async def chat(self, chat_request: ChatRequest) -> ChatResponse: - """Default implementation for Chat API.""" - + """Default implementation for Chat API. + + Note: Works within the parent trace context created by @observe + decorator on chat_completions endpoint. Updates that trace context + with the response data. + """ message_id = self.generate_message_id() - response = await self._invoke_bedrock(chat_request) - - output_message = response["output"]["message"] - input_tokens = response["usage"]["inputTokens"] - output_tokens = response["usage"]["outputTokens"] - finish_reason = response["stopReason"] - - chat_response = self._create_response( - model=chat_request.model, - message_id=message_id, - content=output_message["content"], - finish_reason=finish_reason, - input_tokens=input_tokens, - output_tokens=output_tokens, - ) - if DEBUG: - logger.info("Proxy response :" + chat_response.model_dump_json()) - return chat_response + + try: + if DEBUG: + logger.info(f"Langfuse: Starting non-streaming request for model={chat_request.model}") + + response = await self._invoke_bedrock(chat_request) + + output_message = response["output"]["message"] + input_tokens = response["usage"]["inputTokens"] + output_tokens = response["usage"]["outputTokens"] + finish_reason = response["stopReason"] + + # Build metadata including usage info + trace_metadata = { + "model": chat_request.model, + "stream": False, + "stopReason": finish_reason, + "usage": { + "prompt_tokens": input_tokens, + "completion_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens + }, + "ResponseMetadata": response.get("ResponseMetadata", {}) + } + + # Check for reasoning content in response + has_reasoning = False + reasoning_text = "" + if output_message and "content" in output_message: + for content_block in output_message.get("content", []): + if "reasoningContent" in content_block: + has_reasoning = True + reasoning_text = content_block.get("reasoningContent", {}).get("reasoningText", {}).get("text", "") + break + + if has_reasoning and reasoning_text: + trace_metadata["has_extended_thinking"] = True + trace_metadata["reasoning_content"] = reasoning_text + trace_metadata["reasoning_tokens_estimate"] = len(reasoning_text) // 4 + + # Update trace with metadata + langfuse_context.update_current_trace( + metadata=trace_metadata + ) + + if DEBUG: + logger.info(f"Langfuse: Non-streaming response - " + f"input_tokens={input_tokens}, " + f"output_tokens={output_tokens}, " + f"has_reasoning={has_reasoning}, " + f"stop_reason={finish_reason}") + + chat_response = self._create_response( + model=chat_request.model, + message_id=message_id, + content=output_message["content"], + finish_reason=finish_reason, + input_tokens=input_tokens, + output_tokens=output_tokens, + ) + if DEBUG: + logger.info("Proxy response :" + chat_response.model_dump_json()) + return chat_response + except HTTPException: + # Re-raise HTTPException as-is + raise + except Exception as e: + logger.error("Chat error for model %s: %s", chat_request.model, str(e)) + if DEBUG: + logger.info(f"Langfuse: Error in non-streaming - error={str(e)[:100]}") + raise async def _async_iterate(self, stream): """Helper method to convert sync iterator to async iterator""" @@ -263,17 +368,56 @@ async def _async_iterate(self, stream): yield chunk async def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: - """Default implementation for Chat Stream API""" + """Default implementation for Chat Stream API + + Note: For streaming, we work within the parent trace context created by @observe + decorator on chat_completions endpoint. We update that trace context with + streaming data as it arrives. + """ try: + if DEBUG: + logger.info(f"Langfuse: Starting streaming request for model={chat_request.model}") + + # Parse request for metadata to log in parent trace + args = self._parse_request(chat_request) + messages = args.get('messages', []) + model_id = args.get('modelId', 'unknown') + response = await self._invoke_bedrock(chat_request, stream=True) message_id = self.generate_message_id() stream = response.get("stream") self.think_emitted = False + + # Track streaming output and usage for Langfuse + accumulated_output = [] + accumulated_reasoning = [] + final_usage = None + finish_reason = None + has_reasoning = False + async for chunk in self._async_iterate(stream): - args = {"model_id": chat_request.model, "message_id": message_id, "chunk": chunk} - stream_response = self._create_response_stream(**args) + args_chunk = {"model_id": chat_request.model, "message_id": message_id, "chunk": chunk} + stream_response = self._create_response_stream(**args_chunk) if not stream_response: continue + + # Accumulate output content for Langfuse tracking + if stream_response.choices: + for choice in stream_response.choices: + if choice.delta and choice.delta.content: + content = choice.delta.content + # Check if this is reasoning content (wrapped in tags) + if "" in content or self.think_emitted: + accumulated_reasoning.append(content) + has_reasoning = True + accumulated_output.append(content) + if choice.finish_reason: + finish_reason = choice.finish_reason + + # Capture final usage metrics for Langfuse tracking + if stream_response.usage: + final_usage = stream_response.usage + if DEBUG: logger.info("Proxy response :" + stream_response.model_dump_json()) if stream_response.choices: @@ -287,11 +431,66 @@ async def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: # All other chunks will also include a usage field, but with a null value. yield self.stream_response_to_bytes(stream_response) + # Update Langfuse trace with final streaming output + # This updates the parent trace from chat_completions + if final_usage or accumulated_output: + final_output = "".join(accumulated_output) if accumulated_output else None + trace_output = { + "message": { + "role": "assistant", + "content": final_output, + }, + "finish_reason": finish_reason, + } + + # Build metadata including usage info + trace_metadata = { + "model": model_id, + "stream": True, + } + if finish_reason: + trace_metadata["finish_reason"] = finish_reason + if final_usage: + trace_metadata["usage"] = { + "prompt_tokens": final_usage.prompt_tokens, + "completion_tokens": final_usage.completion_tokens, + "total_tokens": final_usage.total_tokens + } + if has_reasoning and accumulated_reasoning: + reasoning_text = "".join(accumulated_reasoning) + trace_metadata["has_extended_thinking"] = True + trace_metadata["reasoning_tokens_estimate"] = len(reasoning_text) // 4 + + langfuse_context.update_current_trace( + output=trace_output, + metadata=trace_metadata + ) + + if DEBUG: + output_length = len(accumulated_output) + logger.info(f"Langfuse: Updated trace with streaming output - " + f"chunks_count={output_length}, " + f"output_chars={len(final_output) if final_output else 0}, " + f"input_tokens={final_usage.prompt_tokens if final_usage else 'N/A'}, " + f"output_tokens={final_usage.completion_tokens if final_usage else 'N/A'}, " + f"has_reasoning={has_reasoning}, " + f"finish_reason={finish_reason}") + # return an [DONE] message at the end. yield self.stream_response_to_bytes() self.think_emitted = False # Cleanup + except HTTPException: + # Re-raise HTTPException as-is + raise except Exception as e: logger.error("Stream error for model %s: %s", chat_request.model, str(e)) + # Update Langfuse with error + langfuse_context.update_current_trace( + output={"error": str(e)}, + metadata={"error": True, "error_type": type(e).__name__} + ) + if DEBUG: + logger.info(f"Langfuse: Updated trace with streaming error - error={str(e)[:100]}") error_event = Error(error=ErrorMessage(message=str(e))) yield self.stream_response_to_bytes(error_event) @@ -720,6 +919,7 @@ def _create_response_stream( # Port of "signature_delta" if self.think_emitted: message = ChatResponseMessage(content="\n \n\n") + self.think_emitted = False # Reset flag after closing else: return None # Ignore signature if no started else: diff --git a/src/api/routers/chat.py b/src/api/routers/chat.py index 530f75d6..23071d4c 100644 --- a/src/api/routers/chat.py +++ b/src/api/routers/chat.py @@ -1,7 +1,8 @@ from typing import Annotated -from fastapi import APIRouter, Body, Depends +from fastapi import APIRouter, Body, Depends, Header, Request from fastapi.responses import StreamingResponse +from langfuse.decorators import langfuse_context, observe from api.auth import api_key_auth from api.models.bedrock import BedrockModel @@ -15,10 +16,63 @@ ) +def extract_langfuse_metadata(chat_request: ChatRequest, headers: dict) -> dict: + """Extract Langfuse tracing metadata from request body and headers. + + Metadata can be provided via: + 1. extra_body.langfuse_metadata dict in the request + 2. user field in the request (PRIORITIZED for user_id - ensures consistent email usage) + 3. HTTP headers: X-Chat-Id, X-User-Id, X-Session-Id, X-Message-Id + + Returns a dict with: user_id, session_id, chat_id, message_id, and any custom metadata + + Note: The 'user' field is prioritized for user_id to ensure email addresses + are consistently used across all messages instead of generated IDs from headers. + """ + metadata = {} + + # Extract from extra_body if present + if chat_request.extra_body and isinstance(chat_request.extra_body, dict): + langfuse_meta = chat_request.extra_body.get("langfuse_metadata", {}) + if isinstance(langfuse_meta, dict): + metadata.update(langfuse_meta) + + # PRIORITY: Set user_id from the 'user' field FIRST + # This ensures we always use the email address when available + if chat_request.user: + metadata["user_id"] = chat_request.user + + # Extract from headers + headers_lower = {k.lower(): v for k, v in headers.items()} + + # Map headers to metadata fields - support both standard and OpenWebUI-prefixed headers + header_mapping = { + "x-chat-id": "chat_id", + "x-openwebui-chat-id": "chat_id", # OpenWebUI sends this format + "x-user-id": "user_id", + "x-openwebui-user-id": "user_id", # OpenWebUI sends this format + "x-session-id": "session_id", + "x-openwebui-session-id": "session_id", # OpenWebUI sends this format + "x-message-id": "message_id", + "x-openwebui-message-id": "message_id", # OpenWebUI sends this format + } + + for header_key, meta_key in header_mapping.items(): + if header_key in headers_lower and headers_lower[header_key]: + # Don't override if already set + # (chat_request.user takes precedence for user_id) + if meta_key not in metadata: + metadata[meta_key] = headers_lower[header_key] + + return metadata + + @router.post( "/completions", response_model=ChatResponse | ChatStreamResponse | Error, response_model_exclude_unset=True ) +@observe(as_type="generation", name="chat_completion") async def chat_completions( + request: Request, chat_request: Annotated[ ChatRequest, Body( @@ -34,12 +88,42 @@ async def chat_completions( ), ], ): - if chat_request.model.lower().startswith("gpt-"): - chat_request.model = DEFAULT_MODEL + # Extract metadata for Langfuse tracing + metadata = extract_langfuse_metadata(chat_request, dict(request.headers)) + + # Create trace name using chat_id if available + trace_name = f"chat:{metadata.get('chat_id', 'unknown')}" + + # Update trace with metadata, user_id, and session_id + langfuse_context.update_current_trace( + name=trace_name, + user_id=metadata.get("user_id"), + session_id=metadata.get("session_id"), + metadata=metadata, + input={ + "model": chat_request.model, + "messages": [msg.model_dump() for msg in chat_request.messages], + "temperature": chat_request.temperature, + "max_tokens": chat_request.max_tokens, + "tools": [tool.model_dump() for tool in chat_request.tools] if chat_request.tools else None, + } + ) # Exception will be raised if model not supported. model = BedrockModel() model.validate(chat_request) + if chat_request.stream: return StreamingResponse(content=model.chat_stream(chat_request), media_type="text/event-stream") - return await model.chat(chat_request) + + response = await model.chat(chat_request) + + # Update trace with output for non-streaming + langfuse_context.update_current_trace( + output={ + "message": response.choices[0].message.model_dump() if response.choices else None, + "finish_reason": response.choices[0].finish_reason if response.choices else None, + } + ) + + return response diff --git a/src/api/schema.py b/src/api/schema.py index 233e1139..f7c829fc 100644 --- a/src/api/schema.py +++ b/src/api/schema.py @@ -99,7 +99,7 @@ class ChatRequest(BaseModel): stream_options: StreamOptions | None = None temperature: float | None = Field(default=1.0, le=2.0, ge=0.0) top_p: float | None = Field(default=1.0, le=1.0, ge=0.0) - user: str | None = None # Not used + user: str | None = None max_tokens: int | None = 2048 max_completion_tokens: int | None = None reasoning_effort: Literal["low", "medium", "high"] | None = None diff --git a/src/api/setting.py b/src/api/setting.py index 43fd2b7f..92dfcc0b 100644 --- a/src/api/setting.py +++ b/src/api/setting.py @@ -13,5 +13,9 @@ AWS_REGION = os.environ.get("AWS_REGION", "us-west-2") DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240229-v1:0") DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3") +CUSTOM_MODEL_CSV = os.environ.get("CUSTOM_MODEL_LIST", "") +CUSTOM_MODEL_LIST = [m.strip() for m in CUSTOM_MODEL_CSV.split(",") if m.strip()] ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false" +MAX_RETRIES_AWS = int(os.environ.get("MAX_RETRIES_AWS", "3")) ENABLE_APPLICATION_INFERENCE_PROFILES = os.environ.get("ENABLE_APPLICATION_INFERENCE_PROFILES", "true").lower() != "false" +MODEL_CACHE_TTL = int(os.environ.get("MODEL_CACHE_TTL", "3600")) # 1 hour default diff --git a/src/requirements.txt b/src/requirements.txt index 9aa0e2da..09c2d996 100644 --- a/src/requirements.txt +++ b/src/requirements.txt @@ -7,3 +7,4 @@ requests==2.32.4 numpy==2.2.5 boto3==1.40.4 botocore==1.40.4 +langfuse<3.0.0