From 406617ddbd162f158c63c057ba6e5fdbdac9a60c Mon Sep 17 00:00:00 2001 From: Kazmer Nagy-Betegh Date: Wed, 19 Nov 2025 16:53:51 +0200 Subject: [PATCH 01/18] bedrock utils for retry --- .../idp_common/utils/bedrock_utils.py | 409 ++++++++++++++++++ 1 file changed, 409 insertions(+) create mode 100644 lib/idp_common_pkg/idp_common/utils/bedrock_utils.py diff --git a/lib/idp_common_pkg/idp_common/utils/bedrock_utils.py b/lib/idp_common_pkg/idp_common/utils/bedrock_utils.py new file mode 100644 index 00000000..04885781 --- /dev/null +++ b/lib/idp_common_pkg/idp_common/utils/bedrock_utils.py @@ -0,0 +1,409 @@ +import asyncio +import json +import logging +import os +import random +import time +from collections.abc import Awaitable, Callable +from functools import wraps +from typing import Unpack + +import botocore.exceptions +from mypy_boto3_bedrock_runtime import BedrockRuntimeClient +from mypy_boto3_bedrock_runtime.type_defs import ( + ConverseRequestTypeDef, + ConverseResponseTypeDef, + ConverseStreamRequestTypeDef, + ConverseStreamResponseTypeDef, + InvokeModelRequestTypeDef, + InvokeModelResponseTypeDef, +) +from pydantic_core import ArgsKwargs + +# Configure logger +logger = logging.getLogger(__name__) +logger.setLevel(os.environ.get("LOG_LEVEL", "INFO")) + + +def async_exponential_backoff_retry[T, **P]( + max_retries: int = 5, + initial_delay: float = 1.0, + max_delay: float = 32.0, + exponential_base: float = 2.0, + jitter: float = 0.1, + retryable_errors: list[str] | None = None, +) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]: + if not retryable_errors: + retryable_errors = [ + "ThrottlingException", + "throttlingException", + "ModelErrorException", + "ValidationException", + ] + + def decorator(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]: + @wraps(func) + async def wrapper(*args, **kwargs) -> T: + delay = initial_delay + + def log_bedrock_invocation_error(error: Exception, attempt_num: int): + """Log bedrock invocation details when an error occurs""" + # Fallback logging if extraction fails + logger.error( + "Bedrock invocation error", + extra={ + "function_name": func.__name__, + "original_error": str(error), + "max_attempts": max_retries, + "attempt_num":attempt_num + }, + ) + + for attempt in range(max_retries): + try: + return await func(*args, **kwargs) + except botocore.exceptions.ClientError as e: + error_code = e.response.get("Error", {}).get("Code") + + # Log bedrock invocation details for all errors + log_bedrock_invocation_error(e, attempt + 1) + + if ( + error_code == "ValidationException" + and "Output blocked by content filtering policy" + not in e.response.get("Error", {}).get("Message", "") + ): + raise + if error_code not in retryable_errors or attempt == max_retries - 1: + raise + + jitter_value = random.uniform(-jitter, jitter) + sleep_time = max(0.1, delay * (1 + jitter_value)) + logger.warning( + f"{error_code}:{e.response.get('Error', {}).get('Message', '')} encountered in {func.__name__}. Retrying in {sleep_time:.2f} seconds. " + f"Attempt {attempt + 1}/{max_retries}" + ) + await asyncio.sleep(sleep_time) + delay = min(delay * exponential_base, max_delay) + except Exception as e: + # Log bedrock invocation details for non-ClientError exceptions too + log_bedrock_invocation_error(e, attempt + 1) + raise + + return await func(*args, **kwargs) + + return wrapper + + return decorator + + +def exponential_backoff_retry[T, **P]( + max_retries: int = 5, + initial_delay: float = 1.0, + max_delay: float = 32.0, + exponential_base: float = 2.0, + jitter: float = 0.1, +) -> Callable[[Callable[P, T]], Callable[P, T]]: + def decorator(func: Callable[P, T]) -> Callable[P, T]: + @wraps(func) + def wrapper(*args, **kwargs) -> T: + delay = initial_delay + + def log_bedrock_invocation_error(error: Exception, attempt_num: int): + """Log bedrock invocation details when an error occurs""" + try: + # Check for invoke_model API (has 'body' parameter) + if "body" in kwargs: + logger.error( + "Bedrock invoke_model failed", + extra={ + "attempt_number": attempt_num, + "max_retries": max_retries, + "function_name": func.__name__, + "error": str(error), + "body": kwargs["body"], + }, + ) + # Check for converse API (has structured parameters) + elif any( + key in kwargs + for key in [ + "messages", + "inferenceConfig", + "system", + "toolConfig", + ] + ): + # Log converse API parameters + converse_data = { + k: v + for k, v in kwargs.items() + if k + in [ + "messages", + "inferenceConfig", + "system", + "toolConfig", + "additionalModelRequestFields", + "guardrailConfig", + "performanceConfig", + "promptVariables", + "requestMetadata", + ] + } + logger.error( + "Bedrock converse failed", + extra={ + "attempt_number": attempt_num, + "max_retries": max_retries, + "function_name": func.__name__, + "error": str(error), + "parameters": json.dumps(converse_data, default=str), + }, + ) + else: + # Generic bedrock error logging + logger.error( + "Bedrock invocation failed", + extra={ + "attempt_number": attempt_num, + "max_retries": max_retries, + "function_name": func.__name__, + "error": str(error), + }, + ) + + except Exception as log_error: + # Fallback logging if extraction fails + logger.error( + "Failed to log bedrock invocation details", + extra={ + "function_name": func.__name__, + "log_error": str(log_error), + "original_error": str(error), + }, + ) + + for attempt in range(max_retries): + try: + return func(*args, **kwargs) + except botocore.exceptions.ClientError as e: + error_code = e.response.get("Error", {}).get("Code") + + # Log bedrock invocation details for all errors + log_bedrock_invocation_error(e, attempt + 1) + + if ( + error_code == "ValidationException" + and "Output blocked by content filtering policy" + not in e.response.get("Error", {}).get("Message", "") + ): + raise + if ( + error_code + not in [ + "ThrottlingException", + "ModelErrorException", + "ValidationException", + ] + or attempt == max_retries - 1 + ): + raise + + jitter_value = random.uniform(-jitter, jitter) + sleep_time = max(0.1, delay * (1 + jitter_value)) + logger.warning( + f"{error_code}:{e.response.get('Error', {}).get('Message', '')} encountered in {func.__name__}. Retrying in {sleep_time:.2f} seconds. " + f"Attempt {attempt + 1}/{max_retries}" + ) + time.sleep(sleep_time) + delay = min(delay * exponential_base, max_delay) + except Exception as e: + # Log bedrock invocation details for non-ClientError exceptions too + log_bedrock_invocation_error(e, attempt + 1) + raise + + return func(*args, **kwargs) + + return wrapper + + return decorator + + +class BedrockClientWrapper: + """ + A wrapper around AWS Bedrock Runtime Client that provides automatic retry logic + with exponential backoff for handling transient errors and rate limiting. + + This wrapper automatically retries failed requests for specific error types: + - ThrottlingException: When API rate limits are exceeded + - ModelErrorException: When the model encounters temporary errors + - ValidationException: When content filtering blocks output (retryable case) + + The retry mechanism uses exponential backoff with jitter to avoid thundering herd + problems when multiple clients retry simultaneously. + + Attributes: + client (BedrockRuntimeClient): The underlying AWS Bedrock Runtime client + max_retries (int): Maximum number of retry attempts + initial_delay (float): Initial delay between retries in seconds + max_delay (float): Maximum delay between retries in seconds + exponential_base (float): Base for exponential backoff calculation + jitter (float): Random jitter factor to add variance to retry delays + invoke_model: Wrapped invoke_model method with retry logic + converse: Wrapped converse method with retry logic + + Example: + >>> import boto3 + >>> from mypy_boto3_bedrock_runtime import BedrockRuntimeClient + >>> bedrock_client = boto3.client("bedrock-runtime", region_name="us-east-1") + >>> wrapper = BedrockClientWrapper(bedrock_client, max_retries=3) + >>> # Use invoke_model with automatic retries + >>> response = wrapper.invoke_model( + ... modelId="anthropic.claude-3-sonnet-20240229-v1:0", + ... body=json.dumps( + ... { + ... "messages": [{"role": "user", "content": "Hello"}], + ... "max_tokens": 100, + ... } + ... ), + ... ) + >>> # Use converse API with automatic retries + >>> response = wrapper.converse( + ... modelId="anthropic.claude-3-sonnet-20240229-v1:0", + ... messages=[{"role": "user", "content": [{"text": "Hello"}]}], + ... ) + """ + + def __init__( + self, + bedrock_client: BedrockRuntimeClient, + max_retries: int = 5, + initial_delay: float = 1.0, + max_delay: float = 32.0, + exponential_base: float = 2.0, + jitter: float = 0.1, + ): + """ + Initialize the BedrockClientWrapper with retry configuration. + + Args: + bedrock_client (BedrockRuntimeClient): The AWS Bedrock Runtime client to wrap + max_retries (int, optional): Maximum number of retry attempts. Defaults to 5. + initial_delay (float, optional): Initial delay between retries in seconds. Defaults to 1.0. + max_delay (float, optional): Maximum delay between retries in seconds. Defaults to 32.0. + exponential_base (float, optional): Base for exponential backoff calculation. Defaults to 2.0. + jitter (float, optional): Random jitter factor (0.0-1.0) to add variance to retry delays. Defaults to 0.1. + + Raises: + TypeError: If bedrock_client is not a BedrockRuntimeClient instance + ValueError: If retry parameters are invalid (negative values, jitter > 1.0, etc.) + """ + self.client = bedrock_client + + self.max_retries = max_retries + self.initial_delay = initial_delay + self.max_delay = max_delay + self.exponential_base = exponential_base + self.jitter = jitter + + # Apply decorator directly to client methods + self._decorated_invoke_model = exponential_backoff_retry( + max_retries=max_retries, + initial_delay=initial_delay, + max_delay=max_delay, + exponential_base=exponential_base, + jitter=jitter, + )(self.client.invoke_model) + + self._decorated_converse = exponential_backoff_retry( + max_retries=max_retries, + initial_delay=initial_delay, + max_delay=max_delay, + exponential_base=exponential_base, + jitter=jitter, + )(self.client.converse) + + self._decorated_converse_stream_async = exponential_backoff_retry( + max_retries=max_retries, + initial_delay=initial_delay, + max_delay=max_delay, + exponential_base=exponential_base, + jitter=jitter, + )(self.client.converse_stream) + + def invoke_model( + self, **kwargs: Unpack[InvokeModelRequestTypeDef] + ) -> InvokeModelResponseTypeDef: + """ + Invoke a model with automatic retry logic. + + This method has the same signature as BedrockRuntimeClient.invoke_model() + but includes automatic retry logic with exponential backoff. + + Args: + modelId: The ID or ARN of the model to invoke + body: The input data to send to the model + contentType: The MIME type of the input data + accept: The desired MIME type of the response + **kwargs: Additional arguments passed to the underlying API + + Returns: + InvokeModelResponseTypeDef: The response from the model invocation + + Raises: + botocore.exceptions.ClientError: For non-retryable errors or after max retries + """ + return self._decorated_invoke_model(**kwargs) + + def converse( + self, + **kwargs: Unpack[ConverseRequestTypeDef], + ) -> ConverseResponseTypeDef: + """ + Converse with a model using the conversation API with automatic retry logic. + + This method has the same signature as BedrockRuntimeClient.converse() + but includes automatic retry logic with exponential backoff. + + Args: + modelId: The ID or ARN of the model to invoke + messages: The conversation messages + system: System prompts to provide context + inferenceConfig: Configuration for model inference parameters + toolConfig: Configuration for tool use + guardrailConfig: Configuration for content filtering + additionalModelRequestFields: Additional model-specific request fields + promptVariables: Variables to substitute in prompts + additionalModelResponseFieldPaths: Additional response field paths + performanceConfig: Performance optimization configuration + requestMetadata: Metadata for the request + **kwargs: Additional arguments passed to the underlying API + + Returns: + ConverseResponseTypeDef: The response from the conversation + + Raises: + botocore.exceptions.ClientError: For non-retryable errors or after max retries + """ + return self._decorated_converse(**kwargs) + + def converse_stream( + self, **kwargs: Unpack[ConverseStreamRequestTypeDef] + ) -> ConverseStreamResponseTypeDef: + """ + Async version of converse_stream with automatic retry logic. + + This method has the same signature as BedrockRuntimeClient.converse_stream() + but runs asynchronously with automatic retry logic and exponential backoff. + + Args: + **kwargs: All arguments passed to the underlying converse_stream API + + Returns: + The streaming response from the conversation + + Raises: + botocore.exceptions.ClientError: For non-retryable errors or after max retries + """ + return self._decorated_converse_stream_async(**kwargs) From 9d63a54c3d6fd6fee74d1515dd3ab0148aba3a19 Mon Sep 17 00:00:00 2001 From: Kazmer Nagy-Betegh Date: Wed, 19 Nov 2025 17:01:03 +0200 Subject: [PATCH 02/18] invoke extraction with long retries --- .../idp_common/extraction/agentic_idp.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py b/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py index 7d8a2415..a053f837 100644 --- a/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py +++ b/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py @@ -42,6 +42,9 @@ update_todo, view_todo_list, ) +from lib.idp_common_pkg.idp_common.utils.bedrock_utils import ( + async_exponential_backoff_retry, +) # Use AWS Lambda Powertools Logger for structured logging # Automatically logs as JSON with Lambda context, request_id, timestamp, etc. @@ -458,6 +461,16 @@ def patch_buffer_data(patches: list[dict[str, Any]], agent: Agent) -> str: """ +@async_exponential_backoff_retry( + max_retries=50, + initial_delay=5, + max_delay=1800, + jitter=0.5, +) +async def invoke_agent_with_retry(input: Any, agent: Agent): + return await agent.invoke_async(input) + + async def structured_output_async( model_id: str, data_format: type[TargetModel], @@ -755,7 +768,7 @@ async def structured_output_async( for attempt in range(max_retries): try: - response = await agent.invoke_async(prompt_content) + response = await invoke_agent_with_retry(agent=agent, input=prompt_content) logger.debug("Agent response received") break # Success, exit retry loop except Exception as e: @@ -885,7 +898,9 @@ async def structured_output_async( ) ) - review_response = await agent.invoke_async(review_prompt) + review_response = await invoke_agent_with_retry( + agent=agent, input=review_prompt + ) logger.debug("Review response received", extra={"review_completed": True}) # Accumulate token usage from review From 00e7df42c828a92f55fb3279a41273a111507eb1 Mon Sep 17 00:00:00 2001 From: Kazmer Nagy-Betegh Date: Wed, 19 Nov 2025 18:50:33 +0200 Subject: [PATCH 03/18] caching the first prompt --- .../idp_common/extraction/agentic_idp.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py b/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py index a053f837..e55b9127 100644 --- a/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py +++ b/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py @@ -29,7 +29,7 @@ from strands import Agent, tool from strands.agent.conversation_manager import SummarizingConversationManager from strands.models.bedrock import BedrockModel -from strands.types.content import ContentBlock, Message +from strands.types.content import CachePoint, ContentBlock, Message from strands.types.media import ( DocumentContent, ImageContent, @@ -725,15 +725,21 @@ async def structured_output_async( format="png", source=ImageSource(bytes=img_bytes) ) ), + ContentBlock(cachePoint=CachePoint(type="default")), ], ) ] elif isinstance(prompt, dict) and "content" in prompt: prompt_content = [prompt] - # Extract and store images as binary strings else: prompt_content = [ - Message(role="user", content=[ContentBlock(text=str(prompt))]) + Message( + role="user", + content=[ + ContentBlock(text=str(prompt)), + ContentBlock(cachePoint=CachePoint(type="default")), + ], + ) ] # Track token usage From f96499f9a2d3066806548c840763d82dd038a4cd Mon Sep 17 00:00:00 2001 From: Kazmer Nagy-Betegh Date: Thu, 20 Nov 2025 11:31:34 +0200 Subject: [PATCH 04/18] add review agent model config --- .../lending-package-sample/config.yaml | 2 +- .../idp_common/config/models.py | 13 ++++- .../idp_common/extraction/agentic_idp.py | 54 ++++++++++++------- .../idp_common/extraction/service.py | 2 +- patterns/pattern-2/template.yaml | 29 +++++----- 5 files changed, 64 insertions(+), 36 deletions(-) diff --git a/config_library/pattern-2/lending-package-sample/config.yaml b/config_library/pattern-2/lending-package-sample/config.yaml index 37da9a05..cbac09b4 100644 --- a/config_library/pattern-2/lending-package-sample/config.yaml +++ b/config_library/pattern-2/lending-package-sample/config.yaml @@ -1811,7 +1811,7 @@ agents: parameters: max_log_events: 5 time_range_hours_default: 24 - + chat_companion: model_id: us.anthropic.claude-haiku-4-5-20251001-v1:0 pricing: diff --git a/lib/idp_common_pkg/idp_common/config/models.py b/lib/idp_common_pkg/idp_common/config/models.py index 604d84fc..a0efc832 100644 --- a/lib/idp_common_pkg/idp_common/config/models.py +++ b/lib/idp_common_pkg/idp_common/config/models.py @@ -19,7 +19,8 @@ """ from typing import Any, Dict, List, Optional, Union, Literal, Annotated -from pydantic import BaseModel, ConfigDict, Field, field_validator, Discriminator +from typing_extensions import Self +from pydantic import BaseModel, ConfigDict, Field, field_validator, Discriminator, model_validator class ImageConfig(BaseModel): @@ -78,6 +79,7 @@ class AgenticConfig(BaseModel): enabled: bool = Field(default=False, description="Enable agentic extraction") review_agent: bool = Field(default=False, description="Enable review agent") + review_agent_model: str | None= Field(default=None, description="Model used for reviewing and correcting extraction work") class ExtractionConfig(BaseModel): @@ -119,6 +121,15 @@ def parse_int(cls, v: Any) -> int: if isinstance(v, str): return int(v) if v else 0 return int(v) + + @model_validator(mode="after") + def model_validator(self) -> Self: + + if not self.agentic.review_agent_model: + self.agentic.review_agent_model = self.model + + return self + class ClassificationConfig(BaseModel): diff --git a/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py b/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py index e55b9127..a4a2e974 100644 --- a/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py +++ b/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py @@ -24,6 +24,10 @@ from aws_lambda_powertools import Logger from botocore.config import Config from botocore.exceptions import ClientError +from idp_common_pkg.idp_common import IDPConfig +from idp_common_pkg.idp_common.utils.bedrock_utils import ( + async_exponential_backoff_retry, +) from PIL import Image from pydantic import BaseModel, Field from strands import Agent, tool @@ -42,9 +46,6 @@ update_todo, view_todo_list, ) -from lib.idp_common_pkg.idp_common.utils.bedrock_utils import ( - async_exponential_backoff_retry, -) # Use AWS Lambda Powertools Logger for structured logging # Automatically logs as JSON with Lambda context, request_id, timestamp, etc. @@ -478,7 +479,7 @@ async def structured_output_async( existing_data: BaseModel | None = None, system_prompt: str | None = None, custom_instruction: str | None = None, - review_agent: bool = False, + config: IDPConfig = IDPConfig(), context: str = "Extraction", max_retries: int = 7, connect_timeout: float = 10.0, @@ -593,9 +594,6 @@ async def structured_output_async( }, ) - # Build final system prompt without modifying the original - final_system_prompt = system_prompt - # Configure retry behavior and timeouts using boto3 Config boto_config = Config( retries={ @@ -606,13 +604,6 @@ async def structured_output_async( read_timeout=read_timeout, ) - model_config = dict(model_id=model_id, boto_client_config=boto_config) - # Set max_tokens based on actual model limits - # Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/ - - # Determine model's maximum - # Use regex for more flexible matching (e.g., claude-sonnet-4-5 should match claude-sonnet-4) - model_max = 4_096 # Default fallback model_id_lower = model_id.lower() # Check Claude 4 patterns first (more specific) @@ -681,7 +672,7 @@ async def structured_output_async( else: logger.debug("Caching not supported for model", extra={"model_id": model_id}) - final_system_prompt = SYSTEM_PROMPT + final_system_prompt = system_prompt if system_prompt else SYSTEM_PROMPT if custom_instruction: final_system_prompt = f"{final_system_prompt}\n\nCustom Instructions for this specific task: {custom_instruction}" @@ -763,6 +754,7 @@ async def structured_output_async( ContentBlock( text=f"Please update the existing data using the extraction tool or patches. Existing data: {existing_data.model_dump()}" ), + ContentBlock(cachePoint=CachePoint(type="default")), ], ) ) @@ -875,7 +867,7 @@ async def structured_output_async( ) # Add explicit review step (Option 2) - if review_agent: + if config.extraction.agentic.enabled and config.extraction.agentic.review_agent: logger.debug( "Initiating final review of extracted data", extra={"review_enabled": True}, @@ -899,10 +891,32 @@ async def structured_output_async( If everything is correct, respond with "Data verified and accurate." If corrections are needed, use the apply_json_patches tool to fix any issues you find. """ - ) + ), + ContentBlock(cachePoint=CachePoint(type="default")), ], ) ) + model_config = dict( + model_id=config.extraction.agentic.review_agent_model, + boto_client_config=boto_config, + max_tokens=max_output_tokens, + ) + agent = Agent( + model=BedrockModel(**model_config), # pyright: ignore[reportArgumentType] + tools=tools, + system_prompt=f"{final_system_prompt}", + state={ + "current_extraction": None, + "images": {}, + "existing_data": existing_data.model_dump() + if existing_data + else None, + "extraction_schema_json": schema_json, # Store for schema reminder tool + }, + conversation_manager=SummarizingConversationManager( + summary_ratio=0.8, preserve_recent_messages=2 + ), + ) review_response = await invoke_agent_with_retry( agent=agent, input=review_prompt @@ -960,8 +974,8 @@ def structured_output( existing_data: BaseModel | None = None, system_prompt: str | None = None, custom_instruction: str | None = None, - review_agent: bool = False, context: str = "Extraction", + config: IDPConfig = IDPConfig(), max_retries: int = 7, connect_timeout: float = 10.0, read_timeout: float = 300.0, @@ -1045,7 +1059,7 @@ def run_in_new_loop(): existing_data=existing_data, system_prompt=system_prompt, custom_instruction=custom_instruction, - review_agent=review_agent, + config=config, context=context, max_retries=max_retries, connect_timeout=connect_timeout, @@ -1076,7 +1090,7 @@ def run_in_new_loop(): existing_data=existing_data, system_prompt=system_prompt, custom_instruction=custom_instruction, - review_agent=review_agent, + config=config, context=context, max_retries=max_retries, connect_timeout=connect_timeout, diff --git a/lib/idp_common_pkg/idp_common/extraction/service.py b/lib/idp_common_pkg/idp_common/extraction/service.py index b93ac293..260ef512 100644 --- a/lib/idp_common_pkg/idp_common/extraction/service.py +++ b/lib/idp_common_pkg/idp_common/extraction/service.py @@ -1104,7 +1104,7 @@ def process_document_section(self, document: Document, section_id: str) -> Docum data_format=dynamic_model, prompt=message_prompt, # pyright: ignore[reportArgumentType] custom_instruction=system_prompt, - review_agent=self.config.extraction.agentic.review_agent, # Type-safe boolean! + config=self.config, context="Extraction", ) diff --git a/patterns/pattern-2/template.yaml b/patterns/pattern-2/template.yaml index 3716d16c..91715908 100644 --- a/patterns/pattern-2/template.yaml +++ b/patterns/pattern-2/template.yaml @@ -114,14 +114,14 @@ Parameters: EnableXRayTracing: Type: String - Default: 'true' - AllowedValues: ['true', 'false'] + Default: "true" + AllowedValues: ["true", "false"] Description: Enable X-Ray tracing EnableECRImageScanning: Type: String - Default: 'true' - AllowedValues: ['true', 'false'] + Default: "true" + AllowedValues: ["true", "false"] Description: Enable automatic vulnerability scanning for Lambda container images in ECR PermissionsBoundaryArn: @@ -384,7 +384,7 @@ Resources: MemorySize: 128 Handler: index.handler CodeUri: ../../src/lambda/start_codebuild - Description: CodeBuild trigger Lambda for Docker image builds + Description: CodeBuild trigger Lambda for Docker image builds LoggingConfig: LogGroup: !Ref CodeBuildTriggerLogGroup @@ -426,7 +426,7 @@ Resources: Properties: ServiceToken: !GetAtt CodeBuildTrigger.Arn RepositoryName: !Ref ECRRepository - + # Shared IAM policy for Lambda functions to pull container images from ECR LambdaECRAccessPolicy: Type: AWS::IAM::ManagedPolicy @@ -852,7 +852,6 @@ Resources: - "eu.anthropic.claude-sonnet-4-5-20250929-v1:0" - "eu.anthropic.claude-sonnet-4-5-20250929-v1:0:1m" - order: 1 classificationMethod: type: string @@ -925,6 +924,10 @@ Resources: description: This introduces a second agent to review the first agents work. Only use with highly complex workflows as it increases token usage. order: 1 default: false + review_agent_model: + type: string + description: Model to review the initial extraction agents work and correct it if needed, if not specified will default to the same as the extraction model. + default: Null image: type: object sectionLabel: Image Processing Settings @@ -1098,7 +1101,7 @@ Resources: "us.anthropic.claude-sonnet-4-5-20250929-v1:0:1m", "us.anthropic.claude-opus-4-20250514-v1:0", "us.anthropic.claude-opus-4-1-20250805-v1:0", - "eu.amazon.nova-lite-v1:0", + "eu.amazon.nova-lite-v1:0", "eu.amazon.nova-pro-v1:0", "eu.anthropic.claude-3-haiku-20240307-v1:0", "eu.anthropic.claude-haiku-4-5-20251001-v1:0", @@ -1163,7 +1166,7 @@ Resources: "us.anthropic.claude-sonnet-4-5-20250929-v1:0:1m", "us.anthropic.claude-opus-4-20250514-v1:0", "us.anthropic.claude-opus-4-1-20250805-v1:0", - "eu.amazon.nova-lite-v1:0", + "eu.amazon.nova-lite-v1:0", "eu.amazon.nova-pro-v1:0", "eu.anthropic.claude-3-haiku-20240307-v1:0", "eu.anthropic.claude-haiku-4-5-20251001-v1:0", @@ -1231,7 +1234,7 @@ Resources: "us.anthropic.claude-sonnet-4-5-20250929-v1:0:1m", "us.anthropic.claude-opus-4-20250514-v1:0", "us.anthropic.claude-opus-4-1-20250805-v1:0", - "eu.amazon.nova-lite-v1:0", + "eu.amazon.nova-lite-v1:0", "eu.amazon.nova-pro-v1:0", "eu.anthropic.claude-3-haiku-20240307-v1:0", "eu.anthropic.claude-haiku-4-5-20251001-v1:0", @@ -1300,7 +1303,7 @@ Resources: "us.anthropic.claude-sonnet-4-5-20250929-v1:0:1m", "us.anthropic.claude-opus-4-20250514-v1:0", "us.anthropic.claude-opus-4-1-20250805-v1:0", - "eu.amazon.nova-lite-v1:0", + "eu.amazon.nova-lite-v1:0", "eu.amazon.nova-pro-v1:0", "eu.anthropic.claude-3-haiku-20240307-v1:0", "eu.anthropic.claude-haiku-4-5-20251001-v1:0", @@ -1369,7 +1372,7 @@ Resources: "us.anthropic.claude-sonnet-4-5-20250929-v1:0:1m", "us.anthropic.claude-opus-4-20250514-v1:0", "us.anthropic.claude-opus-4-1-20250805-v1:0", - "eu.amazon.nova-lite-v1:0", + "eu.amazon.nova-lite-v1:0", "eu.amazon.nova-pro-v1:0", "eu.anthropic.claude-3-haiku-20240307-v1:0", "eu.anthropic.claude-haiku-4-5-20251001-v1:0", @@ -2514,7 +2517,7 @@ Resources: SAVE_REPORTING_FUNCTION_NAME: !Ref SaveReportingFunctionName CONFIGURATION_TABLE_NAME: !Ref ConfigurationTable WORKING_BUCKET: !Ref WorkingBucket - DOCUMENT_TRACKING_MODE: !If [HasAppSyncApi, 'appsync', 'dynamodb'] + DOCUMENT_TRACKING_MODE: !If [HasAppSyncApi, "appsync", "dynamodb"] TRACKING_TABLE: !Ref TrackingTable LoggingConfig: LogGroup: !Ref EvaluationFunctionLogGroup From 37576ac1744aa2c2337e05d6b7e5c3a4355b1af0 Mon Sep 17 00:00:00 2001 From: Kazmer Nagy-Betegh Date: Thu, 20 Nov 2025 14:18:51 +0200 Subject: [PATCH 05/18] refactor extraction service --- .../idp_common/extraction/agentic_idp.py | 814 +++++---- .../idp_common/extraction/service.py | 1484 ++++++++--------- 2 files changed, 1181 insertions(+), 1117 deletions(-) diff --git a/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py b/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py index a4a2e974..78060bd0 100644 --- a/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py +++ b/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py @@ -9,10 +9,10 @@ import asyncio import io import json +import logging import os import re import threading -import traceback from pathlib import Path from typing import ( Any, @@ -23,16 +23,12 @@ import jsonpatch from aws_lambda_powertools import Logger from botocore.config import Config -from botocore.exceptions import ClientError -from idp_common_pkg.idp_common import IDPConfig -from idp_common_pkg.idp_common.utils.bedrock_utils import ( - async_exponential_backoff_retry, -) from PIL import Image from pydantic import BaseModel, Field from strands import Agent, tool from strands.agent.conversation_manager import SummarizingConversationManager -from strands.models.bedrock import BedrockModel +from strands.models import BedrockModel +from strands.types.agent import AgentInput from strands.types.content import CachePoint, ContentBlock, Message from strands.types.media import ( DocumentContent, @@ -41,6 +37,10 @@ ) from idp_common.bedrock.client import CACHEPOINT_SUPPORTED_MODELS +from idp_common.config.models import IDPConfig +from idp_common.utils.bedrock_utils import ( + async_exponential_backoff_retry, +) from idp_common.utils.strands_agent_tools.todo_list import ( create_todo_list, update_todo, @@ -52,7 +52,7 @@ # In Lambda: Full JSON structured logs # Outside Lambda: Human-readable format for local development logger = Logger(service="agentic_idp", level=os.getenv("LOG_LEVEL", "INFO")) - +logging.getLogger("strands.models.bedrock").setLevel(logging.DEBUG) TargetModel = TypeVar("TargetModel", bound=BaseModel) @@ -138,14 +138,6 @@ class BedrockInvokeModelResponse(TypedDict): metering: dict[str, BedrockUsage] # Key format: "{context}/bedrock/{model_id}" -# Data Models for structured extraction -class BoolResponseModel(BaseModel): - """Model for boolean validation responses.""" - - valid_result: bool - description: str = Field(..., description="explanation of the decision") - - class JsonPatchModel(BaseModel): """Model for JSON patch operations.""" @@ -182,6 +174,59 @@ def apply_patches_to_data( return patched_dict +def create_view_image_tool(page_images: list[bytes]) -> Any: + """ + Create a view_image tool that has access to page images. + + Args: + page_images: List of page image bytes (with grid overlay already applied) + sorted_page_ids: List of page IDs in sorted order + + Returns: + A Strands tool function for viewing images + """ + + @tool + def view_image(image_index: int, agent: Agent) -> dict: + """ + View a specific page image. Use this tool when the doc has more images than what you already see. + """ + + # Validate image index exists + if image_index >= len(page_images): + raise ValueError( + f"Invalid image_index {image_index}. " + f"Valid range: 0-{len(page_images) - 1}" + ) + + # Get the base image (already has grid overlay) + img_bytes = page_images[image_index] + + logger.info( + "Returning image to agent", + extra={ + "image_index": image_index, + "image_size_bytes": len(img_bytes), + }, + ) + + return { + "status": "success", + "content": [ + { + "image": { + "format": "png", + "source": { + "bytes": img_bytes, + }, + } + } + ], + } + + return view_image + + def create_dynamic_extraction_tool_and_patch_tool(model_class: type[TargetModel]): """ Create a dynamic tool function that extracts data according to a Pydantic model. @@ -468,10 +513,338 @@ def patch_buffer_data(patches: list[dict[str, Any]], agent: Agent) -> str: max_delay=1800, jitter=0.5, ) -async def invoke_agent_with_retry(input: Any, agent: Agent): +async def invoke_agent_with_retry(input: AgentInput, agent: Agent): return await agent.invoke_async(input) +def _initialize_token_usage() -> dict[str, int]: + """Initialize token usage tracking dictionary.""" + return { + "inputTokens": 0, + "outputTokens": 0, + "totalTokens": 0, + "cacheReadInputTokens": 0, + "cacheWriteInputTokens": 0, + } + + +def _accumulate_token_usage(response: Any, token_usage: dict[str, int]) -> None: + """ + Accumulate token usage from response into usage dict. + + Args: + response: Agent response object with metrics + token_usage: Dictionary to accumulate usage into (modified in place) + """ + if response and response.metrics and response.metrics.accumulated_usage: + for key in token_usage.keys(): + token_usage[key] += response.metrics.accumulated_usage.get(key, 0) + + +def _build_system_prompt( + base_prompt: str, custom_instruction: str | None, data_format: type[BaseModel] +) -> tuple[str, str]: + """ + Build complete system prompt with custom instructions and schema. + + Args: + base_prompt: The base system prompt (typically SYSTEM_PROMPT constant) + custom_instruction: Optional custom instructions to append + data_format: Pydantic model class to extract schema from + + Returns: + Tuple of (complete system prompt with schema, schema_json for state storage) + """ + # Generate and clean schema + schema_json = json.dumps(data_format.model_json_schema(), indent=2) + + # Build final prompt + final_prompt = base_prompt + if custom_instruction: + final_prompt = f"{final_prompt}\n\nCustom Instructions for this specific task: {custom_instruction}" + + complete_prompt = f"{final_prompt}\n\nExpected Schema:\n{schema_json}" + + return complete_prompt, schema_json + + +def _build_model_config( + model_id: str, + max_tokens: int | None, + max_retries: int, + connect_timeout: float, + read_timeout: float, +) -> dict[str, Any]: + """ + Build model configuration with token limits and caching settings. + + This function: + 1. Creates boto3 Config with retry and timeout settings + 2. Determines model-specific max token limits + 3. Validates and caps max_tokens if needed + 4. Auto-detects and enables caching support (prompt and tool caching) + + Args: + model_id: Bedrock model identifier (supports us.*, eu.*, and global.anthropic.*) + max_tokens: Optional max tokens override (will be capped at model max) + max_retries: Maximum retry attempts for API calls + connect_timeout: Connection timeout in seconds + read_timeout: Read timeout in seconds + + Returns: + Dictionary of model configuration parameters for create_strands_bedrock_model. + Automatically uses BedrockModel for regional models (us.*, eu.*) and + AnthropicModel with AnthropicBedrock for cross-region models (global.anthropic.*). + """ + # Configure retry behavior and timeouts using boto3 Config + boto_config = Config( + retries={ + "max_attempts": max_retries, + "mode": "adaptive", # Uses exponential backoff with adaptive retry mode + }, + connect_timeout=connect_timeout, + read_timeout=read_timeout, + ) + + # Determine model-specific maximum token limits + model_max = 4_096 # Default fallback + model_id_lower = model_id.lower() + + # Check Claude 4 patterns first (more specific) + if re.search(r"claude-(opus|sonnet|haiku)-4", model_id_lower): + model_max = 64_000 + # Check Nova models + elif any( + nova in model_id_lower + for nova in ["nova-premier", "nova-pro", "nova-lite", "nova-micro"] + ): + model_max = 10_000 + # Check Claude 3 models + elif "claude-3" in model_id_lower: + model_max = 8_192 + + # Use config value if provided, but cap at model's maximum + if max_tokens is not None: + if max_tokens > model_max: + logger.warning( + "Config max_tokens exceeds model limit, capping at model maximum", + extra={ + "config_max_tokens": max_tokens, + "model_max_tokens": model_max, + "model_id": model_id, + }, + ) + max_output_tokens = model_max + else: + max_output_tokens = max_tokens + else: + # No config value - use model maximum for agentic extraction + max_output_tokens = model_max + + # Build base model config + model_config = dict( + model_id=model_id, boto_client_config=boto_config, max_tokens=max_output_tokens + ) + + logger.info( + "Setting max_tokens for model", + extra={ + "max_tokens": max_output_tokens, + "model_id": model_id, + "model_max_tokens": model_max, + }, + ) + + # Auto-detect caching support based on model capabilities + if supports_prompt_caching(model_id): + model_config["cache_prompt"] = "default" + logger.info( + "Prompt caching enabled for model", + extra={"model_id": model_id, "auto_detected": True}, + ) + + # Only enable tool caching if the model supports it (Claude only, not Nova) + if supports_tool_caching(model_id): + model_config["cache_tools"] = "default" + logger.info( + "Tool caching enabled for model", + extra={"model_id": model_id, "auto_detected": True}, + ) + else: + logger.info( + "Tool caching not supported for model", + extra={"model_id": model_id, "reason": "prompt_caching_only"}, + ) + else: + logger.debug("Caching not supported for model", extra={"model_id": model_id}) + + return model_config + + +def _prepare_prompt_content( + prompt: str | Message | Image.Image, + page_images: list[bytes] | None, + existing_data: BaseModel | None, +) -> list[ContentBlock]: + """ + Prepare prompt content from various input types. + + Converts different prompt types (text, PIL Image, Message dict) into + a list of ContentBlocks, adds page images, and appends existing data context. + + Args: + prompt: Input content (text string, PIL Image, or Message dict) + page_images: Optional list of page image bytes to include + existing_data: Optional existing extraction data to update + + Returns: + List of ContentBlock objects ready for agent invocation + """ + prompt_content: list[ContentBlock] = [] + + # Process prompt based on type + if isinstance(prompt, Image.Image): + # Convert PIL Image to binary string + img_buffer = io.BytesIO() + prompt.save(img_buffer, format="PNG") + img_bytes = img_buffer.getvalue() + + logger.debug( + "Processing PIL Image", + extra={"size": prompt.size, "mode": prompt.mode}, + ) + + prompt_content = [ + ContentBlock(text="Extract structured data from this image:"), + ContentBlock( + image=ImageContent(format="png", source=ImageSource(bytes=img_bytes)) + ), + ] + elif isinstance(prompt, dict) and "content" in prompt: + prompt_content = prompt["content"] # type: ignore + else: + prompt_content = [ContentBlock(text=str(prompt))] + + # Add page images if provided + if page_images: + if len(page_images) > 20: + prompt_content.append( + ContentBlock( + text=f"There are {len(page_images)} images, initially you'll see 20 of them, use the tools to see the rest." + ) + ) + + prompt_content += [ + ContentBlock( + image=ImageContent(format="png", source=ImageSource(bytes=img_bytes)) + ) + for img_bytes in page_images + ] + + # Add existing data context if provided + if existing_data: + prompt_content.append( + ContentBlock( + text=f"Please update the existing data using the extraction tool or patches. Existing data: {existing_data.model_dump()}" + ) + ) + + prompt_content += [ + ContentBlock(text="end of your main task description"), + ContentBlock(cachePoint=CachePoint(type="default")), + ] + return prompt_content + + +async def _invoke_agent_for_extraction( + agent: Agent, + prompt_content: list[ContentBlock], + data_format: type[TargetModel], + max_extraction_retries: int = 3, +) -> tuple[Any, TargetModel | None]: + """ + Invoke agent and retry if extraction fails. + + Unlike network retries (handled by invoke_agent_with_retry), this retries when + the agent completes successfully but fails to produce a valid extraction. + + Args: + agent: The Strands agent to invoke + prompt_content: List of ContentBlocks to send to the agent + data_format: Pydantic model class for validation + max_extraction_retries: Maximum retry attempts for failed extractions (default: 3) + + Returns: + Tuple of (response, validated_result or None) + """ + response = None + + for attempt in range(max_extraction_retries): + # invoke_agent_with_retry already handles network errors and throttling + response = await invoke_agent_with_retry(agent=agent, input=prompt_content) + logger.debug("Agent response received") + + # Try to get extraction from state + current_extraction = agent.state.get("current_extraction") + + if current_extraction: + try: + result = data_format(**current_extraction) + logger.debug( + "Successfully validated extraction", + extra={"data_format": data_format.__name__, "attempt": attempt + 1}, + ) + return response, result + except Exception as e: + logger.warning( + "Extraction validation failed, retrying", + extra={ + "attempt": attempt + 1, + "max_retries": max_extraction_retries, + "error": str(e), + "data_format": data_format.__name__, + }, + ) + if attempt < max_extraction_retries - 1: + # Ask agent to fix the extraction + prompt_content = [ + ContentBlock( + text=f"The extraction failed validation with error: {str(e)}. Please fix the extraction using the tools." + ) + ] + continue + else: + # Last attempt failed + logger.error( + "Failed to validate extraction after all retries", + extra={ + "data_format": data_format.__name__, + "error": str(e), + "extraction_data": current_extraction, + }, + ) + return response, None + else: + logger.warning( + "No extraction found in agent state", + extra={"attempt": attempt + 1, "max_retries": max_extraction_retries}, + ) + if attempt < max_extraction_retries - 1: + # Ask agent to provide extraction + prompt_content = [ + ContentBlock( + text="No extraction was found. Please use the extraction_tool to provide the extracted data." + ) + ] + continue + + # Should never reach here, but handle it gracefully + if response is None: + raise ValueError("No response from agent after retries") + + return response, None + + async def structured_output_async( model_id: str, data_format: type[TargetModel], @@ -480,6 +853,7 @@ async def structured_output_async( system_prompt: str | None = None, custom_instruction: str | None = None, config: IDPConfig = IDPConfig(), + page_images: list[bytes] | None = None, context: str = "Extraction", max_retries: int = 7, connect_timeout: float = 10.0, @@ -566,10 +940,14 @@ async def structured_output_async( dynamic_extraction_tools = create_dynamic_extraction_tool_and_patch_tool( data_format ) + image_tools = [] + if page_images: + image_tools.append(create_view_image_tool(page_images)) # Prepare tools list tools = [ *dynamic_extraction_tools, + *image_tools, view_existing_extraction, patch_buffer_data, view_buffer_data, @@ -582,8 +960,13 @@ async def structured_output_async( view_todo_list, ] - # Create agent with system prompt and tools - schema_json = json.dumps(data_format.model_json_schema(), indent=2) + # Build system prompt with schema + final_system_prompt, schema_json = _build_system_prompt( + base_prompt=system_prompt or SYSTEM_PROMPT, + custom_instruction=custom_instruction, + data_format=data_format, + ) + tool_names = [getattr(tool, "__name__", str(tool)) for tool in tools] logger.debug( "Created agent with tools", @@ -594,93 +977,26 @@ async def structured_output_async( }, ) - # Configure retry behavior and timeouts using boto3 Config - boto_config = Config( - retries={ - "max_attempts": max_retries, - "mode": "adaptive", # Uses exponential backoff with adaptive retry mode - }, + # Build model configuration with token limits and caching + model_config = _build_model_config( + model_id=model_id, + max_tokens=max_tokens, + max_retries=max_retries, connect_timeout=connect_timeout, read_timeout=read_timeout, ) - model_max = 4_096 # Default fallback - model_id_lower = model_id.lower() - # Check Claude 4 patterns first (more specific) - if re.search(r"claude-(opus|sonnet|haiku)-4", model_id_lower): - model_max = 64_000 - # Check Nova models - elif any( - nova in model_id_lower - for nova in ["nova-premier", "nova-pro", "nova-lite", "nova-micro"] - ): - model_max = 10_000 - # Check Claude 3 models - elif "claude-3" in model_id_lower: - model_max = 8_192 - - # Use config value if provided, but cap at model's maximum - if max_tokens is not None: - if max_tokens > model_max: - logger.warning( - "Config max_tokens exceeds model limit, capping at model maximum", - extra={ - "config_max_tokens": max_tokens, - "model_max_tokens": model_max, - "model_id": model_id, - }, - ) - max_output_tokens = model_max - else: - max_output_tokens = max_tokens - else: - # No config value - use model maximum for agentic extraction - max_output_tokens = model_max - - model_config = dict( - model_id=model_id, boto_client_config=boto_config, max_tokens=max_output_tokens - ) - logger.info( - "Setting max_tokens for model", - extra={ - "max_tokens": max_output_tokens, - "model_id": model_id, - "model_max_tokens": model_max, - }, + # Prepare prompt content + prompt_content = _prepare_prompt_content( + prompt=prompt, page_images=page_images, existing_data=existing_data ) - # Auto-detect caching support based on model capabilities - if supports_prompt_caching(model_id): - model_config["cache_prompt"] = "default" - logger.info( - "Prompt caching enabled for model", - extra={"model_id": model_id, "auto_detected": True}, - ) - - # Only enable tool caching if the model supports it (Claude only, not Nova) - if supports_tool_caching(model_id): - model_config["cache_tools"] = "default" - logger.info( - "Tool caching enabled for model", - extra={"model_id": model_id, "auto_detected": True}, - ) - else: - logger.info( - "Tool caching not supported for model", - extra={"model_id": model_id, "reason": "prompt_caching_only"}, - ) - else: - logger.debug("Caching not supported for model", extra={"model_id": model_id}) - - final_system_prompt = system_prompt if system_prompt else SYSTEM_PROMPT - - if custom_instruction: - final_system_prompt = f"{final_system_prompt}\n\nCustom Instructions for this specific task: {custom_instruction}" - + # Track token usage + token_usage = _initialize_token_usage() agent = Agent( model=BedrockModel(**model_config), # pyright: ignore[reportArgumentType] tools=tools, - system_prompt=f"{final_system_prompt}\n\nExpected Schema:\n{schema_json}", + system_prompt=final_system_prompt, state={ "current_extraction": None, "images": {}, @@ -691,195 +1007,40 @@ async def structured_output_async( summary_ratio=0.8, preserve_recent_messages=2 ), ) - - # Process prompt based on type - if isinstance(prompt, Image.Image): - # Convert PIL Image to binary string for state storage - img_buffer = io.BytesIO() - prompt.save(img_buffer, format="PNG") - img_bytes = img_buffer.getvalue() - - logger.debug( - "Processing PIL Image", - extra={"size": prompt.size, "mode": prompt.mode}, - ) - - # Store image as binary string in state - - prompt_content = [ - Message( - role="user", - content=[ - ContentBlock(text="Extract structured data from this image:"), - ContentBlock( - image=ImageContent( - format="png", source=ImageSource(bytes=img_bytes) - ) - ), - ContentBlock(cachePoint=CachePoint(type="default")), - ], - ) - ] - elif isinstance(prompt, dict) and "content" in prompt: - prompt_content = [prompt] - else: - prompt_content = [ - Message( - role="user", - content=[ - ContentBlock(text=str(prompt)), - ContentBlock(cachePoint=CachePoint(type="default")), - ], - ) - ] - - # Track token usage - token_usage = { - "inputTokens": 0, - "outputTokens": 0, - "totalTokens": 0, - "cacheReadInputTokens": 0, - "cacheWriteInputTokens": 0, - } - - # Main extraction loop - result = None - response = None - # Prepare prompt for this cycle if existing_data: - prompt_content.append( - Message( - role="user", - content=[ - ContentBlock( - text=f"Please update the existing data using the extraction tool or patches. Existing data: {existing_data.model_dump()}" - ), - ContentBlock(cachePoint=CachePoint(type="default")), - ], - ) - ) agent.state.set("current_extraction", existing_data.model_dump()) - # Retry logic for network errors (ProtocolError, etc.) - max_retries = 3 - retry_delay = 2 # seconds - - for attempt in range(max_retries): - try: - response = await invoke_agent_with_retry(agent=agent, input=prompt_content) - logger.debug("Agent response received") - break # Success, exit retry loop - except Exception as e: - error_type = type(e).__name__ - error_msg = str(e) - is_last_attempt = attempt == max_retries - 1 - - # Check if this is a retryable network error - is_retryable = ( - error_type - in [ - "ProtocolError", - "ConnectionError", - "ReadTimeoutError", - "IncompleteRead", - ] - or "Response ended prematurely" in error_msg - or "Connection" in error_msg - ) - - if is_retryable and not is_last_attempt: - logger.warning( - "Network error during agent invocation, retrying", - extra={ - "attempt": attempt + 1, - "max_retries": max_retries, - "error_type": error_type, - "error_message": error_msg, - "retry_delay_seconds": retry_delay, - }, - ) - await asyncio.sleep(retry_delay) - retry_delay *= 2 # Exponential backoff - continue - - # Log the error - - logger.error( - "Agent invocation failed", - extra={ - "error_type": error_type, - "error_message": error_msg, - "traceback": traceback.format_exc(), - }, - ) - - # Re-raise ClientError (including ThrottlingException) directly for Step Functions retry handling - if isinstance(e, ClientError): - logger.error( - "Bedrock ClientError detected", - extra={ - "error_code": e.response["Error"]["Code"], - "error_message": e.response["Error"].get("Message", ""), - }, - ) - raise - - # Wrap other exceptions - raise ValueError(f"Agent invocation failed: {error_msg}") + response, result = await _invoke_agent_for_extraction( + agent=agent, + prompt_content=prompt_content, + data_format=data_format, + max_extraction_retries=3, + ) # Accumulate token usage - if response and response.metrics and response.metrics.accumulated_usage: - for key in token_usage.keys(): - token_usage[key] += response.metrics.accumulated_usage.get(key, 0) + _accumulate_token_usage(response, token_usage) - # Check for extraction in state - current_extraction = agent.state.get("current_extraction") - logger.debug( - "Current extraction from state", - extra={"extraction": current_extraction}, - ) + # Add explicit review step (Option 2) + if ( + config.extraction.agentic.enabled + and config.extraction.agentic.review_agent + and config.extraction.agentic.review_agent_model + ): + # result is guaranteed to be non-None here (we raised an error earlier if it was None) + assert result is not None - if current_extraction: - try: - result = data_format(**current_extraction) - logger.debug( - "Successfully created extraction instance", - extra={"data_format": data_format.__name__}, - ) - except Exception as e: - logger.error( - "Failed to validate extraction against schema", - extra={ - "data_format": data_format.__name__, - "error": str(e), - "extraction_data": current_extraction, - }, - ) - raise ValueError(f"Failed to validate extraction against schema: {str(e)}") - else: - logger.error( - "No extraction found in agent state", - extra={"agent_state_keys": list(agent.state._state.keys())}, - ) - logger.error( - "Full agent state dump", - extra={"agent_state": agent.state._state}, + logger.debug( + "Initiating final review of extracted data", + extra={"review_enabled": True}, ) - - # Add explicit review step (Option 2) - if config.extraction.agentic.enabled and config.extraction.agentic.review_agent: - logger.debug( - "Initiating final review of extracted data", - extra={"review_enabled": True}, - ) - review_prompt = prompt_content.append( - Message( - role="user", - content=[ - ContentBlock( - text=f""" + review_prompt = Message( + role="user", + content=[ + *prompt_content, + ContentBlock( + text=f""" You have successfully extracted the following data: - {json.dumps(current_extraction, indent=2)} + {json.dumps(result.model_dump(), indent=2)} Please take one final careful look at this extraction: 1. Check each field against the source document @@ -891,60 +1052,56 @@ async def structured_output_async( If everything is correct, respond with "Data verified and accurate." If corrections are needed, use the apply_json_patches tool to fix any issues you find. """ - ), - ContentBlock(cachePoint=CachePoint(type="default")), - ], - ) - ) - model_config = dict( - model_id=config.extraction.agentic.review_agent_model, - boto_client_config=boto_config, - max_tokens=max_output_tokens, - ) - agent = Agent( - model=BedrockModel(**model_config), # pyright: ignore[reportArgumentType] - tools=tools, - system_prompt=f"{final_system_prompt}", - state={ - "current_extraction": None, - "images": {}, - "existing_data": existing_data.model_dump() - if existing_data - else None, - "extraction_schema_json": schema_json, # Store for schema reminder tool - }, - conversation_manager=SummarizingConversationManager( - summary_ratio=0.8, preserve_recent_messages=2 ), - ) + ContentBlock(cachePoint=CachePoint(type="default")), + ], + ) + # Build config for review agent + review_model_config = _build_model_config( + model_id=config.extraction.agentic.review_agent_model, + max_tokens=max_tokens, + max_retries=max_retries, + connect_timeout=connect_timeout, + read_timeout=read_timeout, + ) + agent = Agent( + model=BedrockModel(**review_model_config), # pyright: ignore[reportArgumentType] + tools=tools, + system_prompt=f"{final_system_prompt}", + state={ + "current_extraction": None, + "images": {}, + "existing_data": existing_data.model_dump() if existing_data else None, + "extraction_schema_json": schema_json, # Store for schema reminder tool + }, + conversation_manager=SummarizingConversationManager( + summary_ratio=0.8, preserve_recent_messages=2 + ), + ) - review_response = await invoke_agent_with_retry( - agent=agent, input=review_prompt - ) - logger.debug("Review response received", extra={"review_completed": True}) + review_response = await invoke_agent_with_retry( + agent=agent, input=review_prompt + ) + logger.debug("Review response received", extra={"review_completed": True}) - # Accumulate token usage from review - if review_response.metrics and review_response.metrics.accumulated_usage: - for key in token_usage.keys(): - token_usage[key] += review_response.metrics.accumulated_usage.get( - key, 0 - ) + # Accumulate token usage from review + _accumulate_token_usage(review_response, token_usage) - # Check if patches were applied during review - updated_extraction = agent.state.get("current_extraction") - if updated_extraction != current_extraction: - # Patches were applied, validate the new extraction - try: - result = data_format(**updated_extraction) - logger.debug( - "Applied corrections after final review", - extra={"corrections_applied": True}, - ) - except Exception as e: - logger.debug( - "Post-review validation failed", - extra={"error": str(e)}, - ) + # Check if patches were applied during review + updated_extraction = agent.state.get("current_extraction") + if updated_extraction != result.model_dump(): + # Patches were applied, validate the new extraction + try: + result = data_format(**updated_extraction) + logger.debug( + "Applied corrections after final review", + extra={"corrections_applied": True}, + ) + except Exception as e: + logger.debug( + "Post-review validation failed", + extra={"error": str(e)}, + ) # Return best effort result if result and response: @@ -974,6 +1131,7 @@ def structured_output( existing_data: BaseModel | None = None, system_prompt: str | None = None, custom_instruction: str | None = None, + page_images: list[bytes] | None = None, context: str = "Extraction", config: IDPConfig = IDPConfig(), max_retries: int = 7, @@ -1064,6 +1222,7 @@ def run_in_new_loop(): max_retries=max_retries, connect_timeout=connect_timeout, read_timeout=read_timeout, + page_images=page_images, ) ) except Exception as e: @@ -1095,6 +1254,7 @@ def run_in_new_loop(): max_retries=max_retries, connect_timeout=connect_timeout, read_timeout=read_timeout, + page_images=page_images, ) ) diff --git a/lib/idp_common_pkg/idp_common/extraction/service.py b/lib/idp_common_pkg/idp_common/extraction/service.py index 260ef512..49b027c1 100644 --- a/lib/idp_common_pkg/idp_common/extraction/service.py +++ b/lib/idp_common_pkg/idp_common/extraction/service.py @@ -14,9 +14,10 @@ import logging import os import time -from typing import Any, Dict, List, Union +from typing import Any from idp_common import bedrock, image, metrics, s3, utils +from idp_common.bedrock import format_prompt from idp_common.config.models import IDPConfig from idp_common.config.schema_constants import ( ID_FIELD, @@ -36,18 +37,54 @@ AGENTIC_AVAILABLE = True except ImportError: AGENTIC_AVAILABLE = False +from pydantic import BaseModel + from idp_common.utils import extract_json_from_text logger = logging.getLogger(__name__) +# Pydantic models for internal data transfer +class SectionInfo(BaseModel): + """Metadata about a document section being processed.""" + + class_label: str + sorted_page_ids: list[str] + page_indices: list[int] + output_bucket: str + output_key: str + output_uri: str + start_page: int + end_page: int + + +class ExtractionConfig(BaseModel): + """Configuration for model invocation.""" + + model_id: str + temperature: float + top_k: float + top_p: float + max_tokens: int | None + system_prompt: str + + +class ExtractionResult(BaseModel): + """Result from model extraction.""" + + extracted_fields: dict[str, Any] + metering: dict[str, Any] + parsing_succeeded: bool + total_duration: float + + class ExtractionService: """Service for extracting fields from documents using LLMs.""" def __init__( self, - region: str = None, - config: Union[Dict[str, Any], "IDPConfig"] = None, + region: str | None = None, + config: dict[str, Any] | IDPConfig | None = None, ): """ Initialize the extraction service. @@ -67,13 +104,62 @@ def __init__( self.config = config_model self.region = region or os.environ.get("AWS_REGION") + # Instance variables for prompt context + # These are initialized here and populated during each process_document_section call + # This allows methods to access context without passing multiple parameters + self._document_text: str = "" + self._class_label: str = "" + self._attribute_descriptions: str = "" + self._class_schema: dict[str, Any] = {} + self._page_images: list[Any] = [] + self._image_uris: list[str] = [] + # Get model_id from config for logging (type-safe access with fallback) model_id = ( self.config.extraction.model if self.config.extraction else "not configured" ) logger.info(f"Initialized extraction service with model {model_id}") - def _get_class_schema(self, class_label: str) -> Dict[str, Any]: + @property + def _substitutions(self) -> dict[str, str]: + """Get prompt placeholder substitutions from stored context.""" + return { + "DOCUMENT_TEXT": self._document_text, + "DOCUMENT_CLASS": self._class_label, + "ATTRIBUTE_NAMES_AND_DESCRIPTIONS": self._attribute_descriptions, + } + + def _get_default_prompt_content(self) -> list[dict[str, Any]]: + """ + Build default fallback prompt content when no template is provided. + + Returns: + List of content items with default prompt text and images + """ + task_prompt = f""" + Extract the following fields from this {self._class_label} document: + + {self._attribute_descriptions} + + Document text: + {self._document_text} + + Respond with a JSON object containing each field name and its extracted value. + """ + content = [{"text": task_prompt}] + + # Add image attachments to the content (limit to 20 images as per Bedrock constraints) + if self._page_images: + logger.info( + f"Attaching images to default prompt, for {len(self._page_images)} pages." + ) + # Limit to 20 images as per Bedrock constraints + for img in self._page_images[:20]: + content.append(image.prepare_bedrock_image_attachment(img)) + + return content + + def _get_class_schema(self, class_label: str) -> dict[str, Any]: """ Get JSON Schema for a specific document class from configuration. @@ -96,7 +182,7 @@ def _get_class_schema(self, class_label: str) -> Dict[str, Any]: return {} - def _clean_schema_for_prompt(self, schema: Dict[str, Any]) -> Dict[str, Any]: + def _clean_schema_for_prompt(self, schema: dict[str, Any]) -> dict[str, Any]: """ Clean JSON Schema by removing IDP custom fields (x-aws-idp-*) for the prompt. Keeps all standard JSON Schema fields including descriptions. @@ -129,7 +215,7 @@ def _clean_schema_for_prompt(self, schema: Dict[str, Any]) -> Dict[str, Any]: return cleaned - def _format_schema_for_prompt(self, schema: Dict[str, Any]) -> str: + def _format_schema_for_prompt(self, schema: dict[str, Any]) -> str: """ Format JSON Schema for inclusion in the extraction prompt. @@ -148,8 +234,8 @@ def _format_schema_for_prompt(self, schema: Dict[str, Any]) -> str: def _prepare_prompt_from_template( self, prompt_template: str, - substitutions: Dict[str, str], - required_placeholders: List[str] = None, + substitutions: dict[str, str], + required_placeholders: list[str] | None = None, ) -> str: """ Prepare prompt from template by replacing placeholders with values. @@ -165,338 +251,156 @@ def _prepare_prompt_from_template( Raises: ValueError: If a required placeholder is missing from the template """ - from idp_common.bedrock import format_prompt return format_prompt(prompt_template, substitutions, required_placeholders) - def _build_content_with_or_without_image_placeholder( + def _build_prompt_content( self, prompt_template: str, - document_text: str, - class_label: str, - attribute_descriptions: str, image_content: Any = None, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """ - Build content array, automatically deciding whether to use image placeholder processing. + Build prompt content array handling FEW_SHOT_EXAMPLES and DOCUMENT_IMAGE placeholders. - If the prompt contains {DOCUMENT_IMAGE}, the image will be inserted at that location. - If the prompt does NOT contain {DOCUMENT_IMAGE}, the image will NOT be included at all. + This consolidated method handles all placeholder types and combinations: + - {FEW_SHOT_EXAMPLES}: Inserts few-shot examples from config + - {DOCUMENT_IMAGE}: Inserts images at specific location + - Regular text placeholders: DOCUMENT_TEXT, DOCUMENT_CLASS, etc. Args: - prompt_template: The prompt template that may contain {DOCUMENT_IMAGE} - document_text: The document text content - class_label: The document class label - attribute_descriptions: Formatted attribute names and descriptions - image_content: Optional image content to insert (only used when {DOCUMENT_IMAGE} is present) + prompt_template: The prompt template with optional placeholders + image_content: Optional image content to insert (only used with {DOCUMENT_IMAGE}) Returns: - List of content items with text and image content properly ordered based on presence of placeholder + List of content items with text and image content properly ordered """ - if "{DOCUMENT_IMAGE}" in prompt_template: - return self._build_content_with_image_placeholder( - prompt_template, - document_text, - class_label, - attribute_descriptions, - image_content, - ) - else: - return self._build_content_without_image_placeholder( - prompt_template, - document_text, - class_label, - attribute_descriptions, - image_content, - ) + content: list[dict[str, Any]] = [] + + # Handle FEW_SHOT_EXAMPLES placeholder first + if "{FEW_SHOT_EXAMPLES}" in prompt_template: + parts = prompt_template.split("{FEW_SHOT_EXAMPLES}") + if len(parts) == 2: + # Process before examples + content.extend( + self._build_text_and_image_content(parts[0], image_content) + ) + + # Add few-shot examples + content.extend(self._build_few_shot_examples_content()) + + # Process after examples (only pass images if not already used) + image_for_after = ( + None if "{DOCUMENT_IMAGE}" in parts[0] else image_content + ) + content.extend( + self._build_text_and_image_content(parts[1], image_for_after) + ) + + return content + + # No FEW_SHOT_EXAMPLES, just handle text and images + return self._build_text_and_image_content(prompt_template, image_content) - def _build_content_with_image_placeholder( + def _build_text_and_image_content( self, prompt_template: str, - document_text: str, - class_label: str, - attribute_descriptions: str, image_content: Any = None, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """ - Build content array with image inserted at DOCUMENT_IMAGE placeholder if present. + Build content array with text and optionally images based on DOCUMENT_IMAGE placeholder. Args: - prompt_template: The prompt template that may contain {DOCUMENT_IMAGE} - document_text: The document text content - class_label: The document class label - attribute_descriptions: Formatted attribute names and descriptions - image_content: Optional image content to insert + prompt_template: Template that may contain {DOCUMENT_IMAGE} + image_content: Optional image content Returns: - List of content items with text and image content properly ordered + List of content items """ - # Split the prompt at the DOCUMENT_IMAGE placeholder - parts = prompt_template.split("{DOCUMENT_IMAGE}") - - if len(parts) != 2: - logger.warning( - "Invalid DOCUMENT_IMAGE placeholder usage, falling back to standard processing" - ) - # Fallback to standard processing - return self._build_content_without_image_placeholder( - prompt_template, - document_text, - class_label, - attribute_descriptions, - image_content, - ) - - # Process the parts before and after the image placeholder - before_image = self._prepare_prompt_from_template( - parts[0], - { - "DOCUMENT_TEXT": document_text, - "DOCUMENT_CLASS": class_label, - "ATTRIBUTE_NAMES_AND_DESCRIPTIONS": attribute_descriptions, - }, - required_placeholders=[], # Don't enforce required placeholders for partial templates - ) + content: list[dict[str, Any]] = [] - after_image = self._prepare_prompt_from_template( - parts[1], - { - "DOCUMENT_TEXT": document_text, - "DOCUMENT_CLASS": class_label, - "ATTRIBUTE_NAMES_AND_DESCRIPTIONS": attribute_descriptions, - }, - required_placeholders=[], # Don't enforce required placeholders for partial templates - ) + # Handle DOCUMENT_IMAGE placeholder + if "{DOCUMENT_IMAGE}" in prompt_template: + parts = prompt_template.split("{DOCUMENT_IMAGE}") + if len(parts) == 2: + # Add text before image + before_text = self._prepare_prompt_from_template( + parts[0], self._substitutions, required_placeholders=[] + ) + if before_text.strip(): + content.append({"text": before_text}) - # Build content array with image in the middle - content = [] + # Add images + if image_content: + content.extend(self._prepare_image_attachments(image_content)) - # Add the part before the image - if before_image.strip(): - content.append({"text": before_image}) + # Add text after image + after_text = self._prepare_prompt_from_template( + parts[1], self._substitutions, required_placeholders=[] + ) + if after_text.strip(): + content.append({"text": after_text}) - # Add the image if available - if image_content: - if isinstance(image_content, list): - # Multiple images (limit to 20 as per Bedrock constraints) - if len(image_content) > 20: - logger.warning( - f"Found {len(image_content)} images, truncating to 20 due to Bedrock constraints. " - f"{len(image_content) - 20} images will be dropped." - ) - for img in image_content[:20]: - content.append(image.prepare_bedrock_image_attachment(img)) + return content else: - # Single image - content.append(image.prepare_bedrock_image_attachment(image_content)) - - # Add the part after the image - if after_image.strip(): - content.append({"text": after_image}) - - return content - - def _build_content_without_image_placeholder( - self, - prompt_template: str, - document_text: str, - class_label: str, - attribute_descriptions: str, - image_content: Any = None, - ) -> List[Dict[str, Any]]: - """ - Build content array without DOCUMENT_IMAGE placeholder (standard processing). - - Note: This method does NOT attach the image content when no placeholder is present. - - Args: - prompt_template: The prompt template - document_text: The document text content - class_label: The document class label - attribute_descriptions: Formatted attribute names and descriptions - image_content: Optional image content (not used when no placeholder is present) + logger.warning("Invalid DOCUMENT_IMAGE placeholder usage") - Returns: - List of content items with text content only (no image) - """ - # Prepare the full prompt + # No image placeholder, just text task_prompt = self._prepare_prompt_from_template( - prompt_template, - { - "DOCUMENT_TEXT": document_text, - "DOCUMENT_CLASS": class_label, - "ATTRIBUTE_NAMES_AND_DESCRIPTIONS": attribute_descriptions, - }, - required_placeholders=[], + prompt_template, self._substitutions, required_placeholders=[] ) - - content = [{"text": task_prompt}] - - # No longer adding image content when no placeholder is present + content.append({"text": task_prompt}) return content - def _build_content_with_few_shot_examples( - self, - task_prompt_template: str, - document_text: str, - class_label: str, - attribute_descriptions: str, - image_content: Any = None, - ) -> List[Dict[str, Any]]: + def _prepare_image_attachments(self, image_content: Any) -> list[dict[str, Any]]: """ - Build content array with few-shot examples inserted at the FEW_SHOT_EXAMPLES placeholder. - Also supports DOCUMENT_IMAGE placeholder for image positioning. + Prepare image attachments for Bedrock, limiting to 20 images. Args: - task_prompt_template: The task prompt template containing {FEW_SHOT_EXAMPLES} - document_text: The document text content - class_label: The document class label - attribute_descriptions: Formatted attribute names and descriptions - image_content: Optional image content to insert + image_content: Single image or list of images Returns: - List of content items with text and image content properly ordered + List of image attachment dicts """ - # Split the task prompt at the FEW_SHOT_EXAMPLES placeholder - parts = task_prompt_template.split("{FEW_SHOT_EXAMPLES}") - - if len(parts) != 2: - # Fallback to regular prompt processing if placeholder not found or malformed - return self._build_content_with_or_without_image_placeholder( - task_prompt_template, - document_text, - class_label, - attribute_descriptions, - image_content, - ) - - # Process each part using the unified function - before_examples_content = self._build_content_with_or_without_image_placeholder( - parts[0], document_text, class_label, attribute_descriptions, image_content - ) - - # Only pass image_content if it wasn't already used in the first part - image_for_second_part = ( - None if "{DOCUMENT_IMAGE}" in parts[0] else image_content - ) - after_examples_content = self._build_content_with_or_without_image_placeholder( - parts[1], - document_text, - class_label, - attribute_descriptions, - image_for_second_part, - ) - - # Build content array - content = [] - - # Add the part before examples (may include image if DOCUMENT_IMAGE was in the first part) - content.extend(before_examples_content) - - # Add few-shot examples from config for this specific class - examples_content = self._build_few_shot_examples_content(class_label) - content.extend(examples_content) - - # Add the part after examples (may include image if DOCUMENT_IMAGE was in the second part) - content.extend(after_examples_content) - - # No longer appending image content when no placeholder is found + attachments: list[dict[str, Any]] = [] + + if isinstance(image_content, list): + # Multiple images (limit to 20 as per Bedrock constraints) + if len(image_content) > 20: + logger.warning( + f"Found {len(image_content)} images, truncating to 20 due to Bedrock constraints. " + f"{len(image_content) - 20} images will be dropped." + ) + for img in image_content[:20]: + attachments.append(image.prepare_bedrock_image_attachment(img)) + else: + # Single image + attachments.append(image.prepare_bedrock_image_attachment(image_content)) - return content + return attachments - def _build_few_shot_examples_content( - self, class_label: str - ) -> List[Dict[str, Any]]: + def _build_few_shot_examples_content(self) -> list[dict[str, Any]]: """ Build content items for few-shot examples from the configuration for a specific class. - Args: - class_label: The document class label to get examples for - Returns: List of content items containing text and image content for examples """ - content = [] + content: list[dict[str, Any]] = [] - # Find the specific class that matches the class_label (now in JSON Schema format) - target_class = self._get_class_schema(class_label) - - if not target_class: + # Use the stored class schema + if not self._class_schema: logger.warning( - f"No class found matching '{class_label}' for few-shot examples" + f"No class schema found for '{self._class_label}' for few-shot examples" ) return content # Get examples from the JSON Schema for this specific class - content = build_few_shot_extraction_examples_content(target_class) + content = build_few_shot_extraction_examples_content(self._class_schema) return content - def _get_image_files_from_path(self, image_path: str) -> List[str]: - """ - Get list of image files from a path that could be a single file, directory, or S3 prefix. - - Args: - image_path: Path to image file, directory, or S3 prefix - - Returns: - List of image file paths/URIs sorted by filename - """ - import os - - from idp_common import s3 - - # Handle S3 URIs - if image_path.startswith("s3://"): - # Check if it's a direct file or a prefix - if image_path.endswith( - (".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".tif", ".webp") - ): - # Direct S3 file - return [image_path] - else: - # S3 prefix - list all images - return s3.list_images_from_path(image_path) - else: - # Handle local paths - config_bucket = os.environ.get("CONFIGURATION_BUCKET") - root_dir = os.environ.get("ROOT_DIR") - - if config_bucket: - # Use environment bucket with imagePath as key - s3_uri = f"s3://{config_bucket}/{image_path}" - - # Check if it's a direct file or a prefix - if image_path.endswith( - (".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".tif", ".webp") - ): - # Direct S3 file - return [s3_uri] - else: - # S3 prefix - list all images - return s3.list_images_from_path(s3_uri) - elif root_dir: - # Use relative path from ROOT_DIR - full_path = os.path.join(root_dir, image_path) - full_path = os.path.normpath(full_path) - - if os.path.isfile(full_path): - # Single local file - return [full_path] - elif os.path.isdir(full_path): - # Local directory - list all images - return s3.list_images_from_path(full_path) - else: - # Path doesn't exist - logger.warning(f"Image path does not exist: {full_path}") - return [] - else: - raise ValueError( - "No CONFIGURATION_BUCKET or ROOT_DIR set. Cannot read example images from local filesystem." - ) - - def _make_json_serializable(self, obj): + def _make_json_serializable(self, obj: Any) -> Any: """ Recursively convert any object to a JSON-serializable format. @@ -534,96 +438,18 @@ def _make_json_serializable(self, obj): # Convert non-serializable objects to string representation return str(obj) - def _convert_image_bytes_to_uris_in_content( - self, content: List[Dict[str, Any]] - ) -> List[Dict[str, Any]]: - """ - Convert image bytes to URIs in content array for JSON serialization. - - Args: - content: Content array that may contain image objects with bytes - - Returns: - Content array with image URIs instead of bytes - """ - converted_content = [] - - for item in content: - if "image" in item and isinstance(item["image"], dict): - # Extract image URI if it exists, or use placeholder - if "source" in item["image"] and "bytes" in item["image"]["source"]: - # This is a bytes-based image - replace with URI reference - # In practice, we need to store these bytes somewhere accessible - # For now, we'll use a placeholder that indicates bytes were present - converted_item = { - "image_uri": f"" - } - else: - # Keep other image formats as-is - converted_item = item.copy() - else: - # Keep non-image items as-is - converted_item = item.copy() - - converted_content.append(converted_item) - - return converted_content - - def _convert_image_uris_to_bytes_in_content( - self, content: List[Dict[str, Any]], original_images: List[Any] - ) -> List[Dict[str, Any]]: - """ - Convert image URIs back to bytes in content array after Lambda processing. - - Args: - content: Content array from Lambda that may contain image URIs - original_images: Original image data to restore - - Returns: - Content array with image bytes restored - """ - converted_content = [] - image_index = 0 - - for item in content: - if "image_uri" in item: - # Convert image URI back to bytes format - if image_index < len(original_images): - # Restore original image bytes - converted_item = image.prepare_bedrock_image_attachment( - original_images[image_index] - ) - image_index += 1 - else: - # Skip if no original image data - logger.warning( - "No original image data available for URI conversion" - ) - continue - elif "image" in item: - # Keep existing image objects as-is - converted_item = item.copy() - else: - # Keep non-image items as-is - converted_item = item.copy() - - converted_content.append(converted_item) - - return converted_content - def _invoke_custom_prompt_lambda( - self, lambda_arn: str, payload: dict, original_images: List[Any] = None - ) -> dict: + self, lambda_arn: str, payload: dict[str, Any] + ) -> dict[str, Any]: """ Invoke custom prompt generator Lambda function with JSON-serializable payload. Args: lambda_arn: ARN of the Lambda function to invoke payload: Payload to send to Lambda function (must be JSON serializable) - original_images: Original image data for restoration after Lambda processing Returns: - Dict containing system_prompt and task_prompt_content with images restored + Dict containing system_prompt and task_prompt_content Raises: Exception: If Lambda invocation fails or returns invalid response @@ -665,14 +491,6 @@ def _invoke_custom_prompt_lambda( logger.error(error_msg) raise Exception(error_msg) - # Convert image URIs back to bytes in the response - if original_images: - result["task_prompt_content"] = ( - self._convert_image_uris_to_bytes_in_content( - result["task_prompt_content"], original_images - ) - ) - return result except Exception as e: @@ -680,41 +498,58 @@ def _invoke_custom_prompt_lambda( logger.error(error_msg) raise Exception(error_msg) - def process_document_section(self, document: Document, section_id: str) -> Document: + def _reset_context(self) -> None: + """Reset instance variables for clean state before processing.""" + self._document_text = "" + self._class_label = "" + self._attribute_descriptions = "" + self._class_schema = {} + self._page_images = [] + self._image_uris = [] + + def _validate_and_find_section( + self, document: Document, section_id: str + ) -> Any | None: """ - Process a single section from a Document object. + Validate document and find section by ID. Args: - document: Document object containing section to process - section_id: ID of the section to process + document: Document to validate + section_id: ID of section to find Returns: - Document: Updated Document object with extraction results for the section + Section if found, None otherwise (errors added to document) """ - # Validate input document if not document: logger.error("No document provided") - return document + return None if not document.sections: logger.error("Document has no sections to process") document.errors.append("Document has no sections to process") - return document + return None # Find the section with the given ID - section = None - for s in document.sections: - if s.section_id == section_id: - section = s - break + for section in document.sections: + if section.section_id == section_id: + return section - if not section: - error_msg = f"Section {section_id} not found in document" - logger.error(error_msg) - document.errors.append(error_msg) - return document + error_msg = f"Section {section_id} not found in document" + logger.error(error_msg) + document.errors.append(error_msg) + return None + + def _prepare_section_info(self, document: Document, section: Any) -> SectionInfo: + """ + Prepare section metadata and output paths. + + Args: + document: Document being processed + section: Section being processed - # Extract information about the section + Returns: + SectionInfo with all metadata + """ class_label = section.classification output_bucket = document.output_bucket output_prefix = document.input_key @@ -723,19 +558,23 @@ def process_document_section(self, document: Document, section_id: str) -> Docum # Check if the section has required pages if not section.page_ids: - error_msg = f"Section {section_id} has no page IDs" + error_msg = f"Section {section.section_id} has no page IDs" logger.error(error_msg) document.errors.append(error_msg) - return document + raise ValueError(error_msg) # Sort pages by page number sorted_page_ids = sorted(section.page_ids, key=int) start_page = int(sorted_page_ids[0]) end_page = int(sorted_page_ids[-1]) - # Convert 1-based page IDs to 0-based indices for the original document packet - # This preserves the actual position of pages in the document (e.g., pages [5,6,7] -> indices [4,5,6]) - page_indices = [int(page_id) - 1 for page_id in sorted_page_ids] + # Find minimum page ID across all sections + min_page_id = min( + int(page_id) for sec in document.sections for page_id in sec.page_ids + ) + + # Adjust page indices to be zero-based + page_indices = [int(page_id) - min_page_id for page_id in sorted_page_ids] logger.info( f"Processing {len(sorted_page_ids)} pages, class {class_label}: {start_page}-{end_page}" @@ -745,442 +584,507 @@ def process_document_section(self, document: Document, section_id: str) -> Docum metrics.put_metric("InputDocuments", 1) metrics.put_metric("InputDocumentPages", len(section.page_ids)) - try: - # Read document text from all pages in order - t0 = time.time() - document_texts = [] - for page_id in sorted_page_ids: - if page_id not in document.pages: - error_msg = f"Page {page_id} not found in document" - logger.error(error_msg) - document.errors.append(error_msg) - continue + return SectionInfo( + class_label=class_label, + sorted_page_ids=sorted_page_ids, + page_indices=page_indices, + output_bucket=output_bucket, + output_key=output_key, + output_uri=output_uri, + start_page=start_page, + end_page=end_page, + ) - page = document.pages[page_id] - text_path = page.parsed_text_uri - page_text = s3.get_text_content(text_path) - document_texts.append(page_text) + def _load_document_text( + self, document: Document, sorted_page_ids: list[str] + ) -> str: + """ + Load and concatenate text from all pages. + + Args: + document: Document containing pages + sorted_page_ids: Sorted list of page IDs + + Returns: + Concatenated document text + """ + t0 = time.time() + document_texts = [] + + for page_id in sorted_page_ids: + if page_id not in document.pages: + error_msg = f"Page {page_id} not found in document" + logger.error(error_msg) + document.errors.append(error_msg) + continue + + page = document.pages[page_id] + text_path = page.parsed_text_uri + page_text = s3.get_text_content(text_path) + document_texts.append(page_text) + + document_text = "\n".join(document_texts) + t1 = time.time() + logger.info(f"Time taken to read text content: {t1 - t0:.2f} seconds") + + return document_text + + def _load_document_images( + self, document: Document, sorted_page_ids: list[str] + ) -> list[Any]: + """ + Load images from all pages. - document_text = "\n".join(document_texts) - t1 = time.time() - logger.info(f"Time taken to read text content: {t1 - t0:.2f} seconds") + Args: + document: Document containing pages + sorted_page_ids: Sorted list of page IDs + + Returns: + List of prepared images + """ + t0 = time.time() + target_width = self.config.extraction.image.target_width + target_height = self.config.extraction.image.target_height + + page_images = [] + for page_id in sorted_page_ids: + if page_id not in document.pages: + continue + + page = document.pages[page_id] + image_uri = page.image_uri + image_content = image.prepare_image(image_uri, target_width, target_height) + page_images.append(image_content) + + t1 = time.time() + logger.info(f"Time taken to read images: {t1 - t0:.2f} seconds") - # Read page images with configurable dimensions (type-safe access) - target_width = self.config.extraction.image.target_width - target_height = self.config.extraction.image.target_height + return page_images + + def _initialize_extraction_context( + self, + class_label: str, + document_text: str, + page_images: list[Any], + sorted_page_ids: list[str], + document: Document, + ) -> tuple[dict[str, Any], str]: + """ + Initialize extraction context and set instance variables. - page_images = [] - for page_id in sorted_page_ids: - if page_id not in document.pages: - continue + Args: + class_label: Document class + document_text: Text content + page_images: Prepared images + sorted_page_ids: Sorted page IDs + document: Document being processed + Returns: + Tuple of (class_schema, attribute_descriptions) + """ + # Get JSON Schema for this document class + class_schema = self._get_class_schema(class_label) + attribute_descriptions = self._format_schema_for_prompt(class_schema) + + # Store context in instance variables + self._document_text = document_text + self._class_label = class_label + self._attribute_descriptions = attribute_descriptions + self._class_schema = class_schema + self._page_images = page_images + + # Prepare image URIs for Lambda + image_uris = [] + for page_id in sorted_page_ids: + if page_id in document.pages: page = document.pages[page_id] - image_uri = page.image_uri - # Just pass the values directly - prepare_image handles empty strings/None - image_content = image.prepare_image( - image_uri, target_width, target_height - ) - page_images.append(image_content) - - t2 = time.time() - logger.info(f"Time taken to read images: {t2 - t1:.2f} seconds") - - # Get extraction configuration (type-safe access, automatic type conversion) - model_id = self.config.extraction.model - temperature = ( - self.config.extraction.temperature - ) # Already float, no conversion needed! - top_k = self.config.extraction.top_k # Already float! - top_p = self.config.extraction.top_p # Already float! - max_tokens = ( - self.config.extraction.max_tokens - if self.config.extraction.max_tokens - else None - ) - system_prompt = self.config.extraction.system_prompt + if page.image_uri: + image_uris.append(page.image_uri) + self._image_uris = image_uris - # Get JSON Schema for this document class - class_schema = self._get_class_schema(class_label) - attribute_descriptions = self._format_schema_for_prompt(class_schema) + return class_schema, attribute_descriptions - # Check if schema has properties - if not, skip LLM invocation entirely - if ( - not class_schema.get(SCHEMA_PROPERTIES) - or not attribute_descriptions.strip() - ): - logger.info( - f"No attributes defined for class {class_label}, skipping LLM extraction" - ) + def _handle_empty_schema( + self, + document: Document, + section: Any, + section_info: SectionInfo, + section_id: str, + t0: float, + ) -> Document: + """ + Handle case when schema has no attributes - skip LLM and return empty result. - # Create empty result structure without invoking LLM - extracted_fields = {} - metering = { - "input_tokens": 0, - "output_tokens": 0, - "invocation_count": 0, - "total_cost": 0.0, - } - total_duration = 0.0 - parsing_succeeded = True - - # Write to S3 with empty extraction result - output = { - "document_class": {"type": class_label}, - "split_document": {"page_indices": page_indices}, - "inference_result": extracted_fields, - "metadata": { - "parsing_succeeded": parsing_succeeded, - "extraction_time_seconds": total_duration, - "skipped_due_to_empty_attributes": True, - }, - } - s3.write_content( - output, output_bucket, output_key, content_type="application/json" - ) + Args: + document: Document being processed + section: Section being processed + section_info: Section metadata + section_id: Section ID + t0: Start time + + Returns: + Updated document + """ + logger.info( + f"No attributes defined for class {section_info.class_label}, skipping LLM extraction" + ) - # Update the section with extraction result URI - section.extraction_result_uri = output_uri + # Create empty result structure + extracted_fields = {} + metering = { + "input_tokens": 0, + "output_tokens": 0, + "invocation_count": 0, + "total_cost": 0.0, + } + total_duration = 0.0 + parsing_succeeded = True + + # Write to S3 + output = { + "document_class": {"type": section_info.class_label}, + "split_document": {"page_indices": section_info.page_indices}, + "inference_result": extracted_fields, + "metadata": { + "parsing_succeeded": parsing_succeeded, + "extraction_time_seconds": total_duration, + "skipped_due_to_empty_attributes": True, + }, + } + s3.write_content( + output, + section_info.output_bucket, + section_info.output_key, + content_type="application/json", + ) - # Update document with zero metering data - document.metering = utils.merge_metering_data( - document.metering, metering - ) + # Update section and document + section.extraction_result_uri = section_info.output_uri + document.metering = utils.merge_metering_data(document.metering, metering) - t3 = time.time() - logger.info( - f"Skipped extraction for section {section_id} due to empty attributes: {t3 - t0:.2f} seconds" - ) - return document - - # Check for custom prompt Lambda function (type-safe access) - custom_lambda_arn = self.config.extraction.custom_prompt_lambda_arn - - if custom_lambda_arn and custom_lambda_arn.strip(): - logger.info(f"Using custom prompt Lambda: {custom_lambda_arn}") - - # Prepare prompt placeholders including image URIs - image_uris = [] - for page_id in sorted_page_ids: - if page_id in document.pages: - page = document.pages[page_id] - if page.image_uri: - image_uris.append(page.image_uri) - - prompt_placeholders = { - "DOCUMENT_TEXT": document_text, - "DOCUMENT_CLASS": class_label, - "ATTRIBUTE_NAMES_AND_DESCRIPTIONS": attribute_descriptions, - "DOCUMENT_IMAGE": image_uris, - } - - logger.info( - f"Lambda will receive {len(image_uris)} image URIs in DOCUMENT_IMAGE placeholder" + t3 = time.time() + logger.info( + f"Skipped extraction for section {section_id} due to empty attributes: {t3 - t0:.2f} seconds" + ) + return document + + def _build_extraction_content( + self, + document: Document, + page_images: list[Any], + ) -> tuple[list[dict[str, Any]], str]: + """ + Build prompt content (with or without custom Lambda). + + Args: + document: Document being processed + page_images: Prepared page images + + Returns: + Tuple of (content, system_prompt) + """ + system_prompt = self.config.extraction.system_prompt + custom_lambda_arn = self.config.extraction.custom_prompt_lambda_arn + + if custom_lambda_arn and custom_lambda_arn.strip(): + logger.info(f"Using custom prompt Lambda: {custom_lambda_arn}") + + prompt_placeholders = { + "DOCUMENT_TEXT": self._document_text, + "DOCUMENT_CLASS": self._class_label, + "ATTRIBUTE_NAMES_AND_DESCRIPTIONS": self._attribute_descriptions, + "DOCUMENT_IMAGE": self._image_uris, + } + + logger.info( + f"Lambda will receive {len(self._image_uris)} image URIs in DOCUMENT_IMAGE placeholder" + ) + + # Build default content for Lambda input + prompt_template = self.config.extraction.task_prompt + if prompt_template: + default_content = self._build_prompt_content( + prompt_template, page_images ) + else: + default_content = self._get_default_prompt_content() + + # Prepare Lambda payload + try: + document_dict = document.to_dict() + except Exception as e: + logger.warning(f"Error serializing document for Lambda payload: {e}") + document_dict = {"id": getattr(document, "id", "unknown")} + + payload = { + "config": self._make_json_serializable(self.config), + "prompt_placeholders": prompt_placeholders, + "default_task_prompt_content": self._make_json_serializable( + default_content + ), + "serialized_document": document_dict, + } + + # Invoke custom Lambda + lambda_result = self._invoke_custom_prompt_lambda( + custom_lambda_arn, payload + ) - # Build default content for Lambda input - prompt_template = self.config.extraction.task_prompt - if prompt_template: - # Check if task prompt contains FEW_SHOT_EXAMPLES placeholder - if "{FEW_SHOT_EXAMPLES}" in prompt_template: - default_content = self._build_content_with_few_shot_examples( - prompt_template, - document_text, - class_label, - attribute_descriptions, - page_images, - ) - else: - # Use the unified content builder for DOCUMENT_IMAGE placeholder support - default_content = ( - self._build_content_with_or_without_image_placeholder( - prompt_template, - document_text, - class_label, - attribute_descriptions, - page_images, - ) - ) - else: - # Default content if no template - task_prompt = f""" - Extract the following fields from this {class_label} document: - - {attribute_descriptions} - - Document text: - {document_text} - - Respond with a JSON object containing each field name and its extracted value. - """ - default_content = [{"text": task_prompt}] - if page_images: - for img in page_images[:20]: - default_content.append( - image.prepare_bedrock_image_attachment(img) - ) - - # Prepare Lambda payload with JSON-serializable content + # Use Lambda results + system_prompt = lambda_result.get("system_prompt", system_prompt) + content = lambda_result.get("task_prompt_content", default_content) + + logger.info("Successfully applied custom prompt from Lambda function") + else: + # Use default prompt logic + logger.info( + "No custom prompt Lambda configured - using default prompt generation" + ) + prompt_template = self.config.extraction.task_prompt + + if not prompt_template: + content = self._get_default_prompt_content() + else: try: - # Use Document's built-in to_dict() method which properly handles Status enum conversion - document_dict = document.to_dict() - except Exception as e: + content = self._build_prompt_content(prompt_template, page_images) + except ValueError as e: logger.warning( - f"Error serializing document for Lambda payload: {e}" + f"Error formatting prompt template: {str(e)}. Using default prompt." ) - document_dict = {"id": getattr(document, "id", "unknown")} + content = self._get_default_prompt_content() - # Convert image bytes to URIs in default content for JSON serialization - serializable_default_content = ( - self._convert_image_bytes_to_uris_in_content(default_content) - ) + return content, system_prompt - # Create fully serializable payload using comprehensive helper - payload = { - "config": self._make_json_serializable(self.config), - "prompt_placeholders": prompt_placeholders, - "default_task_prompt_content": serializable_default_content, - "serialized_document": document_dict, - } + def _invoke_extraction_model( + self, + content: list[dict[str, Any]], + system_prompt: str, + section_info: SectionInfo, + ) -> ExtractionResult: + """ + Invoke Bedrock model (agentic or standard) and parse response. - # Test JSON serialization before sending to Lambda to catch any remaining issues - try: - json.dumps(payload) - logger.info("Lambda payload successfully serialized") - except (TypeError, ValueError) as e: - logger.error( - f"Lambda payload still contains non-serializable data: {e}" - ) - logger.info("Using comprehensive serialization as fallback") - # Apply comprehensive serialization to entire payload - payload = self._make_json_serializable(payload) - try: - json.dumps(payload) - logger.info("Comprehensive serialization successful") - except (TypeError, ValueError) as e2: - logger.error(f"Even comprehensive serialization failed: {e2}") - # Ultimate fallback to minimal payload - payload = { - "config": { - "extraction": {"model": self.config.extraction.model} - }, - "prompt_placeholders": prompt_placeholders, - "default_task_prompt_content": [ - {"text": "Fallback content"} - ], - "serialized_document": { - "id": str(document.id), - "status": "PROCESSING", - }, - } - - # Invoke custom Lambda and get result (pass original images for restoration) - lambda_result = self._invoke_custom_prompt_lambda( - custom_lambda_arn, payload, page_images + Args: + content: Prompt content + system_prompt: System prompt + section_info: Section metadata + + Returns: + ExtractionResult with extracted fields and metering + """ + logger.info( + f"Extracting fields for {section_info.class_label} document, section" + ) + + # Get extraction config + model_id = self.config.extraction.model + temperature = self.config.extraction.temperature + top_k = self.config.extraction.top_k + top_p = self.config.extraction.top_p + max_tokens = ( + self.config.extraction.max_tokens + if self.config.extraction.max_tokens + else None + ) + + # Time the model invocation + request_start_time = time.time() + + if self.config.extraction.agentic.enabled: + if not AGENTIC_AVAILABLE: + raise ImportError( + "Agentic extraction requires Python 3.10+ and strands-agents dependencies. " + "Install with: pip install 'idp_common[agents]' or use agentic=False" ) - # Use Lambda results - system_prompt = lambda_result.get("system_prompt", system_prompt) - content = lambda_result.get("task_prompt_content", default_content) + # Create dynamic Pydantic model from JSON Schema + dynamic_model = create_pydantic_model_from_json_schema( + schema=self._class_schema, + class_label=section_info.class_label, + clean_schema=False, # Already cleaned + ) - logger.info("Successfully applied custom prompt from Lambda function") + # Log schema for debugging + model_schema = dynamic_model.model_json_schema() + logger.debug(f"Pydantic model schema for {section_info.class_label}:") + logger.debug(json.dumps(model_schema, indent=2)) + # Use agentic extraction + if isinstance(content, list): + message_prompt = {"role": "user", "content": content} else: - # Use default prompt logic when no custom Lambda is configured - logger.info( - "No custom prompt Lambda configured - using default prompt generation" - ) - prompt_template = self.config.extraction.task_prompt - - if not prompt_template: - # Default prompt if template not found - task_prompt = f""" - Extract the following fields from this {class_label} document: - - {attribute_descriptions} - - Document text: - {document_text} - - Respond with a JSON object containing each field name and its extracted value. - """ - content = [{"text": task_prompt}] - - # Add image attachments to the content (limit to 20 images as per Bedrock constraints) - if page_images: - logger.info( - f"Attaching images to prompt, for {len(page_images)} pages." - ) - # Limit to 20 images as per Bedrock constraints - for img in page_images[:20]: - content.append(image.prepare_bedrock_image_attachment(img)) - else: - # Check if task prompt contains FEW_SHOT_EXAMPLES placeholder - if "{FEW_SHOT_EXAMPLES}" in prompt_template: - content = self._build_content_with_few_shot_examples( - prompt_template, - document_text, - class_label, - attribute_descriptions, - page_images, # Pass images to the content builder - ) - else: - # Use the unified content builder for DOCUMENT_IMAGE placeholder support - try: - content = ( - self._build_content_with_or_without_image_placeholder( - prompt_template, - document_text, - class_label, - attribute_descriptions, - page_images, # Pass images to the content builder - ) - ) - except ValueError as e: - logger.warning( - f"Error formatting prompt template: {str(e)}. Using default prompt." - ) - # Fall back to default prompt if template validation fails - task_prompt = f""" - Extract the following fields from this {class_label} document: - - {attribute_descriptions} - - Document text: - {document_text} - - Respond with a JSON object containing each field name and its extracted value. - """ - content = [{"text": task_prompt}] - - # Add image attachments for fallback case - if page_images: - logger.info( - f"Attaching images to prompt, for {len(page_images)} pages." - ) - # Limit to 20 images as per Bedrock constraints - for img in page_images[:20]: - content.append( - image.prepare_bedrock_image_attachment(img) - ) + message_prompt = content + + logger.info("Using Agentic extraction") + logger.debug(f"Using input: {str(message_prompt)}") + + structured_data, response_with_metering = structured_output( + model_id=model_id, + data_format=dynamic_model, + prompt=message_prompt, + page_images=self._page_images, + config=self.config, + context="Extraction", + ) - logger.info( - f"Extracting fields for {class_label} document, section {section_id}" + extracted_fields = structured_data.model_dump() + metering = response_with_metering["metering"] + parsing_succeeded = True + else: + # Standard Bedrock invocation + response_with_metering = bedrock.invoke_model( + model_id=model_id, + system_prompt=system_prompt, + content=content, + temperature=temperature, + top_k=top_k, + top_p=top_p, + max_tokens=max_tokens, + context="Extraction", ) - # Time the model invocation - request_start_time = time.time() + extracted_text = bedrock.extract_text_from_response( + dict(response_with_metering) + ) + metering = response_with_metering["metering"] - # Type-safe boolean access - no string conversion needed! - if self.config.extraction.agentic.enabled: - if not AGENTIC_AVAILABLE: - raise ImportError( - "Agentic extraction requires Python 3.10+ and strands-agents dependencies. " - "Install with: pip install 'idp_common[agents]' or use agentic=False" - ) + # Parse response into JSON + extracted_fields = {} + parsing_succeeded = True - # Create dynamic Pydantic model from JSON Schema - # Schema is already cleaned by _clean_schema_for_prompt before being passed here - dynamic_model = create_pydantic_model_from_json_schema( - schema=class_schema, - class_label=class_label, - clean_schema=False, # Already cleaned + try: + extracted_fields = json.loads(extract_json_from_text(extracted_text)) + except Exception as e: + logger.error( + f"Error parsing LLM output - invalid JSON?: {extracted_text} - {e}" ) + logger.info("Using unparsed LLM output.") + extracted_fields = {"raw_output": extracted_text} + parsing_succeeded = False + + total_duration = time.time() - request_start_time + logger.info(f"Time taken for extraction: {total_duration:.2f} seconds") + + return ExtractionResult( + extracted_fields=extracted_fields, + metering=metering, + parsing_succeeded=parsing_succeeded, + total_duration=total_duration, + ) - # Log the Pydantic model schema for debugging - model_schema = dynamic_model.model_json_schema() - logger.debug(f"Pydantic model schema for {class_label}:") - logger.debug(json.dumps(model_schema, indent=2)) - - # Use agentic extraction with the dynamic model - # Wrap content list in proper Message format for agentic_idp compatibility - if isinstance(content, list): - message_prompt = {"role": "user", "content": content} - else: - message_prompt = content - logger.info("Using Agentic extraction") - logger.debug(f"Using input: {str(message_prompt)}") - structured_data, response_with_metering = structured_output( # pyright: ignore[reportPossiblyUnboundVariable] - model_id=model_id, - data_format=dynamic_model, - prompt=message_prompt, # pyright: ignore[reportArgumentType] - custom_instruction=system_prompt, - config=self.config, - context="Extraction", - ) + def _save_results( + self, + document: Document, + section: Any, + result: ExtractionResult, + section_info: SectionInfo, + section_id: str, + t0: float, + ) -> None: + """ + Save extraction results to S3 and update document. - # Extract the structured data as dict for compatibility with existing code - extracted_fields = structured_data.model_dump() - # Extract metering from BedrockInvokeModelResponse - metering = response_with_metering["metering"] - parsing_succeeded = True # Agentic approach always succeeds in parsing since it returns structured data + Args: + document: Document being processed + section: Section being processed + result: Extraction result + section_info: Section metadata + section_id: Section ID + t0: Start time + """ + # Write to S3 + output = { + "document_class": {"type": section_info.class_label}, + "split_document": {"page_indices": section_info.page_indices}, + "inference_result": result.extracted_fields, + "metadata": { + "parsing_succeeded": result.parsing_succeeded, + "extraction_time_seconds": result.total_duration, + }, + } + s3.write_content( + output, + section_info.output_bucket, + section_info.output_key, + content_type="application/json", + ) - else: - # Invoke Bedrock with the common library - response_with_metering = bedrock.invoke_model( - model_id=model_id, - system_prompt=system_prompt, - content=content, - temperature=temperature, - top_k=top_k, - top_p=top_p, - max_tokens=max_tokens, - context="Extraction", - ) - # For non-agentic approach, response_with_metering is BedrockInvokeModelResponse - # Extract text from response for non-agentic approach - extracted_text = bedrock.extract_text_from_response( - dict(response_with_metering) - ) - metering = response_with_metering["metering"] + # Update section and document + section.extraction_result_uri = section_info.output_uri + document.metering = utils.merge_metering_data( + document.metering, result.metering or {} + ) - # Parse response into JSON - extracted_fields = {} - parsing_succeeded = True # Flag to track if parsing was successful + t3 = time.time() + logger.info( + f"Total extraction time for section {section_id}: {t3 - t0:.2f} seconds" + ) - try: - # Try to parse the extracted text as JSON - extracted_fields = json.loads( - extract_json_from_text(extracted_text) - ) - except Exception as e: - # Handle parsing error - logger.error( - f"Error parsing LLM output - invalid JSON?: {extracted_text} - {e}" - ) - logger.info("Using unparsed LLM output.") - extracted_fields = {"raw_output": extracted_text} - parsing_succeeded = False # Mark that parsing failed - - total_duration = time.time() - request_start_time - logger.info(f"Time taken for extraction: {total_duration:.2f} seconds") - - # Write to S3 - output = { - "document_class": {"type": class_label}, - "split_document": {"page_indices": page_indices}, - "inference_result": extracted_fields, - "metadata": { - "parsing_succeeded": parsing_succeeded, - "extraction_time_seconds": total_duration, - }, - } - s3.write_content( - output, output_bucket, output_key, content_type="application/json" - ) + def process_document_section(self, document: Document, section_id: str) -> Document: + """ + Process a single section from a Document object. + + Args: + document: Document object containing section to process + section_id: ID of the section to process + + Returns: + Document: Updated Document object with extraction results for the section + """ + # Reset state + self._reset_context() + + # Validate and get section + section = self._validate_and_find_section(document, section_id) + if not section: + return document + + # Prepare section metadata + try: + section_info = self._prepare_section_info(document, section) + except ValueError: + return document - # Update the section with extraction result URI only (not the attributes themselves) - section.extraction_result_uri = output_uri + try: + t0 = time.time() - # Update document with metering data - document.metering = utils.merge_metering_data( - document.metering, metering or {} + # Load document content + document_text = self._load_document_text( + document, section_info.sorted_page_ids + ) + page_images = self._load_document_images( + document, section_info.sorted_page_ids ) - t3 = time.time() - logger.info( - f"Total extraction time for section {section_id}: {t3 - t0:.2f} seconds" + # Initialize extraction context + class_schema, attribute_descriptions = self._initialize_extraction_context( + section_info.class_label, + document_text, + page_images, + section_info.sorted_page_ids, + document, ) + # Handle empty schema case (early return) + if ( + not class_schema.get(SCHEMA_PROPERTIES) + or not attribute_descriptions.strip() + ): + return self._handle_empty_schema( + document, section, section_info, section_id, t0 + ) + + # Build prompt content + content, system_prompt = self._build_extraction_content( + document, page_images + ) + + # Invoke model + result = self._invoke_extraction_model(content, system_prompt, section_info) + + # Save results + self._save_results(document, section, result, section_info, section_id, t0) + except Exception as e: error_msg = f"Error processing section {section_id}: {str(e)}" logger.error(error_msg) From 6ba436cbfeff32420ff3da996995a4583b5f5a64 Mon Sep 17 00:00:00 2001 From: Kazmer Nagy-Betegh Date: Fri, 21 Nov 2025 17:37:17 +0000 Subject: [PATCH 06/18] fix pipeline ci failure --- .github/workflows/developer-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/developer-tests.yml b/.github/workflows/developer-tests.yml index 90f3956a..eee84472 100644 --- a/.github/workflows/developer-tests.yml +++ b/.github/workflows/developer-tests.yml @@ -59,7 +59,7 @@ jobs: - name: Install Node.js and basedpyright run: | - curl -fsSL https://deb.nodesource.com/setup_20.x | bash - + curl -fsSL https://deb.nodesource.com/setup_22.x | bash - apt-get install -y nodejs npm install -g basedpyright From 4a0ac0bd3b1abe3c7104b1b24916941d7375ecc4 Mon Sep 17 00:00:00 2001 From: Kazmer Nagy-Betegh Date: Wed, 19 Nov 2025 09:47:23 +0200 Subject: [PATCH 07/18] deps --- lib/idp_common_pkg/pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/idp_common_pkg/pyproject.toml b/lib/idp_common_pkg/pyproject.toml index 6f2fdb82..dba3dd79 100644 --- a/lib/idp_common_pkg/pyproject.toml +++ b/lib/idp_common_pkg/pyproject.toml @@ -71,7 +71,8 @@ extraction = [ # Assessment module dependencies assessment = [ - "Pillow==11.2.1", # For image handling + "Pillow==11.2.1", # For image handling + "aws-lambda-powertools>=3.2.0", # Structured logging and observability ] # Evaluation module dependencies From 0729f036bfe4ed6580aca9285667afdb7fb43448 Mon Sep 17 00:00:00 2001 From: Kazmer Nagy-Betegh Date: Mon, 24 Nov 2025 13:02:45 +0000 Subject: [PATCH 08/18] bug fixes --- .../idp_common/extraction/agentic_idp.py | 14 ++++++++++---- .../idp_common/extraction/service.py | 2 +- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py b/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py index 78060bd0..45adb18e 100644 --- a/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py +++ b/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py @@ -251,9 +251,13 @@ def extraction_tool( When you call this tool it overwrites the previous extraction, if you want to expand the extraction use jsonpatch. This tool needs to be Successfully invoked before the patch tool can be used.""" - logger.info("extraction_tool called", extra={"models_extraction": extraction}) - extraction_model = model_class(**extraction) # pyright: ignore[reportAssignmentType] - extraction_dict = extraction_model.model_dump() + # Note: The @tool decorator passes data as a dict, not as a model instance + # We need to validate it manually using the Pydantic model + extraction_model = model_class.model_validate(extraction) # pyright: ignore[reportAssignmentType] + extraction_dict = extraction_model.model_dump(mode="json") + logger.info( + "extraction_tool called", extra={"models_extraction": extraction_dict} + ) agent.state.set(key="current_extraction", value=extraction_dict) logger.debug( "Successfully stored extraction in state", @@ -1000,7 +1004,9 @@ async def structured_output_async( state={ "current_extraction": None, "images": {}, - "existing_data": existing_data.model_dump() if existing_data else None, + "existing_data": existing_data.model_dump(mode="json") + if existing_data + else None, "extraction_schema_json": schema_json, # Store for schema reminder tool }, conversation_manager=SummarizingConversationManager( diff --git a/lib/idp_common_pkg/idp_common/extraction/service.py b/lib/idp_common_pkg/idp_common/extraction/service.py index 49b027c1..ecc86729 100644 --- a/lib/idp_common_pkg/idp_common/extraction/service.py +++ b/lib/idp_common_pkg/idp_common/extraction/service.py @@ -927,7 +927,7 @@ def _invoke_extraction_model( context="Extraction", ) - extracted_fields = structured_data.model_dump() + extracted_fields = structured_data.model_dump(mode="json") metering = response_with_metering["metering"] parsing_succeeded = True else: From dcd70533be089b45c4e556ebd4bbdb70830bc9f0 Mon Sep 17 00:00:00 2001 From: Kazmer Nagy-Betegh Date: Mon, 24 Nov 2025 13:28:40 +0000 Subject: [PATCH 09/18] pydantic validator update --- lib/idp_common_pkg/idp_common/schema/pydantic_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/idp_common_pkg/idp_common/schema/pydantic_generator.py b/lib/idp_common_pkg/idp_common/schema/pydantic_generator.py index bdc21cd2..e91480fe 100644 --- a/lib/idp_common_pkg/idp_common/schema/pydantic_generator.py +++ b/lib/idp_common_pkg/idp_common/schema/pydantic_generator.py @@ -146,7 +146,7 @@ def create_json_schema_validator( def validate_against_json_schema(value: BaseModel) -> BaseModel: """Validate model data against the original JSON Schema.""" # Convert Pydantic model to dict for JSON Schema validation - data = value.model_dump() + data = value.model_dump(mode="json") try: # Validate against JSON Schema From f7c053476de8a3775880f06fdb66d44623b4eb5d Mon Sep 17 00:00:00 2001 From: "Kazmer, Nagy-Betegh" Date: Mon, 24 Nov 2025 18:21:38 +0000 Subject: [PATCH 10/18] pass the linting --- src/ui/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ui/package.json b/src/ui/package.json index 12d18735..d57b7081 100644 --- a/src/ui/package.json +++ b/src/ui/package.json @@ -4,7 +4,7 @@ "private": true, "engines": { "node": ">=22.12.0", - "npm": ">=11.0.0" + "npm": ">=10.0.0" }, "dependencies": { "@aws-amplify/ui-react": "^6.12.0", From 8f05ff70fc571262f1d156907761dd47f32b0700 Mon Sep 17 00:00:00 2001 From: "Kazmer, Nagy-Betegh" Date: Mon, 24 Nov 2025 18:23:40 +0000 Subject: [PATCH 11/18] default to empty string --- patterns/pattern-2/template.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/patterns/pattern-2/template.yaml b/patterns/pattern-2/template.yaml index 91715908..d987b160 100644 --- a/patterns/pattern-2/template.yaml +++ b/patterns/pattern-2/template.yaml @@ -927,7 +927,7 @@ Resources: review_agent_model: type: string description: Model to review the initial extraction agents work and correct it if needed, if not specified will default to the same as the extraction model. - default: Null + default: "" image: type: object sectionLabel: Image Processing Settings From 24acad145b147385d1159214d1670d3a8d7998be Mon Sep 17 00:00:00 2001 From: "Kazmer, Nagy-Betegh" Date: Mon, 24 Nov 2025 18:45:18 +0000 Subject: [PATCH 12/18] deps fix --- lib/idp_common_pkg/pyproject.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/lib/idp_common_pkg/pyproject.toml b/lib/idp_common_pkg/pyproject.toml index dba3dd79..a8a15148 100644 --- a/lib/idp_common_pkg/pyproject.toml +++ b/lib/idp_common_pkg/pyproject.toml @@ -142,6 +142,11 @@ test = [ "ruff>=0.14.0", "deepdiff>=6.0.0", # Required for BDA blueprint service tests "datamodel-code-generator>=0.25.0", # Required for schema/pydantic generator tests + # Evaluation module dependencies (from evaluation group) + "stickler-eval==0.1.3", + "genson==1.3.0", + "munkres>=1.1.4", + "numpy==1.26.4", ] # Full package with all dependencies From 9ae0feb5efed24d5a52b7db671033edc2caad51b Mon Sep 17 00:00:00 2001 From: "Kazmer, Nagy-Betegh" Date: Mon, 24 Nov 2025 18:54:32 +0000 Subject: [PATCH 13/18] disable ci comments --- .github/workflows/developer-tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/developer-tests.yml b/.github/workflows/developer-tests.yml index eee84472..726506cb 100644 --- a/.github/workflows/developer-tests.yml +++ b/.github/workflows/developer-tests.yml @@ -104,6 +104,7 @@ jobs: with: files: lib/idp_common_pkg/test-reports/test-results.xml check_name: Test Results + comment_mode: off # Disable PR comments to avoid permission issues on fork PRs - name: Code Coverage Report uses: irongut/CodeCoverageSummary@v1.3.0 From 15c8bbda8f68f108303ac262023e4befddc11d21 Mon Sep 17 00:00:00 2001 From: "Kazmer, Nagy-Betegh" Date: Wed, 26 Nov 2025 10:28:25 +0000 Subject: [PATCH 14/18] fixes --- .../idp_common/config/models.py | 34 +++++++++++-------- .../idp_common/extraction/agentic_idp.py | 25 +++++++++----- .../idp_common/extraction/service.py | 2 +- .../idp_common/utils/bedrock_utils.py | 6 ++-- .../utils/strands_agent_tools/__init__.py | 14 ++++++++ lib/idp_common_pkg/pyproject.toml | 1 + 6 files changed, 55 insertions(+), 27 deletions(-) create mode 100644 lib/idp_common_pkg/idp_common/utils/strands_agent_tools/__init__.py diff --git a/lib/idp_common_pkg/idp_common/config/models.py b/lib/idp_common_pkg/idp_common/config/models.py index a0efc832..b06f14f3 100644 --- a/lib/idp_common_pkg/idp_common/config/models.py +++ b/lib/idp_common_pkg/idp_common/config/models.py @@ -20,7 +20,14 @@ from typing import Any, Dict, List, Optional, Union, Literal, Annotated from typing_extensions import Self -from pydantic import BaseModel, ConfigDict, Field, field_validator, Discriminator, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, + Discriminator, + model_validator, +) class ImageConfig(BaseModel): @@ -79,7 +86,10 @@ class AgenticConfig(BaseModel): enabled: bool = Field(default=False, description="Enable agentic extraction") review_agent: bool = Field(default=False, description="Enable review agent") - review_agent_model: str | None= Field(default=None, description="Model used for reviewing and correcting extraction work") + review_agent_model: str | None = Field( + default=None, + description="Model used for reviewing and correcting extraction work", + ) class ExtractionConfig(BaseModel): @@ -121,17 +131,16 @@ def parse_int(cls, v: Any) -> int: if isinstance(v, str): return int(v) if v else 0 return int(v) - - @model_validator(mode="after") - def model_validator(self) -> Self: + @model_validator(mode="after") + def set_default_review_agent_model(self) -> Self: + """Set review_agent_model to extraction model if not specified.""" if not self.agentic.review_agent_model: self.agentic.review_agent_model = self.model return self - class ClassificationConfig(BaseModel): """Document classification configuration""" @@ -434,7 +443,7 @@ class ErrorAnalyzerConfig(BaseModel): "AccessDenied", "ThrottlingException", ], - description="Error patterns to search for in logs" + description="Error patterns to search for in logs", ) system_prompt: str = Field( default=""" @@ -522,11 +531,10 @@ class ErrorAnalyzerConfig(BaseModel): - No time specified: 24 hours (default) IMPORTANT: Do not include any search quality reflections, search quality scores, or meta-analysis sections in your response. Only provide the three required sections: Root Cause, Recommendations, and Evidence.""", - description="System prompt for error analyzer" + description="System prompt for error analyzer", ) parameters: ErrorAnalyzerParameters = Field( - default_factory=ErrorAnalyzerParameters, - description="Error analyzer parameters" + default_factory=ErrorAnalyzerParameters, description="Error analyzer parameters" ) @@ -646,12 +654,10 @@ class AgentsConfig(BaseModel): """Agents configuration""" error_analyzer: Optional[ErrorAnalyzerConfig] = Field( - default_factory=ErrorAnalyzerConfig, - description="Error analyzer configuration" + default_factory=ErrorAnalyzerConfig, description="Error analyzer configuration" ) chat_companion: Optional[ChatCompanionConfig] = Field( - default_factory=ChatCompanionConfig, - description="Chat companion configuration" + default_factory=ChatCompanionConfig, description="Chat companion configuration" ) diff --git a/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py b/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py index 45adb18e..77fff86c 100644 --- a/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py +++ b/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py @@ -52,7 +52,10 @@ # In Lambda: Full JSON structured logs # Outside Lambda: Human-readable format for local development logger = Logger(service="agentic_idp", level=os.getenv("LOG_LEVEL", "INFO")) -logging.getLogger("strands.models.bedrock").setLevel(logging.DEBUG) +# Configure strands bedrock logger based on environment variable +logging.getLogger("strands.models.bedrock").setLevel( + os.getenv("STRANDS_LOG_LEVEL", os.getenv("LOG_LEVEL", "INFO")) +) TargetModel = TypeVar("TargetModel", bound=BaseModel) @@ -193,6 +196,8 @@ def view_image(image_index: int, agent: Agent) -> dict: """ # Validate image index exists + if not page_images: + raise ValueError("No images available to view.") if image_index >= len(page_images): raise ValueError( f"Invalid image_index {image_index}. " @@ -729,12 +734,12 @@ def _prepare_prompt_content( else: prompt_content = [ContentBlock(text=str(prompt))] - # Add page images if provided + # Add page images if provided (limit to 20 as per Bedrock constraints) if page_images: if len(page_images) > 20: prompt_content.append( ContentBlock( - text=f"There are {len(page_images)} images, initially you'll see 20 of them, use the tools to see the rest." + text=f"There are {len(page_images)} images, initially you'll see 20 of them, use the view_image tool to see the rest." ) ) @@ -742,14 +747,14 @@ def _prepare_prompt_content( ContentBlock( image=ImageContent(format="png", source=ImageSource(bytes=img_bytes)) ) - for img_bytes in page_images + for img_bytes in page_images[:20] ] # Add existing data context if provided if existing_data: prompt_content.append( ContentBlock( - text=f"Please update the existing data using the extraction tool or patches. Existing data: {existing_data.model_dump()}" + text=f"Please update the existing data using the extraction tool or patches. Existing data: {existing_data.model_dump(mode='json')}" ) ) @@ -1014,7 +1019,7 @@ async def structured_output_async( ), ) if existing_data: - agent.state.set("current_extraction", existing_data.model_dump()) + agent.state.set("current_extraction", existing_data.model_dump(mode="json")) response, result = await _invoke_agent_for_extraction( agent=agent, @@ -1075,9 +1080,11 @@ async def structured_output_async( tools=tools, system_prompt=f"{final_system_prompt}", state={ - "current_extraction": None, + "current_extraction": result.model_dump(mode="json"), "images": {}, - "existing_data": existing_data.model_dump() if existing_data else None, + "existing_data": existing_data.model_dump(mode="json") + if existing_data + else None, "extraction_schema_json": schema_json, # Store for schema reminder tool }, conversation_manager=SummarizingConversationManager( @@ -1095,7 +1102,7 @@ async def structured_output_async( # Check if patches were applied during review updated_extraction = agent.state.get("current_extraction") - if updated_extraction != result.model_dump(): + if updated_extraction != result.model_dump(mode="json"): # Patches were applied, validate the new extraction try: result = data_format(**updated_extraction) diff --git a/lib/idp_common_pkg/idp_common/extraction/service.py b/lib/idp_common_pkg/idp_common/extraction/service.py index ecc86729..8468cd17 100644 --- a/lib/idp_common_pkg/idp_common/extraction/service.py +++ b/lib/idp_common_pkg/idp_common/extraction/service.py @@ -111,7 +111,7 @@ def __init__( self._class_label: str = "" self._attribute_descriptions: str = "" self._class_schema: dict[str, Any] = {} - self._page_images: list[Any] = [] + self._page_images: list[bytes] = [] self._image_uris: list[str] = [] # Get model_id from config for logging (type-safe access with fallback) diff --git a/lib/idp_common_pkg/idp_common/utils/bedrock_utils.py b/lib/idp_common_pkg/idp_common/utils/bedrock_utils.py index 04885781..c91d456c 100644 --- a/lib/idp_common_pkg/idp_common/utils/bedrock_utils.py +++ b/lib/idp_common_pkg/idp_common/utils/bedrock_utils.py @@ -18,7 +18,6 @@ InvokeModelRequestTypeDef, InvokeModelResponseTypeDef, ) -from pydantic_core import ArgsKwargs # Configure logger logger = logging.getLogger(__name__) @@ -48,14 +47,14 @@ async def wrapper(*args, **kwargs) -> T: def log_bedrock_invocation_error(error: Exception, attempt_num: int): """Log bedrock invocation details when an error occurs""" - # Fallback logging if extraction fails + # Fallback logging if extraction fails logger.error( "Bedrock invocation error", extra={ "function_name": func.__name__, "original_error": str(error), "max_attempts": max_retries, - "attempt_num":attempt_num + "attempt_num": attempt_num, }, ) @@ -203,6 +202,7 @@ def log_bedrock_invocation_error(error: Exception, attempt_num: int): error_code not in [ "ThrottlingException", + "throttlingException", "ModelErrorException", "ValidationException", ] diff --git a/lib/idp_common_pkg/idp_common/utils/strands_agent_tools/__init__.py b/lib/idp_common_pkg/idp_common/utils/strands_agent_tools/__init__.py new file mode 100644 index 00000000..2154b5a3 --- /dev/null +++ b/lib/idp_common_pkg/idp_common/utils/strands_agent_tools/__init__.py @@ -0,0 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 + +""" +Strands agent tools for IDP common library. +""" + +from idp_common.utils.strands_agent_tools.todo_list import ( + create_todo_list, + update_todo, + view_todo_list, +) + +__all__ = ["create_todo_list", "update_todo", "view_todo_list"] diff --git a/lib/idp_common_pkg/pyproject.toml b/lib/idp_common_pkg/pyproject.toml index a8a15148..ddd60a7e 100644 --- a/lib/idp_common_pkg/pyproject.toml +++ b/lib/idp_common_pkg/pyproject.toml @@ -177,6 +177,7 @@ agentic-extraction = [ "tabulate>=0.9.0", "aws-lambda-powertools>=3.2.0", # Structured logging and observability "datamodel-code-generator>=0.25.0", # Generate Pydantic models from JSON Schema + "mypy-boto3-bedrock-runtime>=1.39.0", # Type stubs for bedrock_utils.py ] [project.urls] From c85a5a316541c1a9a36eb3d007716d46172400f0 Mon Sep 17 00:00:00 2001 From: "Kazmer, Nagy-Betegh" Date: Wed, 26 Nov 2025 10:37:42 +0000 Subject: [PATCH 15/18] json mode for model dump --- lib/idp_common_pkg/idp_common/extraction/agentic_idp.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py b/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py index 77fff86c..cdd7b3e0 100644 --- a/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py +++ b/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py @@ -291,7 +291,8 @@ def apply_json_patches( patched_data = apply_patches_to_data(current_data, patches) validated_patched_data = model_class(**patched_data) agent.state.set( - key="current_extraction", value=validated_patched_data.model_dump() + key="current_extraction", + value=validated_patched_data.model_dump(mode="json"), ) return { @@ -303,9 +304,9 @@ def apply_json_patches( def make_buffer_data_final_extraction(agent: Agent) -> str: valid_extraction = model_class(**agent.state.get("intermediate_extraction")) - agent.state.set("current_extraction", valid_extraction.model_dump()) + agent.state.set("current_extraction", valid_extraction.model_dump(mode="json")) - return f"Successfully made the existing extraction the same as the buffer data {str(valid_extraction.model_dump())[100:]}..." + return f"Successfully made the existing extraction the same as the buffer data {str(valid_extraction.model_dump(mode='json'))[:100]}..." return extraction_tool, apply_json_patches, make_buffer_data_final_extraction @@ -1051,7 +1052,7 @@ async def structured_output_async( ContentBlock( text=f""" You have successfully extracted the following data: - {json.dumps(result.model_dump(), indent=2)} + {json.dumps(result.model_dump(mode="json"), indent=2)} Please take one final careful look at this extraction: 1. Check each field against the source document From b297e182d4f21b96b95d0a2e20ce70d345ce3c1d Mon Sep 17 00:00:00 2001 From: "Kazmer, Nagy-Betegh" Date: Wed, 26 Nov 2025 17:33:15 +0000 Subject: [PATCH 16/18] add additional retryable errors --- lib/idp_common_pkg/idp_common/utils/bedrock_utils.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/lib/idp_common_pkg/idp_common/utils/bedrock_utils.py b/lib/idp_common_pkg/idp_common/utils/bedrock_utils.py index c91d456c..1af3f814 100644 --- a/lib/idp_common_pkg/idp_common/utils/bedrock_utils.py +++ b/lib/idp_common_pkg/idp_common/utils/bedrock_utils.py @@ -30,15 +30,21 @@ def async_exponential_backoff_retry[T, **P]( max_delay: float = 32.0, exponential_base: float = 2.0, jitter: float = 0.1, - retryable_errors: list[str] | None = None, + retryable_errors: set[str] | None = None, ) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]: if not retryable_errors: - retryable_errors = [ + retryable_errors = set( [ "ThrottlingException", "throttlingException", "ModelErrorException", "ValidationException", - ] + "ServiceQuotaExceededException", + "RequestLimitExceeded", + "TooManyRequestsException", + "ServiceUnavailableException", + "RequestTimeout", + "RequestTimeoutException", + ] ) def decorator(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]: @wraps(func) From 5400f31870d27f06a98e29462deb7070cc72a244 Mon Sep 17 00:00:00 2001 From: "Kazmer, Nagy-Betegh" Date: Wed, 26 Nov 2025 17:44:52 +0000 Subject: [PATCH 17/18] additional retryable errors --- .../idp_common/utils/bedrock_utils.py | 37 +- .../tests/unit/test_bedrock_utils.py | 587 ++++++++++++++++++ 2 files changed, 612 insertions(+), 12 deletions(-) create mode 100644 lib/idp_common_pkg/tests/unit/test_bedrock_utils.py diff --git a/lib/idp_common_pkg/idp_common/utils/bedrock_utils.py b/lib/idp_common_pkg/idp_common/utils/bedrock_utils.py index 1af3f814..3f901b20 100644 --- a/lib/idp_common_pkg/idp_common/utils/bedrock_utils.py +++ b/lib/idp_common_pkg/idp_common/utils/bedrock_utils.py @@ -3,6 +3,7 @@ import logging import os import random +import re import time from collections.abc import Awaitable, Callable from functools import wraps @@ -33,18 +34,21 @@ def async_exponential_backoff_retry[T, **P]( retryable_errors: set[str] | None = None, ) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]: if not retryable_errors: - retryable_errors = set( [ - "ThrottlingException", - "throttlingException", - "ModelErrorException", - "ValidationException", - "ServiceQuotaExceededException", - "RequestLimitExceeded", - "TooManyRequestsException", - "ServiceUnavailableException", - "RequestTimeout", - "RequestTimeoutException", - ] ) + retryable_errors = set( + [ + "ThrottlingException", + "throttlingException", + "ModelErrorException", + "ValidationException", + "ServiceQuotaExceededException", + "RequestLimitExceeded", + "TooManyRequestsException", + "ServiceUnavailableException", + "serviceUnavailableException", # lowercase variant from EventStreamError + "RequestTimeout", + "RequestTimeoutException", + ] + ) def decorator(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]: @wraps(func) @@ -70,6 +74,15 @@ def log_bedrock_invocation_error(error: Exception, attempt_num: int): except botocore.exceptions.ClientError as e: error_code = e.response.get("Error", {}).get("Code") + # For EventStreamError (subclass of ClientError), the error code + # may be in a different location or need to be extracted from the message + if not error_code: + # Try to extract error code from exception message + # Format: "An error occurred (errorCode) when calling..." + match = re.search(r"\((\w+)\)", str(e)) + if match: + error_code = match.group(1) + # Log bedrock invocation details for all errors log_bedrock_invocation_error(e, attempt + 1) diff --git a/lib/idp_common_pkg/tests/unit/test_bedrock_utils.py b/lib/idp_common_pkg/tests/unit/test_bedrock_utils.py new file mode 100644 index 00000000..39e403e8 --- /dev/null +++ b/lib/idp_common_pkg/tests/unit/test_bedrock_utils.py @@ -0,0 +1,587 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 + +""" +Unit tests for the bedrock_utils module. + +Tests the async_exponential_backoff_retry and exponential_backoff_retry decorators, +including handling of ClientError, EventStreamError, and other exceptions. +""" + +import asyncio +import re +from unittest.mock import patch + +import botocore.exceptions +import pytest + + +class MockClientError(botocore.exceptions.ClientError): + """Mock ClientError for testing that mimics botocore.exceptions.ClientError""" + + def __init__(self, error_response, operation_name): + self.response = error_response + self.operation_name = operation_name + error = error_response.get("Error", {}) + msg = f"An error occurred ({error.get('Code', 'Unknown')}) when calling the {operation_name} operation: {error.get('Message', 'Unknown')}" + # Call Exception.__init__ directly to avoid ClientError's __init__ + Exception.__init__(self, msg) + + +class MockEventStreamError(MockClientError): + """Mock EventStreamError that mimics botocore.exceptions.EventStreamError + + EventStreamError is a subclass of ClientError but may have a different + response structure where error code needs to be extracted from the message. + """ + + pass + + +# Now import the module under test +from idp_common.utils.bedrock_utils import ( + async_exponential_backoff_retry, + exponential_backoff_retry, +) + + +@pytest.mark.unit +class TestAsyncExponentialBackoffRetry: + """Tests for the async_exponential_backoff_retry decorator.""" + + @pytest.mark.asyncio + async def test_successful_call_no_retry(self): + """Test that successful calls don't trigger retries.""" + call_count = 0 + + @async_exponential_backoff_retry(max_retries=3) + async def successful_func(): + nonlocal call_count + call_count += 1 + return "success" + + result = await successful_func() + + assert result == "success" + assert call_count == 1 + + @pytest.mark.asyncio + async def test_retry_on_throttling_exception(self): + """Test retry on ThrottlingException.""" + call_count = 0 + + @async_exponential_backoff_retry(max_retries=3, initial_delay=0.01) + async def throttled_func(): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise MockClientError( + { + "Error": { + "Code": "ThrottlingException", + "Message": "Rate exceeded", + } + }, + "TestOperation", + ) + return "success" + + result = await throttled_func() + + assert result == "success" + assert call_count == 3 + + @pytest.mark.asyncio + async def test_retry_on_service_unavailable_exception(self): + """Test retry on ServiceUnavailableException (uppercase).""" + call_count = 0 + + @async_exponential_backoff_retry(max_retries=3, initial_delay=0.01) + async def unavailable_func(): + nonlocal call_count + call_count += 1 + if call_count < 2: + raise MockClientError( + { + "Error": { + "Code": "ServiceUnavailableException", + "Message": "Service unavailable", + } + }, + "TestOperation", + ) + return "success" + + result = await unavailable_func() + + assert result == "success" + assert call_count == 2 + + @pytest.mark.asyncio + async def test_retry_on_event_stream_error_with_lowercase_error_code(self): + """Test retry on EventStreamError with lowercase serviceUnavailableException. + + This tests the fix for the issue where EventStreamError from ConverseStream + uses lowercase error codes like 'serviceUnavailableException'. + """ + call_count = 0 + + @async_exponential_backoff_retry(max_retries=3, initial_delay=0.01) + async def stream_error_func(): + nonlocal call_count + call_count += 1 + if call_count < 2: + # Simulate EventStreamError where error code is not in the response dict + # but is in the exception message + error = MockEventStreamError( + {"Error": {}}, # Empty error dict - code not available here + "ConverseStream", + ) + # Override the message to match actual EventStreamError format + error.args = ( + "An error occurred (serviceUnavailableException) when calling the ConverseStream operation: Bedrock is unable to process your request.", + ) + raise error + return "success" + + result = await stream_error_func() + + assert result == "success" + assert call_count == 2 + + @pytest.mark.asyncio + async def test_retry_on_event_stream_error_extracts_code_from_message(self): + """Test that error code is extracted from exception message when not in response.""" + call_count = 0 + + original_decorator = async_exponential_backoff_retry( + max_retries=3, initial_delay=0.01 + ) + + @original_decorator + async def stream_error_func(): + nonlocal call_count + call_count += 1 + if call_count < 2: + # Create error with empty Error dict but error code in message + error = MockEventStreamError( + {"Error": {}}, + "ConverseStream", + ) + error.args = ( + "An error occurred (throttlingException) when calling the ConverseStream operation: Too many requests", + ) + raise error + return "success" + + result = await stream_error_func() + + assert result == "success" + assert call_count == 2 + + @pytest.mark.asyncio + async def test_no_retry_on_non_retryable_error(self): + """Test that non-retryable errors are not retried.""" + call_count = 0 + + @async_exponential_backoff_retry(max_retries=3, initial_delay=0.01) + async def non_retryable_func(): + nonlocal call_count + call_count += 1 + raise MockClientError( + { + "Error": { + "Code": "AccessDeniedException", + "Message": "Access denied", + } + }, + "TestOperation", + ) + + with pytest.raises(MockClientError) as exc_info: + await non_retryable_func() + + assert "AccessDeniedException" in str(exc_info.value) + assert call_count == 1 + + @pytest.mark.asyncio + async def test_max_retries_exceeded(self): + """Test that exception is raised after max retries are exceeded.""" + call_count = 0 + + @async_exponential_backoff_retry(max_retries=3, initial_delay=0.01) + async def always_fails_func(): + nonlocal call_count + call_count += 1 + raise MockClientError( + {"Error": {"Code": "ThrottlingException", "Message": "Rate exceeded"}}, + "TestOperation", + ) + + with pytest.raises(MockClientError) as exc_info: + await always_fails_func() + + assert "ThrottlingException" in str(exc_info.value) + assert call_count == 3 + + @pytest.mark.asyncio + async def test_validation_exception_not_retried_by_default(self): + """Test that ValidationException without content filtering message is not retried.""" + call_count = 0 + + @async_exponential_backoff_retry(max_retries=3, initial_delay=0.01) + async def validation_error_func(): + nonlocal call_count + call_count += 1 + raise MockClientError( + { + "Error": { + "Code": "ValidationException", + "Message": "Invalid parameter", + } + }, + "TestOperation", + ) + + with pytest.raises(MockClientError): + await validation_error_func() + + assert call_count == 1 + + @pytest.mark.asyncio + async def test_validation_exception_with_content_filtering_is_retried(self): + """Test that ValidationException with content filtering message is retried.""" + call_count = 0 + + @async_exponential_backoff_retry(max_retries=3, initial_delay=0.01) + async def content_filtered_func(): + nonlocal call_count + call_count += 1 + if call_count < 2: + raise MockClientError( + { + "Error": { + "Code": "ValidationException", + "Message": "Output blocked by content filtering policy", + } + }, + "TestOperation", + ) + return "success" + + result = await content_filtered_func() + + assert result == "success" + assert call_count == 2 + + @pytest.mark.asyncio + async def test_non_client_error_not_retried(self): + """Test that non-ClientError exceptions are not retried.""" + call_count = 0 + + @async_exponential_backoff_retry(max_retries=3, initial_delay=0.01) + async def generic_error_func(): + nonlocal call_count + call_count += 1 + raise ValueError("Some other error") + + with pytest.raises(ValueError): + await generic_error_func() + + assert call_count == 1 + + @pytest.mark.asyncio + async def test_custom_retryable_errors(self): + """Test that custom retryable errors can be specified.""" + call_count = 0 + + @async_exponential_backoff_retry( + max_retries=3, + initial_delay=0.01, + retryable_errors={"CustomRetryableError"}, + ) + async def custom_error_func(): + nonlocal call_count + call_count += 1 + if call_count < 2: + raise MockClientError( + { + "Error": { + "Code": "CustomRetryableError", + "Message": "Custom error", + } + }, + "TestOperation", + ) + return "success" + + result = await custom_error_func() + + assert result == "success" + assert call_count == 2 + + @pytest.mark.asyncio + async def test_exponential_backoff_increases_delay(self): + """Test that delay increases exponentially between retries.""" + delays = [] + call_count = 0 + + @async_exponential_backoff_retry( + max_retries=4, + initial_delay=0.1, + exponential_base=2.0, + jitter=0.0, # Disable jitter for predictable delays + ) + async def tracking_func(): + nonlocal call_count + call_count += 1 + raise MockClientError( + {"Error": {"Code": "ThrottlingException", "Message": "Rate exceeded"}}, + "TestOperation", + ) + + # Patch asyncio.sleep to capture delay values + original_sleep = asyncio.sleep + + async def mock_sleep(delay): + delays.append(delay) + await original_sleep(0.001) # Actually sleep a tiny bit + + with patch("asyncio.sleep", mock_sleep): + with pytest.raises(MockClientError): + await tracking_func() + + assert call_count == 4 + assert len(delays) == 3 # 3 retries = 3 sleeps + # Check delays are increasing (with some tolerance for jitter) + assert delays[0] < delays[1] < delays[2] + + @pytest.mark.asyncio + async def test_all_retryable_error_codes(self): + """Test that all documented retryable error codes are retried.""" + retryable_codes = [ + "ThrottlingException", + "throttlingException", + "ModelErrorException", + "ServiceQuotaExceededException", + "RequestLimitExceeded", + "TooManyRequestsException", + "ServiceUnavailableException", + "serviceUnavailableException", + "RequestTimeout", + "RequestTimeoutException", + ] + + for error_code in retryable_codes: + call_count = 0 + + @async_exponential_backoff_retry(max_retries=2, initial_delay=0.01) + async def retryable_func(): + nonlocal call_count + call_count += 1 + if call_count < 2: + raise MockClientError( + { + "Error": { + "Code": error_code, + "Message": f"Test {error_code}", + } + }, + "TestOperation", + ) + return "success" + + result = await retryable_func() + assert result == "success", f"Failed for error code: {error_code}" + assert call_count == 2, ( + f"Expected 2 calls for {error_code}, got {call_count}" + ) + + +@pytest.mark.unit +class TestExponentialBackoffRetry: + """Tests for the synchronous exponential_backoff_retry decorator.""" + + def test_successful_call_no_retry(self): + """Test that successful calls don't trigger retries.""" + call_count = 0 + + @exponential_backoff_retry(max_retries=3) + def successful_func(): + nonlocal call_count + call_count += 1 + return "success" + + result = successful_func() + + assert result == "success" + assert call_count == 1 + + def test_retry_on_throttling_exception(self): + """Test retry on ThrottlingException.""" + call_count = 0 + + @exponential_backoff_retry(max_retries=3, initial_delay=0.01) + def throttled_func(): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise MockClientError( + { + "Error": { + "Code": "ThrottlingException", + "Message": "Rate exceeded", + } + }, + "TestOperation", + ) + return "success" + + result = throttled_func() + + assert result == "success" + assert call_count == 3 + + def test_max_retries_exceeded(self): + """Test that exception is raised after max retries are exceeded.""" + call_count = 0 + + @exponential_backoff_retry(max_retries=3, initial_delay=0.01) + def always_fails_func(): + nonlocal call_count + call_count += 1 + raise MockClientError( + {"Error": {"Code": "ThrottlingException", "Message": "Rate exceeded"}}, + "TestOperation", + ) + + with pytest.raises(MockClientError) as exc_info: + always_fails_func() + + assert "ThrottlingException" in str(exc_info.value) + assert call_count == 3 + + def test_non_client_error_not_retried(self): + """Test that non-ClientError exceptions are not retried.""" + call_count = 0 + + @exponential_backoff_retry(max_retries=3, initial_delay=0.01) + def generic_error_func(): + nonlocal call_count + call_count += 1 + raise ValueError("Some other error") + + with pytest.raises(ValueError): + generic_error_func() + + assert call_count == 1 + + +@pytest.mark.unit +class TestEventStreamErrorHandling: + """Specific tests for EventStreamError handling from ConverseStream API.""" + + @pytest.mark.asyncio + async def test_event_stream_error_service_unavailable_lowercase(self): + """Test the exact error format from the reported issue. + + This reproduces the error: + EventStreamError: An error occurred (serviceUnavailableException) when calling + the ConverseStream operation: Bedrock is unable to process your request. + """ + call_count = 0 + + @async_exponential_backoff_retry(max_retries=5, initial_delay=0.01) + async def converse_stream_func(): + nonlocal call_count + call_count += 1 + if call_count < 3: + # Simulate the exact error from the issue + error = MockEventStreamError( + {"Error": {}}, # Empty - error code not in response dict + "ConverseStream", + ) + # Override message to match actual format + error.args = ( + "An error occurred (serviceUnavailableException) when calling the ConverseStream operation: Bedrock is unable to process your request.", + ) + raise error + return {"result": "success"} + + result = await converse_stream_func() + + assert result == {"result": "success"} + assert call_count == 3 # 2 failures + 1 success + + @pytest.mark.asyncio + async def test_event_stream_error_with_response_code(self): + """Test EventStreamError when error code is in response dict.""" + call_count = 0 + + @async_exponential_backoff_retry(max_retries=3, initial_delay=0.01) + async def converse_stream_func(): + nonlocal call_count + call_count += 1 + if call_count < 2: + raise MockEventStreamError( + { + "Error": { + "Code": "ServiceUnavailableException", + "Message": "Service unavailable", + } + }, + "ConverseStream", + ) + return "success" + + result = await converse_stream_func() + + assert result == "success" + assert call_count == 2 + + @pytest.mark.asyncio + async def test_event_stream_error_unknown_code_not_retried(self): + """Test EventStreamError with unknown error code is not retried.""" + call_count = 0 + + @async_exponential_backoff_retry(max_retries=3, initial_delay=0.01) + async def converse_stream_func(): + nonlocal call_count + call_count += 1 + error = MockEventStreamError( + {"Error": {}}, + "ConverseStream", + ) + error.args = ( + "An error occurred (unknownException) when calling the ConverseStream operation: Unknown error", + ) + raise error + + with pytest.raises(MockEventStreamError): + await converse_stream_func() + + assert call_count == 1 + + def test_error_code_extraction_regex(self): + """Test that the regex correctly extracts error codes from exception messages.""" + test_cases = [ + ( + "An error occurred (serviceUnavailableException) when calling the ConverseStream operation", + "serviceUnavailableException", + ), + ( + "An error occurred (ThrottlingException) when calling the Converse operation", + "ThrottlingException", + ), + ( + "An error occurred (ModelErrorException) when calling the InvokeModel operation: Model error", + "ModelErrorException", + ), + ] + + for message, expected_code in test_cases: + match = re.search(r"\((\w+)\)", message) + assert match is not None, f"Failed to match: {message}" + assert match.group(1) == expected_code, ( + f"Expected {expected_code}, got {match.group(1)}" + ) From ee455ebace7e114bc3657fa0d2637eade80991e5 Mon Sep 17 00:00:00 2001 From: "Kazmer, Nagy-Betegh" Date: Wed, 26 Nov 2025 17:51:20 +0000 Subject: [PATCH 18/18] linting fix --- lib/idp_common_pkg/tests/unit/test_bedrock_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/idp_common_pkg/tests/unit/test_bedrock_utils.py b/lib/idp_common_pkg/tests/unit/test_bedrock_utils.py index 39e403e8..3d2c317b 100644 --- a/lib/idp_common_pkg/tests/unit/test_bedrock_utils.py +++ b/lib/idp_common_pkg/tests/unit/test_bedrock_utils.py @@ -38,8 +38,8 @@ class MockEventStreamError(MockClientError): pass -# Now import the module under test -from idp_common.utils.bedrock_utils import ( +# Now import the module under test (must be after mock classes are defined) +from idp_common.utils.bedrock_utils import ( # noqa: E402 async_exponential_backoff_retry, exponential_backoff_retry, )