Skip to content

Commit 8e8dfe1

Browse files
committed
Fixed and tested
1 parent 9124d2c commit 8e8dfe1

File tree

2 files changed

+26
-11
lines changed

2 files changed

+26
-11
lines changed

interpreter/core/async_core.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -700,17 +700,8 @@ async def chat_completion(request: ChatCompletionRequest):
700700
return router
701701

702702

703-
host = os.getenv(
704-
"HOST", "127.0.0.1"
705-
) # IP address for localhost, used for local testing. To expose to local network, use 0.0.0.0
706-
port = int(os.getenv("PORT", 8000)) # Default port is 8000
707-
708-
# FOR TESTING ONLY
709-
# host = "0.0.0.0"
710-
711-
712703
class Server:
713-
def __init__(self, async_interpreter, host="127.0.0.1", port=8000):
704+
def __init__(self, async_interpreter, host=None, port=None):
714705
self.app = FastAPI()
715706
router = create_router(async_interpreter)
716707
self.authenticate = authenticate_function
@@ -729,7 +720,7 @@ async def validate_api_key(request: Request, call_next):
729720
)
730721

731722
self.app.include_router(router)
732-
self.config = uvicorn.Config(app=self.app, host=host, port=port)
723+
self.config = uvicorn.Config(app=self.app, host=host or os.getenv("HOST", "127.0.0.1"), port=port or int(os.getenv("PORT", "8000")))
733724
self.uvicorn_server = uvicorn.Server(self.config)
734725

735726
@property

tests/core/test_async_core.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import os
2+
from unittest import TestCase, mock
3+
4+
from interpreter.core.async_core import Server, AsyncInterpreter
5+
6+
7+
class TestServerConstruction(TestCase):
8+
def test_host_and_port_from_env_1(self):
9+
fake_host = "fake_host"
10+
fake_port = 1234
11+
12+
with mock.patch.dict(os.environ, {"HOST": fake_host, "PORT": str(fake_port)}):
13+
s = Server(AsyncInterpreter())
14+
self.assertEqual(s.host, fake_host)
15+
self.assertEqual(s.port, fake_port)
16+
17+
def test_host_and_port_from_env_2(self):
18+
fake_host = "some-other-fake-host"
19+
fake_port = 4321
20+
21+
with mock.patch.dict(os.environ, {"HOST": fake_host, "PORT": str(fake_port)}):
22+
s = Server(AsyncInterpreter())
23+
self.assertEqual(s.host, fake_host)
24+
self.assertEqual(s.port, fake_port)

0 commit comments

Comments
 (0)