Skip to content

Commit 809a2df

Browse files
simulation under ReactiveNetwork
1 parent 8c35916 commit 809a2df

File tree

2 files changed

+82
-114
lines changed

2 files changed

+82
-114
lines changed

src/solvers.jl

Lines changed: 79 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -298,53 +298,6 @@ function free_blocked_species!(state)
298298
end
299299
end
300300

301-
"""
302-
Transform an `acs` to a `DiscreteProblem` instance, compatible with standard solvers.
303-
304-
# Examples
305-
306-
```julia
307-
transform(DiscreteProblem, acs; schedule = schedule_weighted!)
308-
```
309-
"""
310-
function transform(
311-
::Type{DiffEqBase.DiscreteProblem},
312-
state::ReactiveNetwork;
313-
kwargs...,
314-
)
315-
f = function (du, u, p, t)
316-
state = p[:__state__]
317-
free_blocked_species!(state)
318-
du .= state.u
319-
update_observables(state)
320-
sample_transitions!(state)
321-
evolve!(du, state)
322-
finish!(du, state)
323-
update_u!(state, du)
324-
event_action!(state)
325-
326-
du .= state.u
327-
push!(
328-
state.log,
329-
(:valuation, t, du' * [state[i, :specValuation] for i = 1:nparts(state, :S)]),
330-
)
331-
332-
t = (state.t += state.solverargs[:tstep])
333-
update_u!(state, du)
334-
save!(state)
335-
sync_p!(p, state)
336-
337-
return du
338-
end
339-
340-
return DiffEqBase.DiscreteProblem(
341-
f,
342-
state.u,
343-
(0.0, 2.0),
344-
Dict(state.p..., :__state__ => state, :__state0__ => deepcopy(state));
345-
kwargs...,
346-
)
347-
end
348301

349302
## resolve tspan, tstep
350303

@@ -358,21 +311,11 @@ function get_tcontrol(tspan, args)
358311
return ((0.0, tspan), tstep)
359312
end
360313

361-
"""
362-
Transform an `acs` to a `DiscreteProblem` instance, compatible with standard solvers.
363-
364-
Optionally accepts initial values and parameters, which take precedence over specifications in `acs`.
365-
366-
# Examples
367-
368-
```julia
369-
DiscreteProblem(acs, u0, p; tspan = (0.0, 100.0), schedule = schedule_weighted!)
370-
```
371-
"""
372-
function DiffEqBase.DiscreteProblem(
314+
function ReactiveNetwork(
373315
acs::ReactionNetwork,
374316
u0 = Dict(),
375317
p = DiffEqBase.NullParameters();
318+
name = "reactive_network",
376319
kwargs...,
377320
)
378321
assign_defaults!(acs)
@@ -386,30 +329,20 @@ function DiffEqBase.DiscreteProblem(
386329

387330
acs = remove_choose(acs)
388331
attrs, transitions, wrap_fun = compile_attrs(acs)
389-
state = ReactiveNetwork(
390-
acs,
391-
attrs,
392-
transitions,
393-
wrap_fun,
394-
keywords[:tspan][1];
395-
name = "rn_state",
396-
keywords...,
397-
)
332+
398333
init_u!(state)
399334
save!(state)
400335

401-
prob = transform(DiffEqBase.DiscreteProblem, state; kwargs...)
336+
u0_init = zeros(nparts(state, :S))
402337

403338
u0 isa Dict && foreach(
404339
i ->
405-
prob.u0[i] =
406-
if !isnothing(acs[i, :specName]) && haskey(u0, acs[i, :specName])
407-
u0[acs[i, :specName]]
408-
else
409-
prob.u0[i]
410-
end,
340+
if !isnothing(acs[i, :specName]) && haskey(u0, acs[i, :specName])
341+
u0_init[i] = u0[acs[i, :specName]]
342+
end,
411343
1:nparts(state, :S),
412344
)
345+
413346
p_ = p == DiffEqBase.NullParameters() ? Dict() : Dict(k => v for (k, v) in p)
414347
prob = remake(
415348
prob;
@@ -426,7 +359,77 @@ function DiffEqBase.DiscreteProblem(
426359
),
427360
)
428361

429-
return prob
362+
ongoing_transitions = Transition[]
363+
log = NamedTuple[]
364+
observables = compile_observables(acs)
365+
transitions_attrs =
366+
setdiff(
367+
filter(a -> contains(string(a), "trans"), propertynames(acs.subparts)),
368+
(:trans,),
369+
) [:transLHS, :transRHS, :transToSpawn, :transHash]
370+
transitions = Dict{Symbol,Vector}(a => [] for a in transitions_attrs)
371+
372+
return ReactiveNetwork(
373+
name,
374+
acs,
375+
attrs,
376+
transition_recipes,
377+
u0_init,
378+
merge(
379+
prob.p,
380+
p_,
381+
Dict(
382+
:tstep => get(keywords, :tstep, 1),
383+
:strategy => get(keywords, :alloc_strategy, :weighted),
384+
),
385+
),
386+
t,
387+
keywords[:tspan][1],
388+
keywords[:tspan],
389+
get(keywords, :tstep, 1),
390+
transitions,
391+
ongoing_transitions,
392+
log,
393+
observables,
394+
kwargs,
395+
wrap_fun,
396+
Vector{Float64}[],
397+
Float64[],
398+
)
399+
end
400+
401+
function AlgebraicAgents.step!(state::ReactiveNetwork)
402+
du = copy(state.u)
403+
free_blocked_species!(state)
404+
du .= state.u
405+
update_observables(state)
406+
sample_transitions!(state)
407+
evolve!(du, state)
408+
finish!(du, state)
409+
update_u!(state, du)
410+
event_action!(state)
411+
412+
du .= state.u
413+
push!(
414+
state.log,
415+
(:valuation, t, du' * [state[i, :specValuation] for i = 1:nparts(state, :S)]),
416+
)
417+
418+
t = (state.t += state.solverargs[:tstep])
419+
update_u!(state, du)
420+
save!(state)
421+
sync_p!(p, state)
422+
423+
state.u .= du
424+
state.t += state.dt
425+
end
426+
427+
function AlgebraicAgents._projected_to(state::ReactiveNetwork)
428+
if state.t >= state.tspan[2]
429+
true
430+
else
431+
state.t
432+
end
430433
end
431434

432435
function fetch_params(acs::ReactionNetwork)

src/state.jl

Lines changed: 3 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ end
4141
p::Any
4242
t::Float64
4343

44+
tspan::Tuple{Float64,Float64}
45+
dt::Float64
46+
4447
transitions::Dict{Symbol,Vector}
4548
ongoing_transitions::Vector{Transition}
4649
log::Vector{Tuple}
@@ -53,44 +56,6 @@ end
5356
history_t::Vector{Float64}
5457
end
5558

56-
function ReactiveNetwork(
57-
acs::ReactionNetwork,
58-
attrs,
59-
transition_recipes,
60-
wrap_fun,
61-
t0 = 0;
62-
name = "rn_state",
63-
kwargs...,
64-
)
65-
ongoing_transitions = Transition[]
66-
log = NamedTuple[]
67-
observables = compile_observables(acs)
68-
transitions_attrs =
69-
setdiff(
70-
filter(a -> contains(string(a), "trans"), propertynames(acs.subparts)),
71-
(:trans,),
72-
) [:transLHS, :transRHS, :transToSpawn, :transHash]
73-
transitions = Dict{Symbol,Vector}(a => [] for a in transitions_attrs)
74-
75-
return ReactiveNetwork(
76-
name,
77-
acs,
78-
attrs,
79-
transition_recipes,
80-
zeros(nparts(acs, :S)),
81-
fetch_params(acs),
82-
t0,
83-
transitions,
84-
ongoing_transitions,
85-
log,
86-
observables,
87-
kwargs,
88-
wrap_fun,
89-
Vector{Float64}[],
90-
Float64[],
91-
)
92-
end
93-
9459
# get value of a numeric expression
9560
# evaluate compiled numeric expression in context of (u, p, t)
9661
function context_eval(state::ReactiveNetwork, o)

0 commit comments

Comments
 (0)