Skip to content

Commit 63d5878

Browse files
committed
Don't use generic skills if possible
1 parent faef616 commit 63d5878

File tree

2 files changed

+45
-52
lines changed

2 files changed

+45
-52
lines changed

mindsdb_sdk/skills.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, name: str, tables: List[str], database: str):
3636
'database': database,
3737
'tables': tables,
3838
}
39-
super().__init__(name, params)
39+
super().__init__(name, 'sql', params)
4040

4141

4242
class Skills(CollectionBase):

tests/test_sdk.py

Lines changed: 44 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from mindsdb_sdk.agents import Agent
1414
from mindsdb_sdk.connect import DEFAULT_LOCAL_API_URL
15-
from mindsdb_sdk.skills import Skill
15+
from mindsdb_sdk.skills import SQLSkill
1616
from mindsdb_sdk.connectors import rest_api
1717

1818
# patch _raise_for_status
@@ -166,7 +166,7 @@ def check_model(self, model, database, mock_post):
166166
# get call before last call
167167
mock_call = mock_post.call_args_list[-2]
168168
assert mock_call[1]['json'][
169-
'query'] == f"update models_versions set active=1 where name = '{model2.name}' AND version = 3"
169+
'query'] == f"update models_versions set active=1 where (name = '{model2.name}') AND (version = 3)"
170170

171171
@patch('requests.Session.post')
172172
def check_table(self, table, mock_post):
@@ -1151,8 +1151,8 @@ def test_create(self, mock_post):
11511151
'id': 0,
11521152
'name': 'test_skill',
11531153
'project_id': 1,
1154-
'type': 'test',
1155-
'params': {'k1': 'v1'},
1154+
'type': 'sql',
1155+
'params': {'tables': ['test_table'], 'database': 'test_database'},
11561156
}],
11571157
'params': {'k1': 'v1'},
11581158
'created_at': created_at,
@@ -1165,7 +1165,7 @@ def test_create(self, mock_post):
11651165
new_agent = server.agents.create(
11661166
name='test_agent',
11671167
model=Model(None, {'name':'m1'}),
1168-
skills=[Skill('test_skill', 'test', {'k1': 'v1'})],
1168+
skills=[SQLSkill('test_skill', ['test_table'], 'test_database')],
11691169
params={'k1': 'v1'}
11701170
)
11711171
# Check API call.
@@ -1185,12 +1185,12 @@ def test_create(self, mock_post):
11851185
assert mock_post.call_args_list[-2].kwargs['json'] == {
11861186
'skill': {
11871187
'name': 'test_skill',
1188-
'type': 'test',
1189-
'params': {'k1': 'v1'}
1188+
'type': 'sql',
1189+
'params': {'database': 'test_database', 'tables': ['test_table']}
11901190
}
11911191
}
11921192

1193-
expected_skill = Skill('test_skill', 'test', {'k1': 'v1'})
1193+
expected_skill = SQLSkill('test_skill', ['test_table'], 'test_database')
11941194
expected_agent = Agent(
11951195
'test_agent',
11961196
'test_model',
@@ -1212,15 +1212,15 @@ def test_update(self, mock_get, mock_put, _):
12121212
updated_at = dt.datetime(2001, 3, 1, 9, 30)
12131213
data = {
12141214
'id': 1,
1215-
'name': 'updated_agent',
1215+
'name': 'test_agent',
12161216
'project_id': 1,
12171217
'model_name': 'updated_model',
12181218
'skills': [{
12191219
'id': 1,
12201220
'name': 'updated_skill',
12211221
'project_id': 1,
1222-
'type': 'test',
1223-
'params': {'k2': 'v2'},
1222+
'type': 'sql',
1223+
'params': {'tables': ['updated_table'], 'database': 'updated_database'},
12241224
}],
12251225
'params': {'k2': 'v2'},
12261226
'created_at': created_at,
@@ -1234,21 +1234,15 @@ def test_update(self, mock_get, mock_put, _):
12341234
'name': 'test_agent',
12351235
'project_id': 1,
12361236
'model_name': 'test_model',
1237-
'skills': [{
1238-
'id': 1,
1239-
'name': 'test_skill',
1240-
'project_id': 1,
1241-
'type': 'test',
1242-
'params': {'k1': 'v1'},
1243-
}],
1237+
'skills': [],
12441238
'params': {'k1': 'v1'},
12451239
})
12461240

12471241
server = mindsdb_sdk.connect()
12481242
expected_agent = Agent(
1249-
'updated_agent',
1243+
'test_agent',
12501244
'updated_model',
1251-
[Skill('updated_skill', 'test', {'k2': 'v2'})],
1245+
[SQLSkill('updated_skill', ['updated_table'], 'updated_database')],
12521246
{'k2': 'v2'},
12531247
created_at,
12541248
updated_at
@@ -1259,14 +1253,25 @@ def test_update(self, mock_get, mock_put, _):
12591253
assert mock_put.call_args.args[0] == f'{DEFAULT_LOCAL_API_URL}/api/projects/mindsdb/agents/test_agent'
12601254
assert mock_put.call_args.kwargs['json'] == {
12611255
'agent': {
1262-
'name': 'updated_agent',
1256+
'name': 'test_agent',
12631257
'model_name': 'updated_model',
12641258
'skills_to_add': ['updated_skill'],
12651259
'skills_to_remove': [],
12661260
'params': {'k2': 'v2'}
12671261
}
12681262
}
12691263

1264+
print('UPDATED')
1265+
print(updated_agent.name)
1266+
print(updated_agent.model_name)
1267+
print(updated_agent.skills)
1268+
print(updated_agent.params)
1269+
1270+
print('expected')
1271+
print(expected_agent.name)
1272+
print(expected_agent.model_name)
1273+
print(expected_agent.skills)
1274+
print(expected_agent.params)
12701275
assert updated_agent == expected_agent
12711276

12721277

@@ -1316,18 +1321,14 @@ def test_list(self, mock_get):
13161321
'id': 1,
13171322
'name': 'test_skill',
13181323
'project_id': 1,
1319-
'params': {'k1': 'v1'},
1320-
'type': 'test'
1324+
'params': {'tables': ['test_table'], 'database': 'test_database'},
1325+
'type': 'sql'
13211326
}
13221327
])
13231328
all_skills = server.skills.list()
13241329
assert len(all_skills) == 1
13251330

1326-
expected_skill = Skill(
1327-
'test_skill',
1328-
'test',
1329-
params={'k1': 'v1'}
1330-
)
1331+
expected_skill = SQLSkill('test_skill', ['test_table'], 'test_database')
13311332
assert all_skills[0] == expected_skill
13321333

13331334
@patch('requests.Session.get')
@@ -1338,18 +1339,14 @@ def test_get(self, mock_get):
13381339
'id': 1,
13391340
'name': 'test_skill',
13401341
'project_id': 1,
1341-
'params': {'k1': 'v1'},
1342-
'type': 'test'
1342+
'params': {'tables': ['test_table'], 'database': 'test_database'},
1343+
'type': 'sql'
13431344
}
13441345
)
13451346
skill = server.skills.get('test_skill')
13461347
# Check API call.
13471348
assert mock_get.call_args.args[0] == f'{DEFAULT_LOCAL_API_URL}/api/projects/mindsdb/skills/test_skill'
1348-
expected_skill = Skill(
1349-
'test_skill',
1350-
'test',
1351-
params={'k1': 'v1'}
1352-
)
1349+
expected_skill = SQLSkill('test_skill', ['test_table'], 'test_database')
13531350
assert skill == expected_skill
13541351

13551352
@patch('requests.Session.post')
@@ -1367,48 +1364,44 @@ def test_create(self, mock_post):
13671364
server = mindsdb_sdk.connect()
13681365
new_skill = server.skills.create(
13691366
'test_skill',
1370-
'test',
1371-
params={'k1': 'v1'}
1367+
'sql',
1368+
params={'tables': ['test_table'], 'database': 'test_database'}
13721369
)
13731370
# Check API call.
13741371
assert mock_post.call_args.args[0] == f'{DEFAULT_LOCAL_API_URL}/api/projects/mindsdb/skills'
13751372
assert mock_post.call_args.kwargs['json'] == {
13761373
'skill': {
13771374
'name': 'test_skill',
1378-
'type': 'test',
1379-
'params': {'k1': 'v1'}
1375+
'type': 'sql',
1376+
'params': {'database': 'test_database', 'tables': ['test_table']}
13801377
}
13811378
}
1382-
expected_skill = Skill('test_skill', 'test', {'k1': 'v1'})
1379+
expected_skill = SQLSkill('test_skill', ['test_table'], 'test_database')
13831380

13841381
assert new_skill == expected_skill
13851382

13861383
@patch('requests.Session.put')
13871384
def test_update(self, mock_put):
13881385
data = {
13891386
'id': 1,
1390-
'name': 'updated_skill',
1387+
'name': 'test_skill',
13911388
'project_id': 1,
1392-
'params': {'k2': 'v2'},
1393-
'type': 'new_type'
1389+
'params': {'tables': ['updated_table'], 'database': 'updated_database'},
1390+
'type': 'sql'
13941391
}
13951392
response_mock(mock_put, data)
13961393

13971394
server = mindsdb_sdk.connect()
1398-
expected_skill = Skill(
1399-
'updated_skill',
1400-
'new_type',
1401-
params={'k2': 'v2'}
1402-
)
1395+
expected_skill = SQLSkill('test_skill', ['updated_table'], 'updated_database')
14031396

14041397
updated_skill = server.skills.update('test_skill', expected_skill)
14051398
# Check API call.
14061399
assert mock_put.call_args.args[0] == f'{DEFAULT_LOCAL_API_URL}/api/projects/mindsdb/skills/test_skill'
14071400
assert mock_put.call_args.kwargs['json'] == {
14081401
'skill': {
1409-
'name': 'updated_skill',
1410-
'type': 'new_type',
1411-
'params': {'k2': 'v2'}
1402+
'name': 'test_skill',
1403+
'type': 'sql',
1404+
'params': {'tables': ['updated_table'], 'database': 'updated_database'}
14121405
}
14131406
}
14141407

0 commit comments

Comments
 (0)