Skip to content

Commit cf74f06

Browse files
authored
Merge pull request #106 from mindsdb/retrieval-fix
Use Retrieval Skill for Adding Files to Agents
2 parents 57122d6 + d66f40b commit cf74f06

File tree

5 files changed

+55
-30
lines changed

5 files changed

+55
-30
lines changed

mindsdb_sdk/agents.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,15 +192,15 @@ def add_file(self, name: str, file_path: str, description: str, knowledge_base:
192192
kb.model.wait_complete()
193193

194194
# Insert the entire file.
195-
kb.insert(self.databases.files.tables.get(filename_no_extension))
195+
kb.insert_files([filename_no_extension])
196196

197197
# Make sure skill name is unique.
198198
skill_name = f'{filename_no_extension}_retrieval_skill_{uuid4()}'
199199
retrieval_params = {
200200
'source': kb.name,
201201
'description': description,
202202
}
203-
file_retrieval_skill = self.skills.create(skill_name, 'knowledge_base', retrieval_params)
203+
file_retrieval_skill = self.skills.create(skill_name, 'retrieval', retrieval_params)
204204
agent = self.get(name)
205205
agent.skills.append(file_retrieval_skill)
206206
self.update(agent.name, agent)

mindsdb_sdk/connectors/rest_api.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,3 +323,18 @@ def delete_skill(self, project: str, name: str):
323323
url = self.url + f'/api/projects/{project}/skills/{name}'
324324
r = self.session.delete(url)
325325
_raise_for_status(r)
326+
327+
# Knowledge Base operations.
328+
@_try_relogin
329+
def insert_files_into_knowledge_base(self, project: str, knowledge_base_name: str, file_names: List[str]):
330+
r = self.session.put(
331+
self.url + f'/api/projects/{project}/knowledge_bases/{knowledge_base_name}',
332+
json={
333+
'knowledge_base': {
334+
'files': file_names
335+
}
336+
}
337+
)
338+
_raise_for_status(r)
339+
340+
return r.json()

mindsdb_sdk/knowledge_bases.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ class KnowledgeBase(Query):
3232
3333
"""
3434

35-
def __init__(self, project, data: dict):
36-
35+
def __init__(self, api, project, data: dict):
36+
self.api = api
3737
self.project = project
3838
self.name = data['name']
3939
self.table_name = Identifier(parts=[self.project.name, self.name])
@@ -118,6 +118,12 @@ def _update_query(self):
118118
ast_query.limit = Constant(self._limit)
119119
self.sql = ast_query.to_string()
120120

121+
def insert_files(self, file_paths: List[str]):
122+
"""
123+
Insert data from file to knowledge base
124+
"""
125+
self.api.insert_files_into_knowledge_base(self.project.name, self.name, file_paths)
126+
121127
def insert(self, data: Union[pd.DataFrame, Query, dict]):
122128
"""
123129
Insert data to knowledge base
@@ -210,7 +216,7 @@ def _list(self, name: str = None) -> List[KnowledgeBase]:
210216
df = df.rename(columns=cols_map)
211217

212218
return [
213-
KnowledgeBase(self.project, item)
219+
KnowledgeBase(self.api, self.project, item)
214220
for item in df.to_dict('records')
215221
]
216222

mindsdb_sdk/skills.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class Skill():
1818
1919
Create a new SQL skill:
2020
21-
>>> text_to_sql_skill = skills.create('text_to_sql', 'sql', { 'tables': ['my_table'], 'database': 'my_database' })
21+
>>> text_to_sql_skill = skills.create('text_to_sql', 'sql', { 'tables': ['my_table'], 'database': 'my_database', 'description': 'my_description'})
2222
2323
Update a skill:
2424
@@ -50,30 +50,34 @@ def __repr__(self):
5050

5151
@classmethod
5252
def from_json(cls, json: dict):
53+
name = json['name']
54+
type = json['type']
55+
params = json['params']
5356
if json['type'] == 'sql':
54-
return SQLSkill(json['name'], json['params']['tables'], json['params']['database'])
57+
return SQLSkill(name, params['tables'], params['database'], params.get('description', ''))
5558
if json['type'] == 'retrieval':
56-
return RetrievalSkill(json['name'], json['params']['knowledge_base'], json['params']['description'])
57-
return Skill(json['name'], json['type'], json['params'])
59+
return RetrievalSkill(name, params['source'], params.get('description', ''))
60+
return Skill(name, type, params)
5861

5962

6063
class SQLSkill(Skill):
6164
"""Represents a MindsDB skill for agents to interact with MindsDB databases"""
62-
def __init__(self, name: str, tables: List[str], database: str):
65+
def __init__(self, name: str, tables: List[str], database: str, description: str):
6366
params = {
6467
'database': database,
6568
'tables': tables,
69+
'description': description
6670
}
6771
super().__init__(name, 'sql', params)
6872

6973
class RetrievalSkill(Skill):
7074
"""Represents a MindsDB skill for agents to interact with MindsDB data sources"""
7175
def __init__(self, name: str, knowledge_base: str, description: str):
7276
params = {
73-
'knowledge_base': knowledge_base,
77+
'source': knowledge_base,
7478
'description': description
7579
}
76-
super().__init__(name, 'knowledge_base', params)
80+
super().__init__(name, 'retrieval', params)
7781

7882

7983
class Skills(CollectionBase):
@@ -114,7 +118,7 @@ def create(self, name: str, type: str, params: dict = None) -> Skill:
114118
"""
115119
_ = self.api.create_skill(self.project, name, type, params)
116120
if type == 'sql':
117-
return SQLSkill(name, params['tables'], params['database'])
121+
return SQLSkill(name, params['tables'], params['database'], params['description'])
118122
return Skill(name, type, params)
119123

120124
def update(self, name: str, updated_skill: Skill) -> Skill:

tests/test_sdk.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,7 +1295,7 @@ def test_create(self, mock_post):
12951295
'name': 'test_skill',
12961296
'project_id': 1,
12971297
'type': 'sql',
1298-
'params': {'tables': ['test_table'], 'database': 'test_database'},
1298+
'params': {'tables': ['test_table'], 'database': 'test_database', 'description': 'test_description'},
12991299
}],
13001300
'params': {'k1': 'v1'},
13011301
'created_at': created_at,
@@ -1308,7 +1308,7 @@ def test_create(self, mock_post):
13081308
new_agent = server.agents.create(
13091309
name='test_agent',
13101310
model=Model(None, {'name':'m1'}),
1311-
skills=[SQLSkill('test_skill', ['test_table'], 'test_database')],
1311+
skills=[SQLSkill('test_skill', ['test_table'], 'test_database', 'test_description')],
13121312
params={'k1': 'v1'}
13131313
)
13141314
# Check API call.
@@ -1329,11 +1329,11 @@ def test_create(self, mock_post):
13291329
'skill': {
13301330
'name': 'test_skill',
13311331
'type': 'sql',
1332-
'params': {'database': 'test_database', 'tables': ['test_table']}
1332+
'params': {'database': 'test_database', 'tables': ['test_table'], 'description': 'test_description'}
13331333
}
13341334
}
13351335

1336-
expected_skill = SQLSkill('test_skill', ['test_table'], 'test_database')
1336+
expected_skill = SQLSkill('test_skill', ['test_table'], 'test_database', 'test_description')
13371337
expected_agent = Agent(
13381338
'test_agent',
13391339
'test_model',
@@ -1363,7 +1363,7 @@ def test_update(self, mock_get, mock_put, _):
13631363
'name': 'updated_skill',
13641364
'project_id': 1,
13651365
'type': 'sql',
1366-
'params': {'tables': ['updated_table'], 'database': 'updated_database'},
1366+
'params': {'tables': ['updated_table'], 'database': 'updated_database', 'description': 'test_description'},
13671367
}],
13681368
'params': {'k2': 'v2'},
13691369
'created_at': created_at,
@@ -1385,7 +1385,7 @@ def test_update(self, mock_get, mock_put, _):
13851385
expected_agent = Agent(
13861386
'test_agent',
13871387
'updated_model',
1388-
[SQLSkill('updated_skill', ['updated_table'], 'updated_database')],
1388+
[SQLSkill('updated_skill', ['updated_table'], 'updated_database', 'test_description')],
13891389
{'k2': 'v2'},
13901390
created_at,
13911391
updated_at
@@ -1453,14 +1453,14 @@ def test_list(self, mock_get):
14531453
'id': 1,
14541454
'name': 'test_skill',
14551455
'project_id': 1,
1456-
'params': {'tables': ['test_table'], 'database': 'test_database'},
1456+
'params': {'tables': ['test_table'], 'database': 'test_database', 'description': 'test_description' },
14571457
'type': 'sql'
14581458
}
14591459
])
14601460
all_skills = server.skills.list()
14611461
assert len(all_skills) == 1
14621462

1463-
expected_skill = SQLSkill('test_skill', ['test_table'], 'test_database')
1463+
expected_skill = SQLSkill('test_skill', ['test_table'], 'test_database', 'test_description')
14641464
assert all_skills[0] == expected_skill
14651465

14661466
@patch('requests.Session.get')
@@ -1471,14 +1471,14 @@ def test_get(self, mock_get):
14711471
'id': 1,
14721472
'name': 'test_skill',
14731473
'project_id': 1,
1474-
'params': {'tables': ['test_table'], 'database': 'test_database'},
1474+
'params': {'tables': ['test_table'], 'database': 'test_database', 'description': 'test_description'},
14751475
'type': 'sql'
14761476
}
14771477
)
14781478
skill = server.skills.get('test_skill')
14791479
# Check API call.
14801480
assert mock_get.call_args[0][0] == f'{DEFAULT_LOCAL_API_URL}/api/projects/mindsdb/skills/test_skill'
1481-
expected_skill = SQLSkill('test_skill', ['test_table'], 'test_database')
1481+
expected_skill = SQLSkill('test_skill', ['test_table'], 'test_database', 'test_description')
14821482
assert skill == expected_skill
14831483

14841484
@patch('requests.Session.post')
@@ -1487,7 +1487,7 @@ def test_create(self, mock_post):
14871487
'id': 1,
14881488
'name': 'test_skill',
14891489
'project_id': 1,
1490-
'params': {'k1': 'v1'},
1490+
'params': {'tables': ['test_table'], 'database': 'test_database', 'description': 'test_description'},
14911491
'type': 'test'
14921492
}
14931493
response_mock(mock_post, data)
@@ -1497,18 +1497,18 @@ def test_create(self, mock_post):
14971497
new_skill = server.skills.create(
14981498
'test_skill',
14991499
'sql',
1500-
params={'tables': ['test_table'], 'database': 'test_database'}
1500+
params={'tables': ['test_table'], 'database': 'test_database', 'description': 'test_description'}
15011501
)
15021502
# Check API call.
15031503
assert mock_post.call_args[0][0] == f'{DEFAULT_LOCAL_API_URL}/api/projects/mindsdb/skills'
15041504
assert mock_post.call_args[1]['json'] == {
15051505
'skill': {
15061506
'name': 'test_skill',
15071507
'type': 'sql',
1508-
'params': {'database': 'test_database', 'tables': ['test_table']}
1508+
'params': {'database': 'test_database', 'tables': ['test_table'], 'description': 'test_description'}
15091509
}
15101510
}
1511-
expected_skill = SQLSkill('test_skill', ['test_table'], 'test_database')
1511+
expected_skill = SQLSkill('test_skill', ['test_table'], 'test_database', 'test_description')
15121512

15131513
assert new_skill == expected_skill
15141514

@@ -1518,13 +1518,13 @@ def test_update(self, mock_put):
15181518
'id': 1,
15191519
'name': 'test_skill',
15201520
'project_id': 1,
1521-
'params': {'tables': ['updated_table'], 'database': 'updated_database'},
1521+
'params': {'tables': ['updated_table'], 'database': 'updated_database', 'description': 'updated_description'},
15221522
'type': 'sql'
15231523
}
15241524
response_mock(mock_put, data)
15251525

15261526
server = mindsdb_sdk.connect()
1527-
expected_skill = SQLSkill('test_skill', ['updated_table'], 'updated_database')
1527+
expected_skill = SQLSkill('test_skill', ['updated_table'], 'updated_database', 'updated_description')
15281528

15291529
updated_skill = server.skills.update('test_skill', expected_skill)
15301530
# Check API call.
@@ -1533,7 +1533,7 @@ def test_update(self, mock_put):
15331533
'skill': {
15341534
'name': 'test_skill',
15351535
'type': 'sql',
1536-
'params': {'tables': ['updated_table'], 'database': 'updated_database'}
1536+
'params': {'tables': ['updated_table'], 'database': 'updated_database', 'description': 'updated_description'}
15371537
}
15381538
}
15391539

0 commit comments

Comments
 (0)