@@ -96,14 +96,15 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
9696 @testset " Correct definitions" begin
9797 local inplace_used
9898 function ChainRulesCore. frule ((_, ẋ), :: typeof (identity), x:: Array )
99- ẏ = InplaceableThunk (@thunk (ẋ), ȧ -> (inplace_used = true ; ȧ .+ = ẋ))
99+ ẏ = InplaceableThunk (ȧ -> (inplace_used = true ; ȧ .+ = ẋ), @thunk ( ẋ))
100100 return identity (x), ẏ
101101 end
102102 function ChainRulesCore. rrule (:: typeof (identity), x:: Array )
103103 function identity_pullback (ȳ)
104- x̄_ret = InplaceableThunk (
105- @thunk (ȳ), ā -> (inplace_used = true ; ā .+ = ȳ)
106- )
104+ x̄_ret = InplaceableThunk (@thunk (ȳ)) do ā
105+ inplace_used = true
106+ ā .+ = ȳ
107+ end
107108 return (NoTangent (), x̄_ret)
108109 end
109110 return identity (x), identity_pullback
@@ -122,14 +123,14 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
122123 my_identity (value) = value # we will define bad rules on this
123124 function ChainRulesCore. frule ((_, ẋ), :: typeof (my_identity), x:: Array )
124125 # only the in-place part is incorrect
125- ẏ = InplaceableThunk (@thunk (ẋ), ȧ -> ȧ .+ = 200 .* ẋ)
126+ ẏ = InplaceableThunk (ȧ -> ȧ .+ = 200 .* ẋ, @thunk (ẋ) )
126127 return my_identity (x), ẏ
127128 end
128129 function ChainRulesCore. rrule (:: typeof (my_identity), x:: Array )
129130 x_dims = size (x)
130131 function my_identity_pullback (ȳ)
131132 # only the in-place part is incorrect
132- x̄_ret = InplaceableThunk (@thunk (ȳ), ā -> ā .+ = 200 .* ȳ)
133+ x̄_ret = InplaceableThunk (ā -> ā .+ = 200 .* ȳ, @thunk (ȳ) )
133134 return (NoTangent (), x̄_ret)
134135 end
135136 return my_identity (x), my_identity_pullback
0 commit comments