1717try :
1818 import janus
1919 import uvicorn
20- from fastapi import APIRouter , FastAPI , File , Form , UploadFile , WebSocket
21- from fastapi .responses import PlainTextResponse , StreamingResponse
20+ from fastapi import (
21+ APIRouter ,
22+ FastAPI ,
23+ File ,
24+ Form ,
25+ HTTPException ,
26+ Request ,
27+ UploadFile ,
28+ WebSocket ,
29+ )
30+ from fastapi .responses import JSONResponse , PlainTextResponse , StreamingResponse
31+ from starlette .status import HTTP_403_FORBIDDEN
2232except :
2333 # Server dependencies are not required by the main package.
2434 pass
@@ -204,6 +214,24 @@ def accumulate(self, chunk):
204214 self .messages [- 1 ]["content" ] += chunk
205215
206216
217+ def authenticate_function (key ):
218+ """
219+ This function checks if the provided key is valid for authentication.
220+
221+ Returns True if the key is valid, False otherwise.
222+ """
223+ # Fetch the API key from the environment variables. If it's not set, return True.
224+ api_key = os .getenv ("INTERPRETER_API_KEY" , None )
225+
226+ # If the API key is not set in the environment variables, return True.
227+ # Otherwise, check if the provided key matches the fetched API key.
228+ # Return True if they match, False otherwise.
229+ if api_key is None :
230+ return True
231+ else :
232+ return key == api_key
233+
234+
207235def create_router (async_interpreter ):
208236 router = APIRouter ()
209237
@@ -226,6 +254,7 @@ async def home():
226254 <button>Send</button>
227255 </form>
228256 <button id="approveCodeButton">Approve Code</button>
257+ <button id="authButton">Send Auth</button>
229258 <div id="messages"></div>
230259 <script>
231260 var ws = new WebSocket("ws://"""
@@ -234,6 +263,7 @@ async def home():
234263 + str (async_interpreter .server .port )
235264 + """/");
236265 var lastMessageElement = null;
266+
237267 ws.onmessage = function(event) {
238268
239269 var eventData = JSON.parse(event.data);
@@ -326,8 +356,15 @@ async def home():
326356 };
327357 ws.send(JSON.stringify(endCommandBlock));
328358 }
359+ function authenticate() {
360+ var authBlock = {
361+ "auth": "dummy-api-key"
362+ };
363+ ws.send(JSON.stringify(authBlock));
364+ }
329365
330366 document.getElementById("approveCodeButton").addEventListener("click", approveCode);
367+ document.getElementById("authButton").addEventListener("click", authenticate);
331368 </script>
332369 </body>
333370 </html>
@@ -338,13 +375,30 @@ async def home():
338375 @router .websocket ("/" )
339376 async def websocket_endpoint (websocket : WebSocket ):
340377 await websocket .accept ()
378+
341379 try :
342380
343381 async def receive_input ():
382+ authenticated = False
344383 while True :
345384 try :
346385 data = await websocket .receive ()
347386
387+ if not authenticated :
388+ if "text" in data :
389+ data = json .loads (data ["text" ])
390+ if "auth" in data :
391+ if async_interpreter .server .authenticate (
392+ data ["auth" ]
393+ ):
394+ authenticated = True
395+ await websocket .send_text (
396+ json .dumps ({"auth" : True })
397+ )
398+ if not authenticated :
399+ await websocket .send_text (json .dumps ({"auth" : False }))
400+ continue
401+
348402 if data .get ("type" ) == "websocket.receive" :
349403 if "text" in data :
350404 data = json .loads (data ["text" ])
@@ -474,19 +528,6 @@ async def post_input(payload: Dict[str, Any]):
474528 except Exception as e :
475529 return {"error" : str (e )}, 500
476530
477- @router .post ("/run" )
478- async def run_code (payload : Dict [str , Any ]):
479- language , code = payload .get ("language" ), payload .get ("code" )
480- if not (language and code ):
481- return {"error" : "Both 'language' and 'code' are required." }, 400
482- try :
483- print (f"Running { language } :" , code )
484- output = async_interpreter .computer .run (language , code )
485- print ("Output:" , output )
486- return {"output" : output }
487- except Exception as e :
488- return {"error" : str (e )}, 500
489-
490531 @router .post ("/settings" )
491532 async def set_settings (payload : Dict [str , Any ]):
492533 for key , value in payload .items ():
@@ -520,23 +561,38 @@ async def get_setting(setting: str):
520561 else :
521562 return json .dumps ({"error" : "Setting not found" }), 404
522563
523- @router .post ("/upload" )
524- async def upload_file (file : UploadFile = File (...), path : str = Form (...)):
525- try :
526- with open (path , "wb" ) as output_file :
527- shutil .copyfileobj (file .file , output_file )
528- return {"status" : "success" }
529- except Exception as e :
530- return {"error" : str (e )}, 500
564+ if os .getenv ("INTERPRETER_INSECURE_ROUTES" , "" ).lower () == "true" :
531565
532- @router .get ("/download/{filename}" )
533- async def download_file (filename : str ):
534- try :
535- return StreamingResponse (
536- open (filename , "rb" ), media_type = "application/octet-stream"
537- )
538- except Exception as e :
539- return {"error" : str (e )}, 500
566+ @router .post ("/run" )
567+ async def run_code (payload : Dict [str , Any ]):
568+ language , code = payload .get ("language" ), payload .get ("code" )
569+ if not (language and code ):
570+ return {"error" : "Both 'language' and 'code' are required." }, 400
571+ try :
572+ print (f"Running { language } :" , code )
573+ output = async_interpreter .computer .run (language , code )
574+ print ("Output:" , output )
575+ return {"output" : output }
576+ except Exception as e :
577+ return {"error" : str (e )}, 500
578+
579+ @router .post ("/upload" )
580+ async def upload_file (file : UploadFile = File (...), path : str = Form (...)):
581+ try :
582+ with open (path , "wb" ) as output_file :
583+ shutil .copyfileobj (file .file , output_file )
584+ return {"status" : "success" }
585+ except Exception as e :
586+ return {"error" : str (e )}, 500
587+
588+ @router .get ("/download/{filename}" )
589+ async def download_file (filename : str ):
590+ try :
591+ return StreamingResponse (
592+ open (filename , "rb" ), media_type = "application/octet-stream"
593+ )
594+ except Exception as e :
595+ return {"error" : str (e )}, 500
540596
541597 ### OPENAI COMPATIBLE ENDPOINT
542598
@@ -648,6 +704,21 @@ class Server:
648704 def __init__ (self , async_interpreter , host = "127.0.0.1" , port = 8000 ):
649705 self .app = FastAPI ()
650706 router = create_router (async_interpreter )
707+ self .authenticate = authenticate_function
708+
709+ # Add authentication middleware
710+ @self .app .middleware ("http" )
711+ async def validate_api_key (request : Request , call_next ):
712+ api_key = request .headers .get ("X-API-KEY" )
713+ if self .authenticate (api_key ):
714+ response = await call_next (request )
715+ return response
716+ else :
717+ return JSONResponse (
718+ status_code = HTTP_403_FORBIDDEN ,
719+ content = {"detail" : "Authentication failed" },
720+ )
721+
651722 self .app .include_router (router )
652723 self .config = uvicorn .Config (app = self .app , host = host , port = port )
653724 self .uvicorn_server = uvicorn .Server (self .config )
0 commit comments