Skip to content

Commit d66f40b

Browse files
committed
Use new KB endpoint for inserting files
1 parent f04df80 commit d66f40b

File tree

5 files changed

+46
-22
lines changed

5 files changed

+46
-22
lines changed

mindsdb_sdk/agents.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def add_file(self, name: str, file_path: str, description: str, knowledge_base:
192192
kb.model.wait_complete()
193193

194194
# Insert the entire file.
195-
kb.insert(self.databases.files.tables.get(filename_no_extension))
195+
kb.insert_files([filename_no_extension])
196196

197197
# Make sure skill name is unique.
198198
skill_name = f'{filename_no_extension}_retrieval_skill_{uuid4()}'

mindsdb_sdk/connectors/rest_api.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,3 +319,18 @@ def delete_skill(self, project: str, name: str):
319319
url = self.url + f'/api/projects/{project}/skills/{name}'
320320
r = self.session.delete(url)
321321
_raise_for_status(r)
322+
323+
# Knowledge Base operations.
324+
@_try_relogin
325+
def insert_files_into_knowledge_base(self, project: str, knowledge_base_name: str, file_names: List[str]):
326+
r = self.session.put(
327+
self.url + f'/api/projects/{project}/knowledge_bases/{knowledge_base_name}',
328+
json={
329+
'knowledge_base': {
330+
'files': file_names
331+
}
332+
}
333+
)
334+
_raise_for_status(r)
335+
336+
return r.json()

mindsdb_sdk/knowledge_bases.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ class KnowledgeBase(Query):
3131
3232
"""
3333

34-
def __init__(self, project, data: dict):
35-
34+
def __init__(self, api, project, data: dict):
35+
self.api = api
3636
self.project = project
3737
self.name = data['name']
3838

@@ -118,6 +118,12 @@ def _update_query(self):
118118
ast_query.limit = Constant(self._limit)
119119
self.sql = ast_query.to_string()
120120

121+
def insert_files(self, file_paths: List[str]):
122+
"""
123+
Insert data from file to knowledge base
124+
"""
125+
self.api.insert_files_into_knowledge_base(self.project.name, self.name, file_paths)
126+
121127
def insert(self, data: Union[pd.DataFrame, Query, dict]):
122128
"""
123129
Insert data to knowledge base
@@ -202,7 +208,7 @@ def _list(self, name: str = None) -> List[KnowledgeBase]:
202208
df = df.rename(columns=cols_map)
203209

204210
return [
205-
KnowledgeBase(self.project, item)
211+
KnowledgeBase(self.api, self.project, item)
206212
for item in df.to_dict('records')
207213
]
208214

mindsdb_sdk/skills.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class Skill():
1818
1919
Create a new SQL skill:
2020
21-
>>> text_to_sql_skill = skills.create('text_to_sql', 'sql', { 'tables': ['my_table'], 'database': 'my_database' })
21+
>>> text_to_sql_skill = skills.create('text_to_sql', 'sql', { 'tables': ['my_table'], 'database': 'my_database', 'description': 'my_description'})
2222
2323
Update a skill:
2424
@@ -50,11 +50,14 @@ def __repr__(self):
5050

5151
@classmethod
5252
def from_json(cls, json: dict):
53+
name = json['name']
54+
type = json['type']
55+
params = json['params']
5356
if json['type'] == 'sql':
54-
return SQLSkill(json['name'], json['params']['tables'], json['params']['database'], json['params']['description'])
57+
return SQLSkill(name, params['tables'], params['database'], params.get('description', ''))
5558
if json['type'] == 'retrieval':
56-
return RetrievalSkill(json['name'], json['params']['source'], json['params']['description'])
57-
return Skill(json['name'], json['type'], json['params'])
59+
return RetrievalSkill(name, params['source'], params.get('description', ''))
60+
return Skill(name, type, params)
5861

5962

6063
class SQLSkill(Skill):

tests/test_sdk.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1264,7 +1264,7 @@ def test_create(self, mock_post):
12641264
'name': 'test_skill',
12651265
'project_id': 1,
12661266
'type': 'sql',
1267-
'params': {'tables': ['test_table'], 'database': 'test_database'},
1267+
'params': {'tables': ['test_table'], 'database': 'test_database', 'description': 'test_description'},
12681268
}],
12691269
'params': {'k1': 'v1'},
12701270
'created_at': created_at,
@@ -1277,7 +1277,7 @@ def test_create(self, mock_post):
12771277
new_agent = server.agents.create(
12781278
name='test_agent',
12791279
model=Model(None, {'name':'m1'}),
1280-
skills=[SQLSkill('test_skill', ['test_table'], 'test_database'), 'test_description'],
1280+
skills=[SQLSkill('test_skill', ['test_table'], 'test_database', 'test_description')],
12811281
params={'k1': 'v1'}
12821282
)
12831283
# Check API call.
@@ -1298,11 +1298,11 @@ def test_create(self, mock_post):
12981298
'skill': {
12991299
'name': 'test_skill',
13001300
'type': 'sql',
1301-
'params': {'database': 'test_database', 'tables': ['test_table']}
1301+
'params': {'database': 'test_database', 'tables': ['test_table'], 'description': 'test_description'}
13021302
}
13031303
}
13041304

1305-
expected_skill = SQLSkill('test_skill', ['test_table'], 'test_database', 'test_descrition')
1305+
expected_skill = SQLSkill('test_skill', ['test_table'], 'test_database', 'test_description')
13061306
expected_agent = Agent(
13071307
'test_agent',
13081308
'test_model',
@@ -1332,7 +1332,7 @@ def test_update(self, mock_get, mock_put, _):
13321332
'name': 'updated_skill',
13331333
'project_id': 1,
13341334
'type': 'sql',
1335-
'params': {'tables': ['updated_table'], 'database': 'updated_database'},
1335+
'params': {'tables': ['updated_table'], 'database': 'updated_database', 'description': 'test_description'},
13361336
}],
13371337
'params': {'k2': 'v2'},
13381338
'created_at': created_at,
@@ -1422,7 +1422,7 @@ def test_list(self, mock_get):
14221422
'id': 1,
14231423
'name': 'test_skill',
14241424
'project_id': 1,
1425-
'params': {'tables': ['test_table'], 'database': 'test_database'},
1425+
'params': {'tables': ['test_table'], 'database': 'test_database', 'description': 'test_description' },
14261426
'type': 'sql'
14271427
}
14281428
])
@@ -1440,7 +1440,7 @@ def test_get(self, mock_get):
14401440
'id': 1,
14411441
'name': 'test_skill',
14421442
'project_id': 1,
1443-
'params': {'tables': ['test_table'], 'database': 'test_database'},
1443+
'params': {'tables': ['test_table'], 'database': 'test_database', 'description': 'test_description'},
14441444
'type': 'sql'
14451445
}
14461446
)
@@ -1456,7 +1456,7 @@ def test_create(self, mock_post):
14561456
'id': 1,
14571457
'name': 'test_skill',
14581458
'project_id': 1,
1459-
'params': {'k1': 'v1'},
1459+
'params': {'tables': ['test_table'], 'database': 'test_database', 'description': 'test_description'},
14601460
'type': 'test'
14611461
}
14621462
response_mock(mock_post, data)
@@ -1466,18 +1466,18 @@ def test_create(self, mock_post):
14661466
new_skill = server.skills.create(
14671467
'test_skill',
14681468
'sql',
1469-
params={'tables': ['test_table'], 'database': 'test_database'}
1469+
params={'tables': ['test_table'], 'database': 'test_database', 'description': 'test_description'}
14701470
)
14711471
# Check API call.
14721472
assert mock_post.call_args[0][0] == f'{DEFAULT_LOCAL_API_URL}/api/projects/mindsdb/skills'
14731473
assert mock_post.call_args[1]['json'] == {
14741474
'skill': {
14751475
'name': 'test_skill',
14761476
'type': 'sql',
1477-
'params': {'database': 'test_database', 'tables': ['test_table']}
1477+
'params': {'database': 'test_database', 'tables': ['test_table'], 'description': 'test_description'}
14781478
}
14791479
}
1480-
expected_skill = SQLSkill('test_skill', ['test_table'], 'test_database')
1480+
expected_skill = SQLSkill('test_skill', ['test_table'], 'test_database', 'test_description')
14811481

14821482
assert new_skill == expected_skill
14831483

@@ -1487,13 +1487,13 @@ def test_update(self, mock_put):
14871487
'id': 1,
14881488
'name': 'test_skill',
14891489
'project_id': 1,
1490-
'params': {'tables': ['updated_table'], 'database': 'updated_database'},
1490+
'params': {'tables': ['updated_table'], 'database': 'updated_database', 'description': 'updated_description'},
14911491
'type': 'sql'
14921492
}
14931493
response_mock(mock_put, data)
14941494

14951495
server = mindsdb_sdk.connect()
1496-
expected_skill = SQLSkill('test_skill', ['updated_table'], 'updated_database', 'test_description')
1496+
expected_skill = SQLSkill('test_skill', ['updated_table'], 'updated_database', 'updated_description')
14971497

14981498
updated_skill = server.skills.update('test_skill', expected_skill)
14991499
# Check API call.
@@ -1502,7 +1502,7 @@ def test_update(self, mock_put):
15021502
'skill': {
15031503
'name': 'test_skill',
15041504
'type': 'sql',
1505-
'params': {'tables': ['updated_table'], 'database': 'updated_database'}
1505+
'params': {'tables': ['updated_table'], 'database': 'updated_database', 'description': 'updated_description'}
15061506
}
15071507
}
15081508

0 commit comments

Comments
 (0)