Skip to content

Commit 6d8d2be

Browse files
authored
Hparams: Use DataProvider.list_hyperparameters() to generate HparamInfos (#6391)
Use the result from a DataProvider.list_hyperparameters() call to generate HparamInfos for operations in the Hparams plugin. Generate the HparamInfo name, display_name, data_type, and domain from each Hyperparameter. We use the same two-step strategy for generating these as is done with tensor-based hparams. The tensor-based hparams has one operation, `hparams_metadata()` for retrieving the tensor data, and then passes the result to `experiment_metadata()` to generate the HparamInfos and MetricInfos. Similarly, we add an `hparams_from_data_provider()` operation for retrieving the hyperparameter data and we pass the result to `experiment_metadata()`. Note: We don't yet generate `MetricInfos`. We don't yet have enough information from the data provider. Note: We don't yet merge the new HparamInfos with the tensor-based HparamInfos.
1 parent abc4219 commit 6d8d2be

File tree

4 files changed

+287
-54
lines changed

4 files changed

+287
-54
lines changed

tensorboard/plugins/hparams/backend_context.py

Lines changed: 107 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,18 @@
2020
import os
2121

2222

23+
from tensorboard.data import provider
2324
from tensorboard.plugins.hparams import api_pb2
2425
from tensorboard.plugins.hparams import metadata
2526
from google.protobuf import json_format
2627
from tensorboard.plugins.scalar import metadata as scalar_metadata
2728

29+
_DISCRETE_DOMAIN_TYPE_TO_DATA_TYPE = {
30+
provider.HyperparameterDomainType.DISCRETE_BOOL: api_pb2.DATA_TYPE_BOOL,
31+
provider.HyperparameterDomainType.DISCRETE_FLOAT: api_pb2.DATA_TYPE_FLOAT64,
32+
provider.HyperparameterDomainType.DISCRETE_STRING: api_pb2.DATA_TYPE_STRING,
33+
}
34+
2835

2936
class Context:
3037
"""Wraps the base_plugin.TBContext to stores additional data shared across
@@ -51,33 +58,57 @@ def __init__(self, tb_context, max_domain_discrete_len=10):
5158
self._max_domain_discrete_len = max_domain_discrete_len
5259

5360
def experiment_from_metadata(
54-
self, ctx, experiment_id, hparams_run_to_tag_to_content
61+
self,
62+
ctx,
63+
experiment_id,
64+
hparams_run_to_tag_to_content,
65+
data_provider_hparams,
5566
):
56-
"""Returns the experiment protobuffer defining the experiment.
57-
58-
Accepts a dict containing the plugin contents for all summary tags
59-
associated with the hparams plugin, as an optimization for callers
60-
who already have this information available, so that this function
61-
can minimize its calls to the underlying `DataProvider`.
67+
"""Returns the experiment proto defining the experiment.
6268
6369
This method first attempts to find a metadata.EXPERIMENT_TAG tag and
64-
retrieve the associated protobuffer. If no such tag is found, the method
65-
will attempt to build a minimal experiment protobuffer by scanning for
66-
all metadata.SESSION_START_INFO_TAG tags (to compute the hparam_infos
67-
field of the experiment) and for all scalar tags (to compute the
68-
metric_infos field of the experiment).
70+
retrieve the associated proto.
71+
72+
If no such tag is found, the method will attempt to build a minimal
73+
experiment proto by scanning for all metadata.SESSION_START_INFO_TAG
74+
tags (to compute the hparam_infos field of the experiment) and for all
75+
scalar tags (to compute the metric_infos field of the experiment).
76+
77+
If no metadata.EXPERIMENT_TAG nor metadata.SESSION_START_INFO_TAG tags
78+
are found, then will build an experiment proto using the results from
79+
DataProvider.list_hyperparameters().
80+
81+
Args:
82+
experiment_id: String, from `plugin_util.experiment_id`.
83+
hparams_run_to_tag_to_content: The output from an hparams_metadata()
84+
call. A dict `d` such that `d[run][tag]` is a `bytes` value with the
85+
summary metadata content for the keyed time series.
86+
data_provider_hparams: The ouput from an hparams_from_data_provider()
87+
call, corresponding to DataProvider.list_hyperparameters().
88+
A Collection[provider.Hyperparameter].
6989
7090
Returns:
71-
The experiment protobuffer. If no tags are found from which an experiment
72-
protobuffer can be built (possibly, because the event data has not been
73-
completely loaded yet), returns an entirely empty experiment.
91+
The experiment proto. If no data is found for an experiment proto to
92+
be built, returns an entirely empty experiment.
7493
"""
7594
experiment = self._find_experiment_tag(hparams_run_to_tag_to_content)
7695
if experiment:
7796
return experiment
78-
return self._compute_experiment_from_runs(
97+
98+
experiment_from_runs = self._compute_experiment_from_runs(
7999
ctx, experiment_id, hparams_run_to_tag_to_content
80100
)
101+
if experiment_from_runs:
102+
return experiment_from_runs
103+
104+
experiment_from_data_provider_hparams = (
105+
self._experiment_from_data_provider_hparams(data_provider_hparams)
106+
)
107+
return (
108+
experiment_from_data_provider_hparams
109+
if experiment_from_data_provider_hparams
110+
else api_pb2.Experiment()
111+
)
81112

82113
@property
83114
def tb_context(self):
@@ -159,6 +190,12 @@ def read_last_scalars(self, ctx, experiment_id, run_tag_filter):
159190
for (run, tag_to_data) in data_provider_output.items()
160191
}
161192

193+
def hparams_from_data_provider(self, ctx, experiment_id):
194+
"""Calls DataProvider.list_hyperparameters() and returns the result."""
195+
return self._tb_context.data_provider.list_hyperparameters(
196+
ctx, experiment_ids=[experiment_id]
197+
)
198+
162199
def _find_experiment_tag(self, hparams_run_to_tag_to_content):
163200
"""Finds the experiment associcated with the metadata.EXPERIMENT_TAG
164201
tag.
@@ -179,7 +216,7 @@ def _compute_experiment_from_runs(
179216
):
180217
"""Computes a minimal Experiment protocol buffer by scanning the runs.
181218
182-
Returns an empty Experiment if there are no hparam infos logged.
219+
Returns None if there are no hparam infos logged.
183220
"""
184221
hparam_infos = self._compute_hparam_infos(hparams_run_to_tag_to_content)
185222
if hparam_infos:
@@ -188,6 +225,9 @@ def _compute_experiment_from_runs(
188225
)
189226
else:
190227
metric_infos = []
228+
if not hparam_infos and not metric_infos:
229+
return None
230+
191231
return api_pb2.Experiment(
192232
hparam_infos=hparam_infos, metric_infos=metric_infos
193233
)
@@ -273,6 +313,56 @@ def _compute_hparam_info_from_values(self, name, values):
273313

274314
return result
275315

316+
def _experiment_from_data_provider_hparams(
317+
self,
318+
data_provider_hparams,
319+
):
320+
"""Returns an experiment protobuffer based on data provider hparams.
321+
322+
Args:
323+
data_provider_hparams: The ouput from an hparams_from_data_provider()
324+
call, corresponding to DataProvider.list_hyperparameters().
325+
A Collection[provider.Hyperparameter].
326+
327+
Returns:
328+
The experiment proto. If there are no hyperparameters in the input,
329+
returns None.
330+
"""
331+
if not data_provider_hparams:
332+
return None
333+
334+
hparam_infos = [
335+
self._convert_data_provider_hparam(dp_hparam)
336+
for dp_hparam in data_provider_hparams
337+
]
338+
return api_pb2.Experiment(hparam_infos=hparam_infos)
339+
340+
def _convert_data_provider_hparam(self, dp_hparam):
341+
"""Builds an HParamInfo message from data provider Hyperparameter.
342+
343+
Args:
344+
dp_hparam: The provider.Hyperparameter returned by the call to
345+
provider.DataProvider.list_hyperparameters().
346+
347+
Returns:
348+
An HParamInfo to include in the Experiment.
349+
"""
350+
hparam_info = api_pb2.HParamInfo(
351+
name=dp_hparam.hyperparameter_name,
352+
display_name=dp_hparam.hyperparameter_display_name,
353+
)
354+
if dp_hparam.domain_type == provider.HyperparameterDomainType.INTERVAL:
355+
hparam_info.type = api_pb2.DATA_TYPE_FLOAT64
356+
(dp_hparam_min, dp_hparam_max) = dp_hparam.domain
357+
hparam_info.domain_interval.min_value = dp_hparam_min
358+
hparam_info.domain_interval.max_value = dp_hparam_max
359+
elif dp_hparam.domain_type in _DISCRETE_DOMAIN_TYPE_TO_DATA_TYPE.keys():
360+
hparam_info.type = _DISCRETE_DOMAIN_TYPE_TO_DATA_TYPE.get(
361+
dp_hparam.domain_type
362+
)
363+
hparam_info.domain_discrete.extend(dp_hparam.domain)
364+
return hparam_info
365+
276366
def _compute_metric_infos(
277367
self, ctx, experiment_id, hparams_run_to_tag_to_content
278368
):

0 commit comments

Comments
 (0)