@@ -4,7 +4,7 @@ using LinearAlgebra
44using Optimisers: Optimisers
55using Functors: fmap
66
7- export train!, update!, adjust!, FluxState,
7+ export train!, update!, adjust!, FluxState, @train_autodiff ,
88 Descent, Adam, Momentum, Nesterov, RMSProp,
99 AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief # ,
1010 # InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser,
@@ -108,6 +108,48 @@ include("implicit_train.jl") # Params etc, Zygote only
108108
109109explicit_withgradient (f, args... ) = Zygote. withgradient (f, args... ) # can overload this to use e.g. Yota / Diffractor
110110
111+ """
112+ @train_autodiff Diffractor
113+ @train_autodiff Yota
114+ @train_autodiff Zygote
115+
116+ This macro allows the use of `train!` with various automatic differentiation packages,
117+ instead of the default Zygote.jl. You should load the package, then call this macro.
118+
119+ Only affects "explicit-mode" versions `train!(loss, model, data, opt)` or `train!(loss, model, opt)`,
120+ since the (deprecated) "implicit-mode" `train!(loss, ps::Params, data, opt)` is Zygote-specific.
121+
122+ !!! note
123+ Experimental
124+ """
125+ macro train_autodiff (pkg)
126+ if pkg == :Diffractor
127+ return quote
128+ Diffractor. gradient (sin, 0.0 )[1 ] ≈ 1.0 # ensures an error if not loaded
129+ function Flux. Train. explicit_withgradient (f, args... )
130+ y, back = Diffractor.∂⃖¹ (f, args... )
131+ dy1 = Flux. Zygote. sensitivity (y) # Zygote is loaded, and this gives nice errors
132+ return (; value = y, gradient = Base. tail (back (dy1)))
133+ end
134+ end |> esc
135+ elseif pkg == :Yota
136+ return quote
137+ Yota. grad (sin, 0.0 ) # [2][1] ≈ 1.0
138+ function Flux. Train. explicit_withgradient (f, args... )
139+ value, (_, gradient... ) = Yota. grad (f, args... )
140+ return (; value, gradient)
141+ end
142+ end |> esc
143+ elseif pkg == :Zygote
144+ return quote
145+ Flux. Train. explicit_withgradient (f, args... ) = Flux. Zygote. withgradient (f, args... )
146+ end |> esc
147+ else
148+ throw (" @train_autodiff expects either Zygote, Yota, or Diffractor. No other arguments are understood." )
149+ end
150+ end
151+
152+
111153# ## Misc. related utilities
112154
113155"""
0 commit comments