|
15 | 15 | import logging |
16 | 16 | import os |
17 | 17 |
|
| 18 | +from retrying import retry |
| 19 | +from botocore.config import Config |
| 20 | +from botocore.exceptions import ClientError |
18 | 21 | from langchain_aws import ChatBedrock |
19 | 22 | from langchain_core.messages import HumanMessage |
20 | 23 | from langchain_core.prompts import ChatPromptTemplate |
|
24 | 27 | logger = logging.getLogger() |
25 | 28 | logger.setLevel(os.getenv("LOG_LEVEL", "INFO")) |
26 | 29 |
|
27 | | -bedrock_client = boto3.client('bedrock-runtime') |
| 30 | +bedrock_client = boto3.client('bedrock-runtime', config=Config( |
| 31 | + connect_timeout=180, |
| 32 | + read_timeout=180, |
| 33 | + retries={ |
| 34 | + "max_attempts": 50, |
| 35 | + "mode": "adaptive", |
| 36 | + }, |
| 37 | +)) |
28 | 38 |
|
| 39 | +class BedrockRetryableError(Exception): |
| 40 | + """Custom exception for retryable Bedrock errors""" |
| 41 | + pass |
| 42 | + |
| 43 | +@retry( |
| 44 | + wait_fixed=10000, # 10 seconds between retries |
| 45 | + stop_max_attempt_number=None, # Keep retrying indefinitely |
| 46 | + retry_on_exception=lambda ex: isinstance(ex, BedrockRetryableError), |
| 47 | +) |
| 48 | +def invoke_chain_with_retry(chain): |
| 49 | + """Invoke Bedrock with retry logic for throttling""" |
| 50 | + try: |
| 51 | + return chain.invoke({}) |
| 52 | + except ClientError as exc: |
| 53 | + logger.warning(f"Bedrock ClientError: {exc}") |
| 54 | + |
| 55 | + if exc.response["Error"]["Code"] == "ThrottlingException": |
| 56 | + logger.warning("Bedrock throttling. Retrying...") |
| 57 | + raise BedrockRetryableError(str(exc)) |
| 58 | + elif exc.response["Error"]["Code"] == "ModelTimeoutException": |
| 59 | + logger.warning("Bedrock ModelTimeoutException. Retrying...") |
| 60 | + raise BedrockRetryableError(str(exc)) |
| 61 | + else: |
| 62 | + raise |
| 63 | + except bedrock_client.exceptions.ThrottlingException as throttlingExc: |
| 64 | + logger.warning("Bedrock ThrottlingException. Retrying...") |
| 65 | + raise BedrockRetryableError(str(throttlingExc)) |
| 66 | + except bedrock_client.exceptions.ModelTimeoutException as timeoutExc: |
| 67 | + logger.warning("Bedrock ModelTimeoutException. Retrying...") |
| 68 | + raise BedrockRetryableError(str(timeoutExc)) |
29 | 69 |
|
30 | 70 | def invoke_llm(prompt, model_id, temperature=0.5, top_k=None, top_p=0.8, max_new_tokens=4096, verbose=False): |
31 | 71 | model_id = (model_id or CLAUDE_MODEL_ID) |
@@ -57,7 +97,7 @@ def invoke_llm(prompt, model_id, temperature=0.5, top_k=None, top_p=0.8, max_new |
57 | 97 | ]) |
58 | 98 | chain = prompt | chat |
59 | 99 |
|
60 | | - response = chain.invoke({}) |
| 100 | + response = invoke_chain_with_retry(chain) |
61 | 101 | content = response.content |
62 | 102 |
|
63 | 103 | usage_data = None |
|
0 commit comments