@@ -114,17 +114,29 @@ def _session_groups_from_tags(self):
114114 )
115115 session_groups = self ._filter (session_groups , filters )
116116 self ._sort (session_groups , extractors )
117+
118+ if _specifies_include (self ._request .col_params ):
119+ _reduce_to_hparams_to_include (
120+ session_groups , self ._request .col_params
121+ )
122+
117123 return session_groups
118124
119125 def _session_groups_from_data_provider (self ):
120126 """Constructs lists of SessionGroups based on DataProvider results."""
121127 filters = _build_data_provider_filters (self ._request .col_params )
122128 sort = _build_data_provider_sort (self ._request .col_params )
129+ hparams_to_include = (
130+ _get_hparams_to_include (self ._request .col_params )
131+ if _specifies_include (self ._request .col_params )
132+ else None
133+ )
123134 response = self ._backend_context .session_groups_from_data_provider (
124135 self ._request_context ,
125136 self ._experiment_id ,
126137 filters ,
127138 sort ,
139+ hparams_to_include ,
128140 )
129141
130142 metric_infos = (
@@ -968,3 +980,62 @@ def _build_data_provider_sort_item(col_param):
968980 hyperparameter_name = col_param .hparam ,
969981 sort_direction = sort_direction ,
970982 )
983+
984+
985+ def _specifies_include (col_params ):
986+ """Determines whether any `ColParam` contains the `include_in_result` field.
987+
988+ In the case where none of the col_params contains the field, we should assume
989+ that all fields should be included in the response.
990+ """
991+ return any (
992+ col_param .HasField ("include_in_result" ) for col_param in col_params
993+ )
994+
995+
996+ def _get_hparams_to_include (col_params ):
997+ """Generates the list of hparams to include in the response.
998+
999+ The determination is based on the `include_in_result` field in ColParam. If
1000+ a ColParam either has `include_in_result: True` or does not specify the
1001+ field at all, then it should be included in the result.
1002+
1003+ Args:
1004+ col_params: A collection of `ColParams` protos.
1005+
1006+ Returns:
1007+ A list of names of hyperparameters to include in the response.
1008+ """
1009+ hparams_to_include = []
1010+ for col_param in col_params :
1011+ if (
1012+ col_param .HasField ("include_in_result" )
1013+ and not col_param .include_in_result
1014+ ):
1015+ # Explicitly set to exclude this hparam.
1016+ continue
1017+ if col_param .hparam :
1018+ hparams_to_include .append (col_param .hparam )
1019+ return hparams_to_include
1020+
1021+
1022+ def _reduce_to_hparams_to_include (session_groups , col_params ):
1023+ """Removes hparams from session_groups that should not be included.
1024+
1025+ Args:
1026+ session_groups: A collection of `SessionGroup` protos, which will be
1027+ modified in place.
1028+ col_params: A collection of `ColParams` protos.
1029+ """
1030+ hparams_to_include = _get_hparams_to_include (col_params )
1031+
1032+ for session_group in session_groups :
1033+ new_hparams = {
1034+ hparam : value
1035+ for (hparam , value ) in session_group .hparams .items ()
1036+ if hparam in hparams_to_include
1037+ }
1038+
1039+ session_group .ClearField ("hparams" )
1040+ for (hparam , value ) in new_hparams .items ():
1041+ session_group .hparams [hparam ].CopyFrom (value )
0 commit comments