4242DEFAULT_SPARK_VERSION = "3.2.1"
4343DEFAULT_NUM_EXECUTORS = 1
4444DEFAULT_SHAPE = "VM.Standard.E3.Flex"
45+ DATAFLOW_SHAPE_FAMILY = [
46+ "Standard.E3" ,
47+ "Standard.E4" ,
48+ "Standard3" ,
49+ "Standard.A1" ,
50+ "Standard2"
51+ ]
4552
4653
4754def conda_pack_name_to_dataflow_config (conda_uri ):
@@ -860,6 +867,15 @@ def create(self, runtime: DataFlowRuntime, **kwargs) -> "DataFlow":
860867 raise ValueError (
861868 "Compartment id is required. Specify compartment id via 'with_compartment_id()'."
862869 )
870+ self ._validate_shapes (payload )
871+ payload .pop ("id" , None )
872+ logger .debug (f"Creating a DataFlow Application with payload { payload } " )
873+ self .df_app = DataFlowApp (** payload ).create ()
874+ self .with_id (self .df_app .id )
875+ return self
876+
877+ @staticmethod
878+ def _validate_shapes (payload : Dict ):
863879 if "executor_shape" not in payload :
864880 payload ["executor_shape" ] = DEFAULT_SHAPE
865881 if "driver_shape" not in payload :
@@ -868,15 +884,22 @@ def create(self, runtime: DataFlowRuntime, **kwargs) -> "DataFlow":
868884 executor_shape_config = payload .get ("executor_shape_config" , {})
869885 driver_shape = payload ["driver_shape" ]
870886 driver_shape_config = payload .get ("driver_shape_config" , {})
871- if executor_shape != driver_shape :
872- raise ValueError ("`executor_shape` and `driver_shape` must be from the same shape family." )
873- if (not executor_shape .endswith ("Flex" ) and executor_shape_config ) or (not driver_shape .endswith ("Flex" ) and driver_shape_config ):
874- raise ValueError ("Shape config is not required for non flex shape from user end." )
875- payload .pop ("id" , None )
876- logger .debug (f"Creating a DataFlow Application with payload { payload } " )
877- self .df_app = DataFlowApp (** payload ).create ()
878- self .with_id (self .df_app .id )
879- return self
887+ same_shape_family = False
888+ for shape in DATAFLOW_SHAPE_FAMILY :
889+ if shape in executor_shape and shape in driver_shape :
890+ same_shape_family = True
891+ break
892+ if not same_shape_family :
893+ raise ValueError (
894+ "`executor_shape` and `driver_shape` must be from the same shape family."
895+ )
896+ if (
897+ (not executor_shape .endswith ("Flex" ) and executor_shape_config )
898+ or (not driver_shape .endswith ("Flex" ) and driver_shape_config )
899+ ):
900+ raise ValueError (
901+ "Shape config is not required for non flex shape from user end."
902+ )
880903
881904 @staticmethod
882905 def _upload_file (local_path , bucket , overwrite = False ):
0 commit comments