Skip to content

Commit 31abe0d

Browse files
committed
Add init method for the infrastructure classes.
1 parent d9fc1a6 commit 31abe0d

File tree

10 files changed

+104
-28
lines changed

10 files changed

+104
-28
lines changed

ads/common/serializer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def to_yaml(
243243
"""
244244
note = kwargs.pop("note", "")
245245

246-
yaml_string = note + yaml.dump(self.to_dict(**kwargs), Dumper=dumper)
246+
yaml_string = f"{note}\n" + yaml.dump(self.to_dict(**kwargs), Dumper=dumper)
247247
if uri:
248248
self._write_to_file(s=yaml_string, uri=uri, **kwargs)
249249
return None

ads/jobs/builders/infrastructure/dataflow.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
DEFAULT_LANGUAGE = "PYTHON"
4242
DEFAULT_SPARK_VERSION = "3.2.1"
4343
DEFAULT_NUM_EXECUTORS = 1
44+
DEFAULT_SHAPE = "VM.Standard.E3.Flex"
4445

4546

4647
def conda_pack_name_to_dataflow_config(conda_uri):
@@ -366,7 +367,6 @@ def executor(self):
366367

367368

368369
class DataFlow(Infrastructure):
369-
370370
CONST_COMPARTMENT_ID = "compartment_id"
371371
CONST_CONFIG = "configuration"
372372
CONST_EXECUTE = "execute"
@@ -423,7 +423,6 @@ def __init__(self, spec: dict = None, **kwargs):
423423
self.runtime = None
424424
self._name = None
425425

426-
427426
def _load_default_properties(self) -> Dict:
428427
"""
429428
Load default properties from environment variables, notebook session, etc.
@@ -1133,3 +1132,21 @@ def to_yaml(self, **kwargs) -> str:
11331132
YAML stored in a string.
11341133
"""
11351134
return yaml.safe_dump(self.to_dict(**kwargs))
1135+
1136+
def init(self) -> "DataFlow":
1137+
"""Initializes a starter specification for the DataFlow.
1138+
1139+
Returns
1140+
-------
1141+
DataFlow
1142+
The DataFlow instance (self)
1143+
"""
1144+
return (
1145+
self.build()
1146+
.with_compartment_id(self.compartment_id or "{Provide a compartment OCID}")
1147+
.with_language(self.language or DEFAULT_LANGUAGE)
1148+
.with_spark_version(self.spark_version or DEFAULT_SPARK_VERSION)
1149+
.with_num_executors(self.num_executors or DEFAULT_NUM_EXECUTORS)
1150+
.with_driver_shape(self.driver_shape or DEFAULT_SHAPE)
1151+
.with_executor_shape(self.with_executor_shape or DEFAULT_SHAPE)
1152+
)

ads/jobs/builders/infrastructure/dsc_job.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,6 +1438,21 @@ def build(self) -> DataScienceJob:
14381438
self._update_from_dsc_model(self.dsc_job, overwrite=False)
14391439
return self
14401440

1441+
def init(self) -> DataScienceJob:
1442+
"""Initializes a starter specification for the DataScienceJob.
1443+
1444+
Returns
1445+
-------
1446+
DataScienceJob
1447+
The DataScienceJob instance (self)
1448+
"""
1449+
return (
1450+
self.build()
1451+
.with_compartment_id(self.compartment_id or "{Provide a compartment OCID}")
1452+
.with_project_id(self.project_id or "{Provide a project OCID}")
1453+
.with_subnet_id(self.subnet_id or "{Provide a subnet OCID}")
1454+
)
1455+
14411456
def create(self, runtime, **kwargs) -> DataScienceJob:
14421457
"""Creates a job with runtime.
14431458

ads/jobs/builders/runtimes/python_runtime.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,10 @@ def init(self) -> "CondaRuntime":
125125
The runtime instance.
126126
"""
127127
super().init()
128-
return self.with_custom_conda("oci://your_bucket@namespace/object_name")
128+
return self.with_custom_conda(
129+
"{Path to the custom conda environment. "
130+
"Example: oci://your_bucket@namespace/object_name"
131+
)
129132

130133

131134
class ScriptRuntime(CondaRuntime):
@@ -254,10 +257,10 @@ def init(self) -> "ScriptRuntime":
254257
super().init()
255258
return (
256259
self.with_entrypoint(
257-
"{Entry point script. For the MLflow will be replaced with the CMD}"
260+
"{Entrypoint script. For MLFlow, it will be replaced with the CMD}"
258261
)
259262
.with_script(
260-
"{Path to the script. For the MLFlow will be replaced with path to the project}"
263+
"{Path to the script. For MLFlow, it will be replaced with the path to the project}"
261264
)
262265
.with_argument(key1="val1")
263266
)
@@ -442,12 +445,12 @@ def init(self) -> "PythonRuntime":
442445
"""
443446
super().init()
444447
return (
445-
self.with_working_dir("{For the MLflow the project folder will be used.}")
448+
self.with_working_dir("{For MLflow the project folder will be used.}")
446449
.with_entrypoint(
447-
"{Entry point script. For the MLFlow will be replaced with the CMD}"
450+
"{Entrypoint script. For MLFlow, it will be replaced with the CMD}"
448451
)
449452
.with_script(
450-
"{Path to the script. For the MLFlow will be replaced with path to the project}"
453+
"{Path to the script. For MLFlow, it will be replaced with the path to the project}"
451454
)
452455
)
453456

@@ -630,8 +633,8 @@ def init(self) -> "NotebookRuntime":
630633
"""
631634
super().init()
632635
return self.with_source(
633-
uri="{Path to the source code directory. For the MLFlow will be replaced with path to the project}",
634-
notebook="{Entry point notebook. For the MLFlow will be replaced with the CMD}",
636+
uri="{Path to the source code directory. For MLflow, it will be replaced with the path to the project}",
637+
notebook="{Entrypoint notebook. For MLflow, it will be replaced with the CMD}",
635638
).with_exclude_tag("tag1")
636639

637640

@@ -751,9 +754,9 @@ def init(self) -> "GitPythonRuntime":
751754
"""
752755
super().init()
753756
return self.with_source(
754-
"{Git URI. For the MLFlow will be replaced with the Project URI}"
757+
"{Git URI. For MLFlow, it will be replaced with the Project URI}"
755758
).with_entrypoint(
756-
"{Entry point script. For the MLflow will be replaced with the CMD}"
759+
"{Entry point script. For MLFlow, it will be replaced with the CMD}"
757760
)
758761

759762

@@ -976,10 +979,12 @@ def init(self) -> "DataFlowRuntime":
976979
self._spec.pop(self.CONST_ENV_VAR, None)
977980
return (
978981
self.with_script_uri(
979-
"{Path to the executable script. For the MLFlow will be replaced with the CMD}"
982+
"{Path to the executable script. For MLFlow, it will be replaced with the CMD}"
983+
)
984+
.with_script_bucket(
985+
"{The object storage bucket to save a script. "
986+
"Example: oci://<bucket_name>@<tenancy>/<prefix>}"
980987
)
981-
.with_argument(key1="val1")
982-
.with_script_bucket("oci://<bucket_name>@<tenancy>/<prefix>")
983988
.with_overwrite(True)
984989
.with_configuration({"spark.driverEnv.env_key": "env_value"})
985990
)

ads/jobs/serializer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def to_yaml(
203203
"""
204204
note = kwargs.pop("note", "")
205205

206-
yaml_string = note + yaml.dump(self.to_dict(**kwargs), Dumper=dumper)
206+
yaml_string = f"{note}\n" + yaml.dump(self.to_dict(**kwargs), Dumper=dumper)
207207
if uri:
208208
self._write_to_file(s=yaml_string, uri=uri, **kwargs)
209209
return None

ads/model/deployment/model_deployment_infrastructure.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919
MODEL_DEPLOYMENT_INFRASTRUCTURE_TYPE = "datascienceModelDeployment"
2020
MODEL_DEPLOYMENT_INFRASTRUCTURE_KIND = "infrastructure"
2121

22+
DEFAULT_BANDWIDTH_MBPS = 10
23+
DEFAULT_WEB_CONCURRENCY = 10
24+
DEFAULT_REPLICA = 1
25+
DEFAULT_SHAPE_NAME = "VM.Standard.E2.4"
26+
2227
logger = logging.getLogger(__name__)
2328

2429

@@ -152,7 +157,7 @@ class ModelDeploymentInfrastructure(Builder):
152157
CONST_LOG_ID: "log_id",
153158
CONST_LOG_GROUP_ID: "log_group_id",
154159
CONST_WEB_CONCURRENCY: "web_concurrency",
155-
CONST_SUBNET_ID: "subnet_id"
160+
CONST_SUBNET_ID: "subnet_id",
156161
}
157162

158163
shape_config_details_attribute_map = {
@@ -211,14 +216,15 @@ def _load_default_properties(self) -> Dict:
211216
if PROJECT_OCID:
212217
defaults[self.CONST_PROJECT_ID] = PROJECT_OCID
213218

219+
defaults[self.CONST_BANDWIDTH_MBPS] = DEFAULT_BANDWIDTH_MBPS
220+
defaults[self.CONST_WEB_CONCURRENCY] = DEFAULT_WEB_CONCURRENCY
221+
defaults[self.CONST_REPLICA] = DEFAULT_REPLICA
222+
214223
if NB_SESSION_OCID:
215224
try:
216225
nb_session = DSCNotebookSession.from_ocid(NB_SESSION_OCID)
217226
nb_config = nb_session.notebook_session_configuration_details
218227
defaults[self.CONST_SHAPE_NAME] = nb_config.shape
219-
defaults[self.CONST_BANDWIDTH_MBPS] = 10
220-
defaults[self.CONST_WEB_CONCURRENCY] = 10
221-
defaults[self.CONST_REPLICA] = 1
222228

223229
if nb_config.notebook_session_shape_config_details:
224230
notebook_shape_config_details = oci_util.to_dict(
@@ -602,3 +608,21 @@ def subnet_id(self) -> str:
602608
The model deployment subnet id.
603609
"""
604610
return self.get_spec(self.CONST_SUBNET_ID, None)
611+
612+
def init(self) -> "ModelDeploymentInfrastructure":
613+
"""Initializes a starter specification for the ModelDeploymentInfrastructure.
614+
615+
Returns
616+
-------
617+
ModelDeploymentInfrastructure
618+
The ModelDeploymentInfrastructure instance (self)
619+
"""
620+
return (
621+
self.build()
622+
.with_compartment_id(self.compartment_id or "{Provide a compartment OCID}")
623+
.with_project_id(self.project_id or "{Provide a project OCID}")
624+
.with_bandwidth_mbps(self.bandwidth_mbps or DEFAULT_BANDWIDTH_MBPS)
625+
.with_web_concurrency(self.web_concurrency or DEFAULT_WEB_CONCURRENCY)
626+
.with_replica(self.replica or DEFAULT_REPLICA)
627+
.with_shape_name(self.shape_name or DEFAULT_SHAPE_NAME)
628+
)

ads/opctl/backend/ads_ml_job.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,14 +101,15 @@ def init(
101101
"{Job name. For the MLFlow will be auto replaced with the Project name}"
102102
)
103103
.with_infrastructure(
104-
DataScienceJob(**(self.config.get("infrastructure", {}) or {}))
104+
DataScienceJob(
105+
**(self.config.get("infrastructure", {}) or {})
106+
).init()
105107
)
106108
.with_runtime(
107109
JobRuntimeFactory.get_runtime(
108110
key=runtime_type or PythonRuntime().type
109111
).init()
110112
)
111-
.build()
112113
)
113114

114115
note = (

ads/opctl/backend/ads_ml_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def init(
120120
pipeline_step = (
121121
PipelineStep("pipeline_step_name_1")
122122
.with_description("A step running a python script")
123-
.with_infrastructure(CustomScriptStep().build())
123+
.with_infrastructure(CustomScriptStep().init())
124124
.with_runtime(
125125
JobRuntimeFactory.get_runtime(
126126
key=runtime_type or PythonRuntime().type
@@ -136,7 +136,7 @@ def init(
136136
)
137137
.with_step_details([pipeline_step])
138138
.with_dag(["pipeline_step_name_1"])
139-
.build()
139+
.init()
140140
)
141141

142142
note = (

ads/opctl/backend/ads_model_deployment.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,13 @@ def init(
7272
.with_infrastructure(
7373
ModelDeploymentInfrastructure(
7474
**(self.config.get("infrastructure", {}) or {})
75-
)
75+
).init()
7676
)
7777
.with_runtime(
7878
ModelDeploymentRuntimeFactory.get_runtime(
7979
key=runtime_type or ModelDeploymentCondaRuntime().type
8080
).init()
8181
)
82-
.build()
8382
)
8483

8584
note = (

ads/pipeline/ads_pipeline.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1981,6 +1981,21 @@ def status(self) -> Optional[str]:
19811981
return self.data_science_pipeline.lifecycle_state
19821982
return None
19831983

1984+
def init(self) -> "Pipeline":
1985+
"""Initializes a starter specification for the Pipeline.
1986+
1987+
Returns
1988+
-------
1989+
Pipeline
1990+
The Pipeline instance (self)
1991+
"""
1992+
return (
1993+
self.build()
1994+
.with_compartment_id(self.compartment_id or "{Provide a compartment OCID}")
1995+
.with_project_id(self.project_id or "{Provide a project OCID}")
1996+
)
1997+
1998+
19841999

19852000
class DataSciencePipeline(OCIDataScienceMixin, oci.data_science.models.Pipeline):
19862001
@classmethod
@@ -2262,4 +2277,4 @@ def delete(
22622277
operation_kwargs=operation_kwargs,
22632278
waiter_kwargs=waiter_kwargs,
22642279
)
2265-
return self.sync()
2280+
return self.sync()

0 commit comments

Comments
 (0)