Skip to content

Commit 015db36

Browse files
committed
use full names for model.predict
1 parent b96cf64 commit 015db36

File tree

1 file changed

+34
-13
lines changed

1 file changed

+34
-13
lines changed

mindsdb_sdk/models.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
from mindsdb_sql.parser.dialects.mindsdb import RetrainPredictor, FinetunePredictor
1010
from mindsdb_sql.parser.ast import Identifier, Select, Star, Join, Update, Describe, Constant
1111
from mindsdb_sql import parse_sql
12+
from mindsdb_sql.exceptions import ParsingException
1213

1314
from .ml_engines import MLEngine
1415

1516
from mindsdb_sdk.utils.objects_collection import CollectionBase
16-
from mindsdb_sdk.utils.sql import dict_to_binary_op
17+
from mindsdb_sdk.utils.sql import dict_to_binary_op, query_to_native_query
1718

1819
from .query import Query
1920

@@ -132,7 +133,10 @@ def predict(self, data: Union[pd.DataFrame, Query, dict], params: dict = None) -
132133
"""
133134
if isinstance(data, Query):
134135
# create join from select if it is simple select
135-
ast_query = parse_sql(data.sql, dialect='mindsdb')
136+
try:
137+
ast_query = parse_sql(data.sql, dialect='mindsdb')
138+
except ParsingException:
139+
ast_query = None
136140

137141
# injection of join disabled yet
138142
# if isinstance(ast_query, Select) and isinstance(ast_query.from_table, Identifier):
@@ -165,24 +169,41 @@ def predict(self, data: Union[pd.DataFrame, Query, dict], params: dict = None) -
165169
# ast_query.targets = [Identifier(parts=['m', Star()])]
166170
#
167171

168-
# wrap query to subselect
169172
model_identifier = self._get_identifier()
170173
model_identifier.alias = Identifier('m')
171174

172-
ast_query.parentheses = True
173-
ast_query.alias = Identifier('t')
174-
upper_query = Select(
175-
targets=[Identifier(parts=['m', Star()])],
176-
from_table=Join(
177-
join_type='join',
178-
left=ast_query,
179-
right=model_identifier
175+
if data.database is not None or ast_query is None or not isinstance(ast_query, Select):
176+
# use native query
177+
native_query = query_to_native_query(data)
178+
native_query.parentheses = True
179+
native_query.alias = Identifier('t')
180+
upper_query = Select(
181+
targets=[Identifier(parts=['m', Star()])],
182+
from_table=Join(
183+
join_type='join',
184+
left=native_query,
185+
right=model_identifier
186+
)
187+
)
188+
else:
189+
# wrap query to subselect
190+
model_identifier = self._get_identifier()
191+
model_identifier.alias = Identifier('m')
192+
193+
ast_query.parentheses = True
194+
ast_query.alias = Identifier('t')
195+
upper_query = Select(
196+
targets=[Identifier(parts=['m', Star()])],
197+
from_table=Join(
198+
join_type='join',
199+
left=ast_query,
200+
right=model_identifier
201+
)
180202
)
181-
)
182203
if params is not None:
183204
upper_query.using = params
184205
# execute in query's database
185-
return self.project.api.sql_query(upper_query.to_string(), database=data.database)
206+
return self.project.api.sql_query(upper_query.to_string(), database=None)
186207

187208
elif isinstance(data, dict):
188209
data = pd.DataFrame([data])

0 commit comments

Comments
 (0)