Skip to content

Commit 9df411b

Browse files
Improve typing (#359)
This is largely teaching mypy about the gen/ path, removing those `type-ignore` comments on the imports, and then fixing up the aftermath. There's also a bit of type narrowing, minor 'Pythonic' tweaks, and tweaking the types-protobuf dependency (and fixing those lints), so that we can eventually get better types 😄.
1 parent 42bfa58 commit 9df411b

File tree

4 files changed

+142
-198
lines changed

4 files changed

+142
-198
lines changed

protovalidate/internal/rules.py

Lines changed: 50 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
import dataclasses
1616
import datetime
1717
import typing
18-
from collections.abc import Callable
18+
from collections.abc import Callable, Container, Iterable, Mapping
1919

2020
import celpy
2121
from celpy import celtypes
22-
from google.protobuf import any_pb2, descriptor, message, message_factory
22+
from google.protobuf import any_pb2, descriptor, duration_pb2, message, message_factory
2323

24-
from buf.validate import validate_pb2 # type: ignore
24+
from buf.validate import validate_pb2
2525
from protovalidate.config import Config
2626
from protovalidate.internal.cel_field_presence import InterpretedRunner, in_has
2727

@@ -30,14 +30,14 @@ class CompilationError(Exception):
3030
pass
3131

3232

33-
def make_duration(msg: message.Message) -> celtypes.DurationType:
33+
def make_duration(msg: duration_pb2.Duration) -> celtypes.DurationType:
3434
return celtypes.DurationType(
35-
seconds=msg.seconds, # type: ignore
36-
nanos=msg.nanos, # type: ignore
35+
seconds=msg.seconds,
36+
nanos=msg.nanos,
3737
)
3838

3939

40-
def make_timestamp(msg: message.Message) -> celtypes.TimestampType:
40+
def make_timestamp(msg: duration_pb2.Duration) -> celtypes.TimestampType:
4141
return celtypes.TimestampType(1970, 1, 1) + make_duration(msg)
4242

4343

@@ -62,7 +62,6 @@ def unwrap(msg: message.Message) -> celtypes.Value:
6262

6363
class MessageType(celtypes.MapType):
6464
msg: message.Message
65-
desc: descriptor.Descriptor
6665

6766
def __init__(self, msg: message.Message):
6867
super().__init__()
@@ -163,7 +162,7 @@ def _field_value_to_cel(val: typing.Any, field: descriptor.FieldDescriptor) -> c
163162

164163

165164
def _is_empty_field(msg: message.Message, field: descriptor.FieldDescriptor) -> bool:
166-
if field.has_presence: # type: ignore[attr-defined]
165+
if field.has_presence:
167166
return not _proto_message_has_field(msg, field)
168167
if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
169168
return len(_proto_message_get_field(msg, field)) == 0
@@ -176,14 +175,11 @@ def _repeated_field_to_cel(msg: message.Message, field: descriptor.FieldDescript
176175
return _repeated_field_value_to_cel(_proto_message_get_field(msg, field), field)
177176

178177

179-
def _repeated_field_value_to_cel(val: typing.Any, field: descriptor.FieldDescriptor) -> celtypes.Value:
180-
result = celtypes.ListType()
181-
for item in val:
182-
result.append(_scalar_field_value_to_cel(item, field))
183-
return result
178+
def _repeated_field_value_to_cel(val: Iterable, field: descriptor.FieldDescriptor) -> celtypes.Value:
179+
return celtypes.ListType(_scalar_field_value_to_cel(item, field) for item in val)
184180

185181

186-
def _map_field_value_to_cel(mapping: typing.Any, field: descriptor.FieldDescriptor) -> celtypes.Value:
182+
def _map_field_value_to_cel(mapping: Mapping, field: descriptor.FieldDescriptor) -> celtypes.Value:
187183
result = celtypes.MapType()
188184
key_field = field.message_type.fields[0]
189185
val_field = field.message_type.fields[1]
@@ -269,6 +265,7 @@ class RuleContext:
269265
"""The state associated with a single rule evaluation."""
270266

271267
_cfg: Config
268+
_violations: list[Violation]
272269

273270
def __init__(self, *, config: Config, violations: typing.Optional[list[Violation]] = None):
274271
self._cfg = config
@@ -305,7 +302,7 @@ def done(self) -> bool:
305302
def has_errors(self) -> bool:
306303
return len(self._violations) > 0
307304

308-
def sub_context(self):
305+
def sub_context(self) -> "RuleContext":
309306
return RuleContext(config=self._cfg)
310307

311308

@@ -545,19 +542,17 @@ def __init__(
545542
type_case = field_level.WhichOneof("type")
546543
super().__init__(None if type_case is None else getattr(field_level, type_case))
547544
self._field = field
548-
self._ignore_empty = field_level.ignore in (validate_pb2.IGNORE_IF_ZERO_VALUE,) or (
549-
field.has_presence # type: ignore[attr-defined]
550-
and not for_items
545+
self._ignore_empty = field_level.ignore == validate_pb2.IGNORE_IF_ZERO_VALUE or (
546+
field.has_presence and not for_items
551547
)
552548
self._required = field_level.required
553-
type_case = field_level.WhichOneof("type")
554549
if type_case is not None:
555550
rules: message.Message = getattr(field_level, type_case)
556551
# For each set field in the message, look for the private rule
557552
# extension.
558553
for list_field, _ in rules.ListFields():
559-
if validate_pb2.predefined in list_field.GetOptions().Extensions:
560-
for cel in list_field.GetOptions().Extensions[validate_pb2.predefined].cel:
554+
if validate_pb2.predefined in list_field.GetOptions().Extensions: # type: ignore
555+
for cel in list_field.GetOptions().Extensions[validate_pb2.predefined].cel: # type: ignore
561556
self.add_rule(
562557
env,
563558
funcs,
@@ -646,25 +641,20 @@ def __init__(
646641
field_level: validate_pb2.FieldRules,
647642
):
648643
super().__init__(env, funcs, field, field_level)
649-
self._in = []
650-
if getattr(field_level.any, "in"):
651-
self._in = getattr(field_level.any, "in")
652-
self._not_in = []
653-
if field_level.any.not_in:
654-
self._not_in = field_level.any.not_in
644+
self._in = getattr(field_level.any, "in") or []
645+
self._not_in: Container[str] = field_level.any.not_in or []
655646

656647
def _validate_value(self, ctx: RuleContext, value: any_pb2.Any, *, for_key: bool = False):
657-
if len(self._in) > 0:
658-
if value.type_url not in self._in:
659-
ctx.add(
660-
Violation(
661-
rule=AnyRules._in_rule_path,
662-
rule_value=self._in,
663-
rule_id="any.in",
664-
message="type URL must be in the allow list",
665-
for_key=for_key,
666-
)
648+
if len(self._in) > 0 and value.type_url not in self._in:
649+
ctx.add(
650+
Violation(
651+
rule=AnyRules._in_rule_path,
652+
rule_value=self._in,
653+
rule_id="any.in",
654+
message="type URL must be in the allow list",
655+
for_key=for_key,
667656
)
657+
)
668658
if value.type_url in self._not_in:
669659
ctx.add(
670660
Violation(
@@ -710,22 +700,20 @@ def validate(self, ctx: RuleContext, message: message.Message):
710700
super().validate(ctx, message)
711701
if ctx.done:
712702
return
713-
if self._defined_only:
714-
value = getattr(message, self._field.name)
715-
if value not in self._field.enum_type.values_by_number:
716-
ctx.add(
717-
Violation(
718-
field=validate_pb2.FieldPath(
719-
elements=[
720-
_field_to_element(self._field),
721-
],
722-
),
723-
rule=EnumRules._defined_only_rule_path,
724-
rule_value=self._defined_only,
725-
rule_id="enum.defined_only",
726-
message="value must be one of the defined enum values",
703+
if self._defined_only and getattr(message, self._field.name) not in self._field.enum_type.values_by_number:
704+
ctx.add(
705+
Violation(
706+
field=validate_pb2.FieldPath(
707+
elements=[
708+
_field_to_element(self._field),
709+
],
727710
),
728-
)
711+
rule=EnumRules._defined_only_rule_path,
712+
rule_value=self._defined_only,
713+
rule_id="enum.defined_only",
714+
message="value must be one of the defined enum values",
715+
),
716+
)
729717

730718

731719
class RepeatedRules(FieldRules):
@@ -875,7 +863,7 @@ def __init__(self, funcs: dict[str, celpy.CELFunction]):
875863
self._funcs = funcs
876864
self._cache = {}
877865

878-
def get(self, descriptor: descriptor.Descriptor) -> list[Rules]:
866+
def get(self, descriptor) -> list[Rules]:
879867
if descriptor not in self._cache:
880868
try:
881869
self._cache[descriptor] = self._new_rules(descriptor)
@@ -1042,8 +1030,8 @@ def _new_rules(self, desc: descriptor.Descriptor) -> list[Rules]:
10421030
result: list[Rules] = []
10431031
rule: typing.Optional[Rules] = None
10441032
all_msg_oneof_fields = set()
1045-
if validate_pb2.message in desc.GetOptions().Extensions:
1046-
message_level = desc.GetOptions().Extensions[validate_pb2.message]
1033+
if desc.GetOptions().HasExtension(validate_pb2.message): # type: ignore
1034+
message_level = desc.GetOptions().Extensions[validate_pb2.message] # type: ignore
10471035
for oneof in message_level.oneof:
10481036
all_msg_oneof_fields.update(oneof.fields)
10491037
if rule := self._new_message_rule(message_level, desc):
@@ -1094,8 +1082,8 @@ def __init__(
10941082
def validate(self, ctx: RuleContext, message: message.Message):
10951083
if not message.HasField(self._field.name):
10961084
return
1097-
rules = self._factory.get(self._field.message_type)
1098-
if rules is None:
1085+
rules: list[Rules] = self._factory.get(self._field.message_type)
1086+
if not rules:
10991087
return
11001088
val = getattr(message, self._field.name)
11011089
sub_ctx = ctx.sub_context()
@@ -1124,8 +1112,8 @@ def validate(self, ctx: RuleContext, message: message.Message):
11241112
val = getattr(message, self._field.name)
11251113
if not val:
11261114
return
1127-
rules = self._factory.get(self._value_field.message_type)
1128-
if rules is None:
1115+
rules: list[Rules] = self._factory.get(self._value_field.message_type)
1116+
if not rules:
11291117
return
11301118
for k, v in val.items():
11311119
sub_ctx = ctx.sub_context()
@@ -1151,8 +1139,8 @@ def validate(self, ctx: RuleContext, message: message.Message):
11511139
val = getattr(message, self._field.name)
11521140
if not val:
11531141
return
1154-
rules = self._factory.get(self._field.message_type)
1155-
if rules is None:
1142+
rules: list[Rules] = self._factory.get(self._field.message_type)
1143+
if not rules:
11561144
return
11571145
for idx, item in enumerate(val):
11581146
sub_ctx = ctx.sub_context()

protovalidate/validator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from google.protobuf import message
1818

19-
from buf.validate import validate_pb2 # type: ignore
19+
from buf.validate import validate_pb2
2020
from protovalidate.config import Config
2121
from protovalidate.internal import extra_func
2222
from protovalidate.internal import rules as _rules
@@ -38,7 +38,7 @@ class Validator:
3838
_factory: _rules.RuleFactory
3939
_cfg: Config
4040

41-
def __init__(self, config=None):
41+
def __init__(self, config: typing.Optional[Config] = None):
4242
self._cfg = config if config is not None else Config()
4343
funcs = extra_func.make_extra_funcs()
4444
self._factory = _rules.RuleFactory(funcs)
@@ -92,9 +92,9 @@ def collect_violations(
9292
break
9393
for violation in ctx.violations:
9494
if violation.proto.HasField("field"):
95-
violation.proto.field.elements.reverse()
95+
violation.proto.field.elements.reverse() # type: ignore
9696
if violation.proto.HasField("rule"):
97-
violation.proto.rule.elements.reverse()
97+
violation.proto.rule.elements.reverse() # type: ignore
9898
return ctx.violations
9999

100100

pyproject.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ dev = [
4545
"google-re2-stubs>=0.1.1",
4646
"mypy>=1.17.1",
4747
"ruff>=0.12.0",
48-
"types-protobuf>=5",
48+
"types-protobuf>=5.29.1.20250315",
4949
]
5050

5151
[tool.hatch.version]
@@ -106,3 +106,9 @@ ban-relative-imports = "all"
106106
[tool.ruff.lint.per-file-ignores]
107107
# Tests can use magic values, assertions, and relative imports.
108108
"tests/**/*" = ["PLR2004", "S101", "TID252"]
109+
110+
[tool.mypy]
111+
mypy_path = "gen"
112+
113+
[tool.ty.environment]
114+
extra-paths = ["gen"]

0 commit comments

Comments
 (0)