1212
1313from mindsdb_sdk .agents import Agent
1414from mindsdb_sdk .connect import DEFAULT_LOCAL_API_URL
15- from mindsdb_sdk .skills import Skill
15+ from mindsdb_sdk .skills import SQLSkill
1616from mindsdb_sdk .connectors import rest_api
1717
1818# patch _raise_for_status
@@ -166,7 +166,7 @@ def check_model(self, model, database, mock_post):
166166 # get call before last call
167167 mock_call = mock_post .call_args_list [- 2 ]
168168 assert mock_call [1 ]['json' ][
169- 'query' ] == f"update models_versions set active=1 where name = '{ model2 .name } ' AND version = 3"
169+ 'query' ] == f"update models_versions set active=1 where ( name = '{ model2 .name } ') AND ( version = 3) "
170170
171171 @patch ('requests.Session.post' )
172172 def check_table (self , table , mock_post ):
@@ -1151,8 +1151,8 @@ def test_create(self, mock_post):
11511151 'id' : 0 ,
11521152 'name' : 'test_skill' ,
11531153 'project_id' : 1 ,
1154- 'type' : 'test ' ,
1155- 'params' : {'k1 ' : 'v1 ' },
1154+ 'type' : 'sql ' ,
1155+ 'params' : {'tables ' : [ 'test_table' ], 'database' : 'test_database ' },
11561156 }],
11571157 'params' : {'k1' : 'v1' },
11581158 'created_at' : created_at ,
@@ -1165,7 +1165,7 @@ def test_create(self, mock_post):
11651165 new_agent = server .agents .create (
11661166 name = 'test_agent' ,
11671167 model = Model (None , {'name' :'m1' }),
1168- skills = [Skill ('test_skill' , 'test' , { 'k1' : 'v1' } )],
1168+ skills = [SQLSkill ('test_skill' , [ 'test_table' ], 'test_database' )],
11691169 params = {'k1' : 'v1' }
11701170 )
11711171 # Check API call.
@@ -1185,12 +1185,12 @@ def test_create(self, mock_post):
11851185 assert mock_post .call_args_list [- 2 ].kwargs ['json' ] == {
11861186 'skill' : {
11871187 'name' : 'test_skill' ,
1188- 'type' : 'test ' ,
1189- 'params' : {'k1 ' : 'v1' }
1188+ 'type' : 'sql ' ,
1189+ 'params' : {'database ' : 'test_database' , 'tables' : [ 'test_table' ] }
11901190 }
11911191 }
11921192
1193- expected_skill = Skill ('test_skill' , 'test' , { 'k1' : 'v1' } )
1193+ expected_skill = SQLSkill ('test_skill' , [ 'test_table' ], 'test_database' )
11941194 expected_agent = Agent (
11951195 'test_agent' ,
11961196 'test_model' ,
@@ -1212,15 +1212,15 @@ def test_update(self, mock_get, mock_put, _):
12121212 updated_at = dt .datetime (2001 , 3 , 1 , 9 , 30 )
12131213 data = {
12141214 'id' : 1 ,
1215- 'name' : 'updated_agent ' ,
1215+ 'name' : 'test_agent ' ,
12161216 'project_id' : 1 ,
12171217 'model_name' : 'updated_model' ,
12181218 'skills' : [{
12191219 'id' : 1 ,
12201220 'name' : 'updated_skill' ,
12211221 'project_id' : 1 ,
1222- 'type' : 'test ' ,
1223- 'params' : {'k2 ' : 'v2 ' },
1222+ 'type' : 'sql ' ,
1223+ 'params' : {'tables ' : [ 'updated_table' ], 'database' : 'updated_database ' },
12241224 }],
12251225 'params' : {'k2' : 'v2' },
12261226 'created_at' : created_at ,
@@ -1234,21 +1234,15 @@ def test_update(self, mock_get, mock_put, _):
12341234 'name' : 'test_agent' ,
12351235 'project_id' : 1 ,
12361236 'model_name' : 'test_model' ,
1237- 'skills' : [{
1238- 'id' : 1 ,
1239- 'name' : 'test_skill' ,
1240- 'project_id' : 1 ,
1241- 'type' : 'test' ,
1242- 'params' : {'k1' : 'v1' },
1243- }],
1237+ 'skills' : [],
12441238 'params' : {'k1' : 'v1' },
12451239 })
12461240
12471241 server = mindsdb_sdk .connect ()
12481242 expected_agent = Agent (
1249- 'updated_agent ' ,
1243+ 'test_agent ' ,
12501244 'updated_model' ,
1251- [Skill ('updated_skill' , 'test' , { 'k2' : 'v2' } )],
1245+ [SQLSkill ('updated_skill' , [ 'updated_table' ], 'updated_database' )],
12521246 {'k2' : 'v2' },
12531247 created_at ,
12541248 updated_at
@@ -1259,14 +1253,25 @@ def test_update(self, mock_get, mock_put, _):
12591253 assert mock_put .call_args .args [0 ] == f'{ DEFAULT_LOCAL_API_URL } /api/projects/mindsdb/agents/test_agent'
12601254 assert mock_put .call_args .kwargs ['json' ] == {
12611255 'agent' : {
1262- 'name' : 'updated_agent ' ,
1256+ 'name' : 'test_agent ' ,
12631257 'model_name' : 'updated_model' ,
12641258 'skills_to_add' : ['updated_skill' ],
12651259 'skills_to_remove' : [],
12661260 'params' : {'k2' : 'v2' }
12671261 }
12681262 }
12691263
1264+ print ('UPDATED' )
1265+ print (updated_agent .name )
1266+ print (updated_agent .model_name )
1267+ print (updated_agent .skills )
1268+ print (updated_agent .params )
1269+
1270+ print ('expected' )
1271+ print (expected_agent .name )
1272+ print (expected_agent .model_name )
1273+ print (expected_agent .skills )
1274+ print (expected_agent .params )
12701275 assert updated_agent == expected_agent
12711276
12721277
@@ -1316,18 +1321,14 @@ def test_list(self, mock_get):
13161321 'id' : 1 ,
13171322 'name' : 'test_skill' ,
13181323 'project_id' : 1 ,
1319- 'params' : {'k1 ' : 'v1 ' },
1320- 'type' : 'test '
1324+ 'params' : {'tables ' : [ 'test_table' ], 'database' : 'test_database ' },
1325+ 'type' : 'sql '
13211326 }
13221327 ])
13231328 all_skills = server .skills .list ()
13241329 assert len (all_skills ) == 1
13251330
1326- expected_skill = Skill (
1327- 'test_skill' ,
1328- 'test' ,
1329- params = {'k1' : 'v1' }
1330- )
1331+ expected_skill = SQLSkill ('test_skill' , ['test_table' ], 'test_database' )
13311332 assert all_skills [0 ] == expected_skill
13321333
13331334 @patch ('requests.Session.get' )
@@ -1338,18 +1339,14 @@ def test_get(self, mock_get):
13381339 'id' : 1 ,
13391340 'name' : 'test_skill' ,
13401341 'project_id' : 1 ,
1341- 'params' : {'k1 ' : 'v1 ' },
1342- 'type' : 'test '
1342+ 'params' : {'tables ' : [ 'test_table' ], 'database' : 'test_database ' },
1343+ 'type' : 'sql '
13431344 }
13441345 )
13451346 skill = server .skills .get ('test_skill' )
13461347 # Check API call.
13471348 assert mock_get .call_args .args [0 ] == f'{ DEFAULT_LOCAL_API_URL } /api/projects/mindsdb/skills/test_skill'
1348- expected_skill = Skill (
1349- 'test_skill' ,
1350- 'test' ,
1351- params = {'k1' : 'v1' }
1352- )
1349+ expected_skill = SQLSkill ('test_skill' , ['test_table' ], 'test_database' )
13531350 assert skill == expected_skill
13541351
13551352 @patch ('requests.Session.post' )
@@ -1367,48 +1364,44 @@ def test_create(self, mock_post):
13671364 server = mindsdb_sdk .connect ()
13681365 new_skill = server .skills .create (
13691366 'test_skill' ,
1370- 'test ' ,
1371- params = {'k1 ' : 'v1 ' }
1367+ 'sql ' ,
1368+ params = {'tables ' : [ 'test_table' ], 'database' : 'test_database ' }
13721369 )
13731370 # Check API call.
13741371 assert mock_post .call_args .args [0 ] == f'{ DEFAULT_LOCAL_API_URL } /api/projects/mindsdb/skills'
13751372 assert mock_post .call_args .kwargs ['json' ] == {
13761373 'skill' : {
13771374 'name' : 'test_skill' ,
1378- 'type' : 'test ' ,
1379- 'params' : {'k1 ' : 'v1' }
1375+ 'type' : 'sql ' ,
1376+ 'params' : {'database ' : 'test_database' , 'tables' : [ 'test_table' ] }
13801377 }
13811378 }
1382- expected_skill = Skill ('test_skill' , 'test' , { 'k1' : 'v1' } )
1379+ expected_skill = SQLSkill ('test_skill' , [ 'test_table' ], 'test_database' )
13831380
13841381 assert new_skill == expected_skill
13851382
13861383 @patch ('requests.Session.put' )
13871384 def test_update (self , mock_put ):
13881385 data = {
13891386 'id' : 1 ,
1390- 'name' : 'updated_skill ' ,
1387+ 'name' : 'test_skill ' ,
13911388 'project_id' : 1 ,
1392- 'params' : {'k2 ' : 'v2 ' },
1393- 'type' : 'new_type '
1389+ 'params' : {'tables ' : [ 'updated_table' ], 'database' : 'updated_database ' },
1390+ 'type' : 'sql '
13941391 }
13951392 response_mock (mock_put , data )
13961393
13971394 server = mindsdb_sdk .connect ()
1398- expected_skill = Skill (
1399- 'updated_skill' ,
1400- 'new_type' ,
1401- params = {'k2' : 'v2' }
1402- )
1395+ expected_skill = SQLSkill ('test_skill' , ['updated_table' ], 'updated_database' )
14031396
14041397 updated_skill = server .skills .update ('test_skill' , expected_skill )
14051398 # Check API call.
14061399 assert mock_put .call_args .args [0 ] == f'{ DEFAULT_LOCAL_API_URL } /api/projects/mindsdb/skills/test_skill'
14071400 assert mock_put .call_args .kwargs ['json' ] == {
14081401 'skill' : {
1409- 'name' : 'updated_skill ' ,
1410- 'type' : 'new_type ' ,
1411- 'params' : {'k2 ' : 'v2 ' }
1402+ 'name' : 'test_skill ' ,
1403+ 'type' : 'sql ' ,
1404+ 'params' : {'tables ' : [ 'updated_table' ], 'database' : 'updated_database ' }
14121405 }
14131406 }
14141407
0 commit comments