|
25 | 25 |
|
26 | 26 | from fastapi import Body, FastAPI |
27 | 27 | from fastapi.exceptions import RequestValidationError |
| 28 | +from fastapi.middleware.cors import CORSMiddleware |
28 | 29 | from starlette.requests import Request |
29 | 30 | from starlette.responses import Response |
30 | 31 | from starlette.background import BackgroundTasks |
|
55 | 56 | ) |
56 | 57 |
|
57 | 58 | app = FastAPI() |
| 59 | + |
| 60 | +app.add_middleware( |
| 61 | + CORSMiddleware, |
| 62 | + allow_origins=["*"], |
| 63 | + allow_credentials=True, |
| 64 | + allow_methods=["*"], |
| 65 | + allow_headers=["*"], |
| 66 | +) |
| 67 | + |
58 | 68 | local_cache = {"api": None, "predictor_impl": None, "client": None, "class_set": set()} |
59 | 69 |
|
60 | 70 |
|
@@ -90,21 +100,18 @@ def is_prediction_request(request): |
90 | 100 | @app.exception_handler(StarletteHTTPException) |
91 | 101 | async def http_exception_handler(request, e): |
92 | 102 | response = Response(content=str(e.detail), status_code=e.status_code) |
93 | | - apply_cors_headers(request, response) |
94 | 103 | return response |
95 | 104 |
|
96 | 105 |
|
97 | 106 | @app.exception_handler(RequestValidationError) |
98 | 107 | async def validation_exception_handler(request, e): |
99 | 108 | response = Response(content=str(e), status_code=400) |
100 | | - apply_cors_headers(request, response) |
101 | 109 | return response |
102 | 110 |
|
103 | 111 |
|
104 | 112 | @app.exception_handler(Exception) |
105 | 113 | async def uncaught_exception_handler(request, e): |
106 | 114 | response = Response(content="internal server error", status_code=500) |
107 | | - apply_cors_headers(request, response) |
108 | 115 | return response |
109 | 116 |
|
110 | 117 |
|
@@ -132,20 +139,12 @@ async def register_request(request: Request, call_next): |
132 | 139 | status_code = 500 |
133 | 140 | if response is not None: |
134 | 141 | status_code = response.status_code |
135 | | - apply_cors_headers(request, response) |
136 | 142 | api = local_cache["api"] |
137 | 143 | api.post_request_metrics(status_code, time.time() - request.state.start_time) |
138 | 144 |
|
139 | 145 | return response |
140 | 146 |
|
141 | 147 |
|
142 | | -def apply_cors_headers(request: Request, response: Response): |
143 | | - response.headers["Access-Control-Allow-Origin"] = "*" |
144 | | - response.headers["Access-Control-Allow-Headers"] = request.headers.get( |
145 | | - "Access-Control-Request-Headers", "*" |
146 | | - ) |
147 | | - |
148 | | - |
149 | 148 | @app.post("/predict") |
150 | 149 | def predict(request: Any = Body(..., media_type="application/json"), debug=False): |
151 | 150 | api = local_cache["api"] |
|
0 commit comments