|
31 | 31 | from .projects import Project |
32 | 32 | from .version import __version__ |
33 | 33 |
|
34 | | -from .schemas import DatasetSchema, ModelSchema |
| 34 | +from .schemas import DatasetSchema, ModelSchema, ProjectSchema |
35 | 35 | from marshmallow import ValidationError |
36 | 36 |
|
37 | 37 |
|
@@ -86,6 +86,13 @@ def create_project( |
86 | 86 | name: str, |
87 | 87 | description: str, |
88 | 88 | ): |
| 89 | + # ----------------------------- Schema validation ---------------------------- # |
| 90 | + project_schema = ProjectSchema() |
| 91 | + try: |
| 92 | + project_schema.load({"name": name, "description": description}) |
| 93 | + except ValidationError as err: |
| 94 | + raise UnboxValidationError(self._format_error_message(err)) |
| 95 | + |
89 | 96 | endpoint = "initialize_project" |
90 | 97 | payload = dict( |
91 | 98 | name=name, |
@@ -346,6 +353,15 @@ def add_model( |
346 | 353 | >>> model.to_dict() |
347 | 354 | """ |
348 | 355 | # ---------------------------- Schema validations ---------------------------- # |
| 356 | + if task_type not in [ |
| 357 | + TaskType.TabularClassification, |
| 358 | + TaskType.TextClassification, |
| 359 | + ]: |
| 360 | + raise UnboxValidationError( |
| 361 | + "`task_type` must be either TaskType.TabularClassification or TaskType.TextClassification. \n" |
| 362 | + ) |
| 363 | + if model_type not in []: |
| 364 | + pass |
349 | 365 | model_schema = ModelSchema() |
350 | 366 | try: |
351 | 367 | model_schema.load( |
@@ -711,14 +727,21 @@ def add_dataset( |
711 | 727 | >>> dataset.to_dict() |
712 | 728 | """ |
713 | 729 | # ---------------------------- Schema validations ---------------------------- # |
| 730 | + if task_type not in [ |
| 731 | + TaskType.TabularClassification, |
| 732 | + TaskType.TextClassification, |
| 733 | + ]: |
| 734 | + raise UnboxValidationError( |
| 735 | + "`task_type` must be either TaskType.TabularClassification or TaskType.TextClassification. \n" |
| 736 | + ) |
714 | 737 | dataset_schema = DatasetSchema() |
715 | 738 | try: |
716 | 739 | dataset_schema.load( |
717 | 740 | { |
718 | 741 | "name": name, |
719 | 742 | "file_path": file_path, |
720 | | - "description": description, |
721 | 743 | "task_type": task_type.value, |
| 744 | + "description": description, |
722 | 745 | "class_names": class_names, |
723 | 746 | "label_column_name": label_column_name, |
724 | 747 | "tag_column_name": tag_column_name, |
|
0 commit comments