@@ -167,6 +167,9 @@ def validate(self) -> List[str]:
167167 if not self .failed_validations :
168168 self ._validate_bundle_resources ()
169169
170+ if not self .failed_validations :
171+ self ._validate_resource_consistency ()
172+
170173 if not self .failed_validations :
171174 logger .info (
172175 "----------------------------------------------------------------------------\n "
@@ -401,6 +404,137 @@ def _load_model_config_from_bundle(self) -> Dict[str, Any]:
401404
402405 return model_config
403406
407+ def _validate_resource_consistency (self ):
408+ """Validates that the resources in the bundle are consistent with each other.
409+
410+ For example, if the `classNames` field on the dataset configs are consistent
411+ with the one on the model config.
412+ """
413+ resource_consistency_failed_validations = []
414+
415+ if (
416+ "training" in self ._bundle_resources
417+ and "validation" in self ._bundle_resources
418+ ):
419+ # Loading the relevant configs
420+ model_config = {}
421+ if "model" in self ._bundle_resources :
422+ model_config = self ._load_model_config_from_bundle ()
423+ training_dataset_config = self ._load_dataset_config_from_bundle ("training" )
424+ validation_dataset_config = self ._load_dataset_config_from_bundle (
425+ "validation"
426+ )
427+ model_feature_names = model_config .get ("featureNames" )
428+ model_class_names = model_config .get ("classNames" )
429+ training_feature_names = training_dataset_config .get ("featureNames" )
430+ training_class_names = training_dataset_config .get ("classNames" )
431+ validation_feature_names = validation_dataset_config .get ("featureNames" )
432+ validation_class_names = validation_dataset_config .get ("classNames" )
433+
434+ # Validating the `featureNames` field
435+ if training_feature_names or validation_feature_names :
436+ if not self ._feature_names_consistent (
437+ model_feature_names = model_feature_names ,
438+ training_feature_names = training_feature_names ,
439+ validation_feature_names = validation_feature_names ,
440+ ):
441+ resource_consistency_failed_validations .append (
442+ "The `featureNames` in the provided resources are inconsistent."
443+ " The training and validation set feature names must have some overlap."
444+ " Furthermore, if a model is provided, its feature names must be a subset"
445+ " of the feature names in the training and validation sets."
446+ )
447+
448+ # Validating the `classNames` field
449+ if not self ._class_names_consistent (
450+ model_class_names = model_class_names ,
451+ training_class_names = training_class_names ,
452+ validation_class_names = validation_class_names ,
453+ ):
454+ resource_consistency_failed_validations .append (
455+ "The `classNames` in the provided resources are inconsistent."
456+ " The validation set's class names need to contain the training set's."
457+ " Furthermore, if a model is provided, its class names must be contained"
458+ " in the training and validation sets' class names."
459+ " Note that the order of the items in the `classNames` list matters."
460+ )
461+
462+ # Print results of the validation
463+ if resource_consistency_failed_validations :
464+ logger .error ("Bundle resource consistency failed validations:" )
465+ _list_failed_validation_messages (resource_consistency_failed_validations )
466+
467+ # Add the bundle resource consistency failed validations to the list of all failed validations
468+ self .failed_validations .extend (resource_consistency_failed_validations )
469+
470+ @staticmethod
471+ def _feature_names_consistent (
472+ model_feature_names : Optional [List [str ]],
473+ training_feature_names : List [str ],
474+ validation_feature_names : List [str ],
475+ ) -> bool :
476+ """Checks whether the feature names in the training, validation and model
477+ configs are consistent.
478+
479+ Parameters
480+ ----------
481+ model_feature_names : List[str]
482+ The feature names in the model config.
483+ training_feature_names : List[str]
484+ The feature names in the training dataset config.
485+ validation_feature_names : List[str]
486+ The feature names in the validation dataset config.
487+
488+ Returns
489+ -------
490+ bool
491+ True if the feature names are consistent, False otherwise.
492+ """
493+ train_val_intersection = set (training_feature_names ).intersection (
494+ set (validation_feature_names )
495+ )
496+ if model_feature_names is None :
497+ return len (train_val_intersection ) != 0
498+ return set (model_feature_names ).issubset (train_val_intersection )
499+
500+ @staticmethod
501+ def _class_names_consistent (
502+ model_class_names : Optional [List [str ]],
503+ training_class_names : List [str ],
504+ validation_class_names : List [str ],
505+ ) -> bool :
506+ """Checks whether the class names in the training and model configs
507+ are consistent.
508+
509+ Parameters
510+ ----------
511+ model_class_names : List[str]
512+ The class names in the model config.
513+ training_class_names : List[str]
514+ The class names in the training dataset config.
515+ validation_class_names : List[str]
516+ The class names in the validation dataset config.
517+
518+ Returns
519+ -------
520+ bool
521+ True if the class names are consistent, False otherwise.
522+ """
523+ if model_class_names is not None :
524+ num_model_classes = len (model_class_names )
525+ try :
526+ return (
527+ training_class_names [:num_model_classes ] == model_class_names
528+ and validation_class_names [:num_model_classes ] == model_class_names
529+ )
530+ except IndexError :
531+ return False
532+ num_training_classes = len (training_class_names )
533+ try :
534+ return validation_class_names [:num_training_classes ] == training_class_names
535+ except IndexError :
536+ return False
537+
404538
405539class CommitValidator :
406540 """Validates the commit prior to the upload.
0 commit comments