Skip to content

Commit aa541c0

Browse files
MaxBetzDLRHenrZujubickercharlie0614
authored
1351 pymio interface for sde sir type models (#1371)
Co-authored-by: HenrZu <69154294+HenrZu@users.noreply.github.com> Co-authored-by: jubicker <113909589+jubicker@users.noreply.github.com> Co-authored-by: Carlotta Gerstein <100771374+charlie0614@users.noreply.github.com>
1 parent a0e3a4f commit aa541c0

File tree

15 files changed

+772
-16
lines changed

15 files changed

+772
-16
lines changed

cpp/examples/sde_sirs.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ int main()
4848
model.parameters.get<mio::ssirs::ContactPatterns<ScalarType>>().get_baseline()(0, 0) = 20.7;
4949
model.parameters.get<mio::ssirs::ContactPatterns<ScalarType>>().add_damping(0.6,
5050
mio::SimulationTime<ScalarType>(12.5));
51-
51+
model.parameters.set<mio::ssirs::StartDay<ScalarType>>(60);
52+
model.parameters.set<mio::ssirs::Seasonality<ScalarType>>(0.2);
5253
model.check_constraints();
5354

5455
auto ssirs = mio::simulate_stochastic(t0, tmax, dt, model);

cpp/models/sde_sirs/model.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,12 @@ class Model
5656
Eigen::Ref<Eigen::VectorX<FP>> flows) const
5757
{
5858
auto& params = Base::parameters;
59-
FP coeffStoI = params.template get<ContactPatterns<FP>>().get_matrix_at(SimulationTime<FP>(t))(0, 0) *
59+
// effective contact rate by contact rate between groups i and j and damping j
60+
FP season_val =
61+
(1 + params.template get<Seasonality<FP>>() *
62+
sin(std::numbers::pi_v<ScalarType> * ((params.template get<StartDay<FP>>() + t) / 182.5 + 0.5)));
63+
64+
FP coeffStoI = season_val * params.template get<ContactPatterns<FP>>().get_matrix_at(SimulationTime<FP>(t))(0, 0) *
6065
params.template get<TransmissionProbabilityOnContact<FP>>() / Base::populations.get_total();
6166

6267
flows[this->template get_flat_flow_index<InfectionState::Susceptible, InfectionState::Infected>()] =

cpp/models/sde_sirs/parameters.h

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,46 @@ struct ContactPatterns {
9898
}
9999
};
100100

101+
/**
102+
* @brief The start day in the SIRS model
103+
* The start day defines in which season the simulation can be started
104+
* If the start day is 180 and simulation takes place from t0=0 to
105+
* tmax=100 the days 180 to 280 of the year are simulated
106+
*/
107+
template <typename FP>
108+
struct StartDay {
109+
using Type = FP;
110+
static Type get_default()
111+
{
112+
return Type(0.0);
113+
}
114+
static std::string name()
115+
{
116+
return "StartDay";
117+
}
118+
};
119+
120+
/**
121+
* @brief The seasonality in the SIRS model.
122+
* The seasonality is given as (1+k*sin()) where the sine
123+
* curve is below one in summer and above one in winter
124+
*/
125+
template <typename FP>
126+
struct Seasonality {
127+
using Type = UncertainValue<FP>;
128+
static Type get_default()
129+
{
130+
return Type(0.);
131+
}
132+
static std::string name()
133+
{
134+
return "Seasonality";
135+
}
136+
};
137+
101138
template <typename FP>
102139
using ParametersBase =
103-
ParameterSet<TransmissionProbabilityOnContact<FP>, TimeInfected<FP>, ContactPatterns<FP>, TimeImmune<FP>>;
140+
ParameterSet<TransmissionProbabilityOnContact<FP>, TimeInfected<FP>, ContactPatterns<FP>, TimeImmune<FP>, Seasonality<FP>, StartDay<FP>>;
104141

105142
/**
106143
* @brief Parameters of SIR model.
@@ -132,6 +169,12 @@ class Parameters : public ParametersBase<FP>
132169
FP tol_times = 1e-1;
133170

134171
int corrected = false;
172+
if (this->template get<Seasonality<FP>>() < 0.0 || this->template get<Seasonality<FP>>() > 0.5) {
173+
log_warning("Constraint check: Parameter Seasonality changed from {:0.4f} to {:d}",
174+
this->template get<Seasonality<FP>>(), 0);
175+
this->template set<Seasonality<FP>>(0);
176+
corrected = true;
177+
}
135178
if (this->template get<TimeInfected<FP>>() < tol_times) {
136179
log_warning("Constraint check: Parameter TimeInfected changed from {:.4f} to {:.4f}. Please note that "
137180
"unreasonably small compartment stays lead to massively increased run time. Consider to cancel "
@@ -167,6 +210,10 @@ class Parameters : public ParametersBase<FP>
167210
{
168211
FP tol_times = 1e-1;
169212

213+
if (this->template get<Seasonality<FP>>() < 0.0 || this->template get<Seasonality<FP>>() > 0.5) {
214+
log_error("Constraint check: Parameter Seasonality smaller {:d} or larger {:d}", 0, 0.5);
215+
return true;
216+
}
170217
if (this->template get<TimeInfected<FP>>() < tol_times) {
171218
log_error("Constraint check: Parameter TimeInfected {:.4f} smaller or equal {:.4f}. Please note that "
172219
"unreasonably small compartment stays lead to massively increased run time. Consider to cancel "

cpp/tests/test_sde_sirs.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ TEST(TestSdeSirs, check_constraints_parameters)
9090
parameters.set<mio::ssirs::TimeImmune<double>>(6);
9191
parameters.set<mio::ssirs::TransmissionProbabilityOnContact<double>>(0.04);
9292
parameters.get<mio::ssirs::ContactPatterns<double>>().get_baseline()(0, 0) = 10;
93+
parameters.set<mio::ssirs::StartDay<double>>(30);
94+
parameters.set<mio::ssirs::Seasonality<double>>(0.3);
9395

9496
// model.check_constraints() combines the functions from population and parameters.
9597
// We only want to test the functions for the parameters defined in parameters.h
@@ -107,6 +109,10 @@ TEST(TestSdeSirs, check_constraints_parameters)
107109
parameters.set<mio::ssirs::TimeImmune<double>>(6);
108110
parameters.set<mio::ssirs::TransmissionProbabilityOnContact<double>>(10.);
109111
EXPECT_EQ(parameters.check_constraints(), 1);
112+
113+
parameters.set<mio::ssirs::TransmissionProbabilityOnContact<double>>(0.04);
114+
parameters.set<mio::ssirs::Seasonality<double>>(-2.);
115+
EXPECT_EQ(parameters.check_constraints(), 1);
110116
mio::set_log_level(mio::LogLevel::warn);
111117
}
112118

@@ -119,6 +125,8 @@ TEST(TestSdeSirs, apply_constraints_parameters)
119125
parameters.set<mio::ssirs::TimeImmune<double>>(6);
120126
parameters.set<mio::ssirs::TransmissionProbabilityOnContact<double>>(0.04);
121127
parameters.get<mio::ssirs::ContactPatterns<double>>().get_baseline()(0, 0) = 10;
128+
parameters.set<mio::ssirs::StartDay<double>>(30);
129+
parameters.set<mio::ssirs::Seasonality<double>>(0.3);
122130

123131
EXPECT_EQ(parameters.apply_constraints(), 0);
124132

@@ -135,5 +143,9 @@ TEST(TestSdeSirs, apply_constraints_parameters)
135143
parameters.set<mio::ssirs::TransmissionProbabilityOnContact<double>>(10.);
136144
EXPECT_EQ(parameters.apply_constraints(), 1);
137145
EXPECT_NEAR(parameters.get<mio::ssirs::TransmissionProbabilityOnContact<double>>(), 0.0, 1e-14);
146+
147+
parameters.set<mio::ssirs::Seasonality<double>>(-2.);
148+
EXPECT_EQ(parameters.apply_constraints(), 1);
149+
EXPECT_NEAR(parameters.get<mio::ssirs::Seasonality<double>>(), 0.0, 1e-14);
138150
mio::set_log_level(mio::LogLevel::warn);
139151
}

docs/source/cpp/models/ssirs.rst

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,21 @@ Below is an overview of the model architecture and its compartments.
1616
.. image:: https://martinkuehn.eu/research/images/sirs.png
1717
:alt: SIR_model
1818

19-
+-------------------------------+-----------------------------------------------+--------------------------------------------------------------------------------------------------+
20-
| Mathematical variable | C++ variable name | Description |
21-
+===============================+===============================================+==================================================================================================+
22-
| :math:`\phi` | ``ContactPatterns`` | Daily contact rate / Number of daily contacts. |
23-
+-------------------------------+-----------------------------------------------+--------------------------------------------------------------------------------------------------+
24-
| :math:`\rho` | ``TransmissionProbabilityOnContact`` | Transmission risk for people located in the Susceptible compartment. |
25-
+-------------------------------+-----------------------------------------------+--------------------------------------------------------------------------------------------------+
26-
| :math:`N` | ``populations.get_total()`` | Total population. |
27-
+-------------------------------+-----------------------------------------------+--------------------------------------------------------------------------------------------------+
28-
| :math:`T_{I}` | ``TimeInfected`` | Time in days an individual stays in the Infected compartment. |
29-
+-------------------------------+-----------------------------------------------+--------------------------------------------------------------------------------------------------+
30-
| :math:`T_{R}` | ``TimeImmune`` | Time in days an individual stays in the Recovered compartment before becoming Susceptible again. |
31-
+-------------------------------+-----------------------------------------------+--------------------------------------------------------------------------------------------------+
19+
+-------------------------------+-----------------------------------------------+------------------------------------------------------------------------------------------------------------+
20+
| Mathematical variable | C++ variable name | Description |
21+
+===============================+===============================================+============================================================================================================+
22+
| :math:`\phi` | ``ContactPatterns`` | Daily contact rate / Number of daily contacts. |
23+
+-------------------------------+-----------------------------------------------+------------------------------------------------------------------------------------------------------------+
24+
| :math:`\rho` | ``TransmissionProbabilityOnContact`` | Transmission risk for people located in the Susceptible compartment. |
25+
+-------------------------------+-----------------------------------------------+------------------------------------------------------------------------------------------------------------+
26+
| :math:`N` | ``populations.get_total()`` | Total population. |
27+
+-------------------------------+-----------------------------------------------+------------------------------------------------------------------------------------------------------------+
28+
| :math:`T_{I}` | ``TimeInfected`` | Time in days an individual stays in the Infected compartment. |
29+
+-------------------------------+-----------------------------------------------+------------------------------------------------------------------------------------------------------------+
30+
| :math:`T_{R}` | ``TimeImmune`` | Time in days an individual stays in the Recovered compartment before becoming Susceptible again. |
31+
+-------------------------------+-----------------------------------------------+------------------------------------------------------------------------------------------------------------+
32+
| :math:`k` | ``Seasonality`` | Influence of the seasons is given by :math:`s_k(t) = 1 + k \sin \left(\frac{t}{182.5} + \frac{1}{2}\right)`|
33+
+-------------------------------+-----------------------------------------------+------------------------------------------------------------------------------------------------------------+
3234

3335
An example can be found in the
3436
`examples/ode_sir.cpp <https://github.com/SciCompMod/memilio/blob/main/cpp/examples/sde_sirs.cpp>`_.
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#############################################################################
2+
# Copyright (C) 2020-2025 MEmilio
3+
#
4+
# Authors: Maximilian Betz
5+
#
6+
# Contact: Martin J. Kuehn <Martin.Kuehn@DLR.de>
7+
#
8+
# Licensed under the Apache License, Version 2.0 (the "License");
9+
# you may not use this file except in compliance with the License.
10+
# You may obtain a copy of the License at
11+
#
12+
# http://www.apache.org/licenses/LICENSE-2.0
13+
#
14+
# Unless required by applicable law or agreed to in writing, software
15+
# distributed under the License is distributed on an "AS IS" BASIS,
16+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
# See the License for the specific language governing permissions and
18+
# limitations under the License.
19+
#############################################################################
20+
import argparse
21+
22+
import numpy as np
23+
24+
from memilio.simulation import AgeGroup, Damping
25+
from memilio.simulation.ssir import InfectionState as State
26+
from memilio.simulation.ssir import (Model, simulate_stochastic)
27+
28+
29+
def run_sde_sir_simulation():
30+
"""Runs SDE SIR model"""
31+
32+
tmax = 5. # simulation time frame
33+
dt = 0.1
34+
35+
# Initialize model
36+
model = Model()
37+
38+
# Mean time in Infected compartment
39+
model.parameters.TimeInfected.value = 10.
40+
41+
model.parameters.TransmissionProbabilityOnContact.value = 1.
42+
43+
# Initial number of people per compartment
44+
total_population = 10000
45+
model.populations[State.Infected] = 100
46+
model.populations[State.Recovered] = 1000
47+
model.populations.set_difference_from_total(
48+
(State.Susceptible), total_population)
49+
50+
model.parameters.ContactPatterns.baseline = np.ones(
51+
(1, 1)) * 2.7
52+
model.parameters.ContactPatterns.minimum = np.zeros(
53+
(1, 1))
54+
model.parameters.ContactPatterns.add_damping(
55+
Damping(coeffs=np.r_[0.6], t=2., level=0, type=0))
56+
57+
# Check parameter constraints
58+
model.check_constraints()
59+
60+
# Run Simulation
61+
result = simulate_stochastic(0., tmax, dt, model)
62+
63+
result.print_table(False, ["S", "I", "R"], 16, 5)
64+
65+
66+
if __name__ == "__main__":
67+
arg_parser = argparse.ArgumentParser(
68+
'sde_sir_simple',
69+
description='Simple example demonstrating the setup and simulation of the SDE SIR model.')
70+
args = arg_parser.parse_args()
71+
run_sde_sir_simulation()
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#############################################################################
2+
# Copyright (C) 2020-2025 MEmilio
3+
#
4+
# Authors: Maximilian Betz
5+
#
6+
# Contact: Martin J. Kuehn <Martin.Kuehn@DLR.de>
7+
#
8+
# Licensed under the Apache License, Version 2.0 (the "License");
9+
# you may not use this file except in compliance with the License.
10+
# You may obtain a copy of the License at
11+
#
12+
# http://www.apache.org/licenses/LICENSE-2.0
13+
#
14+
# Unless required by applicable law or agreed to in writing, software
15+
# distributed under the License is distributed on an "AS IS" BASIS,
16+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
# See the License for the specific language governing permissions and
18+
# limitations under the License.
19+
#############################################################################
20+
import argparse
21+
22+
import numpy as np
23+
24+
from memilio.simulation import AgeGroup, Damping
25+
from memilio.simulation.ssirs import InfectionState as State
26+
from memilio.simulation.ssirs import (
27+
Model, simulate_stochastic, interpolate_simulation_result)
28+
29+
30+
def run_sde_sirs_simulation():
31+
"""Runs SDE SIRS model"""
32+
33+
tmax = 5. # simulation time frame
34+
dt = 0.001
35+
36+
# Initialize Model
37+
model = Model()
38+
39+
# Mean time in Infected compartment
40+
model.parameters.TimeInfected.value = 10.
41+
model.parameters.TimeImmune.value = 100.
42+
43+
model.parameters.TransmissionProbabilityOnContact.value = 1.
44+
45+
# Initial number of people per compartment
46+
total_population = 10000
47+
model.populations[State.Infected] = 100
48+
model.populations[State.Recovered] = 1000
49+
model.populations.set_difference_from_total(
50+
(State.Susceptible), total_population)
51+
52+
model.parameters.ContactPatterns.baseline = np.ones(
53+
(1, 1)) * 20.7
54+
model.parameters.ContactPatterns.minimum = np.zeros(
55+
(1, 1))
56+
model.parameters.ContactPatterns.add_damping(
57+
Damping(coeffs=np.r_[0.6], t=2, level=0, type=0))
58+
59+
# Check parameter constraints
60+
model.check_constraints()
61+
62+
# Run Simulation
63+
result = simulate_stochastic(0., days, dt, model)
64+
65+
# Interpolate results
66+
result = interpolate_simulation_result(result)
67+
68+
result.print_table(False, ["Susceptible", "Infected", "Recovered"], 16, 5)
69+
70+
71+
if __name__ == "__main__":
72+
arg_parser = argparse.ArgumentParser(
73+
'sde_sirs_simple',
74+
description='Simple example demonstrating the setup and simulation of the SDE SIRS model.')
75+
args = arg_parser.parse_args()
76+
run_sde_sirs_simulation()

pycode/memilio-simulation/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,15 @@ add_pymio_module(_simulation_osecirvvs
8787
SOURCES memilio/simulation/bindings/models/osecirvvs.cpp
8888
)
8989

90+
add_pymio_module(_simulation_ssir
91+
LINKED_LIBRARIES memilio sde_sir
92+
SOURCES memilio/simulation/bindings/models/ssir.cpp
93+
)
94+
95+
add_pymio_module(_simulation_ssirs
96+
LINKED_LIBRARIES memilio sde_sirs
97+
SOURCES memilio/simulation/bindings/models/ssirs.cpp
98+
)
9099
add_pymio_module(_simulation_omseirs4
91100
LINKED_LIBRARIES memilio ode_mseirs4
92101
SOURCES memilio/simulation/bindings/models/omseirs4.cpp

pycode/memilio-simulation/memilio/simulation/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ def __getattr__(attr):
4545
elif attr == "osecirvvs":
4646
import memilio.simulation.osecirvvs as osecirvvs
4747
return osecirvvs
48+
elif attr == "ssir":
49+
import memilio.simulation.ssir as ssir
50+
return ssir
51+
52+
elif attr == "ssirs":
53+
import memilio.simulation.ssirs as ssirs
54+
return ssirs
4855

4956
raise AttributeError("module {!r} has no attribute "
5057
"{!r}".format(__name__, attr))

0 commit comments

Comments
 (0)