Skip to content

Commit 8c7bd77

Browse files
committed
impl model list caching
1 parent 2d785ab commit 8c7bd77

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

src/api/models/bedrock.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
ENABLE_CROSS_REGION_INFERENCE,
4747
ENABLE_APPLICATION_INFERENCE_PROFILES,
4848
MAX_RETRIES_AWS,
49+
MODEL_CACHE_TTL,
4950
)
5051

5152
logger = logging.getLogger(__name__)
@@ -171,15 +172,40 @@ def list_bedrock_models() -> dict:
171172
return model_list
172173

173174

175+
# In-memory cache
176+
_model_cache = {
177+
"data": None,
178+
"timestamp": 0
179+
}
180+
181+
def _get_cached_models():
182+
"""Get models from in-memory cache if still valid."""
183+
global _model_cache
184+
185+
current_time = time.time()
186+
cache_age = current_time - _model_cache["timestamp"]
187+
188+
if _model_cache["data"] is None or cache_age > MODEL_CACHE_TTL:
189+
fresh_models = list_bedrock_models()
190+
if fresh_models:
191+
_model_cache["data"] = fresh_models
192+
_model_cache["timestamp"] = current_time
193+
return fresh_models
194+
else:
195+
# Cache hit
196+
return _model_cache["data"]
197+
174198
# Initialize the model list.
175-
bedrock_model_list = list_bedrock_models()
199+
bedrock_model_list = _get_cached_models()
176200

177201

178202
class BedrockModel(BaseChatModel):
179203
def list_models(self) -> list[str]:
180-
"""Always refresh the latest model list"""
204+
"""Get model list using in-memory cache with TTL"""
181205
global bedrock_model_list
182-
bedrock_model_list = list_bedrock_models()
206+
cached_models = _get_cached_models()
207+
if cached_models:
208+
bedrock_model_list = cached_models
183209
return list(bedrock_model_list.keys())
184210

185211
def validate(self, chat_request: ChatRequest):

src/api/setting.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false"
2121
MAX_RETRIES_AWS = int(os.environ.get("MAX_RETRIES_AWS", "3"))
2222
ENABLE_APPLICATION_INFERENCE_PROFILES = os.environ.get("ENABLE_APPLICATION_INFERENCE_PROFILES", "true").lower() != "false"
23+
MODEL_CACHE_TTL = int(os.environ.get("MODEL_CACHE_TTL", "3600")) # 1 hour default

0 commit comments

Comments
 (0)