11# pyright: reportMissingTypeArgument=true, reportMissingParameterType=true
2- import copy
32import dataclasses
43import enum
5- import functools
64import inspect
75import itertools
86import json
4543_RE_SNAKE_CASE_2 = re .compile (r"[A-Z]" )
4644
4745
48- @functools .lru_cache (1024 )
46+ __not_valid = object ()
47+
48+ __to_snake_case_cache : Dict [str , str ] = {}
49+
50+
4951def to_snake_case (s : str ) -> str :
50- s = _RE_SNAKE_CASE_1 .sub ("_" , s )
51- if not s :
52- return s
53- return s [0 ].lower () + _RE_SNAKE_CASE_2 .sub (lambda matched : "_" + matched .group (0 ).lower (), s [1 :])
52+ result = __to_snake_case_cache .get (s , __not_valid )
53+ if result is __not_valid :
54+ s = _RE_SNAKE_CASE_1 .sub ("_" , s )
55+ if not s :
56+ result = s
57+ else :
58+ result = s [0 ].lower () + _RE_SNAKE_CASE_2 .sub (lambda matched : "_" + matched .group (0 ).lower (), s [1 :])
59+ __to_snake_case_cache [s ] = result
60+ return cast (str , result )
5461
5562
5663_RE_CAMEL_CASE_1 = re .compile (r"^[\-_\.]" )
5764_RE_CAMEL_CASE_2 = re .compile (r"[\-_\.\s]([a-z])" )
5865
66+ __to_snake_camel_cache : Dict [str , str ] = {}
67+
5968
60- @functools .lru_cache (1024 )
6169def to_camel_case (s : str ) -> str :
62- s = _RE_CAMEL_CASE_1 .sub ("" , str (s ))
63- if not s :
64- return s
65- return str (s [0 ]).lower () + _RE_CAMEL_CASE_2 .sub (
66- lambda matched : str (matched .group (1 )).upper (),
67- s [1 :],
68- )
70+ result = __to_snake_camel_cache .get (s , __not_valid )
71+ if result is __not_valid :
72+ s = _RE_CAMEL_CASE_1 .sub ("" , s )
73+ if not s :
74+ result = s
75+ else :
76+ result = str (s [0 ]).lower () + _RE_CAMEL_CASE_2 .sub (
77+ lambda matched : str (matched .group (1 )).upper (),
78+ s [1 :],
79+ )
80+ __to_snake_camel_cache [s ] = result
81+ return cast (str , result )
6982
7083
7184class CamelSnakeMixin :
@@ -110,21 +123,13 @@ def _decode_case(cls, s: str) -> str:
110123 return s
111124
112125
113- __default_config : Optional [DefaultConfig ] = None
114-
115-
116- def __get_default_config () -> DefaultConfig :
117- global __default_config
118-
119- if __default_config is None :
120- __default_config = DefaultConfig ()
121- return __default_config
126+ __default_config = DefaultConfig ()
122127
123128
124129def __get_config (obj : Any , entry_protocol : Type [_T ]) -> _T :
125130 if isinstance (obj , entry_protocol ):
126131 return obj
127- return cast (_T , __get_default_config () )
132+ return cast (_T , __default_config )
128133
129134
130135def encode_case (obj : Any , field : dataclasses .Field ) -> str : # type: ignore
@@ -357,23 +362,32 @@ def from_json(
357362
358363
359364def as_dict (
360- value : Any , * , remove_defaults : bool = False , dict_factory : Callable [[Any ], Dict [str , Any ]] = dict
365+ value : Any ,
366+ * ,
367+ remove_defaults : bool = False ,
368+ dict_factory : Callable [[Any ], Dict [str , Any ]] = dict ,
369+ encode : bool = True ,
361370) -> Dict [str , Any ]:
362371 if not dataclasses .is_dataclass (value ):
363372 raise TypeError ("as_dict() should be called on dataclass instances" )
364373
365- return cast (Dict [str , Any ], _as_dict_inner (value , remove_defaults , dict_factory ))
374+ return cast (Dict [str , Any ], _as_dict_inner (value , remove_defaults , dict_factory , encode ))
366375
367376
368- def _as_dict_inner (value : Any , remove_defaults : bool , dict_factory : Callable [[Any ], Dict [str , Any ]]) -> Any :
377+ def _as_dict_inner (
378+ value : Any ,
379+ remove_defaults : bool ,
380+ dict_factory : Callable [[Any ], Dict [str , Any ]],
381+ encode : bool = True ,
382+ ) -> Any :
369383 if dataclasses .is_dataclass (value ):
370384 result = []
371385 for f in dataclasses .fields (value ):
372386 v = _as_dict_inner (getattr (value , f .name ), remove_defaults , dict_factory )
373387
374388 if remove_defaults and v == f .default :
375389 continue
376- result .append ((f .name , v ))
390+ result .append ((encode_case ( value , f ) if encode else f .name , v ))
377391 return dict_factory (result )
378392
379393 if isinstance (value , tuple ) and hasattr (value , "_fields" ):
@@ -388,7 +402,7 @@ def _as_dict_inner(value: Any, remove_defaults: bool, dict_factory: Callable[[An
388402 for k , v in value .items ()
389403 )
390404
391- return copy . deepcopy ( value )
405+ return value
392406
393407
394408class TypeValidationError (Exception ):
0 commit comments