|
| 1 | +import copy |
| 2 | +import json |
| 3 | +from typing import Union, List |
| 4 | + |
| 5 | +import pandas as pd |
| 6 | + |
| 7 | +from mindsdb_sql.parser.dialects.mindsdb import CreateKnowledgeBase, DropKnowledgeBase |
| 8 | +from mindsdb_sql.parser.ast import Identifier, Star, Select, BinaryOperation, Constant, Insert |
| 9 | + |
| 10 | +from mindsdb_sdk.utils.sql import dict_to_binary_op |
| 11 | +from mindsdb_sdk.utils.objects_collection import CollectionBase |
| 12 | + |
| 13 | +from .models import Model |
| 14 | +from .tables import Table |
| 15 | +from .query import Query |
| 16 | +from .databases import Database |
| 17 | + |
| 18 | + |
| 19 | +class KnowledgeBase(Query): |
| 20 | + """ |
| 21 | +
|
| 22 | + Knowledge base object, used to update or query knowledge base |
| 23 | +
|
| 24 | + Add data to knowledge base: |
| 25 | +
|
| 26 | + >>> kb.insert(pd.read_csv('house_sales.csv')) |
| 27 | +
|
| 28 | + Query relevant results |
| 29 | +
|
| 30 | + >>> df = kb.find('flats').fetch() |
| 31 | +
|
| 32 | + """ |
| 33 | + |
| 34 | + def __init__(self, project, data: dict): |
| 35 | + |
| 36 | + self.project = project |
| 37 | + self.name = data['name'] |
| 38 | + |
| 39 | + self.storage = None |
| 40 | + if data['storage'] is not None: |
| 41 | + # if name contents '.' there could be errors |
| 42 | + |
| 43 | + parts = data['storage'].split('.') |
| 44 | + if len(parts) == 2: |
| 45 | + database_name, table_name = parts |
| 46 | + database = Database(project, database_name) |
| 47 | + table = Table(database, table_name) |
| 48 | + self.storage = table |
| 49 | + |
| 50 | + self.model = None |
| 51 | + if data['model'] is not None: |
| 52 | + self.model = Model(self.project, {'name': data['model']}) |
| 53 | + |
| 54 | + params = data.get('params', {}) |
| 55 | + if isinstance(params, str): |
| 56 | + try: |
| 57 | + params = json.loads(params) |
| 58 | + except json.JSONDecodeError: |
| 59 | + params = {} |
| 60 | + |
| 61 | + # columns |
| 62 | + self.metadata_columns = params.pop('metadata_columns', []) |
| 63 | + self.content_columns = params.pop('content_columns', []) |
| 64 | + self.id_column = params.pop('id_column', None) |
| 65 | + |
| 66 | + self.params = params |
| 67 | + |
| 68 | + # query behavior |
| 69 | + self._query = None |
| 70 | + self._limit = None |
| 71 | + |
| 72 | + database = project.name |
| 73 | + self._update_query() |
| 74 | + |
| 75 | + super().__init__(project.api, self.sql, database) |
| 76 | + |
| 77 | + def __repr__(self): |
| 78 | + return f'{self.__class__.__name__}({self.project.name}.{self.name})' |
| 79 | + |
| 80 | + def find(self, query: str, limit: int = 100): |
| 81 | + """ |
| 82 | +
|
| 83 | + Query data from knowledge base. |
| 84 | + Knowledge base should return a most relevant results for the query |
| 85 | +
|
| 86 | + >>> # query knowledge base |
| 87 | + >>> query = my_kb.find('dogs') |
| 88 | + >>> # fetch dataframe to client |
| 89 | + >>> print(query.fetch()) |
| 90 | +
|
| 91 | + :param query: text query |
| 92 | + :param limit: count of rows in result, default 100 |
| 93 | + :return: Query object |
| 94 | + """ |
| 95 | + |
| 96 | + kb = copy.deepcopy(self) |
| 97 | + kb._query = query |
| 98 | + kb._limit = limit |
| 99 | + kb._update_query() |
| 100 | + |
| 101 | + return kb |
| 102 | + |
| 103 | + def _update_query(self): |
| 104 | + |
| 105 | + ast_query = Select( |
| 106 | + targets=[Star()], |
| 107 | + from_table=Identifier(parts=[ |
| 108 | + self.project.name, self.name |
| 109 | + ]) |
| 110 | + ) |
| 111 | + if self._query is not None: |
| 112 | + ast_query.where = BinaryOperation(op='=', args=[ |
| 113 | + Identifier('content'), |
| 114 | + Constant(self._query) |
| 115 | + ]) |
| 116 | + |
| 117 | + if self._limit is not None: |
| 118 | + ast_query.limit = Constant(self._limit) |
| 119 | + self.sql = ast_query.to_string() |
| 120 | + |
| 121 | + def insert(self, data: Union[pd.DataFrame, Query, dict]): |
| 122 | + """ |
| 123 | + Insert data to knowledge base |
| 124 | +
|
| 125 | + >>> # insert using query |
| 126 | + >>> my_kb.insert(server.databases.example_db.tables.houses_sales.filter(type='house')) |
| 127 | + >>> # using dataframe |
| 128 | + >>> my_kb.insert(pd.read_csv('house_sales.csv')) |
| 129 | + >>> # using dict |
| 130 | + >>> my_kb.insert({'type': 'house', 'date': '2020-02-02'}) |
| 131 | +
|
| 132 | + Data will be if id (defined by id_column param, see create knowledge base) is already exists in knowledge base |
| 133 | + it will be replaced |
| 134 | +
|
| 135 | + :param data: Dataframe or Query object or dict. |
| 136 | + """ |
| 137 | + |
| 138 | + if isinstance(data, dict): |
| 139 | + data = pd.DataFrame([data]) |
| 140 | + |
| 141 | + if isinstance(data, pd.DataFrame): |
| 142 | + # insert data |
| 143 | + data_split = data.to_dict('split') |
| 144 | + |
| 145 | + ast_query = Insert( |
| 146 | + table=Identifier(self.name), |
| 147 | + columns=data_split['columns'], |
| 148 | + values=data_split['data'] |
| 149 | + ) |
| 150 | + |
| 151 | + sql = ast_query.to_string() |
| 152 | + self.api.sql_query(sql, self.database) |
| 153 | + else: |
| 154 | + # insert from select |
| 155 | + table = Identifier(parts=[self.database, self.name]) |
| 156 | + self.api.sql_query( |
| 157 | + f'INSERT INTO {table.to_string()} ({data.sql})', |
| 158 | + database=data.database |
| 159 | + ) |
| 160 | + |
| 161 | + |
| 162 | +class KnowledgeBases(CollectionBase): |
| 163 | + """ |
| 164 | + **Knowledge bases** |
| 165 | +
|
| 166 | + Get list: |
| 167 | +
|
| 168 | + >>> kb_list = server.knowledge_bases.list() |
| 169 | + >>> kb = kb_list[0] |
| 170 | +
|
| 171 | + Get by name: |
| 172 | +
|
| 173 | + >>> kb = server.knowledge_bases.get('my_kb') |
| 174 | + >>> # or : |
| 175 | + >>> kb = server.knowledge_bases.my_kb |
| 176 | +
|
| 177 | + Create: |
| 178 | +
|
| 179 | + >>> kb = server.knowledge_bases.create('my_kb') |
| 180 | +
|
| 181 | + Drop: |
| 182 | +
|
| 183 | + >>> server.knowledge_bases.drop('my_kb') |
| 184 | +
|
| 185 | + """ |
| 186 | + |
| 187 | + def __init__(self, project, api): |
| 188 | + self.project = project |
| 189 | + self.api = api |
| 190 | + |
| 191 | + def _list(self, name: str = None) -> List[KnowledgeBase]: |
| 192 | + |
| 193 | + # TODO add filter by project. for now 'project' is empty |
| 194 | + ast_query = Select(targets=[Star()], from_table=Identifier(parts=['information_schema', 'knowledge_bases'])) |
| 195 | + if name is not None: |
| 196 | + ast_query.where = dict_to_binary_op({'name': name}) |
| 197 | + |
| 198 | + df = self.api.sql_query(ast_query.to_string(), database=self.project.name) |
| 199 | + |
| 200 | + # columns to lower case |
| 201 | + cols_map = {i: i.lower() for i in df.columns} |
| 202 | + df = df.rename(columns=cols_map) |
| 203 | + |
| 204 | + return [ |
| 205 | + KnowledgeBase(self.project, item) |
| 206 | + for item in df.to_dict('records') |
| 207 | + ] |
| 208 | + |
| 209 | + def list(self) -> List[KnowledgeBase]: |
| 210 | + """ |
| 211 | +
|
| 212 | + Get list of knowledge bases inside of project: |
| 213 | +
|
| 214 | + >>> kb_list = project.knowledge_bases.list() |
| 215 | +
|
| 216 | + :return: list of knowledge bases |
| 217 | + """ |
| 218 | + return self._list() |
| 219 | + |
| 220 | + def get(self, name: str) -> KnowledgeBase: |
| 221 | + """ |
| 222 | + Get knowledge base by name |
| 223 | +
|
| 224 | + :param name: name of the knowledge base |
| 225 | + :return: KnowledgeBase object |
| 226 | + """ |
| 227 | + item = self._list(name) |
| 228 | + if len(item) == 1: |
| 229 | + return item[0] |
| 230 | + elif len(item) == 0: |
| 231 | + raise AttributeError("KnowledgeBase doesn't exist") |
| 232 | + else: |
| 233 | + raise RuntimeError("Several knowledgeBases with the same name") |
| 234 | + |
| 235 | + def create( |
| 236 | + self, |
| 237 | + name: str, |
| 238 | + model: Model = None, |
| 239 | + storage: Table = None, |
| 240 | + metadata_columns: list = None, |
| 241 | + content_columns: list = None, |
| 242 | + id_column: str = None, |
| 243 | + params: dict = None, |
| 244 | + ) -> KnowledgeBase: |
| 245 | + """ |
| 246 | +
|
| 247 | + Create knowledge base: |
| 248 | +
|
| 249 | + >>> kb = server.knowledge_bases.create( |
| 250 | + ... 'my_kb', |
| 251 | + ... model=server.models.emb_model, |
| 252 | + ... storage=server.databases.pvec.tables.tbl1, |
| 253 | + ... metadata_columns=['date', 'author'], |
| 254 | + ... content_columns=['review', 'description'], |
| 255 | + ... id_column='number', |
| 256 | + ... params={'a': 1} |
| 257 | + ...) |
| 258 | +
|
| 259 | + :param name: name of the knowledge base |
| 260 | + :param model: embedding model, optional. Default: 'sentence_transformers' will be used (defined in mindsdb server) |
| 261 | + :param storage: vector storage, optional. Default: chromadb database will be created |
| 262 | + :param metadata_columns: columns to use as metadata, optional. Default: all columns which are not content and id |
| 263 | + :param content_columns: columns to use as content, optional. Default: all columns except id column |
| 264 | + :param id_column: the column to use as id, optinal. Default: 'id', if exists |
| 265 | + :param params: other parameters to knowledge base |
| 266 | + :return: created KnowledgeBase object |
| 267 | + """ |
| 268 | + |
| 269 | + params_out = {} |
| 270 | + |
| 271 | + if metadata_columns is not None: |
| 272 | + params_out['metadata_columns'] = metadata_columns |
| 273 | + |
| 274 | + if content_columns is not None: |
| 275 | + params_out['content_columns'] = content_columns |
| 276 | + |
| 277 | + if id_column is not None: |
| 278 | + params_out['id_column'] = id_column |
| 279 | + |
| 280 | + if params is not None: |
| 281 | + params_out.update(params) |
| 282 | + |
| 283 | + if model is not None: |
| 284 | + model_name = Identifier(parts=[model.project.name, model.name]) |
| 285 | + else: |
| 286 | + model_name = None |
| 287 | + |
| 288 | + if storage is not None: |
| 289 | + storage_name = Identifier(parts=[storage.db.name, storage.name]) |
| 290 | + else: |
| 291 | + storage_name = None |
| 292 | + |
| 293 | + ast_query = CreateKnowledgeBase( |
| 294 | + Identifier(name), |
| 295 | + model=model_name, |
| 296 | + storage=storage_name, |
| 297 | + params=params_out |
| 298 | + ) |
| 299 | + |
| 300 | + self.api.sql_query(ast_query.to_string(), database=self.project.name) |
| 301 | + |
| 302 | + return self.get(name) |
| 303 | + |
| 304 | + def drop(self, name: str): |
| 305 | + """ |
| 306 | +
|
| 307 | + :param name: |
| 308 | + :return: |
| 309 | + """ |
| 310 | + |
| 311 | + ast_query = DropKnowledgeBase(Identifier(name)) |
| 312 | + |
| 313 | + self.api.sql_query(ast_query.to_string()) |
0 commit comments