6363 SchemaOptional ("features" ): str ,
6464 SchemaOptional ("label_values_or_threshold" ): [Or (int , float , str )],
6565 SchemaOptional ("probability_threshold" ): float ,
66+ SchemaOptional ("segment_config" ): [
67+ {
68+ SchemaOptional ("config_name" ): str ,
69+ "name_or_index" : Or (str , int ),
70+ "segments" : [[Or (str , int )]],
71+ SchemaOptional ("display_aliases" ): [str ],
72+ }
73+ ],
6674 SchemaOptional ("facet" ): [
6775 {
6876 "name_or_index" : Or (str , int ),
@@ -316,6 +324,74 @@ class DatasetType(Enum):
316324 IMAGE = "application/x-image"
317325
318326
327+ class SegmentationConfig :
328+ """Config object that defines segment(s) of the dataset on which metrics are computed."""
329+
330+ def __init__ (
331+ self ,
332+ name_or_index : Union [str , int ],
333+ segments : List [List [Union [str , int ]]],
334+ config_name : Optional [str ] = None ,
335+ display_aliases : Optional [List [str ]] = None ,
336+ ):
337+ """Initializes a segmentation configuration for a dataset column.
338+
339+ Args:
340+ name_or_index (str or int): The name or index of the column in the dataset on which
341+ the segment(s) is defined.
342+ segments (List[List[str or int]]): Each List of values represents one segment. If N
343+ Lists are provided, we generate N+1 segments - the additional segment, denoted as
344+ the '__default__' segment, is for the rest of the values that are not covered by
345+ these lists. For continuous columns, a segment must be given as strings in interval
346+ notation (eg.: ["[1, 4]"] or ["(2, 5]"]). A segment can also be composed of
347+ multiple intervals (eg.: ["[1, 4]", "(5, 6]"] is one segment). For categorical
348+ columns, each segment should contain one or more of the categorical values for
349+ the categorical column, which may be strings or integers.
350+ Eg,: For a continuous column, ``segments`` could be
351+ [["[1, 4]", "(5, 6]"], ["(7, 9)"]] - this generates 3 segments including the
352+ default segment. For a categorical columns with values ("A", "B", "C", "D"),
353+ ``segments``,could be [["A", "B"]]. This generate 2 segments, including the default
354+ segment.
355+ config_name (str) - Optional name for the segment config to identify the config.
356+ display_aliases (List[str]) - Optional list of display names for the ``segments`` for
357+ the analysis output and report. This list should be the same length as the number of
358+ lists provided in ``segments`` or with one additional display alias for the default
359+ segment.
360+
361+ Raises:
362+ ValueError: when the ``name_or_index`` is None, ``segments`` is invalid, or a wrong
363+ number of ``display_aliases`` are specified.
364+ """
365+ if name_or_index is None :
366+ raise ValueError ("`name_or_index` cannot be None" )
367+ self .name_or_index = name_or_index
368+ if (
369+ not segments
370+ or not isinstance (segments , list )
371+ or not all ([isinstance (segment , list ) for segment in segments ])
372+ ):
373+ raise ValueError ("`segments` must be a list of lists of values or intervals." )
374+ self .segments = segments
375+ self .config_name = config_name
376+ if display_aliases is not None and not (
377+ len (display_aliases ) == len (segments ) or len (display_aliases ) == len (segments ) + 1
378+ ):
379+ raise ValueError (
380+ "Number of `display_aliases` must equal the number of segments"
381+ " specified or with one additional default segment display alias."
382+ )
383+ self .display_aliases = display_aliases
384+
385+ def to_dict (self ) -> Dict [str , Any ]: # pragma: no cover
386+ """Returns SegmentationConfig as a dict."""
387+ segment_config_dict = {"name_or_index" : self .name_or_index , "segments" : self .segments }
388+ if self .config_name :
389+ segment_config_dict ["config_name" ] = self .config_name
390+ if self .display_aliases :
391+ segment_config_dict ["display_aliases" ] = self .display_aliases
392+ return segment_config_dict
393+
394+
319395class DataConfig :
320396 """Config object related to configurations of the input and output dataset."""
321397
@@ -336,6 +412,7 @@ def __init__(
336412 predicted_label_headers : Optional [List [str ]] = None ,
337413 predicted_label : Optional [Union [str , int ]] = None ,
338414 excluded_columns : Optional [Union [List [int ], List [str ]]] = None ,
415+ segmentation_config : Optional [List [SegmentationConfig ]] = None ,
339416 ):
340417 """Initializes a configuration of both input and output datasets.
341418
@@ -402,6 +479,8 @@ def __init__(
402479 Only a single predicted label per sample is supported at this time.
403480 excluded_columns (list[int] or list[str]): A list of names or indices of the columns
404481 which are to be excluded from making model inference API calls.
482+ segmentation_config (list[SegmentationConfig]): A list of ``SegmentationConfig``
483+ objects.
405484
406485 Raises:
407486 ValueError: when the ``dataset_type`` is invalid, predicted label dataset parameters
@@ -469,6 +548,7 @@ def __init__(
469548 self .predicted_label_headers = predicted_label_headers
470549 self .predicted_label = predicted_label
471550 self .excluded_columns = excluded_columns
551+ self .segmentation_configs = segmentation_config
472552 self .analysis_config = {
473553 "dataset_type" : dataset_type ,
474554 }
@@ -486,6 +566,12 @@ def __init__(
486566 _set (predicted_label_headers , "predicted_label_headers" , self .analysis_config )
487567 _set (predicted_label , "predicted_label" , self .analysis_config )
488568 _set (excluded_columns , "excluded_columns" , self .analysis_config )
569+ if segmentation_config :
570+ _set (
571+ [item .to_dict () for item in segmentation_config ],
572+ "segment_config" ,
573+ self .analysis_config ,
574+ )
489575
490576 def get_config (self ):
491577 """Returns part of an analysis config dictionary."""
0 commit comments