6565from betterproto .lib .google .protobuf .compiler import CodeGeneratorRequest
6666
6767from .. import which_one_of
68- from ..compile .importing import (
69- get_type_reference ,
70- parse_source_type_name ,
71- )
68+ from ..compile .importing import get_type_reference
7269from ..compile .naming import (
7370 pythonize_class_name ,
7471 pythonize_enum_member_name ,
@@ -205,6 +202,12 @@ def __post_init__(self) -> None:
205202 if field_val is PLACEHOLDER :
206203 raise ValueError (f"`{ field_name } ` is a required field." )
207204
205+ def ready (self ) -> None :
206+ """
207+ This function is called after all the compilers are created, but before generating the output code.
208+ """
209+ pass
210+
208211 @property
209212 def output_file (self ) -> "OutputTemplate" :
210213 current = self
@@ -214,10 +217,7 @@ def output_file(self) -> "OutputTemplate":
214217
215218 @property
216219 def request (self ) -> "PluginRequestCompiler" :
217- current = self
218- while not isinstance (current , OutputTemplate ):
219- current = current .parent
220- return current .parent_request
220+ return self .output_file .parent_request
221221
222222 @property
223223 def comment (self ) -> str :
@@ -228,6 +228,10 @@ def comment(self) -> str:
228228 proto_file = self .source_file , path = self .path , indent = self .comment_indent
229229 )
230230
231+ @property
232+ def deprecated (self ) -> bool :
233+ return self .proto_obj .options .deprecated
234+
231235
232236@dataclass
233237class PluginRequestCompiler :
@@ -244,7 +248,9 @@ def all_messages(self) -> List["MessageCompiler"]:
244248 List of all of the messages in this request.
245249 """
246250 return [
247- msg for output in self .output_packages .values () for msg in output .messages
251+ msg
252+ for output in self .output_packages .values ()
253+ for msg in output .messages .values ()
248254 ]
249255
250256
@@ -264,9 +270,9 @@ class OutputTemplate:
264270 datetime_imports : Set [str ] = field (default_factory = set )
265271 pydantic_imports : Set [str ] = field (default_factory = set )
266272 builtins_import : bool = False
267- messages : List [ "MessageCompiler" ] = field (default_factory = list )
268- enums : List [ "EnumDefinitionCompiler" ] = field (default_factory = list )
269- services : List [ "ServiceCompiler" ] = field (default_factory = list )
273+ messages : Dict [ str , "MessageCompiler" ] = field (default_factory = dict )
274+ enums : Dict [ str , "EnumDefinitionCompiler" ] = field (default_factory = dict )
275+ services : Dict [ str , "ServiceCompiler" ] = field (default_factory = dict )
270276 imports_type_checking_only : Set [str ] = field (default_factory = set )
271277 pydantic_dataclasses : bool = False
272278 output : bool = True
@@ -299,13 +305,13 @@ def python_module_imports(self) -> Set[str]:
299305 imports = set ()
300306
301307 has_deprecated = False
302- if any (m .deprecated for m in self .messages ):
308+ if any (m .deprecated for m in self .messages . values () ):
303309 has_deprecated = True
304- if any (x for x in self .messages if any (x .deprecated_fields )):
310+ if any (x for x in self .messages . values () if any (x .deprecated_fields )):
305311 has_deprecated = True
306312 if any (
307313 any (m .proto_obj .options .deprecated for m in s .methods )
308- for s in self .services
314+ for s in self .services . values ()
309315 ):
310316 has_deprecated = True
311317
@@ -329,17 +335,15 @@ class MessageCompiler(ProtoContentBase):
329335 fields : List [Union ["FieldCompiler" , "MessageCompiler" ]] = field (
330336 default_factory = list
331337 )
332- deprecated : bool = field (default = False , init = False )
333338 builtins_types : Set [str ] = field (default_factory = set )
334339
335340 def __post_init__ (self ) -> None :
336341 # Add message to output file
337342 if isinstance (self .parent , OutputTemplate ):
338343 if isinstance (self , EnumDefinitionCompiler ):
339- self .output_file .enums . append ( self )
344+ self .output_file .enums [ self . proto_name ] = self
340345 else :
341- self .output_file .messages .append (self )
342- self .deprecated = self .proto_obj .options .deprecated
346+ self .output_file .messages [self .proto_name ] = self
343347 super ().__post_init__ ()
344348
345349 @property
@@ -417,16 +421,24 @@ def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
417421
418422
419423@dataclass
420- class FieldCompiler (MessageCompiler ):
424+ class FieldCompiler (ProtoContentBase ):
425+ source_file : FileDescriptorProto
426+ typing_compiler : TypingCompiler
427+ path : List [int ] = PLACEHOLDER
428+ builtins_types : Set [str ] = field (default_factory = set )
429+
421430 parent : MessageCompiler = PLACEHOLDER
422431 proto_obj : FieldDescriptorProto = PLACEHOLDER
423432
424433 def __post_init__ (self ) -> None :
425434 # Add field to message
426- self .parent .fields .append (self )
435+ if isinstance (self .parent , MessageCompiler ):
436+ self .parent .fields .append (self )
437+ super ().__post_init__ ()
438+
439+ def ready (self ) -> None :
427440 # Check for new imports
428441 self .add_imports_to (self .output_file )
429- super ().__post_init__ () # call FieldCompiler-> MessageCompiler __post_init__
430442
431443 def get_field_string (self , indent : int = 4 ) -> str :
432444 """Construct string representation of this field as a field."""
@@ -544,6 +556,7 @@ def py_type(self) -> str:
544556 imports = self .output_file .imports_end ,
545557 source_type = self .proto_obj .type_name ,
546558 typing_compiler = self .typing_compiler ,
559+ request = self .request ,
547560 pydantic = self .output_file .pydantic_dataclasses ,
548561 )
549562 else :
@@ -587,12 +600,22 @@ def pydantic_imports(self) -> Set[str]:
587600
588601@dataclass
589602class MapEntryCompiler (FieldCompiler ):
590- py_k_type : Type = PLACEHOLDER
591- py_v_type : Type = PLACEHOLDER
592- proto_k_type : str = PLACEHOLDER
593- proto_v_type : str = PLACEHOLDER
603+ py_k_type : Optional [ Type ] = None
604+ py_v_type : Optional [ Type ] = None
605+ proto_k_type : str = ""
606+ proto_v_type : str = ""
594607
595- def __post_init__ (self ) -> None :
608+ def __post_init__ (self ):
609+ map_entry = f"{ self .proto_obj .name .replace ('_' , '' ).lower ()} entry"
610+ for nested in self .parent .proto_obj .nested_type :
611+ if (
612+ nested .name .replace ("_" , "" ).lower () == map_entry
613+ and nested .options .map_entry
614+ ):
615+ pass
616+ return super ().__post_init__ ()
617+
618+ def ready (self ) -> None :
596619 """Explore nested types and set k_type and v_type if unset."""
597620 map_entry = f"{ self .proto_obj .name .replace ('_' , '' ).lower ()} entry"
598621 for nested in self .parent .proto_obj .nested_type :
@@ -617,7 +640,9 @@ def __post_init__(self) -> None:
617640 # Get proto types
618641 self .proto_k_type = FieldDescriptorProtoType (nested .field [0 ].type ).name
619642 self .proto_v_type = FieldDescriptorProtoType (nested .field [1 ].type ).name
620- super ().__post_init__ () # call FieldCompiler-> MessageCompiler __post_init__
643+ return
644+
645+ raise ValueError ("can't find enum" )
621646
622647 @property
623648 def betterproto_field_args (self ) -> List [str ]:
@@ -678,7 +703,7 @@ class ServiceCompiler(ProtoContentBase):
678703
679704 def __post_init__ (self ) -> None :
680705 # Add service to output file
681- self .output_file .services . append ( self )
706+ self .output_file .services [ self . proto_name ] = self
682707 super ().__post_init__ () # check for unset fields
683708
684709 @property
@@ -744,6 +769,7 @@ def py_input_message_type(self) -> str:
744769 imports = self .output_file .imports_end ,
745770 source_type = self .proto_obj .input_type ,
746771 typing_compiler = self .output_file .typing_compiler ,
772+ request = self .request ,
747773 unwrap = False ,
748774 pydantic = self .output_file .pydantic_dataclasses ,
749775 ).strip ('"' )
@@ -774,6 +800,7 @@ def py_output_message_type(self) -> str:
774800 imports = self .output_file .imports_end ,
775801 source_type = self .proto_obj .output_type ,
776802 typing_compiler = self .output_file .typing_compiler ,
803+ request = self .request ,
777804 unwrap = False ,
778805 pydantic = self .output_file .pydantic_dataclasses ,
779806 ).strip ('"' )
0 commit comments