Skip to content

Commit ba2eefe

Browse files
committed
Propagate API key correctly when creating default KB
1 parent aa5e677 commit ba2eefe

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

mindsdb_sdk/agents.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
from urllib.parse import urlparse
44
from uuid import uuid4
55
import datetime
6+
import json
67
import pandas as pd
78

89
from mindsdb_sdk.databases import Databases
9-
from mindsdb_sdk.knowledge_bases import KnowledgeBases
10+
from mindsdb_sdk.knowledge_bases import KnowledgeBase, KnowledgeBases
1011
from mindsdb_sdk.ml_engines import MLEngines
1112
from mindsdb_sdk.models import Model, Models
1213
from mindsdb_sdk.skills import Skill, Skills
@@ -194,6 +195,22 @@ def completion(self, name: str, messages: List[dict]) -> AgentCompletion:
194195
data = self.api.agent_completion(self.project, name, messages)
195196
return AgentCompletion(data['message']['content'])
196197

198+
def _create_default_knowledge_base(self, agent: Agent, name: str) -> KnowledgeBase:
199+
# Make sure default ML engine for embeddings exists.
200+
try:
201+
_ = self.ml_engines.get('langchain_embedding')
202+
except AttributeError:
203+
_ = self.ml_engines.create('langchain_embedding', 'langchain_embedding')
204+
# Include API keys in embeddings.
205+
agent_model = self.models.get(agent.model_name)
206+
training_options = json.loads(agent_model.data.get('training_options', '{}'))
207+
training_options_using = training_options.get('using', {})
208+
api_key_params = {k:v for k, v in training_options_using.items() if 'api_key' in k}
209+
kb = self.knowledge_bases.create(name, params=api_key_params)
210+
# Wait for underlying embedding model to finish training.
211+
kb.model.wait_complete()
212+
return kb
213+
197214
def add_files(self, name: str, file_paths: List[str], description: str, knowledge_base: str = None):
198215
"""
199216
Add a list of files to the agent for retrieval.
@@ -223,14 +240,12 @@ def add_files(self, name: str, file_paths: List[str], description: str, knowledg
223240
self.api.upload_file(filename_no_extension, df)
224241

225242
# Insert uploaded files into new knowledge base.
243+
agent = self.get(name)
226244
if knowledge_base is not None:
227245
kb = self.knowledge_bases.get(knowledge_base)
228246
else:
229247
kb_name = f'{name}_{filename_no_extension}_{uuid4()}_kb'
230-
# Create KB if it doesn't exist.
231-
kb = self.knowledge_bases.create(kb_name)
232-
# Wait for underlying embedding model to finish training.
233-
kb.model.wait_complete()
248+
kb = self._create_default_knowledge_base(agent, kb_name)
234249

235250
# Insert the entire file.
236251
kb.insert_files(all_filenames)
@@ -242,7 +257,6 @@ def add_files(self, name: str, file_paths: List[str], description: str, knowledg
242257
'description': description,
243258
}
244259
file_retrieval_skill = self.skills.create(skill_name, 'retrieval', retrieval_params)
245-
agent = self.get(name)
246260
agent.skills.append(file_retrieval_skill)
247261
self.update(agent.name, agent)
248262

@@ -271,6 +285,7 @@ def add_webpages(self, name: str, urls: List[str], description: str, knowledge_b
271285
return
272286
domain = ''
273287
path = ''
288+
agent = self.get(name)
274289
for url in urls:
275290
# Validate URLs.
276291
parsed_url = urlparse(url)
@@ -280,10 +295,7 @@ def add_webpages(self, name: str, urls: List[str], description: str, knowledge_b
280295
kb = self.knowledge_bases.get(knowledge_base)
281296
else:
282297
kb_name = f'{name}_{domain}{path}_{uuid4()}_kb'
283-
# Create KB if it doesn't exist.
284-
kb = self.knowledge_bases.create(kb_name)
285-
# Wait for underlying embedding model to finish training.
286-
kb.model.wait_complete()
298+
kb = self._create_default_knowledge_base(agent, kb_name)
287299

288300
# Insert crawled webpage.
289301
kb.insert_webpages(urls)
@@ -295,7 +307,6 @@ def add_webpages(self, name: str, urls: List[str], description: str, knowledge_b
295307
'description': description,
296308
}
297309
webpage_retrieval_skill = self.skills.create(skill_name, 'retrieval', retrieval_params)
298-
agent = self.get(name)
299310
agent.skills.append(webpage_retrieval_skill)
300311
self.update(agent.name, agent)
301312

0 commit comments

Comments
 (0)