Skip to content

Commit 207f224

Browse files
authored
Merge pull request #98 from mindsdb/add-kb
KB support
2 parents 9c535ce + 2890a95 commit 207f224

File tree

7 files changed

+434
-11
lines changed

7 files changed

+434
-11
lines changed

mindsdb_sdk/knowledge_bases.py

Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
import copy
2+
import json
3+
from typing import Union, List
4+
5+
import pandas as pd
6+
7+
from mindsdb_sql.parser.dialects.mindsdb import CreateKnowledgeBase, DropKnowledgeBase
8+
from mindsdb_sql.parser.ast import Identifier, Star, Select, BinaryOperation, Constant, Insert
9+
10+
from mindsdb_sdk.utils.sql import dict_to_binary_op
11+
from mindsdb_sdk.utils.objects_collection import CollectionBase
12+
13+
from .models import Model
14+
from .tables import Table
15+
from .query import Query
16+
from .databases import Database
17+
18+
19+
class KnowledgeBase(Query):
20+
"""
21+
22+
Knowledge base object, used to update or query knowledge base
23+
24+
Add data to knowledge base:
25+
26+
>>> kb.insert(pd.read_csv('house_sales.csv'))
27+
28+
Query relevant results
29+
30+
>>> df = kb.find('flats').fetch()
31+
32+
"""
33+
34+
def __init__(self, project, data: dict):
35+
36+
self.project = project
37+
self.name = data['name']
38+
39+
self.storage = None
40+
if data['storage'] is not None:
41+
# if name contents '.' there could be errors
42+
43+
parts = data['storage'].split('.')
44+
if len(parts) == 2:
45+
database_name, table_name = parts
46+
database = Database(project, database_name)
47+
table = Table(database, table_name)
48+
self.storage = table
49+
50+
self.model = None
51+
if data['model'] is not None:
52+
self.model = Model(self.project, {'name': data['model']})
53+
54+
params = data.get('params', {})
55+
if isinstance(params, str):
56+
try:
57+
params = json.loads(params)
58+
except json.JSONDecodeError:
59+
params = {}
60+
61+
# columns
62+
self.metadata_columns = params.pop('metadata_columns', [])
63+
self.content_columns = params.pop('content_columns', [])
64+
self.id_column = params.pop('id_column', None)
65+
66+
self.params = params
67+
68+
# query behavior
69+
self._query = None
70+
self._limit = None
71+
72+
database = project.name
73+
self._update_query()
74+
75+
super().__init__(project.api, self.sql, database)
76+
77+
def __repr__(self):
78+
return f'{self.__class__.__name__}({self.project.name}.{self.name})'
79+
80+
def find(self, query: str, limit: int = 100):
81+
"""
82+
83+
Query data from knowledge base.
84+
Knowledge base should return a most relevant results for the query
85+
86+
>>> # query knowledge base
87+
>>> query = my_kb.find('dogs')
88+
>>> # fetch dataframe to client
89+
>>> print(query.fetch())
90+
91+
:param query: text query
92+
:param limit: count of rows in result, default 100
93+
:return: Query object
94+
"""
95+
96+
kb = copy.deepcopy(self)
97+
kb._query = query
98+
kb._limit = limit
99+
kb._update_query()
100+
101+
return kb
102+
103+
def _update_query(self):
104+
105+
ast_query = Select(
106+
targets=[Star()],
107+
from_table=Identifier(parts=[
108+
self.project.name, self.name
109+
])
110+
)
111+
if self._query is not None:
112+
ast_query.where = BinaryOperation(op='=', args=[
113+
Identifier('content'),
114+
Constant(self._query)
115+
])
116+
117+
if self._limit is not None:
118+
ast_query.limit = Constant(self._limit)
119+
self.sql = ast_query.to_string()
120+
121+
def insert(self, data: Union[pd.DataFrame, Query, dict]):
122+
"""
123+
Insert data to knowledge base
124+
125+
>>> # insert using query
126+
>>> my_kb.insert(server.databases.example_db.tables.houses_sales.filter(type='house'))
127+
>>> # using dataframe
128+
>>> my_kb.insert(pd.read_csv('house_sales.csv'))
129+
>>> # using dict
130+
>>> my_kb.insert({'type': 'house', 'date': '2020-02-02'})
131+
132+
Data will be if id (defined by id_column param, see create knowledge base) is already exists in knowledge base
133+
it will be replaced
134+
135+
:param data: Dataframe or Query object or dict.
136+
"""
137+
138+
if isinstance(data, dict):
139+
data = pd.DataFrame([data])
140+
141+
if isinstance(data, pd.DataFrame):
142+
# insert data
143+
data_split = data.to_dict('split')
144+
145+
ast_query = Insert(
146+
table=Identifier(self.name),
147+
columns=data_split['columns'],
148+
values=data_split['data']
149+
)
150+
151+
sql = ast_query.to_string()
152+
self.api.sql_query(sql, self.database)
153+
else:
154+
# insert from select
155+
table = Identifier(parts=[self.database, self.name])
156+
self.api.sql_query(
157+
f'INSERT INTO {table.to_string()} ({data.sql})',
158+
database=data.database
159+
)
160+
161+
162+
class KnowledgeBases(CollectionBase):
163+
"""
164+
**Knowledge bases**
165+
166+
Get list:
167+
168+
>>> kb_list = server.knowledge_bases.list()
169+
>>> kb = kb_list[0]
170+
171+
Get by name:
172+
173+
>>> kb = server.knowledge_bases.get('my_kb')
174+
>>> # or :
175+
>>> kb = server.knowledge_bases.my_kb
176+
177+
Create:
178+
179+
>>> kb = server.knowledge_bases.create('my_kb')
180+
181+
Drop:
182+
183+
>>> server.knowledge_bases.drop('my_kb')
184+
185+
"""
186+
187+
def __init__(self, project, api):
188+
self.project = project
189+
self.api = api
190+
191+
def _list(self, name: str = None) -> List[KnowledgeBase]:
192+
193+
# TODO add filter by project. for now 'project' is empty
194+
ast_query = Select(targets=[Star()], from_table=Identifier(parts=['information_schema', 'knowledge_bases']))
195+
if name is not None:
196+
ast_query.where = dict_to_binary_op({'name': name})
197+
198+
df = self.api.sql_query(ast_query.to_string(), database=self.project.name)
199+
200+
# columns to lower case
201+
cols_map = {i: i.lower() for i in df.columns}
202+
df = df.rename(columns=cols_map)
203+
204+
return [
205+
KnowledgeBase(self.project, item)
206+
for item in df.to_dict('records')
207+
]
208+
209+
def list(self) -> List[KnowledgeBase]:
210+
"""
211+
212+
Get list of knowledge bases inside of project:
213+
214+
>>> kb_list = project.knowledge_bases.list()
215+
216+
:return: list of knowledge bases
217+
"""
218+
return self._list()
219+
220+
def get(self, name: str) -> KnowledgeBase:
221+
"""
222+
Get knowledge base by name
223+
224+
:param name: name of the knowledge base
225+
:return: KnowledgeBase object
226+
"""
227+
item = self._list(name)
228+
if len(item) == 1:
229+
return item[0]
230+
elif len(item) == 0:
231+
raise AttributeError("KnowledgeBase doesn't exist")
232+
else:
233+
raise RuntimeError("Several knowledgeBases with the same name")
234+
235+
def create(
236+
self,
237+
name: str,
238+
model: Model = None,
239+
storage: Table = None,
240+
metadata_columns: list = None,
241+
content_columns: list = None,
242+
id_column: str = None,
243+
params: dict = None,
244+
) -> KnowledgeBase:
245+
"""
246+
247+
Create knowledge base:
248+
249+
>>> kb = server.knowledge_bases.create(
250+
... 'my_kb',
251+
... model=server.models.emb_model,
252+
... storage=server.databases.pvec.tables.tbl1,
253+
... metadata_columns=['date', 'author'],
254+
... content_columns=['review', 'description'],
255+
... id_column='number',
256+
... params={'a': 1}
257+
...)
258+
259+
:param name: name of the knowledge base
260+
:param model: embedding model, optional. Default: 'sentence_transformers' will be used (defined in mindsdb server)
261+
:param storage: vector storage, optional. Default: chromadb database will be created
262+
:param metadata_columns: columns to use as metadata, optional. Default: all columns which are not content and id
263+
:param content_columns: columns to use as content, optional. Default: all columns except id column
264+
:param id_column: the column to use as id, optinal. Default: 'id', if exists
265+
:param params: other parameters to knowledge base
266+
:return: created KnowledgeBase object
267+
"""
268+
269+
params_out = {}
270+
271+
if metadata_columns is not None:
272+
params_out['metadata_columns'] = metadata_columns
273+
274+
if content_columns is not None:
275+
params_out['content_columns'] = content_columns
276+
277+
if id_column is not None:
278+
params_out['id_column'] = id_column
279+
280+
if params is not None:
281+
params_out.update(params)
282+
283+
if model is not None:
284+
model_name = Identifier(parts=[model.project.name, model.name])
285+
else:
286+
model_name = None
287+
288+
if storage is not None:
289+
storage_name = Identifier(parts=[storage.db.name, storage.name])
290+
else:
291+
storage_name = None
292+
293+
ast_query = CreateKnowledgeBase(
294+
Identifier(name),
295+
model=model_name,
296+
storage=storage_name,
297+
params=params_out
298+
)
299+
300+
self.api.sql_query(ast_query.to_string(), database=self.project.name)
301+
302+
return self.get(name)
303+
304+
def drop(self, name: str):
305+
"""
306+
307+
:param name:
308+
:return:
309+
"""
310+
311+
ast_query = DropKnowledgeBase(Identifier(name))
312+
313+
self.api.sql_query(ast_query.to_string())

mindsdb_sdk/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class Model:
3737
3838
>>> model.refresh()
3939
40-
Usng model
40+
**Usng model**
4141
4242
Dataframe on input
4343

mindsdb_sdk/projects.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from mindsdb_sql.parser.dialects.mindsdb import CreateDatabase
44
from mindsdb_sql.parser.ast import DropDatabase
5-
from mindsdb_sql.parser.ast import Identifier, Delete
5+
from mindsdb_sql.parser.ast import Identifier, Delete
66

77
from mindsdb_sdk.utils.sql import dict_to_binary_op
88

@@ -14,6 +14,7 @@
1414
from .query import Query
1515
from .views import Views
1616
from .jobs import Jobs
17+
from .knowledge_bases import KnowledgeBases
1718

1819

1920
class Project:
@@ -47,7 +48,7 @@ class Project:
4748
4849
"""
4950

50-
def __init__(self, api, name, agents: Agents = None, skills: Skills = None):
51+
def __init__(self, api, name):
5152
self.name = name
5253
self.api = api
5354

@@ -75,8 +76,9 @@ def __init__(self, api, name, agents: Agents = None, skills: Skills = None):
7576
self.create_job = self.jobs.create
7677
self.drop_job = self.jobs.drop
7778

78-
self.skills = skills or Skills(api, name)
79-
self.agents = agents or Agents(api, name, self.skills)
79+
self.skills = Skills(api, name)
80+
self.agents = Agents(api, name, self.skills)
81+
self.knowledge_bases = KnowledgeBases(self, api)
8082

8183
def __repr__(self):
8284
return f'{self.__class__.__name__}({self.name})'

mindsdb_sdk/server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ class Server(Project):
2525
2626
"""
2727

28-
def __init__(self, api, skills: Skills = None, agents: Agents = None):
28+
def __init__(self, api):
2929
# server is also mindsdb project
3030
project_name = 'mindsdb'
31-
super().__init__(api, project_name, skills=skills, agents=agents)
31+
super().__init__(api, project_name)
3232

3333
self.projects = Projects(api)
3434

mindsdb_sdk/views.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def create(self, name: str, sql: Union[str, Query], database: str = None) -> Vie
102102
database = sql.database
103103
sql = sql.sql
104104
elif not isinstance(sql, str):
105-
raise ValueError()
105+
raise ValueError(sql)
106106

107107
if database is not None:
108108
database = Identifier(database)

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
requests
22
pandas >= 1.3.5
3-
mindsdb-sql >= 0.12.0, < 0.13.0
3+
mindsdb-sql >= 0.13.0, < 0.14.0

0 commit comments

Comments
 (0)