11#!/usr/bin/env python
22# -*- coding: utf-8; -*-
3- import json
43import logging
54from copy import deepcopy
65from datetime import datetime
76from typing import Dict , List , Union
87
98import pandas
9+ import pandas as pd
1010from great_expectations .core import ExpectationSuite
1111
1212from ads import deprecated
13+ from oci .feature_store .models import (
14+ DatasetFeatureGroupCollection ,
15+ DatasetFeatureGroupSummary ,
16+ )
17+
1318from ads .common import utils
1419from ads .common .oci_mixin import OCIModelMixin
1520from ads .feature_store .common .enums import (
2934 OciExecutionStrategyProvider ,
3035)
3136from ads .feature_store .feature import DatasetFeature
37+ from ads .feature_store .feature_group import FeatureGroup
3238from ads .feature_store .feature_group_expectation import Expectation
3339from ads .feature_store .feature_option_details import FeatureOptionDetails
3440from ads .feature_store .service .oci_dataset import OCIDataset
@@ -116,6 +122,7 @@ class Dataset(Builder):
116122 CONST_ITEMS = "items"
117123 CONST_LAST_JOB_ID = "jobId"
118124 CONST_MODEL_DETAILS = "modelDetails"
125+ CONST_FEATURE_GROUP = "datasetFeatureGroups"
119126
120127 attribute_map = {
121128 CONST_ID : "id" ,
@@ -133,6 +140,7 @@ class Dataset(Builder):
133140 CONST_LIFECYCLE_STATE : "lifecycle_state" ,
134141 CONST_MODEL_DETAILS : "model_details" ,
135142 CONST_PARTITION_KEYS : "partition_keys" ,
143+ CONST_FEATURE_GROUP : "dataset_feature_groups" ,
136144 }
137145
138146 def __init__ (self , spec : Dict = None , ** kwargs ) -> None :
@@ -151,6 +159,7 @@ def __init__(self, spec: Dict = None, **kwargs) -> None:
151159 super ().__init__ (spec = spec , ** deepcopy (kwargs ))
152160 # Specify oci Dataset instance
153161 self .dataset_job = None
162+ self ._is_manual_association : bool = False
154163 self ._spark_engine = None
155164 self .oci_dataset = self ._to_oci_dataset (** kwargs )
156165 self .lineage = OCILineage (** kwargs )
@@ -183,6 +192,16 @@ def spark_engine(self):
183192 self ._spark_engine = SparkEngine (get_metastore_id (self .feature_store_id ))
184193 return self ._spark_engine
185194
195+ @property
196+ def is_manual_association (self ):
197+ collection : DatasetFeatureGroupCollection = self .get_spec (
198+ self .CONST_FEATURE_GROUP
199+ )
200+ if collection and collection .is_manual_association is not None :
201+ return collection .is_manual_association
202+ else :
203+ return self ._is_manual_association
204+
186205 @property
187206 def kind (self ) -> str :
188207 """The kind of the object as showing in a YAML."""
@@ -530,6 +549,54 @@ def with_model_details(self, model_details: ModelDetails) -> "Dataset":
530549
531550 return self .set_spec (self .CONST_MODEL_DETAILS , model_details .to_dict ())
532551
552+ @property
553+ def feature_groups (self ) -> List ["FeatureGroup" ]:
554+ collection : "DatasetFeatureGroupCollection" = self .get_spec (
555+ self .CONST_FEATURE_GROUP
556+ )
557+ feature_groups : List ["FeatureGroup" ] = []
558+ if collection and collection .items :
559+ for datasetFGSummary in collection .items :
560+ feature_groups .append (
561+ FeatureGroup .from_id (datasetFGSummary .feature_group_id )
562+ )
563+
564+ return feature_groups
565+
566+ @feature_groups .setter
567+ def feature_groups (self , feature_groups : List ["FeatureGroup" ]):
568+ self .with_feature_groups (feature_groups )
569+
570+ def with_feature_groups (self , feature_groups : List ["FeatureGroup" ]) -> "Dataset" :
571+ """Sets the model details for the dataset.
572+
573+ Parameters
574+ ----------
575+ feature_groups: List of feature groups
576+ Returns
577+ -------
578+ Dataset
579+ The Dataset instance (self).
580+
581+ """
582+ collection : List ["DatasetFeatureGroupSummary" ] = []
583+ for group in feature_groups :
584+ collection .append (DatasetFeatureGroupSummary (feature_group_id = group .id ))
585+
586+ self ._is_manual_association = True
587+ return self .set_spec (
588+ self .CONST_FEATURE_GROUP ,
589+ DatasetFeatureGroupCollection (items = collection , is_manual_association = True ),
590+ )
591+
592+ def feature_groups_to_df (self ):
593+ return pd .DataFrame .from_records (
594+ [
595+ feature_group .oci_feature_group .to_df_record ()
596+ for feature_group in self .feature_groups
597+ ]
598+ )
599+
533600 @property
534601 def partition_keys (self ) -> List [str ]:
535602 return self .get_spec (self .CONST_PARTITION_KEYS )
@@ -641,7 +708,7 @@ def show(self, rankdir: str = GraphOrientation.LEFT_RIGHT) -> None:
641708 f"Can't get lineage information for Feature group id { self .id } "
642709 )
643710
644- def create (self , ** kwargs ) -> "Dataset" :
711+ def create (self , validate_sql = False , ** kwargs ) -> "Dataset" :
645712 """Creates dataset resource.
646713
647714 !!! note "Lazy"
@@ -654,6 +721,8 @@ def create(self, **kwargs) -> "Dataset":
654721 kwargs
655722 Additional kwargs arguments.
656723 Can be any attribute that `oci.feature_store.models.Dataset` accepts.
724+ validate_sql:
725+ Boolean value indicating whether to validate sql before creating dataset
657726
658727 Returns
659728 -------
@@ -674,13 +743,17 @@ def create(self, **kwargs) -> "Dataset":
674743 if self .statistics_config is None :
675744 self .statistics_config = StatisticsConfig ()
676745
746+ if validate_sql is True :
747+ self .spark_engine .sql (self .get_spec (self .CONST_QUERY ))
748+
677749 payload = deepcopy (self ._spec )
678750 payload .pop ("id" , None )
679751 logger .debug (f"Creating a dataset resource with payload { payload } " )
680752
681753 # Create dataset
682754 logger .info ("Saving dataset." )
683755 self .oci_dataset = self ._to_oci_dataset (** kwargs ).create ()
756+ self ._update_from_oci_dataset_model (self .oci_dataset )
684757 self .with_id (self .oci_dataset .id )
685758 return self
686759
@@ -793,8 +866,7 @@ def _update_from_oci_dataset_model(self, oci_dataset: OCIDataset) -> "Dataset":
793866
794867 value = {self .CONST_ITEMS : features_list }
795868 else :
796- value = dataset_details [infra_attr ]
797-
869+ value = getattr (self .oci_dataset , dsc_attr )
798870 self .set_spec (infra_attr , value )
799871
800872 return self
@@ -1134,6 +1206,10 @@ def to_dict(self) -> Dict:
11341206 for key , value in spec .items ():
11351207 if hasattr (value , "to_dict" ):
11361208 value = value .to_dict ()
1209+ if hasattr (value , "attribute_map" ):
1210+ value = self .oci_dataset .client .base_client .sanitize_for_serialization (
1211+ value
1212+ )
11371213 spec [key ] = value
11381214
11391215 return {
0 commit comments