Skip to content

Commit c051c47

Browse files
committed
improve test coverage
1 parent 040fd43 commit c051c47

File tree

2 files changed

+52
-37
lines changed

2 files changed

+52
-37
lines changed

mindsdb_sdk/utils/objects_collection.py

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -19,39 +19,39 @@ def __getattr__(self, name):
1919
return self.get(name)
2020

2121

22-
class MethodCollection(CollectionBase):
23-
24-
def __init__(self, name, methods):
25-
self.name = name
26-
self.methods = methods
27-
28-
def __repr__(self):
29-
return f'{self.__class__.__name__}({self.name})'
30-
31-
def get(self, *args, **kwargs):
32-
method = self.methods.get('get')
33-
if method is None:
34-
raise NotImplementedError()
35-
36-
return method(*args, **kwargs)
37-
38-
def list(self, *args, **kwargs):
39-
method = self.methods.get('list')
40-
if method is None:
41-
raise NotImplementedError()
42-
43-
return method(*args, **kwargs)
44-
45-
def create(self, *args, **kwargs):
46-
method = self.methods.get('create')
47-
if method is None:
48-
raise NotImplementedError()
49-
50-
return method(*args, **kwargs)
51-
52-
def drop(self, name):
53-
method = self.methods.get('drop')
54-
if method is None:
55-
raise NotImplementedError()
56-
57-
return method(name)
22+
# class MethodCollection(CollectionBase):
23+
#
24+
# def __init__(self, name, methods):
25+
# self.name = name
26+
# self.methods = methods
27+
#
28+
# def __repr__(self):
29+
# return f'{self.__class__.__name__}({self.name})'
30+
#
31+
# def get(self, *args, **kwargs):
32+
# method = self.methods.get('get')
33+
# if method is None:
34+
# raise NotImplementedError()
35+
#
36+
# return method(*args, **kwargs)
37+
#
38+
# def list(self, *args, **kwargs):
39+
# method = self.methods.get('list')
40+
# if method is None:
41+
# raise NotImplementedError()
42+
#
43+
# return method(*args, **kwargs)
44+
#
45+
# def create(self, *args, **kwargs):
46+
# method = self.methods.get('create')
47+
# if method is None:
48+
# raise NotImplementedError()
49+
#
50+
# return method(*args, **kwargs)
51+
#
52+
# def drop(self, name):
53+
# method = self.methods.get('drop')
54+
# if method is None:
55+
# raise NotImplementedError()
56+
#
57+
# return method(name)

tests/test_sdk.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,14 @@ class Test(BaseFlow):
181181
@patch('requests.Session.put')
182182
@patch('requests.Session.post')
183183
def test_flow(self, mock_post, mock_put):
184+
# check local
185+
server = mindsdb_sdk.connect()
184186

187+
assert server.api.url == 'http://127.0.0.1:47334'
188+
189+
# check cloud login
185190
server = mindsdb_sdk.connect(login='a@b.com')
186191

187-
# check login
188192
call_args = mock_post.call_args
189193
assert call_args[0][0] == 'https://cloud.mindsdb.com/cloud/login'
190194
assert call_args[1]['json']['email'] == 'a@b.com'
@@ -197,6 +201,7 @@ def test_flow(self, mock_post, mock_put):
197201
check_sql_call(mock_post, "select NAME from information_schema.databases where TYPE='data'")
198202

199203
database = databases[0]
204+
str(database)
200205
assert database.name == 'db1'
201206
self.check_database(database)
202207

@@ -364,6 +369,7 @@ def check_project_models(self, project, database, mock_post):
364369
f'CREATE PREDICTOR m2 FROM example_db (select * from t1) PREDICT price ORDER BY date GROUP BY a, b WINDOW 10 HORIZON 2 USING module="LightGBM", `engine`="lightwood"'
365370
)
366371
assert model.name == 'm2'
372+
model.wait_complete()
367373
self.check_model(model, database)
368374

369375
# create, using deferred query.
@@ -373,6 +379,7 @@ def check_project_models(self, project, database, mock_post):
373379
predict='price',
374380
query=query,
375381
)
382+
str(query)
376383

377384
check_sql_call(
378385
mock_post,
@@ -963,6 +970,7 @@ def check_project_jobs(self, project, mock_post):
963970
assert job.query_str == 'select 1'
964971

965972
job = project.jobs.job1
973+
str(job)
966974
assert job.name == 'job1'
967975
assert job.query_str == 'select 1'
968976

@@ -972,6 +980,13 @@ def check_project_jobs(self, project, mock_post):
972980
f"select * from jobs where name = 'job1'"
973981
)
974982

983+
job.get_history()
984+
985+
check_sql_call(
986+
mock_post,
987+
f"select * from jobs_history where name = 'job1'"
988+
)
989+
975990
project.jobs.create(
976991
name='job2',
977992
query_str='retrain m1',

0 commit comments

Comments
 (0)