|
85 | 85 |
|
86 | 86 | @testset "muladd: $T" for T in (Float64, ComplexF64) |
87 | 87 | @testset "add $(typeof(z))" for z in [rand(), rand(T, 3), rand(T, 3, 3), false] |
88 | | - @testset "forward mode" begin |
89 | | - @gpu test_frule(muladd, rand(T, 3, 5), rand(T, 5, 3), z) |
90 | | - end |
91 | 88 | @testset "matrix * matrix" begin |
92 | 89 | A = rand(T, 3, 3) |
93 | 90 | B = rand(T, 3, 3) |
94 | 91 | @gpu test_rrule(muladd, A, B, z) |
95 | 92 | @gpu test_rrule(muladd, A', B, z) |
96 | 93 | @gpu test_rrule(muladd, A , B', z) |
| 94 | + @gpu test_frule(muladd, A, B, z) |
| 95 | + @gpu test_frule(muladd, A', B, z) |
| 96 | + @gpu test_frule(muladd, A , B', z) |
97 | 97 |
|
98 | 98 | C = rand(T, 3, 5) |
99 | 99 | D = rand(T, 5, 3) |
100 | 100 | @gpu test_rrule(muladd, C, D, z) |
| 101 | + @gpu test_frule(muladd, C, D, z) |
101 | 102 | end |
102 | 103 | if ndims(z) <= 1 |
103 | 104 | @testset "matrix * vector" begin |
104 | 105 | A, B = rand(T, 3, 3), rand(T, 3) |
105 | 106 | test_rrule(muladd, A, B, z) |
106 | 107 | test_rrule(muladd, A, B ⊢ rand(T, 3,1), z) |
| 108 | + test_frule(muladd, A, B, z) |
107 | 109 | end |
108 | 110 | @testset "adjoint * matrix" begin |
109 | 111 | At, B = rand(T, 3)', rand(T, 3, 3) |
110 | 112 | test_rrule(muladd, At, B, z') |
111 | 113 | test_rrule(muladd, At ⊢ rand(T,1,3), B, z') |
| 114 | + test_frule(muladd, At, B, z') |
112 | 115 | end |
113 | 116 | end |
114 | 117 | if ndims(z) == 0 |
115 | 118 | @testset "adjoint * vector" begin # like dot |
116 | 119 | A, B = rand(T, 3)', rand(T, 3) |
117 | 120 | test_rrule(muladd, A, B, z) |
118 | 121 | test_rrule(muladd, A ⊢ rand(T,1,3), B, z') |
| 122 | + test_frule(muladd, A, B, z) |
119 | 123 | end |
120 | 124 | end |
121 | 125 | if ndims(z) == 2 # other dims lead to e.g. muladd(ones(4), ones(1,4), 1) |
122 | 126 | @testset "vector * adjoint" begin # outer product |
123 | 127 | A, B = rand(T, 3), rand(T, 3)' |
124 | 128 | test_rrule(muladd, A, B, z) |
125 | 129 | test_rrule(muladd, A, B ⊢ rand(T,1,3), z) |
| 130 | + test_frule(muladd, A, B, z) |
126 | 131 | end |
127 | 132 | end |
128 | 133 | end |
|
0 commit comments