@@ -98,7 +98,9 @@ def experiment_from_metadata(
9898 return experiment_from_runs
9999
100100 experiment_from_data_provider_hparams = (
101- self ._experiment_from_data_provider_hparams (data_provider_hparams )
101+ self ._experiment_from_data_provider_hparams (
102+ ctx , experiment_id , data_provider_hparams
103+ )
102104 )
103105 return (
104106 experiment_from_data_provider_hparams
@@ -224,7 +226,7 @@ def _compute_experiment_from_runs(
224226 """
225227 hparam_infos = self ._compute_hparam_infos (hparams_run_to_tag_to_content )
226228 if hparam_infos :
227- metric_infos = self ._compute_metric_infos (
229+ metric_infos = self ._compute_metric_infos_from_runs (
228230 ctx , experiment_id , hparams_run_to_tag_to_content
229231 )
230232 else :
@@ -316,6 +318,8 @@ def _compute_hparam_info_from_values(self, name, values):
316318
317319 def _experiment_from_data_provider_hparams (
318320 self ,
321+ ctx ,
322+ experiment_id ,
319323 data_provider_hparams ,
320324 ):
321325 """Returns an experiment protobuffer based on data provider hparams.
@@ -334,18 +338,24 @@ def _experiment_from_data_provider_hparams(
334338 # until all internal implementations of DataProvider can be
335339 # migrated to use new return value of provider.ListHyperparametersResult.
336340 hyperparameters = data_provider_hparams
341+ session_groups = []
337342 else :
338343 # Is instance of provider.ListHyperparametersResult
339344 hyperparameters = data_provider_hparams .hyperparameters
340-
341- if not hyperparameters :
342- return None
345+ session_groups = data_provider_hparams .session_groups
343346
344347 hparam_infos = [
345348 self ._convert_data_provider_hparam (dp_hparam )
346349 for dp_hparam in hyperparameters
347350 ]
348- return api_pb2 .Experiment (hparam_infos = hparam_infos )
351+ metric_infos = (
352+ self .compute_metric_infos_from_data_provider_session_groups (
353+ ctx , experiment_id , session_groups
354+ )
355+ )
356+ return api_pb2 .Experiment (
357+ hparam_infos = hparam_infos , metric_infos = metric_infos
358+ )
349359
350360 def _convert_data_provider_hparam (self , dp_hparam ):
351361 """Builds an HParamInfo message from data provider Hyperparameter.
@@ -374,19 +384,37 @@ def _convert_data_provider_hparam(self, dp_hparam):
374384 hparam_info .domain_discrete .extend (dp_hparam .domain )
375385 return hparam_info
376386
377- def _compute_metric_infos (
387+ def _compute_metric_infos_from_runs (
378388 self , ctx , experiment_id , hparams_run_to_tag_to_content
379389 ):
390+ session_runs = set (
391+ run
392+ for run , tags in hparams_run_to_tag_to_content .items ()
393+ if metadata .SESSION_START_INFO_TAG in tags
394+ )
380395 return (
381396 api_pb2 .MetricInfo (name = api_pb2 .MetricName (group = group , tag = tag ))
382397 for tag , group in self ._compute_metric_names (
383- ctx , experiment_id , hparams_run_to_tag_to_content
398+ ctx , experiment_id , session_runs
384399 )
385400 )
386401
387- def _compute_metric_names (
388- self , ctx , experiment_id , hparams_run_to_tag_to_content
402+ def compute_metric_infos_from_data_provider_session_groups (
403+ self , ctx , experiment_id , session_groups
389404 ):
405+ session_runs = set (
406+ f"{ s .experiment_id } /{ s .run } "
407+ for sg in session_groups
408+ for s in sg .sessions
409+ )
410+ return [
411+ api_pb2 .MetricInfo (name = api_pb2 .MetricName (group = group , tag = tag ))
412+ for tag , group in self ._compute_metric_names (
413+ ctx , experiment_id , session_runs
414+ )
415+ ]
416+
417+ def _compute_metric_names (self , ctx , experiment_id , session_runs ):
390418 """Computes the list of metric names from all the scalar (run, tag)
391419 pairs.
392420
@@ -412,11 +440,6 @@ def _compute_metric_names(
412440 A python list containing pairs. Each pair is a (tag, group) pair
413441 representing a metric name used in some session.
414442 """
415- session_runs = set (
416- run
417- for run , tags in hparams_run_to_tag_to_content .items ()
418- if metadata .SESSION_START_INFO_TAG in tags
419- )
420443 metric_names_set = set ()
421444 scalars_run_to_tag_to_content = self .scalars_metadata (
422445 ctx , experiment_id
0 commit comments