@@ -59,11 +59,17 @@ function solve_call(_prob,args...;merge_callbacks = true, kwargs...)
5959 else
6060 __solve (_prob,args... ; kwargs... )# ::T
6161 end
62+ end
6263
64+ function solve (prob:: DEProblem ,args... ;sensealg= nothing ,
65+ u0 = nothing , p = nothing ,kwargs... )
66+ u0 = u0 != = nothing ? u0 : prob. u0
67+ p = p != = nothing ? p : prob. p
68+ solve_up (prob,sensealg,u0,p,args... ;kwargs... )
6369end
6470
65- function solve (prob:: DEProblem ,args... ;kwargs... )
66- _prob = get_concrete_problem (prob, kwargs)
71+ function solve_up (prob:: DEProblem ,sensealg,u0,p ,args... ;kwargs... )
72+ _prob = get_concrete_problem (prob;u0 = u0,p = p, kwargs... )
6773 if haskey (kwargs,:alg ) && (isempty (args) || args[1 ] === nothing )
6874 alg = kwargs[:alg ]
6975 isadaptive (alg) &&
@@ -93,21 +99,21 @@ function solve(prob::EnsembleProblem,args...;kwargs...)
9399 end
94100end
95101
96- function solve (prob:: AbstractNoiseProblem ,args... ;kwargs... )
102+ function solve (prob:: AbstractNoiseProblem ,args... ; kwargs... )
97103 __solve (prob,args... ;kwargs... )
98104end
99105
100- function get_concrete_problem (prob:: AbstractJumpProblem , kwargs)
106+ function get_concrete_problem (prob:: AbstractJumpProblem ; kwargs... )
101107 prob
102108end
103109
104- function get_concrete_problem (prob:: AbstractSteadyStateProblem , kwargs)
110+ function get_concrete_problem (prob:: AbstractSteadyStateProblem ; kwargs... )
105111 u0 = get_concrete_u0 (prob, Inf , kwargs)
106112 u0 = promote_u0 (u0, prob. p, nothing )
107113 remake (prob; u0 = u0)
108114end
109115
110- function get_concrete_problem (prob:: AbstractEnsembleProblem , kwargs)
116+ function get_concrete_problem (prob:: AbstractEnsembleProblem ; kwargs... )
111117 prob
112118end
113119
@@ -118,45 +124,45 @@ end
118124
119125function discretize end
120126
121- function get_concrete_problem (prob, kwargs)
122- tspan = get_concrete_tspan (prob, kwargs)
127+ function get_concrete_problem (prob; kwargs... )
128+ p = get_concrete_p (prob, kwargs)
129+ tspan = get_concrete_tspan (prob, kwargs, p)
123130 u0 = get_concrete_u0 (prob, tspan[1 ], kwargs)
124- u0_promote = promote_u0 (u0, prob . p, tspan[1 ])
125- tspan_promote = promote_tspan (u0, prob . p, tspan, prob, kwargs)
131+ u0_promote = promote_u0 (u0, p, tspan[1 ])
132+ tspan_promote = promote_tspan (u0, p, tspan, prob, kwargs)
126133 if isconcreteu0 (prob, tspan[1 ], kwargs) && typeof (u0_promote) === typeof (u0) &&
127134 prob. tspan == tspan && typeof (tspan) === typeof (tspan_promote)
128135 return prob
129136 else
130- return remake (prob; u0 = u0_promote, tspan = tspan_promote)
137+ return remake (prob; u0 = u0_promote, p = p, tspan = tspan_promote)
131138 end
132139end
133140
134- function get_concrete_problem (prob:: DDEProblem , kwargs)
135- tspan = get_concrete_tspan (prob, kwargs)
141+ function get_concrete_problem (prob:: DDEProblem ; kwargs... )
142+ p = get_concrete_p (prob, kwargs)
143+ tspan = get_concrete_tspan (prob, kwargs, p)
136144
137145 u0 = get_concrete_u0 (prob, tspan[1 ], kwargs)
138146
139147 if prob. constant_lags isa Function
140- constant_lags = prob. constant_lags (prob . p)
148+ constant_lags = prob. constant_lags (p)
141149 else
142150 constant_lags = prob. constant_lags
143151 end
144152
145- u0 = promote_u0 (u0, prob . p, tspan[1 ])
146- tspan = promote_tspan (u0, prob . p, tspan, prob, kwargs)
153+ u0 = promote_u0 (u0, p, tspan[1 ])
154+ tspan = promote_tspan (u0, p, tspan, prob, kwargs)
147155
148- remake (prob; u0 = u0, tspan = tspan, constant_lags = constant_lags)
156+ remake (prob; u0 = u0, tspan = tspan, p = p, constant_lags = constant_lags)
149157end
150158
151- function get_concrete_tspan (prob, kwargs)
159+ function get_concrete_tspan (prob, kwargs, p )
152160 if prob. tspan isa Function
153- tspan = prob. tspan (prob. p)
154- elseif prob. tspan === (nothing , nothing )
155- if haskey (kwargs, :tspan )
161+ tspan = prob. tspan (p)
162+ elseif haskey (kwargs, :tspan )
156163 tspan = kwargs[:tspan ]
157- else
158- error (" No tspan is set in the problem or chosen in the init/solve call" )
159- end
164+ elseif prob. tspan === (nothing , nothing )
165+ error (" No tspan is set in the problem or chosen in the init/solve call" )
160166 else
161167 tspan = prob. tspan
162168 end
171177function get_concrete_u0 (prob, t0, kwargs)
172178 if eval_u0 (prob. u0)
173179 u0 = prob. u0 (prob. p, t0)
174- elseif prob . u0 === nothing
180+ elseif haskey (kwargs, :u0 )
175181 u0 = kwargs[:u0 ]
176182 else
177183 u0 = prob. u0
@@ -180,6 +186,14 @@ function get_concrete_u0(prob, t0, kwargs)
180186 handle_distribution_u0 (u0)
181187end
182188
189+ function get_concrete_p (prob, kwargs)
190+ if haskey (kwargs,:p )
191+ p = kwargs[:p ]
192+ else
193+ p = prob. p
194+ end
195+ end
196+
183197handle_distribution_u0 (_u0) = _u0
184198eval_u0 (u0:: Function ) = true
185199eval_u0 (u0) = false
@@ -218,38 +232,49 @@ end
218232
219233# ################## Concrete Solve
220234
221- function _concrete_solve end
235+ @deprecate concrete_solve (prob:: DiffEqBase.DEProblem ,alg:: Union{DiffEqBase.DEAlgorithm,Nothing} ,
236+ u0= prob. u0,p= prob. p,args... ;kwargs... ) solve (prob,alg,args... ;u0= u0,p= p,kwargs... )
222237
223- function concrete_solve (prob:: DiffEqBase.DEProblem ,alg:: Union{DiffEqBase.DEAlgorithm,Nothing} ,
224- u0= prob. u0,p= prob. p,args... ;kwargs... )
225- _concrete_solve (prob,alg,u0,p,args... ;kwargs... )
226- end
238+ struct SensitivityADPassThrough <: DiffEqBase.DEAlgorithm end
227239
228- function _concrete_solve (prob:: DiffEqBase.DEProblem ,alg :: Union{DiffEqBase.DEAlgorithm, Nothing} ,
229- u0 = prob . u0,p = prob . p ,args... ;kwargs ... )
230- sol = solve ( remake (prob,u0 = u0,p = p),alg,args ... ; kwargs... )
231- RecursiveArrayTools . DiffEqArray (sol . u,sol . t )
240+ ZygoteRules . @adjoint function solve_up (prob,sensealg :: Union{Nothing,AbstractSensitivityAlgorithm } ,
241+ u0,p ,args... ;
242+ kwargs... )
243+ _solve_adjoint (prob,sensealg,u0,p,args ... ;kwargs ... )
232244end
233245
234- function _concrete_solve (prob:: DiffEqBase.SteadyStateProblem ,alg:: Union{DiffEqBase.DEAlgorithm,Nothing} ,
235- u0= prob. u0,p= prob. p,args... ;kwargs... )
236- sol = solve (remake (prob,u0= u0,p= p),alg,args... ;kwargs... )
237- RecursiveArrayTools. VectorOfArray (sol. u)
246+ function ChainRulesCore. frule (:: typeof (solve_up),prob,
247+ sensealg:: Union{Nothing,AbstractSensitivityAlgorithm} ,
248+ u0,p,args... ;
249+ kwargs... )
250+ _solve_forward (prob,sensealg,u0,p,args... ;kwargs... )
238251end
239252
240- function ChainRulesCore. frule (:: typeof (concrete_solve),prob,alg,u0,p,args... ;
241- sensealg= nothing ,kwargs... )
242- _concrete_solve_forward (prob,alg,sensealg,u0,p,args... ;kwargs... )
253+ function ChainRulesCore. rrule (:: typeof (solve_up),prob,
254+ sensealg:: Union{Nothing,AbstractSensitivityAlgorithm} ,
255+ u0,p,args... ;
256+ kwargs... )
257+ _solve_adjoint (prob,sensealg,u0,p,args... ;kwargs... )
243258end
244259
245- function ChainRulesCore. rrule (:: typeof (concrete_solve),prob,alg,u0,p,args... ;
246- sensealg= nothing ,kwargs... )
247- _concrete_solve_adjoint (prob,alg,sensealg,u0,p,args... ;kwargs... )
260+ # ##
261+ # ## Legacy Dispatches to be Non-Breaking
262+ # ##
263+
264+ function _solve_adjoint (prob,sensealg,u0,p,args... ;kwargs... )
265+ if isempty (args)
266+ _concrete_solve_adjoint (prob,nothing ,sensealg,u0,p;kwargs... )
267+ else
268+ _concrete_solve_adjoint (prob,args[1 ],sensealg,u0,p,Base. tail (args)... ;kwargs... )
269+ end
248270end
249271
250- ZygoteRules. @adjoint function concrete_solve (prob,alg,u0,p,args... ;
251- sensealg= nothing ,kwargs... )
252- _concrete_solve_adjoint (prob,alg,sensealg,u0,p,args... ;kwargs... )
272+ function _solve_forward (prob,sensealg,u0,p,args... ;kwargs... )
273+ if isempty (args)
274+ _concrete_solve_forward (prob,nothing ,sensealg,u0,p;kwargs... )
275+ else
276+ _concrete_solve_forward (prob,args[1 ],sensealg,u0,p,Base. tail (args)... ;kwargs... )
277+ end
253278end
254279
255280function _concrete_solve_adjoint (args... ;kwargs... )
0 commit comments