Skip to content

Commit e45f36a

Browse files
Merge pull request #4 from betterproto/models-refactoring-2
Fix parse_source_type_name
2 parents 7e0cce7 + ef271df commit e45f36a

File tree

13 files changed

+165
-547
lines changed

13 files changed

+165
-547
lines changed

src/betterproto/compile/importing.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818

1919
if TYPE_CHECKING:
20+
from ..plugin.models import PluginRequestCompiler
2021
from ..plugin.typing_compiler import TypingCompiler
2122

2223
WRAPPER_TYPES: Dict[str, Type] = {
@@ -32,20 +33,42 @@
3233
}
3334

3435

35-
def parse_source_type_name(field_type_name: str) -> Tuple[str, str]:
36+
def parse_source_type_name(
37+
field_type_name: str, request: "PluginRequestCompiler"
38+
) -> Tuple[str, str]:
3639
"""
3740
Split full source type name into package and type name.
3841
E.g. 'root.package.Message' -> ('root.package', 'Message')
3942
'root.Message.SomeEnum' -> ('root', 'Message.SomeEnum')
43+
44+
The function goes through the symbols that have been defined (names, enums, packages) to find the actual package and
45+
name of the object that is referenced.
4046
"""
41-
package_match = re.match(r"^\.?([^A-Z]+)\.(.+)", field_type_name)
42-
if package_match:
43-
package = package_match.group(1)
44-
name = package_match.group(2)
45-
else:
46-
package = ""
47-
name = field_type_name.lstrip(".")
48-
return package, name
47+
if field_type_name[0] != ".":
48+
raise RuntimeError("relative names are not supported")
49+
field_type_name = field_type_name[1:]
50+
parts = field_type_name.split(".")
51+
52+
answer = None
53+
54+
# a.b.c:
55+
# i=0: "", "a.b.c"
56+
# i=1: "a", "b.c"
57+
# i=2: "a.b", "c"
58+
for i in range(len(parts)):
59+
package_name, object_name = ".".join(parts[:i]), ".".join(parts[i:])
60+
61+
if package := request.output_packages.get(package_name):
62+
if object_name in package.messages or object_name in package.enums:
63+
if answer:
64+
# This should have already been handeled by protoc
65+
raise ValueError(f"ambiguous definition: {field_type_name}")
66+
answer = package_name, object_name
67+
68+
if answer:
69+
return answer
70+
71+
raise ValueError(f"can't find type name: {field_type_name}")
4972

5073

5174
def get_type_reference(
@@ -54,6 +77,7 @@ def get_type_reference(
5477
imports: set,
5578
source_type: str,
5679
typing_compiler: TypingCompiler,
80+
request: "PluginRequestCompiler",
5781
unwrap: bool = True,
5882
pydantic: bool = False,
5983
) -> str:
@@ -72,7 +96,7 @@ def get_type_reference(
7296
elif source_type == ".google.protobuf.Timestamp":
7397
return "datetime"
7498

75-
source_package, source_type = parse_source_type_name(source_type)
99+
source_package, source_type = parse_source_type_name(source_type, request)
76100

77101
current_package: List[str] = package.split(".") if package else []
78102
py_package: List[str] = source_package.split(".") if source_package else []

src/betterproto/plugin/models.py

Lines changed: 56 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,7 @@
6565
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
6666

6767
from .. import which_one_of
68-
from ..compile.importing import (
69-
get_type_reference,
70-
parse_source_type_name,
71-
)
68+
from ..compile.importing import get_type_reference
7269
from ..compile.naming import (
7370
pythonize_class_name,
7471
pythonize_enum_member_name,
@@ -205,6 +202,12 @@ def __post_init__(self) -> None:
205202
if field_val is PLACEHOLDER:
206203
raise ValueError(f"`{field_name}` is a required field.")
207204

205+
def ready(self) -> None:
206+
"""
207+
This function is called after all the compilers are created, but before generating the output code.
208+
"""
209+
pass
210+
208211
@property
209212
def output_file(self) -> "OutputTemplate":
210213
current = self
@@ -214,10 +217,7 @@ def output_file(self) -> "OutputTemplate":
214217

215218
@property
216219
def request(self) -> "PluginRequestCompiler":
217-
current = self
218-
while not isinstance(current, OutputTemplate):
219-
current = current.parent
220-
return current.parent_request
220+
return self.output_file.parent_request
221221

222222
@property
223223
def comment(self) -> str:
@@ -228,6 +228,10 @@ def comment(self) -> str:
228228
proto_file=self.source_file, path=self.path, indent=self.comment_indent
229229
)
230230

231+
@property
232+
def deprecated(self) -> bool:
233+
return self.proto_obj.options.deprecated
234+
231235

232236
@dataclass
233237
class PluginRequestCompiler:
@@ -244,7 +248,9 @@ def all_messages(self) -> List["MessageCompiler"]:
244248
List of all of the messages in this request.
245249
"""
246250
return [
247-
msg for output in self.output_packages.values() for msg in output.messages
251+
msg
252+
for output in self.output_packages.values()
253+
for msg in output.messages.values()
248254
]
249255

250256

@@ -264,9 +270,9 @@ class OutputTemplate:
264270
datetime_imports: Set[str] = field(default_factory=set)
265271
pydantic_imports: Set[str] = field(default_factory=set)
266272
builtins_import: bool = False
267-
messages: List["MessageCompiler"] = field(default_factory=list)
268-
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
269-
services: List["ServiceCompiler"] = field(default_factory=list)
273+
messages: Dict[str, "MessageCompiler"] = field(default_factory=dict)
274+
enums: Dict[str, "EnumDefinitionCompiler"] = field(default_factory=dict)
275+
services: Dict[str, "ServiceCompiler"] = field(default_factory=dict)
270276
imports_type_checking_only: Set[str] = field(default_factory=set)
271277
pydantic_dataclasses: bool = False
272278
output: bool = True
@@ -299,13 +305,13 @@ def python_module_imports(self) -> Set[str]:
299305
imports = set()
300306

301307
has_deprecated = False
302-
if any(m.deprecated for m in self.messages):
308+
if any(m.deprecated for m in self.messages.values()):
303309
has_deprecated = True
304-
if any(x for x in self.messages if any(x.deprecated_fields)):
310+
if any(x for x in self.messages.values() if any(x.deprecated_fields)):
305311
has_deprecated = True
306312
if any(
307313
any(m.proto_obj.options.deprecated for m in s.methods)
308-
for s in self.services
314+
for s in self.services.values()
309315
):
310316
has_deprecated = True
311317

@@ -329,17 +335,15 @@ class MessageCompiler(ProtoContentBase):
329335
fields: List[Union["FieldCompiler", "MessageCompiler"]] = field(
330336
default_factory=list
331337
)
332-
deprecated: bool = field(default=False, init=False)
333338
builtins_types: Set[str] = field(default_factory=set)
334339

335340
def __post_init__(self) -> None:
336341
# Add message to output file
337342
if isinstance(self.parent, OutputTemplate):
338343
if isinstance(self, EnumDefinitionCompiler):
339-
self.output_file.enums.append(self)
344+
self.output_file.enums[self.proto_name] = self
340345
else:
341-
self.output_file.messages.append(self)
342-
self.deprecated = self.proto_obj.options.deprecated
346+
self.output_file.messages[self.proto_name] = self
343347
super().__post_init__()
344348

345349
@property
@@ -417,16 +421,24 @@ def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
417421

418422

419423
@dataclass
420-
class FieldCompiler(MessageCompiler):
424+
class FieldCompiler(ProtoContentBase):
425+
source_file: FileDescriptorProto
426+
typing_compiler: TypingCompiler
427+
path: List[int] = PLACEHOLDER
428+
builtins_types: Set[str] = field(default_factory=set)
429+
421430
parent: MessageCompiler = PLACEHOLDER
422431
proto_obj: FieldDescriptorProto = PLACEHOLDER
423432

424433
def __post_init__(self) -> None:
425434
# Add field to message
426-
self.parent.fields.append(self)
435+
if isinstance(self.parent, MessageCompiler):
436+
self.parent.fields.append(self)
437+
super().__post_init__()
438+
439+
def ready(self) -> None:
427440
# Check for new imports
428441
self.add_imports_to(self.output_file)
429-
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
430442

431443
def get_field_string(self, indent: int = 4) -> str:
432444
"""Construct string representation of this field as a field."""
@@ -544,6 +556,7 @@ def py_type(self) -> str:
544556
imports=self.output_file.imports_end,
545557
source_type=self.proto_obj.type_name,
546558
typing_compiler=self.typing_compiler,
559+
request=self.request,
547560
pydantic=self.output_file.pydantic_dataclasses,
548561
)
549562
else:
@@ -587,12 +600,22 @@ def pydantic_imports(self) -> Set[str]:
587600

588601
@dataclass
589602
class MapEntryCompiler(FieldCompiler):
590-
py_k_type: Type = PLACEHOLDER
591-
py_v_type: Type = PLACEHOLDER
592-
proto_k_type: str = PLACEHOLDER
593-
proto_v_type: str = PLACEHOLDER
603+
py_k_type: Optional[Type] = None
604+
py_v_type: Optional[Type] = None
605+
proto_k_type: str = ""
606+
proto_v_type: str = ""
594607

595-
def __post_init__(self) -> None:
608+
def __post_init__(self):
609+
map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry"
610+
for nested in self.parent.proto_obj.nested_type:
611+
if (
612+
nested.name.replace("_", "").lower() == map_entry
613+
and nested.options.map_entry
614+
):
615+
pass
616+
return super().__post_init__()
617+
618+
def ready(self) -> None:
596619
"""Explore nested types and set k_type and v_type if unset."""
597620
map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry"
598621
for nested in self.parent.proto_obj.nested_type:
@@ -617,7 +640,9 @@ def __post_init__(self) -> None:
617640
# Get proto types
618641
self.proto_k_type = FieldDescriptorProtoType(nested.field[0].type).name
619642
self.proto_v_type = FieldDescriptorProtoType(nested.field[1].type).name
620-
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
643+
return
644+
645+
raise ValueError("can't find enum")
621646

622647
@property
623648
def betterproto_field_args(self) -> List[str]:
@@ -678,7 +703,7 @@ class ServiceCompiler(ProtoContentBase):
678703

679704
def __post_init__(self) -> None:
680705
# Add service to output file
681-
self.output_file.services.append(self)
706+
self.output_file.services[self.proto_name] = self
682707
super().__post_init__() # check for unset fields
683708

684709
@property
@@ -744,6 +769,7 @@ def py_input_message_type(self) -> str:
744769
imports=self.output_file.imports_end,
745770
source_type=self.proto_obj.input_type,
746771
typing_compiler=self.output_file.typing_compiler,
772+
request=self.request,
747773
unwrap=False,
748774
pydantic=self.output_file.pydantic_dataclasses,
749775
).strip('"')
@@ -774,6 +800,7 @@ def py_output_message_type(self) -> str:
774800
imports=self.output_file.imports_end,
775801
source_type=self.proto_obj.output_type,
776802
typing_compiler=self.output_file.typing_compiler,
803+
request=self.request,
777804
unwrap=False,
778805
pydantic=self.output_file.pydantic_dataclasses,
779806
).strip('"')

src/betterproto/plugin/parser.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
from .typing_compiler import (
4141
DirectImportTypingCompiler,
4242
NoTyping310TypingCompiler,
43-
TypingCompiler,
4443
TypingImportTypingCompiler,
4544
)
4645

@@ -61,7 +60,13 @@ def _traverse(
6160
for i, item in enumerate(items):
6261
# Adjust the name since we flatten the hierarchy.
6362
# Todo: don't change the name, but include full name in returned tuple
64-
item.name = next_prefix = f"{prefix}_{item.name}"
63+
should_rename = (
64+
not isinstance(item, DescriptorProto) or not item.options.map_entry
65+
)
66+
67+
item.name = next_prefix = (
68+
f"{prefix}.{item.name}" if prefix and should_rename else item.name
69+
)
6570
yield item, [*path, i]
6671

6772
if isinstance(item, DescriptorProto):
@@ -145,6 +150,21 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
145150
for index, service in enumerate(proto_input_file.service):
146151
read_protobuf_service(proto_input_file, service, index, output_package)
147152

153+
# All the hierarchy is ready. We can perform pre-computations before generating the output files
154+
for package in request_data.output_packages.values():
155+
for message in package.messages.values():
156+
for field in message.fields:
157+
field.ready()
158+
message.ready()
159+
for enum in package.enums.values():
160+
for variant in enum.fields:
161+
variant.ready()
162+
enum.ready()
163+
for service in package.services.values():
164+
for method in service.methods:
165+
method.ready()
166+
service.ready()
167+
148168
# Generate output files
149169
output_paths: Set[pathlib.Path] = set()
150170
for output_package_name, output_package in request_data.output_packages.items():

src/betterproto/templates/header.py.j2

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
# This file has been @generated
55

66
__all__ = (
7-
{%- for enum in output_file.enums -%}
7+
{% for _, enum in output_file.enums|dictsort(by="key") %}
88
"{{ enum.py_name }}",
99
{%- endfor -%}
10-
{%- for message in output_file.messages -%}
10+
{% for _, message in output_file.messages|dictsort(by="key") %}
1111
"{{ message.py_name }}",
1212
{%- endfor -%}
13-
{%- for service in output_file.services -%}
13+
{% for _, service in output_file.services|dictsort(by="key") %}
1414
"{{ service.py_name }}Stub",
1515
"{{ service.py_name }}Base",
1616
{%- endfor -%}

src/betterproto/templates/template.py.j2

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
{% if output_file.enums %}{% for enum in output_file.enums %}
1+
{% if output_file.enums %}{% for _, enum in output_file.enums|dictsort(by="key") %}
22
class {{ enum.py_name }}(betterproto.Enum):
33
{% if enum.comment %}
44
{{ enum.comment }}
@@ -22,7 +22,7 @@ class {{ enum.py_name }}(betterproto.Enum):
2222

2323
{% endfor %}
2424
{% endif %}
25-
{% for message in output_file.messages %}
25+
{% for _, message in output_file.messages|dictsort(by="key") %}
2626
{% if output_file.pydantic_dataclasses %}
2727
@dataclass(eq=False, repr=False, config={"extra": "forbid"})
2828
{% else %}
@@ -63,7 +63,7 @@ class {{ message.py_name }}(betterproto.Message):
6363
{% endif %}
6464

6565
{% endfor %}
66-
{% for service in output_file.services %}
66+
{% for _, service in output_file.services|dictsort(by="key") %}
6767
class {{ service.py_name }}Stub(betterproto.ServiceStub):
6868
{% if service.comment %}
6969
{{ service.comment }}
@@ -147,7 +147,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
147147
{{ i }}
148148
{% endfor %}
149149

150-
{% for service in output_file.services %}
150+
{% for _, service in output_file.services|dictsort(by="key") %}
151151
class {{ service.py_name }}Base(ServiceBase):
152152
{% if service.comment %}
153153
{{ service.comment }}

tests/inputs/config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
"namespace_keywords", # 70
55
"googletypes_struct", # 9
66
"googletypes_value", # 9
7-
"import_capitalized_package",
87
"example", # This is the example in the readme. Not a test.
98
}
109

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
syntax = "proto3";
2+
3+
package import_child_scoping_rules.aaa.bbb.ccc.ddd;
4+
5+
message ChildMessage {
6+
7+
}

0 commit comments

Comments
 (0)