diff --git a/src/authenticator/passport.ts b/src/authenticator/passport.ts index 890428f..9c555f5 100644 --- a/src/authenticator/passport.ts +++ b/src/authenticator/passport.ts @@ -27,7 +27,7 @@ export class PassportAuthenticator implements ServiceAuthenticator { return this.authenticator; } - public getRoles(req: express.Request): Array { + public async getRoles(req: express.Request): Promise> { const roleKey = this.options.rolesKey || 'roles'; return _.castArray(_.get(req.user, roleKey, [])); } diff --git a/src/server/model/server-types.ts b/src/server/model/server-types.ts index df41897..db8ecd5 100644 --- a/src/server/model/server-types.ts +++ b/src/server/model/server-types.ts @@ -102,7 +102,7 @@ export interface ServiceAuthenticator { /** * Get the user list of roles. */ - getRoles: (req: express.Request) => Array; + getRoles: (req: express.Request) => Promise>; /** * Initialize the authenticator */ diff --git a/src/server/server-container.ts b/src/server/server-container.ts index 7854dff..463d78d 100644 --- a/src/server/server-container.ts +++ b/src/server/server-container.ts @@ -351,16 +351,19 @@ export class ServerContainer { } private buildAuthMiddleware(authenticator: ServiceAuthenticator, roles: Array): express.RequestHandler { - return (req: Request, res: Response, next: NextFunction) => { - const requestRoles = authenticator.getRoles(req); - if (this.debugger.runtime.enabled) { - this.debugger.runtime('Validating authentication roles: <%j>.', requestRoles); - } - if (requestRoles.some((role: string) => roles.indexOf(role) >= 0)) { - next(); - } - else { - throw new Errors.ForbiddenError(); + return async (req: Request, res: Response, next: NextFunction) => { + try { + const requestRoles = await authenticator.getRoles(req); + if (this.debugger.runtime.enabled) { + this.debugger.runtime('Validating authentication roles: <%j>.', requestRoles); + } + if (requestRoles.some((role: string) => roles.indexOf(role) >= 0)) { + next(); + } else { + throw new Errors.ForbiddenError(); + } + } catch (err) { + next(err); } }; }