33
44use proc_macro:: TokenStream ;
55use proc_macro_error:: proc_macro_error;
6+ use quote:: { format_ident, quote, quote_spanned} ;
7+ use syn:: {
8+ parse_macro_input, parse_quote, spanned:: Spanned , Data , DataEnum , DeriveInput , Fields ,
9+ GenericParam , Generics , Ident , Index , ItemStruct ,
10+ } ;
611
712#[ cfg( kani_host) ]
813#[ path = "kani.rs" ]
@@ -12,6 +17,135 @@ mod tool;
1217#[ path = "runtime.rs" ]
1318mod tool;
1419
20+ /// Expands the `#[invariant(...)]` attribute macro.
21+ /// The macro expands to an implementation of the `is_safe` method for the `Invariant` trait.
22+ /// This attribute is only supported for structs.
23+ ///
24+ /// # Example
25+ ///
26+ /// ```ignore
27+ /// #[invariant(self.width == self.height)]
28+ /// struct Square {
29+ /// width: u32,
30+ /// height: u32,
31+ /// }
32+ /// ```
33+ ///
34+ /// expands to:
35+ /// ```ignore
36+ /// impl core::ub_checks::Invariant for Square {
37+ /// fn is_safe(&self) -> bool {
38+ /// self.width == self.height
39+ /// }
40+ /// }
41+ /// ```
42+ /// For more information on the Invariant trait, see its documentation in core::ub_checks.
43+ #[ proc_macro_error]
44+ #[ proc_macro_attribute]
45+ pub fn invariant ( attr : TokenStream , item : TokenStream ) -> TokenStream {
46+ let safe_body = proc_macro2:: TokenStream :: from ( attr) ;
47+ let item = parse_macro_input ! ( item as ItemStruct ) ;
48+ let item_name = & item. ident ;
49+ let ( impl_generics, ty_generics, where_clause) = item. generics . split_for_impl ( ) ;
50+
51+ let expanded = quote ! {
52+ #item
53+ #[ unstable( feature="invariant" , issue="none" ) ]
54+ impl #impl_generics core:: ub_checks:: Invariant for #item_name #ty_generics #where_clause {
55+ fn is_safe( & self ) -> bool {
56+ #safe_body
57+ }
58+ }
59+ } ;
60+
61+ proc_macro:: TokenStream :: from ( expanded)
62+ }
63+
64+ /// Expands the derive macro for the Invariant trait.
65+ /// The macro expands to an implementation of the `is_safe` method for the `Invariant` trait.
66+ /// This macro is only supported for structs and enums.
67+ ///
68+ /// # Example
69+ ///
70+ /// ```ignore
71+ /// #[derive(Invariant)]
72+ /// struct Square {
73+ /// width: u32,
74+ /// height: u32,
75+ /// }
76+ /// ```
77+ ///
78+ /// expands to:
79+ /// ```ignore
80+ /// impl core::ub_checks::Invariant for Square {
81+ /// fn is_safe(&self) -> bool {
82+ /// self.width.is_safe() && self.height.is_safe()
83+ /// }
84+ /// }
85+ /// ```
86+ /// For enums, the body of `is_safe` matches on the variant and calls `is_safe` on its fields,
87+ /// # Example
88+ ///
89+ /// ```ignore
90+ /// #[derive(Invariant)]
91+ /// enum MyEnum {
92+ /// OptionOne(u32, u32),
93+ /// OptionTwo(Square),
94+ /// OptionThree
95+ /// }
96+ /// ```
97+ ///
98+ /// expands to:
99+ /// ```ignore
100+ /// impl core::ub_checks::Invariant for MyEnum {
101+ /// fn is_safe(&self) -> bool {
102+ /// match self {
103+ /// MyEnum::OptionOne(field1, field2) => field1.is_safe() && field2.is_safe(),
104+ /// MyEnum::OptionTwo(field1) => field1.is_safe(),
105+ /// MyEnum::OptionThree => true,
106+ /// }
107+ /// }
108+ /// }
109+ /// ```
110+ /// For more information on the Invariant trait, see its documentation in core::ub_checks.
111+ #[ proc_macro_error]
112+ #[ proc_macro_derive( Invariant ) ]
113+ pub fn derive_invariant ( item : TokenStream ) -> TokenStream {
114+ let derive_item = parse_macro_input ! ( item as DeriveInput ) ;
115+ let item_name = & derive_item. ident ;
116+ let safe_body = match derive_item. data {
117+ Data :: Struct ( struct_data) => {
118+ safe_body ( & struct_data. fields )
119+ } ,
120+ Data :: Enum ( enum_data) => {
121+ let variant_checks = variant_checks ( enum_data, item_name) ;
122+
123+ quote ! {
124+ match self {
125+ #( #variant_checks) , *
126+ }
127+ }
128+ } ,
129+ Data :: Union ( ..) => unimplemented ! ( "Attempted to derive Invariant on a union; Invariant can only be derived for structs and enums." ) ,
130+ } ;
131+
132+ // Add a bound `T: Invariant` to every type parameter T.
133+ let generics = add_trait_bound_invariant ( derive_item. generics ) ;
134+ // Generate an expression to sum up the heap size of each field.
135+ let ( impl_generics, ty_generics, where_clause) = generics. split_for_impl ( ) ;
136+
137+ let expanded = quote ! {
138+ // The generated implementation.
139+ #[ unstable( feature="invariant" , issue="none" ) ]
140+ impl #impl_generics core:: ub_checks:: Invariant for #item_name #ty_generics #where_clause {
141+ fn is_safe( & self ) -> bool {
142+ #safe_body
143+ }
144+ }
145+ } ;
146+ proc_macro:: TokenStream :: from ( expanded)
147+ }
148+
15149#[ proc_macro_error]
16150#[ proc_macro_attribute]
17151pub fn requires ( attr : TokenStream , item : TokenStream ) -> TokenStream {
@@ -29,3 +163,96 @@ pub fn ensures(attr: TokenStream, item: TokenStream) -> TokenStream {
29163pub fn loop_invariant ( attr : TokenStream , stmt_stream : TokenStream ) -> TokenStream {
30164 tool:: loop_invariant ( attr, stmt_stream)
31165}
166+
167+ /// Add a bound `T: Invariant` to every type parameter T.
168+ fn add_trait_bound_invariant ( mut generics : Generics ) -> Generics {
169+ generics. params . iter_mut ( ) . for_each ( |param| {
170+ if let GenericParam :: Type ( type_param) = param {
171+ type_param
172+ . bounds
173+ . push ( parse_quote ! ( core:: ub_checks:: Invariant ) ) ;
174+ }
175+ } ) ;
176+ generics
177+ }
178+
179+ /// Generate safety checks for each variant of an enum
180+ fn variant_checks ( enum_data : DataEnum , item_name : & Ident ) -> Vec < proc_macro2:: TokenStream > {
181+ enum_data
182+ . variants
183+ . iter ( )
184+ . map ( |variant| {
185+ let variant_name = & variant. ident ;
186+ match & variant. fields {
187+ Fields :: Unnamed ( fields) => {
188+ let field_names: Vec < _ > = fields
189+ . unnamed
190+ . iter ( )
191+ . enumerate ( )
192+ . map ( |( i, _) | format_ident ! ( "field{}" , i + 1 ) )
193+ . collect ( ) ;
194+
195+ let field_checks: Vec < _ > = field_names
196+ . iter ( )
197+ . map ( |field_name| {
198+ quote ! { #field_name. is_safe( ) }
199+ } )
200+ . collect ( ) ;
201+
202+ quote ! {
203+ #item_name:: #variant_name( #( #field_names) , * ) => #( #field_checks) &&*
204+ }
205+ }
206+ Fields :: Unit => {
207+ quote ! {
208+ #item_name:: #variant_name => true
209+ }
210+ }
211+ Fields :: Named ( _) => unreachable ! ( "Enums do not have named fields" ) ,
212+ }
213+ } )
214+ . collect ( )
215+ }
216+
217+ /// Generate the body for the `is_safe` method.
218+ /// For each field of the type, enforce that it is safe.
219+ fn safe_body ( fields : & Fields ) -> proc_macro2:: TokenStream {
220+ match fields {
221+ Fields :: Named ( ref fields) => {
222+ let field_safe_calls: Vec < proc_macro2:: TokenStream > = fields
223+ . named
224+ . iter ( )
225+ . map ( |field| {
226+ let name = & field. ident ;
227+ quote_spanned ! { field. span( ) =>
228+ self . #name. is_safe( )
229+ }
230+ } )
231+ . collect ( ) ;
232+ if !field_safe_calls. is_empty ( ) {
233+ quote ! { #( #field_safe_calls ) &&* }
234+ } else {
235+ quote ! { true }
236+ }
237+ }
238+ Fields :: Unnamed ( ref fields) => {
239+ let field_safe_calls: Vec < proc_macro2:: TokenStream > = fields
240+ . unnamed
241+ . iter ( )
242+ . enumerate ( )
243+ . map ( |( idx, field) | {
244+ let field_idx = Index :: from ( idx) ;
245+ quote_spanned ! { field. span( ) =>
246+ self . #field_idx. is_safe( )
247+ }
248+ } )
249+ . collect ( ) ;
250+ if !field_safe_calls. is_empty ( ) {
251+ quote ! { #( #field_safe_calls ) &&* }
252+ } else {
253+ quote ! { true }
254+ }
255+ }
256+ Fields :: Unit => quote ! { true } ,
257+ }
258+ }
0 commit comments