|
9 | 9 | from mindsdb_sql.parser.dialects.mindsdb import RetrainPredictor, FinetunePredictor |
10 | 10 | from mindsdb_sql.parser.ast import Identifier, Select, Star, Join, Update, Describe, Constant |
11 | 11 | from mindsdb_sql import parse_sql |
| 12 | +from mindsdb_sql.exceptions import ParsingException |
12 | 13 |
|
13 | 14 | from .ml_engines import MLEngine |
14 | 15 |
|
15 | 16 | from mindsdb_sdk.utils.objects_collection import CollectionBase |
16 | | -from mindsdb_sdk.utils.sql import dict_to_binary_op |
| 17 | +from mindsdb_sdk.utils.sql import dict_to_binary_op, query_to_native_query |
17 | 18 |
|
18 | 19 | from .query import Query |
19 | 20 |
|
@@ -132,7 +133,10 @@ def predict(self, data: Union[pd.DataFrame, Query, dict], params: dict = None) - |
132 | 133 | """ |
133 | 134 | if isinstance(data, Query): |
134 | 135 | # create join from select if it is simple select |
135 | | - ast_query = parse_sql(data.sql, dialect='mindsdb') |
| 136 | + try: |
| 137 | + ast_query = parse_sql(data.sql, dialect='mindsdb') |
| 138 | + except ParsingException: |
| 139 | + ast_query = None |
136 | 140 |
|
137 | 141 | # injection of join disabled yet |
138 | 142 | # if isinstance(ast_query, Select) and isinstance(ast_query.from_table, Identifier): |
@@ -165,24 +169,41 @@ def predict(self, data: Union[pd.DataFrame, Query, dict], params: dict = None) - |
165 | 169 | # ast_query.targets = [Identifier(parts=['m', Star()])] |
166 | 170 | # |
167 | 171 |
|
168 | | - # wrap query to subselect |
169 | 172 | model_identifier = self._get_identifier() |
170 | 173 | model_identifier.alias = Identifier('m') |
171 | 174 |
|
172 | | - ast_query.parentheses = True |
173 | | - ast_query.alias = Identifier('t') |
174 | | - upper_query = Select( |
175 | | - targets=[Identifier(parts=['m', Star()])], |
176 | | - from_table=Join( |
177 | | - join_type='join', |
178 | | - left=ast_query, |
179 | | - right=model_identifier |
| 175 | + if data.database is not None or ast_query is None or not isinstance(ast_query, Select): |
| 176 | + # use native query |
| 177 | + native_query = query_to_native_query(data) |
| 178 | + native_query.parentheses = True |
| 179 | + native_query.alias = Identifier('t') |
| 180 | + upper_query = Select( |
| 181 | + targets=[Identifier(parts=['m', Star()])], |
| 182 | + from_table=Join( |
| 183 | + join_type='join', |
| 184 | + left=native_query, |
| 185 | + right=model_identifier |
| 186 | + ) |
| 187 | + ) |
| 188 | + else: |
| 189 | + # wrap query to subselect |
| 190 | + model_identifier = self._get_identifier() |
| 191 | + model_identifier.alias = Identifier('m') |
| 192 | + |
| 193 | + ast_query.parentheses = True |
| 194 | + ast_query.alias = Identifier('t') |
| 195 | + upper_query = Select( |
| 196 | + targets=[Identifier(parts=['m', Star()])], |
| 197 | + from_table=Join( |
| 198 | + join_type='join', |
| 199 | + left=ast_query, |
| 200 | + right=model_identifier |
| 201 | + ) |
180 | 202 | ) |
181 | | - ) |
182 | 203 | if params is not None: |
183 | 204 | upper_query.using = params |
184 | 205 | # execute in query's database |
185 | | - return self.project.api.sql_query(upper_query.to_string(), database=data.database) |
| 206 | + return self.project.api.sql_query(upper_query.to_string(), database=None) |
186 | 207 |
|
187 | 208 | elif isinstance(data, dict): |
188 | 209 | data = pd.DataFrame([data]) |
|
0 commit comments