44Batch processing utilities
55"""
66import copy
7+ import inspect
78import logging
89import sys
910from abc import ABC , abstractmethod
1516from aws_lambda_powertools .utilities .data_classes .dynamo_db_stream_event import DynamoDBRecord
1617from aws_lambda_powertools .utilities .data_classes .kinesis_stream_event import KinesisStreamRecord
1718from aws_lambda_powertools .utilities .data_classes .sqs_event import SQSRecord
19+ from aws_lambda_powertools .utilities .typing import LambdaContext
1820
1921logger = logging .getLogger (__name__ )
2022
@@ -55,6 +57,8 @@ class BasePartialProcessor(ABC):
5557 Abstract class for batch processors.
5658 """
5759
60+ lambda_context : LambdaContext
61+
5862 def __init__ (self ):
5963 self .success_messages : List [BatchEventTypes ] = []
6064 self .fail_messages : List [BatchEventTypes ] = []
@@ -94,7 +98,7 @@ def __enter__(self):
9498 def __exit__ (self , exception_type , exception_value , traceback ):
9599 self ._clean ()
96100
97- def __call__ (self , records : List [dict ], handler : Callable ):
101+ def __call__ (self , records : List [dict ], handler : Callable , lambda_context : Optional [ LambdaContext ] = None ):
98102 """
99103 Set instance attributes before execution
100104
@@ -107,6 +111,31 @@ def __call__(self, records: List[dict], handler: Callable):
107111 """
108112 self .records = records
109113 self .handler = handler
114+
115+ # NOTE: If a record handler has `lambda_context` parameter in its function signature, we inject it.
116+ # This is the earliest we can inspect for signature to prevent impacting performance.
117+ #
118+ # Mechanism:
119+ #
120+ # 1. When using the `@batch_processor` decorator, this happens automatically.
121+ # 2. When using the context manager, customers have to include `lambda_context` param.
122+ #
123+ # Scenario: Injects Lambda context
124+ #
125+ # def record_handler(record, lambda_context): ... # noqa: E800
126+ # with processor(records=batch, handler=record_handler, lambda_context=context): ... # noqa: E800
127+ #
128+ # Scenario: Does NOT inject Lambda context (default)
129+ #
130+ # def record_handler(record): pass # noqa: E800
131+ # with processor(records=batch, handler=record_handler): ... # noqa: E800
132+ #
133+ if lambda_context is None :
134+ self ._handler_accepts_lambda_context = False
135+ else :
136+ self .lambda_context = lambda_context
137+ self ._handler_accepts_lambda_context = "lambda_context" in inspect .signature (self .handler ).parameters
138+
110139 return self
111140
112141 def success_handler (self , record , result : Any ) -> SuccessResponse :
@@ -155,7 +184,7 @@ def failure_handler(self, record, exception: ExceptionInfo) -> FailureResponse:
155184
156185@lambda_handler_decorator
157186def batch_processor (
158- handler : Callable , event : Dict , context : Dict , record_handler : Callable , processor : BasePartialProcessor
187+ handler : Callable , event : Dict , context : LambdaContext , record_handler : Callable , processor : BasePartialProcessor
159188):
160189 """
161190 Middleware to handle batch event processing
@@ -166,7 +195,7 @@ def batch_processor(
166195 Lambda's handler
167196 event: Dict
168197 Lambda's Event
169- context: Dict
198+ context: LambdaContext
170199 Lambda's Context
171200 record_handler: Callable
172201 Callable to process each record from the batch
@@ -193,7 +222,7 @@ def batch_processor(
193222 """
194223 records = event ["Records" ]
195224
196- with processor (records , record_handler ):
225+ with processor (records , record_handler , lambda_context = context ):
197226 processor .process ()
198227
199228 return handler (event , context )
@@ -365,7 +394,11 @@ def _process_record(self, record: dict) -> Union[SuccessResponse, FailureRespons
365394 """
366395 data = self ._to_batch_type (record = record , event_type = self .event_type , model = self .model )
367396 try :
368- result = self .handler (record = data )
397+ if self ._handler_accepts_lambda_context :
398+ result = self .handler (record = data , lambda_context = self .lambda_context )
399+ else :
400+ result = self .handler (record = data )
401+
369402 return self .success_handler (record = record , result = result )
370403 except Exception :
371404 return self .failure_handler (record = data , exception = sys .exc_info ())
0 commit comments