1515import dataclasses
1616import datetime
1717import typing
18- from collections .abc import Callable
18+ from collections .abc import Callable , Container , Iterable , Mapping
1919
2020import celpy
2121from celpy import celtypes
22- from google .protobuf import any_pb2 , descriptor , message , message_factory
22+ from google .protobuf import any_pb2 , descriptor , duration_pb2 , message , message_factory
2323
24- from buf .validate import validate_pb2 # type: ignore
24+ from buf .validate import validate_pb2
2525from protovalidate .config import Config
2626from protovalidate .internal .cel_field_presence import InterpretedRunner , in_has
2727
@@ -30,14 +30,14 @@ class CompilationError(Exception):
3030 pass
3131
3232
33- def make_duration (msg : message . Message ) -> celtypes .DurationType :
33+ def make_duration (msg : duration_pb2 . Duration ) -> celtypes .DurationType :
3434 return celtypes .DurationType (
35- seconds = msg .seconds , # type: ignore
36- nanos = msg .nanos , # type: ignore
35+ seconds = msg .seconds ,
36+ nanos = msg .nanos ,
3737 )
3838
3939
40- def make_timestamp (msg : message . Message ) -> celtypes .TimestampType :
40+ def make_timestamp (msg : duration_pb2 . Duration ) -> celtypes .TimestampType :
4141 return celtypes .TimestampType (1970 , 1 , 1 ) + make_duration (msg )
4242
4343
@@ -62,7 +62,6 @@ def unwrap(msg: message.Message) -> celtypes.Value:
6262
6363class MessageType (celtypes .MapType ):
6464 msg : message .Message
65- desc : descriptor .Descriptor
6665
6766 def __init__ (self , msg : message .Message ):
6867 super ().__init__ ()
@@ -163,7 +162,7 @@ def _field_value_to_cel(val: typing.Any, field: descriptor.FieldDescriptor) -> c
163162
164163
165164def _is_empty_field (msg : message .Message , field : descriptor .FieldDescriptor ) -> bool :
166- if field .has_presence : # type: ignore[attr-defined]
165+ if field .has_presence :
167166 return not _proto_message_has_field (msg , field )
168167 if field .label == descriptor .FieldDescriptor .LABEL_REPEATED :
169168 return len (_proto_message_get_field (msg , field )) == 0
@@ -176,14 +175,11 @@ def _repeated_field_to_cel(msg: message.Message, field: descriptor.FieldDescript
176175 return _repeated_field_value_to_cel (_proto_message_get_field (msg , field ), field )
177176
178177
179- def _repeated_field_value_to_cel (val : typing .Any , field : descriptor .FieldDescriptor ) -> celtypes .Value :
180- result = celtypes .ListType ()
181- for item in val :
182- result .append (_scalar_field_value_to_cel (item , field ))
183- return result
178+ def _repeated_field_value_to_cel (val : Iterable , field : descriptor .FieldDescriptor ) -> celtypes .Value :
179+ return celtypes .ListType (_scalar_field_value_to_cel (item , field ) for item in val )
184180
185181
186- def _map_field_value_to_cel (mapping : typing . Any , field : descriptor .FieldDescriptor ) -> celtypes .Value :
182+ def _map_field_value_to_cel (mapping : Mapping , field : descriptor .FieldDescriptor ) -> celtypes .Value :
187183 result = celtypes .MapType ()
188184 key_field = field .message_type .fields [0 ]
189185 val_field = field .message_type .fields [1 ]
@@ -269,6 +265,7 @@ class RuleContext:
269265 """The state associated with a single rule evaluation."""
270266
271267 _cfg : Config
268+ _violations : list [Violation ]
272269
273270 def __init__ (self , * , config : Config , violations : typing .Optional [list [Violation ]] = None ):
274271 self ._cfg = config
@@ -305,7 +302,7 @@ def done(self) -> bool:
305302 def has_errors (self ) -> bool :
306303 return len (self ._violations ) > 0
307304
308- def sub_context (self ):
305+ def sub_context (self ) -> "RuleContext" :
309306 return RuleContext (config = self ._cfg )
310307
311308
@@ -545,19 +542,17 @@ def __init__(
545542 type_case = field_level .WhichOneof ("type" )
546543 super ().__init__ (None if type_case is None else getattr (field_level , type_case ))
547544 self ._field = field
548- self ._ignore_empty = field_level .ignore in (validate_pb2 .IGNORE_IF_ZERO_VALUE ,) or (
549- field .has_presence # type: ignore[attr-defined]
550- and not for_items
545+ self ._ignore_empty = field_level .ignore == validate_pb2 .IGNORE_IF_ZERO_VALUE or (
546+ field .has_presence and not for_items
551547 )
552548 self ._required = field_level .required
553- type_case = field_level .WhichOneof ("type" )
554549 if type_case is not None :
555550 rules : message .Message = getattr (field_level , type_case )
556551 # For each set field in the message, look for the private rule
557552 # extension.
558553 for list_field , _ in rules .ListFields ():
559- if validate_pb2 .predefined in list_field .GetOptions ().Extensions :
560- for cel in list_field .GetOptions ().Extensions [validate_pb2 .predefined ].cel :
554+ if validate_pb2 .predefined in list_field .GetOptions ().Extensions : # type: ignore
555+ for cel in list_field .GetOptions ().Extensions [validate_pb2 .predefined ].cel : # type: ignore
561556 self .add_rule (
562557 env ,
563558 funcs ,
@@ -646,25 +641,20 @@ def __init__(
646641 field_level : validate_pb2 .FieldRules ,
647642 ):
648643 super ().__init__ (env , funcs , field , field_level )
649- self ._in = []
650- if getattr (field_level .any , "in" ):
651- self ._in = getattr (field_level .any , "in" )
652- self ._not_in = []
653- if field_level .any .not_in :
654- self ._not_in = field_level .any .not_in
644+ self ._in = getattr (field_level .any , "in" ) or []
645+ self ._not_in : Container [str ] = field_level .any .not_in or []
655646
656647 def _validate_value (self , ctx : RuleContext , value : any_pb2 .Any , * , for_key : bool = False ):
657- if len (self ._in ) > 0 :
658- if value .type_url not in self ._in :
659- ctx .add (
660- Violation (
661- rule = AnyRules ._in_rule_path ,
662- rule_value = self ._in ,
663- rule_id = "any.in" ,
664- message = "type URL must be in the allow list" ,
665- for_key = for_key ,
666- )
648+ if len (self ._in ) > 0 and value .type_url not in self ._in :
649+ ctx .add (
650+ Violation (
651+ rule = AnyRules ._in_rule_path ,
652+ rule_value = self ._in ,
653+ rule_id = "any.in" ,
654+ message = "type URL must be in the allow list" ,
655+ for_key = for_key ,
667656 )
657+ )
668658 if value .type_url in self ._not_in :
669659 ctx .add (
670660 Violation (
@@ -710,22 +700,20 @@ def validate(self, ctx: RuleContext, message: message.Message):
710700 super ().validate (ctx , message )
711701 if ctx .done :
712702 return
713- if self ._defined_only :
714- value = getattr (message , self ._field .name )
715- if value not in self ._field .enum_type .values_by_number :
716- ctx .add (
717- Violation (
718- field = validate_pb2 .FieldPath (
719- elements = [
720- _field_to_element (self ._field ),
721- ],
722- ),
723- rule = EnumRules ._defined_only_rule_path ,
724- rule_value = self ._defined_only ,
725- rule_id = "enum.defined_only" ,
726- message = "value must be one of the defined enum values" ,
703+ if self ._defined_only and getattr (message , self ._field .name ) not in self ._field .enum_type .values_by_number :
704+ ctx .add (
705+ Violation (
706+ field = validate_pb2 .FieldPath (
707+ elements = [
708+ _field_to_element (self ._field ),
709+ ],
727710 ),
728- )
711+ rule = EnumRules ._defined_only_rule_path ,
712+ rule_value = self ._defined_only ,
713+ rule_id = "enum.defined_only" ,
714+ message = "value must be one of the defined enum values" ,
715+ ),
716+ )
729717
730718
731719class RepeatedRules (FieldRules ):
@@ -875,7 +863,7 @@ def __init__(self, funcs: dict[str, celpy.CELFunction]):
875863 self ._funcs = funcs
876864 self ._cache = {}
877865
878- def get (self , descriptor : descriptor . Descriptor ) -> list [Rules ]:
866+ def get (self , descriptor ) -> list [Rules ]:
879867 if descriptor not in self ._cache :
880868 try :
881869 self ._cache [descriptor ] = self ._new_rules (descriptor )
@@ -1042,8 +1030,8 @@ def _new_rules(self, desc: descriptor.Descriptor) -> list[Rules]:
10421030 result : list [Rules ] = []
10431031 rule : typing .Optional [Rules ] = None
10441032 all_msg_oneof_fields = set ()
1045- if validate_pb2 . message in desc .GetOptions ().Extensions :
1046- message_level = desc .GetOptions ().Extensions [validate_pb2 .message ]
1033+ if desc .GetOptions ().HasExtension ( validate_pb2 . message ): # type: ignore
1034+ message_level = desc .GetOptions ().Extensions [validate_pb2 .message ] # type: ignore
10471035 for oneof in message_level .oneof :
10481036 all_msg_oneof_fields .update (oneof .fields )
10491037 if rule := self ._new_message_rule (message_level , desc ):
@@ -1094,8 +1082,8 @@ def __init__(
10941082 def validate (self , ctx : RuleContext , message : message .Message ):
10951083 if not message .HasField (self ._field .name ):
10961084 return
1097- rules = self ._factory .get (self ._field .message_type )
1098- if rules is None :
1085+ rules : list [ Rules ] = self ._factory .get (self ._field .message_type )
1086+ if not rules :
10991087 return
11001088 val = getattr (message , self ._field .name )
11011089 sub_ctx = ctx .sub_context ()
@@ -1124,8 +1112,8 @@ def validate(self, ctx: RuleContext, message: message.Message):
11241112 val = getattr (message , self ._field .name )
11251113 if not val :
11261114 return
1127- rules = self ._factory .get (self ._value_field .message_type )
1128- if rules is None :
1115+ rules : list [ Rules ] = self ._factory .get (self ._value_field .message_type )
1116+ if not rules :
11291117 return
11301118 for k , v in val .items ():
11311119 sub_ctx = ctx .sub_context ()
@@ -1151,8 +1139,8 @@ def validate(self, ctx: RuleContext, message: message.Message):
11511139 val = getattr (message , self ._field .name )
11521140 if not val :
11531141 return
1154- rules = self ._factory .get (self ._field .message_type )
1155- if rules is None :
1142+ rules : list [ Rules ] = self ._factory .get (self ._field .message_type )
1143+ if not rules :
11561144 return
11571145 for idx , item in enumerate (val ):
11581146 sub_ctx = ctx .sub_context ()
0 commit comments