@@ -251,6 +251,7 @@ def __init__( # noqa: PLR0913
251251 self .resource_attributes = resource_attributes
252252 self .passthrough_resource_attributes = passthrough_resource_attributes
253253 self .open_api_version = open_api_version
254+ self .remove_extra_stage = open_api_version
254255 self .models = models
255256 self .domain = domain
256257 self .fail_on_warnings = fail_on_warnings
@@ -399,7 +400,7 @@ def _construct_deployment(self, rest_api: ApiGatewayRestApi) -> ApiGatewayDeploy
399400 self .logical_id + "Deployment" , attributes = self .passthrough_resource_attributes
400401 )
401402 deployment .RestApiId = rest_api .get_runtime_attr ("rest_api_id" )
402- if not self .open_api_version :
403+ if not self .remove_extra_stage :
403404 deployment .StageName = "Stage"
404405
405406 return deployment
@@ -437,7 +438,7 @@ def _construct_stage(
437438 if swagger is not None :
438439 deployment .make_auto_deployable (
439440 stage ,
440- self .open_api_version ,
441+ self .remove_extra_stage ,
441442 swagger ,
442443 self .domain ,
443444 redeploy_restapi_parameters ,
@@ -1124,10 +1125,11 @@ def _openapi_postprocess(self, definition_body: Dict[str, Any]) -> Dict[str, Any
11241125 if definition_body .get ("swagger" ) is not None :
11251126 return definition_body
11261127
1127- normalized_open_api_version = definition_body .get ("openapi" , self .open_api_version )
1128+ if definition_body .get ("openapi" ) is not None and self .open_api_version is None :
1129+ self .open_api_version = definition_body .get ("openapi" )
11281130
1129- if normalized_open_api_version and SwaggerEditor .safe_compare_regex_with_string (
1130- SwaggerEditor ._OPENAPI_VERSION_3_REGEX , normalized_open_api_version
1131+ if self . open_api_version and SwaggerEditor .safe_compare_regex_with_string (
1132+ SwaggerEditor ._OPENAPI_VERSION_3_REGEX , self . open_api_version
11311133 ):
11321134 if definition_body .get ("securityDefinitions" ):
11331135 components = definition_body .get ("components" , Py27Dict ())
0 commit comments