@@ -4,17 +4,16 @@ use ::tt::Ident;
44use base_db:: { CrateOrigin , LangCrateOrigin } ;
55use itertools:: izip;
66use mbe:: TokenMap ;
7- use std :: collections :: HashSet ;
7+ use rustc_hash :: FxHashSet ;
88use stdx:: never;
99use tracing:: debug;
1010
11- use crate :: tt:: { self , TokenId } ;
12- use syntax:: {
13- ast:: {
14- self , AstNode , FieldList , HasAttrs , HasGenericParams , HasModuleItem , HasName ,
15- HasTypeBounds , PathType ,
16- } ,
17- match_ast,
11+ use crate :: {
12+ name:: { AsName , Name } ,
13+ tt:: { self , TokenId } ,
14+ } ;
15+ use syntax:: ast:: {
16+ self , AstNode , FieldList , HasAttrs , HasGenericParams , HasModuleItem , HasName , HasTypeBounds ,
1817} ;
1918
2019use crate :: { db:: ExpandDatabase , name, quote, ExpandError , ExpandResult , MacroCallId } ;
@@ -201,41 +200,54 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
201200 debug ! ( "no module item parsed" ) ;
202201 ExpandError :: Other ( "no item found" . into ( ) )
203202 } ) ?;
204- let node = item. syntax ( ) ;
205- let ( name, params, shape) = match_ast ! {
206- match node {
207- ast:: Struct ( it) => ( it. name( ) , it. generic_param_list( ) , AdtShape :: Struct ( VariantShape :: from( it. field_list( ) , & token_map) ?) ) ,
208- ast:: Enum ( it) => {
209- let default_variant = it. variant_list( ) . into_iter( ) . flat_map( |x| x. variants( ) ) . position( |x| x. attrs( ) . any( |x| x. simple_name( ) == Some ( "default" . into( ) ) ) ) ;
210- (
211- it. name( ) ,
212- it. generic_param_list( ) ,
213- AdtShape :: Enum {
214- default_variant,
215- variants: it. variant_list( )
216- . into_iter( )
217- . flat_map( |x| x. variants( ) )
218- . map( |x| Ok ( ( name_to_token( & token_map, x. name( ) ) ?, VariantShape :: from( x. field_list( ) , & token_map) ?) ) ) . collect:: <Result <_, ExpandError >>( ) ?
219- }
220- )
221- } ,
222- ast:: Union ( it) => ( it. name( ) , it. generic_param_list( ) , AdtShape :: Union ) ,
223- _ => {
224- debug!( "unexpected node is {:?}" , node) ;
225- return Err ( ExpandError :: Other ( "expected struct, enum or union" . into( ) ) )
226- } ,
203+ let adt = ast:: Adt :: cast ( item. syntax ( ) . clone ( ) ) . ok_or_else ( || {
204+ debug ! ( "expected adt, found: {:?}" , item) ;
205+ ExpandError :: Other ( "expected struct, enum or union" . into ( ) )
206+ } ) ?;
207+ let ( name, generic_param_list, shape) = match & adt {
208+ ast:: Adt :: Struct ( it) => (
209+ it. name ( ) ,
210+ it. generic_param_list ( ) ,
211+ AdtShape :: Struct ( VariantShape :: from ( it. field_list ( ) , & token_map) ?) ,
212+ ) ,
213+ ast:: Adt :: Enum ( it) => {
214+ let default_variant = it
215+ . variant_list ( )
216+ . into_iter ( )
217+ . flat_map ( |x| x. variants ( ) )
218+ . position ( |x| x. attrs ( ) . any ( |x| x. simple_name ( ) == Some ( "default" . into ( ) ) ) ) ;
219+ (
220+ it. name ( ) ,
221+ it. generic_param_list ( ) ,
222+ AdtShape :: Enum {
223+ default_variant,
224+ variants : it
225+ . variant_list ( )
226+ . into_iter ( )
227+ . flat_map ( |x| x. variants ( ) )
228+ . map ( |x| {
229+ Ok ( (
230+ name_to_token ( & token_map, x. name ( ) ) ?,
231+ VariantShape :: from ( x. field_list ( ) , & token_map) ?,
232+ ) )
233+ } )
234+ . collect :: < Result < _ , ExpandError > > ( ) ?,
235+ } ,
236+ )
227237 }
238+ ast:: Adt :: Union ( it) => ( it. name ( ) , it. generic_param_list ( ) , AdtShape :: Union ) ,
228239 } ;
229- let mut param_type_set: HashSet < String > = HashSet :: new ( ) ;
230- let param_types = params
240+
241+ let mut param_type_set: FxHashSet < Name > = FxHashSet :: default ( ) ;
242+ let param_types = generic_param_list
231243 . into_iter ( )
232244 . flat_map ( |param_list| param_list. type_or_const_params ( ) )
233245 . map ( |param| {
234246 let name = {
235247 let this = param. name ( ) ;
236248 match this {
237249 Some ( x) => {
238- param_type_set. insert ( x. to_string ( ) ) ;
250+ param_type_set. insert ( x. as_name ( ) ) ;
239251 mbe:: syntax_node_to_token_tree ( x. syntax ( ) ) . 0
240252 }
241253 None => tt:: Subtree :: empty ( ) ,
@@ -259,37 +271,33 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
259271 ( name, ty, bounds)
260272 } )
261273 . collect ( ) ;
262- let is_associated_type = |p : & PathType | {
263- if let Some ( p) = p. path ( ) {
264- if let Some ( parent) = p. qualifier ( ) {
265- if let Some ( x) = parent. segment ( ) {
266- if let Some ( x) = x. path_type ( ) {
267- if let Some ( x) = x. path ( ) {
268- if let Some ( pname) = x. as_single_name_ref ( ) {
269- if param_type_set. contains ( & pname. to_string ( ) ) {
270- // <T as Trait>::Assoc
271- return true ;
272- }
273- }
274- }
275- }
276- }
277- if let Some ( pname) = parent. as_single_name_ref ( ) {
278- if param_type_set. contains ( & pname. to_string ( ) ) {
279- // T::Assoc
280- return true ;
281- }
282- }
283- }
284- }
285- false
274+
275+ // For a generic parameter `T`, when shorthand associated type `T::Assoc` appears in field
276+ // types (of any variant for enums), we generate trait bound for it. It sounds reasonable to
277+ // also generate trait bound for qualified associated type `<T as Trait>::Assoc`, but rustc
278+ // does not do that for some unknown reason.
279+ //
280+ // See the analogous function in rustc [find_type_parameters()] and rust-lang/rust#50730.
281+ // [find_type_parameters()]: https://github.com/rust-lang/rust/blob/1.70.0/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs#L378
282+
283+ // It's cumbersome to deal with the distinct structures of ADTs, so let's just get untyped
284+ // `SyntaxNode` that contains fields and look for descendant `ast::PathType`s. Of note is that
285+ // we should not inspect `ast::PathType`s in parameter bounds and where clauses.
286+ let field_list = match adt {
287+ ast:: Adt :: Enum ( it) => it. variant_list ( ) . map ( |list| list. syntax ( ) . clone ( ) ) ,
288+ ast:: Adt :: Struct ( it) => it. field_list ( ) . map ( |list| list. syntax ( ) . clone ( ) ) ,
289+ ast:: Adt :: Union ( it) => it. record_field_list ( ) . map ( |list| list. syntax ( ) . clone ( ) ) ,
286290 } ;
287- let associated_types = node
288- . descendants ( )
289- . filter_map ( PathType :: cast)
290- . filter ( is_associated_type)
291+ let associated_types = field_list
292+ . into_iter ( )
293+ . flat_map ( |it| it. descendants ( ) )
294+ . filter_map ( ast:: PathType :: cast)
295+ . filter_map ( |p| {
296+ let name = p. path ( ) ?. qualifier ( ) ?. as_single_name_ref ( ) ?. as_name ( ) ;
297+ param_type_set. contains ( & name) . then_some ( p)
298+ } )
291299 . map ( |x| mbe:: syntax_node_to_token_tree ( x. syntax ( ) ) . 0 )
292- . collect :: < Vec < _ > > ( ) ;
300+ . collect ( ) ;
293301 let name_token = name_to_token ( & token_map, name) ?;
294302 Ok ( BasicAdtInfo { name : name_token, shape, param_types, associated_types } )
295303}
@@ -334,18 +342,18 @@ fn name_to_token(token_map: &TokenMap, name: Option<ast::Name>) -> Result<tt::Id
334342/// }
335343/// ```
336344///
337- /// where B1, ..., BN are the bounds given by `bounds_paths`.'. Z is a phantom type, and
345+ /// where B1, ..., BN are the bounds given by `bounds_paths`. Z is a phantom type, and
338346/// therefore does not get bound by the derived trait.
339347fn expand_simple_derive (
340348 tt : & tt:: Subtree ,
341349 trait_path : tt:: Subtree ,
342- trait_body : impl FnOnce ( & BasicAdtInfo ) -> tt:: Subtree ,
350+ make_trait_body : impl FnOnce ( & BasicAdtInfo ) -> tt:: Subtree ,
343351) -> ExpandResult < tt:: Subtree > {
344352 let info = match parse_adt ( tt) {
345353 Ok ( info) => info,
346354 Err ( e) => return ExpandResult :: new ( tt:: Subtree :: empty ( ) , e) ,
347355 } ;
348- let trait_body = trait_body ( & info) ;
356+ let trait_body = make_trait_body ( & info) ;
349357 let mut where_block = vec ! [ ] ;
350358 let ( params, args) : ( Vec < _ > , Vec < _ > ) = info
351359 . param_types
0 commit comments