2020import os
2121
2222
23+ from tensorboard .data import provider
2324from tensorboard .plugins .hparams import api_pb2
2425from tensorboard .plugins .hparams import metadata
2526from google .protobuf import json_format
2627from tensorboard .plugins .scalar import metadata as scalar_metadata
2728
29+ _DISCRETE_DOMAIN_TYPE_TO_DATA_TYPE = {
30+ provider .HyperparameterDomainType .DISCRETE_BOOL : api_pb2 .DATA_TYPE_BOOL ,
31+ provider .HyperparameterDomainType .DISCRETE_FLOAT : api_pb2 .DATA_TYPE_FLOAT64 ,
32+ provider .HyperparameterDomainType .DISCRETE_STRING : api_pb2 .DATA_TYPE_STRING ,
33+ }
34+
2835
2936class Context :
3037 """Wraps the base_plugin.TBContext to stores additional data shared across
@@ -51,33 +58,57 @@ def __init__(self, tb_context, max_domain_discrete_len=10):
5158 self ._max_domain_discrete_len = max_domain_discrete_len
5259
5360 def experiment_from_metadata (
54- self , ctx , experiment_id , hparams_run_to_tag_to_content
61+ self ,
62+ ctx ,
63+ experiment_id ,
64+ hparams_run_to_tag_to_content ,
65+ data_provider_hparams ,
5566 ):
56- """Returns the experiment protobuffer defining the experiment.
57-
58- Accepts a dict containing the plugin contents for all summary tags
59- associated with the hparams plugin, as an optimization for callers
60- who already have this information available, so that this function
61- can minimize its calls to the underlying `DataProvider`.
67+ """Returns the experiment proto defining the experiment.
6268
6369 This method first attempts to find a metadata.EXPERIMENT_TAG tag and
64- retrieve the associated protobuffer. If no such tag is found, the method
65- will attempt to build a minimal experiment protobuffer by scanning for
66- all metadata.SESSION_START_INFO_TAG tags (to compute the hparam_infos
67- field of the experiment) and for all scalar tags (to compute the
68- metric_infos field of the experiment).
70+ retrieve the associated proto.
71+
72+ If no such tag is found, the method will attempt to build a minimal
73+ experiment proto by scanning for all metadata.SESSION_START_INFO_TAG
74+ tags (to compute the hparam_infos field of the experiment) and for all
75+ scalar tags (to compute the metric_infos field of the experiment).
76+
77+ If no metadata.EXPERIMENT_TAG nor metadata.SESSION_START_INFO_TAG tags
78+ are found, then will build an experiment proto using the results from
79+ DataProvider.list_hyperparameters().
80+
81+ Args:
82+ experiment_id: String, from `plugin_util.experiment_id`.
83+ hparams_run_to_tag_to_content: The output from an hparams_metadata()
84+ call. A dict `d` such that `d[run][tag]` is a `bytes` value with the
85+ summary metadata content for the keyed time series.
86+ data_provider_hparams: The ouput from an hparams_from_data_provider()
87+ call, corresponding to DataProvider.list_hyperparameters().
88+ A Collection[provider.Hyperparameter].
6989
7090 Returns:
71- The experiment protobuffer. If no tags are found from which an experiment
72- protobuffer can be built (possibly, because the event data has not been
73- completely loaded yet), returns an entirely empty experiment.
91+ The experiment proto. If no data is found for an experiment proto to
92+ be built, returns an entirely empty experiment.
7493 """
7594 experiment = self ._find_experiment_tag (hparams_run_to_tag_to_content )
7695 if experiment :
7796 return experiment
78- return self ._compute_experiment_from_runs (
97+
98+ experiment_from_runs = self ._compute_experiment_from_runs (
7999 ctx , experiment_id , hparams_run_to_tag_to_content
80100 )
101+ if experiment_from_runs :
102+ return experiment_from_runs
103+
104+ experiment_from_data_provider_hparams = (
105+ self ._experiment_from_data_provider_hparams (data_provider_hparams )
106+ )
107+ return (
108+ experiment_from_data_provider_hparams
109+ if experiment_from_data_provider_hparams
110+ else api_pb2 .Experiment ()
111+ )
81112
82113 @property
83114 def tb_context (self ):
@@ -159,6 +190,12 @@ def read_last_scalars(self, ctx, experiment_id, run_tag_filter):
159190 for (run , tag_to_data ) in data_provider_output .items ()
160191 }
161192
193+ def hparams_from_data_provider (self , ctx , experiment_id ):
194+ """Calls DataProvider.list_hyperparameters() and returns the result."""
195+ return self ._tb_context .data_provider .list_hyperparameters (
196+ ctx , experiment_ids = [experiment_id ]
197+ )
198+
162199 def _find_experiment_tag (self , hparams_run_to_tag_to_content ):
163200 """Finds the experiment associcated with the metadata.EXPERIMENT_TAG
164201 tag.
@@ -179,7 +216,7 @@ def _compute_experiment_from_runs(
179216 ):
180217 """Computes a minimal Experiment protocol buffer by scanning the runs.
181218
182- Returns an empty Experiment if there are no hparam infos logged.
219+ Returns None if there are no hparam infos logged.
183220 """
184221 hparam_infos = self ._compute_hparam_infos (hparams_run_to_tag_to_content )
185222 if hparam_infos :
@@ -188,6 +225,9 @@ def _compute_experiment_from_runs(
188225 )
189226 else :
190227 metric_infos = []
228+ if not hparam_infos and not metric_infos :
229+ return None
230+
191231 return api_pb2 .Experiment (
192232 hparam_infos = hparam_infos , metric_infos = metric_infos
193233 )
@@ -273,6 +313,56 @@ def _compute_hparam_info_from_values(self, name, values):
273313
274314 return result
275315
316+ def _experiment_from_data_provider_hparams (
317+ self ,
318+ data_provider_hparams ,
319+ ):
320+ """Returns an experiment protobuffer based on data provider hparams.
321+
322+ Args:
323+ data_provider_hparams: The ouput from an hparams_from_data_provider()
324+ call, corresponding to DataProvider.list_hyperparameters().
325+ A Collection[provider.Hyperparameter].
326+
327+ Returns:
328+ The experiment proto. If there are no hyperparameters in the input,
329+ returns None.
330+ """
331+ if not data_provider_hparams :
332+ return None
333+
334+ hparam_infos = [
335+ self ._convert_data_provider_hparam (dp_hparam )
336+ for dp_hparam in data_provider_hparams
337+ ]
338+ return api_pb2 .Experiment (hparam_infos = hparam_infos )
339+
340+ def _convert_data_provider_hparam (self , dp_hparam ):
341+ """Builds an HParamInfo message from data provider Hyperparameter.
342+
343+ Args:
344+ dp_hparam: The provider.Hyperparameter returned by the call to
345+ provider.DataProvider.list_hyperparameters().
346+
347+ Returns:
348+ An HParamInfo to include in the Experiment.
349+ """
350+ hparam_info = api_pb2 .HParamInfo (
351+ name = dp_hparam .hyperparameter_name ,
352+ display_name = dp_hparam .hyperparameter_display_name ,
353+ )
354+ if dp_hparam .domain_type == provider .HyperparameterDomainType .INTERVAL :
355+ hparam_info .type = api_pb2 .DATA_TYPE_FLOAT64
356+ (dp_hparam_min , dp_hparam_max ) = dp_hparam .domain
357+ hparam_info .domain_interval .min_value = dp_hparam_min
358+ hparam_info .domain_interval .max_value = dp_hparam_max
359+ elif dp_hparam .domain_type in _DISCRETE_DOMAIN_TYPE_TO_DATA_TYPE .keys ():
360+ hparam_info .type = _DISCRETE_DOMAIN_TYPE_TO_DATA_TYPE .get (
361+ dp_hparam .domain_type
362+ )
363+ hparam_info .domain_discrete .extend (dp_hparam .domain )
364+ return hparam_info
365+
276366 def _compute_metric_infos (
277367 self , ctx , experiment_id , hparams_run_to_tag_to_content
278368 ):
0 commit comments