Skip to content

Commit 99c8970

Browse files
authored
Merge pull request #111 from mindsdb/schema-refactor
Adapt for schema refactor
2 parents d1fbab7 + 7558f6a commit 99c8970

File tree

3 files changed

+19
-33
lines changed

3 files changed

+19
-33
lines changed

mindsdb_sdk/models.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from mindsdb_sql.parser.dialects.mindsdb import CreatePredictor, DropPredictor
99
from mindsdb_sql.parser.dialects.mindsdb import RetrainPredictor, FinetunePredictor
10-
from mindsdb_sql.parser.ast import Identifier, Select, Star, Join, Update, Describe, Constant
10+
from mindsdb_sql.parser.ast import Identifier, Select, Star, Join, Describe, Set
1111
from mindsdb_sql import parse_sql
1212
from mindsdb_sql.exceptions import ParsingException
1313

@@ -388,15 +388,9 @@ def set_active(self, version: int):
388388
389389
:param version: version to set active
390390
"""
391-
ast_query = Update(
392-
table=Identifier(parts=[self.project.name, 'models_versions']),
393-
update_columns={
394-
'active': Constant(1)
395-
},
396-
where=dict_to_binary_op({
397-
'name': self.name,
398-
'version': version
399-
})
391+
ast_query = Set(
392+
category='active',
393+
value=Identifier(parts=[self.project.name, self.name, str(version)])
400394
)
401395
sql = ast_query.to_string()
402396
if is_saving():
@@ -609,21 +603,22 @@ def list(self, with_versions: bool = False,
609603
:return: list of Model or ModelVersion objects
610604
"""
611605

612-
table = 'models'
613606
model_class = Model
614-
if with_versions:
615-
table = 'models_versions'
616-
model_class = ModelVersion
617607

618-
filters = { }
608+
filters = {}
619609
if name is not None:
620610
filters['NAME'] = name
621611
if version is not None:
622612
filters['VERSION'] = version
623613

614+
if with_versions:
615+
model_class = ModelVersion
616+
else:
617+
filters['ACTIVE'] = '1'
618+
624619
ast_query = Select(
625620
targets=[Star()],
626-
from_table=Identifier(table),
621+
from_table=Identifier('models'),
627622
where=dict_to_binary_op(filters)
628623
)
629624
df = self.project.query(ast_query.to_string()).fetch()

mindsdb_sdk/projects.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
from typing import List
22

3-
from mindsdb_sql.parser.dialects.mindsdb import CreateDatabase
3+
from mindsdb_sql.parser.dialects.mindsdb import CreateDatabase, DropPredictor
44
from mindsdb_sql.parser.ast import DropDatabase
5-
from mindsdb_sql.parser.ast import Identifier, Delete
6-
7-
from mindsdb_sdk.utils.sql import dict_to_binary_op
5+
from mindsdb_sql.parser.ast import Identifier
86

97
from mindsdb_sdk.agents import Agents
108
from mindsdb_sdk.databases import Databases
@@ -96,21 +94,15 @@ def query(self, sql: str) -> Query:
9694
"""
9795
return Query(self.api, sql, database=self.name)
9896

99-
10097
def drop_model_version(self, name: str, version: int):
10198
"""
10299
Drop version of the model
103100
104101
:param name: name of the model
105102
:param version: version to drop
106103
"""
107-
ast_query = Delete(
108-
table=Identifier('models_versions'),
109-
where=dict_to_binary_op({
110-
'name': name,
111-
'version': version
112-
})
113-
)
104+
ast_query = DropPredictor(Identifier(parts=[name, str(version)]))
105+
114106
self.query(ast_query.to_string()).fetch()
115107

116108

tests/test_sdk.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def check_model(self, model, database, mock_post):
183183

184184
# list all versions
185185
models = model.list_versions()
186-
check_sql_call(mock_post, f"SELECT * FROM models_versions WHERE NAME = '{model.name}'",
186+
check_sql_call(mock_post, f"SELECT * FROM models WHERE NAME = '{model.name}'",
187187
database=model.project.name)
188188
model2 = models[0] # Model object
189189

@@ -194,8 +194,7 @@ def check_model(self, model, database, mock_post):
194194

195195
# get call before last call
196196
mock_call = mock_post.call_args_list[-2]
197-
assert mock_call[1]['json'][
198-
'query'] == f"update {model2.project.name}.models_versions set active=1 where name = '{model2.name}' AND version = 3"
197+
assert mock_call[1]['json']['query'] == f"SET active {model2.project.name}.{model2.name}.`3`"
199198

200199
@patch('requests.Session.post')
201200
def check_table(self, table, mock_post):
@@ -494,7 +493,7 @@ def check_project_models_versions(self, project, database, mock_post):
494493
self.check_model(model, database)
495494

496495
project.drop_model_version('m1', 1)
497-
check_sql_call(mock_post, f"delete from models_versions where name='m1' and version=1")
496+
check_sql_call(mock_post, f"DROP PREDICTOR m1.`1`")
498497

499498

500499
@patch('requests.Session.post')
@@ -961,7 +960,7 @@ def check_project_models_versions(self, project, database, mock_post):
961960
self.check_model(model, database)
962961

963962
project.models.m1.drop_version(1)
964-
check_sql_call(mock_post, f"delete from models_versions where name='m1' and version=1")
963+
check_sql_call(mock_post, f"DROP PREDICTOR m1.`1`")
965964

966965
@patch('requests.Session.post')
967966
def check_database(self, database, mock_post):

0 commit comments

Comments
 (0)