11import logging
22import os
3+ from collections .abc import AsyncGenerator
34from typing import Annotated
45
56import azure .identity
6- from fastapi import Depends
7+ from fastapi import Depends , Request
78from openai import AsyncAzureOpenAI , AsyncOpenAI
89from pydantic import BaseModel
910from sqlalchemy .ext .asyncio import AsyncEngine , AsyncSession , async_sessionmaker
1011
11- from fastapi_app .openai_clients import create_openai_chat_client , create_openai_embed_client
12- from fastapi_app .postgres_engine import create_postgres_engine_from_env
13-
1412logger = logging .getLogger ("ragapp" )
1513
1614
@@ -67,7 +65,7 @@ async def common_parameters():
6765 )
6866
6967
70- def get_azure_credentials () -> azure .identity .DefaultAzureCredential | azure .identity .ManagedIdentityCredential :
68+ async def get_azure_credentials () -> azure .identity .DefaultAzureCredential | azure .identity .ManagedIdentityCredential :
7169 azure_credential : azure .identity .DefaultAzureCredential | azure .identity .ManagedIdentityCredential
7270 try :
7371 if client_id := os .getenv ("APP_IDENTITY_ID" ):
@@ -86,35 +84,55 @@ def get_azure_credentials() -> azure.identity.DefaultAzureCredential | azure.ide
8684 raise e
8785
8886
89- azure_credentials = get_azure_credentials ()
87+ async def create_async_sessionmaker (engine : AsyncEngine ) -> async_sessionmaker [AsyncSession ]:
88+ """Get the agent database"""
89+ return async_sessionmaker (
90+ engine ,
91+ expire_on_commit = False ,
92+ autoflush = False ,
93+ )
9094
9195
92- async def get_engine ():
93- """Get the agent database engine"""
94- engine = await create_postgres_engine_from_env ( azure_credentials )
95- return engine
96+ async def get_async_sessionmaker (
97+ request : Request ,
98+ ) -> AsyncGenerator [ async_sessionmaker [ AsyncSession ], None ]:
99+ yield request . state . sessionmaker
96100
97101
98- async def get_async_session (engine : Annotated [AsyncEngine , Depends (get_engine )]):
99- """Get the agent database"""
100- async_session_maker = async_sessionmaker (engine , expire_on_commit = False )
101- async with async_session_maker () as async_session :
102- yield async_session
102+ async def get_context (
103+ request : Request ,
104+ ) -> FastAPIAppContext :
105+ return request .state .context
106+
107+
108+ async def get_async_db_session (
109+ sessionmaker : Annotated [async_sessionmaker [AsyncSession ], Depends (get_async_sessionmaker )],
110+ ) -> AsyncGenerator [AsyncSession , None ]:
111+ async with sessionmaker () as session :
112+ try :
113+ yield session
114+ except :
115+ await session .rollback ()
116+ raise
117+ else :
118+ await session .commit ()
103119
104120
105- async def get_openai_chat_client ():
121+ async def get_openai_chat_client (
122+ request : Request ,
123+ ) -> OpenAIClient :
106124 """Get the OpenAI chat client"""
107- chat_client = await create_openai_chat_client (azure_credentials )
108- return OpenAIClient (client = chat_client )
125+ return OpenAIClient (client = request .state .chat_client )
109126
110127
111- async def get_openai_embed_client ():
128+ async def get_openai_embed_client (
129+ request : Request ,
130+ ) -> OpenAIClient :
112131 """Get the OpenAI embed client"""
113- embed_client = await create_openai_embed_client (azure_credentials )
114- return OpenAIClient (client = embed_client )
132+ return OpenAIClient (client = request .state .embed_client )
115133
116134
117- CommonDeps = Annotated [FastAPIAppContext , Depends (common_parameters )]
118- DBSession = Annotated [AsyncSession , Depends (get_async_session )]
135+ CommonDeps = Annotated [FastAPIAppContext , Depends (get_context )]
136+ DBSession = Annotated [AsyncSession , Depends (get_async_db_session )]
119137ChatClient = Annotated [OpenAIClient , Depends (get_openai_chat_client )]
120138EmbeddingsClient = Annotated [OpenAIClient , Depends (get_openai_embed_client )]
0 commit comments