11"""
22Base for Parameter providers
33"""
4+ from __future__ import annotations
45
56import base64
67import json
78from abc import ABC , abstractmethod
8- from collections import namedtuple
99from datetime import datetime , timedelta
10- from typing import TYPE_CHECKING , Any , Dict , Optional , Tuple , Type , Union
10+ from typing import (
11+ TYPE_CHECKING ,
12+ Any ,
13+ Callable ,
14+ Dict ,
15+ NamedTuple ,
16+ Optional ,
17+ Tuple ,
18+ Type ,
19+ Union ,
20+ cast ,
21+ overload ,
22+ )
1123
1224import boto3
1325from botocore .config import Config
1426
27+ from aws_lambda_powertools .utilities .parameters .types import TransformOptions
28+
1529from .exceptions import GetParameterError , TransformParameterError
1630
1731if TYPE_CHECKING :
2236
2337
2438DEFAULT_MAX_AGE_SECS = 5
25- ExpirableValue = namedtuple ("ExpirableValue" , ["value" , "ttl" ])
2639# These providers will be dynamically initialized on first use of the helper functions
2740DEFAULT_PROVIDERS : Dict [str , Any ] = {}
2841TRANSFORM_METHOD_JSON = "json"
2942TRANSFORM_METHOD_BINARY = "binary"
3043SUPPORTED_TRANSFORM_METHODS = [TRANSFORM_METHOD_JSON , TRANSFORM_METHOD_BINARY ]
3144ParameterClients = Union ["AppConfigDataClient" , "SecretsManagerClient" , "SSMClient" ]
3245
46+ TRANSFORM_METHOD_MAPPING = {
47+ TRANSFORM_METHOD_JSON : json .loads ,
48+ TRANSFORM_METHOD_BINARY : base64 .b64decode ,
49+ ".json" : json .loads ,
50+ ".binary" : base64 .b64decode ,
51+ None : lambda x : x ,
52+ }
53+
54+
55+ class ExpirableValue (NamedTuple ):
56+ value : str | bytes | Dict [str , Any ]
57+ ttl : datetime
58+
3359
3460class BaseProvider (ABC ):
3561 """
3662 Abstract Base Class for Parameter providers
3763 """
3864
39- store : Any = None
65+ store : Dict [ Tuple [ str , TransformOptions ], ExpirableValue ]
4066
4167 def __init__ (self ):
4268 """
4369 Initialize the base provider
4470 """
4571
46- self .store = {}
72+ self .store : Dict [ Tuple [ str , TransformOptions ], ExpirableValue ] = {}
4773
48- def _has_not_expired (self , key : Tuple [str , Optional [ str ] ]) -> bool :
74+ def has_not_expired_in_cache (self , key : Tuple [str , TransformOptions ]) -> bool :
4975 return key in self .store and self .store [key ].ttl >= datetime .now ()
5076
5177 def get (
5278 self ,
5379 name : str ,
5480 max_age : int = DEFAULT_MAX_AGE_SECS ,
55- transform : Optional [ str ] = None ,
81+ transform : TransformOptions = None ,
5682 force_fetch : bool = False ,
5783 ** sdk_options ,
5884 ) -> Optional [Union [str , dict , bytes ]]:
@@ -95,7 +121,7 @@ def get(
95121 value : Optional [Union [str , bytes , dict ]] = None
96122 key = (name , transform )
97123
98- if not force_fetch and self ._has_not_expired (key ):
124+ if not force_fetch and self .has_not_expired_in_cache (key ):
99125 return self .store [key ].value
100126
101127 try :
@@ -105,11 +131,11 @@ def get(
105131 raise GetParameterError (str (exc ))
106132
107133 if transform :
108- if isinstance (value , bytes ):
109- value = value .decode ("utf-8" )
110- value = transform_value (value , transform )
134+ value = transform_value (key = name , value = value , transform = transform , raise_on_transform_error = True )
111135
112- self .store [key ] = ExpirableValue (value , datetime .now () + timedelta (seconds = max_age ))
136+ # NOTE: don't cache None, as they might've been failed transforms and may be corrected
137+ if value is not None :
138+ self .store [key ] = ExpirableValue (value , datetime .now () + timedelta (seconds = max_age ))
113139
114140 return value
115141
@@ -124,7 +150,7 @@ def get_multiple(
124150 self ,
125151 path : str ,
126152 max_age : int = DEFAULT_MAX_AGE_SECS ,
127- transform : Optional [ str ] = None ,
153+ transform : TransformOptions = None ,
128154 raise_on_transform_error : bool = False ,
129155 force_fetch : bool = False ,
130156 ** sdk_options ,
@@ -160,8 +186,8 @@ def get_multiple(
160186 """
161187 key = (path , transform )
162188
163- if not force_fetch and self ._has_not_expired (key ):
164- return self .store [key ].value
189+ if not force_fetch and self .has_not_expired_in_cache (key ):
190+ return self .store [key ].value # type: ignore # need to revisit entire typing here
165191
166192 try :
167193 values = self ._get_multiple (path , ** sdk_options )
@@ -170,13 +196,8 @@ def get_multiple(
170196 raise GetParameterError (str (exc ))
171197
172198 if transform :
173- transformed_values : dict = {}
174- for (item , value ) in values .items ():
175- _transform = get_transform_method (item , transform )
176- if not _transform :
177- continue
178- transformed_values [item ] = transform_value (value , _transform , raise_on_transform_error )
179- values .update (transformed_values )
199+ values .update (transform_value (values , transform , raise_on_transform_error ))
200+
180201 self .store [key ] = ExpirableValue (values , datetime .now () + timedelta (seconds = max_age ))
181202
182203 return values
@@ -191,6 +212,12 @@ def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]:
191212 def clear_cache (self ):
192213 self .store .clear ()
193214
215+ def add_to_cache (self , key : Tuple [str , TransformOptions ], value : Any , max_age : int ):
216+ if max_age <= 0 :
217+ return
218+
219+ self .store [key ] = ExpirableValue (value , datetime .now () + timedelta (seconds = max_age ))
220+
194221 @staticmethod
195222 def _build_boto3_client (
196223 service_name : str ,
@@ -258,57 +285,81 @@ def _build_boto3_resource_client(
258285 return session .resource (service_name = service_name , config = config , endpoint_url = endpoint_url )
259286
260287
261- def get_transform_method (key : str , transform : Optional [ str ] = None ) -> Optional [ str ]:
288+ def get_transform_method (value : str , transform : TransformOptions = None ) -> Callable [..., Any ]:
262289 """
263290 Determine the transform method
264291
265292 Examples
266293 -------
267- >>> get_transform_method("key", "any_other_value")
294+ >>> get_transform_method("key","any_other_value")
268295 'any_other_value'
269- >>> get_transform_method("key.json", "auto")
296+ >>> get_transform_method("key.json","auto")
270297 'json'
271- >>> get_transform_method("key.binary", "auto")
298+ >>> get_transform_method("key.binary","auto")
272299 'binary'
273- >>> get_transform_method("key", "auto")
300+ >>> get_transform_method("key","auto")
274301 None
275- >>> get_transform_method("key", None)
302+ >>> get_transform_method("key",None)
276303 None
277304
278305 Parameters
279306 ---------
280- key : str
281- Only used when the tranform is "auto".
307+ value : str
308+ Only used when the transform is "auto".
282309 transform: str, optional
283310 Original transform method, only "auto" will try to detect the transform method by the key
284311
285312 Returns
286313 ------
287- Optional[str]:
288- The transform method either when transform is "auto" then None, "json" or "binary" is returned
289- or the original transform method
314+ Callable:
315+ Transform function could be json.loads, base64.b64decode, or a lambda that echo the str value
290316 """
291- if transform != "auto" :
292- return transform
317+ transform_method = TRANSFORM_METHOD_MAPPING .get (transform )
318+
319+ if transform == "auto" :
320+ key_suffix = value .rsplit ("." )[- 1 ]
321+ transform_method = TRANSFORM_METHOD_MAPPING .get (key_suffix , TRANSFORM_METHOD_MAPPING [None ])
322+
323+ return cast (Callable , transform_method ) # https://github.com/python/mypy/issues/10740
324+
325+
326+ @overload
327+ def transform_value (
328+ value : Dict [str , Any ],
329+ transform : TransformOptions ,
330+ raise_on_transform_error : bool = False ,
331+ key : str = "" ,
332+ ) -> Dict [str , Any ]:
333+ ...
334+
293335
294- for transform_method in SUPPORTED_TRANSFORM_METHODS :
295- if key .endswith ("." + transform_method ):
296- return transform_method
297- return None
336+ @overload
337+ def transform_value (
338+ value : Union [str , bytes , Dict [str , Any ]],
339+ transform : TransformOptions ,
340+ raise_on_transform_error : bool = False ,
341+ key : str = "" ,
342+ ) -> Optional [Union [str , bytes , Dict [str , Any ]]]:
343+ ...
298344
299345
300346def transform_value (
301- value : str , transform : str , raise_on_transform_error : Optional [bool ] = True
302- ) -> Optional [Union [dict , bytes ]]:
347+ value : Union [str , bytes , Dict [str , Any ]],
348+ transform : TransformOptions ,
349+ raise_on_transform_error : bool = True ,
350+ key : str = "" ,
351+ ) -> Optional [Union [str , bytes , Dict [str , Any ]]]:
303352 """
304- Apply a transform to a value
353+ Transform a value using one of the available options.
305354
306355 Parameters
307356 ---------
308357 value: str
309358 Parameter value to transform
310359 transform: str
311- Type of transform, supported values are "json" and "binary"
360+ Type of transform, supported values are "json", "binary", and "auto" based on suffix (.json, .binary)
361+ key: str
362+ Parameter key when transform is auto to infer its transform method
312363 raise_on_transform_error: bool, optional
313364 Raises an exception if any transform fails, otherwise this will
314365 return a None value for each transform that failed
@@ -318,18 +369,41 @@ def transform_value(
318369 TransformParameterError:
319370 When the parameter value could not be transformed
320371 """
372+ # Maintenance: For v3, we should consider returning the original value for soft transform failures.
373+
374+ err_msg = "Unable to transform value using '{transform}' transform: {exc}"
375+
376+ if isinstance (value , bytes ):
377+ value = value .decode ("utf-8" )
378+
379+ if isinstance (value , dict ):
380+ # NOTE: We must handle partial failures when receiving multiple values
381+ # where one of the keys might fail during transform, e.g. `{"a": "valid", "b": "{"}`
382+ # expected: `{"a": "valid", "b": None}`
383+
384+ transformed_values : Dict [str , Any ] = {}
385+ for dict_key , dict_value in value .items ():
386+ transform_method = get_transform_method (value = dict_key , transform = transform )
387+ try :
388+ transformed_values [dict_key ] = transform_method (dict_value )
389+ except Exception as exc :
390+ if raise_on_transform_error :
391+ raise TransformParameterError (err_msg .format (transform = transform , exc = exc )) from exc
392+ transformed_values [dict_key ] = None
393+ return transformed_values
394+
395+ if transform == "auto" :
396+ # key="a.json", value='{"a": "b"}', or key="a.binary", value="b64_encoded"
397+ transform_method = get_transform_method (value = key , transform = transform )
398+ else :
399+ # value='{"key": "value"}
400+ transform_method = get_transform_method (value = value , transform = transform )
321401
322402 try :
323- if transform == TRANSFORM_METHOD_JSON :
324- return json .loads (value )
325- elif transform == TRANSFORM_METHOD_BINARY :
326- return base64 .b64decode (value )
327- else :
328- raise ValueError (f"Invalid transform type '{ transform } '" )
329-
403+ return transform_method (value )
330404 except Exception as exc :
331405 if raise_on_transform_error :
332- raise TransformParameterError (str ( exc ))
406+ raise TransformParameterError (err_msg . format ( transform = transform , exc = exc )) from exc
333407 return None
334408
335409
0 commit comments