1- use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , DiffActivity , DiffMode } ;
1+ use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , DiffActivity , DiffMode , valid_ret_activity , valid_input_activities } ;
22use rustc_ast:: { ast, attr, MetaItem , MetaItemKind , NestedMetaItem } ;
33use rustc_attr:: { list_contains_name, InlineAttr , InstructionSetAttr , OptimizeAttr } ;
44use rustc_errors:: struct_span_err;
@@ -692,6 +692,11 @@ fn check_link_name_xor_ordinal(
692692 }
693693}
694694
695+ /// We now check the #[rustc_autodiff] attributes which we generated from the #[autodiff(...)]
696+ /// macros. There are two forms. The pure one without args to mark primal functions (the functions
697+ /// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the
698+ /// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never
699+ /// panic, unless we introduced a bug when parsing the autodiff macro.
695700fn autodiff_attrs ( tcx : TyCtxt < ' _ > , id : DefId ) -> AutoDiffAttrs {
696701 let attrs = tcx. get_attrs ( id, sym:: rustc_autodiff) ;
697702
@@ -726,20 +731,22 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
726731 return AutoDiffAttrs :: source ( ) ;
727732 }
728733
729- let msg_ad_mode = "autodiff attribute must contain autodiff mode" ;
730- let ( mode, list) = match list. split_first ( ) {
731- Some ( (
732- NestedMetaItem :: MetaItem ( MetaItem { path : ref p1, kind : MetaItemKind :: Word , .. } ) ,
733- list,
734- ) ) => ( p1. segments . first ( ) . unwrap ( ) . ident , list) ,
735- _ => {
736- tcx. sess
737- . struct_span_err ( attr. span , msg_ad_mode)
738- . span_label ( attr. span , "empty argument list" )
739- . emit ( ) ;
740-
741- return AutoDiffAttrs :: inactive ( ) ;
742- }
734+ let [ mode, input_activities @ .., ret_activity] = & list[ ..] else {
735+ tcx. sess
736+ . struct_span_err ( attr. span , msg_once)
737+ . span_label ( attr. span , "Implementation bug in autodiff_attrs. Please report this!" )
738+ . emit ( ) ;
739+ return AutoDiffAttrs :: inactive ( ) ;
740+ } ;
741+ let mode = if let NestedMetaItem :: MetaItem ( MetaItem { path : ref p1, .. } ) = mode {
742+ p1. segments . first ( ) . unwrap ( ) . ident
743+ } else {
744+ let msg = "autodiff attribute must contain autodiff mode" ;
745+ tcx. sess
746+ . struct_span_err ( attr. span , msg)
747+ . span_label ( attr. span , "empty argument list" )
748+ . emit ( ) ;
749+ return AutoDiffAttrs :: inactive ( ) ;
743750 } ;
744751
745752 // parse mode
@@ -752,27 +759,23 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
752759 . struct_span_err ( attr. span , msg_mode)
753760 . span_label ( attr. span , "invalid mode" )
754761 . emit ( ) ;
755-
756762 return AutoDiffAttrs :: inactive ( ) ;
757763 }
758764 } ;
759765
760- let msg_ret_activity = "autodiff attribute must contain the return activity" ;
761- let ( ret_symbol, list) = match list. split_last ( ) {
762- Some ( (
763- NestedMetaItem :: MetaItem ( MetaItem { path : ref p1, kind : MetaItemKind :: Word , .. } ) ,
764- list,
765- ) ) => ( p1. segments . first ( ) . unwrap ( ) . ident , list) ,
766- _ => {
767- tcx. sess
768- . struct_span_err ( attr. span , msg_ret_activity)
769- . span_label ( attr. span , "missing return activity" )
770- . emit ( ) ;
771-
772- return AutoDiffAttrs :: inactive ( ) ;
773- }
766+ // First read the ret symbol from the attribute
767+ let ret_symbol = if let NestedMetaItem :: MetaItem ( MetaItem { path : ref p1, .. } ) = ret_activity {
768+ p1. segments . first ( ) . unwrap ( ) . ident
769+ } else {
770+ let msg = "autodiff attribute must contain the return activity" ;
771+ tcx. sess
772+ . struct_span_err ( attr. span , msg)
773+ . span_label ( attr. span , "missing return activity" )
774+ . emit ( ) ;
775+ return AutoDiffAttrs :: inactive ( ) ;
774776 } ;
775777
778+ // Then parse it into an actual DiffActivity
776779 let msg_unknown_ret_activity = "unknown return activity" ;
777780 let ret_activity = match DiffActivity :: from_str ( ret_symbol. as_str ( ) ) {
778781 Ok ( x) => x,
@@ -781,26 +784,22 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
781784 . struct_span_err ( attr. span , msg_unknown_ret_activity)
782785 . span_label ( attr. span , "invalid return activity" )
783786 . emit ( ) ;
784-
785787 return AutoDiffAttrs :: inactive ( ) ;
786788 }
787789 } ;
788790
791+ // Now parse all the intermediate (inptut) activities
789792 let msg_arg_activity = "autodiff attribute must contain the return activity" ;
790793 let mut arg_activities: Vec < DiffActivity > = vec ! [ ] ;
791- for arg in list {
792- let arg_symbol = match arg {
793- NestedMetaItem :: MetaItem ( MetaItem {
794- path : ref p2, kind : MetaItemKind :: Word , ..
795- } ) => p2. segments . first ( ) . unwrap ( ) . ident ,
796- _ => {
797- tcx. sess
798- . struct_span_err ( attr. span , msg_arg_activity)
799- . span_label ( attr. span , "missing return activity" )
800- . emit ( ) ;
801-
802- return AutoDiffAttrs :: inactive ( ) ;
803- }
794+ for arg in input_activities {
795+ let arg_symbol = if let NestedMetaItem :: MetaItem ( MetaItem { path : ref p2, .. } ) = arg {
796+ p2. segments . first ( ) . unwrap ( ) . ident
797+ } else {
798+ tcx. sess
799+ . struct_span_err ( attr. span , msg_arg_activity)
800+ . span_label ( attr. span , "Implementation bug, please report this!" )
801+ . emit ( ) ;
802+ return AutoDiffAttrs :: inactive ( ) ;
804803 } ;
805804
806805 match DiffActivity :: from_str ( arg_symbol. as_str ( ) ) {
@@ -810,45 +809,22 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
810809 . struct_span_err ( attr. span , msg_unknown_ret_activity)
811810 . span_label ( attr. span , "invalid input activity" )
812811 . emit ( ) ;
813-
814812 return AutoDiffAttrs :: inactive ( ) ;
815813 }
816814 }
817815 }
818816
819- let msg_fwd_incompatible_ret = "Forward Mode is incompatible with Active ret" ;
820- let msg_fwd_incompatible_arg = "Forward Mode is incompatible with Active ret" ;
821- let msg_rev_incompatible_arg =
822- "Reverse Mode is only compatible with Active, None, or Const ret" ;
823- if mode == DiffMode :: Forward {
824- if ret_activity == DiffActivity :: Active {
825- tcx. sess
826- . struct_span_err ( attr. span , msg_fwd_incompatible_ret)
827- . span_label ( attr. span , "invalid return activity" )
828- . emit ( ) ;
829- return AutoDiffAttrs :: inactive ( ) ;
830- }
831- if arg_activities. iter ( ) . filter ( |& x| * x == DiffActivity :: Active ) . count ( ) > 0 {
832- tcx. sess
833- . struct_span_err ( attr. span , msg_fwd_incompatible_arg)
834- . span_label ( attr. span , "invalid input activity" )
835- . emit ( ) ;
836- return AutoDiffAttrs :: inactive ( ) ;
837- }
817+ let msg = "Invalid activity for mode" ;
818+ let valid_input = valid_input_activities ( mode, & arg_activities) ;
819+ let valid_ret = valid_ret_activity ( mode, ret_activity) ;
820+ if !valid_input || !valid_ret {
821+ tcx. sess
822+ . struct_span_err ( attr. span , msg)
823+ . span_label ( attr. span , "invalid activity" )
824+ . emit ( ) ;
825+ return AutoDiffAttrs :: inactive ( ) ;
838826 }
839827
840- if mode == DiffMode :: Reverse {
841- if ret_activity == DiffActivity :: Duplicated
842- || ret_activity == DiffActivity :: DuplicatedOnly
843- {
844- dbg ! ( "ret_activity = {:?}" , ret_activity) ;
845- tcx. sess
846- . struct_span_err ( attr. span , msg_rev_incompatible_arg)
847- . span_label ( attr. span , "invalid return activity" )
848- . emit ( ) ;
849- return AutoDiffAttrs :: inactive ( ) ;
850- }
851- }
852828
853829 AutoDiffAttrs { mode, ret_activity, input_activity : arg_activities }
854830}
0 commit comments