@@ -153,7 +153,9 @@ def _mock_list_hyperparameters(
153153 ):
154154 return self ._hyperparameters
155155
156- def _experiment_from_metadata (self , * , include_metrics = True ):
156+ def _experiment_from_metadata (
157+ self , * , include_metrics = True , hparams_limit = None
158+ ):
157159 """Calls the expected operations for generating an Experiment proto."""
158160 ctxt = backend_context .Context (self ._mock_tb_context )
159161 request_ctx = context .RequestContext ()
@@ -162,7 +164,10 @@ def _experiment_from_metadata(self, *, include_metrics=True):
162164 "123" ,
163165 include_metrics ,
164166 ctxt .hparams_metadata (request_ctx , "123" ),
165- ctxt .hparams_from_data_provider (request_ctx , "123" , limit = None ),
167+ ctxt .hparams_from_data_provider (
168+ request_ctx , "123" , limit = hparams_limit
169+ ),
170+ hparams_limit ,
166171 )
167172
168173 def test_experiment_with_experiment_tag (self ):
@@ -897,6 +902,178 @@ def test_experiment_from_data_provider_old_response_type(self):
897902 """
898903 self .assertProtoEquals (expected_exp , actual_exp )
899904
905+ def test_experiment_from_tags_with_hparams_limit_no_differed_hparams (self ):
906+ experiment = """
907+ name: 'Test experiment'
908+ hparam_infos: {
909+ name: 'batch_size'
910+ type: DATA_TYPE_FLOAT64
911+ differs: false
912+ }
913+ hparam_infos: {
914+ name: 'lr'
915+ type: DATA_TYPE_FLOAT64
916+ differs: false
917+ }
918+ hparam_infos: {
919+ name: 'use_batch_norm'
920+ type: DATA_TYPE_BOOL
921+ differs: false
922+ }
923+ hparam_infos: {
924+ name: 'model_type'
925+ type: DATA_TYPE_STRING
926+ differs: false
927+ }
928+ """
929+ t = provider .TensorTimeSeries (
930+ max_step = 0 ,
931+ max_wall_time = 0 ,
932+ plugin_content = self ._serialized_plugin_data (
933+ DATA_TYPE_EXPERIMENT , experiment
934+ ),
935+ description = "" ,
936+ display_name = "" ,
937+ )
938+ self ._mock_tb_context .data_provider .list_tensors .side_effect = None
939+ self ._mock_tb_context .data_provider .list_tensors .return_value = {
940+ "train" : {metadata .EXPERIMENT_TAG : t }
941+ }
942+ expected_exp = """
943+ name: 'Test experiment'
944+ hparam_infos: {
945+ name: 'batch_size'
946+ type: DATA_TYPE_FLOAT64
947+ differs: false
948+ }
949+ hparam_infos: {
950+ name: 'lr'
951+ type: DATA_TYPE_FLOAT64
952+ differs: false
953+ }
954+ """
955+ actual_exp = self ._experiment_from_metadata (
956+ include_metrics = False , hparams_limit = 2
957+ )
958+ self .assertProtoEquals (expected_exp , actual_exp )
959+
960+ def test_experiment_from_tags_with_hparams_limit_returns_differed_hparams_first (
961+ self ,
962+ ):
963+ experiment = """
964+ name: 'Test experiment'
965+ hparam_infos: {
966+ name: 'batch_size'
967+ type: DATA_TYPE_FLOAT64
968+ differs: false
969+ }
970+ hparam_infos: {
971+ name: 'lr'
972+ type: DATA_TYPE_FLOAT64
973+ differs: true
974+ }
975+ hparam_infos: {
976+ name: 'use_batch_norm'
977+ type: DATA_TYPE_BOOL
978+ differs: false
979+ }
980+ hparam_infos: {
981+ name: 'model_type'
982+ type: DATA_TYPE_STRING
983+ differs: true
984+ }
985+ """
986+ t = provider .TensorTimeSeries (
987+ max_step = 0 ,
988+ max_wall_time = 0 ,
989+ plugin_content = self ._serialized_plugin_data (
990+ DATA_TYPE_EXPERIMENT , experiment
991+ ),
992+ description = "" ,
993+ display_name = "" ,
994+ )
995+ self ._mock_tb_context .data_provider .list_tensors .side_effect = None
996+ self ._mock_tb_context .data_provider .list_tensors .return_value = {
997+ "train" : {metadata .EXPERIMENT_TAG : t }
998+ }
999+ expected_exp = """
1000+ name: 'Test experiment'
1001+ hparam_infos: {
1002+ name: 'lr'
1003+ type: DATA_TYPE_FLOAT64
1004+ differs: true
1005+ },
1006+ hparam_infos: {
1007+ name: 'model_type'
1008+ type: DATA_TYPE_STRING
1009+ differs: true
1010+ }
1011+ """
1012+ actual_exp = self ._experiment_from_metadata (
1013+ include_metrics = False , hparams_limit = 2
1014+ )
1015+ self .assertProtoEquals (expected_exp , actual_exp )
1016+
1017+ def test_experiment_from_tags_sorts_differed_hparams_first (self ):
1018+ experiment = """
1019+ name: 'Test experiment'
1020+ hparam_infos: {
1021+ name: 'batch_size'
1022+ type: DATA_TYPE_FLOAT64
1023+ differs: false
1024+ }
1025+ hparam_infos: {
1026+ name: 'lr'
1027+ type: DATA_TYPE_FLOAT64
1028+ differs: true
1029+ }
1030+ hparam_infos: {
1031+ name: 'use_batch_norm'
1032+ type: DATA_TYPE_BOOL
1033+ differs: false
1034+ }
1035+ hparam_infos: {
1036+ name: 'model_type'
1037+ type: DATA_TYPE_STRING
1038+ differs: true
1039+ }
1040+ """
1041+ t = provider .TensorTimeSeries (
1042+ max_step = 0 ,
1043+ max_wall_time = 0 ,
1044+ plugin_content = self ._serialized_plugin_data (
1045+ DATA_TYPE_EXPERIMENT , experiment
1046+ ),
1047+ description = "" ,
1048+ display_name = "" ,
1049+ )
1050+ self ._mock_tb_context .data_provider .list_tensors .side_effect = None
1051+ self ._mock_tb_context .data_provider .list_tensors .return_value = {
1052+ "train" : {metadata .EXPERIMENT_TAG : t }
1053+ }
1054+ expected_exp = """
1055+ name: 'Test experiment'
1056+ hparam_infos: {
1057+ name: 'lr'
1058+ type: DATA_TYPE_FLOAT64
1059+ differs: true
1060+ }
1061+ hparam_infos: {
1062+ name: 'model_type'
1063+ type: DATA_TYPE_STRING
1064+ differs: true
1065+ }
1066+ hparam_infos: {
1067+ name: 'batch_size'
1068+ type: DATA_TYPE_FLOAT64
1069+ differs: false
1070+ }
1071+ """
1072+ actual_exp = self ._experiment_from_metadata (
1073+ include_metrics = False , hparams_limit = None
1074+ )
1075+ self .assertProtoEquals (expected_exp , actual_exp )
1076+
9001077 def _serialized_plugin_data (self , data_oneof_field , text_protobuffer ):
9011078 oneof_type_dict = {
9021079 DATA_TYPE_EXPERIMENT : api_pb2 .Experiment ,
0 commit comments