@@ -78,3 +78,84 @@ Zygote.gradient(0.1ones(length(SS.canonicalize(SS.Tunable(), p)[1]))) do tunable
7878 sol = solve (newprob, Tsit5 ())
7979 return sum (sol. u[end ])
8080end
81+
82+ using OrdinaryDiffEq
83+ using Random, Lux
84+ using ComponentArrays
85+ using SciMLSensitivity
86+ import SciMLStructures as SS
87+ using Zygote
88+ using ADTypes
89+ using Test
90+
91+ mutable struct myparam{M,P,S}
92+ model:: M
93+ ps :: P
94+ st :: S
95+ α :: Float64
96+ β :: Float64
97+ γ :: Float64
98+ end
99+
100+ SS. isscimlstructure (:: myparam ) = true
101+ SS. ismutablescimlstructure (:: myparam ) = true
102+ SS. hasportion (:: SS.Tunable , :: myparam ) = true
103+ function SS. canonicalize (:: SS.Tunable , p:: myparam )
104+ buffer = copy (p. ps)
105+ repack = let p = p
106+ function repack (newbuffer)
107+ SS. replace (SS. Tunable (), p, newbuffer)
108+ end
109+ end
110+ return buffer, repack, false
111+ end
112+ function SS. replace (:: SS.Tunable , p:: myparam , newbuffer)
113+ return myparam (p. model, newbuffer, p. st, p. α, p. β, p. γ)
114+ end
115+ function SS. replace! (:: SS.Tunable , p:: myparam , newbuffer)
116+ p. ps = newbuffer
117+ return p
118+ end
119+ function initialize ()
120+ # Defining the neural network
121+ U = Lux. Chain (Lux. Dense (3 ,30 ,tanh),Lux. Dense (30 ,30 ,tanh),Lux. Dense (30 ,1 ))
122+ rng = Random. GLOBAL_RNG
123+ _para,st = Lux. setup (rng,U)
124+ _para = ComponentArray (_para)
125+ # Setting the parameters
126+ α = 0.5
127+ β = 0.1
128+ γ = 0.01
129+ return myparam (U,_para,st,α,β,γ)
130+ end
131+ function UDE_model! (du, u, p, t)
132+ o = p. model (u,p. ps, p. st)[1 ][1 ]
133+ du[1 ] = o * p. α * u[1 ] + p. β * u[2 ] + p. γ * u[3 ]
134+ du[2 ] = - p. α * u[1 ] + p. β * u[2 ] - p. γ * u[3 ]
135+ du[3 ] = p. α * u[1 ] - p. β * u[2 ] + p. γ * u[3 ]
136+ nothing
137+ end
138+
139+ p = initialize ()
140+ function run_diff (ps)
141+ u01 = [1.0 , 0.0 , 0.0 ]
142+ tspan = (0.0 , 10.0 )
143+ prob = ODEProblem (UDE_model!, u01, tspan, ps)
144+ sol = solve (prob, Rosenbrock23 (), saveat = 0.1 )
145+ return sol. u |> last |> sum
146+ end
147+
148+ run_diff (initialize ())
149+ @test ! iszero (Zygote. gradient (run_diff, initialize ())[1 ]. ps)
150+
151+ function run_diff (ps,sensealg)
152+ u01 = [1.0 , 0.0 , 0.0 ]
153+ tspan = (0.0 , 10.0 )
154+ prob = ODEProblem (UDE_model!, u01, tspan, ps)
155+ sol = solve (prob, Rosenbrock23 (), saveat = 0.1 , sensealg= sensealg)
156+ return sol. u |> last |> sum
157+ end
158+
159+ run_diff (initialize ())
160+ @test ! iszero (Zygote. gradient (run_diff, initialize (), GaussAdjoint ())[1 ]. ps)
161+ @test ! iszero (Zygote. gradient (run_diff, initialize (), GaussAdjoint (autojacvec= false ))[1 ]. ps)
0 commit comments