@@ -18,6 +18,7 @@ extern crate stable_mir;
1818
1919use rustc_smir:: rustc_internal;
2020use stable_mir:: mir:: MirVisitor ;
21+ use stable_mir:: mir:: MutMirVisitor ;
2122use stable_mir:: * ;
2223use std:: collections:: HashSet ;
2324use std:: io:: Write ;
@@ -99,6 +100,83 @@ impl<'a> mir::MirVisitor for TestVisitor<'a> {
99100 }
100101}
101102
103+ fn test_mut_visitor ( ) -> ControlFlow < ( ) > {
104+ let main_fn = stable_mir:: entry_fn ( ) ;
105+ let mut main_body = main_fn. unwrap ( ) . expect_body ( ) ;
106+ let locals = main_body. locals ( ) . to_vec ( ) ;
107+ let mut main_visitor = TestMutVisitor :: collect ( locals) ;
108+ main_visitor. visit_body ( & mut main_body) ;
109+ assert ! ( main_visitor. ret_val. is_some( ) ) ;
110+ assert ! ( main_visitor. args. is_empty( ) ) ;
111+ assert ! ( main_visitor. tys. contains( & main_visitor. ret_val. unwrap( ) . ty) ) ;
112+ assert ! ( !main_visitor. calls. is_empty( ) ) ;
113+
114+ let exit_fn = main_visitor. calls . last ( ) . unwrap ( ) ;
115+ assert ! ( exit_fn. mangled_name( ) . contains( "exit_fn" ) , "Unexpected last function: {exit_fn:?}" ) ;
116+
117+ let mut exit_body = exit_fn. body ( ) . unwrap ( ) ;
118+ let locals = exit_body. locals ( ) . to_vec ( ) ;
119+ let mut exit_visitor = TestMutVisitor :: collect ( locals) ;
120+ exit_visitor. visit_body ( & mut exit_body) ;
121+ assert ! ( exit_visitor. ret_val. is_some( ) ) ;
122+ assert_eq ! ( exit_visitor. args. len( ) , 1 ) ;
123+ assert ! ( exit_visitor. tys. contains( & exit_visitor. ret_val. unwrap( ) . ty) ) ;
124+ assert ! ( exit_visitor. tys. contains( & exit_visitor. args[ 0 ] . ty) ) ;
125+ ControlFlow :: Continue ( ( ) )
126+ }
127+
128+ struct TestMutVisitor {
129+ locals : Vec < mir:: LocalDecl > ,
130+ pub tys : HashSet < ty:: Ty > ,
131+ pub ret_val : Option < mir:: LocalDecl > ,
132+ pub args : Vec < mir:: LocalDecl > ,
133+ pub calls : Vec < mir:: mono:: Instance > ,
134+ }
135+
136+ impl TestMutVisitor {
137+ fn collect ( locals : Vec < mir:: LocalDecl > ) -> TestMutVisitor {
138+ let visitor = TestMutVisitor {
139+ locals : locals,
140+ tys : Default :: default ( ) ,
141+ ret_val : None ,
142+ args : vec ! [ ] ,
143+ calls : vec ! [ ] ,
144+ } ;
145+ visitor
146+ }
147+ }
148+
149+ impl mir:: MutMirVisitor for TestMutVisitor {
150+ fn visit_ty ( & mut self , ty : & mut ty:: Ty , _location : mir:: visit:: Location ) {
151+ self . tys . insert ( * ty) ;
152+ self . super_ty ( ty)
153+ }
154+
155+ fn visit_ret_decl ( & mut self , local : mir:: Local , decl : & mut mir:: LocalDecl ) {
156+ assert ! ( local == mir:: RETURN_LOCAL ) ;
157+ assert ! ( self . ret_val. is_none( ) ) ;
158+ self . ret_val = Some ( decl. clone ( ) ) ;
159+ self . super_ret_decl ( local, decl) ;
160+ }
161+
162+ fn visit_arg_decl ( & mut self , local : mir:: Local , decl : & mut mir:: LocalDecl ) {
163+ self . args . push ( decl. clone ( ) ) ;
164+ assert_eq ! ( local, self . args. len( ) ) ;
165+ self . super_arg_decl ( local, decl) ;
166+ }
167+
168+ fn visit_terminator ( & mut self , term : & mut mir:: Terminator , location : mir:: visit:: Location ) {
169+ if let mir:: TerminatorKind :: Call { func, .. } = & mut term. kind {
170+ let ty:: TyKind :: RigidTy ( ty) = func. ty ( & self . locals ) . unwrap ( ) . kind ( ) else {
171+ unreachable ! ( )
172+ } ;
173+ let ty:: RigidTy :: FnDef ( def, args) = ty else { unreachable ! ( ) } ;
174+ self . calls . push ( mir:: mono:: Instance :: resolve ( def, & args) . unwrap ( ) ) ;
175+ }
176+ self . super_terminator ( term, location) ;
177+ }
178+ }
179+
102180/// This test will generate and analyze a dummy crate using the stable mir.
103181/// For that, it will first write the dummy crate into a file.
104182/// Then it will create a `StableMir` using custom arguments and then
@@ -113,7 +191,8 @@ fn main() {
113191 CRATE_NAME . to_string( ) ,
114192 path. to_string( ) ,
115193 ] ;
116- run ! ( args, test_visitor) . unwrap ( ) ;
194+ run ! ( args. clone( ) , test_visitor) . unwrap ( ) ;
195+ run ! ( args, test_mut_visitor) . unwrap ( ) ;
117196}
118197
119198fn generate_input ( path : & str ) -> std:: io:: Result < ( ) > {
0 commit comments