Skip to content

Commit 818c8c2

Browse files
test: test AD after subset_tunables
1 parent 2ae41fb commit 818c8c2

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

test/mtkparameters.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using ModelingToolkit
22
using ModelingToolkit: t_nounits as t, D_nounits as D, MTKParameters
33
using SymbolicIndexingInterface, StaticArrays
44
using SciMLStructures: SciMLStructures, canonicalize, Tunable, Discrete, Constants
5+
using ModelingToolkitStandardLibrary.Electrical, ModelingToolkitStandardLibrary.Blocks
56
using BlockArrays: BlockedArray, BlockedVector, Block
67
using OrdinaryDiffEq
78
using ForwardDiff
@@ -379,3 +380,52 @@ with_updated_parameter_timeseries_values(
379380
ps2 = remake_buffer(sys, ps, [p], [:a])
380381
@test ps2.nonnumeric isa Tuple{Vector{Any}}
381382
end
383+
384+
@testset "Issue#3925: Autodiff after `subset_tunables`" begin
385+
function circuit_model()
386+
@named resistor1 = Resistor(R=5.0)
387+
@named resistor2 = Resistor(R=2.0)
388+
@named capacitor1 = Capacitor(C=2.4)
389+
@named capacitor2 = Capacitor(C=60.0)
390+
@named source = Voltage()
391+
@named input_signal = Sine(frequency=1.0)
392+
@named ground = Ground()
393+
@named ampermeter = CurrentSensor()
394+
395+
eqs = [connect(input_signal.output, source.V)
396+
connect(source.p, capacitor1.n, capacitor2.n)
397+
connect(source.n, resistor1.p, resistor2.p, ground.g)
398+
connect(resistor1.n, capacitor1.p, ampermeter.n)
399+
connect(resistor2.n, capacitor2.p, ampermeter.p)]
400+
401+
@named circuit_model = System(eqs, t,
402+
systems=[
403+
resistor1, resistor2, capacitor1, capacitor2,
404+
source, input_signal, ground, ampermeter
405+
])
406+
end
407+
408+
model = circuit_model()
409+
sys = mtkcompile(model)
410+
411+
tunable_parameters(sys)
412+
413+
sub_sys = subset_tunables(sys, [sys.capacitor2.C])
414+
415+
tunable_parameters(sub_sys)
416+
417+
prob = ODEProblem(sub_sys, [sys.capacitor2.v => 0.0], (0, 3.))
418+
419+
setter = setsym_oop(prob, [sys.capacitor2.C]);
420+
421+
function loss(x, ps)
422+
setter, prob = ps
423+
u0, p = setter(prob, x)
424+
new_prob = remake(prob; u0, p)
425+
sol = solve(new_prob, Rodas5P())
426+
sum(sol)
427+
end
428+
429+
grad = ForwardDiff.gradient(Base.Fix2(loss, (setter, prob)), [3.0])
430+
@test grad [0.14882627068752538] atol=1e-10
431+
end

0 commit comments

Comments
 (0)