1- import type {
2- CorsOptions ,
3- Middleware ,
4- } from '../../types/rest.js' ;
1+ import type { CorsOptions , Middleware } from '../../types/rest.js' ;
52import {
63 DEFAULT_CORS_OPTIONS ,
74 HttpErrorCodes ,
85 HttpVerbs ,
96} from '../constants.js' ;
107
11- /**
12- * Resolves the origin value based on the configuration
13- */
14- const resolveOrigin = (
15- originConfig : NonNullable < CorsOptions [ 'origin' ] > ,
16- requestOrigin : string | null ,
17- ) : string => {
18- if ( Array . isArray ( originConfig ) ) {
19- return requestOrigin && originConfig . includes ( requestOrigin ) ? requestOrigin : '' ;
20- }
21- return originConfig ;
22- } ;
23-
248/**
259 * Creates a CORS middleware that adds appropriate CORS headers to responses
2610 * and handles OPTIONS preflight requests.
@@ -29,9 +13,9 @@ const resolveOrigin = (
2913 * ```typescript
3014 * import { Router } from '@aws-lambda-powertools/event-handler/experimental-rest';
3115 * import { cors } from '@aws-lambda-powertools/event-handler/experimental-rest/middleware';
32- *
16+ *
3317 * const app = new Router();
34- *
18+ *
3519 * // Use default configuration
3620 * app.use(cors());
3721 *
@@ -50,7 +34,7 @@ const resolveOrigin = (
5034 * }
5135 * }));
5236 * ```
53- *
37+ *
5438 * @param options.origin - The origin to allow requests from
5539 * @param options.allowMethods - The HTTP methods to allow
5640 * @param options.allowHeaders - The headers to allow
@@ -61,38 +45,93 @@ const resolveOrigin = (
6145export const cors = ( options ?: CorsOptions ) : Middleware => {
6246 const config = {
6347 ...DEFAULT_CORS_OPTIONS ,
64- ...options
48+ ...options ,
6549 } ;
50+ const allowedOrigins =
51+ typeof config . origin === 'string' ? [ config . origin ] : config . origin ;
52+ const allowsWildcard = allowedOrigins . includes ( '*' ) ;
53+ const allowedMethods = config . allowMethods . map ( ( method ) =>
54+ method . toUpperCase ( )
55+ ) ;
56+ const allowedHeaders = config . allowHeaders . map ( ( header ) =>
57+ header . toLowerCase ( )
58+ ) ;
6659
67- return async ( _params , reqCtx , next ) => {
68- const requestOrigin = reqCtx . request . headers . get ( 'Origin' ) ;
69- const resolvedOrigin = resolveOrigin ( config . origin , requestOrigin ) ;
60+ const isOriginAllowed = (
61+ requestOrigin : string | null
62+ ) : requestOrigin is string => {
63+ return (
64+ requestOrigin !== null &&
65+ ( allowsWildcard || allowedOrigins . includes ( requestOrigin ) )
66+ ) ;
67+ } ;
7068
71- reqCtx . res . headers . set ( 'access-control-allow-origin' , resolvedOrigin ) ;
72- if ( resolvedOrigin !== '*' ) {
73- reqCtx . res . headers . set ( 'Vary' , 'Origin' ) ;
69+ const isValidPreflightRequest = ( requestHeaders : Headers ) => {
70+ const accessControlRequestMethod = requestHeaders
71+ . get ( 'Access-Control-Request-Method' )
72+ ?. toUpperCase ( ) ;
73+ const accessControlRequestHeaders = requestHeaders
74+ . get ( 'Access-Control-Request-Headers' )
75+ ?. toLowerCase ( ) ;
76+ return (
77+ accessControlRequestMethod &&
78+ allowedMethods . includes ( accessControlRequestMethod ) &&
79+ accessControlRequestHeaders
80+ ?. split ( ',' )
81+ . every ( ( header ) => allowedHeaders . includes ( header . trim ( ) ) )
82+ ) ;
83+ } ;
84+
85+ const setCORSBaseHeaders = (
86+ requestOrigin : string ,
87+ responseHeaders : Headers
88+ ) => {
89+ const resolvedOrigin = allowsWildcard ? '*' : requestOrigin ;
90+ responseHeaders . set ( 'access-control-allow-origin' , resolvedOrigin ) ;
91+ if ( ! allowsWildcard && Array . isArray ( config . origin ) ) {
92+ responseHeaders . set ( 'vary' , 'Origin' ) ;
7493 }
75- config . allowMethods . forEach ( method => {
76- reqCtx . res . headers . append ( 'access-control-allow-methods' , method ) ;
77- } ) ;
78- config . allowHeaders . forEach ( header => {
79- reqCtx . res . headers . append ( 'access-control-allow-headers' , header ) ;
80- } ) ;
81- config . exposeHeaders . forEach ( header => {
82- reqCtx . res . headers . append ( 'access-control-expose-headers' , header ) ;
83- } ) ;
84- reqCtx . res . headers . set ( 'access-control-allow-credentials' , config . credentials . toString ( ) ) ;
85- if ( config . maxAge !== undefined ) {
86- reqCtx . res . headers . set ( 'access-control-max-age' , config . maxAge . toString ( ) ) ;
94+ if ( config . credentials ) {
95+ responseHeaders . set ( 'access-control-allow-credentials' , 'true' ) ;
96+ }
97+ } ;
98+
99+ return async ( _params , reqCtx , next ) => {
100+ const requestOrigin = reqCtx . request . headers . get ( 'Origin' ) ;
101+ if ( ! isOriginAllowed ( requestOrigin ) ) {
102+ await next ( ) ;
103+ return ;
87104 }
88105
89106 // Handle preflight OPTIONS request
90- if ( reqCtx . request . method === HttpVerbs . OPTIONS && reqCtx . request . headers . has ( 'Access-Control-Request-Method' ) ) {
107+ if ( reqCtx . request . method === HttpVerbs . OPTIONS ) {
108+ if ( ! isValidPreflightRequest ( reqCtx . request . headers ) ) {
109+ await next ( ) ;
110+ return ;
111+ }
112+ setCORSBaseHeaders ( requestOrigin , reqCtx . res . headers ) ;
113+ if ( config . maxAge !== undefined ) {
114+ reqCtx . res . headers . set (
115+ 'access-control-max-age' ,
116+ config . maxAge . toString ( )
117+ ) ;
118+ }
119+ for ( const method of allowedMethods ) {
120+ reqCtx . res . headers . append ( 'access-control-allow-methods' , method ) ;
121+ }
122+ for ( const header of allowedHeaders ) {
123+ reqCtx . res . headers . append ( 'access-control-allow-headers' , header ) ;
124+ }
91125 return new Response ( null , {
92126 status : HttpErrorCodes . NO_CONTENT ,
93127 headers : reqCtx . res . headers ,
94128 } ) ;
95129 }
130+
131+ setCORSBaseHeaders ( requestOrigin , reqCtx . res . headers ) ;
132+ for ( const header of config . exposeHeaders ) {
133+ reqCtx . res . headers . append ( 'access-control-expose-headers' , header ) ;
134+ }
96135 await next ( ) ;
97136 } ;
98137} ;
0 commit comments