Skip to content

Commit 2ce18ca

Browse files
authored
Merge pull request #102 from mindsdb/agents-retrieval
Add File Retrieval to Agents
2 parents 207f224 + 1868035 commit 2ce18ca

File tree

5 files changed

+99
-9
lines changed

5 files changed

+99
-9
lines changed

mindsdb_sdk/agents.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
from requests.exceptions import HTTPError
22
from typing import List, Union
3+
from uuid import uuid4
34
import datetime
5+
import pandas as pd
46

7+
from mindsdb_sdk.databases import Databases
8+
from mindsdb_sdk.knowledge_bases import KnowledgeBases
59
from mindsdb_sdk.models import Model
610
from mindsdb_sdk.skills import Skill, Skills
711
from mindsdb_sdk.utils.objects_collection import CollectionBase
@@ -74,6 +78,14 @@ def __init__(
7478
def completion(self, messages: List[dict]) -> AgentCompletion:
7579
return self.collection.completion(self.name, messages)
7680

81+
def add_file(self, file_path: str, description: str, knowledge_base: str = None):
82+
"""
83+
Add a file to the agent for retrieval.
84+
85+
:param file_path: Path to the file to be added.
86+
"""
87+
self.collection.add_file(self.name, file_path, description, knowledge_base)
88+
7789
def __repr__(self):
7890
return f'{self.__class__.__name__}(name: {self.name})'
7991

@@ -105,10 +117,12 @@ def from_json(cls, json: dict, collection: CollectionBase):
105117

106118
class Agents(CollectionBase):
107119
"""Collection for agents"""
108-
def __init__(self, api, project: str, skills: Skills = None):
120+
def __init__(self, api, project: str, knowledge_bases: KnowledgeBases, databases: Databases, skills: Skills = None):
109121
self.api = api
110122
self.project = project
111123
self.skills = skills or Skills(self.api, project)
124+
self.databases = databases
125+
self.knowledge_bases = knowledge_bases
112126

113127
def list(self) -> List[Agent]:
114128
"""
@@ -142,6 +156,55 @@ def completion(self, name: str, messages: List[dict]) -> AgentCompletion:
142156
data = self.api.agent_completion(self.project, name, messages)
143157
return AgentCompletion(data['message']['content'])
144158

159+
def add_file(self, name: str, file_path: str, description: str, knowledge_base: str = None):
160+
"""
161+
Add a file to the agent for retrieval.
162+
163+
:param name: Name of the agent
164+
:param file_path: Path to the file to be added, or name of existing file.
165+
:param description: Description of the file. Used by agent to know when to do retrieval
166+
:param knowledge_base: Name of an existing knowledge base to be used. Will create a default knowledge base if not given.
167+
"""
168+
filename = file_path.split('/')[-1]
169+
filename_no_extension = filename.split('.')[0]
170+
try:
171+
_ = self.api.get_file_metadata(filename_no_extension)
172+
except HTTPError as e:
173+
if e.response.status_code >= 400 and e.response.status_code != 404:
174+
raise e
175+
# Upload file if it doesn't exist.
176+
with open(file_path, 'rb') as file:
177+
content = file.read()
178+
df = pd.DataFrame.from_records([{'content': content}])
179+
self.api.upload_file(filename_no_extension, df)
180+
181+
# Insert uploaded file into new knowledge base.
182+
if knowledge_base is not None:
183+
kb = self.knowledge_bases.get(knowledge_base)
184+
else:
185+
kb_name = f'{name}_{filename_no_extension}_kb'
186+
try:
187+
kb = self.knowledge_bases.get(kb_name)
188+
except AttributeError as e:
189+
# Create KB if it doesn't exist.
190+
kb = self.knowledge_bases.create(kb_name)
191+
# Wait for underlying embedding model to finish training.
192+
kb.model.wait_complete()
193+
194+
# Insert the entire file.
195+
kb.insert(self.databases.files.tables.get(filename_no_extension))
196+
197+
# Make sure skill name is unique.
198+
skill_name = f'{filename_no_extension}_retrieval_skill_{uuid4()}'
199+
retrieval_params = {
200+
'source': kb.name,
201+
'description': description,
202+
}
203+
file_retrieval_skill = self.skills.create(skill_name, 'knowledge_base', retrieval_params)
204+
agent = self.get(name)
205+
agent.skills.append(file_retrieval_skill)
206+
self.update(agent.name, agent)
207+
145208
def create(
146209
self,
147210
name: str,
@@ -200,7 +263,7 @@ def update(self, name: str, updated_agent: Agent):
200263
updated_skills.add(skill.name)
201264

202265
existing_agent = self.api.agent(self.project, name)
203-
existing_skills = set([s.name for s in existing_agent['skills']])
266+
existing_skills = set([s['name'] for s in existing_agent['skills']])
204267
skills_to_add = updated_skills.difference(existing_skills)
205268
skills_to_remove = existing_skills.difference(updated_skills)
206269
data = self.api.update_agent(

mindsdb_sdk/connectors/rest_api.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,19 @@ def upload_file(self, name: str, df: pd.DataFrame):
146146
)
147147
_raise_for_status(r)
148148

149+
@_try_relogin
150+
def get_file_metadata(self, name: str) -> dict:
151+
# No endpoint currently to get single file.
152+
url = self.url + f'/api/files/'
153+
r = self.session.get(url)
154+
_raise_for_status(r)
155+
all_file_metadata = r.json()
156+
for metadata in all_file_metadata:
157+
if metadata.get('name', None) == name:
158+
return metadata
159+
r.status_code = 404
160+
raise requests.HTTPError(f'Not found: No file named {name} found', response=r)
161+
149162
@_try_relogin
150163
def upload_byom(self, name: str, code: str, requirements: str):
151164

mindsdb_sdk/projects.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from mindsdb_sdk.utils.sql import dict_to_binary_op
88

99
from mindsdb_sdk.agents import Agents
10+
from mindsdb_sdk.databases import Databases
1011
from mindsdb_sdk.skills import Skills
1112
from mindsdb_sdk.utils.objects_collection import CollectionBase
1213

@@ -48,7 +49,7 @@ class Project:
4849
4950
"""
5051

51-
def __init__(self, api, name):
52+
def __init__(self, api, name, agents: Agents = None, skills: Skills = None, knowledge_bases: KnowledgeBases = None, databases: Databases = None):
5253
self.name = name
5354
self.api = api
5455

@@ -76,9 +77,11 @@ def __init__(self, api, name):
7677
self.create_job = self.jobs.create
7778
self.drop_job = self.jobs.drop
7879

79-
self.skills = Skills(api, name)
80-
self.agents = Agents(api, name, self.skills)
81-
self.knowledge_bases = KnowledgeBases(self, api)
80+
self.databases = databases or Databases(api)
81+
self.knowledge_bases = knowledge_bases or KnowledgeBases(self, api)
82+
83+
self.skills = skills or Skills(api, name)
84+
self.agents = agents or Agents(api, name, self.knowledge_bases, self.databases, self.skills)
8285

8386
def __repr__(self):
8487
return f'{self.__class__.__name__}({self.name})'

mindsdb_sdk/server.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@ class Server(Project):
2525
2626
"""
2727

28-
def __init__(self, api):
28+
def __init__(self, api, skills: Skills = None, agents: Agents = None):
2929
# server is also mindsdb project
3030
project_name = 'mindsdb'
31-
super().__init__(api, project_name)
31+
self.databases = Databases(api)
32+
super().__init__(api, project_name, skills=skills, agents=agents, databases=self.databases)
3233

3334
self.projects = Projects(api)
3435

@@ -38,7 +39,6 @@ def __init__(self, api):
3839
self.create_project = self.projects.create
3940
self.drop_project = self.projects.drop
4041

41-
self.databases = Databases(api)
4242

4343
# old api
4444
self.get_database = self.databases.get

mindsdb_sdk/skills.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def __repr__(self):
5252
def from_json(cls, json: dict):
5353
if json['type'] == 'sql':
5454
return SQLSkill(json['name'], json['params']['tables'], json['params']['database'])
55+
if json['type'] == 'retrieval':
56+
return RetrievalSkill(json['name'], json['params']['knowledge_base'], json['params']['description'])
5557
return Skill(json['name'], json['type'], json['params'])
5658

5759

@@ -64,6 +66,15 @@ def __init__(self, name: str, tables: List[str], database: str):
6466
}
6567
super().__init__(name, 'sql', params)
6668

69+
class RetrievalSkill(Skill):
70+
"""Represents a MindsDB skill for agents to interact with MindsDB data sources"""
71+
def __init__(self, name: str, knowledge_base: str, description: str):
72+
params = {
73+
'knowledge_base': knowledge_base,
74+
'description': description
75+
}
76+
super().__init__(name, 'knowledge_base', params)
77+
6778

6879
class Skills(CollectionBase):
6980
"""Collection for skills"""

0 commit comments

Comments
 (0)