Skip to content

Commit 054a4db

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
Completes OPEN-5039 Improve create_inference_pipeline and create_project methods
1 parent e0bb27c commit 054a4db

File tree

1 file changed

+113
-93
lines changed

1 file changed

+113
-93
lines changed

openlayer/__init__.py

Lines changed: 113 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -118,39 +118,48 @@ def create_project(
118118
datasets to it. Refer to :obj:`add_model` and :obj:`add_dataset` or
119119
:obj:`add_dataframe` for detailed examples.
120120
"""
121-
# Validate project
122-
project_config = {
123-
"name": name,
124-
"description": description,
125-
"task_type": task_type,
126-
}
127-
project_validator = project_validators.ProjectValidator(
128-
project_config=project_config
129-
)
130-
failed_validations = project_validator.validate()
121+
try:
122+
project = self.load_project(name)
123+
warnings.warn(
124+
f"Found an existing project with name '{name}'. Loading it instead."
125+
)
126+
return project
127+
except exceptions.OpenlayerResourceNotFound:
128+
# Validate project
129+
project_config = {
130+
"name": name,
131+
"description": description,
132+
"task_type": task_type,
133+
}
134+
project_validator = project_validators.ProjectValidator(
135+
project_config=project_config
136+
)
137+
failed_validations = project_validator.validate()
131138

132-
if failed_validations:
133-
raise exceptions.OpenlayerValidationError(
134-
"There are issues with the project. \n"
135-
"Make sure to fix all of the issues listed above before creating it.",
136-
) from None
139+
if failed_validations:
140+
raise exceptions.OpenlayerValidationError(
141+
"There are issues with the project. \n"
142+
"Make sure to fix all of the issues listed above before creating it.",
143+
) from None
137144

138-
endpoint = "projects"
139-
payload = {
140-
"name": name,
141-
"description": description,
142-
"taskType": task_type.value,
143-
}
144-
project_data = self.api.post_request(endpoint, body=payload)
145+
endpoint = "projects"
146+
payload = {
147+
"name": name,
148+
"description": description,
149+
"taskType": task_type.value,
150+
}
151+
project_data = self.api.post_request(endpoint, body=payload)
145152

146-
project = Project(project_data, self.api.upload, self)
153+
project = Project(project_data, self.api.upload, self)
147154

148-
# Check if the staging area exists
149-
project_dir = os.path.join(constants.OPENLAYER_DIR, f"{project.id}/staging")
150-
os.makedirs(project_dir)
155+
# Check if the staging area exists
156+
project_dir = os.path.join(constants.OPENLAYER_DIR, f"{project.id}/staging")
157+
os.makedirs(project_dir)
151158

152-
print(f"Created your project. Navigate to {project.links['app']} to see it.")
153-
return project
159+
print(
160+
f"Created your project. Navigate to {project.links['app']} to see it."
161+
)
162+
return project
154163

155164
def load_project(self, name: str) -> Project:
156165
"""Loads an existing project from the Openlayer platform.
@@ -1391,7 +1400,7 @@ def create_inference_pipeline(
13911400
self,
13921401
project_id: str,
13931402
task_type: TaskType,
1394-
name: Optional[str] = None,
1403+
name: str = "Production",
13951404
description: Optional[str] = None,
13961405
reference_df: Optional[pd.DataFrame] = None,
13971406
reference_dataset_file_path: Optional[str] = None,
@@ -1404,7 +1413,7 @@ def create_inference_pipeline(
14041413
14051414
Parameters
14061415
----------
1407-
name : str, optional
1416+
name : str
14081417
Name of your inference pipeline. If not specified, the name will be
14091418
set to ``"Production"``.
14101419
@@ -1476,82 +1485,93 @@ def create_inference_pipeline(
14761485
" file path."
14771486
)
14781487

1479-
# Validate inference pipeline
1480-
inference_pipeline_config = {
1481-
"name": name or "Production",
1482-
"description": description or "Monitoring production data.",
1483-
}
1484-
inference_pipeline_validator = (
1485-
inference_pipeline_validators.InferencePipelineValidator(
1486-
inference_pipeline_config=inference_pipeline_config
1488+
try:
1489+
inference_pipeline = self.load_inference_pipeline(
1490+
name=name, project_id=project_id, task_type=task_type
14871491
)
1488-
)
1489-
failed_validations = inference_pipeline_validator.validate()
1490-
if failed_validations:
1491-
raise exceptions.OpenlayerValidationError(
1492-
"There are issues with the inference pipeline. \n"
1493-
"Make sure to fix all of the issues listed above before creating it.",
1494-
) from None
1495-
1496-
# Validate reference dataset and augment config
1497-
if reference_dataset_config_file_path is not None:
1498-
dataset_validator = dataset_validators.get_validator(
1499-
task_type=task_type,
1500-
dataset_config_file_path=reference_dataset_config_file_path,
1501-
dataset_df=reference_df,
1492+
warnings.warn(
1493+
f"Found an existing inference pipeline with name '{name}'. "
1494+
"Loading it instead."
15021495
)
1503-
failed_validations = dataset_validator.validate()
1504-
1496+
except exceptions.OpenlayerResourceNotFound:
1497+
# Validate inference pipeline
1498+
inference_pipeline_config = {
1499+
"name": name or "Production",
1500+
"description": description or "Monitoring production data.",
1501+
}
1502+
inference_pipeline_validator = (
1503+
inference_pipeline_validators.InferencePipelineValidator(
1504+
inference_pipeline_config=inference_pipeline_config
1505+
)
1506+
)
1507+
failed_validations = inference_pipeline_validator.validate()
15051508
if failed_validations:
15061509
raise exceptions.OpenlayerValidationError(
1507-
"There are issues with the reference dataset and its config. \n"
1508-
"Make sure to fix all of the issues listed above before the upload.",
1510+
"There are issues with the inference pipeline. \n"
1511+
"Make sure to fix all of the issues listed above before"
1512+
" creating it.",
15091513
) from None
15101514

1511-
# Load dataset config and augment with defaults
1512-
reference_dataset_config = utils.read_yaml(
1513-
reference_dataset_config_file_path
1514-
)
1515-
reference_dataset_data = DatasetSchema().load(
1516-
{"task_type": task_type.value, **reference_dataset_config}
1517-
)
1518-
1519-
with tempfile.TemporaryDirectory() as tmp_dir:
1520-
# Copy relevant files to tmp dir if reference dataset is provided
1515+
# Validate reference dataset and augment config
15211516
if reference_dataset_config_file_path is not None:
1522-
utils.write_yaml(
1523-
reference_dataset_data, f"{tmp_dir}/dataset_config.yaml"
1517+
dataset_validator = dataset_validators.get_validator(
1518+
task_type=task_type,
1519+
dataset_config_file_path=reference_dataset_config_file_path,
1520+
dataset_df=reference_df,
1521+
)
1522+
failed_validations = dataset_validator.validate()
1523+
1524+
if failed_validations:
1525+
raise exceptions.OpenlayerValidationError(
1526+
"There are issues with the reference dataset and its config. \n"
1527+
"Make sure to fix all of the issues listed above before the"
1528+
" upload.",
1529+
) from None
1530+
1531+
# Load dataset config and augment with defaults
1532+
reference_dataset_config = utils.read_yaml(
1533+
reference_dataset_config_file_path
15241534
)
1525-
if reference_df is not None:
1526-
reference_df.to_csv(f"{tmp_dir}/dataset.csv", index=False)
1527-
else:
1528-
shutil.copy(
1529-
reference_dataset_file_path,
1530-
f"{tmp_dir}/dataset.csv",
1535+
reference_dataset_data = DatasetSchema().load(
1536+
{"task_type": task_type.value, **reference_dataset_config}
1537+
)
1538+
1539+
with tempfile.TemporaryDirectory() as tmp_dir:
1540+
# Copy relevant files to tmp dir if reference dataset is provided
1541+
if reference_dataset_config_file_path is not None:
1542+
utils.write_yaml(
1543+
reference_dataset_data, f"{tmp_dir}/dataset_config.yaml"
15311544
)
1545+
if reference_df is not None:
1546+
reference_df.to_csv(f"{tmp_dir}/dataset.csv", index=False)
1547+
else:
1548+
shutil.copy(
1549+
reference_dataset_file_path,
1550+
f"{tmp_dir}/dataset.csv",
1551+
)
15321552

1533-
tar_file_path = os.path.join(tmp_dir, "tarfile")
1534-
with tarfile.open(tar_file_path, mode="w:gz") as tar:
1535-
tar.add(tmp_dir, arcname=os.path.basename("reference_dataset"))
1553+
tar_file_path = os.path.join(tmp_dir, "tarfile")
1554+
with tarfile.open(tar_file_path, mode="w:gz") as tar:
1555+
tar.add(tmp_dir, arcname=os.path.basename("reference_dataset"))
15361556

1537-
endpoint = f"projects/{project_id}/inference-pipelines"
1538-
inference_pipeline_data = self.api.upload(
1539-
endpoint=endpoint,
1540-
file_path=tar_file_path,
1541-
object_name="tarfile",
1542-
body=inference_pipeline_config,
1543-
storage_uri_key="referenceDatasetUri",
1544-
method="POST",
1557+
endpoint = f"projects/{project_id}/inference-pipelines"
1558+
inference_pipeline_data = self.api.upload(
1559+
endpoint=endpoint,
1560+
file_path=tar_file_path,
1561+
object_name="tarfile",
1562+
body=inference_pipeline_config,
1563+
storage_uri_key="referenceDatasetUri",
1564+
method="POST",
1565+
)
1566+
inference_pipeline = InferencePipeline(
1567+
inference_pipeline_data, self.api.upload, self, task_type
15451568
)
1546-
inference_pipeline = InferencePipeline(
1547-
inference_pipeline_data, self.api.upload, self, task_type
1548-
)
15491569

1550-
print(
1551-
"Created your inference pipeline. Navigate to"
1552-
f" {inference_pipeline.links['app']} to see it."
1553-
)
1554-
return inference_pipeline
1570+
print(
1571+
"Created your inference pipeline. Navigate to"
1572+
f" {inference_pipeline.links['app']} to see it."
1573+
)
1574+
return inference_pipeline
15551575

15561576
def load_inference_pipeline(
15571577
self,

0 commit comments

Comments
 (0)