@@ -282,20 +282,26 @@ func TestNewClientExplicitNoAuth(t *testing.T) {
282282
283283func TestCustomToken (t * testing.T ) {
284284 client := & Client {
285- signer : testSigner ,
286- clock : testClock ,
285+ baseClient : & baseClient {
286+ signer : testSigner ,
287+ clock : testClock ,
288+ },
287289 }
288290 token , err := client .CustomToken (context .Background (), "user1" )
289291 if err != nil {
290292 t .Fatal (err )
291293 }
292- verifyCustomToken (context .Background (), token , nil , t )
294+ if err := verifyCustomToken (context .Background (), token , nil , "" ); err != nil {
295+ t .Fatal (err )
296+ }
293297}
294298
295299func TestCustomTokenWithClaims (t * testing.T ) {
296300 client := & Client {
297- signer : testSigner ,
298- clock : testClock ,
301+ baseClient : & baseClient {
302+ signer : testSigner ,
303+ clock : testClock ,
304+ },
299305 }
300306 claims := map [string ]interface {}{
301307 "foo" : "bar" ,
@@ -306,19 +312,46 @@ func TestCustomTokenWithClaims(t *testing.T) {
306312 if err != nil {
307313 t .Fatal (err )
308314 }
309- verifyCustomToken (context .Background (), token , claims , t )
315+ if err := verifyCustomToken (context .Background (), token , claims , "" ); err != nil {
316+ t .Fatal (err )
317+ }
310318}
311319
312320func TestCustomTokenWithNilClaims (t * testing.T ) {
313321 client := & Client {
314- signer : testSigner ,
315- clock : testClock ,
322+ baseClient : & baseClient {
323+ signer : testSigner ,
324+ clock : testClock ,
325+ },
316326 }
317327 token , err := client .CustomTokenWithClaims (context .Background (), "user1" , nil )
318328 if err != nil {
319329 t .Fatal (err )
320330 }
321- verifyCustomToken (context .Background (), token , nil , t )
331+ if err := verifyCustomToken (context .Background (), token , nil , "" ); err != nil {
332+ t .Fatal (err )
333+ }
334+ }
335+
336+ func TestCustomTokenForTenant (t * testing.T ) {
337+ client := & Client {
338+ baseClient : & baseClient {
339+ tenantID : "tenantID" ,
340+ signer : testSigner ,
341+ clock : testClock ,
342+ },
343+ }
344+ claims := map [string ]interface {}{
345+ "foo" : "bar" ,
346+ "premium" : true ,
347+ }
348+ token , err := client .CustomTokenWithClaims (context .Background (), "user1" , claims )
349+ if err != nil {
350+ t .Fatal (err )
351+ }
352+ if err := verifyCustomToken (context .Background (), token , claims , "tenantID" ); err != nil {
353+ t .Fatal (err )
354+ }
322355}
323356
324357func TestCustomTokenError (t * testing.T ) {
@@ -333,7 +366,7 @@ func TestCustomTokenError(t *testing.T) {
333366 {"ReservedClaims" , "uid" , map [string ]interface {}{"sub" : "1234" , "aud" : "foo" }},
334367 }
335368
336- client := & Client {
369+ client := & baseClient {
337370 signer : testSigner ,
338371 clock : testClock ,
339372 }
@@ -628,9 +661,9 @@ func TestCustomTokenVerification(t *testing.T) {
628661 client := & Client {
629662 baseClient : & baseClient {
630663 idTokenVerifier : testIDTokenVerifier ,
664+ signer : testSigner ,
665+ clock : testClock ,
631666 },
632- signer : testSigner ,
633- clock : testClock ,
634667 }
635668 token , err := client .CustomToken (context .Background (), "user1" )
636669 if err != nil {
@@ -1137,52 +1170,61 @@ func checkBaseClient(client *Client, wantProjectID string) error {
11371170 return nil
11381171}
11391172
1140- func verifyCustomToken (ctx context.Context , token string , expected map [string ]interface {}, t * testing.T ) {
1173+ func verifyCustomToken (
1174+ ctx context.Context , token string , expected map [string ]interface {}, tenantID string ) error {
1175+
11411176 if err := testIDTokenVerifier .verifySignature (ctx , token ); err != nil {
1142- t . Fatal ( err )
1177+ return err
11431178 }
1179+
11441180 var (
11451181 header jwtHeader
11461182 payload customToken
11471183 )
11481184 segments := strings .Split (token , "." )
11491185 if err := decode (segments [0 ], & header ); err != nil {
1150- t . Fatal ( err )
1186+ return err
11511187 }
11521188 if err := decode (segments [1 ], & payload ); err != nil {
1153- t . Fatal ( err )
1189+ return err
11541190 }
11551191
11561192 email , err := testSigner .Email (ctx )
11571193 if err != nil {
1158- t . Fatal ( err )
1194+ return err
11591195 }
11601196
11611197 if header .Algorithm != "RS256" {
1162- t .Errorf ("Algorithm: %q; want: 'RS256'" , header .Algorithm )
1198+ return fmt .Errorf ("Algorithm: %q; want: 'RS256'" , header .Algorithm )
11631199 } else if header .Type != "JWT" {
1164- t .Errorf ("Type: %q; want: 'JWT'" , header .Type )
1200+ return fmt .Errorf ("Type: %q; want: 'JWT'" , header .Type )
11651201 } else if payload .Aud != firebaseAudience {
1166- t .Errorf ("Audience: %q; want: %q" , payload .Aud , firebaseAudience )
1202+ return fmt .Errorf ("Audience: %q; want: %q" , payload .Aud , firebaseAudience )
11671203 } else if payload .Iss != email {
1168- t .Errorf ("Issuer: %q; want: %q" , payload .Iss , email )
1204+ return fmt .Errorf ("Issuer: %q; want: %q" , payload .Iss , email )
11691205 } else if payload .Sub != email {
1170- t .Errorf ("Subject: %q; want: %q" , payload .Sub , email )
1206+ return fmt .Errorf ("Subject: %q; want: %q" , payload .Sub , email )
11711207 }
11721208
11731209 now := testClock .Now ().Unix ()
11741210 if payload .Exp != now + 3600 {
1175- t .Errorf ("Exp: %d; want: %d" , payload .Exp , now + 3600 )
1211+ return fmt .Errorf ("Exp: %d; want: %d" , payload .Exp , now + 3600 )
11761212 }
11771213 if payload .Iat != now {
1178- t .Errorf ("Iat: %d; want: %d" , payload .Iat , now )
1214+ return fmt .Errorf ("Iat: %d; want: %d" , payload .Iat , now )
11791215 }
11801216
11811217 for k , v := range expected {
11821218 if payload .Claims [k ] != v {
1183- t .Errorf ("Claim[%q]: %v; want: %v" , k , payload .Claims [k ], v )
1219+ return fmt .Errorf ("Claim[%q]: %v; want: %v" , k , payload .Claims [k ], v )
11841220 }
11851221 }
1222+
1223+ if payload .TenantID != tenantID {
1224+ return fmt .Errorf ("Tenant ID: %q; want: %q" , payload .TenantID , tenantID )
1225+ }
1226+
1227+ return nil
11861228}
11871229
11881230func logFatal (err error ) {
0 commit comments