Skip to content

Commit aa5e677

Browse files
committed
Added add_files, add_webpages, & add_database methods to agents
1 parent dd2684a commit aa5e677

File tree

4 files changed

+254
-67
lines changed

4 files changed

+254
-67
lines changed

mindsdb_sdk/agents.py

Lines changed: 168 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88
from mindsdb_sdk.databases import Databases
99
from mindsdb_sdk.knowledge_bases import KnowledgeBases
10-
from mindsdb_sdk.models import Model
10+
from mindsdb_sdk.ml_engines import MLEngines
11+
from mindsdb_sdk.models import Model, Models
1112
from mindsdb_sdk.skills import Skill, Skills
1213
from mindsdb_sdk.utils.objects_collection import CollectionBase
1314

@@ -79,6 +80,14 @@ def __init__(
7980
def completion(self, messages: List[dict]) -> AgentCompletion:
8081
return self.collection.completion(self.name, messages)
8182

83+
def add_files(self, file_paths: List[str], description: str, knowledge_base: str = None):
84+
"""
85+
Add a list of files to the agent for retrieval.
86+
87+
:param file_paths: List of paths to the files to be added.
88+
"""
89+
self.collection.add_files(self.name, file_paths, description, knowledge_base)
90+
8291
def add_file(self, file_path: str, description: str, knowledge_base: str = None):
8392
"""
8493
Add a file to the agent for retrieval.
@@ -87,6 +96,14 @@ def add_file(self, file_path: str, description: str, knowledge_base: str = None)
8796
"""
8897
self.collection.add_file(self.name, file_path, description, knowledge_base)
8998

99+
def add_webpages(self, urls: List[str], description: str, knowledge_base: str = None):
100+
"""
101+
Add a list of crawled URLs to the agent for retrieval.
102+
103+
:param urls: List of URLs to be crawled and added.
104+
"""
105+
self.collection.add_webpages(self.name, urls, description, knowledge_base)
106+
90107
def add_webpage(self, url: str, description: str, knowledge_base: str = None):
91108
"""
92109
Add a crawled URL to the agent for retrieval.
@@ -95,6 +112,16 @@ def add_webpage(self, url: str, description: str, knowledge_base: str = None):
95112
"""
96113
self.collection.add_webpage(self.name, url, description, knowledge_base)
97114

115+
def add_database(self, database: str, tables: List[str], description: str):
116+
"""
117+
Add a database to the agent for retrieval.
118+
119+
:param database: Name of the database to be added.
120+
:param tables: List of tables to be added.
121+
:param description: Description of the database tables. Used by the agent to know when to use SQL skill.
122+
"""
123+
self.collection.add_database(self.name, database, tables, description)
124+
98125
def __repr__(self):
99126
return f'{self.__class__.__name__}(name: {self.name})'
100127

@@ -126,12 +153,14 @@ def from_json(cls, json: dict, collection: CollectionBase):
126153

127154
class Agents(CollectionBase):
128155
"""Collection for agents"""
129-
def __init__(self, api, project: str, knowledge_bases: KnowledgeBases, databases: Databases, skills: Skills = None):
156+
def __init__(self, api, project: str, knowledge_bases: KnowledgeBases, databases: Databases, models: Models, ml_engines: MLEngines, skills: Skills = None):
130157
self.api = api
131158
self.project = project
132159
self.skills = skills or Skills(self.api, project)
133160
self.databases = databases
134161
self.knowledge_bases = knowledge_bases
162+
self.ml_engines = ml_engines
163+
self.models = models
135164

136165
def list(self) -> List[Agent]:
137166
"""
@@ -165,43 +194,46 @@ def completion(self, name: str, messages: List[dict]) -> AgentCompletion:
165194
data = self.api.agent_completion(self.project, name, messages)
166195
return AgentCompletion(data['message']['content'])
167196

168-
def add_file(self, name: str, file_path: str, description: str, knowledge_base: str = None):
197+
def add_files(self, name: str, file_paths: List[str], description: str, knowledge_base: str = None):
169198
"""
170-
Add a file to the agent for retrieval.
199+
Add a list of files to the agent for retrieval.
171200
172201
:param name: Name of the agent
173-
:param file_path: Path to the file to be added, or name of existing file.
202+
:param file_paths: List of paths to the files to be added.
174203
:param description: Description of the file. Used by agent to know when to do retrieval
175204
:param knowledge_base: Name of an existing knowledge base to be used. Will create a default knowledge base if not given.
176205
"""
177-
filename = file_path.split('/')[-1]
178-
filename_no_extension = filename.split('.')[0]
179-
try:
180-
_ = self.api.get_file_metadata(filename_no_extension)
181-
except HTTPError as e:
182-
if e.response.status_code >= 400 and e.response.status_code != 404:
183-
raise e
184-
# Upload file if it doesn't exist.
185-
with open(file_path, 'rb') as file:
186-
content = file.read()
187-
df = pd.DataFrame.from_records([{'content': content}])
188-
self.api.upload_file(filename_no_extension, df)
189-
190-
# Insert uploaded file into new knowledge base.
206+
if not file_paths:
207+
return
208+
filename_no_extension = ''
209+
all_filenames = []
210+
for file_path in file_paths:
211+
filename = file_path.split('/')[-1]
212+
filename_no_extension = filename.split('.')[0]
213+
all_filenames.append(filename_no_extension)
214+
try:
215+
_ = self.api.get_file_metadata(filename_no_extension)
216+
except HTTPError as e:
217+
if e.response.status_code >= 400 and e.response.status_code != 404:
218+
raise e
219+
# Upload file if it doesn't exist.
220+
with open(file_path, 'rb') as file:
221+
content = file.read()
222+
df = pd.DataFrame.from_records([{'content': content}])
223+
self.api.upload_file(filename_no_extension, df)
224+
225+
# Insert uploaded files into new knowledge base.
191226
if knowledge_base is not None:
192227
kb = self.knowledge_bases.get(knowledge_base)
193228
else:
194-
kb_name = f'{name}_{filename_no_extension}_kb'
195-
try:
196-
kb = self.knowledge_bases.get(kb_name)
197-
except AttributeError as e:
198-
# Create KB if it doesn't exist.
199-
kb = self.knowledge_bases.create(kb_name)
200-
# Wait for underlying embedding model to finish training.
201-
kb.model.wait_complete()
229+
kb_name = f'{name}_{filename_no_extension}_{uuid4()}_kb'
230+
# Create KB if it doesn't exist.
231+
kb = self.knowledge_bases.create(kb_name)
232+
# Wait for underlying embedding model to finish training.
233+
kb.model.wait_complete()
202234

203235
# Insert the entire file.
204-
kb.insert_files([filename_no_extension])
236+
kb.insert_files(all_filenames)
205237

206238
# Make sure skill name is unique.
207239
skill_name = f'{filename_no_extension}_retrieval_skill_{uuid4()}'
@@ -214,32 +246,47 @@ def add_file(self, name: str, file_path: str, description: str, knowledge_base:
214246
agent.skills.append(file_retrieval_skill)
215247
self.update(agent.name, agent)
216248

217-
def add_webpage(self, name: str, url: str, description: str, knowledge_base: str = None):
249+
250+
def add_file(self, name: str, file_path: str, description: str, knowledge_base: str = None):
218251
"""
219-
Add a webpage to the agent for retrieval.
252+
Add a file to the agent for retrieval.
220253
221254
:param name: Name of the agent
222-
:param file_path: URL of the webpage to be added, or name of existing webpage.
223-
:param description: Description of the webpage. Used by agent to know when to do retrieval.
255+
:param file_path: Path to the file to be added, or name of existing file.
256+
:param description: Description of the file. Used by agent to know when to do retrieval
224257
:param knowledge_base: Name of an existing knowledge base to be used. Will create a default knowledge base if not given.
225258
"""
226-
parsed_url = urlparse(url)
227-
domain = parsed_url.netloc.replace('.', '_')
228-
path = parsed_url.path.replace('/', '_')
259+
self.add_files(name, [file_path], description, knowledge_base)
260+
261+
def add_webpages(self, name: str, urls: List[str], description: str, knowledge_base: str = None):
262+
"""
263+
Add a list of webpages to the agent for retrieval.
264+
265+
:param name: Name of the agent
266+
:param urls: List of URLs of the webpages to be added.
267+
:param description: Description of the webpages. Used by agent to know when to do retrieval.
268+
:param knowledge_base: Name of an existing knowledge base to be used. Will create a default knowledge base if not given.
269+
"""
270+
if not urls:
271+
return
272+
domain = ''
273+
path = ''
274+
for url in urls:
275+
# Validate URLs.
276+
parsed_url = urlparse(url)
277+
domain = parsed_url.netloc.replace('.', '_')
278+
path = parsed_url.path.replace('/', '_')
229279
if knowledge_base is not None:
230280
kb = self.knowledge_bases.get(knowledge_base)
231281
else:
232-
kb_name = f'{name}_{domain}{path}_kb'
233-
try:
234-
kb = self.knowledge_bases.get(kb_name)
235-
except AttributeError:
236-
# Create KB if it doesn't exist.
237-
kb = self.knowledge_bases.create(kb_name)
238-
# Wait for underlying embedding model to finish training.
239-
kb.model.wait_complete()
282+
kb_name = f'{name}_{domain}{path}_{uuid4()}_kb'
283+
# Create KB if it doesn't exist.
284+
kb = self.knowledge_bases.create(kb_name)
285+
# Wait for underlying embedding model to finish training.
286+
kb.model.wait_complete()
240287

241288
# Insert crawled webpage.
242-
kb.insert_webpages([url])
289+
kb.insert_webpages(urls)
243290

244291
# Make sure skill name is unique.
245292
skill_name = f'{domain}{path}_retrieval_skill_{uuid4()}'
@@ -252,10 +299,84 @@ def add_webpage(self, name: str, url: str, description: str, knowledge_base: str
252299
agent.skills.append(webpage_retrieval_skill)
253300
self.update(agent.name, agent)
254301

302+
def add_webpage(self, name: str, url: str, description: str, knowledge_base: str = None):
303+
"""
304+
Add a webpage to the agent for retrieval.
305+
306+
:param name: Name of the agent
307+
:param file_path: URL of the webpage to be added, or name of existing webpage.
308+
:param description: Description of the webpage. Used by agent to know when to do retrieval.
309+
:param knowledge_base: Name of an existing knowledge base to be used. Will create a default knowledge base if not given.
310+
"""
311+
self.add_webpages(name, [url], description, knowledge_base)
312+
313+
def add_database(self, name: str, database: str, tables: List[str], description: str):
314+
"""
315+
Add a database to the agent for retrieval.
316+
317+
:param name: Name of the agent
318+
:param database: Name of the database to be added.
319+
:param tables: List of tables to be added.
320+
:param description: Description of the database. Used by agent to know when to do retrieval.
321+
"""
322+
# Make sure database exists.
323+
db = self.databases.get(database)
324+
# Make sure tables exist.
325+
all_table_names = set([t.name for t in db.tables.list()])
326+
for t in tables:
327+
if t not in all_table_names:
328+
raise ValueError(f'Table {t} does not exist in database {database}.')
329+
330+
# Make sure skill name is unique.
331+
skill_name = f'{database}_sql_skill_{uuid4()}'
332+
sql_params = {
333+
'database': database,
334+
'tables': tables,
335+
'description': description,
336+
}
337+
database_sql_skill = self.skills.create(skill_name, 'sql', sql_params)
338+
agent = self.get(name)
339+
agent.skills.append(database_sql_skill)
340+
self.update(agent.name, agent)
341+
342+
def _create_ml_engine_if_not_exists(self, name: str = 'langchain'):
343+
try:
344+
_ = self.ml_engines.get('langchain')
345+
except Exception:
346+
# Create the engine if it doesn't exist.
347+
_ = self.ml_engines.create('langchain', handler='langchain')
348+
349+
def _create_model_if_not_exists(self, name: str, model: Union[Model, dict]) -> Model:
350+
# Create langchain engine if it doesn't exist.
351+
self._create_ml_engine_if_not_exists()
352+
# Create a default model if it doesn't exist.
353+
default_model_params = {
354+
'predict': 'answer',
355+
'mode': 'retrieval',
356+
'engine': 'langchain',
357+
'prompt_template': 'Answer the user"s question in a helpful way: {{question}}',
358+
# Use GPT-4 by default.
359+
'provider': 'openai',
360+
'model_name': 'gpt-4'
361+
}
362+
if model is None:
363+
return self.models.create(
364+
f'{name}_default_model',
365+
**default_model_params
366+
)
367+
if isinstance(model, dict):
368+
default_model_params.update(model)
369+
# Create model with passed in params.
370+
return self.models.create(
371+
f'{name}_default_model',
372+
**default_model_params
373+
)
374+
return model
375+
255376
def create(
256377
self,
257378
name: str,
258-
model: Model,
379+
model: Union[Model, dict] = None,
259380
skills: List[Union[Skill, str]] = None,
260381
params: dict = None) -> Agent:
261382
"""
@@ -280,6 +401,8 @@ def create(
280401
_ = self.skills.create(skill.name, skill.type, skill.params)
281402
skill_names.append(skill.name)
282403

404+
# Create a default model if it doesn't exist.
405+
model = self._create_model_if_not_exists(name, model)
283406
data = self.api.create_agent(self.project, name, model.name, skill_names, params)
284407
return Agent.from_json(data, self)
285408

mindsdb_sdk/projects.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from mindsdb_sdk.agents import Agents
1010
from mindsdb_sdk.databases import Databases
11+
from mindsdb_sdk.ml_engines import MLEngines
1112
from mindsdb_sdk.skills import Skills
1213
from mindsdb_sdk.utils.objects_collection import CollectionBase
1314

@@ -49,7 +50,7 @@ class Project:
4950
5051
"""
5152

52-
def __init__(self, api, name, agents: Agents = None, skills: Skills = None, knowledge_bases: KnowledgeBases = None, databases: Databases = None):
53+
def __init__(self, api, name, agents: Agents = None, skills: Skills = None, knowledge_bases: KnowledgeBases = None, databases: Databases = None, ml_engines: MLEngines = None):
5354
self.name = name
5455
self.api = api
5556

@@ -81,7 +82,7 @@ def __init__(self, api, name, agents: Agents = None, skills: Skills = None, know
8182
self.knowledge_bases = knowledge_bases or KnowledgeBases(self, api)
8283

8384
self.skills = skills or Skills(api, name)
84-
self.agents = agents or Agents(api, name, self.knowledge_bases, self.databases, self.skills)
85+
self.agents = agents or Agents(api, name, self.knowledge_bases, self.databases, self.models, ml_engines, self.skills)
8586

8687
def __repr__(self):
8788
return f'{self.__class__.__name__}({self.name})'

mindsdb_sdk/server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ def __init__(self, api, skills: Skills = None, agents: Agents = None):
2929
# server is also mindsdb project
3030
project_name = 'mindsdb'
3131
self.databases = Databases(api)
32-
super().__init__(api, project_name, skills=skills, agents=agents, databases=self.databases)
32+
self.ml_engines = MLEngines(api)
33+
super().__init__(api, project_name, skills=skills, agents=agents, databases=self.databases, ml_engines=self.ml_engines)
3334

3435
self.projects = Projects(api)
3536

@@ -46,7 +47,6 @@ def __init__(self, api, skills: Skills = None, agents: Agents = None):
4647
self.create_database = self.databases.create
4748
self.drop_database = self.databases.drop
4849

49-
self.ml_engines = MLEngines(self.api)
5050

5151
self.ml_handlers = Handlers(self.api, 'ml')
5252
self.data_handlers = Handlers(self.api, 'data')

0 commit comments

Comments
 (0)