Skip to content

Commit 13d62a4

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
Completes OPEN-5006 Infer config properties from dataset/model
1 parent ce7db20 commit 13d62a4

File tree

4 files changed

+34
-14
lines changed

4 files changed

+34
-14
lines changed

openlayer/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,7 @@ def add_dataset(
718718
"Either `dataset_config` or `dataset_config_file_path` must be"
719719
" provided."
720720
)
721+
721722
# Validate dataset
722723
dataset_validator = dataset_validators.get_validator(
723724
task_type=task_type,
@@ -739,6 +740,8 @@ def add_dataset(
739740
dataset_data = DatasetSchema().load(
740741
{"task_type": task_type.value, **dataset_config}
741742
)
743+
if dataset_data.get("columnNames") is None:
744+
dataset_data["columnNames"] = utils.get_column_names(file_path)
742745

743746
# Copy relevant resources to temp directory
744747
with tempfile.TemporaryDirectory() as temp_dir:

openlayer/schemas.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ class BaseDatasetSchema(ma.Schema):
7676

7777
columnNames = ma.fields.List(
7878
ma.fields.Str(validate=COLUMN_NAME_VALIDATION_LIST),
79-
required=True,
79+
allow_none=True,
80+
load_default=None,
8081
)
8182
label = ma.fields.Str(
8283
validate=ma.validate.OneOf(
@@ -233,11 +234,12 @@ class BaseModelSchema(ma.Schema):
233234
"""Common schema for models for all task types."""
234235

235236
name = ma.fields.Str(
236-
required=True,
237237
validate=ma.validate.Length(
238238
min=1,
239239
max=64,
240240
),
241+
allow_none=True,
242+
load_default="Model",
241243
)
242244
metadata = ma.fields.Dict(
243245
allow_none=True,
@@ -252,7 +254,8 @@ class BaseModelSchema(ma.Schema):
252254
+ " If you can't find your framework, specify 'custom' for your model's"
253255
+ " `architectureType`.",
254256
),
255-
required=True,
257+
allow_none=True,
258+
load_default="custom",
256259
)
257260

258261

@@ -307,6 +310,19 @@ class LLMModelSchema(BaseModelSchema):
307310
ma.fields.Str(validate=COLUMN_NAME_VALIDATION_LIST),
308311
load_default=[],
309312
)
313+
# Important that here the architectureType defaults to `llm` and not `custom` since
314+
# the architectureType is used to check if the model is an LLM or not.
315+
architectureType = ma.fields.Str(
316+
validate=ma.validate.OneOf(
317+
[model_framework.value for model_framework in ModelType],
318+
error="`architectureType` must be one of the supported frameworks."
319+
+ " Check out our API reference for a full list."
320+
+ " If you can't find your framework, specify 'custom' for your model's"
321+
+ " `architectureType`.",
322+
),
323+
allow_none=True,
324+
load_default="llm",
325+
)
310326

311327
@ma.validates_schema
312328
def validates_model_type_fields(self, data, **kwargs):

openlayer/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,18 @@ def camel_to_snake_str(name: str) -> str:
7777
return re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
7878

7979

80+
def get_column_names(file_path: str) -> list:
81+
"""Returns the column names of the specified file.
82+
83+
Args:
84+
file_path (str): the path to the file.
85+
86+
Returns:
87+
list: the column names of the specified file.
88+
"""
89+
return pd.read_csv(file_path, nrows=0).columns.tolist()
90+
91+
8092
def get_exception_stacktrace(err: Exception):
8193
"""Returns the stacktrace of the most recent exception.
8294

openlayer/validators/dataset_validators.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,6 @@ def _validate_dataset_and_config_consistency(self):
162162

163163
# Dataset-wide validations
164164
self._validate_dataset_dtypes()
165-
self._validate_dataset_columns()
166165

167166
self._validate_inputs()
168167
self._validate_outputs()
@@ -180,16 +179,6 @@ def _validate_dataset_dtypes(self):
180179
" Please cast the columns in your dataset to conform to these dtypes."
181180
)
182181

183-
def _validate_dataset_columns(self):
184-
"""Checks whether all columns in the dataset are specified in `columnNames`."""
185-
dataset_columns = set(self.dataset_df.columns)
186-
column_names = set(self.column_names)
187-
if dataset_columns != column_names:
188-
self.failed_validations.append(
189-
"Not all columns in the dataset are specified in `columnNames`. "
190-
"Please specify all dataset columns in `columnNames`."
191-
)
192-
193182
@abstractmethod
194183
def _validate_inputs(self):
195184
"""To be implemented by InputValidator child classes."""

0 commit comments

Comments
 (0)