Skip to content

Commit fc4f59c

Browse files
committed
create job with sdk commands
1 parent 0e74165 commit fc4f59c

File tree

5 files changed

+209
-43
lines changed

5 files changed

+209
-43
lines changed

mindsdb_sdk/jobs.py

Lines changed: 94 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,32 @@
11
import datetime as dt
22
from typing import Union, List
33

4+
45
import pandas as pd
56

67
from mindsdb_sql.parser.dialects.mindsdb import CreateJob, DropJob
78
from mindsdb_sql.parser.ast import Identifier, Star, Select
89

10+
from mindsdb_sdk.query import Query
911
from mindsdb_sdk.utils.sql import dict_to_binary_op
1012
from mindsdb_sdk.utils.objects_collection import CollectionBase
13+
from mindsdb_sdk.utils.context import set_saving
1114

1215

1316
class Job:
14-
def __init__(self, project, data):
17+
def __init__(self, project, name, data=None, create_callback=None):
1518
self.project = project
19+
self.name = name
1620
self.data = data
17-
self._update(data)
21+
22+
self.query_str = None
23+
if data is not None:
24+
self._update(data)
25+
self._queries = []
26+
self._create_callback = create_callback
1827

1928
def _update(self, data):
20-
self.name = data['name']
29+
# self.name = data['name']
2130
self.query_str = data['query']
2231
self.start_at = data['start_at']
2332
self.end_at = data['end_at']
@@ -27,13 +36,52 @@ def _update(self, data):
2736
def __repr__(self):
2837
return f"{self.__class__.__name__}({self.name}, query='{self.query_str}')"
2938

39+
def __enter__(self):
40+
if self._create_callback is None:
41+
raise RuntimeError('It can not be used to create context')
42+
set_saving(f'job-{self.name}')
43+
44+
return self
45+
46+
def __exit__(self, type, value, traceback):
47+
set_saving(None)
48+
if type is None:
49+
if len(self._queries) == 0:
50+
raise RuntimeError('No queries were added to job')
51+
52+
query_str = '; '.join(self._queries)
53+
54+
self._create_callback(query_str)
55+
56+
self.refresh()
57+
3058
def refresh(self):
3159
"""
3260
Retrieve job data from mindsdb server
3361
"""
3462
job = self.project.get_job(self.name)
3563
self._update(job.data)
3664

65+
def add_query(self, query: Union[Query, str]):
66+
"""
67+
Add a query to job. Method is used in context of the job
68+
69+
>>> with con.jobs.create('j1') as job:
70+
>>> job.add_query(table1.insert(table2))
71+
72+
:param query: string or Query object. Query.database should be emtpy or the same as job's project
73+
"""
74+
if isinstance(query, Query):
75+
76+
if query.database is not None and query.database != self.project.name:
77+
# we can't execute this query in jobs project
78+
raise ValueError(f"Wrong query database: {query.database}. You could try to use sql string instead")
79+
80+
query = query.sql
81+
elif not isinstance(query, str):
82+
raise RuntimeError(f'Unable to use add this object as a query: {query}. Try to use sql string instead')
83+
self._queries.append(query)
84+
3785
def get_history(self) -> pd.DataFrame:
3886
"""
3987
Get history of job execution
@@ -69,7 +117,7 @@ def _list(self, name: str = None) -> List[Job]:
69117
df = df.rename(columns=cols_map)
70118

71119
return [
72-
Job(self.project, item)
120+
Job(self.project, item.pop('name'), item)
73121
for item in df.to_dict('records')
74122
]
75123

@@ -101,7 +149,7 @@ def get(self, name: str) -> Job:
101149
def create(
102150
self,
103151
name: str,
104-
query_str: str,
152+
query_str: str = None,
105153
start_at: dt.datetime = None,
106154
end_at: dt.datetime = None,
107155
repeat_str: str = None,
@@ -113,7 +161,25 @@ def create(
113161
If it is not possible (job executed and not accessible anymore):
114162
return None
115163
116-
More info: https://docs.mindsdb.com/sql/create/jobs
164+
Usage options:
165+
166+
Option 1: to use string query
167+
All job tasks could be passed as string with sql queries. Job is created emmideiately
168+
169+
>>> job = con.jobs.create('j1', query_str='retrain m1; show models', repeat_min=1):
170+
171+
Option 2: to use 'with' block.
172+
It allows to pass sdk commands to job tasks.
173+
Not all sdk commands could be accepted here,
174+
only those which are converted in to sql in sdk and sent to /query endpoint
175+
Adding query sql string is accepted as well
176+
Job will be created after exit from 'with' block
177+
178+
>>> with con.jobs.create('j1', repeat_min=1) as job:
179+
>>> job.add_query(table1.insert(table2))
180+
>>> job.add_query('retrain m1') # using string
181+
182+
More info about jobs: https://docs.mindsdb.com/sql/create/jobs
117183
118184
:param name: name of the job
119185
:param query_str: str, job's query (or list of queries with ';' delimiter) which job have to execute
@@ -137,20 +203,30 @@ def create(
137203
if repeat_min is not None:
138204
repeat_str = f'{repeat_min} minutes'
139205

140-
ast_query = CreateJob(
141-
name=Identifier(name),
142-
query_str=query_str,
143-
start_str=start_str,
144-
end_str=end_str,
145-
repeat_str=repeat_str
146-
)
206+
def _create_callback(query):
207+
ast_query = CreateJob(
208+
name=Identifier(name),
209+
query_str=query,
210+
start_str=start_str,
211+
end_str=end_str,
212+
repeat_str=repeat_str
213+
)
214+
215+
self.api.sql_query(ast_query.to_string(), database=self.project.name)
216+
217+
if query_str is None:
218+
# allow to create context with job
219+
job = Job(self.project, name, create_callback=_create_callback)
220+
return job
221+
else:
222+
# create it
223+
_create_callback(query_str)
147224

148-
self.api.sql_query(ast_query.to_string(), database=self.project.name)
225+
# job can be executed and remove it is not repeatable
226+
jobs = self._list(name)
227+
if len(jobs) == 1:
228+
return jobs[0]
149229

150-
# job can be executed and remove it is not repeatable
151-
jobs = self._list(name)
152-
if len(jobs) == 1:
153-
return jobs[0]
154230

155231
def drop(self, name: str):
156232
"""

mindsdb_sdk/models.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from mindsdb_sdk.utils.objects_collection import CollectionBase
1717
from mindsdb_sdk.utils.sql import dict_to_binary_op, query_to_native_query
18+
from mindsdb_sdk.utils.context import is_saving
1819

1920
from .query import Query
2021

@@ -106,15 +107,15 @@ def __repr__(self):
106107
version = ''
107108
if self.version is not None:
108109
version = f', version={self.version}'
109-
return f'{self.__class__.__name__}({self.name}{version}, status={self.data["status"]})'
110+
return f'{self.__class__.__name__}({self.name}{version}, status={self.data.get("status")})'
110111

111112
def _get_identifier(self):
112113
parts = [self.project.name, self.name]
113114
if self.version is not None:
114115
parts.append(str(self.version))
115116
return Identifier(parts=parts)
116117

117-
def predict(self, data: Union[pd.DataFrame, Query, dict], params: dict = None) -> pd.DataFrame:
118+
def predict(self, data: Union[pd.DataFrame, Query, dict], params: dict = None) -> Union[pd.DataFrame, Query]:
118119
"""
119120
Make prediction using model
120121
@@ -203,7 +204,11 @@ def predict(self, data: Union[pd.DataFrame, Query, dict], params: dict = None) -
203204
if params is not None:
204205
upper_query.using = params
205206
# execute in query's database
206-
return self.project.api.sql_query(upper_query.to_string(), database=None)
207+
sql = upper_query.to_string()
208+
if is_saving():
209+
return Query(self, sql)
210+
211+
return self.project.api.sql_query(sql, database=None)
207212

208213
elif isinstance(data, dict):
209214
data = pd.DataFrame([data])
@@ -310,15 +315,19 @@ def _retrain(self,
310315
integration_name=database,
311316
using=options or None,
312317
)
318+
sql = ast_query.to_string()
319+
320+
if is_saving():
321+
return Query(self, sql)
313322

314-
data = self.project.query(ast_query.to_string()).fetch()
323+
data = self.project.api.sql_query(sql)
315324
data = {k.lower(): v for k, v in data.items()}
316325

317326
# return new instance
318327
base_class = self.__class__
319328
return base_class(self.project, data)
320329

321-
def describe(self, type: str = None) -> pd.DataFrame:
330+
def describe(self, type: str = None) -> Union[pd.DataFrame, Query]:
322331
"""
323332
Return description of the model
324333
@@ -332,7 +341,12 @@ def describe(self, type: str = None) -> pd.DataFrame:
332341
if type is not None:
333342
identifier.parts.append(type)
334343
ast_query = Describe(identifier)
335-
return self.project.query(ast_query.to_string()).fetch()
344+
345+
sql = ast_query.to_string()
346+
if is_saving():
347+
return Query(self, sql)
348+
349+
return self.project.api.sql_query(sql)
336350

337351
def list_versions(self) -> List[ModelVersion]:
338352
"""
@@ -374,7 +388,7 @@ def set_active(self, version: int):
374388
:param version: version to set active
375389
"""
376390
ast_query = Update(
377-
table=Identifier('models_versions'),
391+
table=Identifier(parts=[self.project.name, 'models_versions']),
378392
update_columns={
379393
'active': Constant(1)
380394
},
@@ -383,7 +397,11 @@ def set_active(self, version: int):
383397
'version': version
384398
})
385399
)
386-
self.project.query(ast_query.to_string()).fetch()
400+
sql = ast_query.to_string()
401+
if is_saving():
402+
return Query(self, sql)
403+
404+
self.project.api.sql_query(sql)
387405
self.refresh()
388406

389407

@@ -430,7 +448,7 @@ def create(
430448
database: str = None,
431449
options: dict = None,
432450
timeseries_options: dict = None, **kwargs
433-
) -> Model:
451+
) -> Union[Model, Query]:
434452
"""
435453
Create new model in project and return it
436454
@@ -486,7 +504,7 @@ def create(
486504
targets = None
487505

488506
ast_query = CreatePredictor(
489-
name=Identifier(name),
507+
name=Identifier(parts=[self.project.name, name]),
490508
query_str=query,
491509
integration_name=database,
492510
targets=targets,
@@ -522,7 +540,13 @@ def create(
522540

523541
options['engine'] = engine
524542
ast_query.using = options
525-
df = self.project.query(ast_query.to_string()).fetch()
543+
544+
sql = ast_query.to_string()
545+
546+
if is_saving():
547+
return Query(self, sql)
548+
549+
df = self.project.api.sql_query(sql)
526550
if len(df) > 0:
527551
data = dict(df.iloc[0])
528552
# to lowercase
@@ -559,8 +583,12 @@ def drop(self, name: str):
559583
560584
:param name: name of the model
561585
"""
562-
ast_query = DropPredictor(name=Identifier(name))
563-
self.project.query(ast_query.to_string()).fetch()
586+
ast_query = DropPredictor(name=Identifier(parts=[self.project.name, name]))
587+
sql = ast_query.to_string()
588+
if is_saving():
589+
return Query(self, sql)
590+
591+
self.project.api.sql_query(sql)
564592

565593

566594
def list(self, with_versions: bool = False,

mindsdb_sdk/tables.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from mindsdb_sdk.utils.sql import dict_to_binary_op, add_condition, query_to_native_query
1111
from mindsdb_sdk.utils.objects_collection import CollectionBase
12+
from mindsdb_sdk.utils.context import is_saving
1213

1314
from .query import Query
1415

@@ -159,7 +160,11 @@ def delete(self, **kwargs):
159160
where=dict_to_binary_op(kwargs)
160161
)
161162
sql = ast_query.to_string()
162-
self.api.sql_query(sql, 'mindsdb')
163+
164+
if is_saving():
165+
return Query(self, sql)
166+
167+
self.api.sql_query(sql)
163168

164169
def update(self, values: Union[dict, Query], on: list = None, filters: dict = None):
165170
'''
@@ -218,6 +223,11 @@ def update(self, values: Union[dict, Query], on: list = None, filters: dict = No
218223
else:
219224
raise NotImplementedError
220225

226+
if is_saving():
227+
return Query(self, sql)
228+
229+
self.api.sql_query(sql)
230+
221231

222232
class Tables(CollectionBase):
223233
"""

mindsdb_sdk/utils/context.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from contextvars import ContextVar
2+
3+
context_storage = ContextVar('create_context')
4+
5+
6+
def set_context(name, value):
7+
data = context_storage.get({})
8+
data[name] = value
9+
10+
context_storage.set(data)
11+
12+
13+
def get_context(name):
14+
15+
data = context_storage.get({})
16+
return data.get(name)
17+
18+
19+
def set_saving(name):
20+
set_context('saving', name)
21+
22+
23+
def is_saving():
24+
return get_context('saving') is not None
25+

0 commit comments

Comments
 (0)