Skip to content

Commit f2bda42

Browse files
authored
Merge pull request #3 from Constantinople-AI/GGGG-370-workers-model-cache
Variable workers & model cache
2 parents 2d785ab + 4934ec0 commit f2bda42

File tree

3 files changed

+32
-4
lines changed

3 files changed

+32
-4
lines changed

src/Dockerfile_ecs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@ RUN pip install --no-cache-dir --upgrade -r /app/requirements.txt
99
COPY ./api /app/api
1010

1111
ENV PORT=80
12+
ENV WORKERS=1
1213

13-
CMD ["sh", "-c", "uvicorn api.app:app --host 0.0.0.0 --port ${PORT}"]
14+
CMD ["sh", "-c", "uvicorn api.app:app --host 0.0.0.0 --port ${PORT} --workers ${WORKERS}"]

src/api/models/bedrock.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
ENABLE_CROSS_REGION_INFERENCE,
4747
ENABLE_APPLICATION_INFERENCE_PROFILES,
4848
MAX_RETRIES_AWS,
49+
MODEL_CACHE_TTL,
4950
)
5051

5152
logger = logging.getLogger(__name__)
@@ -171,15 +172,40 @@ def list_bedrock_models() -> dict:
171172
return model_list
172173

173174

175+
# In-memory cache
176+
_model_cache = {
177+
"data": None,
178+
"timestamp": 0
179+
}
180+
181+
def _get_cached_models():
182+
"""Get models from in-memory cache if still valid."""
183+
global _model_cache
184+
185+
current_time = time.time()
186+
cache_age = current_time - _model_cache["timestamp"]
187+
188+
if _model_cache["data"] is None or cache_age > MODEL_CACHE_TTL:
189+
fresh_models = list_bedrock_models()
190+
if fresh_models:
191+
_model_cache["data"] = fresh_models
192+
_model_cache["timestamp"] = current_time
193+
return fresh_models
194+
else:
195+
# Cache hit
196+
return _model_cache["data"]
197+
174198
# Initialize the model list.
175-
bedrock_model_list = list_bedrock_models()
199+
bedrock_model_list = _get_cached_models()
176200

177201

178202
class BedrockModel(BaseChatModel):
179203
def list_models(self) -> list[str]:
180-
"""Always refresh the latest model list"""
204+
"""Get model list using in-memory cache with TTL"""
181205
global bedrock_model_list
182-
bedrock_model_list = list_bedrock_models()
206+
cached_models = _get_cached_models()
207+
if cached_models:
208+
bedrock_model_list = cached_models
183209
return list(bedrock_model_list.keys())
184210

185211
def validate(self, chat_request: ChatRequest):

src/api/setting.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false"
2121
MAX_RETRIES_AWS = int(os.environ.get("MAX_RETRIES_AWS", "3"))
2222
ENABLE_APPLICATION_INFERENCE_PROFILES = os.environ.get("ENABLE_APPLICATION_INFERENCE_PROFILES", "true").lower() != "false"
23+
MODEL_CACHE_TTL = int(os.environ.get("MODEL_CACHE_TTL", "3600")) # 1 hour default

0 commit comments

Comments
 (0)