Skip to content

Commit f04df80

Browse files
committed
Add description to sql skill
1 parent 3d84d5b commit f04df80

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

mindsdb_sdk/skills.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,19 @@ def __repr__(self):
5151
@classmethod
5252
def from_json(cls, json: dict):
5353
if json['type'] == 'sql':
54-
return SQLSkill(json['name'], json['params']['tables'], json['params']['database'])
54+
return SQLSkill(json['name'], json['params']['tables'], json['params']['database'], json['params']['description'])
5555
if json['type'] == 'retrieval':
5656
return RetrievalSkill(json['name'], json['params']['source'], json['params']['description'])
5757
return Skill(json['name'], json['type'], json['params'])
5858

5959

6060
class SQLSkill(Skill):
6161
"""Represents a MindsDB skill for agents to interact with MindsDB databases"""
62-
def __init__(self, name: str, tables: List[str], database: str):
62+
def __init__(self, name: str, tables: List[str], database: str, description: str):
6363
params = {
6464
'database': database,
6565
'tables': tables,
66+
'description': description
6667
}
6768
super().__init__(name, 'sql', params)
6869

@@ -114,7 +115,7 @@ def create(self, name: str, type: str, params: dict = None) -> Skill:
114115
"""
115116
_ = self.api.create_skill(self.project, name, type, params)
116117
if type == 'sql':
117-
return SQLSkill(name, params['tables'], params['database'])
118+
return SQLSkill(name, params['tables'], params['database'], params['description'])
118119
return Skill(name, type, params)
119120

120121
def update(self, name: str, updated_skill: Skill) -> Skill:

tests/test_sdk.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,7 +1277,7 @@ def test_create(self, mock_post):
12771277
new_agent = server.agents.create(
12781278
name='test_agent',
12791279
model=Model(None, {'name':'m1'}),
1280-
skills=[SQLSkill('test_skill', ['test_table'], 'test_database')],
1280+
skills=[SQLSkill('test_skill', ['test_table'], 'test_database'), 'test_description'],
12811281
params={'k1': 'v1'}
12821282
)
12831283
# Check API call.
@@ -1302,7 +1302,7 @@ def test_create(self, mock_post):
13021302
}
13031303
}
13041304

1305-
expected_skill = SQLSkill('test_skill', ['test_table'], 'test_database')
1305+
expected_skill = SQLSkill('test_skill', ['test_table'], 'test_database', 'test_descrition')
13061306
expected_agent = Agent(
13071307
'test_agent',
13081308
'test_model',
@@ -1354,7 +1354,7 @@ def test_update(self, mock_get, mock_put, _):
13541354
expected_agent = Agent(
13551355
'test_agent',
13561356
'updated_model',
1357-
[SQLSkill('updated_skill', ['updated_table'], 'updated_database')],
1357+
[SQLSkill('updated_skill', ['updated_table'], 'updated_database', 'test_description')],
13581358
{'k2': 'v2'},
13591359
created_at,
13601360
updated_at
@@ -1429,7 +1429,7 @@ def test_list(self, mock_get):
14291429
all_skills = server.skills.list()
14301430
assert len(all_skills) == 1
14311431

1432-
expected_skill = SQLSkill('test_skill', ['test_table'], 'test_database')
1432+
expected_skill = SQLSkill('test_skill', ['test_table'], 'test_database', 'test_description')
14331433
assert all_skills[0] == expected_skill
14341434

14351435
@patch('requests.Session.get')
@@ -1447,7 +1447,7 @@ def test_get(self, mock_get):
14471447
skill = server.skills.get('test_skill')
14481448
# Check API call.
14491449
assert mock_get.call_args[0][0] == f'{DEFAULT_LOCAL_API_URL}/api/projects/mindsdb/skills/test_skill'
1450-
expected_skill = SQLSkill('test_skill', ['test_table'], 'test_database')
1450+
expected_skill = SQLSkill('test_skill', ['test_table'], 'test_database', 'test_description')
14511451
assert skill == expected_skill
14521452

14531453
@patch('requests.Session.post')
@@ -1493,7 +1493,7 @@ def test_update(self, mock_put):
14931493
response_mock(mock_put, data)
14941494

14951495
server = mindsdb_sdk.connect()
1496-
expected_skill = SQLSkill('test_skill', ['updated_table'], 'updated_database')
1496+
expected_skill = SQLSkill('test_skill', ['updated_table'], 'updated_database', 'test_description')
14971497

14981498
updated_skill = server.skills.update('test_skill', expected_skill)
14991499
# Check API call.

0 commit comments

Comments
 (0)