88
99use std:: iter;
1010
11- use proc_macro2:: TokenStream ;
11+ use proc_macro2:: { Span , TokenStream } ;
1212use quote:: quote;
1313use syn:: {
1414 parse:: { Parse , ParseStream } ,
1515 parse_macro_input,
1616 punctuated:: Punctuated ,
1717 token:: Plus ,
18- Error , FnArg , Generics , Ident , ItemTrait , Pat , PatType , Result , ReturnType , Signature , Token ,
19- TraitBound , TraitItem , TraitItemConst , TraitItemFn , TraitItemType , Type , TypeImplTrait ,
20- TypeParamBound ,
18+ Error , FnArg , Generics , Ident , ItemTrait , Pat , PatType , Receiver , Result , ReturnType ,
19+ Signature , Token , TraitBound , TraitItem , TraitItemConst , TraitItemFn , TraitItemType , Type ,
20+ TypeImplTrait , TypeParamBound , WhereClause ,
2121} ;
2222
2323struct Attrs {
@@ -119,10 +119,10 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
119119 // fn stream(&self) -> impl Iterator<Item = i32> + Send;
120120 // fn call(&self) -> u32;
121121 // }
122- let TraitItem :: Fn ( fn_item @ TraitItemFn { sig, .. } ) = item else {
122+ let TraitItem :: Fn ( fn_item @ TraitItemFn { sig, default , .. } ) = item else {
123123 return item. clone ( ) ;
124124 } ;
125- let ( arrow , output ) = if sig. asyncness . is_some ( ) {
125+ let ( sig , default ) = if sig. asyncness . is_some ( ) {
126126 let orig = match & sig. output {
127127 ReturnType :: Default => quote ! { ( ) } ,
128128 ReturnType :: Type ( _, ty) => quote ! { #ty } ,
@@ -134,7 +134,22 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
134134 . chain ( bounds. iter ( ) . cloned ( ) )
135135 . collect ( ) ,
136136 } ) ;
137- ( syn:: parse2 ( quote ! { -> } ) . unwrap ( ) , ty)
137+ let mut sig = sig. clone ( ) ;
138+ if default. is_some ( ) {
139+ add_receiver_bounds ( & mut sig) ;
140+ }
141+
142+ (
143+ Signature {
144+ asyncness : None ,
145+ output : ReturnType :: Type ( syn:: parse2 ( quote ! { -> } ) . unwrap ( ) , Box :: new ( ty) ) ,
146+ ..sig. clone ( )
147+ } ,
148+ fn_item
149+ . default
150+ . as_ref ( )
151+ . map ( |b| syn:: parse2 ( quote ! { { async move #b } } ) . unwrap ( ) ) ,
152+ )
138153 } else {
139154 match & sig. output {
140155 ReturnType :: Type ( arrow, ty) => match & * * ty {
@@ -143,19 +158,22 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
143158 impl_token : it. impl_token ,
144159 bounds : it. bounds . iter ( ) . chain ( bounds) . cloned ( ) . collect ( ) ,
145160 } ) ;
146- ( * arrow, ty)
161+ (
162+ Signature {
163+ output : ReturnType :: Type ( * arrow, Box :: new ( ty) ) ,
164+ ..sig. clone ( )
165+ } ,
166+ fn_item. default . clone ( ) ,
167+ )
147168 }
148169 _ => return item. clone ( ) ,
149170 } ,
150171 ReturnType :: Default => return item. clone ( ) ,
151172 }
152173 } ;
153174 TraitItem :: Fn ( TraitItemFn {
154- sig : Signature {
155- asyncness : None ,
156- output : ReturnType :: Type ( arrow, Box :: new ( output) ) ,
157- ..sig. clone ( )
158- } ,
175+ sig,
176+ default,
159177 ..fn_item. clone ( )
160178 } )
161179}
@@ -164,8 +182,26 @@ fn mk_blanket_impl(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
164182 let orig = & tr. ident ;
165183 let variant = & attrs. variant . name ;
166184 let items = tr. items . iter ( ) . map ( |item| blanket_impl_item ( item, variant) ) ;
185+ let self_is_sync = tr
186+ . items
187+ . iter ( )
188+ . any ( |item| {
189+ matches ! (
190+ item,
191+ TraitItem :: Fn ( TraitItemFn {
192+ default : Some ( _) ,
193+ ..
194+ } )
195+ )
196+ } )
197+ . then ( || quote ! { Self : Sync } )
198+ . unwrap_or_default ( ) ;
167199 quote ! {
168- impl <T > #orig for T where T : #variant {
200+ impl <T > #orig for T
201+ where
202+ T : #variant,
203+ #self_is_sync
204+ {
169205 #( #items) *
170206 }
171207 }
@@ -205,6 +241,7 @@ fn blanket_impl_item(item: &TraitItem, variant: &Ident) -> TokenStream {
205241 } else {
206242 quote ! { }
207243 } ;
244+
208245 quote ! {
209246 #sig {
210247 <Self as #variant>:: #ident( #( #args) , * ) #maybe_await
@@ -228,3 +265,40 @@ fn blanket_impl_item(item: &TraitItem, variant: &Ident) -> TokenStream {
228265 _ => Error :: new_spanned ( item, "unsupported item type" ) . into_compile_error ( ) ,
229266 }
230267}
268+
269+ fn add_receiver_bounds ( sig : & mut Signature ) {
270+ if let Some ( FnArg :: Receiver ( Receiver { ty, reference, .. } ) ) = sig. inputs . first_mut ( ) {
271+ let predicate =
272+ if let ( Type :: Reference ( reference) , Some ( ( _and, lt) ) ) = ( & mut * * ty, reference) {
273+ let lifetime = syn:: Lifetime {
274+ apostrophe : Span :: mixed_site ( ) ,
275+ ident : Ident :: new ( "the_self_lt" , Span :: mixed_site ( ) ) ,
276+ } ;
277+ sig. generics . params . insert (
278+ 0 ,
279+ syn:: GenericParam :: Lifetime ( syn:: LifetimeParam {
280+ lifetime : lifetime. clone ( ) ,
281+ colon_token : None ,
282+ bounds : Default :: default ( ) ,
283+ attrs : Default :: default ( ) ,
284+ } ) ,
285+ ) ;
286+ reference. lifetime = Some ( lifetime. clone ( ) ) ;
287+ let predicate = syn:: parse2 ( quote ! { #reference: Send } ) . unwrap ( ) ;
288+ * lt = Some ( lifetime) ;
289+ predicate
290+ } else {
291+ syn:: parse2 ( quote ! { #ty: Send } ) . unwrap ( )
292+ } ;
293+
294+ if let Some ( wh) = & mut sig. generics . where_clause {
295+ wh. predicates . push ( predicate) ;
296+ } else {
297+ let where_clause = WhereClause {
298+ where_token : Token ! [ where ] ( Span :: mixed_site ( ) ) ,
299+ predicates : Punctuated :: from_iter ( [ predicate] ) ,
300+ } ;
301+ sig. generics . where_clause = Some ( where_clause) ;
302+ }
303+ }
304+ }
0 commit comments