@@ -2,7 +2,7 @@ use std::str::FromStr;
22
33use rustc_abi:: ExternAbi ;
44use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , DiffActivity , DiffMode } ;
5- use rustc_ast:: { MetaItem , MetaItemInner , attr} ;
5+ use rustc_ast:: { LitKind , MetaItem , MetaItemInner , attr} ;
66use rustc_attr_parsing:: ReprAttr :: ReprAlign ;
77use rustc_attr_parsing:: { AttributeKind , InlineAttr , InstructionSetAttr , OptimizeAttr } ;
88use rustc_data_structures:: fx:: FxHashMap ;
@@ -805,8 +805,8 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
805805 return Some ( AutoDiffAttrs :: source ( ) ) ;
806806 }
807807
808- let [ mode, input_activities @ .., ret_activity] = & list[ ..] else {
809- span_bug ! ( attr. span( ) , "rustc_autodiff attribute must contain mode and activities" ) ;
808+ let [ mode, width_meta , input_activities @ .., ret_activity] = & list[ ..] else {
809+ span_bug ! ( attr. span( ) , "rustc_autodiff attribute must contain mode, width and activities" ) ;
810810 } ;
811811 let mode = if let MetaItemInner :: MetaItem ( MetaItem { path : p1, .. } ) = mode {
812812 p1. segments . first ( ) . unwrap ( ) . ident
@@ -823,6 +823,30 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
823823 }
824824 } ;
825825
826+ let width: u32 = match width_meta {
827+ MetaItemInner :: MetaItem ( MetaItem { path : p1, .. } ) => {
828+ let w = p1. segments . first ( ) . unwrap ( ) . ident ;
829+ match w. as_str ( ) . parse ( ) {
830+ Ok ( val) => val,
831+ Err ( _) => {
832+ span_bug ! ( w. span, "rustc_autodiff width should fit u32" ) ;
833+ }
834+ }
835+ }
836+ MetaItemInner :: Lit ( lit) => {
837+ if let LitKind :: Int ( val, _) = lit. kind {
838+ match val. get ( ) . try_into ( ) {
839+ Ok ( val) => val,
840+ Err ( _) => {
841+ span_bug ! ( lit. span, "rustc_autodiff width should fit u32" ) ;
842+ }
843+ }
844+ } else {
845+ span_bug ! ( lit. span, "rustc_autodiff width should be an integer" ) ;
846+ }
847+ }
848+ } ;
849+
826850 // First read the ret symbol from the attribute
827851 let ret_symbol = if let MetaItemInner :: MetaItem ( MetaItem { path : p1, .. } ) = ret_activity {
828852 p1. segments . first ( ) . unwrap ( ) . ident
@@ -860,7 +884,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
860884 }
861885 }
862886
863- Some ( AutoDiffAttrs { mode, width : 1 , ret_activity, input_activity : arg_activities } )
887+ Some ( AutoDiffAttrs { mode, width, ret_activity, input_activity : arg_activities } )
864888}
865889
866890pub ( crate ) fn provide ( providers : & mut Providers ) {
0 commit comments