|
46 | 46 | ENABLE_CROSS_REGION_INFERENCE, |
47 | 47 | ENABLE_APPLICATION_INFERENCE_PROFILES, |
48 | 48 | MAX_RETRIES_AWS, |
| 49 | + MODEL_CACHE_TTL, |
49 | 50 | ) |
50 | 51 |
|
51 | 52 | logger = logging.getLogger(__name__) |
@@ -171,15 +172,40 @@ def list_bedrock_models() -> dict: |
171 | 172 | return model_list |
172 | 173 |
|
173 | 174 |
|
| 175 | +# In-memory cache |
| 176 | +_model_cache = { |
| 177 | + "data": None, |
| 178 | + "timestamp": 0 |
| 179 | +} |
| 180 | + |
| 181 | +def _get_cached_models(): |
| 182 | + """Get models from in-memory cache if still valid.""" |
| 183 | + global _model_cache |
| 184 | + |
| 185 | + current_time = time.time() |
| 186 | + cache_age = current_time - _model_cache["timestamp"] |
| 187 | + |
| 188 | + if _model_cache["data"] is None or cache_age > MODEL_CACHE_TTL: |
| 189 | + fresh_models = list_bedrock_models() |
| 190 | + if fresh_models: |
| 191 | + _model_cache["data"] = fresh_models |
| 192 | + _model_cache["timestamp"] = current_time |
| 193 | + return fresh_models |
| 194 | + else: |
| 195 | + # Cache hit |
| 196 | + return _model_cache["data"] |
| 197 | + |
174 | 198 | # Initialize the model list. |
175 | | -bedrock_model_list = list_bedrock_models() |
| 199 | +bedrock_model_list = _get_cached_models() |
176 | 200 |
|
177 | 201 |
|
178 | 202 | class BedrockModel(BaseChatModel): |
179 | 203 | def list_models(self) -> list[str]: |
180 | | - """Always refresh the latest model list""" |
| 204 | + """Get model list using in-memory cache with TTL""" |
181 | 205 | global bedrock_model_list |
182 | | - bedrock_model_list = list_bedrock_models() |
| 206 | + cached_models = _get_cached_models() |
| 207 | + if cached_models: |
| 208 | + bedrock_model_list = cached_models |
183 | 209 | return list(bedrock_model_list.keys()) |
184 | 210 |
|
185 | 211 | def validate(self, chat_request: ChatRequest): |
|
0 commit comments