Skip to content

Commit a61ce6c

Browse files
committed
fixed the query builder for joins
1 parent 8e8fa68 commit a61ce6c

File tree

6 files changed

+36
-14
lines changed

6 files changed

+36
-14
lines changed

ads/feature_store/execution_strategy/spark/spark_execution.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,9 @@ def _save_offline_dataframe(
179179

180180
if isinstance(data_frame, pd.DataFrame):
181181
if not feature_group.is_infer_schema:
182-
convert_pandas_datatype_with_schema(feature_group.input_feature_details, data_frame)
182+
convert_pandas_datatype_with_schema(
183+
feature_group.input_feature_details, data_frame
184+
)
183185

184186
# TODO: Get event timestamp column and apply filtering basis from and to timestamp
185187

ads/feature_store/feature_group.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ class FeatureGroup(Builder):
155155
CONST_LIFECYCLE_STATE: "lifecycle_state",
156156
CONST_OUTPUT_FEATURE_DETAILS: "output_feature_details",
157157
CONST_STATISTICS_CONFIG: "statistics_config",
158-
CONST_INFER_SCHEMA: "is_infer_schema"
158+
CONST_INFER_SCHEMA: "is_infer_schema",
159159
}
160160

161161
def __init__(self, spec: Dict = None, **kwargs) -> None:
@@ -1225,7 +1225,8 @@ def _get_job_id(self, job_id: str = None) -> str:
12251225

12261226
if self.job_id is None:
12271227
raise ValueError(
1228-
"Unable to retrieve the last job. Please provide the job ID and make sure you materialized the data.")
1228+
"Unable to retrieve the last job. Please provide the job ID and make sure you materialized the data."
1229+
)
12291230

12301231
return self.job_id
12311232

@@ -1278,9 +1279,13 @@ def get_validation_output(self, job_id: str = None) -> "ValidationOutput":
12781279
# Retrieve the validation output JSON from data_flow_batch_execution_output.
12791280
fg_job = FeatureGroupJob.from_id(validation_job_id)
12801281
output_details = fg_job.job_output_details
1281-
validation_output = output_details.get("validationOutput") if output_details else None
1282+
validation_output = (
1283+
output_details.get("validationOutput") if output_details else None
1284+
)
12821285

1283-
validation_output_json = json.loads(validation_output) if validation_output else None
1286+
validation_output_json = (
1287+
json.loads(validation_output) if validation_output else None
1288+
)
12841289

12851290
return ValidationOutput(validation_output_json)
12861291

ads/feature_store/query/generator/query_generator.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,13 @@ def generate_query(self, is_online: bool = False) -> str:
7474
selected_features_map = {}
7575
index = 0
7676
on_condition = []
77-
left_table = f"`{self.query.entity_id}`.{self.query.left_feature_group.name}"
78-
left_table_alias = self.get_table_alias(len(self.query.joins))
77+
table = f"`{self.query.entity_id}`.{self.query.left_feature_group.name}"
78+
table_alias = self.get_table_alias(len(self.query.joins))
79+
left_table_alias = table_alias
7980

8081
# store the left features in the map
81-
selected_features_map[left_table_alias] = self.query.left_features
82-
feature_group_id_map = {self.query.left_feature_group.id: left_table_alias}
82+
selected_features_map[table_alias] = self.query.left_features
83+
feature_group_id_map = {self.query.left_feature_group.id: table_alias}
8384

8485
for join in self.query.joins:
8586
# Ge table and alias and map the features
@@ -99,6 +100,8 @@ def generate_query(self, is_online: bool = False) -> str:
99100
)
100101
)
101102

103+
left_table_alias = right_table_alias
104+
102105
selected_columns = self._get_selected_columns(selected_features_map)
103106
filters = (
104107
None
@@ -109,7 +112,7 @@ def generate_query(self, is_online: bool = False) -> str:
109112
)
110113

111114
# Define the SQL query as an f-string
112-
query = f"SELECT {selected_columns} FROM {left_table} {left_table_alias}"
115+
query = f"SELECT {selected_columns} FROM {table} {table_alias}"
113116

114117
if on_condition:
115118
# If there is an ON condition, add it to the query

ads/feature_store/query/query.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,20 @@ def with_left_feature_group(self, feature_group) -> "Query":
109109
"""
110110
return self.set_spec(self.CONST_LEFT_FEATURE_GROUP, feature_group)
111111

112+
def get_last_joined_feature_group(self):
113+
"""
114+
Retrieves the last joined feature group from the list of joins,
115+
or returns the left feature group if no joins are present.
116+
117+
Returns:
118+
The last joined feature group if the list of joins is non-empty,
119+
otherwise returns the left feature group.
120+
"""
121+
if self.joins:
122+
return self.joins[-1].sub_query.left_feature_group
123+
else:
124+
return self.left_feature_group
125+
112126
@property
113127
def _filter(self):
114128
return self.get_spec(self.CONST_FILTER)
@@ -291,7 +305,7 @@ def join(
291305
"""
292306

293307
join = Join(sub_query, on, left_on, right_on, join_type)
294-
QueryValidator.validate_query_join(self.left_feature_group, join)
308+
QueryValidator.validate_query_join(self.get_last_joined_feature_group(), join)
295309
self.joins.append(join)
296310
return self
297311

ads/feature_store/response/response_builder.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,9 @@ def to_pandas(self) -> pd.DataFrame:
5151

5252
@property
5353
def kind(self) -> str:
54-
5554
return "response_builder"
5655

5756
def to_dict(self) -> Dict:
58-
5957
spec = deepcopy(self._spec)
6058
for key, value in spec.items():
6159
if hasattr(value, "to_dict"):

ads/feature_store/statistics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,4 @@ def kind(self) -> str:
2222
str
2323
The kind of the statistics object, which is always "statistics".
2424
"""
25-
return "statistics"
25+
return "statistics"

0 commit comments

Comments
 (0)