22
33__all__ = ["RustCodeGen" ]
44
5- import dataclasses
65import functools
76import itertools
87import json
1110import sys
1211from abc import ABC , abstractmethod
1312from collections .abc import Iterator , MutableMapping , MutableSequence , Sequence
14- from dataclasses import dataclass
1513from importlib .resources import files as resource_files
1614from io import StringIO
1715from pathlib import Path
@@ -117,6 +115,7 @@ def to_rust_literal(value: Any) -> str:
117115
118116def make_avro (items : MutableSequence [JsonDataType ]) -> MutableSequence [NamedSchema ]:
119117 """Process a list of dictionaries to generate a list of Avro schemas."""
118+
120119 # Same as `from .utils import convert_to_dict`, which, however, is not public
121120 def convert_to_dict (j4 : Any ) -> Any :
122121 """Convert generic Mapping objects to dicts recursively."""
@@ -150,11 +149,13 @@ def convert_to_dict(j4: Any) -> Any:
150149RustIdent = str # alias
151150
152151
153- @dataclass # ASSERT: Immutable class
154152class RustLifetime :
155153 """Represents a Rust lifetime parameter (e.g., `'a`)."""
156154
157- ident : RustIdent
155+ __slots__ = ("ident" ,)
156+
157+ def __init__ (self , ident : RustIdent ):
158+ self .ident = ident
158159
159160 def __hash__ (self ) -> int :
160161 return hash (self .ident )
@@ -175,11 +176,16 @@ class RustMeta(ABC):
175176 pass
176177
177178
178- @dataclass (unsafe_hash = True ) # ASSERT: Immutable class
179179class RustAttribute :
180180 """Represents a Rust attribute (e.g., `#[derive(Debug)]`)."""
181181
182- meta : RustMeta
182+ __slots__ = ("meta" ,)
183+
184+ def __init__ (self , meta : RustMeta ):
185+ self .meta = meta
186+
187+ def __hash__ (self ) -> int :
188+ return hash (self .meta )
183189
184190 def __str__ (self ) -> str :
185191 return f"#[{ str (self .meta )} ]"
@@ -193,17 +199,22 @@ def __str__(self) -> str:
193199RustGenericsMut = MutableSequence [Union [RustLifetime , "RustPath" ]] # alias
194200
195201
196- @dataclass (unsafe_hash = True ) # ASSERT: Immutable class
197202class RustPathSegment :
198203 """Represents a segment in a Rust path with optional generics."""
199204
200- ident : RustIdent
201- generics : RustGenerics = dataclasses .field (default_factory = tuple )
205+ __slots__ = ("ident" , "generics" )
202206
203207 REX : ClassVar [Pattern [str ]] = re .compile (
204208 r"^([a-zA-Z_]\w*)(?:<([ \w\t,'<>]+)>)?$"
205209 ) # Using `re.Pattern[str]` raise CI build errors
206210
211+ def __init__ (self , ident : RustIdent , generics : Optional [RustGenerics ] = None ):
212+ self .ident = ident
213+ self .generics = () if generics is None else generics
214+
215+ def __hash__ (self ) -> int :
216+ return hash ((self .ident , self .generics ))
217+
207218 def __str__ (self ) -> str :
208219 if not self .generics :
209220 return self .ident
@@ -256,13 +267,18 @@ def parse_generics_string(value_generics: str) -> RustGenerics:
256267RustPathSegmentsMut = MutableSequence [RustPathSegment ] # alias
257268
258269
259- @dataclass (unsafe_hash = True ) # ASSERT: Immutable class
260270class RustPath (RustMeta ):
261271 """Represents a complete Rust path (e.g., `::std::vec::Vec<T>`)."""
262272
273+ __slots__ = ("segments" , "leading_colon" )
274+
263275 # ASSERT: Never initialized with an empty sequence
264- segments : RustPathSegments
265- leading_colon : bool = False
276+ def __init__ (self , segments : RustPathSegments , leading_colon : bool = False ):
277+ self .segments = segments
278+ self .leading_colon = leading_colon
279+
280+ def __hash__ (self ) -> int :
281+ return hash ((self .segments , self .leading_colon ))
266282
267283 def __truediv__ (self , other : Union ["RustPath" , RustPathSegment ]) -> "RustPath" :
268284 if self .segments [- 1 ].generics :
@@ -304,24 +320,31 @@ def from_str(cls, value: str) -> "RustPath":
304320 return cls (segments = tuple (segments ), leading_colon = leading_colon )
305321
306322
307- @dataclass (unsafe_hash = True ) # ASSERT: Immutable class
308323class RustTypeTuple (RustType ):
309324 """Represents a Rust tuple type (e.g., `(T, U)`)."""
310325
326+ __slots__ = ("types" ,)
327+
311328 # ASSERT: Never initialized with an empty sequence
312- types : Sequence [RustPath ]
329+ def __init__ (self , types : Sequence [RustPath ]):
330+ self .types = types
331+
332+ def __hash__ (self ) -> int :
333+ return hash (self .types )
313334
314335 def __str__ (self ) -> str :
315336 types_str = ", " .join (str (ty ) for ty in self .types )
316337 return f"({ types_str } )"
317338
318339
319- @dataclass # ASSERT: Immutable class
320340class RustMetaList (RustMeta ):
321341 """Represents attribute meta list information (e.g., `derive(Debug, Clone)`).."""
322342
323- path : RustPath
324- metas : Sequence [RustMeta ] = tuple ()
343+ __slots__ = ("path" , "metas" )
344+
345+ def __init__ (self , path : RustPath , metas : Optional [Sequence [RustMeta ]] = None ):
346+ self .path = path
347+ self .metas = () if metas is None else metas
325348
326349 def __hash__ (self ) -> int :
327350 return hash (self .path )
@@ -331,12 +354,14 @@ def __str__(self) -> str:
331354 return f"{ str (self .path )} (" + meta_str + ")"
332355
333356
334- @dataclass # ASSERT: Immutable class
335357class RustMetaNameValue (RustMeta ):
336358 """Represents attribute meta name-value information (e.g., `key = value`)."""
337359
338- path : RustPath
339- value : Any = True
360+ __slots__ = ("path" , "value" )
361+
362+ def __init__ (self , path : RustPath , value : Any = True ):
363+ self .path = path
364+ self .value = value
340365
341366 def __hash__ (self ) -> int :
342367 return hash (self .path )
@@ -350,13 +375,17 @@ def __str__(self) -> str:
350375#
351376
352377
353- @dataclass
354378class RustNamedType (ABC ): # ABC class
355379 """Abstract class for Rust struct and enum types."""
356380
357- ident : RustIdent
358- attrs : RustAttributes = dataclasses .field (default_factory = list )
359- visibility : str = "pub"
381+ __slots__ = ("ident" , "attrs" , "visibility" )
382+
383+ def __init__ (
384+ self , ident : RustIdent , attrs : Optional [RustAttributes ] = None , visibility : str = "pub"
385+ ):
386+ self .ident = ident
387+ self .attrs = () if attrs is None else attrs
388+ self .visibility = visibility
360389
361390 def __hash__ (self ) -> int :
362391 return hash (self .ident )
@@ -371,13 +400,15 @@ def __str__(self) -> str:
371400 return output .getvalue ()
372401
373402
374- @dataclass # ASSERT: Immutable class
375403class RustField :
376404 """Represents a field in a Rust struct."""
377405
378- ident : RustIdent
379- type : RustPath
380- attrs : RustAttributes = dataclasses .field (default_factory = list )
406+ __slots__ = ("ident" , "type" , "attrs" )
407+
408+ def __init__ (self , ident : RustIdent , type : RustPath , attrs : Optional [RustAttributes ] = None ):
409+ self .ident = ident
410+ self .type = type
411+ self .attrs = () if attrs is None else attrs
381412
382413 def __hash__ (self ) -> int :
383414 return hash (self .ident )
@@ -394,11 +425,21 @@ def write_to(self, writer: IO[str], depth: int = 0) -> None:
394425RustFieldsMut = Union [MutableSequence [RustField ], RustTypeTuple ] # alias
395426
396427
397- @dataclass
398428class RustStruct (RustNamedType ):
399429 """Represents a Rust struct definition."""
400430
401- fields : Optional [RustFields ] = None
431+ __slots__ = ("fields" ,)
432+
433+ def __init__ (
434+ self ,
435+ ident : RustIdent ,
436+ fields : Optional [RustFields ] = None ,
437+ attrs : Optional [RustAttributes ] = None ,
438+ visibility : str = "pub" ,
439+ ):
440+ _attrs = () if attrs is None else attrs
441+ super ().__init__ (ident , _attrs , visibility )
442+ self .fields = fields
402443
403444 def write_to (self , writer : IO [str ], depth : int = 0 ) -> None :
404445 indent = " " * depth
@@ -419,13 +460,20 @@ def write_to(self, writer: IO[str], depth: int = 0) -> None:
419460 writer .write (f"{ indent } }}\n " )
420461
421462
422- @dataclass # ASSERT: Immutable class
423463class RustVariant :
424464 """Represents a variant in a Rust enum."""
425465
426- ident : RustIdent
427- tuple : Optional [RustTypeTuple ] = None
428- attrs : RustAttributes = dataclasses .field (default_factory = list )
466+ __slots__ = ("ident" , "tuple" , "attrs" )
467+
468+ def __init__ (
469+ self ,
470+ ident : RustIdent ,
471+ tuple : Optional [RustTypeTuple ] = None ,
472+ attrs : Optional [RustAttributes ] = None ,
473+ ):
474+ self .ident = ident
475+ self .tuple = tuple
476+ self .attrs = () if attrs is None else attrs
429477
430478 def __hash__ (self ) -> int :
431479 return hash (self .ident )
@@ -435,7 +483,6 @@ def write_to(self, writer: IO[str], depth: int = 0) -> None:
435483
436484 if self .attrs :
437485 writer .write ("\n " .join (f"{ indent } { str (attr )} " for attr in self .attrs ) + "\n " )
438-
439486 writer .write (f"{ indent } { self .ident } " )
440487 if self .tuple :
441488 writer .write (str (self .tuple ))
@@ -462,11 +509,21 @@ def from_path(cls, path: RustPath) -> "RustVariant":
462509RustVariantsMut = MutableSequence [RustVariant ] # alias
463510
464511
465- @dataclass
466512class RustEnum (RustNamedType ):
467513 """Represents a Rust enum definition."""
468514
469- variants : RustVariants = dataclasses .field (default_factory = tuple )
515+ __slots__ = ("variants" ,)
516+
517+ def __init__ (
518+ self ,
519+ ident : RustIdent ,
520+ variants : Optional [RustVariants ] = None ,
521+ attrs : Optional [RustAttributes ] = None ,
522+ visibility : str = "pub" ,
523+ ):
524+ _attrs = () if attrs is None else attrs
525+ super ().__init__ (ident , _attrs , visibility )
526+ self .variants = () if variants is None else variants
470527
471528 def write_to (self , writer : IO [str ], depth : int = 0 ) -> None :
472529 indent = " " * depth
@@ -495,16 +552,22 @@ def salad_macro_write_to(ty: RustNamedType, writer: IO[str], depth: int = 0) ->
495552#
496553
497554
498- @dataclass
499555class RustModuleTree :
500556 """Represents a Rust module with submodules and named types."""
501557
502- ident : RustIdent # ASSERT: Immutable field
503- parent : Optional ["RustModuleTree" ] # ASSERT: Immutable field
504- named_types : MutableMapping [RustIdent , RustNamedType ] = dataclasses .field (default_factory = dict )
505- submodules : MutableMapping [RustIdent , "RustModuleTree" ] = dataclasses .field (
506- default_factory = dict
507- )
558+ __slots__ = ("ident" , "parent" , "named_types" , "submodules" )
559+
560+ def __init__ (
561+ self ,
562+ ident : RustIdent , # ASSERT: Immutable field
563+ parent : Optional ["RustModuleTree" ] = None , # ASSERT: Immutable field
564+ named_types : Optional [MutableMapping [RustIdent , RustNamedType ]] = None ,
565+ submodules : Optional [MutableMapping [RustIdent , "RustModuleTree" ]] = None ,
566+ ):
567+ self .ident = ident
568+ self .parent = parent
569+ self .named_types = {} if named_types is None else named_types
570+ self .submodules = {} if submodules is None else submodules
508571
509572 def __hash__ (self ) -> int :
510573 return hash ((self .ident , self .parent ))
0 commit comments