Skip to content

Commit a0865d9

Browse files
authored
Bug Fix: update hparams plugin to generate domains for boolean hparams (#6393)
## Motivation for features / changes This has been a bug for a while that we only noticed now that hparams are being used to filter runs in the time series dashboard. Googlers see https://chat.google.com/room/AAAA03izhrk/VyyPgojcNvY for context on how this was discovered ## Technical description of changes Boolean HParams were not having a domain set. This lead to the ui treating them as intervals, however, this lead to incorrect filter conditions. We currently treat all non number values being filtered this way as not matching (the alternative would be to treat them all as being true which also seems wrong). This leads to all runs with a value for a boolean hparam being filtered out. ## Screenshots of UI changes (or N/A) Googlers see cl/532932348 ## Alternate designs / implementations considered (or N/A) This could have been done on the client but ideally the bug would be fixed in other places where the api is being used.
1 parent 87b3656 commit a0865d9

File tree

2 files changed

+47
-4
lines changed

2 files changed

+47
-4
lines changed

tensorboard/plugins/hparams/backend_context.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,9 @@ def _compute_hparam_info_from_values(self, name, values):
311311
):
312312
result.domain_discrete.extend(distinct_values)
313313

314+
if result.type == api_pb2.DATA_TYPE_BOOL:
315+
result.domain_discrete.extend([True, False])
316+
314317
return result
315318

316319
def _experiment_from_data_provider_hparams(

tensorboard/plugins/hparams/backend_context_test.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ def _mock_list_tensors(
8080
},
8181
}
8282
result = {}
83-
for (run, tag_to_content) in hparams_content.items():
83+
for run, tag_to_content in hparams_content.items():
8484
result.setdefault(run, {})
85-
for (tag, content) in tag_to_content.items():
85+
for tag, content in tag_to_content.items():
8686
t = provider.TensorTimeSeries(
8787
max_step=0,
8888
max_wall_time=0,
@@ -131,9 +131,9 @@ def _mock_list_scalars(
131131
},
132132
}
133133
result = {}
134-
for (run, tag_to_content) in scalars_content.items():
134+
for run, tag_to_content in scalars_content.items():
135135
result.setdefault(run, {})
136-
for (tag, content) in tag_to_content.items():
136+
for tag, content in tag_to_content.items():
137137
t = provider.ScalarTimeSeries(
138138
max_step=0,
139139
max_wall_time=0,
@@ -358,6 +358,46 @@ def test_experiment_without_experiment_tag_many_distinct_values(self):
358358
_canonicalize_experiment(actual_exp)
359359
self.assertProtoEquals(expected_exp, actual_exp)
360360

361+
def test_experiment_with_bool_types(self):
362+
self.session_1_start_info_ = """
363+
hparams:[
364+
{key: 'batch_size' value: {bool_value: true}}
365+
]
366+
"""
367+
self.session_2_start_info_ = """
368+
hparams:[
369+
{key: 'batch_size' value: {bool_value: true}}
370+
]
371+
"""
372+
self.session_3_start_info_ = """
373+
hparams:[
374+
]
375+
"""
376+
expected_exp = """
377+
hparam_infos: {
378+
name: 'batch_size'
379+
type: DATA_TYPE_BOOL
380+
domain_discrete: {
381+
values: [{bool_value: true}, {bool_value: false}]
382+
}
383+
}
384+
metric_infos: {
385+
name: {group: '', tag: 'accuracy'}
386+
}
387+
metric_infos: {
388+
name: {group: '', tag: 'loss'}
389+
}
390+
metric_infos: {
391+
name: {group: 'eval', tag: 'loss'}
392+
}
393+
metric_infos: {
394+
name: {group: 'train', tag: 'loss'}
395+
}
396+
"""
397+
actual_exp = self._experiment_from_metadata()
398+
_canonicalize_experiment(actual_exp)
399+
self.assertProtoEquals(expected_exp, actual_exp)
400+
361401
def test_experiment_without_any_hparams(self):
362402
request_ctx = context.RequestContext()
363403
actual_exp = self._experiment_from_metadata()

0 commit comments

Comments
 (0)