Skip to content

Commit 90397f1

Browse files
Completes OPEN-3507 Validate column names to not allow special characters
1 parent 8dab074 commit 90397f1

File tree

3 files changed

+52
-17
lines changed

3 files changed

+52
-17
lines changed

openlayer/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ class OpenlayerClient(object):
3636

3737
def __init__(self, api_key: str = None):
3838
self.api = api.Api(api_key)
39-
# self.subscription_plan = self.api.get_request("me/subscription-plan")
4039

4140
if not os.path.exists(OPENLAYER_DIR):
4241
os.makedirs(OPENLAYER_DIR)

openlayer/schemas.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,18 @@
66
from .models import ModelType
77
from .tasks import TaskType
88

9-
9+
# ---------------------------- Regular expressions --------------------------- #
10+
COLUMN_NAME_REGEX = validate = ma.validate.Regexp(
11+
r"^[a-zA-Z0-9_-]+$",
12+
error="strings that are not alphanumeric with underscores or hyphens."
13+
+ " Spaces and special characters are not allowed.",
14+
)
15+
LANGUAGE_CODE_REGEX = ma.validate.Regexp(
16+
r"^[a-z]{2}(-[A-Z]{2})?$",
17+
error="`language` of the dataset is not in the ISO 639-1 (alpha-2 code) format.",
18+
)
19+
20+
# ---------------------------------- Schemas --------------------------------- #
1021
class CommitSchema(ma.Schema):
1122
"""Schema for commits."""
1223

@@ -23,11 +34,13 @@ class DatasetSchema(ma.Schema):
2334
"""Schema for datasets."""
2435

2536
categoricalFeatureNames = ma.fields.List(
26-
ma.fields.Str(), allow_none=True, load_default=[]
37+
ma.fields.Str(validate=COLUMN_NAME_REGEX),
38+
allow_none=True,
39+
load_default=[],
2740
)
2841
classNames = ma.fields.List(ma.fields.Str(), required=True)
2942
columnNames = ma.fields.List(
30-
ma.fields.Str(),
43+
ma.fields.Str(validate=COLUMN_NAME_REGEX),
3144
required=True,
3245
)
3346
label = ma.fields.Str(
@@ -40,21 +53,26 @@ class DatasetSchema(ma.Schema):
4053
required=True,
4154
)
4255
featureNames = ma.fields.List(
43-
ma.fields.Str(),
56+
ma.fields.Str(validate=COLUMN_NAME_REGEX),
4457
load_default=[],
4558
)
46-
labelColumnName = ma.fields.Str(required=True)
59+
labelColumnName = ma.fields.Str(
60+
validate=COLUMN_NAME_REGEX,
61+
required=True,
62+
)
4763
language = ma.fields.Str(
4864
load_default="en",
49-
validate=ma.validate.Regexp(
50-
r"^[a-z]{2}(-[A-Z]{2})?$",
51-
error="`language` of the dataset is not in the ISO 639-1 (alpha-2 code) format.",
52-
),
65+
validate=LANGUAGE_CODE_REGEX,
5366
)
5467
metadata = ma.fields.Dict(allow_none=True, load_default={})
55-
predictionsColumnName = ma.fields.Str(allow_none=True, load_default=None)
68+
predictionsColumnName = ma.fields.Str(
69+
validate=COLUMN_NAME_REGEX,
70+
allow_none=True,
71+
load_default=None,
72+
)
5673
sep = ma.fields.Str(load_default=",")
5774
textColumnName = ma.fields.Str(
75+
validate=COLUMN_NAME_REGEX,
5876
allow_none=True,
5977
)
6078

@@ -70,7 +88,10 @@ def validates_label_column_not_in_feature_names(self, data, **kwargs):
7088
class ModelSchema(ma.Schema):
7189
"""Schema for models with artifacts (i.e., model_package)."""
7290

73-
categoricalFeatureNames = ma.fields.List(ma.fields.Str(), load_default=[])
91+
categoricalFeatureNames = ma.fields.List(
92+
ma.fields.Str(validate=COLUMN_NAME_REGEX),
93+
load_default=[],
94+
)
7495
classNames = ma.fields.List(
7596
ma.fields.Str(),
7697
)
@@ -81,7 +102,11 @@ class ModelSchema(ma.Schema):
81102
max=64,
82103
),
83104
)
84-
featureNames = ma.fields.List(ma.fields.Str(), allow_none=True, load_default=[])
105+
featureNames = ma.fields.List(
106+
ma.fields.Str(validate=COLUMN_NAME_REGEX),
107+
allow_none=True,
108+
load_default=[],
109+
)
85110
metadata = ma.fields.Dict(
86111
allow_none=True,
87112
load_default={},

openlayer/validators.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def validate(self) -> List[str]:
235235
self._validate_bundle_resources()
236236

237237
if not self.failed_validations:
238-
print("All validations passed!")
238+
print("All commit bundle validations passed!")
239239

240240
return self.failed_validations
241241

@@ -287,7 +287,7 @@ def validate(self) -> List[str]:
287287
self._validate_commit_message()
288288

289289
if not self.failed_validations:
290-
print("All validations passed!")
290+
print("All commit validations passed!")
291291

292292
return self.failed_validations
293293

@@ -454,6 +454,9 @@ def _validate_dataset_and_config_consistency(self):
454454
if self.dataset_config and self.dataset_df is not None:
455455
# Extract vars
456456
dataset_df = self.dataset_df
457+
categorical_feature_names = self.dataset_config.get(
458+
"categoricalFeatureNames"
459+
)
457460
class_names = self.dataset_config.get("classNames")
458461
column_names = self.dataset_config.get("columnNames")
459462
label_column_name = self.dataset_config.get("labelColumnName")
@@ -582,6 +585,14 @@ def _validate_dataset_and_config_consistency(self):
582585
"There are features specified in `featureNames` which are "
583586
"not in the dataset."
584587
)
588+
if categorical_feature_names:
589+
if self._columns_not_in_dataset_df(
590+
dataset_df, categorical_feature_names
591+
):
592+
dataset_and_config_consistency_failed_validations.append(
593+
"There are categorical features specified in `categoricalFeatureNames` "
594+
"which are not in the dataset."
595+
)
585596

586597
# Print results of the validation
587598
if dataset_and_config_consistency_failed_validations:
@@ -723,7 +734,7 @@ def validate(self) -> List[str]:
723734
self._validate_dataset_and_config_consistency()
724735

725736
if not self.failed_validations:
726-
print("All validations passed!")
737+
print("All dataset validations passed!")
727738

728739
return self.failed_validations
729740

@@ -1070,7 +1081,7 @@ def validate(self):
10701081
self._validate_project_config()
10711082

10721083
if not self.failed_validations:
1073-
print("All validations passed!")
1084+
print("All model validations passed!")
10741085

10751086
return self.failed_validations
10761087

0 commit comments

Comments
 (0)