@@ -9,7 +9,7 @@ use syn::{
99 parse:: { Parse , ParseStream } ,
1010 parse_macro_input,
1111 spanned:: Spanned ,
12- Expr , Ident , Item , ItemEnum , Token , Variant ,
12+ Expr , Ident , DeriveInput , Data , Token , Variant ,
1313} ;
1414
1515struct Flag < ' a > {
@@ -58,14 +58,8 @@ pub fn bitflags_internal(
5858 input : proc_macro:: TokenStream ,
5959) -> proc_macro:: TokenStream {
6060 let Parameters { default } = parse_macro_input ! ( attr as Parameters ) ;
61- let mut ast = parse_macro_input ! ( input as Item ) ;
62- let output = match ast {
63- Item :: Enum ( ref mut item_enum) => gen_enumflags ( item_enum, default) ,
64- _ => Err ( syn:: Error :: new_spanned (
65- & ast,
66- "#[bitflags] requires an enum" ,
67- ) ) ,
68- } ;
61+ let mut ast = parse_macro_input ! ( input as DeriveInput ) ;
62+ let output = gen_enumflags ( & mut ast, default) ;
6963
7064 output
7165 . unwrap_or_else ( |err| {
@@ -247,17 +241,29 @@ fn check_flag(type_name: &Ident, flag: &Flag, bits: u8) -> Result<Option<TokenSt
247241 }
248242}
249243
250- fn gen_enumflags ( ast : & mut ItemEnum , default : Vec < Ident > ) -> Result < TokenStream , syn:: Error > {
244+ fn gen_enumflags ( ast : & mut DeriveInput , default : Vec < Ident > ) -> Result < TokenStream , syn:: Error > {
251245 let ident = & ast. ident ;
252246
253247 let span = Span :: call_site ( ) ;
254248
249+ let ast_variants = match & mut ast. data {
250+ Data :: Enum ( ref mut data) => & mut data. variants ,
251+ Data :: Struct ( data) => {
252+ return Err ( syn:: Error :: new_spanned ( & data. struct_token ,
253+ "expected enum for #[bitflags], found struct" ) ) ;
254+ }
255+ Data :: Union ( data) => {
256+ return Err ( syn:: Error :: new_spanned ( & data. union_token ,
257+ "expected enum for #[bitflags], found union" ) ) ;
258+ }
259+ } ;
260+
255261 let repr = extract_repr ( & ast. attrs ) ?
256262 . ok_or_else ( || syn:: Error :: new_spanned ( ident,
257263 "repr attribute missing. Add #[repr(u64)] or a similar attribute to specify the size of the bitfield." ) ) ?;
258264 let bits = type_bits ( & repr) ?;
259265
260- let mut variants = collect_flags ( ast . variants . iter_mut ( ) ) ?;
266+ let mut variants = collect_flags ( ast_variants . iter_mut ( ) ) ?;
261267 let deferred = variants
262268 . iter ( )
263269 . flat_map ( |variant| check_flag ( ident, variant, bits) . transpose ( ) )
@@ -273,7 +279,12 @@ fn gen_enumflags(ast: &mut ItemEnum, default: Vec<Ident>) -> Result<TokenStream,
273279 }
274280
275281 let std = quote_spanned ! ( span => :: enumflags2:: _internal:: core) ;
276- let variant_names = ast. variants . iter ( ) . map ( |v| & v. ident ) . collect :: < Vec < _ > > ( ) ;
282+ let ast_variants = match & ast. data {
283+ Data :: Enum ( ref data) => & data. variants ,
284+ _ => unreachable ! ( ) ,
285+ } ;
286+
287+ let variant_names = ast_variants. iter ( ) . map ( |v| & v. ident ) . collect :: < Vec < _ > > ( ) ;
277288
278289 Ok ( quote_spanned ! {
279290 span =>
0 commit comments