Skip to content

Commit e2f4a1a

Browse files
committed
added as_of interface to the feature store
1 parent ba93f40 commit e2f4a1a

File tree

5 files changed

+177
-8
lines changed

5 files changed

+177
-8
lines changed

ads/feature_store/common/spark_session_singleton.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __init__(self, metastore_id: str = None):
8484
)
8585
.enableHiveSupport()
8686
)
87+
_managed_table_location = None
8788

8889
if not developer_enabled() and metastore_id:
8990
# Get the authentication credentials for the OCI data catalog service
@@ -94,12 +95,11 @@ def __init__(self, metastore_id: str = None):
9495

9596
data_catalog_client = OCIClientFactory(**auth).data_catalog
9697
metastore = data_catalog_client.get_metastore(metastore_id).data
98+
_managed_table_location = metastore.default_managed_table_location
9799
# Configure the Spark session builder object to use the specified metastore
98100
spark_builder.config(
99101
"spark.hadoop.oracle.dcat.metastore.id", metastore_id
100-
).config(
101-
"spark.sql.warehouse.dir", metastore.default_managed_table_location
102-
).config(
102+
).config("spark.sql.warehouse.dir", _managed_table_location).config(
103103
"spark.driver.memory", "16G"
104104
)
105105

@@ -114,7 +114,12 @@ def __init__(self, metastore_id: str = None):
114114

115115
self.spark_session.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
116116
self.spark_session.sparkContext.setLogLevel("OFF")
117+
self.managed_table_location = _managed_table_location
117118

118119
def get_spark_session(self):
119120
"""Access method to get the spark session."""
120121
return self.spark_session
122+
123+
def get_managed_table_location(self):
124+
"""Returns the managed table location for the spark"""
125+
return self.managed_table_location

ads/feature_store/common/utils/transformation_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8; -*-
33
import json
4+
45
# Copyright (c) 2023 Oracle and/or its affiliates.
56
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
67

@@ -64,9 +65,13 @@ def apply_transformation(
6465
dataframe.createOrReplaceTempView(temporary_table_view)
6566

6667
transformed_data = spark.sql(
67-
transformation_function_caller(temporary_table_view, **transformation_kwargs_dict)
68+
transformation_function_caller(
69+
temporary_table_view, **transformation_kwargs_dict
70+
)
6871
)
6972
elif transformation.transformation_mode == TransformationMode.PANDAS.value:
70-
transformed_data = transformation_function_caller(dataframe, **transformation_kwargs_dict)
73+
transformed_data = transformation_function_caller(
74+
dataframe, **transformation_kwargs_dict
75+
)
7176

7277
return transformed_data

ads/feature_store/dataset.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,16 @@
88

99
import pandas
1010
from great_expectations.core import ExpectationSuite
11+
12+
from ads import deprecated
1113
from ads.common import utils
1214
from ads.common.oci_mixin import OCIModelMixin
1315
from ads.feature_store.common.enums import (
1416
ExecutionEngine,
1517
ExpectationType,
1618
EntityType,
1719
)
20+
from ads.feature_store.common.exceptions import NotMaterializedError
1821
from ads.feature_store.common.utils.utility import (
1922
get_metastore_id,
2023
validate_delta_format_parameters,
@@ -475,6 +478,20 @@ def with_statistics_config(
475478
self.CONST_STATISTICS_CONFIG, statistics_config_in.to_dict()
476479
)
477480

481+
def target_delta_table(self):
482+
"""
483+
Returns the fully-qualified name of the target table for storing delta data.
484+
485+
The name of the target table is constructed by concatenating the entity ID
486+
and the name of the table, separated by a dot. The resulting string has the
487+
format 'entity_id.table_name'.
488+
489+
Returns:
490+
str: The fully-qualified name of the target delta table.
491+
"""
492+
target_table = f"{self.entity_id}.{self.name}"
493+
return target_table
494+
478495
@property
479496
def model_details(self) -> "ModelDetails":
480497
return self.get_spec(self.CONST_MODEL_DETAILS)
@@ -560,7 +577,9 @@ def add_models(self, model_details: ModelDetails) -> "Dataset":
560577
f"Dataset update Failed with : {type(ex)} with error message: {ex}"
561578
)
562579
if existing_model_details:
563-
self.with_model_details(ModelDetails().with_items(existing_model_details["items"]))
580+
self.with_model_details(
581+
ModelDetails().with_items(existing_model_details["items"])
582+
)
564583
else:
565584
self.with_model_details(ModelDetails().with_items([]))
566585
return self
@@ -773,6 +792,7 @@ def materialise(
773792

774793
dataset_execution_strategy.ingest_dataset(self, dataset_job)
775794

795+
@deprecated(details="preview functionality is deprecated. Please use as_of.")
776796
def preview(
777797
self,
778798
row_count: int = 10,
@@ -797,6 +817,8 @@ def preview(
797817
spark dataframe
798818
The preview result in spark dataframe
799819
"""
820+
self.check_resource_materialization()
821+
800822
validate_delta_format_parameters(timestamp, version_number)
801823
target_table = f"{self.entity_id}.{self.name}"
802824

@@ -806,6 +828,43 @@ def preview(
806828

807829
return self.spark_engine.sql(sql_query)
808830

831+
def check_resource_materialization(self):
832+
"""Checks whether the target Delta table for this resource has been materialized in Spark.
833+
If the target Delta table doesn't exist, raises a NotMaterializedError with the type and name of this resource.
834+
"""
835+
if not self.spark_engine.is_delta_table_exists(self.target_delta_table()):
836+
raise NotMaterializedError(self.type, self.name)
837+
838+
def as_of(
839+
self,
840+
version_number: int = None,
841+
timestamp: datetime = None,
842+
):
843+
"""preview the feature definition and return the response in dataframe.
844+
845+
Parameters
846+
----------
847+
timestamp: datetime
848+
commit date time to preview in format yyyy-MM-dd or yyyy-MM-dd HH:mm:ss
849+
commit date time is maintained for every ingestion commit using delta lake
850+
version_number: int
851+
commit version number for the preview. Version numbers are automatically versioned for every ingestion
852+
commit using delta lake
853+
854+
Returns
855+
-------
856+
spark dataframe
857+
The preview result in spark dataframe
858+
"""
859+
self.check_resource_materialization()
860+
861+
validate_delta_format_parameters(timestamp, version_number)
862+
target_table = self.target_delta_table()
863+
864+
return self.spark_engine.get_time_version_data(
865+
target_table, version_number, timestamp
866+
)
867+
809868
def profile(self):
810869
"""Get the dataset profile information and return the response in dataframe.
811870
@@ -814,6 +873,8 @@ def profile(self):
814873
spark dataframe
815874
The profile result in spark dataframe
816875
"""
876+
self.check_resource_materialization()
877+
817878
target_table = f"{self.entity_id}.{self.name}"
818879
sql_query = f"DESCRIBE DETAIL {target_table}"
819880

@@ -835,6 +896,8 @@ def restore(self, version_number: int = None, timestamp: datetime = None):
835896
spark dataframe
836897
The restore output as spark dataframe
837898
"""
899+
self.check_resource_materialization()
900+
838901
validate_delta_format_parameters(timestamp, version_number, True)
839902
target_table = f"{self.entity_id}.{self.name}"
840903
if version_number is not None:

ads/feature_store/execution_strategy/engine/spark_engine.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
import logging
8+
from datetime import datetime
89

910
from ads.common.decorator.runtime_dependency import OptionalDependency
1011

@@ -17,7 +18,7 @@
1718
)
1819
except Exception as e:
1920
raise
20-
from typing import List
21+
from typing import List, Dict
2122

2223
from ads.feature_store.common.utils.feature_schema_mapper import (
2324
map_spark_type_to_feature_type,
@@ -36,6 +37,69 @@ def __init__(self, metastore_id: str = None, spark_session: SparkSession = None)
3637
else:
3738
self.spark = SparkSessionSingleton(metastore_id).get_spark_session()
3839

40+
self.managed_table_location = (
41+
SparkSessionSingleton().get_managed_table_location()
42+
)
43+
44+
def get_time_version_data(
45+
self,
46+
delta_table_name: str,
47+
version_number: int = None,
48+
timestamp: datetime = None,
49+
):
50+
# Get the Delta table path
51+
delta_table_path = (
52+
self.managed_table_location + delta_table_name
53+
if self.managed_table_location
54+
else self._get_delta_table_path(delta_table_name)
55+
)
56+
57+
# Set read options based on version_number and timestamp
58+
read_options = {}
59+
if version_number is not None:
60+
read_options["versionAsOf"] = version_number
61+
if timestamp:
62+
read_options["timestampAsOf"] = timestamp
63+
64+
# Load the data from the Delta table using specified read options
65+
df = self._read_delta_table(delta_table_path, read_options)
66+
return df
67+
68+
def _get_delta_table_path(self, delta_table_name: str) -> str:
69+
"""
70+
Get the path of the Delta table using DESCRIBE EXTENDED SQL command.
71+
72+
Args:
73+
delta_table_name (str): The name of the Delta table.
74+
75+
Returns:
76+
str: The path of the Delta table.
77+
"""
78+
delta_table_path = (
79+
self.spark.sql(f"DESCRIBE EXTENDED {delta_table_name}")
80+
.filter("col_name = 'Location'")
81+
.collect()[0][1]
82+
)
83+
return delta_table_path
84+
85+
def _read_delta_table(self, delta_table_path: str, read_options: Dict):
86+
"""
87+
Read the Delta table using specified read options.
88+
89+
Args:
90+
delta_table_path (str): The path of the Delta table.
91+
read_options (dict): Dictionary of read options for Delta table.
92+
93+
Returns:
94+
DataFrame: The loaded DataFrame from the Delta table.
95+
"""
96+
df = (
97+
self.spark.read.format("delta")
98+
.options(**read_options)
99+
.load(delta_table_path)
100+
)
101+
return df
102+
39103
def sql(
40104
self,
41105
query: str,

ads/feature_store/feature_group.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import pandas as pd
1414
from great_expectations.core import ExpectationSuite
1515

16+
from ads import deprecated
1617
from ads.common import utils
1718
from ads.common.decorator.runtime_dependency import OptionalDependency
1819
from ads.common.oci_mixin import OCIModelMixin
@@ -989,6 +990,7 @@ def filter(self, f: Union[Filter, Logic]):
989990
"""
990991
return self.select().filter(f)
991992

993+
@deprecated(details="preview functionality is deprecated. Please use as_of.")
992994
def preview(
993995
self,
994996
row_count: int = 10,
@@ -1020,10 +1022,41 @@ def preview(
10201022

10211023
if version_number is not None:
10221024
logger.warning("Time travel queries are not supported in current version")
1025+
10231026
sql_query = f"select * from {target_table} LIMIT {row_count}"
10241027

10251028
return self.spark_engine.sql(sql_query)
10261029

1030+
def as_of(
1031+
self,
1032+
version_number: int = None,
1033+
timestamp: datetime = None,
1034+
):
1035+
"""preview the feature definition and return the response in dataframe.
1036+
1037+
Parameters
1038+
----------
1039+
timestamp: datetime
1040+
commit date time to preview in format yyyy-MM-dd or yyyy-MM-dd HH:mm:ss
1041+
commit date time is maintained for every ingestion commit using delta lake
1042+
version_number: int
1043+
commit version number for the preview. Version numbers are automatically versioned for every ingestion
1044+
commit using delta lake
1045+
1046+
Returns
1047+
-------
1048+
spark dataframe
1049+
The preview result in spark dataframe
1050+
"""
1051+
self.check_resource_materialization()
1052+
1053+
validate_delta_format_parameters(timestamp, version_number)
1054+
target_table = self.target_delta_table()
1055+
1056+
return self.spark_engine.get_time_version_data(
1057+
target_table, version_number, timestamp
1058+
)
1059+
10271060
def profile(self):
10281061
"""get the profile information for feature definition and return the response in dataframe.
10291062
@@ -1085,7 +1118,6 @@ def check_resource_materialization(self):
10851118
"""Checks whether the target Delta table for this resource has been materialized in Spark.
10861119
If the target Delta table doesn't exist, raises a NotMaterializedError with the type and name of this resource.
10871120
"""
1088-
print(self.target_delta_table())
10891121
if not self.spark_engine.is_delta_table_exists(self.target_delta_table()):
10901122
raise NotMaterializedError(self.type, self.name)
10911123

0 commit comments

Comments
 (0)