2626from .casing import camel_case , safe_snake_case , snake_case
2727from .grpc .grpclib_client import ServiceStub
2828
29- if not ( sys .version_info . major == 3 and sys . version_info . minor >= 7 ):
29+ if sys .version_info [: 2 ] < ( 3 , 7 ):
3030 # Apply backport of datetime.fromisoformat from 3.7
3131 from backports .datetime_fromisoformat import MonkeyPatch
3232
110110
111111
112112# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
113- def datetime_default_gen ():
113+ def datetime_default_gen () -> datetime :
114114 return datetime (1970 , 1 , 1 , tzinfo = timezone .utc )
115115
116116
@@ -256,8 +256,7 @@ class Enum(enum.IntEnum):
256256
257257 @classmethod
258258 def from_string (cls , name : str ) -> "Enum" :
259- """
260- Return the value which corresponds to the string name.
259+ """Return the value which corresponds to the string name.
261260
262261 Parameters
263262 -----------
@@ -316,11 +315,7 @@ def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
316315 return encode_varint (value )
317316 elif proto_type in [TYPE_SINT32 , TYPE_SINT64 ]:
318317 # Handle zig-zag encoding.
319- if value >= 0 :
320- value = value << 1
321- else :
322- value = (value << 1 ) ^ (~ 0 )
323- return encode_varint (value )
318+ return encode_varint (value << 1 if value >= 0 else (value << 1 ) ^ (~ 0 ))
324319 elif proto_type in FIXED_TYPES :
325320 return struct .pack (_pack_fmt (proto_type ), value )
326321 elif proto_type == TYPE_STRING :
@@ -413,15 +408,15 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
413408 wire_type = num_wire & 0x7
414409
415410 decoded : Any = None
416- if wire_type == 0 :
411+ if wire_type == WIRE_VARINT :
417412 decoded , i = decode_varint (value , i )
418- elif wire_type == 1 :
413+ elif wire_type == WIRE_FIXED_64 :
419414 decoded , i = value [i : i + 8 ], i + 8
420- elif wire_type == 2 :
415+ elif wire_type == WIRE_LEN_DELIM :
421416 length , i = decode_varint (value , i )
422417 decoded = value [i : i + length ]
423418 i += length
424- elif wire_type == 5 :
419+ elif wire_type == WIRE_FIXED_32 :
425420 decoded , i = value [i : i + 4 ], i + 4
426421
427422 yield ParsedField (
@@ -430,12 +425,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
430425
431426
432427class ProtoClassMetadata :
433- oneof_group_by_field : Dict [str , str ]
434- oneof_field_by_group : Dict [str , Set [dataclasses .Field ]]
435- default_gen : Dict [str , Callable ]
436- cls_by_field : Dict [str , Type ]
437- field_name_by_number : Dict [int , str ]
438- meta_by_field_name : Dict [str , FieldMetadata ]
439428 __slots__ = (
440429 "oneof_group_by_field" ,
441430 "oneof_field_by_group" ,
@@ -446,6 +435,14 @@ class ProtoClassMetadata:
446435 "sorted_field_names" ,
447436 )
448437
438+ oneof_group_by_field : Dict [str , str ]
439+ oneof_field_by_group : Dict [str , Set [dataclasses .Field ]]
440+ field_name_by_number : Dict [int , str ]
441+ meta_by_field_name : Dict [str , FieldMetadata ]
442+ sorted_field_names : Tuple [str , ...]
443+ default_gen : Dict [str , Callable [[], Any ]]
444+ cls_by_field : Dict [str , Type ]
445+
449446 def __init__ (self , cls : Type ["Message" ]):
450447 by_field = {}
451448 by_group : Dict [str , Set ] = {}
@@ -470,23 +467,21 @@ def __init__(self, cls: Type["Message"]):
470467 self .field_name_by_number = by_field_number
471468 self .meta_by_field_name = by_field_name
472469 self .sorted_field_names = tuple (
473- by_field_number [number ] for number in sorted (by_field_number . keys () )
470+ by_field_number [number ] for number in sorted (by_field_number )
474471 )
475-
476472 self .default_gen = self ._get_default_gen (cls , fields )
477473 self .cls_by_field = self ._get_cls_by_field (cls , fields )
478474
479475 @staticmethod
480- def _get_default_gen (cls , fields ):
481- default_gen = {}
482-
483- for field in fields :
484- default_gen [field .name ] = cls ._get_field_default_gen (field )
485-
486- return default_gen
476+ def _get_default_gen (
477+ cls : Type ["Message" ], fields : List [dataclasses .Field ]
478+ ) -> Dict [str , Callable [[], Any ]]:
479+ return {field .name : cls ._get_field_default_gen (field ) for field in fields }
487480
488481 @staticmethod
489- def _get_cls_by_field (cls , fields ):
482+ def _get_cls_by_field (
483+ cls : Type ["Message" ], fields : List [dataclasses .Field ]
484+ ) -> Dict [str , Type ]:
490485 field_cls = {}
491486
492487 for field in fields :
@@ -503,7 +498,7 @@ def _get_cls_by_field(cls, fields):
503498 ],
504499 bases = (Message ,),
505500 )
506- field_cls [field .name + " .value" ] = vt
501+ field_cls [f" { field .name } .value" ] = vt
507502 else :
508503 field_cls [field .name ] = cls ._cls_for (field )
509504
@@ -612,7 +607,7 @@ def __setattr__(self, attr: str, value: Any) -> None:
612607 super ().__setattr__ (attr , value )
613608
614609 @property
615- def _betterproto (self ):
610+ def _betterproto (self ) -> ProtoClassMetadata :
616611 """
617612 Lazy initialize metadata for each protobuf class.
618613 It may be initialized multiple times in a multi-threaded environment,
@@ -726,9 +721,8 @@ def _type_hint(cls, field_name: str) -> Type:
726721
727722 @classmethod
728723 def _type_hints (cls ) -> Dict [str , Type ]:
729- module = inspect .getmodule (cls )
730- type_hints = get_type_hints (cls , vars (module ))
731- return type_hints
724+ module = sys .modules [cls .__module__ ]
725+ return get_type_hints (cls , vars (module ))
732726
733727 @classmethod
734728 def _cls_for (cls , field : dataclasses .Field , index : int = 0 ) -> Type :
@@ -739,7 +733,7 @@ def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
739733 field_cls = field_cls .__args__ [index ]
740734 return field_cls
741735
742- def _get_field_default (self , field_name ) :
736+ def _get_field_default (self , field_name : str ) -> Any :
743737 return self ._betterproto .default_gen [field_name ]()
744738
745739 @classmethod
@@ -762,7 +756,7 @@ def _get_field_default_gen(cls, field: dataclasses.Field) -> Any:
762756 elif issubclass (t , Enum ):
763757 # Enums always default to zero.
764758 return int
765- elif t == datetime :
759+ elif t is datetime :
766760 # Offsets are relative to 1970-01-01T00:00:00Z
767761 return datetime_default_gen
768762 else :
@@ -966,7 +960,7 @@ def to_dict(
966960 )
967961 ):
968962 output [cased_name ] = value .to_dict (casing , include_default_values )
969- elif meta .proto_type == "map" :
963+ elif meta .proto_type == TYPE_MAP :
970964 for k in value :
971965 if hasattr (value [k ], "to_dict" ):
972966 value [k ] = value [k ].to_dict (casing , include_default_values )
@@ -1032,12 +1026,12 @@ def from_dict(self: T, value: Dict[str, Any]) -> T:
10321026 continue
10331027
10341028 if value [key ] is not None :
1035- if meta .proto_type == "message" :
1029+ if meta .proto_type == TYPE_MESSAGE :
10361030 v = getattr (self , field_name )
10371031 if isinstance (v , list ):
10381032 cls = self ._betterproto .cls_by_field [field_name ]
1039- for i in range ( len ( value [key ])) :
1040- v .append (cls ().from_dict (value [ key ][ i ] ))
1033+ for item in value [key ]:
1034+ v .append (cls ().from_dict (item ))
10411035 elif isinstance (v , datetime ):
10421036 v = datetime .fromisoformat (value [key ].replace ("Z" , "+00:00" ))
10431037 setattr (self , field_name , v )
@@ -1052,7 +1046,7 @@ def from_dict(self: T, value: Dict[str, Any]) -> T:
10521046 v .from_dict (value [key ])
10531047 elif meta .map_types and meta .map_types [1 ] == TYPE_MESSAGE :
10541048 v = getattr (self , field_name )
1055- cls = self ._betterproto .cls_by_field [field_name + " .value" ]
1049+ cls = self ._betterproto .cls_by_field [f" { field_name } .value" ]
10561050 for k in value [key ]:
10571051 v [k ] = cls ().from_dict (value [key ][k ])
10581052 else :
@@ -1134,7 +1128,7 @@ def serialized_on_wire(message: Message) -> bool:
11341128 return message ._serialized_on_wire
11351129
11361130
1137- def which_one_of (message : Message , group_name : str ) -> Tuple [str , Any ]:
1131+ def which_one_of (message : Message , group_name : str ) -> Tuple [str , Optional [ Any ] ]:
11381132 """
11391133 Return the name and value of a message's one-of field group.
11401134
@@ -1145,21 +1139,21 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
11451139 """
11461140 field_name = message ._group_current .get (group_name )
11471141 if not field_name :
1148- return ( "" , None )
1149- return ( field_name , getattr (message , field_name ) )
1142+ return "" , None
1143+ return field_name , getattr (message , field_name )
11501144
11511145
11521146# Circular import workaround: google.protobuf depends on base classes defined above.
11531147from .lib .google .protobuf import ( # noqa
1154- Duration ,
1155- Timestamp ,
11561148 BoolValue ,
11571149 BytesValue ,
11581150 DoubleValue ,
1151+ Duration ,
11591152 FloatValue ,
11601153 Int32Value ,
11611154 Int64Value ,
11621155 StringValue ,
1156+ Timestamp ,
11631157 UInt32Value ,
11641158 UInt64Value ,
11651159)
@@ -1174,8 +1168,8 @@ def delta_to_json(delta: timedelta) -> str:
11741168 parts = str (delta .total_seconds ()).split ("." )
11751169 if len (parts ) > 1 :
11761170 while len (parts [1 ]) not in [3 , 6 , 9 ]:
1177- parts [1 ] = parts [1 ] + " 0"
1178- return "." .join (parts ) + " s"
1171+ parts [1 ] = f" { parts [1 ]} 0"
1172+ return f" { '.' .join (parts )} s"
11791173
11801174
11811175class _Timestamp (Timestamp ):
@@ -1191,15 +1185,15 @@ def timestamp_to_json(dt: datetime) -> str:
11911185 if (nanos % 1e9 ) == 0 :
11921186 # If there are 0 fractional digits, the fractional
11931187 # point '.' should be omitted when serializing.
1194- return result + " Z"
1188+ return f" { result } Z"
11951189 if (nanos % 1e6 ) == 0 :
11961190 # Serialize 3 fractional digits.
1197- return result + ".%03dZ" % (nanos / 1e6 )
1191+ return f" { result } . { int (nanos // 1e6 ) :03d } Z"
11981192 if (nanos % 1e3 ) == 0 :
11991193 # Serialize 6 fractional digits.
1200- return result + ".%06dZ" % (nanos / 1e3 )
1194+ return f" { result } . { int (nanos // 1e3 ) :06d } Z"
12011195 # Serialize 9 fractional digits.
1202- return result + ".%09dZ" % nanos
1196+ return f" { result } . { nanos :09d } "
12031197
12041198
12051199class _WrappedMessage (Message ):
0 commit comments