@@ -324,17 +324,7 @@ def __post_init__(self) -> None:
324324 # Add field to message
325325 self .parent .fields .append (self )
326326 # Check for new imports
327- annotation = self .annotation
328- if "Optional[" in annotation :
329- self .output_file .typing_imports .add ("Optional" )
330- if "List[" in annotation :
331- self .output_file .typing_imports .add ("List" )
332- if "Dict[" in annotation :
333- self .output_file .typing_imports .add ("Dict" )
334- if "timedelta" in annotation :
335- self .output_file .datetime_imports .add ("timedelta" )
336- if "datetime" in annotation :
337- self .output_file .datetime_imports .add ("datetime" )
327+ self .add_imports_to (self .output_file )
338328 super ().__post_init__ () # call FieldCompiler-> MessageCompiler __post_init__
339329
340330 def get_field_string (self , indent : int = 4 ) -> str :
@@ -356,6 +346,33 @@ def betterproto_field_args(self) -> List[str]:
356346 args .append (f"wraps={ self .field_wraps } " )
357347 return args
358348
349+ @property
350+ def datetime_imports (self ) -> Set [str ]:
351+ imports = set ()
352+ annotation = self .annotation
353+ # FIXME: false positives - e.g. `MyDatetimedelta`
354+ if "timedelta" in annotation :
355+ imports .add ("timedelta" )
356+ if "datetime" in annotation :
357+ imports .add ("datetime" )
358+ return imports
359+
360+ @property
361+ def typing_imports (self ) -> Set [str ]:
362+ imports = set ()
363+ annotation = self .annotation
364+ if "Optional[" in annotation :
365+ imports .add ("Optional" )
366+ if "List[" in annotation :
367+ imports .add ("List" )
368+ if "Dict[" in annotation :
369+ imports .add ("Dict" )
370+ return imports
371+
372+ def add_imports_to (self , output_file : OutputTemplate ) -> None :
373+ output_file .datetime_imports .update (self .datetime_imports )
374+ output_file .typing_imports .update (self .typing_imports )
375+
359376 @property
360377 def field_wraps (self ) -> Optional [str ]:
361378 """Returns betterproto wrapped field type or None."""
@@ -577,11 +594,10 @@ def __post_init__(self) -> None:
577594 # Add method to service
578595 self .parent .methods .append (self )
579596
580- # Check for Optional import
597+ # Check for imports
581598 if self .py_input_message :
582599 for f in self .py_input_message .fields :
583- if f .default_value_string == "None" :
584- self .output_file .typing_imports .add ("Optional" )
600+ f .add_imports_to (self .output_file )
585601 if "Optional" in self .py_output_message_type :
586602 self .output_file .typing_imports .add ("Optional" )
587603 self .mutable_default_args # ensure this is called before rendering
0 commit comments