@@ -185,7 +185,7 @@ pub enum ScopeParent {
185185
186186// List of class names that a type refers to, after stripping Optional and Awaitable.
187187#[ derive( Debug , Clone , Serialize , PartialEq , Eq ) ]
188- struct ClassNamesFromType {
188+ pub struct ClassNamesFromType {
189189 class_names : Vec < ClassRef > ,
190190 #[ serde( skip_serializing_if = "<&bool>::not" ) ]
191191 stripped_coroutine : bool ,
@@ -593,9 +593,9 @@ fn has_superclass(class: &Class, want: &Class, context: &ModuleContext) -> bool
593593}
594594
595595impl ClassNamesFromType {
596- fn from_class ( class : Class , context : & ModuleContext ) -> ClassNamesFromType {
596+ pub fn from_class ( class : & Class , context : & ModuleContext ) -> ClassNamesFromType {
597597 ClassNamesFromType {
598- class_names : vec ! [ ClassRef :: from_class( & class, context. module_ids) ] ,
598+ class_names : vec ! [ ClassRef :: from_class( class, context. module_ids) ] ,
599599 stripped_coroutine : false ,
600600 stripped_optional : false ,
601601 stripped_readonly : false ,
@@ -604,7 +604,19 @@ impl ClassNamesFromType {
604604 }
605605 }
606606
607- fn not_a_class ( ) -> ClassNamesFromType {
607+ #[ cfg( test) ]
608+ pub fn from_classes ( class_names : Vec < ClassRef > , is_exhaustive : bool ) -> ClassNamesFromType {
609+ ClassNamesFromType {
610+ class_names,
611+ stripped_coroutine : false ,
612+ stripped_optional : false ,
613+ stripped_readonly : false ,
614+ unbound_type_variable : false ,
615+ is_exhaustive,
616+ }
617+ }
618+
619+ pub fn not_a_class ( ) -> ClassNamesFromType {
608620 ClassNamesFromType {
609621 class_names : vec ! [ ] ,
610622 stripped_coroutine : false ,
@@ -619,13 +631,13 @@ impl ClassNamesFromType {
619631 self . class_names . is_empty ( )
620632 }
621633
622- fn with_strip_optional ( mut self ) -> ClassNamesFromType {
623- self . stripped_optional = true ;
634+ pub fn with_strip_optional ( mut self , stripped_optional : bool ) -> ClassNamesFromType {
635+ self . stripped_optional = stripped_optional ;
624636 self
625637 }
626638
627- fn with_strip_coroutine ( mut self ) -> ClassNamesFromType {
628- self . stripped_coroutine = true ;
639+ pub fn with_strip_coroutine ( mut self , stripped_coroutine : bool ) -> ClassNamesFromType {
640+ self . stripped_coroutine = stripped_coroutine ;
629641 self
630642 }
631643
@@ -701,22 +713,20 @@ fn is_scalar_type(get: &Type, want: &Class, context: &ModuleContext) -> bool {
701713
702714fn get_classes_of_type ( type_ : & Type , context : & ModuleContext ) -> ClassNamesFromType {
703715 if let Some ( inner) = strip_optional ( type_) {
704- return get_classes_of_type ( inner, context) . with_strip_optional ( ) ;
716+ return get_classes_of_type ( inner, context) . with_strip_optional ( true ) ;
705717 }
706718 if let Some ( inner) = strip_awaitable ( type_, context) {
707- return get_classes_of_type ( inner, context) . with_strip_coroutine ( ) ;
719+ return get_classes_of_type ( inner, context) . with_strip_coroutine ( true ) ;
708720 }
709721 if let Some ( inner) = strip_coroutine ( type_, context) {
710- return get_classes_of_type ( inner, context) . with_strip_coroutine ( ) ;
722+ return get_classes_of_type ( inner, context) . with_strip_coroutine ( true ) ;
711723 }
712724 // No need to strip ReadOnly[], it is already stripped by pyrefly.
713725 match type_ {
714726 Type :: ClassType ( class_type) => {
715- ClassNamesFromType :: from_class ( class_type. class_object ( ) . clone ( ) , context)
716- }
717- Type :: Tuple ( _) => {
718- ClassNamesFromType :: from_class ( context. stdlib . tuple_object ( ) . clone ( ) , context)
727+ ClassNamesFromType :: from_class ( class_type. class_object ( ) , context)
719728 }
729+ Type :: Tuple ( _) => ClassNamesFromType :: from_class ( context. stdlib . tuple_object ( ) , context) ,
720730 Type :: Union ( elements) if !elements. is_empty ( ) => elements
721731 . iter ( )
722732 . map ( |inner| get_classes_of_type ( inner, context) )
@@ -729,6 +739,42 @@ fn get_classes_of_type(type_: &Type, context: &ModuleContext) -> ClassNamesFromT
729739}
730740
731741impl PysaType {
742+ #[ cfg( test) ]
743+ pub fn new ( string : String , class_names : ClassNamesFromType ) -> PysaType {
744+ PysaType {
745+ string,
746+ is_bool : false ,
747+ is_int : false ,
748+ is_float : false ,
749+ is_enum : false ,
750+ class_names,
751+ }
752+ }
753+
754+ #[ cfg( test) ]
755+ pub fn with_is_bool ( mut self , is_bool : bool ) -> PysaType {
756+ self . is_bool = is_bool;
757+ self
758+ }
759+
760+ #[ cfg( test) ]
761+ pub fn with_is_int ( mut self , is_int : bool ) -> PysaType {
762+ self . is_int = is_int;
763+ self
764+ }
765+
766+ #[ cfg( test) ]
767+ pub fn with_is_float ( mut self , is_float : bool ) -> PysaType {
768+ self . is_float = is_float;
769+ self
770+ }
771+
772+ #[ cfg( test) ]
773+ pub fn with_is_enum ( mut self , is_enum : bool ) -> PysaType {
774+ self . is_enum = is_enum;
775+ self
776+ }
777+
732778 pub fn from_type ( type_ : & Type , context : & ModuleContext ) -> PysaType {
733779 // Promote `Literal[..]` into `str` or `int`.
734780 let type_ = type_. clone ( ) . promote_literals ( & context. stdlib ) ;
0 commit comments