@@ -3,17 +3,17 @@ extern crate proc_macro;
33use proc_macro2:: { Span , TokenStream } ;
44use quote:: quote;
55use quote:: ToTokens ;
6- use syn:: { parse_quote, DeriveInput , GenericParam , Ident , TypeParamBound } ;
6+ use syn:: { parse_quote, DeriveInput , Ident , TypeParam , TypeParamBound } ;
77
88use synstructure:: decl_derive;
99
1010/// Checks whether a generic parameter has a `: HasInterner` bound
11- fn has_interner ( param : & GenericParam ) -> Option < & Ident > {
11+ fn has_interner ( param : & TypeParam ) -> Option < & Ident > {
1212 bounded_by_trait ( param, "HasInterner" )
1313}
1414
1515/// Checks whether a generic parameter has a `: Interner` bound
16- fn is_interner ( param : & GenericParam ) -> Option < & Ident > {
16+ fn is_interner ( param : & TypeParam ) -> Option < & Ident > {
1717 bounded_by_trait ( param, "Interner" )
1818}
1919
@@ -28,48 +28,44 @@ fn has_interner_attr(input: &DeriveInput) -> Option<TokenStream> {
2828 )
2929}
3030
31- fn bounded_by_trait < ' p > ( param : & ' p GenericParam , name : & str ) -> Option < & ' p Ident > {
31+ fn bounded_by_trait < ' p > ( param : & ' p TypeParam , name : & str ) -> Option < & ' p Ident > {
3232 let name = Some ( String :: from ( name) ) ;
33- match param {
34- GenericParam :: Type ( ref t) => t. bounds . iter ( ) . find_map ( |b| {
35- if let TypeParamBound :: Trait ( trait_bound) = b {
36- if trait_bound
37- . path
38- . segments
39- . last ( )
40- . map ( |s| s. ident . to_string ( ) )
41- == name
42- {
43- return Some ( & t. ident ) ;
44- }
33+ param. bounds . iter ( ) . find_map ( |b| {
34+ if let TypeParamBound :: Trait ( trait_bound) = b {
35+ if trait_bound
36+ . path
37+ . segments
38+ . last ( )
39+ . map ( |s| s. ident . to_string ( ) )
40+ == name
41+ {
42+ return Some ( & param. ident ) ;
4543 }
46- None
47- } ) ,
48- _ => None ,
49- }
44+ }
45+ None
46+ } )
5047}
5148
52- fn get_generic_param ( input : & DeriveInput ) -> & GenericParam {
53- match input. generics . params . len ( ) {
54- 1 => { }
49+ fn get_intern_param ( input : & DeriveInput ) -> Option < ( DeriveKind , & Ident ) > {
50+ let mut params = input. generics . type_params ( ) . filter_map ( |param| {
51+ has_interner ( param)
52+ . map ( |ident| ( DeriveKind :: FromHasInterner , ident) )
53+ . or_else ( || is_interner ( param) . map ( |ident| ( DeriveKind :: FromInterner , ident) ) )
54+ } ) ;
5555
56- 0 => panic ! (
57- "deriving this trait requires a single type parameter or a `#[has_interner]` attr"
58- ) ,
56+ let param = params. next ( ) ;
57+ assert ! ( params. next( ) . is_none( ) , "deriving this trait only works with at most one type parameter that implements HasInterner or Interner" ) ;
5958
60- _ => panic ! ( "deriving this trait only works with a single type parameter" ) ,
61- } ;
62- & input. generics . params [ 0 ]
59+ param
6360}
6461
65- fn get_generic_param_name ( input : & DeriveInput ) -> Option < & Ident > {
66- match get_generic_param ( input) {
67- GenericParam :: Type ( t) => Some ( & t. ident ) ,
68- _ => None ,
69- }
62+ fn get_intern_param_name ( input : & DeriveInput ) -> & Ident {
63+ get_intern_param ( input)
64+ . expect ( "deriving this trait requires a parameter that implements HasInterner or Interner" )
65+ . 1
7066}
7167
72- fn find_interner ( s : & mut synstructure:: Structure ) -> ( TokenStream , DeriveKind ) {
68+ fn try_find_interner ( s : & mut synstructure:: Structure ) -> Option < ( TokenStream , DeriveKind ) > {
7369 let input = s. ast ( ) ;
7470
7571 if let Some ( arg) = has_interner_attr ( input) {
@@ -79,35 +75,40 @@ fn find_interner(s: &mut synstructure::Structure) -> (TokenStream, DeriveKind) {
7975 // struct S {
8076 //
8177 // }
82- return ( arg, DeriveKind :: FromHasInternerAttr ) ;
78+ return Some ( ( arg, DeriveKind :: FromHasInternerAttr ) ) ;
8379 }
8480
85- let generic_param0 = get_generic_param ( input) ;
86-
87- if let Some ( param) = has_interner ( generic_param0) {
88- // HasInterner bound:
89- //
90- // Example:
91- //
92- // struct Binders<T: HasInterner> { }
93- s. add_impl_generic ( parse_quote ! { _I } ) ;
94-
95- s. add_where_predicate ( parse_quote ! { _I: :: chalk_ir:: interner:: Interner } ) ;
96- s. add_where_predicate (
97- parse_quote ! { #param: :: chalk_ir:: interner:: HasInterner <Interner = _I> } ,
98- ) ;
81+ get_intern_param ( input) . map ( |generic_param0| match generic_param0 {
82+ ( DeriveKind :: FromHasInterner , param) => {
83+ // HasInterner bound:
84+ //
85+ // Example:
86+ //
87+ // struct Binders<T: HasInterner> { }
88+ s. add_impl_generic ( parse_quote ! { _I } ) ;
89+
90+ s. add_where_predicate ( parse_quote ! { _I: :: chalk_ir:: interner:: Interner } ) ;
91+ s. add_where_predicate (
92+ parse_quote ! { #param: :: chalk_ir:: interner:: HasInterner <Interner = _I> } ,
93+ ) ;
94+
95+ ( quote ! { _I } , DeriveKind :: FromHasInterner )
96+ }
97+ ( DeriveKind :: FromInterner , i) => {
98+ // Interner bound:
99+ //
100+ // Example:
101+ //
102+ // struct Foo<I: Interner> { }
103+ ( quote ! { #i } , DeriveKind :: FromInterner )
104+ }
105+ _ => unreachable ! ( ) ,
106+ } )
107+ }
99108
100- ( quote ! { _I } , DeriveKind :: FromHasInterner )
101- } else if let Some ( i) = is_interner ( generic_param0) {
102- // Interner bound:
103- //
104- // Example:
105- //
106- // struct Foo<I: Interner> { }
107- ( quote ! { #i } , DeriveKind :: FromInterner )
108- } else {
109- panic ! ( "deriving this trait requires a parameter that implements HasInterner or Interner" , ) ;
110- }
109+ fn find_interner ( s : & mut synstructure:: Structure ) -> ( TokenStream , DeriveKind ) {
110+ try_find_interner ( s)
111+ . expect ( "deriving this trait requires a `#[has_interner]` attr or a parameter that implements HasInterner or Interner" )
111112}
112113
113114#[ derive( Copy , Clone , PartialEq ) ]
@@ -174,7 +175,7 @@ fn derive_any_type_visitable(
174175 } ) ;
175176
176177 if kind == DeriveKind :: FromHasInterner {
177- let param = get_generic_param_name ( input) . unwrap ( ) ;
178+ let param = get_intern_param_name ( input) ;
178179 s. add_where_predicate ( parse_quote ! { #param: :: chalk_ir:: visit:: TypeVisitable <#interner> } ) ;
179180 }
180181
@@ -278,7 +279,7 @@ fn derive_type_foldable(mut s: synstructure::Structure) -> TokenStream {
278279 let input = s. ast ( ) ;
279280
280281 if kind == DeriveKind :: FromHasInterner {
281- let param = get_generic_param_name ( input) . unwrap ( ) ;
282+ let param = get_intern_param_name ( input) ;
282283 s. add_where_predicate ( parse_quote ! { #param: :: chalk_ir:: fold:: TypeFoldable <#interner> } ) ;
283284 } ;
284285
@@ -298,7 +299,14 @@ fn derive_type_foldable(mut s: synstructure::Structure) -> TokenStream {
298299}
299300
300301fn derive_fallible_type_folder ( mut s : synstructure:: Structure ) -> TokenStream {
301- let ( interner, _) = find_interner ( & mut s) ;
302+ let interner = try_find_interner ( & mut s) . map_or_else (
303+ || {
304+ s. add_impl_generic ( parse_quote ! { _I } ) ;
305+ s. add_where_predicate ( parse_quote ! { _I: :: chalk_ir:: interner:: Interner } ) ;
306+ quote ! { _I }
307+ } ,
308+ |( interner, _) | interner,
309+ ) ;
302310 s. underscore_const ( true ) ;
303311 s. unbound_impl (
304312 quote ! ( :: chalk_ir:: fold:: FallibleTypeFolder <#interner>) ,
0 commit comments