|
1 | 1 | from requests.exceptions import HTTPError |
2 | 2 | from typing import List, Union |
| 3 | +from uuid import uuid4 |
3 | 4 | import datetime |
| 5 | +import pandas as pd |
4 | 6 |
|
| 7 | +from mindsdb_sdk.databases import Databases |
| 8 | +from mindsdb_sdk.knowledge_bases import KnowledgeBases |
5 | 9 | from mindsdb_sdk.models import Model |
6 | 10 | from mindsdb_sdk.skills import Skill, Skills |
7 | 11 | from mindsdb_sdk.utils.objects_collection import CollectionBase |
@@ -74,6 +78,14 @@ def __init__( |
74 | 78 | def completion(self, messages: List[dict]) -> AgentCompletion: |
75 | 79 | return self.collection.completion(self.name, messages) |
76 | 80 |
|
| 81 | + def add_file(self, file_path: str, description: str, knowledge_base: str = None): |
| 82 | + """ |
| 83 | + Add a file to the agent for retrieval. |
| 84 | +
|
| 85 | + :param file_path: Path to the file to be added. |
| 86 | + """ |
| 87 | + self.collection.add_file(self.name, file_path, description, knowledge_base) |
| 88 | + |
77 | 89 | def __repr__(self): |
78 | 90 | return f'{self.__class__.__name__}(name: {self.name})' |
79 | 91 |
|
@@ -105,10 +117,12 @@ def from_json(cls, json: dict, collection: CollectionBase): |
105 | 117 |
|
106 | 118 | class Agents(CollectionBase): |
107 | 119 | """Collection for agents""" |
108 | | - def __init__(self, api, project: str, skills: Skills = None): |
| 120 | + def __init__(self, api, project: str, knowledge_bases: KnowledgeBases, databases: Databases, skills: Skills = None): |
109 | 121 | self.api = api |
110 | 122 | self.project = project |
111 | 123 | self.skills = skills or Skills(self.api, project) |
| 124 | + self.databases = databases |
| 125 | + self.knowledge_bases = knowledge_bases |
112 | 126 |
|
113 | 127 | def list(self) -> List[Agent]: |
114 | 128 | """ |
@@ -142,6 +156,55 @@ def completion(self, name: str, messages: List[dict]) -> AgentCompletion: |
142 | 156 | data = self.api.agent_completion(self.project, name, messages) |
143 | 157 | return AgentCompletion(data['message']['content']) |
144 | 158 |
|
| 159 | + def add_file(self, name: str, file_path: str, description: str, knowledge_base: str = None): |
| 160 | + """ |
| 161 | + Add a file to the agent for retrieval. |
| 162 | +
|
| 163 | + :param name: Name of the agent |
| 164 | + :param file_path: Path to the file to be added, or name of existing file. |
| 165 | + :param description: Description of the file. Used by agent to know when to do retrieval |
| 166 | + :param knowledge_base: Name of an existing knowledge base to be used. Will create a default knowledge base if not given. |
| 167 | + """ |
| 168 | + filename = file_path.split('/')[-1] |
| 169 | + filename_no_extension = filename.split('.')[0] |
| 170 | + try: |
| 171 | + _ = self.api.get_file_metadata(filename_no_extension) |
| 172 | + except HTTPError as e: |
| 173 | + if e.response.status_code >= 400 and e.response.status_code != 404: |
| 174 | + raise e |
| 175 | + # Upload file if it doesn't exist. |
| 176 | + with open(file_path, 'rb') as file: |
| 177 | + content = file.read() |
| 178 | + df = pd.DataFrame.from_records([{'content': content}]) |
| 179 | + self.api.upload_file(filename_no_extension, df) |
| 180 | + |
| 181 | + # Insert uploaded file into new knowledge base. |
| 182 | + if knowledge_base is not None: |
| 183 | + kb = self.knowledge_bases.get(knowledge_base) |
| 184 | + else: |
| 185 | + kb_name = f'{name}_{filename_no_extension}_kb' |
| 186 | + try: |
| 187 | + kb = self.knowledge_bases.get(kb_name) |
| 188 | + except AttributeError as e: |
| 189 | + # Create KB if it doesn't exist. |
| 190 | + kb = self.knowledge_bases.create(kb_name) |
| 191 | + # Wait for underlying embedding model to finish training. |
| 192 | + kb.model.wait_complete() |
| 193 | + |
| 194 | + # Insert the entire file. |
| 195 | + kb.insert(self.databases.files.tables.get(filename_no_extension)) |
| 196 | + |
| 197 | + # Make sure skill name is unique. |
| 198 | + skill_name = f'{filename_no_extension}_retrieval_skill_{uuid4()}' |
| 199 | + retrieval_params = { |
| 200 | + 'source': kb.name, |
| 201 | + 'description': description, |
| 202 | + } |
| 203 | + file_retrieval_skill = self.skills.create(skill_name, 'knowledge_base', retrieval_params) |
| 204 | + agent = self.get(name) |
| 205 | + agent.skills.append(file_retrieval_skill) |
| 206 | + self.update(agent.name, agent) |
| 207 | + |
145 | 208 | def create( |
146 | 209 | self, |
147 | 210 | name: str, |
@@ -200,7 +263,7 @@ def update(self, name: str, updated_agent: Agent): |
200 | 263 | updated_skills.add(skill.name) |
201 | 264 |
|
202 | 265 | existing_agent = self.api.agent(self.project, name) |
203 | | - existing_skills = set([s.name for s in existing_agent['skills']]) |
| 266 | + existing_skills = set([s['name'] for s in existing_agent['skills']]) |
204 | 267 | skills_to_add = updated_skills.difference(existing_skills) |
205 | 268 | skills_to_remove = existing_skills.difference(updated_skills) |
206 | 269 | data = self.api.update_agent( |
|
0 commit comments