@@ -13,21 +13,71 @@ See the License for the specific language governing permissions and
1313limitations under the License.
1414==============================================================================*/
1515import { Injectable } from '@angular/core' ;
16- import { Observable , of } from 'rxjs' ;
17- import { map } from 'rxjs/operators' ;
18- import { TBHttpClient } from '../../webapp_data_source/tb_http_client' ;
16+ import { Observable , of , throwError } from 'rxjs' ;
17+ import { catchError , map , mergeMap } from 'rxjs/operators' ;
1918import {
19+ HttpErrorResponse ,
20+ TBHttpClient ,
21+ } from '../../webapp_data_source/tb_http_client' ;
22+ import * as backendTypes from './runs_backend_types' ;
23+ import {
24+ Domain ,
25+ DomainType ,
2026 HparamsAndMetadata ,
27+ HparamSpec ,
28+ HparamValue ,
29+ MetricSpec ,
2130 Run ,
2231 RunsDataSource ,
32+ RunToHparamsAndMetrics ,
2333} from './runs_data_source_types' ;
2434
35+ const HPARAMS_HTTP_PATH_PREFIX = 'data/plugin/hparams' ;
36+
2537type BackendGetRunsResponse = string [ ] ;
2638
2739function runToRunId ( run : string , experimentId : string ) {
2840 return `${ experimentId } /${ run } ` ;
2941}
3042
43+ function transformBackendHparamSpec (
44+ hparamInfo : backendTypes . HparamSpec
45+ ) : HparamSpec {
46+ let domain : Domain ;
47+ if ( backendTypes . isDiscreteDomainHparamSpec ( hparamInfo ) ) {
48+ domain = { type : DomainType . DISCRETE , values : hparamInfo . domainDiscrete } ;
49+ } else if ( backendTypes . isIntervalDomainHparamSpec ( hparamInfo ) ) {
50+ domain = { ...hparamInfo . domainInterval , type : DomainType . INTERVAL } ;
51+ } else {
52+ domain = {
53+ type : DomainType . INTERVAL ,
54+ minValue : - Infinity ,
55+ maxValue : Infinity ,
56+ } ;
57+ }
58+ return {
59+ description : hparamInfo . description ,
60+ displayName : hparamInfo . displayName ,
61+ name : hparamInfo . name ,
62+ type : hparamInfo . type ,
63+ domain,
64+ } ;
65+ }
66+
67+ function transformBackendMetricSpec (
68+ metricInfo : backendTypes . MetricSpec
69+ ) : MetricSpec {
70+ const { name, ...otherSpec } = metricInfo ;
71+ return {
72+ ...otherSpec ,
73+ tag : name . tag ,
74+ } ;
75+ }
76+
77+ declare interface GetExperimentHparamRequestPayload {
78+ experimentName : string ;
79+ }
80+
3181@Injectable ( )
3282export class TBRunsDataSource implements RunsDataSource {
3383 constructor ( private readonly http : TBHttpClient ) { }
@@ -48,11 +98,112 @@ export class TBRunsDataSource implements RunsDataSource {
4898 }
4999
50100 fetchHparamsMetadata ( experimentId : string ) : Observable < HparamsAndMetadata > {
51- // Return a stub implementation.
52- return of ( {
53- hparamSpecs : [ ] ,
54- metricSpecs : [ ] ,
55- runToHparamsAndMetrics : { } ,
56- } ) ;
101+ const requestPayload : GetExperimentHparamRequestPayload = {
102+ experimentName : experimentId ,
103+ } ;
104+ return this . http
105+ . post < backendTypes . BackendHparamsExperimentResponse > (
106+ `/experiment/${ experimentId } /${ HPARAMS_HTTP_PATH_PREFIX } /experiment` ,
107+ requestPayload
108+ )
109+ . pipe (
110+ map ( ( response ) => {
111+ const colParams : backendTypes . BackendListSessionGroupRequest [ 'colParams' ] =
112+ [ ] ;
113+
114+ for ( const hparamInfo of response . hparamInfos ) {
115+ colParams . push ( { hparam : hparamInfo . name } ) ;
116+ }
117+ for ( const metricInfo of response . metricInfos ) {
118+ colParams . push ( { metric : metricInfo . name } ) ;
119+ }
120+
121+ const listSessionRequestParams : backendTypes . BackendListSessionGroupRequest =
122+ {
123+ experimentName : experimentId ,
124+ allowedStatuses : [
125+ backendTypes . RunStatus . STATUS_FAILURE ,
126+ backendTypes . RunStatus . STATUS_RUNNING ,
127+ backendTypes . RunStatus . STATUS_SUCCESS ,
128+ backendTypes . RunStatus . STATUS_UNKNOWN ,
129+ ] ,
130+ colParams,
131+ startIndex : 0 ,
132+ // arbitrary large number so it does not get clipped.
133+ sliceSize : 1e6 ,
134+ } ;
135+
136+ return {
137+ experimentHparamsInfo : response ,
138+ listSessionRequestParams,
139+ } ;
140+ } ) ,
141+ mergeMap ( ( { experimentHparamsInfo, listSessionRequestParams} ) => {
142+ return this . http
143+ . post < backendTypes . BackendListSessionGroupResponse > (
144+ `/experiment/${ experimentId } /${ HPARAMS_HTTP_PATH_PREFIX } /session_groups` ,
145+ JSON . stringify ( listSessionRequestParams )
146+ )
147+ . pipe (
148+ map ( ( sessionGroupsList ) => {
149+ return { experimentHparamsInfo, sessionGroupsList} ;
150+ } )
151+ ) ;
152+ } ) ,
153+ map ( ( { experimentHparamsInfo, sessionGroupsList} ) => {
154+ const runToHparamsAndMetrics : RunToHparamsAndMetrics = { } ;
155+
156+ // Reorganize the sessionGroup/session into run to <hparams,
157+ // metrics>.
158+ for ( const sessionGroup of sessionGroupsList . sessionGroups ) {
159+ const hparams : HparamValue [ ] = Object . entries (
160+ sessionGroup . hparams
161+ ) . map ( ( keyValue ) => {
162+ const [ hparam , value ] = keyValue ;
163+ return { name : hparam , value} ;
164+ } ) ;
165+
166+ for ( const session of sessionGroup . sessions ) {
167+ for ( const metricValue of session . metricValues ) {
168+ const runName = metricValue . name . group
169+ ? `${ session . name } /${ metricValue . name . group } `
170+ : session . name ;
171+ const runId = `${ experimentId } /${ runName } ` ;
172+ const hparamsAndMetrics = runToHparamsAndMetrics [ runId ] || {
173+ metrics : [ ] ,
174+ hparams,
175+ } ;
176+ hparamsAndMetrics . metrics . push ( {
177+ tag : metricValue . name . tag ,
178+ trainingStep : metricValue . trainingStep ,
179+ value : metricValue . value ,
180+ } ) ;
181+ runToHparamsAndMetrics [ runId ] = hparamsAndMetrics ;
182+ }
183+ }
184+ }
185+ return {
186+ hparamSpecs : experimentHparamsInfo . hparamInfos . map (
187+ transformBackendHparamSpec
188+ ) ,
189+ metricSpecs : experimentHparamsInfo . metricInfos . map (
190+ transformBackendMetricSpec
191+ ) ,
192+ runToHparamsAndMetrics,
193+ } ;
194+ } ) ,
195+ catchError ( ( error ) => {
196+ // HParams plugin return 400 when there are no hparams for an
197+ // experiment.
198+ if ( error instanceof HttpErrorResponse && error . status === 400 ) {
199+ return of ( {
200+ hparamSpecs : [ ] ,
201+ metricSpecs : [ ] ,
202+ runToHparamsAndMetrics : { } ,
203+ } ) ;
204+ }
205+ return throwError ( error ) ;
206+ } )
207+ ) ;
57208 }
58209}
0 commit comments