99import functools
1010import logging
1111import warnings
12- from typing import TYPE_CHECKING , Any , Callable , Mapping , Sequence , overload
12+ from copy import deepcopy
13+ from typing import TYPE_CHECKING , Any , Callable , Mapping , Sequence
1314
1415from jsonpath_ng .ext import parse
1516
1819 DataMaskingUnsupportedTypeError ,
1920)
2021from aws_lambda_powertools .utilities .data_masking .provider import BaseProvider
22+ from aws_lambda_powertools .warnings import PowertoolsUserWarning
2123
2224if TYPE_CHECKING :
2325 from numbers import Number
@@ -67,11 +69,39 @@ def encrypt(
6769 provider_options : dict | None = None ,
6870 ** encryption_context : str ,
6971 ) -> str :
72+ """
73+ Encrypt data using the configured encryption provider.
74+
75+ Parameters
76+ ----------
77+ data : dict, Mapping, Sequence, or Number
78+ The data to encrypt.
79+ provider_options : dict, optional
80+ Provider-specific options for encryption.
81+ **encryption_context : str
82+ Additional key-value pairs for encryption context.
83+
84+ Returns
85+ -------
86+ str
87+ The encrypted data as a base64-encoded string.
88+
89+ Example
90+ --------
91+
92+ encryption_provider = AWSEncryptionSDKProvider(keys=[KMS_KEY_ARN])
93+ data_masker = DataMasking(provider=encryption_provider)
94+ encrypted = data_masker.encrypt({"secret": "value"})
95+ """
7096 return self ._apply_action (
7197 data = data ,
7298 fields = None ,
7399 action = self .provider .encrypt ,
74100 provider_options = provider_options or {},
101+ dynamic_mask = None ,
102+ custom_mask = None ,
103+ regex_pattern = None ,
104+ mask_format = None ,
75105 ** encryption_context ,
76106 )
77107
@@ -81,37 +111,104 @@ def decrypt(
81111 provider_options : dict | None = None ,
82112 ** encryption_context : str ,
83113 ) -> Any :
114+ """
115+ Decrypt data using the configured encryption provider.
116+
117+ Parameters
118+ ----------
119+ data : dict, Mapping, Sequence, or Number
120+ The data to encrypt.
121+ provider_options : dict, optional
122+ Provider-specific options for encryption.
123+ **encryption_context : str
124+ Additional key-value pairs for encryption context.
125+
126+ Returns
127+ -------
128+ str
129+ The encrypted data as a base64-encoded string.
130+
131+ Example
132+ --------
133+
134+ encryption_provider = AWSEncryptionSDKProvider(keys=[KMS_KEY_ARN])
135+ data_masker = DataMasking(provider=encryption_provider)
136+ encrypted = data_masker.decrypt(encrypted_data)
137+ """
138+
84139 return self ._apply_action (
85140 data = data ,
86141 fields = None ,
87142 action = self .provider .decrypt ,
88143 provider_options = provider_options or {},
144+ dynamic_mask = None ,
145+ custom_mask = None ,
146+ regex_pattern = None ,
147+ mask_format = None ,
89148 ** encryption_context ,
90149 )
91150
92- @overload
93- def erase (self , data , fields : None ) -> str : ...
94-
95- @overload
96- def erase (self , data : list , fields : list [str ]) -> list [str ]: ...
97-
98- @overload
99- def erase (self , data : tuple , fields : list [str ]) -> tuple [str ]: ...
151+ def erase (
152+ self ,
153+ data : Any ,
154+ fields : list [str ] | None = None ,
155+ * ,
156+ dynamic_mask : bool | None = None ,
157+ custom_mask : str | None = None ,
158+ regex_pattern : str | None = None ,
159+ mask_format : str | None = None ,
160+ masking_rules : dict | None = None ,
161+ ) -> Any :
162+ """
163+ Erase or mask sensitive data in the input.
100164
101- @overload
102- def erase (self , data : dict , fields : list [str ]) -> dict : ...
165+ Parameters
166+ ----------
167+ data : Any
168+ The data to be erased or masked.
169+ fields : list of str, optional
170+ List of field names to be erased or masked.
171+ dynamic_mask : bool, optional
172+ Whether to use dynamic masking.
173+ custom_mask : str, optional
174+ Custom mask to apply instead of the default.
175+ regex_pattern : str, optional
176+ Regular expression pattern for identifying data to mask.
177+ mask_format : str, optional
178+ Format string for the mask.
179+ masking_rules : dict, optional
180+ Dictionary of custom masking rules.
103181
104- def erase (self , data : Sequence | Mapping , fields : list [str ] | None = None ) -> str | list [str ] | tuple [str ] | dict :
105- return self ._apply_action (data = data , fields = fields , action = self .provider .erase )
182+ Returns
183+ -------
184+ Any
185+ The data with sensitive information erased or masked.
186+ """
187+ if masking_rules :
188+ return self ._apply_masking_rules (data = data , masking_rules = masking_rules )
189+ else :
190+ return self ._apply_action (
191+ data = data ,
192+ fields = fields ,
193+ action = self .provider .erase ,
194+ dynamic_mask = dynamic_mask ,
195+ custom_mask = custom_mask ,
196+ regex_pattern = regex_pattern ,
197+ mask_format = mask_format ,
198+ )
106199
107200 def _apply_action (
108201 self ,
109202 data ,
110203 fields : list [str ] | None ,
111204 action : Callable ,
112205 provider_options : dict | None = None ,
113- ** encryption_context : str ,
114- ):
206+ dynamic_mask : bool | None = None ,
207+ custom_mask : str | None = None ,
208+ regex_pattern : str | None = None ,
209+ mask_format : str | None = None ,
210+ ** kwargs : Any ,
211+ ) -> Any :
115212 """
116213 Helper method to determine whether to apply a given action to the entire input data
117214 or to specific fields if the 'fields' argument is specified.
@@ -127,8 +224,6 @@ def _apply_action(
127224 and returns the modified value.
128225 provider_options : dict
129226 Provider specific keyword arguments to propagate; used as an escape hatch.
130- encryption_context: str
131- Encryption context to use in encrypt and decrypt operations.
132227
133228 Returns
134229 -------
@@ -143,18 +238,34 @@ def _apply_action(
143238 fields = fields ,
144239 action = action ,
145240 provider_options = provider_options ,
146- ** encryption_context ,
241+ dynamic_mask = dynamic_mask ,
242+ custom_mask = custom_mask ,
243+ regex_pattern = regex_pattern ,
244+ mask_format = mask_format ,
245+ ** kwargs ,
147246 )
148247 else :
149248 logger .debug (f"Running action { action .__name__ } with the entire data" )
150- return action (data = data , provider_options = provider_options , ** encryption_context )
249+ return action (
250+ data = data ,
251+ provider_options = provider_options ,
252+ dynamic_mask = dynamic_mask ,
253+ custom_mask = custom_mask ,
254+ regex_pattern = regex_pattern ,
255+ mask_format = mask_format ,
256+ ** kwargs ,
257+ )
151258
152259 def _apply_action_to_fields (
153260 self ,
154261 data : dict | str ,
155262 fields : list ,
156263 action : Callable ,
157264 provider_options : dict | None = None ,
265+ dynamic_mask : bool | None = None ,
266+ custom_mask : str | None = None ,
267+ regex_pattern : str | None = None ,
268+ mask_format : str | None = None ,
158269 ** encryption_context : str ,
159270 ) -> dict | str :
160271 """
@@ -201,8 +312,10 @@ def _apply_action_to_fields(
201312 new_dict = {'a': {'b': {'c': '*****'}}, 'x': {'y': '*****'}}
202313 ```
203314 """
315+ if not fields :
316+ raise ValueError ("Fields parameter cannot be empty" )
204317
205- data_parsed : dict = self ._normalize_data_to_parse (fields , data )
318+ data_parsed : dict = self ._normalize_data_to_parse (data )
206319
207320 # For in-place updates, json_parse accepts a callback function
208321 # this function must receive 3 args: field_value, fields, field_name
@@ -211,6 +324,10 @@ def _apply_action_to_fields(
211324 self ._call_action ,
212325 action = action ,
213326 provider_options = provider_options ,
327+ dynamic_mask = dynamic_mask ,
328+ custom_mask = custom_mask ,
329+ regex_pattern = regex_pattern ,
330+ mask_format = mask_format ,
214331 ** encryption_context , # type: ignore[arg-type]
215332 )
216333
@@ -232,12 +349,6 @@ def _apply_action_to_fields(
232349 # For in-place updates, json_parse accepts a callback function
233350 # that receives 3 args: field_value, fields, field_name
234351 # We create a partial callback to pre-populate known provider options (action, provider opts, enc ctx)
235- update_callback = functools .partial (
236- self ._call_action ,
237- action = action ,
238- provider_options = provider_options ,
239- ** encryption_context , # type: ignore[arg-type]
240- )
241352
242353 json_parse .update (
243354 data_parsed ,
@@ -246,13 +357,70 @@ def _apply_action_to_fields(
246357
247358 return data_parsed
248359
360+ def _apply_masking_rules (self , data : dict , masking_rules : dict ) -> dict :
361+ """
362+ Apply masking rules to data, supporting both simple field names and complex path expressions.
363+
364+ Args:
365+ data: The dictionary containing data to mask
366+ masking_rules: Dictionary mapping field names or path expressions to masking rules
367+
368+ Returns:
369+ dict: The masked data dictionary
370+ """
371+ result = deepcopy (data )
372+
373+ for path , rule in masking_rules .items ():
374+ try :
375+ jsonpath_expr = parse (f"$.{ path } " )
376+ matches = jsonpath_expr .find (result )
377+
378+ if not matches :
379+ warnings .warn (f"No matches found for path: { path } " , stacklevel = 2 )
380+ continue
381+
382+ for match in matches :
383+ try :
384+ value = match .value
385+ if value is not None :
386+ masked_value = self .provider .erase (str (value ), ** rule )
387+ match .full_path .update (result , masked_value )
388+
389+ except Exception as e :
390+ warnings .warn (
391+ f"Error masking value for path { path } : { str (e )} " ,
392+ category = PowertoolsUserWarning ,
393+ stacklevel = 2 ,
394+ )
395+ continue
396+
397+ except Exception as e :
398+ warnings .warn (f"Error processing path { path } : { str (e )} " , category = PowertoolsUserWarning , stacklevel = 2 )
399+ continue
400+
401+ return result
402+
403+ def _mask_nested_field (self , data : dict , field_path : str , mask_function ):
404+ keys = field_path .split ("." )
405+ current = data
406+ for key in keys [:- 1 ]:
407+ current = current .get (key , {})
408+ if not isinstance (current , dict ):
409+ return
410+ if keys [- 1 ] in current :
411+ current [keys [- 1 ]] = self .provider .erase (current [keys [- 1 ]], ** mask_function )
412+
249413 @staticmethod
250414 def _call_action (
251415 field_value : Any ,
252416 fields : dict [str , Any ],
253417 field_name : str ,
254418 action : Callable ,
255419 provider_options : dict [str , Any ] | None = None ,
420+ dynamic_mask : bool | None = None ,
421+ custom_mask : str | None = None ,
422+ regex_pattern : str | None = None ,
423+ mask_format : str | None = None ,
256424 ** encryption_context ,
257425 ) -> None :
258426 """
@@ -270,13 +438,18 @@ def _call_action(
270438 Returns:
271439 - fields[field_name]: Returns the processed field value
272440 """
273- fields [field_name ] = action (field_value , provider_options = provider_options , ** encryption_context )
441+ fields [field_name ] = action (
442+ field_value ,
443+ provider_options = provider_options ,
444+ dynamic_mask = dynamic_mask ,
445+ custom_mask = custom_mask ,
446+ regex_pattern = regex_pattern ,
447+ mask_format = mask_format ,
448+ ** encryption_context ,
449+ )
274450 return fields [field_name ]
275451
276- def _normalize_data_to_parse (self , fields : list , data : str | dict ) -> dict :
277- if not fields :
278- raise ValueError ("No fields specified." )
279-
452+ def _normalize_data_to_parse (self , data : str | dict ) -> dict :
280453 if isinstance (data , str ):
281454 # Parse JSON string as dictionary
282455 data_parsed = self .json_deserializer (data )
0 commit comments