11//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
22//@ no-prefer-dynamic
33//@ needs-enzyme
4+
45#![ feature( autodiff) ]
56
67use std:: autodiff:: autodiff;
78
8- #[ autodiff( d_square, Reverse , 4 , Duplicated , Active ) ]
9+ #[ autodiff( d_square3, Forward , Dual , DualOnly ) ]
10+ #[ no_mangle]
11+ fn squaref ( x : & f32 ) -> f32 {
12+ 2.0 * x * x
13+ }
14+
15+
16+ #[ autodiff( d_square2, Forward , 4 , Dual , DualOnly ) ]
17+ #[ autodiff( d_square, Forward , 4 , Dual , Dual ) ]
918#[ no_mangle]
10- fn square ( x : & f64 ) -> f64 {
19+ fn square ( x : & f32 ) -> f32 {
1120 x * x
1221}
1322
@@ -33,21 +42,31 @@ fn square(x: &f64) -> f64 {
3342// CHECK-NEXT:}
3443
3544fn main ( ) {
36- let x = 3.0 ;
45+ let x = std :: hint :: black_box ( 3.0 ) ;
3746 let output = square ( & x) ;
47+ dbg ! ( & output) ;
3848 assert_eq ! ( 9.0 , output) ;
49+ dbg ! ( squaref( & x) ) ;
3950
40- let mut df_dx1 = 0 .0;
41- let mut df_dx2 = 0 .0;
42- let mut df_dx3 = 0 .0;
51+ let mut df_dx1 = 1 .0;
52+ let mut df_dx2 = 2 .0;
53+ let mut df_dx3 = 3 .0;
4354 let mut df_dx4 = 0.0 ;
44- let [ o1, o2, o3, o4] = d_square ( & x, & mut df_dx1, & mut df_dx2, & mut df_dx3, & mut df_dx4, 1.0 ) ;
45- assert_eq ! ( output, o1) ;
46- assert_eq ! ( output, o2) ;
47- assert_eq ! ( output, o3) ;
48- assert_eq ! ( output, o4) ;
49- assert_eq ! ( 6.0 , df_dx1) ;
50- assert_eq ! ( 6.0 , df_dx2) ;
51- assert_eq ! ( 6.0 , df_dx3) ;
52- assert_eq ! ( 6.0 , df_dx4) ;
55+ let [ o1, o2, o3, o4] = d_square2 ( & x, & mut df_dx1, & mut df_dx2, & mut df_dx3, & mut df_dx4) ;
56+ dbg ! ( o1, o2, o3, o4) ;
57+ let [ output2, o1, o2, o3, o4] = d_square ( & x, & mut df_dx1, & mut df_dx2, & mut df_dx3, & mut df_dx4) ;
58+ dbg ! ( o1, o2, o3, o4) ;
59+ assert_eq ! ( output, output2) ;
60+ assert ! ( ( 6.0 - o1) . abs( ) < 1e-10 ) ;
61+ assert ! ( ( 12.0 - o2) . abs( ) < 1e-10 ) ;
62+ assert ! ( ( 18.0 - o3) . abs( ) < 1e-10 ) ;
63+ assert ! ( ( 0.0 - o4) . abs( ) < 1e-10 ) ;
64+ assert_eq ! ( 1.0 , df_dx1) ;
65+ assert_eq ! ( 2.0 , df_dx2) ;
66+ assert_eq ! ( 3.0 , df_dx3) ;
67+ assert_eq ! ( 0.0 , df_dx4) ;
68+ assert_eq ! ( d_square3( & x, & mut df_dx1) , 2.0 * o1) ;
69+ assert_eq ! ( d_square3( & x, & mut df_dx2) , 2.0 * o2) ;
70+ assert_eq ! ( d_square3( & x, & mut df_dx3) , 2.0 * o3) ;
71+ assert_eq ! ( d_square3( & x, & mut df_dx4) , 2.0 * o4) ;
5372}
0 commit comments