77from azure .identity import AzureDeveloperCliCredential , get_bearer_token_provider
88from dotenv_azd import load_azd_env
99from openai import AzureOpenAI , OpenAI
10+ from openai .types .chat import ChatCompletionToolParam
1011from sqlalchemy import create_engine , select
1112from sqlalchemy .orm import Session
1213
1516logger = logging .getLogger ("ragapp" )
1617
1718
18- def qa_pairs_tool (num_questions : int = 1 ) -> dict :
19+ def qa_pairs_tool (num_questions : int = 1 ) -> ChatCompletionToolParam :
1920 return {
2021 "type" : "function" ,
2122 "function" : {
@@ -45,7 +46,7 @@ def qa_pairs_tool(num_questions: int = 1) -> dict:
4546 }
4647
4748
48- def source_retriever () -> Generator [dict , None , None ]:
49+ def source_retriever () -> Generator [str , None , None ]:
4950 # Connect to the database
5051 DBHOST = os .environ ["POSTGRES_HOST" ]
5152 DBUSER = os .environ ["POSTGRES_USERNAME" ]
@@ -76,8 +77,9 @@ def answer_formatter(answer, source) -> str:
7677 return f"{ answer } [{ source ['id' ]} ]"
7778
7879
79- def get_openai_client () -> AzureOpenAI | OpenAI :
80+ def get_openai_client () -> tuple [ AzureOpenAI | OpenAI , str ] :
8081 """Return an OpenAI client based on the environment variables"""
82+ openai_client : AzureOpenAI | OpenAI
8183 OPENAI_CHAT_HOST = os .getenv ("OPENAI_CHAT_HOST" )
8284 if OPENAI_CHAT_HOST == "azure" :
8385 if api_key := os .getenv ("AZURE_OPENAI_KEY" ):
@@ -101,8 +103,7 @@ def get_openai_client() -> AzureOpenAI | OpenAI:
101103 raise NotImplementedError ("Ollama OpenAI Service is not supported. Switch to Azure or OpenAI.com" )
102104 else :
103105 logger .info ("Using OpenAI Service with API Key from OPENAICOM_KEY" )
104- openai_config = {"api_type" : "openai" , "api_key" : os .environ ["OPENAICOM_KEY" ]}
105- openai_client = OpenAI (** openai_config )
106+ openai_client = OpenAI (api_key = os .environ ["OPENAICOM_KEY" ])
106107 model = os .environ ["OPENAICOM_CHAT_MODEL" ]
107108 return openai_client , model
108109
@@ -127,6 +128,9 @@ def generate_ground_truth_data(num_questions_total: int, num_questions_per_sourc
127128 ],
128129 tools = [qa_pairs_tool (num_questions = 2 )],
129130 )
131+ if not result .choices [0 ].message .tool_calls :
132+ logger .warning ("No tool calls found in response, skipping" )
133+ continue
130134 qa_pairs = json .loads (result .choices [0 ].message .tool_calls [0 ].function .arguments )["qa_list" ]
131135 qa_pairs = [{"question" : qa_pair ["question" ], "truth" : qa_pair ["answer" ]} for qa_pair in qa_pairs ]
132136 qa .extend (qa_pairs )
0 commit comments