Skip to content

Commit b26bc35

Browse files
committed
Add ability for agents to do webpage retrieval
1 parent cf74f06 commit b26bc35

File tree

4 files changed

+217
-0
lines changed

4 files changed

+217
-0
lines changed

mindsdb_sdk/agents.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from requests.exceptions import HTTPError
22
from typing import List, Union
3+
from urllib.parse import urlparse
34
from uuid import uuid4
45
import datetime
56
import pandas as pd
@@ -86,6 +87,14 @@ def add_file(self, file_path: str, description: str, knowledge_base: str = None)
8687
"""
8788
self.collection.add_file(self.name, file_path, description, knowledge_base)
8889

90+
def add_webpage(self, url: str, description: str, knowledge_base: str = None):
91+
"""
92+
Add a crawled URL to the agent for retrieval.
93+
94+
:param url: URL of the page to be crawled and added.
95+
"""
96+
self.collection.add_webpage(self.name, url, description, knowledge_base)
97+
8998
def __repr__(self):
9099
return f'{self.__class__.__name__}(name: {self.name})'
91100

@@ -205,6 +214,44 @@ def add_file(self, name: str, file_path: str, description: str, knowledge_base:
205214
agent.skills.append(file_retrieval_skill)
206215
self.update(agent.name, agent)
207216

217+
def add_webpage(self, name: str, url: str, description: str, knowledge_base: str = None):
218+
"""
219+
Add a webpage to the agent for retrieval.
220+
221+
:param name: Name of the agent
222+
:param file_path: URL of the webpage to be added, or name of existing webpage.
223+
:param description: Description of the webpage. Used by agent to know when to do retrieval.
224+
:param knowledge_base: Name of an existing knowledge base to be used. Will create a default knowledge base if not given.
225+
"""
226+
parsed_url = urlparse(url)
227+
domain = parsed_url.netloc.replace('.', '_')
228+
path = parsed_url.path.replace('/', '_')
229+
if knowledge_base is not None:
230+
kb = self.knowledge_bases.get(knowledge_base)
231+
else:
232+
kb_name = f'{name}_{domain}{path}_kb'
233+
try:
234+
kb = self.knowledge_bases.get(kb_name)
235+
except AttributeError:
236+
# Create KB if it doesn't exist.
237+
kb = self.knowledge_bases.create(kb_name)
238+
# Wait for underlying embedding model to finish training.
239+
kb.model.wait_complete()
240+
241+
# Insert crawled webpage.
242+
kb.insert_webpages([url])
243+
244+
# Make sure skill name is unique.
245+
skill_name = f'{domain}{path}_retrieval_skill_{uuid4()}'
246+
retrieval_params = {
247+
'source': kb.name,
248+
'description': description,
249+
}
250+
webpage_retrieval_skill = self.skills.create(skill_name, 'retrieval', retrieval_params)
251+
agent = self.get(name)
252+
agent.skills.append(webpage_retrieval_skill)
253+
self.update(agent.name, agent)
254+
208255
def create(
209256
self,
210257
name: str,

mindsdb_sdk/connectors/rest_api.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,3 +338,17 @@ def insert_files_into_knowledge_base(self, project: str, knowledge_base_name: st
338338
_raise_for_status(r)
339339

340340
return r.json()
341+
342+
@_try_relogin
343+
def insert_webpages_into_knowledge_base(self, project: str, knowledge_base_name: str, urls: List[str]):
344+
r = self.session.put(
345+
self.url + f'/api/projects/{project}/knowledge_bases/{knowledge_base_name}',
346+
json={
347+
'knowledge_base': {
348+
'urls': urls
349+
}
350+
}
351+
)
352+
_raise_for_status(r)
353+
354+
return r.json()

mindsdb_sdk/knowledge_bases.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,12 @@ def insert_files(self, file_paths: List[str]):
124124
"""
125125
self.api.insert_files_into_knowledge_base(self.project.name, self.name, file_paths)
126126

127+
def insert_webpages(self, urls: List[str]):
128+
"""
129+
Insert data from crawled URLs to knowledge base
130+
"""
131+
self.api.insert_webpages_into_knowledge_base(self.project.name, self.name, urls)
132+
127133
def insert(self, data: Union[pd.DataFrame, Query, dict]):
128134
"""
129135
Insert data to knowledge base

tests/test_sdk.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,25 @@ def side_effect(*args, **kwargs):
3838
mock.side_effect = side_effect
3939

4040

41+
def responses_mock(mock, data):
42+
side_effect_fns = []
43+
for d in data:
44+
if isinstance(d, pd.DataFrame):
45+
# to sql/query format (mostly used)
46+
pd_data = d.to_dict('split')
47+
d = {
48+
'type': 'table',
49+
'column_names': pd_data['columns'],
50+
'data': pd_data['data']
51+
}
52+
def side_effect(*args, **kwargs):
53+
r_mock = Mock()
54+
r_mock.status_code = 200
55+
r_mock.json.return_value = d
56+
return r_mock
57+
side_effect_fns.append(side_effect())
58+
mock.side_effect = side_effect_fns
59+
4160
def check_sql_call(mock, sql, database=None, call_stack_num=None):
4261
if call_stack_num is not None:
4362
call_args = mock.mock_calls[call_stack_num]
@@ -1436,6 +1455,137 @@ def test_delete(self, mock_delete):
14361455
# Check API call.
14371456
assert mock_delete.call_args[0][0] == f'{DEFAULT_LOCAL_API_URL}/api/projects/mindsdb/agents/test_agent'
14381457

1458+
@patch('requests.Session.get')
1459+
@patch('requests.Session.put')
1460+
@patch('requests.Session.post')
1461+
def test_add_file(self, mock_post, mock_put, mock_get):
1462+
server = mindsdb_sdk.connect()
1463+
responses_mock(mock_get, [
1464+
# File metadata get.
1465+
[{'name': 'tokaido_rules'}],
1466+
# Existing agent get.
1467+
{
1468+
'name': 'test_agent',
1469+
'model_name': 'test_model',
1470+
'skills': [],
1471+
'params': {},
1472+
'created_at': None,
1473+
'updated_at': None
1474+
},
1475+
# Skills get in Agent update to check if it exists.
1476+
{'name': 'new_skill', 'type': 'retrieval', 'params': {'source': 'test_agent_tokaido_rules_kb'}},
1477+
# Existing agent get in Agent update.
1478+
{
1479+
'name': 'test_agent',
1480+
'model_name': 'test_model',
1481+
'skills': [],
1482+
'params': {},
1483+
'created_at': None,
1484+
'updated_at': None
1485+
},
1486+
])
1487+
responses_mock(mock_post, [
1488+
# KB get (POST /sql).
1489+
pd.DataFrame([
1490+
{'name': 'test_agent_tokaido_rules_kb', 'storage': None, 'model': None},
1491+
]),
1492+
# Skill creation.
1493+
{'name': 'new_skill', 'type': 'retrieval', 'params': {'source': 'test_agent_tokaido_rules_kb'}}
1494+
])
1495+
responses_mock(mock_put, [
1496+
# KB update.
1497+
{'name': 'test_agent_tokaido_rules_kb'},
1498+
# Agent update with new skill.
1499+
{
1500+
'name': 'test_agent',
1501+
'model_name': 'test_model',
1502+
'skills': [{'name': 'new_skill', 'type': 'retrieval', 'params': {'source': 'test_agent_tokaido_rules_kb'}}],
1503+
'params': {},
1504+
'created_at': None,
1505+
'updated_at': None
1506+
},
1507+
])
1508+
server.agents.add_file('test_agent', './tokaido_rules.pdf', 'Rules for the board game Tokaido')
1509+
1510+
# Check Agent was updated with a new skill.
1511+
agent_update_json = mock_put.call_args[-1]['json']
1512+
expected_agent_json = {
1513+
'agent': {
1514+
'name': 'test_agent',
1515+
'model_name': 'test_model',
1516+
# Skill name is a generated UUID.
1517+
'skills_to_add': [agent_update_json['agent']['skills_to_add'][0]],
1518+
'skills_to_remove': [],
1519+
'params': {},
1520+
}
1521+
}
1522+
assert agent_update_json == expected_agent_json
1523+
1524+
1525+
@patch('requests.Session.get')
1526+
@patch('requests.Session.put')
1527+
@patch('requests.Session.post')
1528+
def test_add_webpage(self, mock_post, mock_put, mock_get):
1529+
server = mindsdb_sdk.connect()
1530+
responses_mock(mock_get, [
1531+
# Existing agent get.
1532+
{
1533+
'name': 'test_agent',
1534+
'model_name': 'test_model',
1535+
'skills': [],
1536+
'params': {},
1537+
'created_at': None,
1538+
'updated_at': None
1539+
},
1540+
# Skills get in Agent update to check if it exists.
1541+
{'name': 'new_skill', 'type': 'retrieval', 'params': {'source': 'test_agent_tokaido_rules_kb'}},
1542+
# Existing agent get in Agent update.
1543+
{
1544+
'name': 'test_agent',
1545+
'model_name': 'test_model',
1546+
'skills': [],
1547+
'params': {},
1548+
'created_at': None,
1549+
'updated_at': None
1550+
},
1551+
])
1552+
responses_mock(mock_post, [
1553+
# KB get (POST /sql).
1554+
pd.DataFrame([
1555+
{'name': 'test_agent_docs_mdb_ai_kb', 'storage': None, 'model': None},
1556+
]),
1557+
# Skill creation.
1558+
{'name': 'new_skill', 'type': 'retrieval', 'params': {'source': 'test_agent_docs_mdb_ai_kb'}}
1559+
])
1560+
responses_mock(mock_put, [
1561+
# KB update.
1562+
{'name': 'test_agent_docs_mdb_ai_kb'},
1563+
# Agent update with new skill.
1564+
{
1565+
'name': 'test_agent',
1566+
'model_name': 'test_model',
1567+
'skills': [{'name': 'new_skill', 'type': 'retrieval', 'params': {'source': 'test_agent_docs_mdb_ai_kb'}}],
1568+
'params': {},
1569+
'created_at': None,
1570+
'updated_at': None
1571+
},
1572+
])
1573+
server.agents.add_webpage('test_agent', 'docs.mdb.ai', 'Documentation for MindsDB')
1574+
1575+
# Check Agent was updated with a new skill.
1576+
agent_update_json = mock_put.call_args[-1]['json']
1577+
expected_agent_json = {
1578+
'agent': {
1579+
'name': 'test_agent',
1580+
'model_name': 'test_model',
1581+
# Skill name is a generated UUID.
1582+
'skills_to_add': [agent_update_json['agent']['skills_to_add'][0]],
1583+
'skills_to_remove': [],
1584+
'params': {},
1585+
}
1586+
}
1587+
assert agent_update_json == expected_agent_json
1588+
14391589

14401590
class TestSkills():
14411591
@patch('requests.Session.get')

0 commit comments

Comments
 (0)