From 25db4dc828ff573e78bbb395bdac3a5e356425a0 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Wed, 16 Jul 2025 13:44:03 +0200 Subject: [PATCH 01/73] set up simulation for germany without age groups --- cpp/memilio/io/epi_data.cpp | 8 +- cpp/memilio/io/epi_data.h | 14 +- cpp/memilio/mobility/graph.h | 2 +- cpp/models/ode_secir/parameters_io.cpp | 134 ++++++++++++ cpp/models/ode_secir/parameters_io.h | 106 +++++++++- .../simulation/graph_germany_nuts3.py | 194 ++++++++++++++++++ .../memilio/epidata/getPopulationData.py | 1 + .../simulation/bindings/models/osecir.cpp | 11 +- 8 files changed, 448 insertions(+), 22 deletions(-) create mode 100644 pycode/examples/simulation/graph_germany_nuts3.py diff --git a/cpp/memilio/io/epi_data.cpp b/cpp/memilio/io/epi_data.cpp index 24e86a00e0..97210aef98 100644 --- a/cpp/memilio/io/epi_data.cpp +++ b/cpp/memilio/io/epi_data.cpp @@ -25,12 +25,8 @@ namespace mio { -std::vector ConfirmedCasesDataEntry::age_group_names = {"A00-A04", "A05-A14", "A15-A34", - "A35-A59", "A60-A79", "A80+"}; - -std::vector PopulationDataEntry::age_group_names = { - "<3 years", "3-5 years", "6-14 years", "15-17 years", "18-24 years", "25-29 years", - "30-39 years", "40-49 years", "50-64 years", "65-74 years", ">74 years"}; +std::vector ConfirmedCasesDataEntry::age_group_names = {"Population"}; +std::vector PopulationDataEntry::age_group_names = {"Population"}; std::vector VaccinationDataEntry::age_group_names = {"0-4", "5-14", "15-34", "35-59", "60-79", "80-99"}; diff --git a/cpp/memilio/io/epi_data.h b/cpp/memilio/io/epi_data.h index 3e4572b870..8c25f96c84 100644 --- a/cpp/memilio/io/epi_data.h +++ b/cpp/memilio/io/epi_data.h @@ -77,6 +77,9 @@ class ConfirmedCasesNoAgeEntry double num_recovered; double num_deaths; Date date; + boost::optional state_id; + boost::optional county_id; + boost::optional district_id; template static IOResult deserialize(IOContext& io) @@ -86,12 +89,15 @@ class ConfirmedCasesNoAgeEntry auto num_recovered = obj.expect_element("Recovered", Tag{}); auto num_deaths = obj.expect_element("Deaths", Tag{}); auto date = obj.expect_element("Date", Tag{}); + auto state_id = obj.expect_optional("ID_State", Tag{}); + auto county_id = obj.expect_optional("ID_County", Tag{}); + auto district_id = obj.expect_optional("ID_District", Tag{}); return apply( io, - [](auto&& nc, auto&& nr, auto&& nd, auto&& d) { - return ConfirmedCasesNoAgeEntry{nc, nr, nd, d}; + [](auto&& nc, auto&& nr, auto&& nd, auto&& d, auto&& sid, auto&& cid, auto&& did) { + return ConfirmedCasesNoAgeEntry{nc, nr, nd, d, sid, cid, did}; }, - num_confirmed, num_recovered, num_deaths, date); + num_confirmed, num_recovered, num_deaths, date, state_id, county_id, district_id); } }; @@ -442,7 +448,7 @@ inline IOResult> deserialize_population_data(co * @return list of population data. */ inline IOResult> read_population_data(const std::string& filename, - bool rki_age_group = true) + bool rki_age_group = false) { BOOST_OUTCOME_TRY(auto&& jsvalue, read_json(filename)); return deserialize_population_data(jsvalue, rki_age_group); diff --git a/cpp/memilio/mobility/graph.h b/cpp/memilio/mobility/graph.h index eb8bd35990..5b211fd4a1 100644 --- a/cpp/memilio/mobility/graph.h +++ b/cpp/memilio/mobility/graph.h @@ -368,7 +368,7 @@ IOResult set_edges(const fs::path& mobility_data_file, Graph(); ++age) { for (auto compartment : mobile_compartments) { auto coeff_index = populations.get_flat_index({age, compartment}); - mobility_coeffs[size_t(ContactLocation::Work)].get_baseline()[coeff_index] = + mobility_coeffs[size_t(ContactLocation::Home)].get_baseline()[coeff_index] = commuter_coeff_ij * commuting_weights[size_t(age)]; } } diff --git a/cpp/models/ode_secir/parameters_io.cpp b/cpp/models/ode_secir/parameters_io.cpp index 5a2e122c24..dbaa1242b7 100644 --- a/cpp/models/ode_secir/parameters_io.cpp +++ b/cpp/models/ode_secir/parameters_io.cpp @@ -187,6 +187,140 @@ IOResult read_confirmed_cases_data( return success(); } +IOResult read_confirmed_cases_noage( + std::vector& rki_data, std::vector const& vregion, Date date, + std::vector>& vnum_Exposed, std::vector>& vnum_InfectedNoSymptoms, + std::vector>& vnum_InfectedSymptoms, std::vector>& vnum_InfectedSevere, + std::vector>& vnum_icu, std::vector>& vnum_death, + std::vector>& vnum_rec, const std::vector>& vt_Exposed, + const std::vector>& vt_InfectedNoSymptoms, + const std::vector>& vt_InfectedSymptoms, const std::vector>& vt_InfectedSevere, + const std::vector>& vt_InfectedCritical, const std::vector>& vmu_C_R, + const std::vector>& vmu_I_H, const std::vector>& vmu_H_U, + const std::vector& scaling_factor_inf) +{ + auto max_date_entry = std::max_element(rki_data.begin(), rki_data.end(), [](auto&& a, auto&& b) { + return a.date < b.date; + }); + if (max_date_entry == rki_data.end()) { + log_error("RKI data file is empty."); + return failure(StatusCode::InvalidFileFormat, "RKI file is empty."); + } + auto max_date = max_date_entry->date; + if (max_date < date) { + log_error("Specified date does not exist in RKI data"); + return failure(StatusCode::OutOfRange, "Specified date does not exist in RKI data."); + } + auto days_surplus = std::min(get_offset_in_days(max_date, date) - 6, 0); + + //this statement causes maybe-uninitialized warning for some versions of gcc. + //the error is reported in an included header, so the warning is disabled for the whole file + std::sort(rki_data.begin(), rki_data.end(), [](auto&& a, auto&& b) { + return std::make_tuple(get_region_id(a), a.date) < std::make_tuple(get_region_id(b), b.date); + }); + + for (auto region_idx = size_t(0); region_idx < vregion.size(); ++region_idx) { + auto region_entry_range_it = + std::equal_range(rki_data.begin(), rki_data.end(), vregion[region_idx], [](auto&& a, auto&& b) { + return get_region_id(a) < get_region_id(b); + }); + auto region_entry_range = make_range(region_entry_range_it); + if (region_entry_range.begin() == region_entry_range.end()) { + log_error("No entries found for region {}", vregion[region_idx]); + return failure(StatusCode::InvalidFileFormat, + "No entries found for region " + std::to_string(vregion[region_idx])); + } + for (auto&& region_entry : region_entry_range) { + + auto& t_Exposed = vt_Exposed[region_idx]; + auto& t_InfectedNoSymptoms = vt_InfectedNoSymptoms[region_idx]; + auto& t_InfectedSymptoms = vt_InfectedSymptoms[region_idx]; + auto& t_InfectedSevere = vt_InfectedSevere[region_idx]; + auto& t_InfectedCritical = vt_InfectedCritical[region_idx]; + + auto& num_InfectedNoSymptoms = vnum_InfectedNoSymptoms[region_idx]; + auto& num_InfectedSymptoms = vnum_InfectedSymptoms[region_idx]; + auto& num_rec = vnum_rec[region_idx]; + auto& num_Exposed = vnum_Exposed[region_idx]; + auto& num_InfectedSevere = vnum_InfectedSevere[region_idx]; + auto& num_death = vnum_death[region_idx]; + auto& num_icu = vnum_icu[region_idx]; + + auto& mu_C_R = vmu_C_R[region_idx]; + auto& mu_I_H = vmu_I_H[region_idx]; + auto& mu_H_U = vmu_H_U[region_idx]; + + auto date_df = region_entry.date; + auto age = 0; + + if (date_df == offset_date_by_days(date, 0)) { + num_InfectedSymptoms[age] += scaling_factor_inf[age] * region_entry.num_confirmed; + num_rec[age] += region_entry.num_confirmed; + } + if (date_df == offset_date_by_days(date, days_surplus)) { + num_InfectedNoSymptoms[age] -= + 1 / (1 - mu_C_R[age]) * scaling_factor_inf[age] * region_entry.num_confirmed; + } + if (date_df == offset_date_by_days(date, t_InfectedNoSymptoms[age] + days_surplus)) { + num_InfectedNoSymptoms[age] += + 1 / (1 - mu_C_R[age]) * scaling_factor_inf[age] * region_entry.num_confirmed; + num_Exposed[age] -= 1 / (1 - mu_C_R[age]) * scaling_factor_inf[age] * region_entry.num_confirmed; + } + if (date_df == offset_date_by_days(date, t_Exposed[age] + t_InfectedNoSymptoms[age] + days_surplus)) { + num_Exposed[age] += 1 / (1 - mu_C_R[age]) * scaling_factor_inf[age] * region_entry.num_confirmed; + } + if (date_df == offset_date_by_days(date, -t_InfectedSymptoms[age])) { + num_InfectedSymptoms[age] -= scaling_factor_inf[age] * region_entry.num_confirmed; + num_InfectedSevere[age] += mu_I_H[age] * scaling_factor_inf[age] * region_entry.num_confirmed; + } + if (date_df == offset_date_by_days(date, -t_InfectedSymptoms[age] - t_InfectedSevere[age])) { + num_InfectedSevere[age] -= mu_I_H[age] * scaling_factor_inf[age] * region_entry.num_confirmed; + num_icu[age] += mu_I_H[age] * mu_H_U[age] * scaling_factor_inf[age] * region_entry.num_confirmed; + } + if (date_df == + offset_date_by_days(date, -t_InfectedSymptoms[age] - t_InfectedSevere[age] - t_InfectedCritical[age])) { + num_death[age] += region_entry.num_deaths; + num_icu[age] -= mu_I_H[age] * mu_H_U[age] * scaling_factor_inf[age] * region_entry.num_confirmed; + } + } + } + + for (size_t region_idx = 0; region_idx < vregion.size(); ++region_idx) { + auto region = vregion[region_idx]; + + auto& num_InfectedNoSymptoms = vnum_InfectedNoSymptoms[region_idx]; + auto& num_InfectedSymptoms = vnum_InfectedSymptoms[region_idx]; + auto& num_rec = vnum_rec[region_idx]; + auto& num_Exposed = vnum_Exposed[region_idx]; + auto& num_InfectedSevere = vnum_InfectedSevere[region_idx]; + auto& num_death = vnum_death[region_idx]; + auto& num_icu = vnum_icu[region_idx]; + + auto try_fix_constraints = [region](double& value, double error, auto str) { + if (value < error) { + //this should probably return a failure + //but the algorithm is not robust enough to avoid large negative values and there are tests that rely on it + log_error("{:s} is {:.4f} for region {:d}, exceeds expected negative value.", str, value, region); + value = 0.0; + } + else if (value < 0) { + log_info("{:s} is {:.4f} for region {:d}, automatically corrected", str, value, region); + value = 0.0; + } + }; + + try_fix_constraints(num_InfectedSymptoms[0], -5, "InfectedSymptoms"); + try_fix_constraints(num_InfectedNoSymptoms[0], -5, "InfectedNoSymptoms"); + try_fix_constraints(num_Exposed[0], -5, "Exposed"); + try_fix_constraints(num_InfectedSevere[0], -5, "InfectedSevere"); + try_fix_constraints(num_death[0], -5, "Dead"); + try_fix_constraints(num_icu[0], -5, "InfectedCritical"); + try_fix_constraints(num_rec[0], -20, "Recovered"); + } + + return success(); +} + } // namespace details } // namespace osecir } // namespace mio diff --git a/cpp/models/ode_secir/parameters_io.h b/cpp/models/ode_secir/parameters_io.h index ab334d922d..c721bab032 100644 --- a/cpp/models/ode_secir/parameters_io.h +++ b/cpp/models/ode_secir/parameters_io.h @@ -151,6 +151,102 @@ IOResult set_confirmed_cases_data(std::vector>& model, std::vect return success(); } +IOResult read_confirmed_cases_noage( + std::vector& rki_data, std::vector const& vregion, Date date, + std::vector>& vnum_Exposed, std::vector>& vnum_InfectedNoSymptoms, + std::vector>& vnum_InfectedSymptoms, std::vector>& vnum_InfectedSevere, + std::vector>& vnum_icu, std::vector>& vnum_death, + std::vector>& vnum_rec, const std::vector>& vt_Exposed, + const std::vector>& vt_InfectedNoSymptoms, + const std::vector>& vt_InfectedSymptoms, const std::vector>& vt_InfectedSevere, + const std::vector>& vt_InfectedCritical, const std::vector>& vmu_C_R, + const std::vector>& vmu_I_H, const std::vector>& vmu_H_U, + const std::vector& scaling_factor_inf); + +/** + * @brief Sets populations data from already read case data with multiple age groups into a Model with one age group. + * @tparam FP Floating point data type, e.g., double. + * @param[in, out] model Vector of models in which the data is set. + * @param[in] case_data List of confirmed cases data entries. + * @param[in] region Vector of keys of the region of interest. + * @param[in] date Date at which the data is read. + * @param[in] scaling_factor_inf Factors by which to scale the confirmed cases of rki data. + */ +template +IOResult +set_confirmed_cases_noage(std::vector>& model, std::vector& case_data, + const std::vector& region, Date date, const std::vector& scaling_factor_inf) +{ + std::vector> t_InfectedNoSymptoms{model.size()}; + std::vector> t_Exposed{model.size()}; + std::vector> t_InfectedSymptoms{model.size()}; + std::vector> t_InfectedSevere{model.size()}; + std::vector> t_InfectedCritical{model.size()}; + + std::vector> mu_C_R{model.size()}; + std::vector> mu_I_H{model.size()}; + std::vector> mu_H_U{model.size()}; + std::vector> mu_U_D{model.size()}; + + for (size_t node = 0; node < model.size(); ++node) { + + t_Exposed[node].push_back( + static_cast(std::round(model[node].parameters.template get>()[AgeGroup(0)]))); + t_InfectedNoSymptoms[node].push_back(static_cast( + std::round(model[node].parameters.template get>()[AgeGroup(0)]))); + t_InfectedSymptoms[node].push_back( + static_cast(std::round(model[node].parameters.template get>()[AgeGroup(0)]))); + t_InfectedSevere[node].push_back( + static_cast(std::round(model[node].parameters.template get>()[AgeGroup(0)]))); + t_InfectedCritical[node].push_back( + static_cast(std::round(model[node].parameters.template get>()[AgeGroup(0)]))); + + mu_C_R[node].push_back(model[node].parameters.template get>()[AgeGroup(0)]); + mu_I_H[node].push_back(model[node].parameters.template get>()[AgeGroup(0)]); + mu_H_U[node].push_back(model[node].parameters.template get>()[AgeGroup(0)]); + mu_U_D[node].push_back(model[node].parameters.template get>()[AgeGroup(0)]); + } + std::vector> num_InfectedSymptoms(model.size(), std::vector(1, 0.0)); + std::vector> num_death(model.size(), std::vector(1, 0.0)); + std::vector> num_rec(model.size(), std::vector(1, 0.0)); + std::vector> num_Exposed(model.size(), std::vector(1, 0.0)); + std::vector> num_InfectedNoSymptoms(model.size(), std::vector(1, 0.0)); + std::vector> num_InfectedSevere(model.size(), std::vector(1, 0.0)); + std::vector> num_icu(model.size(), std::vector(1, 0.0)); + + BOOST_OUTCOME_TRY(read_confirmed_cases_noage(case_data, region, date, num_Exposed, num_InfectedNoSymptoms, + num_InfectedSymptoms, num_InfectedSevere, num_icu, num_death, num_rec, + t_Exposed, t_InfectedNoSymptoms, t_InfectedSymptoms, t_InfectedSevere, + t_InfectedCritical, mu_C_R, mu_I_H, mu_H_U, scaling_factor_inf)); + + for (size_t node = 0; node < model.size(); node++) { + if (std::accumulate(num_InfectedSymptoms[node].begin(), num_InfectedSymptoms[node].end(), 0.0) > 0) { + size_t num_groups = (size_t)model[node].parameters.get_num_groups(); + for (size_t i = 0; i < num_groups; i++) { + model[node].populations[{AgeGroup(i), InfectionState::Exposed}] = num_Exposed[node][i]; + model[node].populations[{AgeGroup(i), InfectionState::InfectedNoSymptoms}] = + num_InfectedNoSymptoms[node][i]; + model[node].populations[{AgeGroup(i), InfectionState::InfectedNoSymptomsConfirmed}] = 0; + model[node].populations[{AgeGroup(i), InfectionState::InfectedSymptoms}] = + num_InfectedSymptoms[node][i]; + model[node].populations[{AgeGroup(i), InfectionState::InfectedSymptomsConfirmed}] = 0; + model[node].populations[{AgeGroup(i), InfectionState::InfectedSevere}] = num_InfectedSevere[node][i]; + // Only set the number of ICU patients here, if the date is not available in the data. + if (!is_divi_data_available(date)) { + model[node].populations[{AgeGroup(i), InfectionState::InfectedCritical}] = num_icu[node][i]; + } + model[node].populations[{AgeGroup(i), InfectionState::Dead}] = num_death[node][i]; + model[node].populations[{AgeGroup(i), InfectionState::Recovered}] = num_rec[node][i]; + } + } + else { + log_warning("No infections reported on date {} for region {}. Population data has not been set.", date, + region[node]); + } + } + return success(); +} + /** * @brief Sets the infected population for a given model based on confirmed cases data. Here, we * read the case data from a file. @@ -166,8 +262,8 @@ IOResult set_confirmed_cases_data(std::vector>& model, const std std::vector const& region, Date date, const std::vector& scaling_factor_inf) { - BOOST_OUTCOME_TRY(auto&& case_data, mio::read_confirmed_cases_data(path)); - BOOST_OUTCOME_TRY(set_confirmed_cases_data(model, case_data, region, date, scaling_factor_inf)); + BOOST_OUTCOME_TRY(auto&& case_data, mio::read_confirmed_cases_noage(path)); + BOOST_OUTCOME_TRY(set_confirmed_cases_noage(model, case_data, region, date, scaling_factor_inf)); return success(); } @@ -388,10 +484,10 @@ IOResult read_input_data_county(std::vector& model, Date date, cons { BOOST_OUTCOME_TRY( details::set_divi_data(model, path_join(pydata_dir, "county_divi_ma7.json"), county, date, scaling_factor_icu)); - BOOST_OUTCOME_TRY(details::set_confirmed_cases_data(model, path_join(pydata_dir, "cases_all_county_age_ma7.json"), + BOOST_OUTCOME_TRY(details::set_confirmed_cases_data(model, path_join(pydata_dir, "cases_all_county_ma7.json"), county, date, scaling_factor_inf)); - BOOST_OUTCOME_TRY( - details::set_population_data(model, path_join(pydata_dir, "county_current_population.json"), county)); + BOOST_OUTCOME_TRY(details::set_population_data( + model, path_join(pydata_dir, "county_current_population_aggregated.json"), county)); if (export_time_series) { // Use only if extrapolated real data is needed for comparison. EXPENSIVE ! diff --git a/pycode/examples/simulation/graph_germany_nuts3.py b/pycode/examples/simulation/graph_germany_nuts3.py new file mode 100644 index 0000000000..ff42b87307 --- /dev/null +++ b/pycode/examples/simulation/graph_germany_nuts3.py @@ -0,0 +1,194 @@ +############################################################################# +# Copyright (C) 2020-2025 MEmilio +# +# Authors: Henrik Zunker +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# +import numpy as np +import datetime +import os +import memilio.simulation as mio +import memilio.simulation.osecir as osecir +import matplotlib.pyplot as plt + +from enum import Enum +from memilio.simulation.osecir import (Model, Simulation, + interpolate_simulation_result) + + +class Simulation: + """ """ + + def __init__(self, data_dir, start_date, results_dir): + self.num_groups = 1 + self.data_dir = data_dir + self.start_date = start_date + self.results_dir = results_dir + if not os.path.exists(self.results_dir): + os.makedirs(self.results_dir) + + def set_covid_parameters(self, model): + """ + + :param model: + + """ + model.parameters.TimeExposed[mio.AgeGroup(0)] = 3.335 + model.parameters.TimeInfectedNoSymptoms[mio.AgeGroup(0)] = 2.58916 + model.parameters.TimeInfectedSymptoms[mio.AgeGroup(0)] = 6.94547 + model.parameters.TimeInfectedSevere[mio.AgeGroup(0)] = 7.28196 + model.parameters.TimeInfectedCritical[mio.AgeGroup(0)] = 13.066 + + # probabilities + model.parameters.TransmissionProbabilityOnContact[mio.AgeGroup(0)] = 0.07333 + model.parameters.RelativeTransmissionNoSymptoms[mio.AgeGroup(0)] = 1 + + model.parameters.RecoveredPerInfectedNoSymptoms[mio.AgeGroup(0)] = 0.2069 + model.parameters.SeverePerInfectedSymptoms[mio.AgeGroup(0)] = 0.07864 + model.parameters.CriticalPerSevere[mio.AgeGroup(0)] = 0.17318 + model.parameters.DeathsPerCritical[mio.AgeGroup(0)] = 0.21718 + + # start day is set to the n-th day of the year + model.parameters.StartDay = self.start_date.timetuple().tm_yday + + model.parameters.Seasonality = mio.UncertainValue(0.2) + + def set_contact_matrices(self, model): + """ + + :param model: + + """ + contact_matrices = mio.ContactMatrixGroup(1, self.num_groups) + + baseline = np.ones((self.num_groups, self.num_groups)) * 7.95 + minimum = np.ones((self.num_groups, self.num_groups)) * 0 + contact_matrices[0] = mio.ContactMatrix(baseline, minimum) + model.parameters.ContactPatterns.cont_freq_mat = contact_matrices + + def set_npis(self, params, end_date, damping_value): + """ + + :param params: + :param end_date: + + """ + + start_damping = datetime.date( + 2020, 12, 18) + + if start_damping < end_date: + start_date = (start_damping - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping(mio.Damping(np.r_[damping_value], t=start_date)) + + def get_graph(self, end_date, damping_value): + """ + + :param end_date: + + """ + print("Initializing model...") + model = Model(self.num_groups) + self.set_covid_parameters(model) + self.set_contact_matrices(model) + self.set_npis(model.parameters, end_date, damping_value) + print("Model initialized.") + + graph = osecir.ModelGraph() + + scaling_factor_infected = [2.5] + scaling_factor_icu = 1.0 + tnt_capacity_factor = 7.5 / 100000. + + data_dir_Germany = os.path.join(self.data_dir, "Germany") + mobility_data_file = os.path.join( + data_dir_Germany, "mobility", "commuter_mobility_2022.txt") + pydata_dir = os.path.join(data_dir_Germany, "pydata") + + path_population_data = os.path.join(pydata_dir, + "county_current_population_aggregated.json") + + print("Setting nodes...") + mio.osecir.set_nodes( + model.parameters, + mio.Date(self.start_date.year, + self.start_date.month, self.start_date.day), + mio.Date(end_date.year, + end_date.month, end_date.day), pydata_dir, + path_population_data, True, graph, scaling_factor_infected, + scaling_factor_icu, 0, 0, False, False) + + print("Setting edges...") + mio.osecir.set_edges( + mobility_data_file, graph, 1) + + print("Graph created.") + + return graph + + def run(self, num_days_sim, damping_value, save_graph=True): + """ + + :param num_days_sim: + :param num_runs: (Default value = 10) + :param save_graph: (Default value = True) + :param create_gif: (Default value = True) + + """ + mio.set_log_level(mio.LogLevel.Warning) + end_date = self.start_date + datetime.timedelta(days=num_days_sim) + num_runs = 10 + + graph = self.get_graph(end_date, damping_value) + + if save_graph: + path_graph = os.path.join(self.results_dir, "graph") + if not os.path.exists(path_graph): + os.makedirs(path_graph) + osecir.write_graph(graph, path_graph) + + mobility_graph = osecir.MobilityGraph() + for node_idx in range(graph.num_nodes): + mobility_graph.add_node(graph.get_node(node_idx).id, graph.get_node(node_idx).property) + for edge_idx in range(graph.num_edges): + mobility_graph.add_edge( + graph.get_edge(edge_idx).start_node_idx, + graph.get_edge(edge_idx).end_node_idx, + graph.get_edge(edge_idx).property) + + mobility_sim = osecir.MobilitySimulation(mobility_graph, t0=0, dt=0.5) + mobility_sim.advance(num_days_sim) + + osecir.interpolate_simulation_result( + mobility_sim.graph.get_node(0).property.result).export_csv('test.csv') + + return 0 + +def run_germany_nuts3_simulation(damping_value): + mio.set_log_level(mio.LogLevel.Warning) + file_path = os.path.dirname(os.path.abspath(__file__)) + + sim = Simulation( + data_dir=os.path.join(file_path, "../../../data"), + start_date=datetime.date(year=2020, month=12, day=12), + results_dir=os.path.join(file_path, "../../../results_osecir")) + num_days_sim = 50 + + sim.run(num_days_sim, damping_value) + + +if __name__ == "__main__": + run_germany_nuts3_simulation(damping_value=0.5) diff --git a/pycode/memilio-epidata/memilio/epidata/getPopulationData.py b/pycode/memilio-epidata/memilio/epidata/getPopulationData.py index c7d0b401de..45f62bb2df 100644 --- a/pycode/memilio-epidata/memilio/epidata/getPopulationData.py +++ b/pycode/memilio-epidata/memilio/epidata/getPopulationData.py @@ -140,6 +140,7 @@ def export_population_dataframe(df_pop: pd.DataFrame, directory: str, file_forma columns=dd.EngEng["idCounty"]) gd.write_dataframe(df_pop_export, directory, filename, file_format) + gd.write_dataframe(df_pop_export.drop(columns=new_cols[2:]), directory, filename + '_aggregated', file_format) return df_pop_export diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp b/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp index 50ea9b65b8..591ef4bfc0 100644 --- a/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp +++ b/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp @@ -252,14 +252,14 @@ PYBIND11_MODULE(_simulation_osecir, m) const std::string& data_dir, const std::string& population_data_path, bool is_node_for_county, mio::Graph, mio::MobilityParameters>& params_graph, const std::vector& scaling_factor_inf, double scaling_factor_icu, double tnt_capacity_factor, - int num_days = 0, bool export_time_series = false) { + int num_days = 0, bool export_time_series = false, bool rki_age_groups = true) { auto result = mio::set_nodes< mio::osecir::TestAndTraceCapacity, mio::osecir::ContactPatterns, mio::osecir::Model, mio::MobilityParameters, mio::osecir::Parameters, decltype(mio::osecir::read_input_data_county>), decltype(mio::get_node_ids)>( params, start_date, end_date, data_dir, population_data_path, is_node_for_county, params_graph, mio::osecir::read_input_data_county>, mio::get_node_ids, scaling_factor_inf, - scaling_factor_icu, tnt_capacity_factor, num_days, export_time_series); + scaling_factor_icu, tnt_capacity_factor, num_days, export_time_series, rki_age_groups); return pymio::check_and_throw(result); }, py::return_value_policy::move); @@ -278,12 +278,11 @@ PYBIND11_MODULE(_simulation_osecir, m) auto mobile_comp = {mio::osecir::InfectionState::Susceptible, mio::osecir::InfectionState::Exposed, mio::osecir::InfectionState::InfectedNoSymptoms, mio::osecir::InfectionState::InfectedSymptoms, mio::osecir::InfectionState::Recovered}; - auto weights = std::vector{0., 0., 1.0, 1.0, 0.33, 0., 0.}; + // auto weights = std::vector{0., 0., 1.0, 1.0, 0.33, 0., 0.}; auto result = mio::set_edges, mio::MobilityParameters, mio::MobilityCoefficientGroup, mio::osecir::InfectionState, - decltype(mio::read_mobility_plain)>(mobility_data_file, params_graph, - mobile_comp, contact_locations_size, - mio::read_mobility_plain, weights); + decltype(mio::read_mobility_plain)>( + mobility_data_file, params_graph, mobile_comp, contact_locations_size, mio::read_mobility_plain, {}); return pymio::check_and_throw(result); }, py::return_value_policy::move); From 1c25e1820362dac877bb3aea0df7252e1595ff35 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Thu, 17 Jul 2025 09:25:43 +0200 Subject: [PATCH 02/73] create training data --- .../simulation/graph_germany_nuts3.py | 61 ++++++++++++++++++- 1 file changed, 58 insertions(+), 3 deletions(-) diff --git a/pycode/examples/simulation/graph_germany_nuts3.py b/pycode/examples/simulation/graph_germany_nuts3.py index ff42b87307..5a8bcbd538 100644 --- a/pycode/examples/simulation/graph_germany_nuts3.py +++ b/pycode/examples/simulation/graph_germany_nuts3.py @@ -28,6 +28,8 @@ from memilio.simulation.osecir import (Model, Simulation, interpolate_simulation_result) +import pickle + class Simulation: """ """ @@ -172,10 +174,15 @@ def run(self, num_days_sim, damping_value, save_graph=True): mobility_sim = osecir.MobilitySimulation(mobility_graph, t0=0, dt=0.5) mobility_sim.advance(num_days_sim) + results = [] + for node_idx in range(graph.num_nodes): + results.append(osecir.interpolate_simulation_result( + mobility_sim.graph.get_node(node_idx).property.result)) + osecir.interpolate_simulation_result( mobility_sim.graph.get_node(0).property.result).export_csv('test.csv') - return 0 + return results def run_germany_nuts3_simulation(damping_value): mio.set_log_level(mio.LogLevel.Warning) @@ -187,8 +194,56 @@ def run_germany_nuts3_simulation(damping_value): results_dir=os.path.join(file_path, "../../../results_osecir")) num_days_sim = 50 - sim.run(num_days_sim, damping_value) + results = sim.run(num_days_sim, damping_value) + + return {"region" + str(region): results[region] for region in range(len(results))} +def prior(): + damping_value = np.random.uniform(0.0, 1.0) + return {"damping_value": damping_value} if __name__ == "__main__": - run_germany_nuts3_simulation(damping_value=0.5) + import os + os.environ["KERAS_BACKEND"] = "tensorflow" + + import bayesflow as bf + + simulator = bf.simulators.make_simulator([prior, run_germany_nuts3_simulation]) + trainings_data = simulator.sample(5) + + with open('trainings_data.pickle', 'wb') as f: + pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) + + # with open('trainings_data.pickle', 'rb') as f: + # test = pickle.load(f) + + # test = {k:v for k, v in test.items() if k in ('damping_value', 'region0')} + # print("Loaded training data:", test) + + + # trainings_data = simulator.sample(2) + # validation_data = simulator.sample(2) + + # adapter = ( + # bf.Adapter() + # .to_array() + # .convert_dtype("float64", "float32") + # .constrain("damping_value", lower=0.0, upper=1.0) + # .rename("damping_value", "inference_variables") + # .rename("region0", "summary_variables") + # #.standardize("summary_variables") + # ) + + # summary_network = bf.networks.TimeSeriesNetwork(summary_dim=4) + # inference_network = bf.networks.CouplingFlow() + + # workflow = bf.BasicWorkflow( + # simulator=simulator, + # adapter=adapter, + # summary_network=summary_network, + # inference_network=inference_network + # ) + + # history = workflow.fit_offline(data=trainings_data, epochs=2, batch_size=2, validation_data=validation_data) + # f = bf.diagnostics.plots.loss(history) + # run_germany_nuts3_simulation(damping_value=0.5) From 6ee021fac66f79c4341972e348a81d841c1a79f8 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Tue, 22 Jul 2025 12:08:06 +0200 Subject: [PATCH 03/73] [ci skip] add first draft for nuts1 simulation of germany --- .../simulation/graph_germany_nuts1.py | 250 ++++++++++++++++++ .../simulation/graph_germany_nuts3.py | 72 ++--- 2 files changed, 287 insertions(+), 35 deletions(-) create mode 100644 pycode/examples/simulation/graph_germany_nuts1.py diff --git a/pycode/examples/simulation/graph_germany_nuts1.py b/pycode/examples/simulation/graph_germany_nuts1.py new file mode 100644 index 0000000000..ca48fbd2fa --- /dev/null +++ b/pycode/examples/simulation/graph_germany_nuts1.py @@ -0,0 +1,250 @@ +############################################################################# +# Copyright (C) 2020-2025 MEmilio +# +# Authors: Henrik Zunker +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# +import numpy as np +import datetime +import os +import memilio.simulation as mio +import memilio.simulation.osecir as osecir +import matplotlib.pyplot as plt + +from enum import Enum +from memilio.simulation.osecir import (Model, Simulation, + interpolate_simulation_result) + +import pickle + + +class Simulation: + """ """ + + def __init__(self, data_dir, start_date, results_dir): + self.num_groups = 1 + self.data_dir = data_dir + self.start_date = start_date + self.results_dir = results_dir + if not os.path.exists(self.results_dir): + os.makedirs(self.results_dir) + + def set_covid_parameters(self, model): + """ + + :param model: + + """ + model.parameters.TimeExposed[mio.AgeGroup(0)] = 3.335 + model.parameters.TimeInfectedNoSymptoms[mio.AgeGroup(0)] = 2.58916 + model.parameters.TimeInfectedSymptoms[mio.AgeGroup(0)] = 6.94547 + model.parameters.TimeInfectedSevere[mio.AgeGroup(0)] = 7.28196 + model.parameters.TimeInfectedCritical[mio.AgeGroup(0)] = 13.066 + + # probabilities + model.parameters.TransmissionProbabilityOnContact[mio.AgeGroup(0)] = 0.07333 + model.parameters.RelativeTransmissionNoSymptoms[mio.AgeGroup(0)] = 1 + + model.parameters.RecoveredPerInfectedNoSymptoms[mio.AgeGroup(0)] = 0.2069 + model.parameters.SeverePerInfectedSymptoms[mio.AgeGroup(0)] = 0.07864 + model.parameters.CriticalPerSevere[mio.AgeGroup(0)] = 0.17318 + model.parameters.DeathsPerCritical[mio.AgeGroup(0)] = 0.21718 + + # start day is set to the n-th day of the year + model.parameters.StartDay = self.start_date.timetuple().tm_yday + + model.parameters.Seasonality = mio.UncertainValue(0.2) + + def set_contact_matrices(self, model): + """ + + :param model: + + """ + contact_matrices = mio.ContactMatrixGroup(1, self.num_groups) + + baseline = np.ones((self.num_groups, self.num_groups)) * 7.95 + minimum = np.ones((self.num_groups, self.num_groups)) * 0 + contact_matrices[0] = mio.ContactMatrix(baseline, minimum) + model.parameters.ContactPatterns.cont_freq_mat = contact_matrices + + def set_npis(self, params, end_date, damping_value): + """ + + :param params: + :param end_date: + + """ + + start_damping = datetime.date( + 2020, 12, 18) + + if start_damping < end_date: + start_date = (start_damping - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping(mio.Damping(np.r_[damping_value], t=start_date)) + + def get_graph(self, end_date, damping_value): + """ + + :param end_date: + + """ + print("Initializing model...") + model = Model(self.num_groups) + self.set_covid_parameters(model) + self.set_contact_matrices(model) + self.set_npis(model.parameters, end_date, damping_value) + print("Model initialized.") + + graph = osecir.ModelGraph() + + scaling_factor_infected = [2.5] + scaling_factor_icu = 1.0 + tnt_capacity_factor = 7.5 / 100000. + + data_dir_Germany = os.path.join(self.data_dir, "Germany") + mobility_data_file = os.path.join( + data_dir_Germany, "mobility", "commuter_mobility_2022_states.txt") + pydata_dir = os.path.join(data_dir_Germany, "pydata") + + path_population_data = os.path.join(pydata_dir, + "county_current_population_states.json") + + print("Setting nodes...") + mio.osecir.set_nodes( + model.parameters, + mio.Date(self.start_date.year, + self.start_date.month, self.start_date.day), + mio.Date(end_date.year, + end_date.month, end_date.day), pydata_dir, + path_population_data, True, graph, scaling_factor_infected, + scaling_factor_icu, 0, 0, False, False) + + print("Setting edges...") + mio.osecir.set_edges( + mobility_data_file, graph, 1) + + print("Graph created.") + + return graph + + def run(self, num_days_sim, damping_value, save_graph=True): + """ + + :param num_days_sim: + :param num_runs: (Default value = 10) + :param save_graph: (Default value = True) + :param create_gif: (Default value = True) + + """ + mio.set_log_level(mio.LogLevel.Warning) + end_date = self.start_date + datetime.timedelta(days=num_days_sim) + num_runs = 10 + + graph = self.get_graph(end_date, damping_value) + + if save_graph: + path_graph = os.path.join(self.results_dir, "graph") + if not os.path.exists(path_graph): + os.makedirs(path_graph) + osecir.write_graph(graph, path_graph) + + mobility_graph = osecir.MobilityGraph() + for node_idx in range(graph.num_nodes): + mobility_graph.add_node(graph.get_node(node_idx).id, graph.get_node(node_idx).property) + for edge_idx in range(graph.num_edges): + mobility_graph.add_edge( + graph.get_edge(edge_idx).start_node_idx, + graph.get_edge(edge_idx).end_node_idx, + graph.get_edge(edge_idx).property) + + mobility_sim = osecir.MobilitySimulation(mobility_graph, t0=0, dt=0.5) + mobility_sim.advance(num_days_sim) + + results = [] + for node_idx in range(graph.num_nodes): + results.append(osecir.interpolate_simulation_result( + mobility_sim.graph.get_node(node_idx).property.result)) + + osecir.interpolate_simulation_result( + mobility_sim.graph.get_node(0).property.result).export_csv('test.csv') + + return results + +def run_germany_nuts1_simulation(damping_value): + mio.set_log_level(mio.LogLevel.Warning) + file_path = os.path.dirname(os.path.abspath(__file__)) + + sim = Simulation( + data_dir=os.path.join(file_path, "../../../data"), + start_date=datetime.date(year=2020, month=12, day=12), + results_dir=os.path.join(file_path, "../../../results_osecir")) + num_days_sim = 50 + + results = sim.run(num_days_sim, damping_value) + + return {"region" + str(region): results[region] for region in range(len(results))} + +def prior(): + damping_value = np.random.uniform(0.0, 1.0) + return {"damping_value": damping_value} + +if __name__ == "__main__": + + run_germany_nuts1_simulation(0.5) + # import os + # os.environ["KERAS_BACKEND"] = "jax" + + # import bayesflow as bf + + # simulator = bf.simulators.make_simulator([prior, run_germany_nuts3_simulation]) + # # trainings_data = simulator.sample(5) + + # with open('trainings_data.pickle', 'wb') as f: + # pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) + + # with open('trainings_data.pickle', 'rb') as f: + # trainings_data = pickle.load(f) + + # # trainings_data = {k:v for k, v in trainings_data.items() if k in ('damping_value', 'region0', 'region1')} + # print("Loaded training data:", trainings_data) + + + # trainings_data = simulator.sample(2) + # validation_data = simulator.sample(2) + + # adapter = ( + # bf.Adapter() + # .to_array() + # .convert_dtype("float64", "float32") + # .constrain("damping_value", lower=0.0, upper=1.0) + # .concatenate(["region"+str(region) for region in range(len(trainings_data)-1)], into="summary_variables") + # .rename("damping_value", "inference_variables") + # #.standardize("summary_variables") + # ) + + # summary_network = bf.networks.TimeSeriesNetwork(summary_dim=4) + # inference_network = bf.networks.CouplingFlow() + + # workflow = bf.BasicWorkflow( + # simulator=simulator, + # adapter=adapter, + # summary_network=summary_network, + # inference_network=inference_network + # ) + + # history = workflow.fit_offline(data=trainings_data, epochs=2, batch_size=2, validation_data=trainings_data) + # f = bf.diagnostics.plots.loss(history) diff --git a/pycode/examples/simulation/graph_germany_nuts3.py b/pycode/examples/simulation/graph_germany_nuts3.py index 5a8bcbd538..47faed6887 100644 --- a/pycode/examples/simulation/graph_germany_nuts3.py +++ b/pycode/examples/simulation/graph_germany_nuts3.py @@ -20,13 +20,13 @@ import numpy as np import datetime import os -import memilio.simulation as mio -import memilio.simulation.osecir as osecir +# import memilio.simulation as mio +# import memilio.simulation.osecir as osecir import matplotlib.pyplot as plt from enum import Enum -from memilio.simulation.osecir import (Model, Simulation, - interpolate_simulation_result) +# from memilio.simulation.osecir import (Model, Simulation, + # interpolate_simulation_result) import pickle @@ -204,46 +204,48 @@ def prior(): if __name__ == "__main__": import os - os.environ["KERAS_BACKEND"] = "tensorflow" + os.environ["KERAS_BACKEND"] = "jax" import bayesflow as bf simulator = bf.simulators.make_simulator([prior, run_germany_nuts3_simulation]) - trainings_data = simulator.sample(5) + # trainings_data = simulator.sample(5) - with open('trainings_data.pickle', 'wb') as f: - pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) + # with open('trainings_data.pickle', 'wb') as f: + # pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) - # with open('trainings_data.pickle', 'rb') as f: - # test = pickle.load(f) + with open('trainings_data.pickle', 'rb') as f: + trainings_data = pickle.load(f) - # test = {k:v for k, v in test.items() if k in ('damping_value', 'region0')} - # print("Loaded training data:", test) + trainings_data['region0'] = trainings_data['region0'][8] + + trainings_data = {k:v for k, v in trainings_data.items() if k in ('damping_value', 'region0', 'region1')} + print("Loaded training data:", trainings_data) # trainings_data = simulator.sample(2) # validation_data = simulator.sample(2) - # adapter = ( - # bf.Adapter() - # .to_array() - # .convert_dtype("float64", "float32") - # .constrain("damping_value", lower=0.0, upper=1.0) - # .rename("damping_value", "inference_variables") - # .rename("region0", "summary_variables") - # #.standardize("summary_variables") - # ) - - # summary_network = bf.networks.TimeSeriesNetwork(summary_dim=4) - # inference_network = bf.networks.CouplingFlow() - - # workflow = bf.BasicWorkflow( - # simulator=simulator, - # adapter=adapter, - # summary_network=summary_network, - # inference_network=inference_network - # ) - - # history = workflow.fit_offline(data=trainings_data, epochs=2, batch_size=2, validation_data=validation_data) - # f = bf.diagnostics.plots.loss(history) - # run_germany_nuts3_simulation(damping_value=0.5) + adapter = ( + bf.Adapter() + .to_array() + .convert_dtype("float64", "float32") + .constrain("damping_value", lower=0.0, upper=1.0) + .concatenate(["region"+str(region) for region in range(len(trainings_data)-1)], into="summary_variables") + .rename("damping_value", "inference_variables") + .log("summary_variables", p1=True) + .standardize("summary_variables") + ) + + summary_network = bf.networks.TimeSeriesNetwork(summary_dim=4) + inference_network = bf.networks.CouplingFlow() + + workflow = bf.BasicWorkflow( + simulator=simulator, + adapter=adapter, + summary_network=summary_network, + inference_network=inference_network + ) + + history = workflow.fit_offline(data=trainings_data, epochs=2, batch_size=2, validation_data=trainings_data) + f = bf.diagnostics.plots.loss(history) From 284d5acede851b1ca5232c77567df9a2414c1086 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Tue, 22 Jul 2025 12:09:08 +0200 Subject: [PATCH 04/73] [ci skip] aggregate population data to states --- .../memilio-epidata/memilio/epidata/getPopulationData.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pycode/memilio-epidata/memilio/epidata/getPopulationData.py b/pycode/memilio-epidata/memilio/epidata/getPopulationData.py index 45f62bb2df..ad19b30bb6 100644 --- a/pycode/memilio-epidata/memilio/epidata/getPopulationData.py +++ b/pycode/memilio-epidata/memilio/epidata/getPopulationData.py @@ -141,6 +141,7 @@ def export_population_dataframe(df_pop: pd.DataFrame, directory: str, file_forma gd.write_dataframe(df_pop_export, directory, filename, file_format) gd.write_dataframe(df_pop_export.drop(columns=new_cols[2:]), directory, filename + '_aggregated', file_format) + gd.write_dataframe(aggregate_to_state_level(df_pop_export.drop(columns=new_cols[2:])), directory, filename + '_states', file_format) return df_pop_export @@ -444,6 +445,14 @@ def get_population_data(read_data: bool = dd.defaultDict['read_data'], ) return df_pop_export +def aggregate_to_state_level(df_pop: pd.DataFrame): + + countyIDtostateID = geoger.get_countyid_to_stateid_map() + + df_pop['ID_State'] = df_pop[dd.EngEng['idCounty']].map(countyIDtostateID) + df_pop = df_pop.drop(columns='ID_County').groupby('ID_State').sum() + return df_pop + def main(): """ Main program entry.""" From 3f6629dd3a105c1067fbf8a7098bea3173d6a15d Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Thu, 31 Jul 2025 08:41:57 +0200 Subject: [PATCH 05/73] changes for fitting and region-specific damping --- .../simulation/graph_germany_nuts3.py | 85 ++++++++++++------- shellscripts/fitting_graphmodel.sh | 13 +++ 2 files changed, 65 insertions(+), 33 deletions(-) create mode 100644 shellscripts/fitting_graphmodel.sh diff --git a/pycode/examples/simulation/graph_germany_nuts3.py b/pycode/examples/simulation/graph_germany_nuts3.py index 5a8bcbd538..7634a86b39 100644 --- a/pycode/examples/simulation/graph_germany_nuts3.py +++ b/pycode/examples/simulation/graph_germany_nuts3.py @@ -80,7 +80,7 @@ def set_contact_matrices(self, model): minimum = np.ones((self.num_groups, self.num_groups)) * 0 contact_matrices[0] = mio.ContactMatrix(baseline, minimum) model.parameters.ContactPatterns.cont_freq_mat = contact_matrices - + def set_npis(self, params, end_date, damping_value): """ @@ -96,7 +96,7 @@ def set_npis(self, params, end_date, damping_value): start_date = (start_damping - self.start_date).days params.ContactPatterns.cont_freq_mat[0].add_damping(mio.Damping(np.r_[damping_value], t=start_date)) - def get_graph(self, end_date, damping_value): + def get_graph(self, end_date): """ :param end_date: @@ -106,7 +106,6 @@ def get_graph(self, end_date, damping_value): model = Model(self.num_groups) self.set_covid_parameters(model) self.set_contact_matrices(model) - self.set_npis(model.parameters, end_date, damping_value) print("Model initialized.") graph = osecir.ModelGraph() @@ -141,7 +140,7 @@ def get_graph(self, end_date, damping_value): return graph - def run(self, num_days_sim, damping_value, save_graph=True): + def run(self, num_days_sim, damping_values, save_graph=True): """ :param num_days_sim: @@ -154,7 +153,7 @@ def run(self, num_days_sim, damping_value, save_graph=True): end_date = self.start_date + datetime.timedelta(days=num_days_sim) num_runs = 10 - graph = self.get_graph(end_date, damping_value) + graph = self.get_graph(end_date) if save_graph: path_graph = os.path.join(self.results_dir, "graph") @@ -164,13 +163,14 @@ def run(self, num_days_sim, damping_value, save_graph=True): mobility_graph = osecir.MobilityGraph() for node_idx in range(graph.num_nodes): - mobility_graph.add_node(graph.get_node(node_idx).id, graph.get_node(node_idx).property) + node = graph.get_node(node_idx) + self.set_npis(node.property.parameters, end_date, damping_values[node_idx]) + mobility_graph.add_node(node.id, node.property) for edge_idx in range(graph.num_edges): mobility_graph.add_edge( graph.get_edge(edge_idx).start_node_idx, graph.get_edge(edge_idx).end_node_idx, graph.get_edge(edge_idx).property) - mobility_sim = osecir.MobilitySimulation(mobility_graph, t0=0, dt=0.5) mobility_sim.advance(num_days_sim) @@ -178,13 +178,10 @@ def run(self, num_days_sim, damping_value, save_graph=True): for node_idx in range(graph.num_nodes): results.append(osecir.interpolate_simulation_result( mobility_sim.graph.get_node(node_idx).property.result)) - - osecir.interpolate_simulation_result( - mobility_sim.graph.get_node(0).property.result).export_csv('test.csv') return results -def run_germany_nuts3_simulation(damping_value): +def run_germany_nuts3_simulation(damping_values): mio.set_log_level(mio.LogLevel.Warning) file_path = os.path.dirname(os.path.abspath(__file__)) @@ -194,44 +191,60 @@ def run_germany_nuts3_simulation(damping_value): results_dir=os.path.join(file_path, "../../../results_osecir")) num_days_sim = 50 - results = sim.run(num_days_sim, damping_value) + results = sim.run(num_days_sim, damping_values) - return {"region" + str(region): results[region] for region in range(len(results))} + return {f'region{region}': results[region] for region in range(len(results))} def prior(): - damping_value = np.random.uniform(0.0, 1.0) - return {"damping_value": damping_value} + damping_values = np.random.uniform(0.0, 1.0, 400) + return {'damping_values': damping_values} + +# def load_divi_data(): +# file_path = os.path.dirname(os.path.abspath(__file__)) +# divi_path = os.path.join(file_path, "../../../data/Germany/pydata") + +# data = pd.read_json(os.path.join(divi_path, "county_divi_ma7.json")) +# data = data.drop(columns=['County', 'ICU_ventilated', 'Date']) +# divi_dict = {f"region{i}": data[f'region{i}'] for i in range(399)} +# print(divi_dict) + if __name__ == "__main__": + + # import pandas as pd + # load_divi_data() import os - os.environ["KERAS_BACKEND"] = "tensorflow" + os.environ["KERAS_BACKEND"] = "jax" import bayesflow as bf simulator = bf.simulators.make_simulator([prior, run_germany_nuts3_simulation]) - trainings_data = simulator.sample(5) - - with open('trainings_data.pickle', 'wb') as f: - pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) + trainings_data = simulator.sample(1000) - # with open('trainings_data.pickle', 'rb') as f: - # test = pickle.load(f) + for region in range(400): + trainings_data[f'region{region}'] = trainings_data[f'region{region}'][:,:, 8][..., np.newaxis] - # test = {k:v for k, v in test.items() if k in ('damping_value', 'region0')} - # print("Loaded training data:", test) + with open('trainings_data10.pickle', 'wb') as f: + pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) + # with open('trainings_data1.pickle', 'rb') as f: + # trainings_data = pickle.load(f) + # for i in range(9): + # with open(f'trainings_data{i+2}.pickle', 'rb') as f: + # data = pickle.load(f) + # trainings_data = {k: np.concatenate([trainings_data[k], data[k]]) for k in trainings_data.keys()} - # trainings_data = simulator.sample(2) - # validation_data = simulator.sample(2) + # with open('validation_data.pickle', 'rb') as f: + # validation_data = pickle.load(f) # adapter = ( # bf.Adapter() # .to_array() # .convert_dtype("float64", "float32") - # .constrain("damping_value", lower=0.0, upper=1.0) - # .rename("damping_value", "inference_variables") - # .rename("region0", "summary_variables") - # #.standardize("summary_variables") + # .constrain("damping_values", lower=0.0, upper=1.0) + # .rename("damping_values", "inference_variables") + # .concatenate([f'region{i}' for i in range(400)], into="summary_variables", axis=-1) + # .log("summary_variables", p1=True) # ) # summary_network = bf.networks.TimeSeriesNetwork(summary_dim=4) @@ -244,6 +257,12 @@ def prior(): # inference_network=inference_network # ) - # history = workflow.fit_offline(data=trainings_data, epochs=2, batch_size=2, validation_data=validation_data) - # f = bf.diagnostics.plots.loss(history) - # run_germany_nuts3_simulation(damping_value=0.5) + # history = workflow.fit_offline(data=trainings_data, epochs=1, batch_size=32, validation_data=validation_data) + + # workflow.approximator.save(filepath=os.path.join(os.path.dirname(__file__), "model.keras")) + + # plots = workflow.plot_default_diagnostics(test_data=validation_data, calibration_ecdf_kwargs={'difference': True}) + # plots['losses'].savefig('losses.png') + # plots['recovery'].savefig('recovery.png') + # plots['calibration_ecdf'].savefig('calibration_ecdf.png') + # plots['z_score_contraction'].savefig('z_score_contraction.png') diff --git a/shellscripts/fitting_graphmodel.sh b/shellscripts/fitting_graphmodel.sh new file mode 100644 index 0000000000..dbaa905ebe --- /dev/null +++ b/shellscripts/fitting_graphmodel.sh @@ -0,0 +1,13 @@ +#!/bin/bash +#SBATCH -N 1 +#SBATCH -n 1 +#SBATCH -c 1 +#SBATCH -t 5-0:00:00 +#SBATCH --output=shellscripts/create_testdata_fitting-%A.out +#SBATCH --error=shellscripts/create_testdata_fitting-%A.err +#SBATCH --exclude="be-cpu05, be-gpu01" +#SBATCH --job-name=create_testdata_fitting + +module load PrgEnv/gcc13-openmpi-python +source venv/bin/activate +srun --cpu-bind=core python pycode/examples/simulation/graph_germany_nuts3.py \ No newline at end of file From f45f9c36eeb25c6b4f375816061e7622613db81a Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Thu, 31 Jul 2025 15:07:50 +0200 Subject: [PATCH 06/73] move to unicorn --- .../simulation/graph_germany_nuts3.py | 79 ++++++++++--------- 1 file changed, 40 insertions(+), 39 deletions(-) diff --git a/pycode/examples/simulation/graph_germany_nuts3.py b/pycode/examples/simulation/graph_germany_nuts3.py index 14863345c4..a03330448d 100644 --- a/pycode/examples/simulation/graph_germany_nuts3.py +++ b/pycode/examples/simulation/graph_germany_nuts3.py @@ -20,13 +20,13 @@ import numpy as np import datetime import os -# import memilio.simulation as mio -# import memilio.simulation.osecir as osecir +import memilio.simulation as mio +import memilio.simulation.osecir as osecir import matplotlib.pyplot as plt from enum import Enum -# from memilio.simulation.osecir import (Model, Simulation, - # interpolate_simulation_result) +from memilio.simulation.osecir import (Model, Simulation, + interpolate_simulation_result) import pickle @@ -151,7 +151,6 @@ def run(self, num_days_sim, damping_values, save_graph=True): """ mio.set_log_level(mio.LogLevel.Warning) end_date = self.start_date + datetime.timedelta(days=num_days_sim) - num_runs = 10 graph = self.get_graph(end_date) @@ -219,33 +218,35 @@ def prior(): import bayesflow as bf simulator = bf.simulators.make_simulator([prior, run_germany_nuts3_simulation]) - trainings_data = simulator.sample(1000) - - for region in range(400): - trainings_data[f'region{region}'] = trainings_data[f'region{region}'][:,:, 8][..., np.newaxis] - - with open('trainings_data10.pickle', 'wb') as f: - pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) - - # with open('trainings_data1.pickle', 'rb') as f: - # trainings_data = pickle.load(f) - # for i in range(9): - # with open(f'trainings_data{i+2}.pickle', 'rb') as f: - # data = pickle.load(f) - # trainings_data = {k: np.concatenate([trainings_data[k], data[k]]) for k in trainings_data.keys()} - - # with open('validation_data.pickle', 'rb') as f: - # validation_data = pickle.load(f) - - # adapter = ( - # bf.Adapter() - # .to_array() - # .convert_dtype("float64", "float32") - # .constrain("damping_values", lower=0.0, upper=1.0) - # .rename("damping_values", "inference_variables") - # .concatenate([f'region{i}' for i in range(400)], into="summary_variables", axis=-1) - # .log("summary_variables", p1=True) - # ) + # trainings_data = simulator.sample(1000) + + # for region in range(400): + # trainings_data[f'region{region}'] = trainings_data[f'region{region}'][:,:, 8][..., np.newaxis] + + # with open('trainings_data7.pickle', 'wb') as f: + # pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) + + with open('trainings_data1.pickle', 'rb') as f: + trainings_data = pickle.load(f) + for i in range(9): + with open(f'trainings_data{i+2}.pickle', 'rb') as f: + data = pickle.load(f) + trainings_data = {k: np.concatenate([trainings_data[k], data[k]]) for k in trainings_data.keys()} + + with open('validation_data.pickle', 'rb') as f: + validation_data = pickle.load(f) + + adapter = ( + bf.Adapter() + .to_array() + .convert_dtype("float64", "float32") + .constrain("damping_values", lower=0.0, upper=1.0) + .rename("damping_values", "inference_variables") + .concatenate([f'region{i}' for i in range(400)], into="summary_variables", axis=-1) + .log("summary_variables", p1=True) + ) + + print("inference_variables shape:", adapter(trainings_data)["inference_variables"].shape) summary_network = bf.networks.TimeSeriesNetwork(summary_dim=4) inference_network = bf.networks.CouplingFlow() @@ -257,12 +258,12 @@ def prior(): inference_network=inference_network ) - # history = workflow.fit_offline(data=trainings_data, epochs=1, batch_size=32, validation_data=validation_data) + history = workflow.fit_offline(data=trainings_data, epochs=1, batch_size=32, validation_data=validation_data) - # workflow.approximator.save(filepath=os.path.join(os.path.dirname(__file__), "model.keras")) + workflow.approximator.save(filepath=os.path.join(os.path.dirname(__file__), "model.keras")) - # plots = workflow.plot_default_diagnostics(test_data=validation_data, calibration_ecdf_kwargs={'difference': True}) - # plots['losses'].savefig('losses.png') - # plots['recovery'].savefig('recovery.png') - # plots['calibration_ecdf'].savefig('calibration_ecdf.png') - # plots['z_score_contraction'].savefig('z_score_contraction.png') + plots = workflow.plot_default_diagnostics(test_data=validation_data, calibration_ecdf_kwargs={'difference': True}) + plots['losses'].savefig('losses_couplingflow.png') + plots['recovery'].savefig('recovery_couplingflow.png') + plots['calibration_ecdf'].savefig('calibration_ecdf_couplingflow.png') + plots['z_score_contraction'].savefig('z_score_contraction_couplingflow.png') From 8fb9fe28afe3d3df6ca91659723520479a9f9a6a Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Tue, 5 Aug 2025 12:34:00 +0200 Subject: [PATCH 07/73] changes for germany and states simulation --- cpp/memilio/io/epi_data.cpp | 14 ++++- cpp/memilio/io/epi_data.h | 2 + cpp/models/ode_secir/parameters_io.h | 8 +-- .../simulation/graph_germany_nuts1.py | 8 +-- .../memilio/epidata/getPopulationData.py | 5 +- .../metapopulation_mobility_instant.cpp | 2 +- .../simulation/bindings/models/osecir.cpp | 52 +++++++++++++++++-- 7 files changed, 78 insertions(+), 13 deletions(-) diff --git a/cpp/memilio/io/epi_data.cpp b/cpp/memilio/io/epi_data.cpp index 97210aef98..2e95ad24fb 100644 --- a/cpp/memilio/io/epi_data.cpp +++ b/cpp/memilio/io/epi_data.cpp @@ -45,11 +45,14 @@ IOResult> get_node_ids(const std::string& path, bool is_node_fo } } else { - if (entry.district_id) { + if (entry.state_id) { + id.push_back(entry.state_id->get()); + } + else if (entry.district_id) { id.push_back(entry.district_id->get()); } else { - return failure(StatusCode::InvalidValue, "Population data file is missing district ids."); + return failure(StatusCode::InvalidValue, "Population data file is missing district and state ids."); } } } @@ -58,6 +61,13 @@ IOResult> get_node_ids(const std::string& path, bool is_node_fo id.erase(std::unique(id.begin(), id.end()), id.end()); return success(id); } + +IOResult> get_country_id(const std::string& /*path*/, bool /*is_node_for_county*/, + bool /*rki_age_groups*/) +{ + std::vector id = {0}; + return success(id); +} } // namespace mio #endif //MEMILIO_HAS_JSONCPP diff --git a/cpp/memilio/io/epi_data.h b/cpp/memilio/io/epi_data.h index 8c25f96c84..3c3dac4c12 100644 --- a/cpp/memilio/io/epi_data.h +++ b/cpp/memilio/io/epi_data.h @@ -480,6 +480,8 @@ IOResult set_vaccination_data_age_group_names(std::vector nam * @return list of node ids. */ IOResult> get_node_ids(const std::string& path, bool is_node_for_county, bool rki_age_groups = true); +IOResult> get_country_id(const std::string& /*path*/, bool /*is_node_for_county*/, + bool /*rki_age_groups*/ = true); /** * Represents an entry in a vaccination data file. diff --git a/cpp/models/ode_secir/parameters_io.h b/cpp/models/ode_secir/parameters_io.h index c721bab032..ef5cc74a32 100644 --- a/cpp/models/ode_secir/parameters_io.h +++ b/cpp/models/ode_secir/parameters_io.h @@ -429,9 +429,10 @@ IOResult export_input_data_county_timeseries( * @param[in] pydata_dir Directory of files. */ template -IOResult read_input_data_germany(std::vector& model, Date date, +IOResult read_input_data_germany(std::vector& model, Date date, std::vector& /*state*/, const std::vector& scaling_factor_inf, double scaling_factor_icu, - const std::string& pydata_dir) + const std::string& pydata_dir, int /*num_days*/ = 0, + bool /*export_time_series*/ = false) { BOOST_OUTCOME_TRY( details::set_divi_data(model, path_join(pydata_dir, "germany_divi.json"), {0}, date, scaling_factor_icu)); @@ -454,7 +455,8 @@ IOResult read_input_data_germany(std::vector& model, Date date, template IOResult read_input_data_state(std::vector& model, Date date, std::vector& state, const std::vector& scaling_factor_inf, double scaling_factor_icu, - const std::string& pydata_dir) + const std::string& pydata_dir, int /*num_days*/ = 0, + bool /*export_time_series*/ = false) { BOOST_OUTCOME_TRY( diff --git a/pycode/examples/simulation/graph_germany_nuts1.py b/pycode/examples/simulation/graph_germany_nuts1.py index ca48fbd2fa..d9aab2a225 100644 --- a/pycode/examples/simulation/graph_germany_nuts1.py +++ b/pycode/examples/simulation/graph_germany_nuts1.py @@ -68,6 +68,8 @@ def set_covid_parameters(self, model): model.parameters.Seasonality = mio.UncertainValue(0.2) + model.parameters.end_commuter_detection = 50. + def set_contact_matrices(self, model): """ @@ -124,14 +126,14 @@ def get_graph(self, end_date, damping_value): "county_current_population_states.json") print("Setting nodes...") - mio.osecir.set_nodes( + mio.osecir.set_nodes_states( model.parameters, mio.Date(self.start_date.year, self.start_date.month, self.start_date.day), mio.Date(end_date.year, end_date.month, end_date.day), pydata_dir, - path_population_data, True, graph, scaling_factor_infected, - scaling_factor_icu, 0, 0, False, False) + path_population_data, False, graph, scaling_factor_infected, + scaling_factor_icu, 0.5, 0, False, False) print("Setting edges...") mio.osecir.set_edges( diff --git a/pycode/memilio-epidata/memilio/epidata/getPopulationData.py b/pycode/memilio-epidata/memilio/epidata/getPopulationData.py index ad19b30bb6..b96d5d6800 100644 --- a/pycode/memilio-epidata/memilio/epidata/getPopulationData.py +++ b/pycode/memilio-epidata/memilio/epidata/getPopulationData.py @@ -142,6 +142,8 @@ def export_population_dataframe(df_pop: pd.DataFrame, directory: str, file_forma gd.write_dataframe(df_pop_export, directory, filename, file_format) gd.write_dataframe(df_pop_export.drop(columns=new_cols[2:]), directory, filename + '_aggregated', file_format) gd.write_dataframe(aggregate_to_state_level(df_pop_export.drop(columns=new_cols[2:])), directory, filename + '_states', file_format) + df_pop_germany = pd.DataFrame({"ID": [0], "Population": [df_pop_export["Population"].sum()]}) + gd.write_dataframe(df_pop_germany, directory, filename + '_germany', file_format) return df_pop_export @@ -450,7 +452,8 @@ def aggregate_to_state_level(df_pop: pd.DataFrame): countyIDtostateID = geoger.get_countyid_to_stateid_map() df_pop['ID_State'] = df_pop[dd.EngEng['idCounty']].map(countyIDtostateID) - df_pop = df_pop.drop(columns='ID_County').groupby('ID_State').sum() + df_pop = df_pop.drop(columns='ID_County').groupby('ID_State', as_index=True).sum() + df_pop['ID_State'] = df_pop.index return df_pop diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/mobility/metapopulation_mobility_instant.cpp b/pycode/memilio-simulation/memilio/simulation/bindings/mobility/metapopulation_mobility_instant.cpp index 34f7e26ef6..4d28f16107 100644 --- a/pycode/memilio-simulation/memilio/simulation/bindings/mobility/metapopulation_mobility_instant.cpp +++ b/pycode/memilio-simulation/memilio/simulation/bindings/mobility/metapopulation_mobility_instant.cpp @@ -53,7 +53,7 @@ void bind_mobility_parameter_edge(py::module_& m, std::string const& name) }) .def_property_readonly( "property", - [](const mio::Edge>& self) -> auto& { + [](const mio::Edge>& self) -> auto& { return self.property; }, py::return_value_policy::reference_internal); diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp b/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp index 591ef4bfc0..ecbb6287e2 100644 --- a/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp +++ b/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp @@ -200,7 +200,15 @@ PYBIND11_MODULE(_simulation_osecir, m) mio::osecir::ParametersBase>(m, "Parameters") .def(py::init()) .def("check_constraints", &mio::osecir::Parameters::check_constraints) - .def("apply_constraints", &mio::osecir::Parameters::apply_constraints); + .def("apply_constraints", &mio::osecir::Parameters::apply_constraints) + .def_property( + "end_commuter_detection", + [](const mio::osecir::Parameters& self) -> auto { + return self.get_end_commuter_detection(); + }, + [](mio::osecir::Parameters& self, double p) { + self.get_end_commuter_detection() = p; + }); using Populations = mio::Populations; pymio::bind_Population(m, "Populations", mio::Tag::Populations>{}); @@ -264,6 +272,44 @@ PYBIND11_MODULE(_simulation_osecir, m) }, py::return_value_policy::move); + m.def( + "set_nodes_states", + [](const mio::osecir::Parameters& params, mio::Date start_date, mio::Date end_date, + const std::string& data_dir, const std::string& population_data_path, bool is_node_for_county, + mio::Graph, mio::MobilityParameters>& params_graph, + const std::vector& scaling_factor_inf, double scaling_factor_icu, double tnt_capacity_factor, + int num_days = 0, bool export_time_series = false, bool rki_age_groups = true) { + auto result = mio::set_nodes< + mio::osecir::TestAndTraceCapacity, mio::osecir::ContactPatterns, + mio::osecir::Model, mio::MobilityParameters, mio::osecir::Parameters, + decltype(mio::osecir::read_input_data_state>), decltype(mio::get_node_ids)>( + params, start_date, end_date, data_dir, population_data_path, is_node_for_county, params_graph, + mio::osecir::read_input_data_state>, mio::get_node_ids, scaling_factor_inf, + scaling_factor_icu, tnt_capacity_factor, num_days, export_time_series, rki_age_groups); + return pymio::check_and_throw(result); + }, + py::return_value_policy::move); + + m.def( + "set_node_germany", + [](const mio::osecir::Parameters& params, mio::Date start_date, mio::Date end_date, + const std::string& data_dir, const std::string& population_data_path, bool is_node_for_county, + mio::Graph, mio::MobilityParameters>& params_graph, + const std::vector& scaling_factor_inf, double scaling_factor_icu, double tnt_capacity_factor, + int num_days = 0, bool export_time_series = false, bool rki_age_groups = true) { + auto result = mio::set_nodes, + mio::osecir::ContactPatterns, mio::osecir::Model, + mio::MobilityParameters, mio::osecir::Parameters, + decltype(mio::osecir::read_input_data_germany>), + decltype(mio::get_country_id)>( + params, start_date, end_date, data_dir, population_data_path, is_node_for_county, params_graph, + mio::osecir::read_input_data_germany>, mio::get_country_id, + scaling_factor_inf, scaling_factor_icu, tnt_capacity_factor, num_days, export_time_series, + rki_age_groups); + return pymio::check_and_throw(result); + }, + py::return_value_policy::move); + pymio::iterable_enum(m, "ContactLocation") .value("Home", ContactLocation::Home) .value("School", ContactLocation::School) @@ -278,11 +324,11 @@ PYBIND11_MODULE(_simulation_osecir, m) auto mobile_comp = {mio::osecir::InfectionState::Susceptible, mio::osecir::InfectionState::Exposed, mio::osecir::InfectionState::InfectedNoSymptoms, mio::osecir::InfectionState::InfectedSymptoms, mio::osecir::InfectionState::Recovered}; - // auto weights = std::vector{0., 0., 1.0, 1.0, 0.33, 0., 0.}; + // auto weights = std::vector{1.0}; auto result = mio::set_edges, mio::MobilityParameters, mio::MobilityCoefficientGroup, mio::osecir::InfectionState, decltype(mio::read_mobility_plain)>( - mobility_data_file, params_graph, mobile_comp, contact_locations_size, mio::read_mobility_plain, {}); + mobility_data_file, params_graph, mobile_comp, contact_locations_size, mio::read_mobility_plain, {1.0}); return pymio::check_and_throw(result); }, py::return_value_policy::move); From 2dbe4772fa3279d4044d631f69ac591ba0a5a997 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Wed, 6 Aug 2025 09:00:18 +0200 Subject: [PATCH 08/73] first draft germany simulation --- .../simulation/graph_germany_nuts0.py | 243 ++++++++++++++++++ 1 file changed, 243 insertions(+) create mode 100644 pycode/examples/simulation/graph_germany_nuts0.py diff --git a/pycode/examples/simulation/graph_germany_nuts0.py b/pycode/examples/simulation/graph_germany_nuts0.py new file mode 100644 index 0000000000..9783230b28 --- /dev/null +++ b/pycode/examples/simulation/graph_germany_nuts0.py @@ -0,0 +1,243 @@ +############################################################################# +# Copyright (C) 2020-2025 MEmilio +# +# Authors: Henrik Zunker +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# +import numpy as np +import datetime +import os +import memilio.simulation as mio +import memilio.simulation.osecir as osecir +import matplotlib.pyplot as plt + +from enum import Enum +from memilio.simulation.osecir import (Model, Simulation, + interpolate_simulation_result) + +import pickle + + +class Simulation: + """ """ + + def __init__(self, data_dir, start_date, results_dir): + self.num_groups = 1 + self.data_dir = data_dir + self.start_date = start_date + self.results_dir = results_dir + if not os.path.exists(self.results_dir): + os.makedirs(self.results_dir) + + def set_covid_parameters(self, model): + """ + + :param model: + + """ + model.parameters.TimeExposed[mio.AgeGroup(0)] = 3.335 + model.parameters.TimeInfectedNoSymptoms[mio.AgeGroup(0)] = 2.58916 + model.parameters.TimeInfectedSymptoms[mio.AgeGroup(0)] = 6.94547 + model.parameters.TimeInfectedSevere[mio.AgeGroup(0)] = 7.28196 + model.parameters.TimeInfectedCritical[mio.AgeGroup(0)] = 13.066 + + # probabilities + model.parameters.TransmissionProbabilityOnContact[mio.AgeGroup(0)] = 0.07333 + model.parameters.RelativeTransmissionNoSymptoms[mio.AgeGroup(0)] = 1 + + model.parameters.RecoveredPerInfectedNoSymptoms[mio.AgeGroup(0)] = 0.2069 + model.parameters.SeverePerInfectedSymptoms[mio.AgeGroup(0)] = 0.07864 + model.parameters.CriticalPerSevere[mio.AgeGroup(0)] = 0.17318 + model.parameters.DeathsPerCritical[mio.AgeGroup(0)] = 0.21718 + + # start day is set to the n-th day of the year + model.parameters.StartDay = self.start_date.timetuple().tm_yday + + model.parameters.Seasonality = mio.UncertainValue(0.2) + + def set_contact_matrices(self, model): + """ + + :param model: + + """ + contact_matrices = mio.ContactMatrixGroup(1, self.num_groups) + + baseline = np.ones((self.num_groups, self.num_groups)) * 7.95 + minimum = np.ones((self.num_groups, self.num_groups)) * 0 + contact_matrices[0] = mio.ContactMatrix(baseline, minimum) + model.parameters.ContactPatterns.cont_freq_mat = contact_matrices + + def set_npis(self, params, end_date, damping_value): + """ + + :param params: + :param end_date: + + """ + + start_damping = datetime.date( + 2020, 12, 18) + + if start_damping < end_date: + start_date = (start_damping - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping(mio.Damping(np.r_[damping_value], t=start_date)) + + def get_graph(self, end_date, damping_value): + """ + + :param end_date: + + """ + print("Initializing model...") + model = Model(self.num_groups) + self.set_covid_parameters(model) + self.set_contact_matrices(model) + self.set_npis(model.parameters, end_date, damping_value) + print("Model initialized.") + + graph = osecir.ModelGraph() + + scaling_factor_infected = [2.5] + scaling_factor_icu = 1.0 + + data_dir_Germany = os.path.join(self.data_dir, "Germany") + pydata_dir = os.path.join(data_dir_Germany, "pydata") + + path_population_data = os.path.join(pydata_dir, + "county_current_population_germany.json") + + print("Setting nodes...") + mio.osecir.set_node_germany( + model.parameters, + mio.Date(self.start_date.year, + self.start_date.month, self.start_date.day), + mio.Date(end_date.year, + end_date.month, end_date.day), pydata_dir, + path_population_data, False, graph, scaling_factor_infected, + scaling_factor_icu, 0.9, 0, False, False) + + print("Graph created.") + + return graph + + def run(self, num_days_sim, damping_value, save_graph=True): + """ + + :param num_days_sim: + :param num_runs: (Default value = 10) + :param save_graph: (Default value = True) + :param create_gif: (Default value = True) + + """ + mio.set_log_level(mio.LogLevel.Warning) + end_date = self.start_date + datetime.timedelta(days=num_days_sim) + num_runs = 10 + + graph = self.get_graph(end_date, damping_value) + + if save_graph: + path_graph = os.path.join(self.results_dir, "graph") + if not os.path.exists(path_graph): + os.makedirs(path_graph) + osecir.write_graph(graph, path_graph) + + mobility_graph = osecir.MobilityGraph() + for node_idx in range(graph.num_nodes): + mobility_graph.add_node(graph.get_node(node_idx).id, graph.get_node(node_idx).property) + for edge_idx in range(graph.num_edges): + mobility_graph.add_edge( + graph.get_edge(edge_idx).start_node_idx, + graph.get_edge(edge_idx).end_node_idx, + graph.get_edge(edge_idx).property) + + mobility_sim = osecir.MobilitySimulation(mobility_graph, t0=0, dt=0.5) + mobility_sim.advance(num_days_sim) + + results = [] + for node_idx in range(graph.num_nodes): + results.append(osecir.interpolate_simulation_result( + mobility_sim.graph.get_node(node_idx).property.result)) + + osecir.interpolate_simulation_result( + mobility_sim.graph.get_node(0).property.result).export_csv('test.csv') + + return results + +def run_germany_nuts0_simulation(damping_value): + mio.set_log_level(mio.LogLevel.Warning) + file_path = os.path.dirname(os.path.abspath(__file__)) + + sim = Simulation( + data_dir=os.path.join(file_path, "../../../data"), + start_date=datetime.date(year=2020, month=12, day=12), + results_dir=os.path.join(file_path, "../../../results_osecir")) + num_days_sim = 50 + + results = sim.run(num_days_sim, damping_value) + + return {"region" + str(region): results[region] for region in range(len(results))} + +def prior(): + damping_value = np.random.uniform(0.0, 1.0) + return {"damping_value": damping_value} + +if __name__ == "__main__": + + run_germany_nuts0_simulation(0.5) + # import os + # os.environ["KERAS_BACKEND"] = "jax" + + # import bayesflow as bf + + # simulator = bf.simulators.make_simulator([prior, run_germany_nuts3_simulation]) + # # trainings_data = simulator.sample(5) + + # with open('trainings_data.pickle', 'wb') as f: + # pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) + + # with open('trainings_data.pickle', 'rb') as f: + # trainings_data = pickle.load(f) + + # # trainings_data = {k:v for k, v in trainings_data.items() if k in ('damping_value', 'region0', 'region1')} + # print("Loaded training data:", trainings_data) + + + # trainings_data = simulator.sample(2) + # validation_data = simulator.sample(2) + + # adapter = ( + # bf.Adapter() + # .to_array() + # .convert_dtype("float64", "float32") + # .constrain("damping_value", lower=0.0, upper=1.0) + # .concatenate(["region"+str(region) for region in range(len(trainings_data)-1)], into="summary_variables") + # .rename("damping_value", "inference_variables") + # #.standardize("summary_variables") + # ) + + # summary_network = bf.networks.TimeSeriesNetwork(summary_dim=4) + # inference_network = bf.networks.CouplingFlow() + + # workflow = bf.BasicWorkflow( + # simulator=simulator, + # adapter=adapter, + # summary_network=summary_network, + # inference_network=inference_network + # ) + + # history = workflow.fit_offline(data=trainings_data, epochs=2, batch_size=2, validation_data=trainings_data) + # f = bf.diagnostics.plots.loss(history) From c4d7cabef9ef826f44bf40a2e41affba485adcd0 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Fri, 8 Aug 2025 09:58:55 +0200 Subject: [PATCH 09/73] update scripts for fitting --- .../simulation/graph_germany_nuts1.py | 126 ++++---- .../simulation/graph_germany_nuts3.py | 64 ++-- .../graph_germany_nuts3_16dampings.py | 280 ++++++++++++++++++ shellscripts/fitting_graphmodel.sh | 12 +- 4 files changed, 395 insertions(+), 87 deletions(-) create mode 100644 pycode/examples/simulation/graph_germany_nuts3_16dampings.py diff --git a/pycode/examples/simulation/graph_germany_nuts1.py b/pycode/examples/simulation/graph_germany_nuts1.py index d9aab2a225..31a077171b 100644 --- a/pycode/examples/simulation/graph_germany_nuts1.py +++ b/pycode/examples/simulation/graph_germany_nuts1.py @@ -68,8 +68,6 @@ def set_covid_parameters(self, model): model.parameters.Seasonality = mio.UncertainValue(0.2) - model.parameters.end_commuter_detection = 50. - def set_contact_matrices(self, model): """ @@ -98,7 +96,7 @@ def set_npis(self, params, end_date, damping_value): start_date = (start_damping - self.start_date).days params.ContactPatterns.cont_freq_mat[0].add_damping(mio.Damping(np.r_[damping_value], t=start_date)) - def get_graph(self, end_date, damping_value): + def get_graph(self, end_date): """ :param end_date: @@ -108,7 +106,6 @@ def get_graph(self, end_date, damping_value): model = Model(self.num_groups) self.set_covid_parameters(model) self.set_contact_matrices(model) - self.set_npis(model.parameters, end_date, damping_value) print("Model initialized.") graph = osecir.ModelGraph() @@ -133,7 +130,7 @@ def get_graph(self, end_date, damping_value): mio.Date(end_date.year, end_date.month, end_date.day), pydata_dir, path_population_data, False, graph, scaling_factor_infected, - scaling_factor_icu, 0.5, 0, False, False) + scaling_factor_icu, 0, 0, False, False) print("Setting edges...") mio.osecir.set_edges( @@ -143,7 +140,7 @@ def get_graph(self, end_date, damping_value): return graph - def run(self, num_days_sim, damping_value, save_graph=True): + def run(self, num_days_sim, damping_values, save_graph=True): """ :param num_days_sim: @@ -154,9 +151,8 @@ def run(self, num_days_sim, damping_value, save_graph=True): """ mio.set_log_level(mio.LogLevel.Warning) end_date = self.start_date + datetime.timedelta(days=num_days_sim) - num_runs = 10 - graph = self.get_graph(end_date, damping_value) + graph = self.get_graph(end_date) if save_graph: path_graph = os.path.join(self.results_dir, "graph") @@ -166,13 +162,19 @@ def run(self, num_days_sim, damping_value, save_graph=True): mobility_graph = osecir.MobilityGraph() for node_idx in range(graph.num_nodes): - mobility_graph.add_node(graph.get_node(node_idx).id, graph.get_node(node_idx).property) + # if node_idx < 5: + node = graph.get_node(node_idx) + self.set_npis(node.property.parameters, end_date, damping_values[node_idx]) + mobility_graph.add_node(node.id, node.property) + # else: + # node = graph.get_node(node_idx) + # self.set_npis(node.property.parameters, end_date, 0.5) + # mobility_graph.add_node(node.id, node.property) for edge_idx in range(graph.num_edges): mobility_graph.add_edge( graph.get_edge(edge_idx).start_node_idx, graph.get_edge(edge_idx).end_node_idx, graph.get_edge(edge_idx).property) - mobility_sim = osecir.MobilitySimulation(mobility_graph, t0=0, dt=0.5) mobility_sim.advance(num_days_sim) @@ -180,13 +182,10 @@ def run(self, num_days_sim, damping_value, save_graph=True): for node_idx in range(graph.num_nodes): results.append(osecir.interpolate_simulation_result( mobility_sim.graph.get_node(node_idx).property.result)) - - osecir.interpolate_simulation_result( - mobility_sim.graph.get_node(0).property.result).export_csv('test.csv') return results -def run_germany_nuts1_simulation(damping_value): +def run_germany_nuts1_simulation(damping_values): mio.set_log_level(mio.LogLevel.Warning) file_path = os.path.dirname(os.path.abspath(__file__)) @@ -196,57 +195,70 @@ def run_germany_nuts1_simulation(damping_value): results_dir=os.path.join(file_path, "../../../results_osecir")) num_days_sim = 50 - results = sim.run(num_days_sim, damping_value) + results = sim.run(num_days_sim, damping_values) - return {"region" + str(region): results[region] for region in range(len(results))} + return {f'region{region}': results[region] for region in range(len(results))} def prior(): - damping_value = np.random.uniform(0.0, 1.0) - return {"damping_value": damping_value} + damping_values = np.random.uniform(0.0, 1.0, 16) + return {'damping_values': damping_values} if __name__ == "__main__": - run_germany_nuts1_simulation(0.5) - # import os - # os.environ["KERAS_BACKEND"] = "jax" - - # import bayesflow as bf - - # simulator = bf.simulators.make_simulator([prior, run_germany_nuts3_simulation]) - # # trainings_data = simulator.sample(5) - - # with open('trainings_data.pickle', 'wb') as f: - # pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) - - # with open('trainings_data.pickle', 'rb') as f: - # trainings_data = pickle.load(f) - - # # trainings_data = {k:v for k, v in trainings_data.items() if k in ('damping_value', 'region0', 'region1')} - # print("Loaded training data:", trainings_data) + import os + os.environ["KERAS_BACKEND"] = "tensorflow" + import bayesflow as bf - # trainings_data = simulator.sample(2) - # validation_data = simulator.sample(2) + simulator = bf.simulators.make_simulator([prior, run_germany_nuts1_simulation]) + # trainings_data = simulator.sample(1000) - # adapter = ( - # bf.Adapter() - # .to_array() - # .convert_dtype("float64", "float32") - # .constrain("damping_value", lower=0.0, upper=1.0) - # .concatenate(["region"+str(region) for region in range(len(trainings_data)-1)], into="summary_variables") - # .rename("damping_value", "inference_variables") - # #.standardize("summary_variables") - # ) + # for region in range(16): + # trainings_data[f'region{region}'] = trainings_data[f'region{region}'][:,:, 8][..., np.newaxis] - # summary_network = bf.networks.TimeSeriesNetwork(summary_dim=4) - # inference_network = bf.networks.CouplingFlow() - - # workflow = bf.BasicWorkflow( - # simulator=simulator, - # adapter=adapter, - # summary_network=summary_network, - # inference_network=inference_network - # ) + # with open('validation_data_16param.pickle', 'wb') as f: + # pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) - # history = workflow.fit_offline(data=trainings_data, epochs=2, batch_size=2, validation_data=trainings_data) - # f = bf.diagnostics.plots.loss(history) + with open('trainings_data1_16param.pickle', 'rb') as f: + trainings_data = pickle.load(f) + trainings_data['damping_values'] = trainings_data['damping_values'][:, :16] + for i in range(9): + with open(f'trainings_data{i+2}_16param.pickle', 'rb') as f: + data = pickle.load(f) + data['damping_values'] = data['damping_values'][:, :16] + trainings_data = {k: np.concatenate([trainings_data[k], data[k]]) for k in trainings_data.keys()} + + with open('validation_data_16param.pickle', 'rb') as f: + validation_data = pickle.load(f) + validation_data['damping_values'] = validation_data['damping_values'][:, :16] + + adapter = ( + bf.Adapter() + .to_array() + .convert_dtype("float64", "float32") + .constrain("damping_values", lower=0.0, upper=1.0) + .rename("damping_values", "inference_variables") + .concatenate([f'region{i}' for i in range(16)], into="summary_variables", axis=-1) + .log("summary_variables", p1=True) + ) + + summary_network = bf.networks.TimeSeriesNetwork(summary_dim=32) + inference_network = bf.networks.CouplingFlow() + + workflow = bf.BasicWorkflow( + simulator=simulator, + adapter=adapter, + summary_network=summary_network, + inference_network=inference_network, + standardize='all' + ) + + history = workflow.fit_offline(data=trainings_data, epochs=100, batch_size=32, validation_data=validation_data) + + # workflow.approximator.save(filepath=os.path.join(os.path.dirname(__file__), "model_1params.keras")) + + plots = workflow.plot_default_diagnostics(test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True}) + plots['losses'].savefig('losses_couplingflow_16param.png') + plots['recovery'].savefig('recovery_couplingflow_16param.png') + plots['calibration_ecdf'].savefig('calibration_ecdf_couplingflow_16param.png') + plots['z_score_contraction'].savefig('z_score_contraction_couplingflow_16param.png') diff --git a/pycode/examples/simulation/graph_germany_nuts3.py b/pycode/examples/simulation/graph_germany_nuts3.py index a03330448d..7e8af0852e 100644 --- a/pycode/examples/simulation/graph_germany_nuts3.py +++ b/pycode/examples/simulation/graph_germany_nuts3.py @@ -198,22 +198,35 @@ def prior(): damping_values = np.random.uniform(0.0, 1.0, 400) return {'damping_values': damping_values} -# def load_divi_data(): -# file_path = os.path.dirname(os.path.abspath(__file__)) -# divi_path = os.path.join(file_path, "../../../data/Germany/pydata") - -# data = pd.read_json(os.path.join(divi_path, "county_divi_ma7.json")) -# data = data.drop(columns=['County', 'ICU_ventilated', 'Date']) -# divi_dict = {f"region{i}": data[f'region{i}'] for i in range(399)} -# print(divi_dict) +def load_divi_data(): + file_path = os.path.dirname(os.path.abspath(__file__)) + divi_path = os.path.join(file_path, "../../../data/Germany/pydata") + + data = pd.read_json(os.path.join(divi_path, "county_divi.json")) + print(data["ID_County"].drop_duplicates().shape) + data = data[data['Date']>= np.datetime64(datetime.date(2020, 8, 1))] + data = data[data['Date'] <= np.datetime64(datetime.date(2020, 8, 1) + datetime.timedelta(days=50))] + print(data["ID_County"].drop_duplicates().shape) + data = data.drop(columns=['County', 'ICU_ventilated', 'Date']) + region_ids = [*dd.County] + divi_dict = {f"region{i}": data[data['ID_County'] == region_ids[i]]['ICU'].to_numpy() for i in range(400)} + # for i in range(100): + # if divi_dict[f'region{i+100}'].size==0: + # print(region_ids[i+100]) + # print(divi_dict[f'region{i+100}'].shape) if __name__ == "__main__": - # import pandas as pd + + # file_path = os.path.dirname(os.path.abspath(__file__)) + # casedata_path = os.path.join(file_path, "../../../data/Germany/pydata/cases_all_county_ma7.jsons") + # county = test[test['ID_County']==7320] + # print(county[county['Date']>= np.datetime64(datetime.date(2020, 8, 1))]) + # from memilio.epidata import defaultDict as dd # load_divi_data() import os - os.environ["KERAS_BACKEND"] = "jax" + os.environ["KERAS_BACKEND"] = "tensorflow" import bayesflow as bf @@ -223,17 +236,17 @@ def prior(): # for region in range(400): # trainings_data[f'region{region}'] = trainings_data[f'region{region}'][:,:, 8][..., np.newaxis] - # with open('trainings_data7.pickle', 'wb') as f: + # with open('validation_data_400params.pickle', 'wb') as f: # pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) - with open('trainings_data1.pickle', 'rb') as f: + with open('trainings_data1_400params.pickle', 'rb') as f: trainings_data = pickle.load(f) for i in range(9): - with open(f'trainings_data{i+2}.pickle', 'rb') as f: + with open(f'trainings_data{i+2}_400params.pickle', 'rb') as f: data = pickle.load(f) trainings_data = {k: np.concatenate([trainings_data[k], data[k]]) for k in trainings_data.keys()} - with open('validation_data.pickle', 'rb') as f: + with open('validation_data_400params.pickle', 'rb') as f: validation_data = pickle.load(f) adapter = ( @@ -246,24 +259,25 @@ def prior(): .log("summary_variables", p1=True) ) - print("inference_variables shape:", adapter(trainings_data)["inference_variables"].shape) + print("summary_variables shape:", adapter(trainings_data)["summary_variables"].shape) - summary_network = bf.networks.TimeSeriesNetwork(summary_dim=4) - inference_network = bf.networks.CouplingFlow() + summary_network = bf.networks.TimeSeriesNetwork(summary_dim=700, recurrent_dim=256) + inference_network = bf.networks.DiffusionModel(subnet_kwargs={'widths': {512, 512, 512, 512, 512}}) workflow = bf.BasicWorkflow( simulator=simulator, adapter=adapter, summary_network=summary_network, - inference_network=inference_network + inference_network=inference_network, + standardize='all' ) - history = workflow.fit_offline(data=trainings_data, epochs=1, batch_size=32, validation_data=validation_data) + history = workflow.fit_offline(data=trainings_data, epochs=1000, batch_size=32, validation_data=validation_data) - workflow.approximator.save(filepath=os.path.join(os.path.dirname(__file__), "model.keras")) + # workflow.approximator.save(filepath=os.path.join(os.path.dirname(__file__), "model_10params.keras")) - plots = workflow.plot_default_diagnostics(test_data=validation_data, calibration_ecdf_kwargs={'difference': True}) - plots['losses'].savefig('losses_couplingflow.png') - plots['recovery'].savefig('recovery_couplingflow.png') - plots['calibration_ecdf'].savefig('calibration_ecdf_couplingflow.png') - plots['z_score_contraction'].savefig('z_score_contraction_couplingflow.png') + plots = workflow.plot_default_diagnostics(test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True}) + plots['losses'].savefig('losses_diffusionmodel_400params.png') + plots['recovery'].savefig('recovery_diffusionmodel_400params.png') + plots['calibration_ecdf'].savefig('calibration_ecdf_diffusionmodel_400params.png') + plots['z_score_contraction'].savefig('z_score_contraction_diffusionmodel_400params.png') diff --git a/pycode/examples/simulation/graph_germany_nuts3_16dampings.py b/pycode/examples/simulation/graph_germany_nuts3_16dampings.py new file mode 100644 index 0000000000..8218a7c48b --- /dev/null +++ b/pycode/examples/simulation/graph_germany_nuts3_16dampings.py @@ -0,0 +1,280 @@ +############################################################################# +# Copyright (C) 2020-2025 MEmilio +# +# Authors: Henrik Zunker +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# +import numpy as np +import datetime +import os +import memilio.simulation as mio +import memilio.simulation.osecir as osecir +import matplotlib.pyplot as plt + +from enum import Enum +from memilio.simulation.osecir import (Model, Simulation, + interpolate_simulation_result) + +import pickle + + +class Simulation: + """ """ + + def __init__(self, data_dir, start_date, results_dir): + self.num_groups = 1 + self.data_dir = data_dir + self.start_date = start_date + self.results_dir = results_dir + if not os.path.exists(self.results_dir): + os.makedirs(self.results_dir) + + def set_covid_parameters(self, model): + """ + + :param model: + + """ + model.parameters.TimeExposed[mio.AgeGroup(0)] = 3.335 + model.parameters.TimeInfectedNoSymptoms[mio.AgeGroup(0)] = 2.58916 + model.parameters.TimeInfectedSymptoms[mio.AgeGroup(0)] = 6.94547 + model.parameters.TimeInfectedSevere[mio.AgeGroup(0)] = 7.28196 + model.parameters.TimeInfectedCritical[mio.AgeGroup(0)] = 13.066 + + # probabilities + model.parameters.TransmissionProbabilityOnContact[mio.AgeGroup(0)] = 0.07333 + model.parameters.RelativeTransmissionNoSymptoms[mio.AgeGroup(0)] = 1 + + model.parameters.RecoveredPerInfectedNoSymptoms[mio.AgeGroup(0)] = 0.2069 + model.parameters.SeverePerInfectedSymptoms[mio.AgeGroup(0)] = 0.07864 + model.parameters.CriticalPerSevere[mio.AgeGroup(0)] = 0.17318 + model.parameters.DeathsPerCritical[mio.AgeGroup(0)] = 0.21718 + + # start day is set to the n-th day of the year + model.parameters.StartDay = self.start_date.timetuple().tm_yday + + model.parameters.Seasonality = mio.UncertainValue(0.2) + + def set_contact_matrices(self, model): + """ + + :param model: + + """ + contact_matrices = mio.ContactMatrixGroup(1, self.num_groups) + + baseline = np.ones((self.num_groups, self.num_groups)) * 7.95 + minimum = np.ones((self.num_groups, self.num_groups)) * 0 + contact_matrices[0] = mio.ContactMatrix(baseline, minimum) + model.parameters.ContactPatterns.cont_freq_mat = contact_matrices + + def set_npis(self, params, end_date, damping_value): + """ + + :param params: + :param end_date: + + """ + + start_damping = datetime.date( + 2020, 12, 18) + + if start_damping < end_date: + start_date = (start_damping - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping(mio.Damping(np.r_[damping_value], t=start_date)) + + def get_graph(self, end_date): + """ + + :param end_date: + + """ + print("Initializing model...") + model = Model(self.num_groups) + self.set_covid_parameters(model) + self.set_contact_matrices(model) + print("Model initialized.") + + graph = osecir.ModelGraph() + + scaling_factor_infected = [2.5] + scaling_factor_icu = 1.0 + tnt_capacity_factor = 7.5 / 100000. + + data_dir_Germany = os.path.join(self.data_dir, "Germany") + mobility_data_file = os.path.join( + data_dir_Germany, "mobility", "commuter_mobility_2022.txt") + pydata_dir = os.path.join(data_dir_Germany, "pydata") + + path_population_data = os.path.join(pydata_dir, + "county_current_population_aggregated.json") + + print("Setting nodes...") + mio.osecir.set_nodes( + model.parameters, + mio.Date(self.start_date.year, + self.start_date.month, self.start_date.day), + mio.Date(end_date.year, + end_date.month, end_date.day), pydata_dir, + path_population_data, True, graph, scaling_factor_infected, + scaling_factor_icu, 0, 0, False, False) + + print("Setting edges...") + mio.osecir.set_edges( + mobility_data_file, graph, 1) + + print("Graph created.") + + return graph + + def run(self, num_days_sim, damping_values, save_graph=True): + """ + + :param num_days_sim: + :param num_runs: (Default value = 10) + :param save_graph: (Default value = True) + :param create_gif: (Default value = True) + + """ + mio.set_log_level(mio.LogLevel.Warning) + end_date = self.start_date + datetime.timedelta(days=num_days_sim) + + graph = self.get_graph(end_date) + + if save_graph: + path_graph = os.path.join(self.results_dir, "graph") + if not os.path.exists(path_graph): + os.makedirs(path_graph) + osecir.write_graph(graph, path_graph) + + mobility_graph = osecir.MobilityGraph() + for node_idx in range(graph.num_nodes): + node = graph.get_node(node_idx) + self.set_npis(node.property.parameters, end_date, damping_values[node_idx]) + mobility_graph.add_node(node.id, node.property) + for edge_idx in range(graph.num_edges): + mobility_graph.add_edge( + graph.get_edge(edge_idx).start_node_idx, + graph.get_edge(edge_idx).end_node_idx, + graph.get_edge(edge_idx).property) + mobility_sim = osecir.MobilitySimulation(mobility_graph, t0=0, dt=0.5) + mobility_sim.advance(num_days_sim) + + print("Simulation finished.") + results = [] + for node_idx in range(graph.num_nodes): + results.append(osecir.interpolate_simulation_result( + mobility_sim.graph.get_node(node_idx).property.result)) + + return results + +def run_germany_nuts3_simulation(damping_values): + mio.set_log_level(mio.LogLevel.Warning) + file_path = os.path.dirname(os.path.abspath(__file__)) + + sim = Simulation( + data_dir=os.path.join(file_path, "../../../data"), + start_date=datetime.date(year=2020, month=12, day=12), + results_dir=os.path.join(file_path, "../../../results_osecir")) + num_days_sim = 50 + + results = sim.run(num_days_sim, damping_values) + + return {f'region{region}': results[region] for region in range(len(results))} + +def prior(): + damping_values = np.random.uniform(0.0, 1.0, 400) + return {'damping_values': damping_values} + +def load_divi_data(): + file_path = os.path.dirname(os.path.abspath(__file__)) + divi_path = os.path.join(file_path, "../../../data/Germany/pydata") + + data = pd.read_json(os.path.join(divi_path, "county_divi.json")) + print(data["ID_County"].drop_duplicates().shape) + data = data[data['Date']>= np.datetime64(datetime.date(2020, 8, 1))] + data = data[data['Date'] <= np.datetime64(datetime.date(2020, 8, 1) + datetime.timedelta(days=50))] + print(data["ID_County"].drop_duplicates().shape) + data = data.drop(columns=['County', 'ICU_ventilated', 'Date']) + region_ids = [*dd.County] + divi_dict = {f"region{i}": data[data['ID_County'] == region_ids[i]]['ICU'].to_numpy() for i in range(400)} + # for i in range(100): + # if divi_dict[f'region{i+100}'].size==0: + # print(region_ids[i+100]) + # print(divi_dict[f'region{i+100}'].shape) + + +if __name__ == "__main__": + + # from memilio.epidata import defaultDict as dd + # import pandas as pd + # load_divi_data() + import os + os.environ["KERAS_BACKEND"] = "tensorflow" + + import bayesflow as bf + + simulator = bf.simulators.make_simulator([prior, run_germany_nuts3_simulation]) + # trainings_data = simulator.sample(1000) + + # for region in range(400): + # trainings_data[f'region{region}'] = trainings_data[f'region{region}'][:,:, 8][..., np.newaxis] + + # with open('validation_data_400params.pickle', 'wb') as f: + # pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) + + with open('trainings_data1_16params_countylvl.pickle', 'rb') as f: + trainings_data = pickle.load(f) + for i in range(9): + with open(f'trainings_data{i+2}_16params_countylvl.pickle', 'rb') as f: + data = pickle.load(f) + trainings_data = {k: np.concatenate([trainings_data[k], data[k]]) for k in trainings_data.keys()} + + with open('validation_data_16params_countylvl.pickle', 'rb') as f: + validation_data = pickle.load(f) + + adapter = ( + bf.Adapter() + .to_array() + .convert_dtype("float64", "float32") + .constrain("damping_values", lower=0.0, upper=1.0) + .rename("damping_values", "inference_variables") + .concatenate([f'region{i}' for i in range(400)], into="summary_variables", axis=-1) + .log("summary_variables", p1=True) + ) + + print("summary_variables shape:", adapter(trainings_data)["summary_variables"].shape) + + summary_network = bf.networks.TimeSeriesNetwork(summary_dim=32) #, recurrent_dim=256) + inference_network = bf.networks.CouplingFlow()#subnet_kwargs={'widths': {512, 512, 512, 512, 512}}) + + workflow = bf.BasicWorkflow( + simulator=simulator, + adapter=adapter, + summary_network=summary_network, + inference_network=inference_network, + standardize='all' + ) + + history = workflow.fit_offline(data=trainings_data, epochs=100, batch_size=32, validation_data=validation_data) + + # workflow.approximator.save(filepath=os.path.join(os.path.dirname(__file__), "model_10params.keras")) + + plots = workflow.plot_default_diagnostics(test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True}) + plots['losses'].savefig('losses_couplingflow_16params_countylvl.png') + plots['recovery'].savefig('recovery_couplingflow_16params_countylvl.png') + plots['calibration_ecdf'].savefig('calibration_ecdf_couplingflow_16params_countylvl.png') + plots['z_score_contraction'].savefig('z_score_contraction_couplingflow_16params_countylvl.png') diff --git a/shellscripts/fitting_graphmodel.sh b/shellscripts/fitting_graphmodel.sh index dbaa905ebe..d92d6dd151 100644 --- a/shellscripts/fitting_graphmodel.sh +++ b/shellscripts/fitting_graphmodel.sh @@ -3,11 +3,13 @@ #SBATCH -n 1 #SBATCH -c 1 #SBATCH -t 5-0:00:00 -#SBATCH --output=shellscripts/create_testdata_fitting-%A.out -#SBATCH --error=shellscripts/create_testdata_fitting-%A.err -#SBATCH --exclude="be-cpu05, be-gpu01" -#SBATCH --job-name=create_testdata_fitting +#SBATCH --output=shellscripts/train_diffusion_model-%A.out +#SBATCH --error=shellscripts/train_diffusion_model-%A.err +#SBATCH --job-name=train_diffusion_model +#SBATCH --partition=gpu +#SBATCH --gpus=1 module load PrgEnv/gcc13-openmpi-python +module load cuda/12.9.0-none-none-6wnenm2 source venv/bin/activate -srun --cpu-bind=core python pycode/examples/simulation/graph_germany_nuts3.py \ No newline at end of file +srun --cpu-bind=core python pycode/examples/simulation/graph_germany_nuts3.py \ No newline at end of file From 67b5a6ddff9c851c6ac185bba2249f194cfcc3d1 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Tue, 12 Aug 2025 09:30:14 +0200 Subject: [PATCH 10/73] try for modeling spain --- cpp/memilio/geography/regions.h | 2 + cpp/memilio/io/epi_data.cpp | 5 +- cpp/memilio/io/epi_data.h | 54 ++++ cpp/memilio/io/parameters_io.cpp | 25 ++ cpp/models/ode_secir/parameters_io.h | 47 ++++ .../examples/simulation/graph_spain_nuts3.py | 260 ++++++++++++++++++ .../memilio/epidata/defaultDict.py | 54 ++++ .../epidata/modifyPopulationDataSpain.py | 21 ++ 8 files changed, 466 insertions(+), 2 deletions(-) create mode 100644 pycode/examples/simulation/graph_spain_nuts3.py create mode 100644 pycode/memilio-epidata/memilio/epidata/modifyPopulationDataSpain.py diff --git a/cpp/memilio/geography/regions.h b/cpp/memilio/geography/regions.h index c8dbeb5978..a359d5bce9 100644 --- a/cpp/memilio/geography/regions.h +++ b/cpp/memilio/geography/regions.h @@ -77,6 +77,8 @@ DECL_TYPESAFE(int, CountyId); DECL_TYPESAFE(int, DistrictId); +DECL_TYPESAFE(int, ProvinciaId); + /** * get the id of the state that the specified county is in. * @param[in, out] county a county id. diff --git a/cpp/memilio/io/epi_data.cpp b/cpp/memilio/io/epi_data.cpp index 2e95ad24fb..0c9fe015f1 100644 --- a/cpp/memilio/io/epi_data.cpp +++ b/cpp/memilio/io/epi_data.cpp @@ -25,8 +25,9 @@ namespace mio { -std::vector ConfirmedCasesDataEntry::age_group_names = {"Population"}; -std::vector PopulationDataEntry::age_group_names = {"Population"}; +std::vector ConfirmedCasesDataEntry::age_group_names = {"Population"}; +std::vector PopulationDataEntry::age_group_names = {"Population"}; +std::vector PopulationDataEntrySpain::age_group_names = {"Population"}; std::vector VaccinationDataEntry::age_group_names = {"0-4", "5-14", "15-34", "35-59", "60-79", "80-99"}; diff --git a/cpp/memilio/io/epi_data.h b/cpp/memilio/io/epi_data.h index 3c3dac4c12..6664e418a2 100644 --- a/cpp/memilio/io/epi_data.h +++ b/cpp/memilio/io/epi_data.h @@ -337,6 +337,35 @@ class PopulationDataEntry } }; +class PopulationDataEntrySpain +{ +public: + static std::vector age_group_names; + + CustomIndexArray population; + boost::optional provincia_id; + + template + static IOResult deserialize(IoContext& io) + { + auto obj = io.expect_object("PopulationDataEntrySpain"); + auto state_id = obj.expect_optional("ID_Provincia", Tag{}); + std::vector> age_groups; + age_groups.reserve(age_group_names.size()); + std::transform(age_group_names.begin(), age_group_names.end(), std::back_inserter(age_groups), + [&obj](auto&& age_name) { + return obj.expect_element(age_name, Tag{}); + }); + return apply( + io, + [](auto&& ag, auto&& pid) { + return PopulationDataEntrySpain{ + CustomIndexArray(AgeGroup(ag.size()), ag.begin(), ag.end()), pid}; + }, + details::unpack_all(age_groups), provincia_id); + } +}; + namespace details { inline void get_rki_age_interpolation_coefficients(const std::vector& age_ranges, @@ -441,6 +470,19 @@ inline IOResult> deserialize_population_data(co } } +/** + * Deserialize population data from a JSON value. + * Age groups are interpolated to RKI age groups. + * @param jsvalue JSON value that contains the population data. + * @param rki_age_groups Specifies whether population data should be interpolated to rki age groups. + * @return list of population data. + */ +inline IOResult> deserialize_population_data_spain(const Json::Value& jsvalue) +{ + BOOST_OUTCOME_TRY(auto&& population_data, deserialize_json(jsvalue, Tag>{})); + return success(population_data); +} + /** * Deserialize population data from a JSON file. * Age groups are interpolated to RKI age groups. @@ -454,6 +496,18 @@ inline IOResult> read_population_data(const std return deserialize_population_data(jsvalue, rki_age_group); } +/** + * Deserialize population data from a JSON file. + * Age groups are interpolated to RKI age groups. + * @param filename JSON file that contains the population data. + * @return list of population data. + */ +inline IOResult> read_population_data_spain(const std::string& filename) +{ + BOOST_OUTCOME_TRY(auto&& jsvalue, read_json(filename)); + return deserialize_population_data_spain(jsvalue); +} + /** * @brief Sets the age groups' names for the ConfirmedCasesDataEntry%s. * @param[in] names age group names diff --git a/cpp/memilio/io/parameters_io.cpp b/cpp/memilio/io/parameters_io.cpp index f9f9f8b3c0..4ca6e86d85 100644 --- a/cpp/memilio/io/parameters_io.cpp +++ b/cpp/memilio/io/parameters_io.cpp @@ -62,6 +62,31 @@ IOResult>> read_population_data(const std::vecto return success(vnum_population); } +IOResult>> +read_population_data_spain(const std::vector& population_data, + const std::vector& vregion) +{ + std::vector> vnum_population(vregion.size(), std::vector(1, 0.0)); + + for (auto&& provincia_entry : population_data) { + printf("Test"); + //find region that this county belongs to + //all counties belong to the country (id = 0) + auto it = std::find_if(vregion.begin(), vregion.end(), [&provincia_entry](auto r) { + return r == 0 || (provincia_entry.provincia_id && regions::ProvinciaId(r) == provincia_entry.provincia_id); + }); + if (it != vregion.end()) { + auto region_idx = size_t(it - vregion.begin()); + auto& num_population = vnum_population[region_idx]; + for (size_t age = 0; age < num_population.size(); age++) { + num_population[age] += provincia_entry.population[AgeGroup(age)]; + } + } + } + + return success(vnum_population); +} + IOResult>> read_population_data(const std::string& path, const std::vector& vregion) { diff --git a/cpp/models/ode_secir/parameters_io.h b/cpp/models/ode_secir/parameters_io.h index ef5cc74a32..1e6d5d8305 100644 --- a/cpp/models/ode_secir/parameters_io.h +++ b/cpp/models/ode_secir/parameters_io.h @@ -352,6 +352,15 @@ IOResult set_population_data(std::vector>& model, const std::str return success(); } +template +IOResult set_population_data_provincias(std::vector>& model, const std::string& path, + const std::vector& vregion) +{ + BOOST_OUTCOME_TRY(const auto&& num_population, read_population_data_spain(path, vregion)); + BOOST_OUTCOME_TRY(set_population_data(model, num_population, vregion)); + return success(); +} + } //namespace details #ifdef MEMILIO_HAS_HDF5 @@ -505,6 +514,44 @@ IOResult read_input_data_county(std::vector& model, Date date, cons return success(); } +/** + * @brief Reads population data from population files for the specefied county. + * @param[in, out] model Vector of model in which the data is set. + * @param[in] date Date for which the data should be read. + * @param[in] county Vector of region keys of counties of interest. + * @param[in] scaling_factor_inf Factors by which to scale the confirmed cases of rki data. + * @param[in] scaling_factor_icu Factor by which to scale the icu cases of divi data. + * @param[in] pydata_dir Directory of files. + * @param[in] num_days [Default: 0] Number of days to be simulated; required to extrapolate real data. + * @param[in] export_time_series [Default: false] If true, reads data for each day of simulation and writes it in the same directory as the input files. + */ +template +IOResult read_input_data_provincias(std::vector& model, Date date, const std::vector& county, + const std::vector& scaling_factor_inf, double scaling_factor_icu, + const std::string& pydata_dir, int num_days = 0, + bool export_time_series = false) +{ + // BOOST_OUTCOME_TRY( + // details::set_divi_data(model, path_join(pydata_dir, "county_divi_ma7.json"), county, date, scaling_factor_icu)); + // BOOST_OUTCOME_TRY(details::set_confirmed_cases_data(model, path_join(pydata_dir, "cases_all_county_ma7.json"), + // county, date, scaling_factor_inf)); + BOOST_OUTCOME_TRY(details::set_population_data_provincias( + model, path_join(pydata_dir, "provincias_current_population.json"), county)); + + if (export_time_series) { + // Use only if extrapolated real data is needed for comparison. EXPENSIVE ! + // Run time equals run time of the previous functions times the num_days ! + // (This only represents the vectorization of the previous function over all simulation days...) + log_warning("Exporting time series of extrapolated real data. This may take some minutes. " + "For simulation runs over the same time period, deactivate it."); + BOOST_OUTCOME_TRY(export_input_data_county_timeseries( + model, pydata_dir, county, date, scaling_factor_inf, scaling_factor_icu, num_days, + path_join(pydata_dir, "county_divi_ma7.json"), path_join(pydata_dir, "cases_all_county_age_ma7.json"), + path_join(pydata_dir, "county_current_population.json"))); + } + return success(); +} + /** * @brief reads population data from population files for the specified nodes * @param[in, out] model vector of model in which the data is set diff --git a/pycode/examples/simulation/graph_spain_nuts3.py b/pycode/examples/simulation/graph_spain_nuts3.py new file mode 100644 index 0000000000..a63e1711e1 --- /dev/null +++ b/pycode/examples/simulation/graph_spain_nuts3.py @@ -0,0 +1,260 @@ +############################################################################# +# Copyright (C) 2020-2025 MEmilio +# +# Authors: Henrik Zunker +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# +import numpy as np +import datetime +import os +import memilio.simulation as mio +import memilio.simulation.osecir as osecir +import matplotlib.pyplot as plt + +from enum import Enum +from memilio.simulation.osecir import (Model, Simulation, + interpolate_simulation_result) + +import pickle + + +class Simulation: + """ """ + + def __init__(self, data_dir, start_date, results_dir): + self.num_groups = 1 + self.data_dir = data_dir + self.start_date = start_date + self.results_dir = results_dir + if not os.path.exists(self.results_dir): + os.makedirs(self.results_dir) + + def set_covid_parameters(self, model): + """ + + :param model: + + """ + model.parameters.TimeExposed[mio.AgeGroup(0)] = 3.335 + model.parameters.TimeInfectedNoSymptoms[mio.AgeGroup(0)] = 2.58916 + model.parameters.TimeInfectedSymptoms[mio.AgeGroup(0)] = 6.94547 + model.parameters.TimeInfectedSevere[mio.AgeGroup(0)] = 7.28196 + model.parameters.TimeInfectedCritical[mio.AgeGroup(0)] = 13.066 + + # probabilities + model.parameters.TransmissionProbabilityOnContact[mio.AgeGroup(0)] = 0.07333 + model.parameters.RelativeTransmissionNoSymptoms[mio.AgeGroup(0)] = 1 + + model.parameters.RecoveredPerInfectedNoSymptoms[mio.AgeGroup(0)] = 0.2069 + model.parameters.SeverePerInfectedSymptoms[mio.AgeGroup(0)] = 0.07864 + model.parameters.CriticalPerSevere[mio.AgeGroup(0)] = 0.17318 + model.parameters.DeathsPerCritical[mio.AgeGroup(0)] = 0.21718 + + # start day is set to the n-th day of the year + model.parameters.StartDay = self.start_date.timetuple().tm_yday + + model.parameters.Seasonality = mio.UncertainValue(0.2) + + def set_contact_matrices(self, model): + """ + + :param model: + + """ + contact_matrices = mio.ContactMatrixGroup(1, self.num_groups) + + baseline = np.ones((self.num_groups, self.num_groups)) * 12.32 + minimum = np.ones((self.num_groups, self.num_groups)) * 0 + contact_matrices[0] = mio.ContactMatrix(baseline, minimum) + model.parameters.ContactPatterns.cont_freq_mat = contact_matrices + + def set_npis(self, params, end_date, damping_value): + """ + + :param params: + :param end_date: + + """ + + start_damping = datetime.date( + 2020, 12, 18) + + if start_damping < end_date: + start_date = (start_damping - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping(mio.Damping(np.r_[damping_value], t=start_date)) + + def get_graph(self, end_date): + """ + + :param end_date: + + """ + print("Initializing model...") + model = Model(self.num_groups) + self.set_covid_parameters(model) + self.set_contact_matrices(model) + print("Model initialized.") + + graph = osecir.ModelGraph() + + scaling_factor_infected = [2.5] + scaling_factor_icu = 1.0 + tnt_capacity_factor = 7.5 / 100000. + + data_dir_Spain = os.path.join(self.data_dir, "Spain") + mobility_data_file = os.path.join( + data_dir_Spain, "mobility", "commuter_mobility_2022.txt") + pydata_dir = os.path.join(data_dir_Spain, "pydata") + + path_population_data = os.path.join(pydata_dir, + "provincias_current_population.json") + + print("Setting nodes...") + mio.osecir.set_nodes_provincias( + model.parameters, + mio.Date(self.start_date.year, + self.start_date.month, self.start_date.day), + mio.Date(end_date.year, + end_date.month, end_date.day), pydata_dir, + path_population_data, True, graph, scaling_factor_infected, + scaling_factor_icu, 0, 0, False, False) + + print("Setting edges...") + mio.osecir.set_edges( + mobility_data_file, graph, 1) + + print("Graph created.") + + return graph + + def run(self, num_days_sim, damping_values, save_graph=True): + """ + + :param num_days_sim: + :param num_runs: (Default value = 10) + :param save_graph: (Default value = True) + :param create_gif: (Default value = True) + + """ + mio.set_log_level(mio.LogLevel.Warning) + end_date = self.start_date + datetime.timedelta(days=num_days_sim) + + graph = self.get_graph(end_date) + + if save_graph: + path_graph = os.path.join(self.results_dir, "graph") + if not os.path.exists(path_graph): + os.makedirs(path_graph) + osecir.write_graph(graph, path_graph) + + mobility_graph = osecir.MobilityGraph() + for node_idx in range(graph.num_nodes): + node = graph.get_node(node_idx) + self.set_npis(node.property.parameters, end_date, damping_values[node_idx]) + mobility_graph.add_node(node.id, node.property) + for edge_idx in range(graph.num_edges): + mobility_graph.add_edge( + graph.get_edge(edge_idx).start_node_idx, + graph.get_edge(edge_idx).end_node_idx, + graph.get_edge(edge_idx).property) + mobility_sim = osecir.MobilitySimulation(mobility_graph, t0=0, dt=0.5) + mobility_sim.advance(num_days_sim) + + results = [] + for node_idx in range(graph.num_nodes): + results.append(osecir.interpolate_simulation_result( + mobility_sim.graph.get_node(node_idx).property.result)) + + return results + +def run_spain_nuts3_simulation(damping_values): + mio.set_log_level(mio.LogLevel.Warning) + file_path = os.path.dirname(os.path.abspath(__file__)) + + sim = Simulation( + data_dir=os.path.join(file_path, "../../../data"), + start_date=datetime.date(year=2020, month=12, day=12), + results_dir=os.path.join(file_path, "../../../results_osecir")) + num_days_sim = 50 + + results = sim.run(num_days_sim, damping_values) + + return {f'region{region}': results[region] for region in range(len(results))} + +def prior(): + damping_values = np.random.uniform(0.0, 1.0, 47) + return {'damping_values': damping_values} + + +if __name__ == "__main__": + test = prior() + run_spain_nuts3_simulation(test['damping_values']) + # import os + # os.environ["KERAS_BACKEND"] = "tensorflow" + + # import bayesflow as bf + + # simulator = bf.simulators.make_simulator([prior, run_germany_nuts3_simulation]) + # # trainings_data = simulator.sample(1000) + + # # for region in range(400): + # # trainings_data[f'region{region}'] = trainings_data[f'region{region}'][:,:, 8][..., np.newaxis] + + # # with open('validation_data_400params.pickle', 'wb') as f: + # # pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) + + # with open('trainings_data1_400params.pickle', 'rb') as f: + # trainings_data = pickle.load(f) + # for i in range(9): + # with open(f'trainings_data{i+2}_400params.pickle', 'rb') as f: + # data = pickle.load(f) + # trainings_data = {k: np.concatenate([trainings_data[k], data[k]]) for k in trainings_data.keys()} + + # with open('validation_data_400params.pickle', 'rb') as f: + # validation_data = pickle.load(f) + + # adapter = ( + # bf.Adapter() + # .to_array() + # .convert_dtype("float64", "float32") + # .constrain("damping_values", lower=0.0, upper=1.0) + # .rename("damping_values", "inference_variables") + # .concatenate([f'region{i}' for i in range(400)], into="summary_variables", axis=-1) + # .log("summary_variables", p1=True) + # ) + + # print("summary_variables shape:", adapter(trainings_data)["summary_variables"].shape) + + # summary_network = bf.networks.TimeSeriesNetwork(summary_dim=700, recurrent_dim=256) + # inference_network = bf.networks.DiffusionModel(subnet_kwargs={'widths': {512, 512, 512, 512, 512}}) + + # workflow = bf.BasicWorkflow( + # simulator=simulator, + # adapter=adapter, + # summary_network=summary_network, + # inference_network=inference_network, + # standardize='all' + # ) + + # history = workflow.fit_offline(data=trainings_data, epochs=1000, batch_size=32, validation_data=validation_data) + + # # workflow.approximator.save(filepath=os.path.join(os.path.dirname(__file__), "model_10params.keras")) + + # plots = workflow.plot_default_diagnostics(test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True}) + # plots['losses'].savefig('losses_diffusionmodel_400params.png') + # plots['recovery'].savefig('recovery_diffusionmodel_400params.png') + # plots['calibration_ecdf'].savefig('calibration_ecdf_diffusionmodel_400params.png') + # plots['z_score_contraction'].savefig('z_score_contraction_diffusionmodel_400params.png') diff --git a/pycode/memilio-epidata/memilio/epidata/defaultDict.py b/pycode/memilio-epidata/memilio/epidata/defaultDict.py index 73389b82d2..8f95fffafd 100644 --- a/pycode/memilio-epidata/memilio/epidata/defaultDict.py +++ b/pycode/memilio-epidata/memilio/epidata/defaultDict.py @@ -704,3 +704,57 @@ def invert_dict(dict_to_invert): """ return {val: key for key, val in dict_to_invert.items()} + +Provincias = { + 5: 'Almería', + 12: 'Cádiz', + 15: 'Córdoba', + 19: 'Granada', + 22: 'Huelva', + 24: 'Jaén', + 30: 'Málaga', + 41: 'Sevilla', + 23: 'Huesca', + 44: 'Teruel', + 50: 'Zaragoza', + 33: 'Asturias', + 8: 'Balears, Illes', + 35: 'Palmas, Las', + 38: 'Santa Cruz de Tenerife', + 39: 'Cantabria', + 6: 'Ávila', + 10: 'Burgos', + 25: 'León', + 34: 'Palencia', + 37: 'Salamanca', + 40: 'Segovia', + 42: 'Soria', + 47: 'Valladolid', + 49: 'Zamora', + 3: 'Albacete', + 14: 'Ciudad Real', + 17: 'Cuenca', + 20: 'Guadalajara', + 45: 'Toledo', + 9: 'Barcelona', + 18: 'Girona', + 26: 'Lleida', + 43: 'Tarragona', + 4: 'Alicante/Alacant', + 13: 'Castellón/Castelló', + 46: 'Valencia/València', + 7: 'Badajoz', + 11: 'Cáceres', + 16: 'Coruña, A', + 28: 'Lugo', + 53: 'Ourense', + 36: 'Pontevedra', + 29: 'Madrid', + 31: 'Murcia', + 32: 'Navarra', + 2: 'Araba/Álava', + 21: 'Gipuzkoa', + 48: 'Bizkaia', + 27: 'Rioja, La', + 51: 'Ceuta', + 52: 'Melilla'} \ No newline at end of file diff --git a/pycode/memilio-epidata/memilio/epidata/modifyPopulationDataSpain.py b/pycode/memilio-epidata/memilio/epidata/modifyPopulationDataSpain.py new file mode 100644 index 0000000000..acdba425f8 --- /dev/null +++ b/pycode/memilio-epidata/memilio/epidata/modifyPopulationDataSpain.py @@ -0,0 +1,21 @@ +import pandas as pd +import os + +def read_population_data(file): + df = pd.read_json(file) + df = df[['MetaData', 'Data']] + df['ID_Provincia'] = df['MetaData'].apply(lambda x: x[0]['Id']) + df['Population'] = df['Data'].apply(lambda x: x[0]['Valor']) + return df[['ID_Provincia', 'Population']] + +def remove_islands(df): + df = df[~df['ID_Provincia'].isin([51, 52, 8, 35, 38])] + return df + +if __name__ == "__main__": + + data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../../data/Spain") + + df = read_population_data(os.path.join(data_dir, 'pydata/67988.json')) + df = remove_islands(df) + df.to_json(os.path.join(data_dir, 'pydata/provincias_current_population.json'), orient='records', force_ascii=False) \ No newline at end of file From e93dced38f4357e56396ebd40f8ff1cf36970ee6 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Tue, 12 Aug 2025 11:02:47 +0200 Subject: [PATCH 11/73] change starting date --- .../simulation/graph_germany_nuts3.py | 88 +++++++++---------- 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/pycode/examples/simulation/graph_germany_nuts3.py b/pycode/examples/simulation/graph_germany_nuts3.py index 7e8af0852e..1588dc459f 100644 --- a/pycode/examples/simulation/graph_germany_nuts3.py +++ b/pycode/examples/simulation/graph_germany_nuts3.py @@ -186,7 +186,7 @@ def run_germany_nuts3_simulation(damping_values): sim = Simulation( data_dir=os.path.join(file_path, "../../../data"), - start_date=datetime.date(year=2020, month=12, day=12), + start_date=datetime.date(year=2020, month=7, day=1), results_dir=os.path.join(file_path, "../../../results_osecir")) num_days_sim = 50 @@ -231,7 +231,7 @@ def load_divi_data(): import bayesflow as bf simulator = bf.simulators.make_simulator([prior, run_germany_nuts3_simulation]) - # trainings_data = simulator.sample(1000) + trainings_data = simulator.sample(1) # for region in range(400): # trainings_data[f'region{region}'] = trainings_data[f'region{region}'][:,:, 8][..., np.newaxis] @@ -239,45 +239,45 @@ def load_divi_data(): # with open('validation_data_400params.pickle', 'wb') as f: # pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) - with open('trainings_data1_400params.pickle', 'rb') as f: - trainings_data = pickle.load(f) - for i in range(9): - with open(f'trainings_data{i+2}_400params.pickle', 'rb') as f: - data = pickle.load(f) - trainings_data = {k: np.concatenate([trainings_data[k], data[k]]) for k in trainings_data.keys()} - - with open('validation_data_400params.pickle', 'rb') as f: - validation_data = pickle.load(f) - - adapter = ( - bf.Adapter() - .to_array() - .convert_dtype("float64", "float32") - .constrain("damping_values", lower=0.0, upper=1.0) - .rename("damping_values", "inference_variables") - .concatenate([f'region{i}' for i in range(400)], into="summary_variables", axis=-1) - .log("summary_variables", p1=True) - ) - - print("summary_variables shape:", adapter(trainings_data)["summary_variables"].shape) - - summary_network = bf.networks.TimeSeriesNetwork(summary_dim=700, recurrent_dim=256) - inference_network = bf.networks.DiffusionModel(subnet_kwargs={'widths': {512, 512, 512, 512, 512}}) - - workflow = bf.BasicWorkflow( - simulator=simulator, - adapter=adapter, - summary_network=summary_network, - inference_network=inference_network, - standardize='all' - ) - - history = workflow.fit_offline(data=trainings_data, epochs=1000, batch_size=32, validation_data=validation_data) - - # workflow.approximator.save(filepath=os.path.join(os.path.dirname(__file__), "model_10params.keras")) - - plots = workflow.plot_default_diagnostics(test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True}) - plots['losses'].savefig('losses_diffusionmodel_400params.png') - plots['recovery'].savefig('recovery_diffusionmodel_400params.png') - plots['calibration_ecdf'].savefig('calibration_ecdf_diffusionmodel_400params.png') - plots['z_score_contraction'].savefig('z_score_contraction_diffusionmodel_400params.png') + # with open('trainings_data1_400params.pickle', 'rb') as f: + # trainings_data = pickle.load(f) + # for i in range(9): + # with open(f'trainings_data{i+2}_400params.pickle', 'rb') as f: + # data = pickle.load(f) + # trainings_data = {k: np.concatenate([trainings_data[k], data[k]]) for k in trainings_data.keys()} + + # with open('validation_data_400params.pickle', 'rb') as f: + # validation_data = pickle.load(f) + + # adapter = ( + # bf.Adapter() + # .to_array() + # .convert_dtype("float64", "float32") + # .constrain("damping_values", lower=0.0, upper=1.0) + # .rename("damping_values", "inference_variables") + # .concatenate([f'region{i}' for i in range(400)], into="summary_variables", axis=-1) + # .log("summary_variables", p1=True) + # ) + + # print("summary_variables shape:", adapter(trainings_data)["summary_variables"].shape) + + # summary_network = bf.networks.TimeSeriesNetwork(summary_dim=700, recurrent_dim=512) + # inference_network = bf.networks.DiffusionModel(subnet_kwargs={'widths': {512, 512, 512, 512, 512}}) + + # workflow = bf.BasicWorkflow( + # simulator=simulator, + # adapter=adapter, + # summary_network=summary_network, + # inference_network=inference_network, + # standardize='all' + # ) + + # history = workflow.fit_offline(data=trainings_data, epochs=1000, batch_size=32, validation_data=validation_data) + + # # workflow.approximator.save(filepath=os.path.join(os.path.dirname(__file__), "model_10params.keras")) + + # plots = workflow.plot_default_diagnostics(test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True}) + # plots['losses'].savefig('losses_diffusionmodel_400params.png') + # plots['recovery'].savefig('recovery_diffusionmodel_400params.png') + # plots['calibration_ecdf'].savefig('calibration_ecdf_diffusionmodel_400params.png') + # plots['z_score_contraction'].savefig('z_score_contraction_diffusionmodel_400params.png') From 19f7bc8256f55bc940842185806e074783fbc682 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Tue, 12 Aug 2025 12:37:44 +0200 Subject: [PATCH 12/73] comment everything out for spain --- cpp/memilio/io/epi_data.cpp | 2 +- cpp/memilio/io/epi_data.h | 76 ++++++++++++++-------------- cpp/memilio/io/parameters_io.cpp | 44 ++++++++-------- cpp/models/ode_secir/parameters_io.h | 68 ++++++++++++------------- 4 files changed, 95 insertions(+), 95 deletions(-) diff --git a/cpp/memilio/io/epi_data.cpp b/cpp/memilio/io/epi_data.cpp index 0c9fe015f1..1b5569f57c 100644 --- a/cpp/memilio/io/epi_data.cpp +++ b/cpp/memilio/io/epi_data.cpp @@ -27,7 +27,7 @@ namespace mio std::vector ConfirmedCasesDataEntry::age_group_names = {"Population"}; std::vector PopulationDataEntry::age_group_names = {"Population"}; -std::vector PopulationDataEntrySpain::age_group_names = {"Population"}; +// std::vector PopulationDataEntrySpain::age_group_names = {"Population"}; std::vector VaccinationDataEntry::age_group_names = {"0-4", "5-14", "15-34", "35-59", "60-79", "80-99"}; diff --git a/cpp/memilio/io/epi_data.h b/cpp/memilio/io/epi_data.h index 6664e418a2..fcb9fe147c 100644 --- a/cpp/memilio/io/epi_data.h +++ b/cpp/memilio/io/epi_data.h @@ -337,34 +337,34 @@ class PopulationDataEntry } }; -class PopulationDataEntrySpain -{ -public: - static std::vector age_group_names; - - CustomIndexArray population; - boost::optional provincia_id; - - template - static IOResult deserialize(IoContext& io) - { - auto obj = io.expect_object("PopulationDataEntrySpain"); - auto state_id = obj.expect_optional("ID_Provincia", Tag{}); - std::vector> age_groups; - age_groups.reserve(age_group_names.size()); - std::transform(age_group_names.begin(), age_group_names.end(), std::back_inserter(age_groups), - [&obj](auto&& age_name) { - return obj.expect_element(age_name, Tag{}); - }); - return apply( - io, - [](auto&& ag, auto&& pid) { - return PopulationDataEntrySpain{ - CustomIndexArray(AgeGroup(ag.size()), ag.begin(), ag.end()), pid}; - }, - details::unpack_all(age_groups), provincia_id); - } -}; +// class PopulationDataEntrySpain +// { +// public: +// static std::vector age_group_names; + +// CustomIndexArray population; +// boost::optional provincia_id; + +// template +// static IOResult deserialize(IoContext& io) +// { +// auto obj = io.expect_object("PopulationDataEntrySpain"); +// auto state_id = obj.expect_optional("ID_Provincia", Tag{}); +// std::vector> age_groups; +// age_groups.reserve(age_group_names.size()); +// std::transform(age_group_names.begin(), age_group_names.end(), std::back_inserter(age_groups), +// [&obj](auto&& age_name) { +// return obj.expect_element(age_name, Tag{}); +// }); +// return apply( +// io, +// [](auto&& ag, auto&& pid) { +// return PopulationDataEntrySpain{ +// CustomIndexArray(AgeGroup(ag.size()), ag.begin(), ag.end()), pid}; +// }, +// details::unpack_all(age_groups), provincia_id); +// } +// }; namespace details { @@ -477,11 +477,11 @@ inline IOResult> deserialize_population_data(co * @param rki_age_groups Specifies whether population data should be interpolated to rki age groups. * @return list of population data. */ -inline IOResult> deserialize_population_data_spain(const Json::Value& jsvalue) -{ - BOOST_OUTCOME_TRY(auto&& population_data, deserialize_json(jsvalue, Tag>{})); - return success(population_data); -} +// inline IOResult> deserialize_population_data_spain(const Json::Value& jsvalue) +// { +// BOOST_OUTCOME_TRY(auto&& population_data, deserialize_json(jsvalue, Tag>{})); +// return success(population_data); +// } /** * Deserialize population data from a JSON file. @@ -502,11 +502,11 @@ inline IOResult> read_population_data(const std * @param filename JSON file that contains the population data. * @return list of population data. */ -inline IOResult> read_population_data_spain(const std::string& filename) -{ - BOOST_OUTCOME_TRY(auto&& jsvalue, read_json(filename)); - return deserialize_population_data_spain(jsvalue); -} +// inline IOResult> read_population_data_spain(const std::string& filename) +// { +// BOOST_OUTCOME_TRY(auto&& jsvalue, read_json(filename)); +// return deserialize_population_data_spain(jsvalue); +// } /** * @brief Sets the age groups' names for the ConfirmedCasesDataEntry%s. diff --git a/cpp/memilio/io/parameters_io.cpp b/cpp/memilio/io/parameters_io.cpp index 4ca6e86d85..1e90999121 100644 --- a/cpp/memilio/io/parameters_io.cpp +++ b/cpp/memilio/io/parameters_io.cpp @@ -62,30 +62,30 @@ IOResult>> read_population_data(const std::vecto return success(vnum_population); } -IOResult>> -read_population_data_spain(const std::vector& population_data, - const std::vector& vregion) -{ - std::vector> vnum_population(vregion.size(), std::vector(1, 0.0)); +// IOResult>> +// read_population_data_spain(const std::vector& population_data, +// const std::vector& vregion) +// { +// std::vector> vnum_population(vregion.size(), std::vector(1, 0.0)); - for (auto&& provincia_entry : population_data) { - printf("Test"); - //find region that this county belongs to - //all counties belong to the country (id = 0) - auto it = std::find_if(vregion.begin(), vregion.end(), [&provincia_entry](auto r) { - return r == 0 || (provincia_entry.provincia_id && regions::ProvinciaId(r) == provincia_entry.provincia_id); - }); - if (it != vregion.end()) { - auto region_idx = size_t(it - vregion.begin()); - auto& num_population = vnum_population[region_idx]; - for (size_t age = 0; age < num_population.size(); age++) { - num_population[age] += provincia_entry.population[AgeGroup(age)]; - } - } - } +// for (auto&& provincia_entry : population_data) { +// printf("Test"); +// //find region that this county belongs to +// //all counties belong to the country (id = 0) +// auto it = std::find_if(vregion.begin(), vregion.end(), [&provincia_entry](auto r) { +// return r == 0 || (provincia_entry.provincia_id && regions::ProvinciaId(r) == provincia_entry.provincia_id); +// }); +// if (it != vregion.end()) { +// auto region_idx = size_t(it - vregion.begin()); +// auto& num_population = vnum_population[region_idx]; +// for (size_t age = 0; age < num_population.size(); age++) { +// num_population[age] += provincia_entry.population[AgeGroup(age)]; +// } +// } +// } - return success(vnum_population); -} +// return success(vnum_population); +// } IOResult>> read_population_data(const std::string& path, const std::vector& vregion) diff --git a/cpp/models/ode_secir/parameters_io.h b/cpp/models/ode_secir/parameters_io.h index 1e6d5d8305..7f44ba95a1 100644 --- a/cpp/models/ode_secir/parameters_io.h +++ b/cpp/models/ode_secir/parameters_io.h @@ -352,14 +352,14 @@ IOResult set_population_data(std::vector>& model, const std::str return success(); } -template -IOResult set_population_data_provincias(std::vector>& model, const std::string& path, - const std::vector& vregion) -{ - BOOST_OUTCOME_TRY(const auto&& num_population, read_population_data_spain(path, vregion)); - BOOST_OUTCOME_TRY(set_population_data(model, num_population, vregion)); - return success(); -} +// template +// IOResult set_population_data_provincias(std::vector>& model, const std::string& path, +// const std::vector& vregion) +// { +// BOOST_OUTCOME_TRY(const auto&& num_population, read_population_data_spain(path, vregion)); +// BOOST_OUTCOME_TRY(set_population_data(model, num_population, vregion)); +// return success(); +// } } //namespace details @@ -525,32 +525,32 @@ IOResult read_input_data_county(std::vector& model, Date date, cons * @param[in] num_days [Default: 0] Number of days to be simulated; required to extrapolate real data. * @param[in] export_time_series [Default: false] If true, reads data for each day of simulation and writes it in the same directory as the input files. */ -template -IOResult read_input_data_provincias(std::vector& model, Date date, const std::vector& county, - const std::vector& scaling_factor_inf, double scaling_factor_icu, - const std::string& pydata_dir, int num_days = 0, - bool export_time_series = false) -{ - // BOOST_OUTCOME_TRY( - // details::set_divi_data(model, path_join(pydata_dir, "county_divi_ma7.json"), county, date, scaling_factor_icu)); - // BOOST_OUTCOME_TRY(details::set_confirmed_cases_data(model, path_join(pydata_dir, "cases_all_county_ma7.json"), - // county, date, scaling_factor_inf)); - BOOST_OUTCOME_TRY(details::set_population_data_provincias( - model, path_join(pydata_dir, "provincias_current_population.json"), county)); - - if (export_time_series) { - // Use only if extrapolated real data is needed for comparison. EXPENSIVE ! - // Run time equals run time of the previous functions times the num_days ! - // (This only represents the vectorization of the previous function over all simulation days...) - log_warning("Exporting time series of extrapolated real data. This may take some minutes. " - "For simulation runs over the same time period, deactivate it."); - BOOST_OUTCOME_TRY(export_input_data_county_timeseries( - model, pydata_dir, county, date, scaling_factor_inf, scaling_factor_icu, num_days, - path_join(pydata_dir, "county_divi_ma7.json"), path_join(pydata_dir, "cases_all_county_age_ma7.json"), - path_join(pydata_dir, "county_current_population.json"))); - } - return success(); -} +// template +// IOResult read_input_data_provincias(std::vector& model, Date date, const std::vector& county, +// const std::vector& scaling_factor_inf, double scaling_factor_icu, +// const std::string& pydata_dir, int num_days = 0, +// bool export_time_series = false) +// { +// // BOOST_OUTCOME_TRY( +// // details::set_divi_data(model, path_join(pydata_dir, "county_divi_ma7.json"), county, date, scaling_factor_icu)); +// // BOOST_OUTCOME_TRY(details::set_confirmed_cases_data(model, path_join(pydata_dir, "cases_all_county_ma7.json"), +// // county, date, scaling_factor_inf)); +// BOOST_OUTCOME_TRY(details::set_population_data_provincias( +// model, path_join(pydata_dir, "provincias_current_population.json"), county)); + +// if (export_time_series) { +// // Use only if extrapolated real data is needed for comparison. EXPENSIVE ! +// // Run time equals run time of the previous functions times the num_days ! +// // (This only represents the vectorization of the previous function over all simulation days...) +// log_warning("Exporting time series of extrapolated real data. This may take some minutes. " +// "For simulation runs over the same time period, deactivate it."); +// BOOST_OUTCOME_TRY(export_input_data_county_timeseries( +// model, pydata_dir, county, date, scaling_factor_inf, scaling_factor_icu, num_days, +// path_join(pydata_dir, "county_divi_ma7.json"), path_join(pydata_dir, "cases_all_county_age_ma7.json"), +// path_join(pydata_dir, "county_current_population.json"))); +// } +// return success(); +// } /** * @brief reads population data from population files for the specified nodes From 2d0dda4dcf03ff044b70135ec2c608d7e5dbef19 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Tue, 12 Aug 2025 14:33:15 +0200 Subject: [PATCH 13/73] made more errors --- cpp/models/ode_secir/parameters_io.cpp | 71 ++++++++-------- cpp/models/ode_secir/parameters_io.h | 83 ++++++++++--------- .../simulation/graph_germany_nuts3.py | 2 +- 3 files changed, 78 insertions(+), 78 deletions(-) diff --git a/cpp/models/ode_secir/parameters_io.cpp b/cpp/models/ode_secir/parameters_io.cpp index dbaa1242b7..f29db18e43 100644 --- a/cpp/models/ode_secir/parameters_io.cpp +++ b/cpp/models/ode_secir/parameters_io.cpp @@ -189,15 +189,15 @@ IOResult read_confirmed_cases_data( IOResult read_confirmed_cases_noage( std::vector& rki_data, std::vector const& vregion, Date date, - std::vector>& vnum_Exposed, std::vector>& vnum_InfectedNoSymptoms, - std::vector>& vnum_InfectedSymptoms, std::vector>& vnum_InfectedSevere, - std::vector>& vnum_icu, std::vector>& vnum_death, - std::vector>& vnum_rec, const std::vector>& vt_Exposed, - const std::vector>& vt_InfectedNoSymptoms, - const std::vector>& vt_InfectedSymptoms, const std::vector>& vt_InfectedSevere, - const std::vector>& vt_InfectedCritical, const std::vector>& vmu_C_R, - const std::vector>& vmu_I_H, const std::vector>& vmu_H_U, - const std::vector& scaling_factor_inf) + std::vector vnum_Exposed, std::vector vnum_InfectedNoSymptoms, + std::vector vnum_InfectedSymptoms, std::vector vnum_InfectedSevere, + std::vector vnum_icu, std::vector vnum_death, + std::vector vnum_rec, const std::vector vt_Exposed, + const std::vector vt_InfectedNoSymptoms, + const std::vector vt_InfectedSymptoms, const std::vector vt_InfectedSevere, + const std::vector vt_InfectedCritical, const std::vector vmu_C_R, + const std::vector vmu_I_H, const std::vector vmu_H_U, + const double scaling_factor_inf) { auto max_date_entry = std::max_element(rki_data.begin(), rki_data.end(), [](auto&& a, auto&& b) { return a.date < b.date; @@ -251,36 +251,35 @@ IOResult read_confirmed_cases_noage( auto& mu_H_U = vmu_H_U[region_idx]; auto date_df = region_entry.date; - auto age = 0; if (date_df == offset_date_by_days(date, 0)) { - num_InfectedSymptoms[age] += scaling_factor_inf[age] * region_entry.num_confirmed; - num_rec[age] += region_entry.num_confirmed; + num_InfectedSymptoms += scaling_factor_inf * region_entry.num_confirmed; + num_rec += region_entry.num_confirmed; } if (date_df == offset_date_by_days(date, days_surplus)) { - num_InfectedNoSymptoms[age] -= - 1 / (1 - mu_C_R[age]) * scaling_factor_inf[age] * region_entry.num_confirmed; + num_InfectedNoSymptoms -= + 1 / (1 - mu_C_R) * scaling_factor_inf * region_entry.num_confirmed; } - if (date_df == offset_date_by_days(date, t_InfectedNoSymptoms[age] + days_surplus)) { - num_InfectedNoSymptoms[age] += - 1 / (1 - mu_C_R[age]) * scaling_factor_inf[age] * region_entry.num_confirmed; - num_Exposed[age] -= 1 / (1 - mu_C_R[age]) * scaling_factor_inf[age] * region_entry.num_confirmed; + if (date_df == offset_date_by_days(date, t_InfectedNoSymptoms + days_surplus)) { + num_InfectedNoSymptoms += + 1 / (1 - mu_C_R) * scaling_factor_inf * region_entry.num_confirmed; + num_Exposed -= 1 / (1 - mu_C_R) * scaling_factor_inf * region_entry.num_confirmed; } - if (date_df == offset_date_by_days(date, t_Exposed[age] + t_InfectedNoSymptoms[age] + days_surplus)) { - num_Exposed[age] += 1 / (1 - mu_C_R[age]) * scaling_factor_inf[age] * region_entry.num_confirmed; + if (date_df == offset_date_by_days(date, t_Exposed + t_InfectedNoSymptoms + days_surplus)) { + num_Exposed += 1 / (1 - mu_C_R) * scaling_factor_inf * region_entry.num_confirmed; } - if (date_df == offset_date_by_days(date, -t_InfectedSymptoms[age])) { - num_InfectedSymptoms[age] -= scaling_factor_inf[age] * region_entry.num_confirmed; - num_InfectedSevere[age] += mu_I_H[age] * scaling_factor_inf[age] * region_entry.num_confirmed; + if (date_df == offset_date_by_days(date, -t_InfectedSymptoms)) { + num_InfectedSymptoms -= scaling_factor_inf * region_entry.num_confirmed; + num_InfectedSevere += mu_I_H * scaling_factor_inf * region_entry.num_confirmed; } - if (date_df == offset_date_by_days(date, -t_InfectedSymptoms[age] - t_InfectedSevere[age])) { - num_InfectedSevere[age] -= mu_I_H[age] * scaling_factor_inf[age] * region_entry.num_confirmed; - num_icu[age] += mu_I_H[age] * mu_H_U[age] * scaling_factor_inf[age] * region_entry.num_confirmed; + if (date_df == offset_date_by_days(date, -t_InfectedSymptoms - t_InfectedSevere)) { + num_InfectedSevere -= mu_I_H * scaling_factor_inf * region_entry.num_confirmed; + num_icu += mu_I_H * mu_H_U * scaling_factor_inf * region_entry.num_confirmed; } if (date_df == - offset_date_by_days(date, -t_InfectedSymptoms[age] - t_InfectedSevere[age] - t_InfectedCritical[age])) { - num_death[age] += region_entry.num_deaths; - num_icu[age] -= mu_I_H[age] * mu_H_U[age] * scaling_factor_inf[age] * region_entry.num_confirmed; + offset_date_by_days(date, -t_InfectedSymptoms - t_InfectedSevere - t_InfectedCritical)) { + num_death += region_entry.num_deaths; + num_icu -= mu_I_H * mu_H_U * scaling_factor_inf * region_entry.num_confirmed; } } } @@ -309,13 +308,13 @@ IOResult read_confirmed_cases_noage( } }; - try_fix_constraints(num_InfectedSymptoms[0], -5, "InfectedSymptoms"); - try_fix_constraints(num_InfectedNoSymptoms[0], -5, "InfectedNoSymptoms"); - try_fix_constraints(num_Exposed[0], -5, "Exposed"); - try_fix_constraints(num_InfectedSevere[0], -5, "InfectedSevere"); - try_fix_constraints(num_death[0], -5, "Dead"); - try_fix_constraints(num_icu[0], -5, "InfectedCritical"); - try_fix_constraints(num_rec[0], -20, "Recovered"); + try_fix_constraints(num_InfectedSymptoms, -5, "InfectedSymptoms"); + try_fix_constraints(num_InfectedNoSymptoms, -5, "InfectedNoSymptoms"); + try_fix_constraints(num_Exposed, -5, "Exposed"); + try_fix_constraints(num_InfectedSevere, -5, "InfectedSevere"); + try_fix_constraints(num_death, -5, "Dead"); + try_fix_constraints(num_icu, -5, "InfectedCritical"); + try_fix_constraints(num_rec, -20, "Recovered"); } return success(); diff --git a/cpp/models/ode_secir/parameters_io.h b/cpp/models/ode_secir/parameters_io.h index 7f44ba95a1..79f4a3a232 100644 --- a/cpp/models/ode_secir/parameters_io.h +++ b/cpp/models/ode_secir/parameters_io.h @@ -153,15 +153,15 @@ IOResult set_confirmed_cases_data(std::vector>& model, std::vect IOResult read_confirmed_cases_noage( std::vector& rki_data, std::vector const& vregion, Date date, - std::vector>& vnum_Exposed, std::vector>& vnum_InfectedNoSymptoms, - std::vector>& vnum_InfectedSymptoms, std::vector>& vnum_InfectedSevere, - std::vector>& vnum_icu, std::vector>& vnum_death, - std::vector>& vnum_rec, const std::vector>& vt_Exposed, - const std::vector>& vt_InfectedNoSymptoms, - const std::vector>& vt_InfectedSymptoms, const std::vector>& vt_InfectedSevere, - const std::vector>& vt_InfectedCritical, const std::vector>& vmu_C_R, - const std::vector>& vmu_I_H, const std::vector>& vmu_H_U, - const std::vector& scaling_factor_inf); + std::vector& vnum_Exposed, std::vector& vnum_InfectedNoSymptoms, + std::vector& vnum_InfectedSymptoms, std::vector& vnum_InfectedSevere, + std::vector& vnum_icu, std::vector& vnum_death, + std::vector& vnum_rec, const std::vector& vt_Exposed, + const std::vector& vt_InfectedNoSymptoms, + const std::vector& vt_InfectedSymptoms, const std::vector& vt_InfectedSevere, + const std::vector& vt_InfectedCritical, const std::vector& vmu_C_R, + const std::vector& vmu_I_H, const std::vector& vmu_H_U, + const double& scaling_factor_inf); /** * @brief Sets populations data from already read case data with multiple age groups into a Model with one age group. @@ -175,44 +175,44 @@ IOResult read_confirmed_cases_noage( template IOResult set_confirmed_cases_noage(std::vector>& model, std::vector& case_data, - const std::vector& region, Date date, const std::vector& scaling_factor_inf) + const std::vector& region, Date date, const double scaling_factor_inf) { - std::vector> t_InfectedNoSymptoms{model.size()}; - std::vector> t_Exposed{model.size()}; - std::vector> t_InfectedSymptoms{model.size()}; - std::vector> t_InfectedSevere{model.size()}; - std::vector> t_InfectedCritical{model.size()}; + std::vector t_InfectedNoSymptoms{model.size()}; + std::vector t_Exposed{model.size()}; + std::vector t_InfectedSymptoms{model.size()}; + std::vector t_InfectedSevere{model.size()}; + std::vector t_InfectedCritical{model.size()}; - std::vector> mu_C_R{model.size()}; - std::vector> mu_I_H{model.size()}; - std::vector> mu_H_U{model.size()}; - std::vector> mu_U_D{model.size()}; + std::vector mu_C_R{model.size()}; + std::vector mu_I_H{model.size()}; + std::vector mu_H_U{model.size()}; + std::vector mu_U_D{model.size()}; for (size_t node = 0; node < model.size(); ++node) { - t_Exposed[node].push_back( - static_cast(std::round(model[node].parameters.template get>()[AgeGroup(0)]))); - t_InfectedNoSymptoms[node].push_back(static_cast( - std::round(model[node].parameters.template get>()[AgeGroup(0)]))); - t_InfectedSymptoms[node].push_back( - static_cast(std::round(model[node].parameters.template get>()[AgeGroup(0)]))); - t_InfectedSevere[node].push_back( - static_cast(std::round(model[node].parameters.template get>()[AgeGroup(0)]))); - t_InfectedCritical[node].push_back( - static_cast(std::round(model[node].parameters.template get>()[AgeGroup(0)]))); - - mu_C_R[node].push_back(model[node].parameters.template get>()[AgeGroup(0)]); - mu_I_H[node].push_back(model[node].parameters.template get>()[AgeGroup(0)]); - mu_H_U[node].push_back(model[node].parameters.template get>()[AgeGroup(0)]); - mu_U_D[node].push_back(model[node].parameters.template get>()[AgeGroup(0)]); + t_Exposed[node] = + static_cast(std::round(model[node].parameters.template get>())); + t_InfectedNoSymptoms[node]= static_cast( + std::round(model[node].parameters.template get>()[AgeGroup(0)])); + t_InfectedSymptoms[node] = + static_cast(std::round(model[node].parameters.template get>()[AgeGroup(0)])); + t_InfectedSevere[node] = + static_cast(std::round(model[node].parameters.template get>()[AgeGroup(0)])); + t_InfectedCritical[node] = + static_cast(std::round(model[node].parameters.template get>()[AgeGroup(0)])); + + mu_C_R[node] = model[node].parameters.template get>()[AgeGroup(0)]; + mu_I_H[node] = model[node].parameters.template get>()[AgeGroup(0)]; + mu_H_U[node] = model[node].parameters.template get>()[AgeGroup(0)]; + mu_U_D[node] = model[node].parameters.template get>()[AgeGroup(0)]; } - std::vector> num_InfectedSymptoms(model.size(), std::vector(1, 0.0)); - std::vector> num_death(model.size(), std::vector(1, 0.0)); - std::vector> num_rec(model.size(), std::vector(1, 0.0)); - std::vector> num_Exposed(model.size(), std::vector(1, 0.0)); - std::vector> num_InfectedNoSymptoms(model.size(), std::vector(1, 0.0)); - std::vector> num_InfectedSevere(model.size(), std::vector(1, 0.0)); - std::vector> num_icu(model.size(), std::vector(1, 0.0)); + std::vector num_InfectedSymptoms(model.size(), 0.0); + std::vector num_death(model.size(), 0.0); + std::vector num_rec(model.size(), 0.0); + std::vector num_Exposed(model.size(), 0.0); + std::vector num_InfectedNoSymptoms(model.size(), 0.0); + std::vector num_InfectedSevere(model.size(), 0.0); + std::vector num_icu(model.size(), 0.0); BOOST_OUTCOME_TRY(read_confirmed_cases_noage(case_data, region, date, num_Exposed, num_InfectedNoSymptoms, num_InfectedSymptoms, num_InfectedSevere, num_icu, num_death, num_rec, @@ -262,6 +262,7 @@ IOResult set_confirmed_cases_data(std::vector>& model, const std std::vector const& region, Date date, const std::vector& scaling_factor_inf) { + printf("Path\n"); BOOST_OUTCOME_TRY(auto&& case_data, mio::read_confirmed_cases_noage(path)); BOOST_OUTCOME_TRY(set_confirmed_cases_noage(model, case_data, region, date, scaling_factor_inf)); return success(); diff --git a/pycode/examples/simulation/graph_germany_nuts3.py b/pycode/examples/simulation/graph_germany_nuts3.py index 1588dc459f..a627d753fb 100644 --- a/pycode/examples/simulation/graph_germany_nuts3.py +++ b/pycode/examples/simulation/graph_germany_nuts3.py @@ -110,7 +110,7 @@ def get_graph(self, end_date): graph = osecir.ModelGraph() - scaling_factor_infected = [2.5] + scaling_factor_infected = [1.0] scaling_factor_icu = 1.0 tnt_capacity_factor = 7.5 / 100000. From 65f63c030dc58bce9a623e04586066cbab3882e5 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Tue, 12 Aug 2025 14:38:48 +0200 Subject: [PATCH 14/73] fix some again --- cpp/models/ode_secir/parameters_io.h | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/cpp/models/ode_secir/parameters_io.h b/cpp/models/ode_secir/parameters_io.h index 79f4a3a232..481a71c746 100644 --- a/cpp/models/ode_secir/parameters_io.h +++ b/cpp/models/ode_secir/parameters_io.h @@ -220,24 +220,21 @@ set_confirmed_cases_noage(std::vector>& model, std::vector 0) { - size_t num_groups = (size_t)model[node].parameters.get_num_groups(); - for (size_t i = 0; i < num_groups; i++) { - model[node].populations[{AgeGroup(i), InfectionState::Exposed}] = num_Exposed[node][i]; - model[node].populations[{AgeGroup(i), InfectionState::InfectedNoSymptoms}] = - num_InfectedNoSymptoms[node][i]; - model[node].populations[{AgeGroup(i), InfectionState::InfectedNoSymptomsConfirmed}] = 0; - model[node].populations[{AgeGroup(i), InfectionState::InfectedSymptoms}] = - num_InfectedSymptoms[node][i]; - model[node].populations[{AgeGroup(i), InfectionState::InfectedSymptomsConfirmed}] = 0; - model[node].populations[{AgeGroup(i), InfectionState::InfectedSevere}] = num_InfectedSevere[node][i]; + if (num_InfectedSymptoms[node] > 0) { + model[node].populations[{InfectionState::Exposed}] = num_Exposed[node]; + model[node].populations[{InfectionState::InfectedNoSymptoms}] = + num_InfectedNoSymptoms[node]; + model[node].populations[{InfectionState::InfectedNoSymptomsConfirmed}] = 0; + model[node].populations[{InfectionState::InfectedSymptoms}] = + num_InfectedSymptoms[node]; + model[node].populations[{InfectionState::InfectedSymptomsConfirmed}] = 0; + model[node].populations[{InfectionState::InfectedSevere}] = num_InfectedSevere[node]; // Only set the number of ICU patients here, if the date is not available in the data. if (!is_divi_data_available(date)) { - model[node].populations[{AgeGroup(i), InfectionState::InfectedCritical}] = num_icu[node][i]; + model[node].populations[{InfectionState::InfectedCritical}] = num_icu[node]; } - model[node].populations[{AgeGroup(i), InfectionState::Dead}] = num_death[node][i]; - model[node].populations[{AgeGroup(i), InfectionState::Recovered}] = num_rec[node][i]; - } + model[node].populations[{InfectionState::Dead}] = num_death[node]; + model[node].populations[{InfectionState::Recovered}] = num_rec[node]; } else { log_warning("No infections reported on date {} for region {}. Population data has not been set.", date, From 36002bd0e924f2251023ccf7a3a94b742e8037bb Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Tue, 12 Aug 2025 14:59:43 +0200 Subject: [PATCH 15/73] builds but cannot be imported --- cpp/models/ode_secir/parameters_io.h | 76 ++++++++++++++-------------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/cpp/models/ode_secir/parameters_io.h b/cpp/models/ode_secir/parameters_io.h index 481a71c746..332420deaa 100644 --- a/cpp/models/ode_secir/parameters_io.h +++ b/cpp/models/ode_secir/parameters_io.h @@ -177,21 +177,21 @@ IOResult set_confirmed_cases_noage(std::vector>& model, std::vector& case_data, const std::vector& region, Date date, const double scaling_factor_inf) { - std::vector t_InfectedNoSymptoms{model.size()}; - std::vector t_Exposed{model.size()}; - std::vector t_InfectedSymptoms{model.size()}; - std::vector t_InfectedSevere{model.size()}; - std::vector t_InfectedCritical{model.size()}; + std::vector t_InfectedNoSymptoms{1}; + std::vector t_Exposed{1}; + std::vector t_InfectedSymptoms{1}; + std::vector t_InfectedSevere{1}; + std::vector t_InfectedCritical{1}; - std::vector mu_C_R{model.size()}; - std::vector mu_I_H{model.size()}; - std::vector mu_H_U{model.size()}; - std::vector mu_U_D{model.size()}; + std::vector mu_C_R{1}; + std::vector mu_I_H{1}; + std::vector mu_H_U{1}; + std::vector mu_U_D{1}; for (size_t node = 0; node < model.size(); ++node) { t_Exposed[node] = - static_cast(std::round(model[node].parameters.template get>())); + static_cast(std::round(model[node].parameters.template get>()[AgeGroup(0)])); t_InfectedNoSymptoms[node]= static_cast( std::round(model[node].parameters.template get>()[AgeGroup(0)])); t_InfectedSymptoms[node] = @@ -206,13 +206,13 @@ set_confirmed_cases_noage(std::vector>& model, std::vector>()[AgeGroup(0)]; mu_U_D[node] = model[node].parameters.template get>()[AgeGroup(0)]; } - std::vector num_InfectedSymptoms(model.size(), 0.0); - std::vector num_death(model.size(), 0.0); - std::vector num_rec(model.size(), 0.0); - std::vector num_Exposed(model.size(), 0.0); - std::vector num_InfectedNoSymptoms(model.size(), 0.0); - std::vector num_InfectedSevere(model.size(), 0.0); - std::vector num_icu(model.size(), 0.0); + std::vector num_InfectedSymptoms(1, 0.0); + std::vector num_death(1, 0.0); + std::vector num_rec(1, 0.0); + std::vector num_Exposed(1, 0.0); + std::vector num_InfectedNoSymptoms(1, 0.0); + std::vector num_InfectedSevere(1, 0.0); + std::vector num_icu(1, 0.0); BOOST_OUTCOME_TRY(read_confirmed_cases_noage(case_data, region, date, num_Exposed, num_InfectedNoSymptoms, num_InfectedSymptoms, num_InfectedSevere, num_icu, num_death, num_rec, @@ -221,20 +221,20 @@ set_confirmed_cases_noage(std::vector>& model, std::vector 0) { - model[node].populations[{InfectionState::Exposed}] = num_Exposed[node]; - model[node].populations[{InfectionState::InfectedNoSymptoms}] = + model[node].populations[{AgeGroup(0), InfectionState::Exposed}] = num_Exposed[node]; + model[node].populations[{AgeGroup(0), InfectionState::InfectedNoSymptoms}] = num_InfectedNoSymptoms[node]; - model[node].populations[{InfectionState::InfectedNoSymptomsConfirmed}] = 0; - model[node].populations[{InfectionState::InfectedSymptoms}] = + model[node].populations[{AgeGroup(0), InfectionState::InfectedNoSymptomsConfirmed}] = 0; + model[node].populations[{AgeGroup(0), InfectionState::InfectedSymptoms}] = num_InfectedSymptoms[node]; - model[node].populations[{InfectionState::InfectedSymptomsConfirmed}] = 0; - model[node].populations[{InfectionState::InfectedSevere}] = num_InfectedSevere[node]; + model[node].populations[{AgeGroup(0), InfectionState::InfectedSymptomsConfirmed}] = 0; + model[node].populations[{AgeGroup(0), InfectionState::InfectedSevere}] = num_InfectedSevere[node]; // Only set the number of ICU patients here, if the date is not available in the data. if (!is_divi_data_available(date)) { - model[node].populations[{InfectionState::InfectedCritical}] = num_icu[node]; + model[node].populations[{AgeGroup(0), InfectionState::InfectedCritical}] = num_icu[node]; } - model[node].populations[{InfectionState::Dead}] = num_death[node]; - model[node].populations[{InfectionState::Recovered}] = num_rec[node]; + model[node].populations[{AgeGroup(0), InfectionState::Dead}] = num_death[node]; + model[node].populations[{AgeGroup(0), InfectionState::Recovered}] = num_rec[node]; } else { log_warning("No infections reported on date {} for region {}. Population data has not been set.", date, @@ -261,7 +261,7 @@ IOResult set_confirmed_cases_data(std::vector>& model, const std { printf("Path\n"); BOOST_OUTCOME_TRY(auto&& case_data, mio::read_confirmed_cases_noage(path)); - BOOST_OUTCOME_TRY(set_confirmed_cases_noage(model, case_data, region, date, scaling_factor_inf)); + BOOST_OUTCOME_TRY(set_confirmed_cases_noage(model, case_data, region, date, scaling_factor_inf[0])); return success(); } @@ -498,17 +498,17 @@ IOResult read_input_data_county(std::vector& model, Date date, cons BOOST_OUTCOME_TRY(details::set_population_data( model, path_join(pydata_dir, "county_current_population_aggregated.json"), county)); - if (export_time_series) { - // Use only if extrapolated real data is needed for comparison. EXPENSIVE ! - // Run time equals run time of the previous functions times the num_days ! - // (This only represents the vectorization of the previous function over all simulation days...) - log_warning("Exporting time series of extrapolated real data. This may take some minutes. " - "For simulation runs over the same time period, deactivate it."); - BOOST_OUTCOME_TRY(export_input_data_county_timeseries( - model, pydata_dir, county, date, scaling_factor_inf, scaling_factor_icu, num_days, - path_join(pydata_dir, "county_divi_ma7.json"), path_join(pydata_dir, "cases_all_county_age_ma7.json"), - path_join(pydata_dir, "county_current_population.json"))); - } + // if (export_time_series) { + // // Use only if extrapolated real data is needed for comparison. EXPENSIVE ! + // // Run time equals run time of the previous functions times the num_days ! + // // (This only represents the vectorization of the previous function over all simulation days...) + // log_warning("Exporting time series of extrapolated real data. This may take some minutes. " + // "For simulation runs over the same time period, deactivate it."); + // // BOOST_OUTCOME_TRY(export_input_data_county_timeseries( + // // model, pydata_dir, county, date, scaling_factor_inf, scaling_factor_icu, num_days, + // // path_join(pydata_dir, "county_divi_ma7.json"), path_join(pydata_dir, "cases_all_county_age_ma7.json"), + // // path_join(pydata_dir, "county_current_population.json"))); + // } return success(); } From 837b331b927752eec18077a6d85af71c19504888 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Wed, 13 Aug 2025 09:08:37 +0200 Subject: [PATCH 16/73] fix read data without age groups --- cpp/models/ode_secir/parameters_io.cpp | 16 +++++------ cpp/models/ode_secir/parameters_io.h | 37 +++++++++++++------------- 2 files changed, 26 insertions(+), 27 deletions(-) diff --git a/cpp/models/ode_secir/parameters_io.cpp b/cpp/models/ode_secir/parameters_io.cpp index f29db18e43..64b0bbce91 100644 --- a/cpp/models/ode_secir/parameters_io.cpp +++ b/cpp/models/ode_secir/parameters_io.cpp @@ -189,14 +189,14 @@ IOResult read_confirmed_cases_data( IOResult read_confirmed_cases_noage( std::vector& rki_data, std::vector const& vregion, Date date, - std::vector vnum_Exposed, std::vector vnum_InfectedNoSymptoms, - std::vector vnum_InfectedSymptoms, std::vector vnum_InfectedSevere, - std::vector vnum_icu, std::vector vnum_death, - std::vector vnum_rec, const std::vector vt_Exposed, - const std::vector vt_InfectedNoSymptoms, - const std::vector vt_InfectedSymptoms, const std::vector vt_InfectedSevere, - const std::vector vt_InfectedCritical, const std::vector vmu_C_R, - const std::vector vmu_I_H, const std::vector vmu_H_U, + std::vector& vnum_Exposed, std::vector& vnum_InfectedNoSymptoms, + std::vector& vnum_InfectedSymptoms, std::vector& vnum_InfectedSevere, + std::vector& vnum_icu, std::vector& vnum_death, + std::vector& vnum_rec, const std::vector& vt_Exposed, + const std::vector& vt_InfectedNoSymptoms, + const std::vector& vt_InfectedSymptoms, const std::vector& vt_InfectedSevere, + const std::vector& vt_InfectedCritical, const std::vector& vmu_C_R, + const std::vector& vmu_I_H, const std::vector& vmu_H_U, const double scaling_factor_inf) { auto max_date_entry = std::max_element(rki_data.begin(), rki_data.end(), [](auto&& a, auto&& b) { diff --git a/cpp/models/ode_secir/parameters_io.h b/cpp/models/ode_secir/parameters_io.h index 332420deaa..4ea862d88c 100644 --- a/cpp/models/ode_secir/parameters_io.h +++ b/cpp/models/ode_secir/parameters_io.h @@ -161,7 +161,7 @@ IOResult read_confirmed_cases_noage( const std::vector& vt_InfectedSymptoms, const std::vector& vt_InfectedSevere, const std::vector& vt_InfectedCritical, const std::vector& vmu_C_R, const std::vector& vmu_I_H, const std::vector& vmu_H_U, - const double& scaling_factor_inf); + const double scaling_factor_inf); /** * @brief Sets populations data from already read case data with multiple age groups into a Model with one age group. @@ -177,16 +177,16 @@ IOResult set_confirmed_cases_noage(std::vector>& model, std::vector& case_data, const std::vector& region, Date date, const double scaling_factor_inf) { - std::vector t_InfectedNoSymptoms{1}; - std::vector t_Exposed{1}; - std::vector t_InfectedSymptoms{1}; - std::vector t_InfectedSevere{1}; - std::vector t_InfectedCritical{1}; + std::vector t_InfectedNoSymptoms(model.size()); + std::vector t_Exposed(model.size()); + std::vector t_InfectedSymptoms(model.size()); + std::vector t_InfectedSevere(model.size()); + std::vector t_InfectedCritical(model.size()); - std::vector mu_C_R{1}; - std::vector mu_I_H{1}; - std::vector mu_H_U{1}; - std::vector mu_U_D{1}; + std::vector mu_C_R(model.size()); + std::vector mu_I_H(model.size()); + std::vector mu_H_U(model.size()); + std::vector mu_U_D(model.size()); for (size_t node = 0; node < model.size(); ++node) { @@ -206,13 +206,13 @@ set_confirmed_cases_noage(std::vector>& model, std::vector>()[AgeGroup(0)]; mu_U_D[node] = model[node].parameters.template get>()[AgeGroup(0)]; } - std::vector num_InfectedSymptoms(1, 0.0); - std::vector num_death(1, 0.0); - std::vector num_rec(1, 0.0); - std::vector num_Exposed(1, 0.0); - std::vector num_InfectedNoSymptoms(1, 0.0); - std::vector num_InfectedSevere(1, 0.0); - std::vector num_icu(1, 0.0); + std::vector num_InfectedSymptoms(model.size(), 0.0); + std::vector num_death(model.size(), 0.0); + std::vector num_rec(model.size(), 0.0); + std::vector num_Exposed(model.size(), 0.0); + std::vector num_InfectedNoSymptoms(model.size(), 0.0); + std::vector num_InfectedSevere(model.size(), 0.0); + std::vector num_icu(model.size(), 0.0); BOOST_OUTCOME_TRY(read_confirmed_cases_noage(case_data, region, date, num_Exposed, num_InfectedNoSymptoms, num_InfectedSymptoms, num_InfectedSevere, num_icu, num_death, num_rec, @@ -259,7 +259,6 @@ IOResult set_confirmed_cases_data(std::vector>& model, const std std::vector const& region, Date date, const std::vector& scaling_factor_inf) { - printf("Path\n"); BOOST_OUTCOME_TRY(auto&& case_data, mio::read_confirmed_cases_noage(path)); BOOST_OUTCOME_TRY(set_confirmed_cases_noage(model, case_data, region, date, scaling_factor_inf[0])); return success(); @@ -489,7 +488,7 @@ IOResult read_input_data_state(std::vector& model, Date date, std:: template IOResult read_input_data_county(std::vector& model, Date date, const std::vector& county, const std::vector& scaling_factor_inf, double scaling_factor_icu, - const std::string& pydata_dir, int num_days = 0, bool export_time_series = false) + const std::string& pydata_dir,int /*num_days*/ = 0, bool /*export_time_series*/ = false) { BOOST_OUTCOME_TRY( details::set_divi_data(model, path_join(pydata_dir, "county_divi_ma7.json"), county, date, scaling_factor_icu)); From c4e283d7ef4a17d402d12bf7c6ce9e2d85591ded Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Wed, 13 Aug 2025 18:59:36 +0200 Subject: [PATCH 17/73] ready for trainingsdata creation --- cpp/models/ode_secir/parameters_io.cpp | 2 +- .../simulation/graph_germany_nuts3.py | 67 ++++++++----------- 2 files changed, 30 insertions(+), 39 deletions(-) diff --git a/cpp/models/ode_secir/parameters_io.cpp b/cpp/models/ode_secir/parameters_io.cpp index 64b0bbce91..f21ab9d471 100644 --- a/cpp/models/ode_secir/parameters_io.cpp +++ b/cpp/models/ode_secir/parameters_io.cpp @@ -254,7 +254,7 @@ IOResult read_confirmed_cases_noage( if (date_df == offset_date_by_days(date, 0)) { num_InfectedSymptoms += scaling_factor_inf * region_entry.num_confirmed; - num_rec += region_entry.num_confirmed; + num_rec += region_entry.num_recovered; } if (date_df == offset_date_by_days(date, days_surplus)) { num_InfectedNoSymptoms -= diff --git a/pycode/examples/simulation/graph_germany_nuts3.py b/pycode/examples/simulation/graph_germany_nuts3.py index a627d753fb..9919ee81c9 100644 --- a/pycode/examples/simulation/graph_germany_nuts3.py +++ b/pycode/examples/simulation/graph_germany_nuts3.py @@ -27,9 +27,12 @@ from enum import Enum from memilio.simulation.osecir import (Model, Simulation, interpolate_simulation_result) - +import pandas as pd +from memilio.epidata import defaultDict as dd import pickle +excluded_ids = [11001, 11002, 11003, 11004, 11005, 11006, 11007, 11008, 11009, 11010, 11011, 11012, 16056, 7338, 9374, 9473, 9573] +region_ids = [region_id for region_id in dd.County.keys() if region_id not in excluded_ids] class Simulation: """ """ @@ -90,7 +93,7 @@ def set_npis(self, params, end_date, damping_value): """ start_damping = datetime.date( - 2020, 12, 18) + 2020, 7, 8) if start_damping < end_date: start_date = (start_damping - self.start_date).days @@ -110,7 +113,7 @@ def get_graph(self, end_date): graph = osecir.ModelGraph() - scaling_factor_infected = [1.0] + scaling_factor_infected = [1] scaling_factor_icu = 1.0 tnt_capacity_factor = 7.5 / 100000. @@ -163,7 +166,7 @@ def run(self, num_days_sim, damping_values, save_graph=True): mobility_graph = osecir.MobilityGraph() for node_idx in range(graph.num_nodes): node = graph.get_node(node_idx) - self.set_npis(node.property.parameters, end_date, damping_values[node_idx]) + self.set_npis(node.property.parameters, end_date, damping_values[node_idx // 1000 -1]) mobility_graph.add_node(node.id, node.property) for edge_idx in range(graph.num_edges): mobility_graph.add_edge( @@ -192,59 +195,47 @@ def run_germany_nuts3_simulation(damping_values): results = sim.run(num_days_sim, damping_values) - return {f'region{region}': results[region] for region in range(len(results))} + return {f'region{region_idx}': results[region_idx] for region_idx, region_id in enumerate(region_ids)} def prior(): - damping_values = np.random.uniform(0.0, 1.0, 400) + damping_values = np.random.uniform(0.0, 1.0, 16) return {'damping_values': damping_values} def load_divi_data(): file_path = os.path.dirname(os.path.abspath(__file__)) divi_path = os.path.join(file_path, "../../../data/Germany/pydata") - data = pd.read_json(os.path.join(divi_path, "county_divi.json")) - print(data["ID_County"].drop_duplicates().shape) - data = data[data['Date']>= np.datetime64(datetime.date(2020, 8, 1))] - data = data[data['Date'] <= np.datetime64(datetime.date(2020, 8, 1) + datetime.timedelta(days=50))] - print(data["ID_County"].drop_duplicates().shape) + data = pd.read_json(os.path.join(divi_path, "county_divi_ma7.json")) + data = data[data['Date']>= np.datetime64(datetime.date(2020, 7, 1))] + data = data[data['Date'] <= np.datetime64(datetime.date(2020, 7, 1) + datetime.timedelta(days=50))] data = data.drop(columns=['County', 'ICU_ventilated', 'Date']) - region_ids = [*dd.County] - divi_dict = {f"region{i}": data[data['ID_County'] == region_ids[i]]['ICU'].to_numpy() for i in range(400)} - # for i in range(100): - # if divi_dict[f'region{i+100}'].size==0: - # print(region_ids[i+100]) - # print(divi_dict[f'region{i+100}'].shape) + divi_dict = {f"region{i}": data[data['ID_County'] == region_id]['ICU'].to_numpy().reshape((1, 51, 1)) for i, region_id in enumerate(region_ids)} + + return divi_dict if __name__ == "__main__": - # import pandas as pd - - # file_path = os.path.dirname(os.path.abspath(__file__)) - # casedata_path = os.path.join(file_path, "../../../data/Germany/pydata/cases_all_county_ma7.jsons") - # county = test[test['ID_County']==7320] - # print(county[county['Date']>= np.datetime64(datetime.date(2020, 8, 1))]) - # from memilio.epidata import defaultDict as dd - # load_divi_data() + import os os.environ["KERAS_BACKEND"] = "tensorflow" import bayesflow as bf simulator = bf.simulators.make_simulator([prior, run_germany_nuts3_simulation]) - trainings_data = simulator.sample(1) + trainings_data = simulator.sample(100) - # for region in range(400): - # trainings_data[f'region{region}'] = trainings_data[f'region{region}'][:,:, 8][..., np.newaxis] + for region in range(len(region_ids)): + trainings_data[f'region{region}'] = trainings_data[f'region{region}'][:,:, 8][..., np.newaxis] - # with open('validation_data_400params.pickle', 'wb') as f: - # pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) + with open('validation_data_counties.pickle', 'wb') as f: + pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) # with open('trainings_data1_400params.pickle', 'rb') as f: # trainings_data = pickle.load(f) - # for i in range(9): - # with open(f'trainings_data{i+2}_400params.pickle', 'rb') as f: - # data = pickle.load(f) - # trainings_data = {k: np.concatenate([trainings_data[k], data[k]]) for k in trainings_data.keys()} + # # for i in range(9): + # # with open(f'trainings_data{i+2}_400params.pickle', 'rb') as f: + # # data = pickle.load(f) + # # trainings_data = {k: np.concatenate([trainings_data[k], data[k]]) for k in trainings_data.keys()} # with open('validation_data_400params.pickle', 'rb') as f: # validation_data = pickle.load(f) @@ -261,8 +252,8 @@ def load_divi_data(): # print("summary_variables shape:", adapter(trainings_data)["summary_variables"].shape) - # summary_network = bf.networks.TimeSeriesNetwork(summary_dim=700, recurrent_dim=512) - # inference_network = bf.networks.DiffusionModel(subnet_kwargs={'widths': {512, 512, 512, 512, 512}}) + # summary_network = bf.networks.TimeSeriesNetwork(summary_dim=32, recurrent_dim=32) + # inference_network = bf.networks.CouplingFlow() # workflow = bf.BasicWorkflow( # simulator=simulator, @@ -272,9 +263,9 @@ def load_divi_data(): # standardize='all' # ) - # history = workflow.fit_offline(data=trainings_data, epochs=1000, batch_size=32, validation_data=validation_data) + # history = workflow.fit_offline(data=trainings_data, epochs=1, batch_size=32, validation_data=validation_data) - # # workflow.approximator.save(filepath=os.path.join(os.path.dirname(__file__), "model_10params.keras")) + # workflow.approximator.save(filepath=os.path.join(os.path.dirname(__file__), "model_test.keras")) # plots = workflow.plot_default_diagnostics(test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True}) # plots['losses'].savefig('losses_diffusionmodel_400params.png') From 4d4bd83a3f839a11df664324e7136b884e50f76e Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Thu, 14 Aug 2025 09:16:08 +0200 Subject: [PATCH 18/73] save all regions and handle summary_variables keys different --- .../simulation/graph_germany_nuts3.py | 57 ++++++++++--------- 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/pycode/examples/simulation/graph_germany_nuts3.py b/pycode/examples/simulation/graph_germany_nuts3.py index 9919ee81c9..1d6e10e445 100644 --- a/pycode/examples/simulation/graph_germany_nuts3.py +++ b/pycode/examples/simulation/graph_germany_nuts3.py @@ -31,7 +31,8 @@ from memilio.epidata import defaultDict as dd import pickle -excluded_ids = [11001, 11002, 11003, 11004, 11005, 11006, 11007, 11008, 11009, 11010, 11011, 11012, 16056, 7338, 9374, 9473, 9573] +excluded_ids = [11001, 11002, 11003, 11004, 11005, 11006, 11007, 11008, 11009, 11010, 11011, 11012, 16056] +no_icu_ids = [7338, 9374, 9473, 9573] region_ids = [region_id for region_id in dd.County.keys() if region_id not in excluded_ids] class Simulation: @@ -93,7 +94,7 @@ def set_npis(self, params, end_date, damping_value): """ start_damping = datetime.date( - 2020, 7, 8) + year=2020, month=7, day=8) if start_damping < end_date: start_date = (start_damping - self.start_date).days @@ -166,7 +167,7 @@ def run(self, num_days_sim, damping_values, save_graph=True): mobility_graph = osecir.MobilityGraph() for node_idx in range(graph.num_nodes): node = graph.get_node(node_idx) - self.set_npis(node.property.parameters, end_date, damping_values[node_idx // 1000 -1]) + self.set_npis(node.property.parameters, end_date, damping_values[node.id // 1000 -1]) mobility_graph.add_node(node.id, node.property) for edge_idx in range(graph.num_edges): mobility_graph.add_edge( @@ -176,10 +177,13 @@ def run(self, num_days_sim, damping_values, save_graph=True): mobility_sim = osecir.MobilitySimulation(mobility_graph, t0=0, dt=0.5) mobility_sim.advance(num_days_sim) - results = [] - for node_idx in range(graph.num_nodes): - results.append(osecir.interpolate_simulation_result( - mobility_sim.graph.get_node(node_idx).property.result)) + results = {} + for node_idx in range(mobility_sim.graph.num_nodes): + node = mobility_sim.graph.get_node(node_idx) + if node.id in no_icu_ids: + results[f'no_icu_region{node_idx}'] = osecir.interpolate_simulation_result(node.property.result) + else: + results[f'region{node_idx}'] = osecir.interpolate_simulation_result(node.property.result) return results @@ -195,7 +199,7 @@ def run_germany_nuts3_simulation(damping_values): results = sim.run(num_days_sim, damping_values) - return {f'region{region_idx}': results[region_idx] for region_idx, region_id in enumerate(region_ids)} + return results def prior(): damping_values = np.random.uniform(0.0, 1.0, 16) @@ -209,7 +213,7 @@ def load_divi_data(): data = data[data['Date']>= np.datetime64(datetime.date(2020, 7, 1))] data = data[data['Date'] <= np.datetime64(datetime.date(2020, 7, 1) + datetime.timedelta(days=50))] data = data.drop(columns=['County', 'ICU_ventilated', 'Date']) - divi_dict = {f"region{i}": data[data['ID_County'] == region_id]['ICU'].to_numpy().reshape((1, 51, 1)) for i, region_id in enumerate(region_ids)} + divi_dict = {f"region{i}": data[data['ID_County'] == region_id]['ICU'].to_numpy().reshape((1, 51, 1)) for i, region_id in enumerate(region_ids) if region_id not in no_icu_ids} return divi_dict @@ -222,22 +226,23 @@ def load_divi_data(): import bayesflow as bf simulator = bf.simulators.make_simulator([prior, run_germany_nuts3_simulation]) - trainings_data = simulator.sample(100) + trainings_data = simulator.sample(1000) - for region in range(len(region_ids)): - trainings_data[f'region{region}'] = trainings_data[f'region{region}'][:,:, 8][..., np.newaxis] + for key in trainings_data.keys(): + if key != 'damping_values': + trainings_data[key] = trainings_data[key][:, :, 8][..., np.newaxis] - with open('validation_data_counties.pickle', 'wb') as f: + with open('trainings_data10_counties.pickle', 'wb') as f: pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) - # with open('trainings_data1_400params.pickle', 'rb') as f: + # with open('trainings_data1_counties.pickle', 'rb') as f: # trainings_data = pickle.load(f) - # # for i in range(9): - # # with open(f'trainings_data{i+2}_400params.pickle', 'rb') as f: - # # data = pickle.load(f) - # # trainings_data = {k: np.concatenate([trainings_data[k], data[k]]) for k in trainings_data.keys()} + # for i in range(9): + # with open(f'trainings_data{i+2}_counties.pickle', 'rb') as f: + # data = pickle.load(f) + # trainings_data = {k: np.concatenate([trainings_data[k], data[k]]) for k in trainings_data.keys()} - # with open('validation_data_400params.pickle', 'rb') as f: + # with open('validation_data_counties.pickle', 'rb') as f: # validation_data = pickle.load(f) # adapter = ( @@ -246,7 +251,7 @@ def load_divi_data(): # .convert_dtype("float64", "float32") # .constrain("damping_values", lower=0.0, upper=1.0) # .rename("damping_values", "inference_variables") - # .concatenate([f'region{i}' for i in range(400)], into="summary_variables", axis=-1) + # .concatenate([f'region{i}' for i in range(len(region_ids)) if region_ids[i] not in no_icu_ids], into="summary_variables", axis=-1) # .log("summary_variables", p1=True) # ) @@ -263,12 +268,12 @@ def load_divi_data(): # standardize='all' # ) - # history = workflow.fit_offline(data=trainings_data, epochs=1, batch_size=32, validation_data=validation_data) + # history = workflow.fit_offline(data=trainings_data, epochs=100, batch_size=32, validation_data=validation_data) - # workflow.approximator.save(filepath=os.path.join(os.path.dirname(__file__), "model_test.keras")) + # workflow.approximator.save(filepath=os.path.join(os.path.dirname(__file__), "model_countylvl.keras")) # plots = workflow.plot_default_diagnostics(test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True}) - # plots['losses'].savefig('losses_diffusionmodel_400params.png') - # plots['recovery'].savefig('recovery_diffusionmodel_400params.png') - # plots['calibration_ecdf'].savefig('calibration_ecdf_diffusionmodel_400params.png') - # plots['z_score_contraction'].savefig('z_score_contraction_diffusionmodel_400params.png') + # plots['losses'].savefig('losses_countylvl.png') + # plots['recovery'].savefig('recovery_countylvl.png') + # plots['calibration_ecdf'].savefig('calibration_ecdf_countylvl.png') + # plots['z_score_contraction'].savefig('z_score_contraction_countylvl.png') From f23568379d385c080c854228beb76f98aa02f624 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Thu, 14 Aug 2025 15:46:55 +0200 Subject: [PATCH 19/73] changes for fitting countylvl --- .../simulation/graph_germany_nuts3.py | 66 +++++++++++-------- .../memilio/epidata/defaultDict.py | 4 +- shellscripts/fitting_graphmodel.sh | 6 +- 3 files changed, 44 insertions(+), 32 deletions(-) diff --git a/pycode/examples/simulation/graph_germany_nuts3.py b/pycode/examples/simulation/graph_germany_nuts3.py index 1d6e10e445..5c0ddc6554 100644 --- a/pycode/examples/simulation/graph_germany_nuts3.py +++ b/pycode/examples/simulation/graph_germany_nuts3.py @@ -181,9 +181,9 @@ def run(self, num_days_sim, damping_values, save_graph=True): for node_idx in range(mobility_sim.graph.num_nodes): node = mobility_sim.graph.get_node(node_idx) if node.id in no_icu_ids: - results[f'no_icu_region{node_idx}'] = osecir.interpolate_simulation_result(node.property.result) + results[f'no_icu_region{node_idx}'] = osecir.interpolate_simulation_result(node.property.result).as_ndarray() else: - results[f'region{node_idx}'] = osecir.interpolate_simulation_result(node.property.result) + results[f'region{node_idx}'] = osecir.interpolate_simulation_result(node.property.result).as_ndarray() return results @@ -213,7 +213,7 @@ def load_divi_data(): data = data[data['Date']>= np.datetime64(datetime.date(2020, 7, 1))] data = data[data['Date'] <= np.datetime64(datetime.date(2020, 7, 1) + datetime.timedelta(days=50))] data = data.drop(columns=['County', 'ICU_ventilated', 'Date']) - divi_dict = {f"region{i}": data[data['ID_County'] == region_id]['ICU'].to_numpy().reshape((1, 51, 1)) for i, region_id in enumerate(region_ids) if region_id not in no_icu_ids} + divi_dict = {f"region{i}": data[data['ID_County'] == region_id]['ICU'].to_numpy()[None, :, None] for i, region_id in enumerate(region_ids) if region_id not in no_icu_ids} return divi_dict @@ -224,16 +224,17 @@ def load_divi_data(): os.environ["KERAS_BACKEND"] = "tensorflow" import bayesflow as bf + from tensorflow import keras simulator = bf.simulators.make_simulator([prior, run_germany_nuts3_simulation]) - trainings_data = simulator.sample(1000) + # trainings_data = simulator.sample(100) - for key in trainings_data.keys(): - if key != 'damping_values': - trainings_data[key] = trainings_data[key][:, :, 8][..., np.newaxis] + # for key in trainings_data.keys(): + # if key != 'damping_values': + # trainings_data[key] = trainings_data[key][:, :, 8][..., np.newaxis] - with open('trainings_data10_counties.pickle', 'wb') as f: - pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) + # with open('validation_data_counties.pickle', 'wb') as f: + # pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) # with open('trainings_data1_counties.pickle', 'rb') as f: # trainings_data = pickle.load(f) @@ -245,28 +246,28 @@ def load_divi_data(): # with open('validation_data_counties.pickle', 'rb') as f: # validation_data = pickle.load(f) - # adapter = ( - # bf.Adapter() - # .to_array() - # .convert_dtype("float64", "float32") - # .constrain("damping_values", lower=0.0, upper=1.0) - # .rename("damping_values", "inference_variables") - # .concatenate([f'region{i}' for i in range(len(region_ids)) if region_ids[i] not in no_icu_ids], into="summary_variables", axis=-1) - # .log("summary_variables", p1=True) - # ) + adapter = ( + bf.Adapter() + .to_array() + .convert_dtype("float64", "float32") + .constrain("damping_values", lower=0.0, upper=1.0) + .rename("damping_values", "inference_variables") + .concatenate([f'region{i}' for i in range(len(region_ids)) if region_ids[i] not in no_icu_ids], into="summary_variables", axis=-1) + .log("summary_variables", p1=True) + ) # print("summary_variables shape:", adapter(trainings_data)["summary_variables"].shape) - # summary_network = bf.networks.TimeSeriesNetwork(summary_dim=32, recurrent_dim=32) - # inference_network = bf.networks.CouplingFlow() + summary_network = bf.networks.TimeSeriesNetwork(summary_dim=32, recurrent_dim=32) + inference_network = bf.networks.CouplingFlow() - # workflow = bf.BasicWorkflow( - # simulator=simulator, - # adapter=adapter, - # summary_network=summary_network, - # inference_network=inference_network, - # standardize='all' - # ) + workflow = bf.BasicWorkflow( + simulator=simulator, + adapter=adapter, + summary_network=summary_network, + inference_network=inference_network, + standardize='all' + ) # history = workflow.fit_offline(data=trainings_data, epochs=100, batch_size=32, validation_data=validation_data) @@ -277,3 +278,14 @@ def load_divi_data(): # plots['recovery'].savefig('recovery_countylvl.png') # plots['calibration_ecdf'].savefig('calibration_ecdf_countylvl.png') # plots['z_score_contraction'].savefig('z_score_contraction_countylvl.png') + + test = load_divi_data() + workflow.approximator = keras.models.load_model(os.path.join(os.path.dirname(__file__), "model_countylvl.keras")) + + samples = workflow.sample(conditions=test, num_samples=10) + # samples = workflow.samples_to_data_frame(samples) + # print(samples.head()) + samples['damping_values'] = np.squeeze(samples['damping_values']) + for i in range(samples['damping_values'].shape[0]): + test = run_germany_nuts3_simulation(samples['damping_values'][i]) + print(test) \ No newline at end of file diff --git a/pycode/memilio-epidata/memilio/epidata/defaultDict.py b/pycode/memilio-epidata/memilio/epidata/defaultDict.py index 8f95fffafd..7c3016844f 100644 --- a/pycode/memilio-epidata/memilio/epidata/defaultDict.py +++ b/pycode/memilio-epidata/memilio/epidata/defaultDict.py @@ -45,9 +45,9 @@ 'out_folder': default_file_path, 'update_data': False, 'start_date': date(2020, 1, 1), - 'end_date': date.today(), + 'end_date': date(2021, 1, 1), 'split_berlin': False, - 'impute_dates': False, + 'impute_dates': True, 'moving_average': 0, 'file_format': 'json_timeasstring', 'no_raw': False, diff --git a/shellscripts/fitting_graphmodel.sh b/shellscripts/fitting_graphmodel.sh index d92d6dd151..fa53bc085b 100644 --- a/shellscripts/fitting_graphmodel.sh +++ b/shellscripts/fitting_graphmodel.sh @@ -3,9 +3,9 @@ #SBATCH -n 1 #SBATCH -c 1 #SBATCH -t 5-0:00:00 -#SBATCH --output=shellscripts/train_diffusion_model-%A.out -#SBATCH --error=shellscripts/train_diffusion_model-%A.err -#SBATCH --job-name=train_diffusion_model +#SBATCH --output=shellscripts/train_countylvl-%A.out +#SBATCH --error=shellscripts/train_countylvl-%A.err +#SBATCH --job-name=train_countylvl #SBATCH --partition=gpu #SBATCH --gpus=1 From 6a45444601ad8add46c1b117d5a827ac33816520 Mon Sep 17 00:00:00 2001 From: HenrZu <69154294+HenrZu@users.noreply.github.com> Date: Fri, 15 Aug 2025 12:45:54 +0200 Subject: [PATCH 20/73] [ci skip] sort provincias --- .../memilio/epidata/defaultDict.py | 106 +++++++++--------- 1 file changed, 53 insertions(+), 53 deletions(-) diff --git a/pycode/memilio-epidata/memilio/epidata/defaultDict.py b/pycode/memilio-epidata/memilio/epidata/defaultDict.py index 7c3016844f..f4013a491b 100644 --- a/pycode/memilio-epidata/memilio/epidata/defaultDict.py +++ b/pycode/memilio-epidata/memilio/epidata/defaultDict.py @@ -705,56 +705,56 @@ def invert_dict(dict_to_invert): """ return {val: key for key, val in dict_to_invert.items()} -Provincias = { - 5: 'Almería', - 12: 'Cádiz', - 15: 'Córdoba', - 19: 'Granada', - 22: 'Huelva', - 24: 'Jaén', - 30: 'Málaga', - 41: 'Sevilla', - 23: 'Huesca', - 44: 'Teruel', - 50: 'Zaragoza', - 33: 'Asturias', - 8: 'Balears, Illes', - 35: 'Palmas, Las', - 38: 'Santa Cruz de Tenerife', - 39: 'Cantabria', - 6: 'Ávila', - 10: 'Burgos', - 25: 'León', - 34: 'Palencia', - 37: 'Salamanca', - 40: 'Segovia', - 42: 'Soria', - 47: 'Valladolid', - 49: 'Zamora', - 3: 'Albacete', - 14: 'Ciudad Real', - 17: 'Cuenca', - 20: 'Guadalajara', - 45: 'Toledo', - 9: 'Barcelona', - 18: 'Girona', - 26: 'Lleida', - 43: 'Tarragona', - 4: 'Alicante/Alacant', - 13: 'Castellón/Castelló', - 46: 'Valencia/València', - 7: 'Badajoz', - 11: 'Cáceres', - 16: 'Coruña, A', - 28: 'Lugo', - 53: 'Ourense', - 36: 'Pontevedra', - 29: 'Madrid', - 31: 'Murcia', - 32: 'Navarra', - 2: 'Araba/Álava', - 21: 'Gipuzkoa', - 48: 'Bizkaia', - 27: 'Rioja, La', - 51: 'Ceuta', - 52: 'Melilla'} \ No newline at end of file + +Provincias = {2: 'Araba/Álava', + 3: 'Albacete', + 4: 'Alicante/Alacant', + 5: 'Almería', + 6: 'Ávila', + 7: 'Badajoz', + 8: 'Balears, Illes', + 9: 'Barcelona', + 10: 'Burgos', + 11: 'Cáceres', + 12: 'Cádiz', + 13: 'Castellón/Castelló', + 14: 'Ciudad Real', + 15: 'Córdoba', + 16: 'Coruña, A', + 17: 'Cuenca', + 18: 'Girona', + 19: 'Granada', + 20: 'Guadalajara', + 21: 'Gipuzkoa', + 22: 'Huelva', + 23: 'Huesca', + 24: 'Jaén', + 25: 'León', + 26: 'Lleida', + 27: 'Rioja, La', + 28: 'Lugo', + 29: 'Madrid', + 30: 'Málaga', + 31: 'Murcia', + 32: 'Navarra', + 33: 'Asturias', + 34: 'Palencia', + 35: 'Palmas, Las', + 36: 'Pontevedra', + 37: 'Salamanca', + 38: 'Santa Cruz de Tenerife', + 39: 'Cantabria', + 40: 'Segovia', + 41: 'Sevilla', + 42: 'Soria', + 43: 'Tarragona', + 44: 'Teruel', + 45: 'Toledo', + 46: 'Valencia/València', + 47: 'Valladolid', + 48: 'Bizkaia', + 49: 'Zamora', + 50: 'Zaragoza', + 51: 'Ceuta', + 52: 'Melilla', + 53: 'Ourense'} From 54548eebb0100d0a5e1f496b7a6240960ea4da58 Mon Sep 17 00:00:00 2001 From: HenrZu <69154294+HenrZu@users.noreply.github.com> Date: Fri, 15 Aug 2025 13:58:21 +0200 Subject: [PATCH 21/73] rm no age functions for io --- cpp/models/ode_secir/parameters_io.cpp | 133 ------------------------- cpp/models/ode_secir/parameters_io.h | 100 +------------------ 2 files changed, 4 insertions(+), 229 deletions(-) diff --git a/cpp/models/ode_secir/parameters_io.cpp b/cpp/models/ode_secir/parameters_io.cpp index f21ab9d471..5a2e122c24 100644 --- a/cpp/models/ode_secir/parameters_io.cpp +++ b/cpp/models/ode_secir/parameters_io.cpp @@ -187,139 +187,6 @@ IOResult read_confirmed_cases_data( return success(); } -IOResult read_confirmed_cases_noage( - std::vector& rki_data, std::vector const& vregion, Date date, - std::vector& vnum_Exposed, std::vector& vnum_InfectedNoSymptoms, - std::vector& vnum_InfectedSymptoms, std::vector& vnum_InfectedSevere, - std::vector& vnum_icu, std::vector& vnum_death, - std::vector& vnum_rec, const std::vector& vt_Exposed, - const std::vector& vt_InfectedNoSymptoms, - const std::vector& vt_InfectedSymptoms, const std::vector& vt_InfectedSevere, - const std::vector& vt_InfectedCritical, const std::vector& vmu_C_R, - const std::vector& vmu_I_H, const std::vector& vmu_H_U, - const double scaling_factor_inf) -{ - auto max_date_entry = std::max_element(rki_data.begin(), rki_data.end(), [](auto&& a, auto&& b) { - return a.date < b.date; - }); - if (max_date_entry == rki_data.end()) { - log_error("RKI data file is empty."); - return failure(StatusCode::InvalidFileFormat, "RKI file is empty."); - } - auto max_date = max_date_entry->date; - if (max_date < date) { - log_error("Specified date does not exist in RKI data"); - return failure(StatusCode::OutOfRange, "Specified date does not exist in RKI data."); - } - auto days_surplus = std::min(get_offset_in_days(max_date, date) - 6, 0); - - //this statement causes maybe-uninitialized warning for some versions of gcc. - //the error is reported in an included header, so the warning is disabled for the whole file - std::sort(rki_data.begin(), rki_data.end(), [](auto&& a, auto&& b) { - return std::make_tuple(get_region_id(a), a.date) < std::make_tuple(get_region_id(b), b.date); - }); - - for (auto region_idx = size_t(0); region_idx < vregion.size(); ++region_idx) { - auto region_entry_range_it = - std::equal_range(rki_data.begin(), rki_data.end(), vregion[region_idx], [](auto&& a, auto&& b) { - return get_region_id(a) < get_region_id(b); - }); - auto region_entry_range = make_range(region_entry_range_it); - if (region_entry_range.begin() == region_entry_range.end()) { - log_error("No entries found for region {}", vregion[region_idx]); - return failure(StatusCode::InvalidFileFormat, - "No entries found for region " + std::to_string(vregion[region_idx])); - } - for (auto&& region_entry : region_entry_range) { - - auto& t_Exposed = vt_Exposed[region_idx]; - auto& t_InfectedNoSymptoms = vt_InfectedNoSymptoms[region_idx]; - auto& t_InfectedSymptoms = vt_InfectedSymptoms[region_idx]; - auto& t_InfectedSevere = vt_InfectedSevere[region_idx]; - auto& t_InfectedCritical = vt_InfectedCritical[region_idx]; - - auto& num_InfectedNoSymptoms = vnum_InfectedNoSymptoms[region_idx]; - auto& num_InfectedSymptoms = vnum_InfectedSymptoms[region_idx]; - auto& num_rec = vnum_rec[region_idx]; - auto& num_Exposed = vnum_Exposed[region_idx]; - auto& num_InfectedSevere = vnum_InfectedSevere[region_idx]; - auto& num_death = vnum_death[region_idx]; - auto& num_icu = vnum_icu[region_idx]; - - auto& mu_C_R = vmu_C_R[region_idx]; - auto& mu_I_H = vmu_I_H[region_idx]; - auto& mu_H_U = vmu_H_U[region_idx]; - - auto date_df = region_entry.date; - - if (date_df == offset_date_by_days(date, 0)) { - num_InfectedSymptoms += scaling_factor_inf * region_entry.num_confirmed; - num_rec += region_entry.num_recovered; - } - if (date_df == offset_date_by_days(date, days_surplus)) { - num_InfectedNoSymptoms -= - 1 / (1 - mu_C_R) * scaling_factor_inf * region_entry.num_confirmed; - } - if (date_df == offset_date_by_days(date, t_InfectedNoSymptoms + days_surplus)) { - num_InfectedNoSymptoms += - 1 / (1 - mu_C_R) * scaling_factor_inf * region_entry.num_confirmed; - num_Exposed -= 1 / (1 - mu_C_R) * scaling_factor_inf * region_entry.num_confirmed; - } - if (date_df == offset_date_by_days(date, t_Exposed + t_InfectedNoSymptoms + days_surplus)) { - num_Exposed += 1 / (1 - mu_C_R) * scaling_factor_inf * region_entry.num_confirmed; - } - if (date_df == offset_date_by_days(date, -t_InfectedSymptoms)) { - num_InfectedSymptoms -= scaling_factor_inf * region_entry.num_confirmed; - num_InfectedSevere += mu_I_H * scaling_factor_inf * region_entry.num_confirmed; - } - if (date_df == offset_date_by_days(date, -t_InfectedSymptoms - t_InfectedSevere)) { - num_InfectedSevere -= mu_I_H * scaling_factor_inf * region_entry.num_confirmed; - num_icu += mu_I_H * mu_H_U * scaling_factor_inf * region_entry.num_confirmed; - } - if (date_df == - offset_date_by_days(date, -t_InfectedSymptoms - t_InfectedSevere - t_InfectedCritical)) { - num_death += region_entry.num_deaths; - num_icu -= mu_I_H * mu_H_U * scaling_factor_inf * region_entry.num_confirmed; - } - } - } - - for (size_t region_idx = 0; region_idx < vregion.size(); ++region_idx) { - auto region = vregion[region_idx]; - - auto& num_InfectedNoSymptoms = vnum_InfectedNoSymptoms[region_idx]; - auto& num_InfectedSymptoms = vnum_InfectedSymptoms[region_idx]; - auto& num_rec = vnum_rec[region_idx]; - auto& num_Exposed = vnum_Exposed[region_idx]; - auto& num_InfectedSevere = vnum_InfectedSevere[region_idx]; - auto& num_death = vnum_death[region_idx]; - auto& num_icu = vnum_icu[region_idx]; - - auto try_fix_constraints = [region](double& value, double error, auto str) { - if (value < error) { - //this should probably return a failure - //but the algorithm is not robust enough to avoid large negative values and there are tests that rely on it - log_error("{:s} is {:.4f} for region {:d}, exceeds expected negative value.", str, value, region); - value = 0.0; - } - else if (value < 0) { - log_info("{:s} is {:.4f} for region {:d}, automatically corrected", str, value, region); - value = 0.0; - } - }; - - try_fix_constraints(num_InfectedSymptoms, -5, "InfectedSymptoms"); - try_fix_constraints(num_InfectedNoSymptoms, -5, "InfectedNoSymptoms"); - try_fix_constraints(num_Exposed, -5, "Exposed"); - try_fix_constraints(num_InfectedSevere, -5, "InfectedSevere"); - try_fix_constraints(num_death, -5, "Dead"); - try_fix_constraints(num_icu, -5, "InfectedCritical"); - try_fix_constraints(num_rec, -20, "Recovered"); - } - - return success(); -} - } // namespace details } // namespace osecir } // namespace mio diff --git a/cpp/models/ode_secir/parameters_io.h b/cpp/models/ode_secir/parameters_io.h index 4ea862d88c..80ecde125f 100644 --- a/cpp/models/ode_secir/parameters_io.h +++ b/cpp/models/ode_secir/parameters_io.h @@ -151,99 +151,6 @@ IOResult set_confirmed_cases_data(std::vector>& model, std::vect return success(); } -IOResult read_confirmed_cases_noage( - std::vector& rki_data, std::vector const& vregion, Date date, - std::vector& vnum_Exposed, std::vector& vnum_InfectedNoSymptoms, - std::vector& vnum_InfectedSymptoms, std::vector& vnum_InfectedSevere, - std::vector& vnum_icu, std::vector& vnum_death, - std::vector& vnum_rec, const std::vector& vt_Exposed, - const std::vector& vt_InfectedNoSymptoms, - const std::vector& vt_InfectedSymptoms, const std::vector& vt_InfectedSevere, - const std::vector& vt_InfectedCritical, const std::vector& vmu_C_R, - const std::vector& vmu_I_H, const std::vector& vmu_H_U, - const double scaling_factor_inf); - -/** - * @brief Sets populations data from already read case data with multiple age groups into a Model with one age group. - * @tparam FP Floating point data type, e.g., double. - * @param[in, out] model Vector of models in which the data is set. - * @param[in] case_data List of confirmed cases data entries. - * @param[in] region Vector of keys of the region of interest. - * @param[in] date Date at which the data is read. - * @param[in] scaling_factor_inf Factors by which to scale the confirmed cases of rki data. - */ -template -IOResult -set_confirmed_cases_noage(std::vector>& model, std::vector& case_data, - const std::vector& region, Date date, const double scaling_factor_inf) -{ - std::vector t_InfectedNoSymptoms(model.size()); - std::vector t_Exposed(model.size()); - std::vector t_InfectedSymptoms(model.size()); - std::vector t_InfectedSevere(model.size()); - std::vector t_InfectedCritical(model.size()); - - std::vector mu_C_R(model.size()); - std::vector mu_I_H(model.size()); - std::vector mu_H_U(model.size()); - std::vector mu_U_D(model.size()); - - for (size_t node = 0; node < model.size(); ++node) { - - t_Exposed[node] = - static_cast(std::round(model[node].parameters.template get>()[AgeGroup(0)])); - t_InfectedNoSymptoms[node]= static_cast( - std::round(model[node].parameters.template get>()[AgeGroup(0)])); - t_InfectedSymptoms[node] = - static_cast(std::round(model[node].parameters.template get>()[AgeGroup(0)])); - t_InfectedSevere[node] = - static_cast(std::round(model[node].parameters.template get>()[AgeGroup(0)])); - t_InfectedCritical[node] = - static_cast(std::round(model[node].parameters.template get>()[AgeGroup(0)])); - - mu_C_R[node] = model[node].parameters.template get>()[AgeGroup(0)]; - mu_I_H[node] = model[node].parameters.template get>()[AgeGroup(0)]; - mu_H_U[node] = model[node].parameters.template get>()[AgeGroup(0)]; - mu_U_D[node] = model[node].parameters.template get>()[AgeGroup(0)]; - } - std::vector num_InfectedSymptoms(model.size(), 0.0); - std::vector num_death(model.size(), 0.0); - std::vector num_rec(model.size(), 0.0); - std::vector num_Exposed(model.size(), 0.0); - std::vector num_InfectedNoSymptoms(model.size(), 0.0); - std::vector num_InfectedSevere(model.size(), 0.0); - std::vector num_icu(model.size(), 0.0); - - BOOST_OUTCOME_TRY(read_confirmed_cases_noage(case_data, region, date, num_Exposed, num_InfectedNoSymptoms, - num_InfectedSymptoms, num_InfectedSevere, num_icu, num_death, num_rec, - t_Exposed, t_InfectedNoSymptoms, t_InfectedSymptoms, t_InfectedSevere, - t_InfectedCritical, mu_C_R, mu_I_H, mu_H_U, scaling_factor_inf)); - - for (size_t node = 0; node < model.size(); node++) { - if (num_InfectedSymptoms[node] > 0) { - model[node].populations[{AgeGroup(0), InfectionState::Exposed}] = num_Exposed[node]; - model[node].populations[{AgeGroup(0), InfectionState::InfectedNoSymptoms}] = - num_InfectedNoSymptoms[node]; - model[node].populations[{AgeGroup(0), InfectionState::InfectedNoSymptomsConfirmed}] = 0; - model[node].populations[{AgeGroup(0), InfectionState::InfectedSymptoms}] = - num_InfectedSymptoms[node]; - model[node].populations[{AgeGroup(0), InfectionState::InfectedSymptomsConfirmed}] = 0; - model[node].populations[{AgeGroup(0), InfectionState::InfectedSevere}] = num_InfectedSevere[node]; - // Only set the number of ICU patients here, if the date is not available in the data. - if (!is_divi_data_available(date)) { - model[node].populations[{AgeGroup(0), InfectionState::InfectedCritical}] = num_icu[node]; - } - model[node].populations[{AgeGroup(0), InfectionState::Dead}] = num_death[node]; - model[node].populations[{AgeGroup(0), InfectionState::Recovered}] = num_rec[node]; - } - else { - log_warning("No infections reported on date {} for region {}. Population data has not been set.", date, - region[node]); - } - } - return success(); -} - /** * @brief Sets the infected population for a given model based on confirmed cases data. Here, we * read the case data from a file. @@ -259,8 +166,8 @@ IOResult set_confirmed_cases_data(std::vector>& model, const std std::vector const& region, Date date, const std::vector& scaling_factor_inf) { - BOOST_OUTCOME_TRY(auto&& case_data, mio::read_confirmed_cases_noage(path)); - BOOST_OUTCOME_TRY(set_confirmed_cases_noage(model, case_data, region, date, scaling_factor_inf[0])); + BOOST_OUTCOME_TRY(auto&& case_data, mio::read_confirmed_cases_data(path)); + BOOST_OUTCOME_TRY(set_confirmed_cases_data(model, case_data, region, date, scaling_factor_inf[0])); return success(); } @@ -488,7 +395,8 @@ IOResult read_input_data_state(std::vector& model, Date date, std:: template IOResult read_input_data_county(std::vector& model, Date date, const std::vector& county, const std::vector& scaling_factor_inf, double scaling_factor_icu, - const std::string& pydata_dir,int /*num_days*/ = 0, bool /*export_time_series*/ = false) + const std::string& pydata_dir, int /*num_days*/ = 0, + bool /*export_time_series*/ = false) { BOOST_OUTCOME_TRY( details::set_divi_data(model, path_join(pydata_dir, "county_divi_ma7.json"), county, date, scaling_factor_icu)); From a5ab06b9ec6df49ea9a457067da8aed47f786d7c Mon Sep 17 00:00:00 2001 From: HenrZu <69154294+HenrZu@users.noreply.github.com> Date: Fri, 15 Aug 2025 14:33:00 +0200 Subject: [PATCH 22/73] simple example --- cpp/examples/CMakeLists.txt | 4 + cpp/examples/ode_secir_read_input_one_age.cpp | 94 +++++++++++++++++++ cpp/models/ode_secir/parameters_io.h | 93 +++++++++++++----- 3 files changed, 165 insertions(+), 26 deletions(-) create mode 100644 cpp/examples/ode_secir_read_input_one_age.cpp diff --git a/cpp/examples/CMakeLists.txt b/cpp/examples/CMakeLists.txt index 980164d324..11141aee09 100644 --- a/cpp/examples/CMakeLists.txt +++ b/cpp/examples/CMakeLists.txt @@ -171,6 +171,10 @@ if(MEMILIO_HAS_HDF5 AND MEMILIO_HAS_JSONCPP) add_executable(ode_secir_parameter_study_graph ode_secir_parameter_study_graph.cpp) target_link_libraries(ode_secir_parameter_study_graph PRIVATE memilio ode_secir) target_compile_options(ode_secir_parameter_study_graph PRIVATE ${MEMILIO_CXX_FLAGS_ENABLE_WARNING_ERRORS}) + + add_executable(ode_secir_read_input_one_age_example ode_secir_read_input_one_age.cpp) + target_link_libraries(ode_secir_read_input_one_age_example PRIVATE memilio ode_secir) + target_compile_options(ode_secir_read_input_one_age_example PRIVATE ${MEMILIO_CXX_FLAGS_ENABLE_WARNING_ERRORS}) endif() if(MEMILIO_HAS_JSONCPP) diff --git a/cpp/examples/ode_secir_read_input_one_age.cpp b/cpp/examples/ode_secir_read_input_one_age.cpp new file mode 100644 index 0000000000..61e31b3e97 --- /dev/null +++ b/cpp/examples/ode_secir_read_input_one_age.cpp @@ -0,0 +1,94 @@ +/* +* Copyright (C) 2020-2025 MEmilio +* +* Authors: Henrik Zunker +* +* Contact: Martin J. Kuehn +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#include "ode_secir/model.h" +#include "ode_secir/parameters_io.h" +#include "memilio/mobility/metapopulation_mobility_instant.h" +#include "memilio/mobility/graph.h" +#include "memilio/utils/logging.h" +#include +#include +#include + +int main() +{ + mio::set_log_level(mio::LogLevel::info); + using FP = double; + + // params for 1 age group + mio::osecir::Parameters params(1); + params.get>()[mio::AgeGroup(0)] = 3.2; + params.get>()[mio::AgeGroup(0)] = 2.0; + params.get>()[mio::AgeGroup(0)] = 5.8; + params.get>()[mio::AgeGroup(0)] = 9.5; + params.get>()[mio::AgeGroup(0)] = 7.1; + params.get>()[mio::AgeGroup(0)] = 0.05; + params.get>()[mio::AgeGroup(0)] = 0.7; + params.get>()[mio::AgeGroup(0)] = 0.09; + params.get>()[mio::AgeGroup(0)] = 0.2; + params.get>()[mio::AgeGroup(0)] = 0.25; + params.get>()[mio::AgeGroup(0)] = 0.3; + + // input data + const mio::Date date{2020, 12, 1}; + const auto& data_dir = "/localdata1/code_2025/memilio/data"; + const std::string pydata_dir = mio::path_join(data_dir, "Germany", "pydata"); + const std::string population_data_path = mio::path_join(pydata_dir, "county_current_population.json"); + + // scaling factors + std::vector scaling_factor_inf(static_cast(params.get_num_groups()), 1.0); + const double scaling_factor_icu = 1.0; + const double tnt_capacity_factor = 7.5 / 100000.0; + + // grraph + mio::Graph, mio::MobilityParameters> graph; + + const auto& read_function_nodes = mio::osecir::read_input_data_county>; + auto node_id_function = [](const std::string&, bool, bool) -> mio::IOResult> { + return mio::success(std::vector{1002}); + }; + + const auto& set_node_function = + mio::set_nodes, mio::osecir::ContactPatterns, mio::osecir::Model, + mio::MobilityParameters, mio::osecir::Parameters, decltype(read_function_nodes), + decltype(node_id_function), FP>; + + auto io = set_node_function(params, date, date, pydata_dir, population_data_path, /*is_county=*/true, graph, + read_function_nodes, std::move(node_id_function), scaling_factor_inf, + scaling_factor_icu, tnt_capacity_factor, /*num_days=*/0, + /*export_time_series=*/false, /*rki_age_groups=*/true); + if (!io) { + std::cerr << io.error().formatted_message() << std::endl; + return 1; + } + + // icu should be 7. + + // check output + const auto& m = graph.nodes()[0].property; + const auto ag = mio::AgeGroup(0); + std::cout << "Initialized via set_nodes for county 1002 on " << date << "\n"; + std::cout << "S=" << m.populations[{ag, mio::osecir::InfectionState::Susceptible}].value() << ", "; + std::cout << "E=" << m.populations[{ag, mio::osecir::InfectionState::Exposed}].value() << ", "; + std::cout << "I=" << m.populations[{ag, mio::osecir::InfectionState::InfectedSymptoms}].value() << ", "; + std::cout << "R=" << m.populations[{ag, mio::osecir::InfectionState::Recovered}].value() << ", "; + std::cout << "D=" << m.populations[{ag, mio::osecir::InfectionState::Dead}].value() << std::endl; + + return 0; +} diff --git a/cpp/models/ode_secir/parameters_io.h b/cpp/models/ode_secir/parameters_io.h index ab334d922d..6fd2484e7a 100644 --- a/cpp/models/ode_secir/parameters_io.h +++ b/cpp/models/ode_secir/parameters_io.h @@ -75,7 +75,17 @@ IOResult set_confirmed_cases_data(std::vector>& model, std::vect const std::vector& scaling_factor_inf) { const size_t num_age_groups = ConfirmedCasesDataEntry::age_group_names.size(); - assert(scaling_factor_inf.size() == num_age_groups); + // allow single scalar scaling that is broadcast to all age groups + assert(scaling_factor_inf.size() == 1 || scaling_factor_inf.size() == num_age_groups); + + // Broadcast scaling factors to match RKI age groups (6) + std::vector scaling_factor_inf_full(num_age_groups, 1.0); + if (scaling_factor_inf.size() == 1) { + scaling_factor_inf_full.assign(num_age_groups, scaling_factor_inf[0]); + } + else { + scaling_factor_inf_full = scaling_factor_inf; + } std::vector> t_InfectedNoSymptoms{model.size()}; std::vector> t_Exposed{model.size()}; @@ -89,25 +99,30 @@ IOResult set_confirmed_cases_data(std::vector>& model, std::vect std::vector> mu_U_D{model.size()}; for (size_t node = 0; node < model.size(); ++node) { + const size_t model_groups = (size_t)model[node].parameters.get_num_groups(); + assert(model_groups == 1 || model_groups == num_age_groups); for (size_t group = 0; group < num_age_groups; group++) { + // If the model has fewer groups than casedata entries available, + // reuse group 0 parameters for all RKI age groups + const size_t pidx = (model_groups == num_age_groups) ? group : 0; t_Exposed[node].push_back( - static_cast(std::round(model[node].parameters.template get>()[(AgeGroup)group]))); + static_cast(std::round(model[node].parameters.template get>()[(AgeGroup)pidx]))); t_InfectedNoSymptoms[node].push_back(static_cast( - std::round(model[node].parameters.template get>()[(AgeGroup)group]))); + std::round(model[node].parameters.template get>()[(AgeGroup)pidx]))); t_InfectedSymptoms[node].push_back(static_cast( - std::round(model[node].parameters.template get>()[(AgeGroup)group]))); + std::round(model[node].parameters.template get>()[(AgeGroup)pidx]))); t_InfectedSevere[node].push_back(static_cast( - std::round(model[node].parameters.template get>()[(AgeGroup)group]))); + std::round(model[node].parameters.template get>()[(AgeGroup)pidx]))); t_InfectedCritical[node].push_back(static_cast( - std::round(model[node].parameters.template get>()[(AgeGroup)group]))); + std::round(model[node].parameters.template get>()[(AgeGroup)pidx]))); mu_C_R[node].push_back( - model[node].parameters.template get>()[(AgeGroup)group]); + model[node].parameters.template get>()[(AgeGroup)pidx]); mu_I_H[node].push_back( - model[node].parameters.template get>()[(AgeGroup)group]); - mu_H_U[node].push_back(model[node].parameters.template get>()[(AgeGroup)group]); - mu_U_D[node].push_back(model[node].parameters.template get>()[(AgeGroup)group]); + model[node].parameters.template get>()[(AgeGroup)pidx]); + mu_H_U[node].push_back(model[node].parameters.template get>()[(AgeGroup)pidx]); + mu_U_D[node].push_back(model[node].parameters.template get>()[(AgeGroup)pidx]); } } std::vector> num_InfectedSymptoms(model.size(), std::vector(num_age_groups, 0.0)); @@ -121,26 +136,52 @@ IOResult set_confirmed_cases_data(std::vector>& model, std::vect BOOST_OUTCOME_TRY(read_confirmed_cases_data(case_data, region, date, num_Exposed, num_InfectedNoSymptoms, num_InfectedSymptoms, num_InfectedSevere, num_icu, num_death, num_rec, t_Exposed, t_InfectedNoSymptoms, t_InfectedSymptoms, t_InfectedSevere, - t_InfectedCritical, mu_C_R, mu_I_H, mu_H_U, scaling_factor_inf)); + t_InfectedCritical, mu_C_R, mu_I_H, mu_H_U, scaling_factor_inf_full)); for (size_t node = 0; node < model.size(); node++) { if (std::accumulate(num_InfectedSymptoms[node].begin(), num_InfectedSymptoms[node].end(), 0.0) > 0) { size_t num_groups = (size_t)model[node].parameters.get_num_groups(); - for (size_t i = 0; i < num_groups; i++) { - model[node].populations[{AgeGroup(i), InfectionState::Exposed}] = num_Exposed[node][i]; - model[node].populations[{AgeGroup(i), InfectionState::InfectedNoSymptoms}] = - num_InfectedNoSymptoms[node][i]; - model[node].populations[{AgeGroup(i), InfectionState::InfectedNoSymptomsConfirmed}] = 0; - model[node].populations[{AgeGroup(i), InfectionState::InfectedSymptoms}] = - num_InfectedSymptoms[node][i]; - model[node].populations[{AgeGroup(i), InfectionState::InfectedSymptomsConfirmed}] = 0; - model[node].populations[{AgeGroup(i), InfectionState::InfectedSevere}] = num_InfectedSevere[node][i]; - // Only set the number of ICU patients here, if the date is not available in the data. + if (num_groups == num_age_groups) { + for (size_t i = 0; i < num_groups; i++) { + model[node].populations[{AgeGroup(i), InfectionState::Exposed}] = num_Exposed[node][i]; + model[node].populations[{AgeGroup(i), InfectionState::InfectedNoSymptoms}] = + num_InfectedNoSymptoms[node][i]; + model[node].populations[{AgeGroup(i), InfectionState::InfectedNoSymptomsConfirmed}] = 0; + model[node].populations[{AgeGroup(i), InfectionState::InfectedSymptoms}] = + num_InfectedSymptoms[node][i]; + model[node].populations[{AgeGroup(i), InfectionState::InfectedSymptomsConfirmed}] = 0; + model[node].populations[{AgeGroup(i), InfectionState::InfectedSevere}] = + num_InfectedSevere[node][i]; + // Only set the number of ICU patients here, if the date is not available in the data. + if (!is_divi_data_available(date)) { + model[node].populations[{AgeGroup(i), InfectionState::InfectedCritical}] = num_icu[node][i]; + } + model[node].populations[{AgeGroup(i), InfectionState::Dead}] = num_death[node][i]; + model[node].populations[{AgeGroup(i), InfectionState::Recovered}] = num_rec[node][i]; + } + } + else if (num_groups == 1) { + const auto sum_vec = [](const std::vector& v) { + return std::accumulate(v.begin(), v.end(), 0.0); + }; + const size_t i0 = 0; + model[node].populations[{AgeGroup(i0), InfectionState::Exposed}] = sum_vec(num_Exposed[node]); + model[node].populations[{AgeGroup(i0), InfectionState::InfectedNoSymptoms}] = + sum_vec(num_InfectedNoSymptoms[node]); + model[node].populations[{AgeGroup(i0), InfectionState::InfectedNoSymptomsConfirmed}] = 0; + model[node].populations[{AgeGroup(i0), InfectionState::InfectedSymptoms}] = + sum_vec(num_InfectedSymptoms[node]); + model[node].populations[{AgeGroup(i0), InfectionState::InfectedSymptomsConfirmed}] = 0; + model[node].populations[{AgeGroup(i0), InfectionState::InfectedSevere}] = + sum_vec(num_InfectedSevere[node]); if (!is_divi_data_available(date)) { - model[node].populations[{AgeGroup(i), InfectionState::InfectedCritical}] = num_icu[node][i]; + model[node].populations[{AgeGroup(i0), InfectionState::InfectedCritical}] = sum_vec(num_icu[node]); } - model[node].populations[{AgeGroup(i), InfectionState::Dead}] = num_death[node][i]; - model[node].populations[{AgeGroup(i), InfectionState::Recovered}] = num_rec[node][i]; + model[node].populations[{AgeGroup(i0), InfectionState::Dead}] = sum_vec(num_death[node]); + model[node].populations[{AgeGroup(i0), InfectionState::Recovered}] = sum_vec(num_rec[node]); + } + else { + assert(false && "Unsupported number of age groups in model; expected 1 or RKI groups."); } } else { @@ -283,8 +324,8 @@ IOResult export_input_data_county_timeseries( const std::string& divi_data_path, const std::string& confirmed_cases_path, const std::string& population_data_path) { const auto num_age_groups = (size_t)models[0].parameters.get_num_groups(); - assert(scaling_factor_inf.size() == num_age_groups); - assert(num_age_groups == ConfirmedCasesDataEntry::age_group_names.size()); + // allow scalar scaling factor as convenience for 1-group models + assert(scaling_factor_inf.size() == 1 || scaling_factor_inf.size() == num_age_groups); assert(models.size() == region.size()); std::vector> extrapolated_data( region.size(), TimeSeries::zero(num_days + 1, (size_t)InfectionState::Count * num_age_groups)); From 873218090964bfd53226960a6ef057070057024f Mon Sep 17 00:00:00 2001 From: HenrZu <69154294+HenrZu@users.noreply.github.com> Date: Mon, 18 Aug 2025 13:03:54 +0200 Subject: [PATCH 23/73] lot of tests --- cpp/models/ode_secir/parameters_io.h | 18 +- cpp/tests/test_odesecir.cpp | 338 +++++++++++++++++++++++++++ 2 files changed, 353 insertions(+), 3 deletions(-) diff --git a/cpp/models/ode_secir/parameters_io.h b/cpp/models/ode_secir/parameters_io.h index 6fd2484e7a..94a524ab1d 100644 --- a/cpp/models/ode_secir/parameters_io.h +++ b/cpp/models/ode_secir/parameters_io.h @@ -272,10 +272,22 @@ IOResult set_population_data(std::vector>& model, assert(num_population.size() == vregion.size()); assert(model.size() == vregion.size()); for (size_t region = 0; region < vregion.size(); region++) { - auto num_groups = model[region].parameters.get_num_groups(); - for (auto i = AgeGroup(0); i < num_groups; i++) { + const auto model_groups = (size_t)model[region].parameters.get_num_groups(); + const auto data_groups = num_population[region].size(); + + if (data_groups == model_groups) { + for (auto i = AgeGroup(0); i < model[region].parameters.get_num_groups(); i++) { + model[region].populations.template set_difference_from_group_total( + {i, InfectionState::Susceptible}, num_population[region][(size_t)i]); + } + } + else if (model_groups == 1 && data_groups >= 1) { + const double total = std::accumulate(num_population[region].begin(), num_population[region].end(), 0.0); model[region].populations.template set_difference_from_group_total( - {i, InfectionState::Susceptible}, num_population[region][size_t(i)]); + {AgeGroup(0), InfectionState::Susceptible}, total); + } + else { + assert(false && "Dimension of population data not supported."); } } return success(); diff --git a/cpp/tests/test_odesecir.cpp b/cpp/tests/test_odesecir.cpp index 023c56fc6b..88e48a1c24 100644 --- a/cpp/tests/test_odesecir.cpp +++ b/cpp/tests/test_odesecir.cpp @@ -29,6 +29,7 @@ #include "memilio/io/parameters_io.h" #include "memilio/data/analyze_result.h" #include "memilio/math/adapt_rk.h" +#include "memilio/geography/regions.h" #include @@ -1518,5 +1519,342 @@ TEST(TestOdeSecir, read_population_data_failure) EXPECT_EQ(result.error().message(), "File with county population expected."); } +TEST(TestOdeSecirIO, read_input_data_county_aggregates_one_group) +{ + // Set up two models with different age groups. + const size_t num_age_groups = 6; + std::vector> models6{mio::osecir::Model((int)num_age_groups)}; + std::vector> models1{mio::osecir::Model(1)}; + + // set parameters for both models + models6[0].parameters.set(60); + models6[0].parameters.set>(0.2); + models1[0].parameters.set(60); + models1[0].parameters.set>(0.2); + + // parameters for model with 6 age groups + for (auto i = mio::AgeGroup(0); i < (mio::AgeGroup)num_age_groups; ++i) { + models6[0].parameters.get>()[i] = 3.2; + models6[0].parameters.get>()[i] = 2.0; + models6[0].parameters.get>()[i] = 5.8; + models6[0].parameters.get>()[i] = 9.5; + models6[0].parameters.get>()[i] = 7.1; + + models6[0].parameters.get>()[i] = 0.05; + models6[0].parameters.get>()[i] = 0.7; + models6[0].parameters.get>()[i] = 0.09; + models6[0].parameters.get>()[i] = 0.25; + models6[0].parameters.get>()[i] = 0.45; + models6[0].parameters.get>()[i] = 0.2; + models6[0].parameters.get>()[i] = 0.25; + models6[0].parameters.get>()[i] = 0.3; + } + + // parameters for model with 1 age group + models1[0].parameters.get>()[mio::AgeGroup(0)] = 3.2; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 2.0; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 5.8; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 9.5; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 7.1; + + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.05; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.7; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.09; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.25; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.45; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.2; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.25; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.3; + + models6[0].check_constraints(); + models1[0].check_constraints(); + + const auto pydata_dir_Germany = mio::path_join(TEST_DATA_DIR, "Germany", "pydata"); + const std::vector counties{1002}; + const auto date = mio::Date(2020, 12, 1); + + std::vector scale6(num_age_groups, 1.0); + std::vector scale1{1.0}; + + // Initialize both models + ASSERT_THAT(mio::osecir::read_input_data_county(models6, date, counties, scale6, 1.0, pydata_dir_Germany), + IsSuccess()); + ASSERT_THAT(mio::osecir::read_input_data_county(models1, date, counties, scale1, 1.0, pydata_dir_Germany), + IsSuccess()); + + // Aggreagate the results from the model with 6 age groups and compare with the model with 1 age group + const auto& m6 = models6[0]; + const auto& m1 = models1[0]; + const double tol = 1e-13; + for (int s = 0; s < (int)mio::osecir::InfectionState::Count; ++s) { + double sum6 = 0.0; + for (size_t ag = 0; ag < num_age_groups; ++ag) { + sum6 += m6.populations[{mio::AgeGroup(ag), (mio::osecir::InfectionState)s}].value(); + } + const double v1 = m1.populations[{mio::AgeGroup(0), (mio::osecir::InfectionState)s}].value(); + EXPECT_NEAR(sum6, v1, tol); + } + + // Total population + EXPECT_NEAR(m6.populations.get_total(), m1.populations.get_total(), 1e-13); +} + +TEST(TestOdeSecirIO, set_population_data_single_age_group) +{ + const size_t num_age_groups = 6; + + // Create two models: one with 6 age groups, one with 1 age group + std::vector> models6{mio::osecir::Model((int)num_age_groups)}; + std::vector> models1{mio::osecir::Model(1)}; + + // Set basic parameters for both models + models6[0].parameters.set(60); + models6[0].parameters.set>(0.2); + models1[0].parameters.set(60); + models1[0].parameters.set>(0.2); + + // Set parameters for 6-age-group model + for (auto i = mio::AgeGroup(0); i < (mio::AgeGroup)num_age_groups; ++i) { + models6[0].parameters.get>()[i] = 3.2; + models6[0].parameters.get>()[i] = 2.0; + models6[0].parameters.get>()[i] = 5.8; + models6[0].parameters.get>()[i] = 9.5; + models6[0].parameters.get>()[i] = 7.1; + models6[0].parameters.get>()[i] = 0.05; + models6[0].parameters.get>()[i] = 0.7; + models6[0].parameters.get>()[i] = 0.09; + models6[0].parameters.get>()[i] = 0.25; + models6[0].parameters.get>()[i] = 0.45; + models6[0].parameters.get>()[i] = 0.2; + models6[0].parameters.get>()[i] = 0.25; + models6[0].parameters.get>()[i] = 0.3; + } + + // Set parameters for 1-age-group model (same values) + models1[0].parameters.get>()[mio::AgeGroup(0)] = 3.2; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 2.0; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 5.8; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 9.5; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 7.1; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.05; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.7; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.09; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.25; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.45; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.2; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.25; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.3; + + models6[0].check_constraints(); + models1[0].check_constraints(); + + // Test population data with 6 different values for age groups + std::vector> population_data6 = {{10000.0, 20000.0, 30000.0, 25000.0, 15000.0, 8000.0}}; + std::vector> population_data1 = {{108000.0}}; // sum of all age groups + std::vector regions = {1002}; + + // Set population data for both models + EXPECT_THAT(mio::osecir::details::set_population_data(models6, population_data6, regions), IsSuccess()); + EXPECT_THAT(mio::osecir::details::set_population_data(models1, population_data1, regions), IsSuccess()); + + // Sum all compartments across age groups in 6-group model and compare 1-group model + const double tol = 1e-13; + for (int s = 0; s < (int)mio::osecir::InfectionState::Count; ++s) { + double sum6 = 0.0; + for (size_t ag = 0; ag < num_age_groups; ++ag) { + sum6 += models6[0].populations[{mio::AgeGroup(ag), (mio::osecir::InfectionState)s}].value(); + } + double val1 = models1[0].populations[{mio::AgeGroup(0), (mio::osecir::InfectionState)s}].value(); + + EXPECT_NEAR(sum6, val1, tol); + } + + // Total population should also match + EXPECT_NEAR(models6[0].populations.get_total(), models1[0].populations.get_total(), tol); +} + +TEST(TestOdeSecirIO, set_confirmed_cases_data_single_age_group) +{ + const size_t num_age_groups = 6; + + // Create two models: one with 6 age groups, one with 1 age group + std::vector> models6{mio::osecir::Model((int)num_age_groups)}; + std::vector> models1{mio::osecir::Model(1)}; + + // Set identical parameters for both models + models6[0].parameters.set(60); + models6[0].parameters.set>(0.2); + models1[0].parameters.set(60); + models1[0].parameters.set>(0.2); + + for (auto i = mio::AgeGroup(0); i < (mio::AgeGroup)num_age_groups; ++i) { + models6[0].parameters.get>()[i] = 3.2; + models6[0].parameters.get>()[i] = 2.0; + models6[0].parameters.get>()[i] = 5.8; + models6[0].parameters.get>()[i] = 9.5; + models6[0].parameters.get>()[i] = 7.1; + models6[0].parameters.get>()[i] = 0.05; + models6[0].parameters.get>()[i] = 0.7; + models6[0].parameters.get>()[i] = 0.09; + models6[0].parameters.get>()[i] = 0.25; + models6[0].parameters.get>()[i] = 0.45; + models6[0].parameters.get>()[i] = 0.2; + models6[0].parameters.get>()[i] = 0.25; + models6[0].parameters.get>()[i] = 0.3; + } + + models1[0].parameters.get>()[mio::AgeGroup(0)] = 3.2; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 2.0; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 5.8; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 9.5; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 7.1; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.05; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.7; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.09; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.25; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.45; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.2; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.25; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.3; + + models6[0].check_constraints(); + models1[0].check_constraints(); + + // Create case data for all 6 age groups over multiple days (current day + 6 days back) + std::vector case_data; + + for (int day_offset = -6; day_offset <= 0; ++day_offset) { + mio::Date current_date = mio::offset_date_by_days(mio::Date(2020, 12, 1), day_offset); + + for (int age_group = 0; age_group < 6; ++age_group) { + double base_confirmed = 80.0 + age_group * 8.0 + (day_offset + 6) * 5.0; + double base_recovered = 40.0 + age_group * 4.0 + (day_offset + 6) * 3.0; + double base_deaths = 3.0 + age_group * 0.5 + (day_offset + 6) * 0.5; + + mio::ConfirmedCasesDataEntry entry{base_confirmed, + base_recovered, + base_deaths, + current_date, + mio::AgeGroup(age_group), + {}, + mio::regions::CountyId(1002), + {}}; + case_data.push_back(entry); + } + } + + std::vector regions = {1002}; + std::vector scaling_factors = {1.0}; + + // Set confirmed cases data for both models + EXPECT_THAT(mio::osecir::details::set_confirmed_cases_data(models6, case_data, regions, mio::Date(2020, 12, 1), + scaling_factors), + IsSuccess()); + EXPECT_THAT(mio::osecir::details::set_confirmed_cases_data(models1, case_data, regions, mio::Date(2020, 12, 1), + scaling_factors), + IsSuccess()); + + // Sum all compartments across age groups in 6-group model should be equal to 1-group model + for (int s = 0; s < (int)mio::osecir::InfectionState::Count; ++s) { + double sum6 = 0.0; + for (size_t ag = 0; ag < num_age_groups; ++ag) { + sum6 += models6[0].populations[{mio::AgeGroup(ag), (mio::osecir::InfectionState)s}].value(); + } + + double val1 = models1[0].populations[{mio::AgeGroup(0), (mio::osecir::InfectionState)s}].value(); + + EXPECT_NEAR(sum6, val1, 1e-10); + } + + // Total population + EXPECT_NEAR(models6[0].populations.get_total(), models1[0].populations.get_total(), 1e-10); +} + +TEST(TestOdeSecirIO, set_divi_data_single_age_group) +{ + // Create models with 6 age groups and 1 age group + std::vector> models_6_groups{mio::osecir::Model(6)}; + std::vector> models_1_group{mio::osecir::Model(1)}; + + // Set identical parameters for both models + models_6_groups[0].parameters.set(60); + models_1_group[0].parameters.set(60); + models_6_groups[0].parameters.set>(0.2); + models_1_group[0].parameters.set>(0.2); + + // Set parameters for all age groups + for (int i = 0; i < 6; i++) { + models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = 3.2; + models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = 2.0; + models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = 5.8; + models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = 9.5; + models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = 7.1; + models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = + 0.05; + models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = + 0.7; + models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = + 0.09; + models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = + 0.25; + models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = + 0.45; + models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = 0.2; + models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = 0.25; + models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = 0.3; + } + + // Set parameters for 1 age group model (same values) + models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 3.2; + models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 2.0; + models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 5.8; + models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 9.5; + models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 7.1; + models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 0.05; + models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 0.7; + models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 0.09; + models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 0.25; + models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 0.45; + models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 0.2; + models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 0.25; + models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 0.3; + + // Set initial ICU populations to known values + double icu_per_age_group = 100.0; + for (int i = 0; i < 6; i++) { + models_6_groups[0].populations[{mio::AgeGroup(i), mio::osecir::InfectionState::InfectedCritical}] = + icu_per_age_group; + } + models_1_group[0].populations[{mio::AgeGroup(0), mio::osecir::InfectionState::InfectedCritical}] = + 6.0 * icu_per_age_group; + + models_6_groups[0].check_constraints(); + models_1_group[0].check_constraints(); + + // Apply DIVI data to both models + std::vector regions = {1002}; + double scaling_factor_icu = 1.0; + mio::Date date(2020, 12, 1); + std::string divi_data_path = mio::path_join(TEST_DATA_DIR, "Germany", "pydata", "county_divi_ma7.json"); + auto result_6_groups = + mio::osecir::details::set_divi_data(models_6_groups, divi_data_path, regions, date, scaling_factor_icu); + auto result_1_group = + mio::osecir::details::set_divi_data(models_1_group, divi_data_path, regions, date, scaling_factor_icu); + + EXPECT_THAT(result_6_groups, IsSuccess()); + EXPECT_THAT(result_1_group, IsSuccess()); + + // Calculate totals after applying DIVI data + double total_icu_6_groups_after = 0.0; + for (int i = 0; i < 6; i++) { + total_icu_6_groups_after += + models_6_groups[0].populations[{mio::AgeGroup(i), mio::osecir::InfectionState::InfectedCritical}].value(); + } + double icu_1_group_after = + models_1_group[0].populations[{mio::AgeGroup(0), mio::osecir::InfectionState::InfectedCritical}].value(); + + EXPECT_NEAR(total_icu_6_groups_after, icu_1_group_after, 1e-10); +} + #endif #endif From 1c537a6b52f07dace3cc92c486505ff9efc9519b Mon Sep 17 00:00:00 2001 From: Henrik Zunker <69154294+HenrZu@users.noreply.github.com> Date: Mon, 18 Aug 2025 13:31:55 +0200 Subject: [PATCH 24/73] merge from fork --- cpp/examples/ode_secir_read_input_one_age.cpp | 94 ------------------- cpp/models/ode_secir/parameters_io.h | 22 ++--- 2 files changed, 11 insertions(+), 105 deletions(-) delete mode 100644 cpp/examples/ode_secir_read_input_one_age.cpp diff --git a/cpp/examples/ode_secir_read_input_one_age.cpp b/cpp/examples/ode_secir_read_input_one_age.cpp deleted file mode 100644 index 61e31b3e97..0000000000 --- a/cpp/examples/ode_secir_read_input_one_age.cpp +++ /dev/null @@ -1,94 +0,0 @@ -/* -* Copyright (C) 2020-2025 MEmilio -* -* Authors: Henrik Zunker -* -* Contact: Martin J. Kuehn -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ -#include "ode_secir/model.h" -#include "ode_secir/parameters_io.h" -#include "memilio/mobility/metapopulation_mobility_instant.h" -#include "memilio/mobility/graph.h" -#include "memilio/utils/logging.h" -#include -#include -#include - -int main() -{ - mio::set_log_level(mio::LogLevel::info); - using FP = double; - - // params for 1 age group - mio::osecir::Parameters params(1); - params.get>()[mio::AgeGroup(0)] = 3.2; - params.get>()[mio::AgeGroup(0)] = 2.0; - params.get>()[mio::AgeGroup(0)] = 5.8; - params.get>()[mio::AgeGroup(0)] = 9.5; - params.get>()[mio::AgeGroup(0)] = 7.1; - params.get>()[mio::AgeGroup(0)] = 0.05; - params.get>()[mio::AgeGroup(0)] = 0.7; - params.get>()[mio::AgeGroup(0)] = 0.09; - params.get>()[mio::AgeGroup(0)] = 0.2; - params.get>()[mio::AgeGroup(0)] = 0.25; - params.get>()[mio::AgeGroup(0)] = 0.3; - - // input data - const mio::Date date{2020, 12, 1}; - const auto& data_dir = "/localdata1/code_2025/memilio/data"; - const std::string pydata_dir = mio::path_join(data_dir, "Germany", "pydata"); - const std::string population_data_path = mio::path_join(pydata_dir, "county_current_population.json"); - - // scaling factors - std::vector scaling_factor_inf(static_cast(params.get_num_groups()), 1.0); - const double scaling_factor_icu = 1.0; - const double tnt_capacity_factor = 7.5 / 100000.0; - - // grraph - mio::Graph, mio::MobilityParameters> graph; - - const auto& read_function_nodes = mio::osecir::read_input_data_county>; - auto node_id_function = [](const std::string&, bool, bool) -> mio::IOResult> { - return mio::success(std::vector{1002}); - }; - - const auto& set_node_function = - mio::set_nodes, mio::osecir::ContactPatterns, mio::osecir::Model, - mio::MobilityParameters, mio::osecir::Parameters, decltype(read_function_nodes), - decltype(node_id_function), FP>; - - auto io = set_node_function(params, date, date, pydata_dir, population_data_path, /*is_county=*/true, graph, - read_function_nodes, std::move(node_id_function), scaling_factor_inf, - scaling_factor_icu, tnt_capacity_factor, /*num_days=*/0, - /*export_time_series=*/false, /*rki_age_groups=*/true); - if (!io) { - std::cerr << io.error().formatted_message() << std::endl; - return 1; - } - - // icu should be 7. - - // check output - const auto& m = graph.nodes()[0].property; - const auto ag = mio::AgeGroup(0); - std::cout << "Initialized via set_nodes for county 1002 on " << date << "\n"; - std::cout << "S=" << m.populations[{ag, mio::osecir::InfectionState::Susceptible}].value() << ", "; - std::cout << "E=" << m.populations[{ag, mio::osecir::InfectionState::Exposed}].value() << ", "; - std::cout << "I=" << m.populations[{ag, mio::osecir::InfectionState::InfectedSymptoms}].value() << ", "; - std::cout << "R=" << m.populations[{ag, mio::osecir::InfectionState::Recovered}].value() << ", "; - std::cout << "D=" << m.populations[{ag, mio::osecir::InfectionState::Dead}].value() << std::endl; - - return 0; -} diff --git a/cpp/models/ode_secir/parameters_io.h b/cpp/models/ode_secir/parameters_io.h index 94a524ab1d..43994083aa 100644 --- a/cpp/models/ode_secir/parameters_io.h +++ b/cpp/models/ode_secir/parameters_io.h @@ -101,28 +101,28 @@ IOResult set_confirmed_cases_data(std::vector>& model, std::vect for (size_t node = 0; node < model.size(); ++node) { const size_t model_groups = (size_t)model[node].parameters.get_num_groups(); assert(model_groups == 1 || model_groups == num_age_groups); - for (size_t group = 0; group < num_age_groups; group++) { + for (size_t ag = 0; ag < num_age_groups; ag++) { // If the model has fewer groups than casedata entries available, // reuse group 0 parameters for all RKI age groups - const size_t pidx = (model_groups == num_age_groups) ? group : 0; + const size_t group = (model_groups == num_age_groups) ? ag : 0; t_Exposed[node].push_back( - static_cast(std::round(model[node].parameters.template get>()[(AgeGroup)pidx]))); + static_cast(std::round(model[node].parameters.template get>()[(AgeGroup)group]))); t_InfectedNoSymptoms[node].push_back(static_cast( - std::round(model[node].parameters.template get>()[(AgeGroup)pidx]))); + std::round(model[node].parameters.template get>()[(AgeGroup)group]))); t_InfectedSymptoms[node].push_back(static_cast( - std::round(model[node].parameters.template get>()[(AgeGroup)pidx]))); + std::round(model[node].parameters.template get>()[(AgeGroup)group]))); t_InfectedSevere[node].push_back(static_cast( - std::round(model[node].parameters.template get>()[(AgeGroup)pidx]))); + std::round(model[node].parameters.template get>()[(AgeGroup)group]))); t_InfectedCritical[node].push_back(static_cast( - std::round(model[node].parameters.template get>()[(AgeGroup)pidx]))); + std::round(model[node].parameters.template get>()[(AgeGroup)group]))); mu_C_R[node].push_back( - model[node].parameters.template get>()[(AgeGroup)pidx]); + model[node].parameters.template get>()[(AgeGroup)group]); mu_I_H[node].push_back( - model[node].parameters.template get>()[(AgeGroup)pidx]); - mu_H_U[node].push_back(model[node].parameters.template get>()[(AgeGroup)pidx]); - mu_U_D[node].push_back(model[node].parameters.template get>()[(AgeGroup)pidx]); + model[node].parameters.template get>()[(AgeGroup)group]); + mu_H_U[node].push_back(model[node].parameters.template get>()[(AgeGroup)group]); + mu_U_D[node].push_back(model[node].parameters.template get>()[(AgeGroup)group]); } } std::vector> num_InfectedSymptoms(model.size(), std::vector(num_age_groups, 0.0)); From eb77e1698105d834d91a51ab329d8467d8aef4f1 Mon Sep 17 00:00:00 2001 From: Henrik Zunker Date: Mon, 18 Aug 2025 13:45:39 +0200 Subject: [PATCH 25/73] rm from example cmakelist --- cpp/examples/CMakeLists.txt | 4 ---- 1 file changed, 4 deletions(-) diff --git a/cpp/examples/CMakeLists.txt b/cpp/examples/CMakeLists.txt index 11141aee09..980164d324 100644 --- a/cpp/examples/CMakeLists.txt +++ b/cpp/examples/CMakeLists.txt @@ -171,10 +171,6 @@ if(MEMILIO_HAS_HDF5 AND MEMILIO_HAS_JSONCPP) add_executable(ode_secir_parameter_study_graph ode_secir_parameter_study_graph.cpp) target_link_libraries(ode_secir_parameter_study_graph PRIVATE memilio ode_secir) target_compile_options(ode_secir_parameter_study_graph PRIVATE ${MEMILIO_CXX_FLAGS_ENABLE_WARNING_ERRORS}) - - add_executable(ode_secir_read_input_one_age_example ode_secir_read_input_one_age.cpp) - target_link_libraries(ode_secir_read_input_one_age_example PRIVATE memilio ode_secir) - target_compile_options(ode_secir_read_input_one_age_example PRIVATE ${MEMILIO_CXX_FLAGS_ENABLE_WARNING_ERRORS}) endif() if(MEMILIO_HAS_JSONCPP) From e20f2833ee6d0e166aaee7fa06a535ee25cdd6c3 Mon Sep 17 00:00:00 2001 From: Henrik Zunker Date: Mon, 18 Aug 2025 14:53:56 +0200 Subject: [PATCH 26/73] better place for asserts --- cpp/models/ode_secir/parameters_io.h | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/cpp/models/ode_secir/parameters_io.h b/cpp/models/ode_secir/parameters_io.h index 43994083aa..165b4d0c10 100644 --- a/cpp/models/ode_secir/parameters_io.h +++ b/cpp/models/ode_secir/parameters_io.h @@ -180,9 +180,6 @@ IOResult set_confirmed_cases_data(std::vector>& model, std::vect model[node].populations[{AgeGroup(i0), InfectionState::Dead}] = sum_vec(num_death[node]); model[node].populations[{AgeGroup(i0), InfectionState::Recovered}] = sum_vec(num_rec[node]); } - else { - assert(false && "Unsupported number of age groups in model; expected 1 or RKI groups."); - } } else { log_warning("No infections reported on date {} for region {}. Population data has not been set.", date, @@ -274,6 +271,7 @@ IOResult set_population_data(std::vector>& model, for (size_t region = 0; region < vregion.size(); region++) { const auto model_groups = (size_t)model[region].parameters.get_num_groups(); const auto data_groups = num_population[region].size(); + assert(data_groups == model_groups || (model_groups == 1 && data_groups >= 1)); if (data_groups == model_groups) { for (auto i = AgeGroup(0); i < model[region].parameters.get_num_groups(); i++) { @@ -286,9 +284,6 @@ IOResult set_population_data(std::vector>& model, model[region].populations.template set_difference_from_group_total( {AgeGroup(0), InfectionState::Susceptible}, total); } - else { - assert(false && "Dimension of population data not supported."); - } } return success(); } From 6870b4c0326affe7943b20a169ee61bc19db1a14 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Thu, 21 Aug 2025 15:01:28 +0200 Subject: [PATCH 27/73] try new io without age groups --- cpp/memilio/io/epi_data.cpp | 7 +- cpp/models/ode_secir/parameters_io.h | 8 +- .../graph_germany_nuts3_16dampings.py | 280 ------------------ 3 files changed, 9 insertions(+), 286 deletions(-) delete mode 100644 pycode/examples/simulation/graph_germany_nuts3_16dampings.py diff --git a/cpp/memilio/io/epi_data.cpp b/cpp/memilio/io/epi_data.cpp index 1b5569f57c..c91c7f9dd5 100644 --- a/cpp/memilio/io/epi_data.cpp +++ b/cpp/memilio/io/epi_data.cpp @@ -25,8 +25,11 @@ namespace mio { -std::vector ConfirmedCasesDataEntry::age_group_names = {"Population"}; -std::vector PopulationDataEntry::age_group_names = {"Population"}; +std::vector ConfirmedCasesDataEntry::age_group_names = {"A00-A04", "A05-A14", "A15-A34", + "A35-A59", "A60-A79", "A80+"}; +std::vector PopulationDataEntry::age_group_names = { + "<3 years", "3-5 years", "6-14 years", "15-17 years", "18-24 years", "25-29 years", + "30-39 years", "40-49 years", "50-64 years", "65-74 years", ">74 years"}; // std::vector PopulationDataEntrySpain::age_group_names = {"Population"}; std::vector VaccinationDataEntry::age_group_names = {"0-4", "5-14", "15-34", "35-59", "60-79", "80-99"}; diff --git a/cpp/models/ode_secir/parameters_io.h b/cpp/models/ode_secir/parameters_io.h index 8c9aeb2487..a601f71479 100644 --- a/cpp/models/ode_secir/parameters_io.h +++ b/cpp/models/ode_secir/parameters_io.h @@ -205,7 +205,7 @@ IOResult set_confirmed_cases_data(std::vector>& model, const std const std::vector& scaling_factor_inf) { BOOST_OUTCOME_TRY(auto&& case_data, mio::read_confirmed_cases_data(path)); - BOOST_OUTCOME_TRY(set_confirmed_cases_data(model, case_data, region, date, scaling_factor_inf[0])); + BOOST_OUTCOME_TRY(set_confirmed_cases_data(model, case_data, region, date, scaling_factor_inf)); return success(); } @@ -448,10 +448,10 @@ IOResult read_input_data_county(std::vector& model, Date date, cons { BOOST_OUTCOME_TRY( details::set_divi_data(model, path_join(pydata_dir, "county_divi_ma7.json"), county, date, scaling_factor_icu)); - BOOST_OUTCOME_TRY(details::set_confirmed_cases_data(model, path_join(pydata_dir, "cases_all_county_ma7.json"), + BOOST_OUTCOME_TRY(details::set_confirmed_cases_data(model, path_join(pydata_dir, "cases_all_county_age_ma7.json"), county, date, scaling_factor_inf)); - BOOST_OUTCOME_TRY(details::set_population_data( - model, path_join(pydata_dir, "county_current_population_aggregated.json"), county)); + BOOST_OUTCOME_TRY( + details::set_population_data(model, path_join(pydata_dir, "county_current_population.json"), county)); // if (export_time_series) { // // Use only if extrapolated real data is needed for comparison. EXPENSIVE ! diff --git a/pycode/examples/simulation/graph_germany_nuts3_16dampings.py b/pycode/examples/simulation/graph_germany_nuts3_16dampings.py deleted file mode 100644 index 8218a7c48b..0000000000 --- a/pycode/examples/simulation/graph_germany_nuts3_16dampings.py +++ /dev/null @@ -1,280 +0,0 @@ -############################################################################# -# Copyright (C) 2020-2025 MEmilio -# -# Authors: Henrik Zunker -# -# Contact: Martin J. Kuehn -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -############################################################################# -import numpy as np -import datetime -import os -import memilio.simulation as mio -import memilio.simulation.osecir as osecir -import matplotlib.pyplot as plt - -from enum import Enum -from memilio.simulation.osecir import (Model, Simulation, - interpolate_simulation_result) - -import pickle - - -class Simulation: - """ """ - - def __init__(self, data_dir, start_date, results_dir): - self.num_groups = 1 - self.data_dir = data_dir - self.start_date = start_date - self.results_dir = results_dir - if not os.path.exists(self.results_dir): - os.makedirs(self.results_dir) - - def set_covid_parameters(self, model): - """ - - :param model: - - """ - model.parameters.TimeExposed[mio.AgeGroup(0)] = 3.335 - model.parameters.TimeInfectedNoSymptoms[mio.AgeGroup(0)] = 2.58916 - model.parameters.TimeInfectedSymptoms[mio.AgeGroup(0)] = 6.94547 - model.parameters.TimeInfectedSevere[mio.AgeGroup(0)] = 7.28196 - model.parameters.TimeInfectedCritical[mio.AgeGroup(0)] = 13.066 - - # probabilities - model.parameters.TransmissionProbabilityOnContact[mio.AgeGroup(0)] = 0.07333 - model.parameters.RelativeTransmissionNoSymptoms[mio.AgeGroup(0)] = 1 - - model.parameters.RecoveredPerInfectedNoSymptoms[mio.AgeGroup(0)] = 0.2069 - model.parameters.SeverePerInfectedSymptoms[mio.AgeGroup(0)] = 0.07864 - model.parameters.CriticalPerSevere[mio.AgeGroup(0)] = 0.17318 - model.parameters.DeathsPerCritical[mio.AgeGroup(0)] = 0.21718 - - # start day is set to the n-th day of the year - model.parameters.StartDay = self.start_date.timetuple().tm_yday - - model.parameters.Seasonality = mio.UncertainValue(0.2) - - def set_contact_matrices(self, model): - """ - - :param model: - - """ - contact_matrices = mio.ContactMatrixGroup(1, self.num_groups) - - baseline = np.ones((self.num_groups, self.num_groups)) * 7.95 - minimum = np.ones((self.num_groups, self.num_groups)) * 0 - contact_matrices[0] = mio.ContactMatrix(baseline, minimum) - model.parameters.ContactPatterns.cont_freq_mat = contact_matrices - - def set_npis(self, params, end_date, damping_value): - """ - - :param params: - :param end_date: - - """ - - start_damping = datetime.date( - 2020, 12, 18) - - if start_damping < end_date: - start_date = (start_damping - self.start_date).days - params.ContactPatterns.cont_freq_mat[0].add_damping(mio.Damping(np.r_[damping_value], t=start_date)) - - def get_graph(self, end_date): - """ - - :param end_date: - - """ - print("Initializing model...") - model = Model(self.num_groups) - self.set_covid_parameters(model) - self.set_contact_matrices(model) - print("Model initialized.") - - graph = osecir.ModelGraph() - - scaling_factor_infected = [2.5] - scaling_factor_icu = 1.0 - tnt_capacity_factor = 7.5 / 100000. - - data_dir_Germany = os.path.join(self.data_dir, "Germany") - mobility_data_file = os.path.join( - data_dir_Germany, "mobility", "commuter_mobility_2022.txt") - pydata_dir = os.path.join(data_dir_Germany, "pydata") - - path_population_data = os.path.join(pydata_dir, - "county_current_population_aggregated.json") - - print("Setting nodes...") - mio.osecir.set_nodes( - model.parameters, - mio.Date(self.start_date.year, - self.start_date.month, self.start_date.day), - mio.Date(end_date.year, - end_date.month, end_date.day), pydata_dir, - path_population_data, True, graph, scaling_factor_infected, - scaling_factor_icu, 0, 0, False, False) - - print("Setting edges...") - mio.osecir.set_edges( - mobility_data_file, graph, 1) - - print("Graph created.") - - return graph - - def run(self, num_days_sim, damping_values, save_graph=True): - """ - - :param num_days_sim: - :param num_runs: (Default value = 10) - :param save_graph: (Default value = True) - :param create_gif: (Default value = True) - - """ - mio.set_log_level(mio.LogLevel.Warning) - end_date = self.start_date + datetime.timedelta(days=num_days_sim) - - graph = self.get_graph(end_date) - - if save_graph: - path_graph = os.path.join(self.results_dir, "graph") - if not os.path.exists(path_graph): - os.makedirs(path_graph) - osecir.write_graph(graph, path_graph) - - mobility_graph = osecir.MobilityGraph() - for node_idx in range(graph.num_nodes): - node = graph.get_node(node_idx) - self.set_npis(node.property.parameters, end_date, damping_values[node_idx]) - mobility_graph.add_node(node.id, node.property) - for edge_idx in range(graph.num_edges): - mobility_graph.add_edge( - graph.get_edge(edge_idx).start_node_idx, - graph.get_edge(edge_idx).end_node_idx, - graph.get_edge(edge_idx).property) - mobility_sim = osecir.MobilitySimulation(mobility_graph, t0=0, dt=0.5) - mobility_sim.advance(num_days_sim) - - print("Simulation finished.") - results = [] - for node_idx in range(graph.num_nodes): - results.append(osecir.interpolate_simulation_result( - mobility_sim.graph.get_node(node_idx).property.result)) - - return results - -def run_germany_nuts3_simulation(damping_values): - mio.set_log_level(mio.LogLevel.Warning) - file_path = os.path.dirname(os.path.abspath(__file__)) - - sim = Simulation( - data_dir=os.path.join(file_path, "../../../data"), - start_date=datetime.date(year=2020, month=12, day=12), - results_dir=os.path.join(file_path, "../../../results_osecir")) - num_days_sim = 50 - - results = sim.run(num_days_sim, damping_values) - - return {f'region{region}': results[region] for region in range(len(results))} - -def prior(): - damping_values = np.random.uniform(0.0, 1.0, 400) - return {'damping_values': damping_values} - -def load_divi_data(): - file_path = os.path.dirname(os.path.abspath(__file__)) - divi_path = os.path.join(file_path, "../../../data/Germany/pydata") - - data = pd.read_json(os.path.join(divi_path, "county_divi.json")) - print(data["ID_County"].drop_duplicates().shape) - data = data[data['Date']>= np.datetime64(datetime.date(2020, 8, 1))] - data = data[data['Date'] <= np.datetime64(datetime.date(2020, 8, 1) + datetime.timedelta(days=50))] - print(data["ID_County"].drop_duplicates().shape) - data = data.drop(columns=['County', 'ICU_ventilated', 'Date']) - region_ids = [*dd.County] - divi_dict = {f"region{i}": data[data['ID_County'] == region_ids[i]]['ICU'].to_numpy() for i in range(400)} - # for i in range(100): - # if divi_dict[f'region{i+100}'].size==0: - # print(region_ids[i+100]) - # print(divi_dict[f'region{i+100}'].shape) - - -if __name__ == "__main__": - - # from memilio.epidata import defaultDict as dd - # import pandas as pd - # load_divi_data() - import os - os.environ["KERAS_BACKEND"] = "tensorflow" - - import bayesflow as bf - - simulator = bf.simulators.make_simulator([prior, run_germany_nuts3_simulation]) - # trainings_data = simulator.sample(1000) - - # for region in range(400): - # trainings_data[f'region{region}'] = trainings_data[f'region{region}'][:,:, 8][..., np.newaxis] - - # with open('validation_data_400params.pickle', 'wb') as f: - # pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) - - with open('trainings_data1_16params_countylvl.pickle', 'rb') as f: - trainings_data = pickle.load(f) - for i in range(9): - with open(f'trainings_data{i+2}_16params_countylvl.pickle', 'rb') as f: - data = pickle.load(f) - trainings_data = {k: np.concatenate([trainings_data[k], data[k]]) for k in trainings_data.keys()} - - with open('validation_data_16params_countylvl.pickle', 'rb') as f: - validation_data = pickle.load(f) - - adapter = ( - bf.Adapter() - .to_array() - .convert_dtype("float64", "float32") - .constrain("damping_values", lower=0.0, upper=1.0) - .rename("damping_values", "inference_variables") - .concatenate([f'region{i}' for i in range(400)], into="summary_variables", axis=-1) - .log("summary_variables", p1=True) - ) - - print("summary_variables shape:", adapter(trainings_data)["summary_variables"].shape) - - summary_network = bf.networks.TimeSeriesNetwork(summary_dim=32) #, recurrent_dim=256) - inference_network = bf.networks.CouplingFlow()#subnet_kwargs={'widths': {512, 512, 512, 512, 512}}) - - workflow = bf.BasicWorkflow( - simulator=simulator, - adapter=adapter, - summary_network=summary_network, - inference_network=inference_network, - standardize='all' - ) - - history = workflow.fit_offline(data=trainings_data, epochs=100, batch_size=32, validation_data=validation_data) - - # workflow.approximator.save(filepath=os.path.join(os.path.dirname(__file__), "model_10params.keras")) - - plots = workflow.plot_default_diagnostics(test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True}) - plots['losses'].savefig('losses_couplingflow_16params_countylvl.png') - plots['recovery'].savefig('recovery_couplingflow_16params_countylvl.png') - plots['calibration_ecdf'].savefig('calibration_ecdf_couplingflow_16params_countylvl.png') - plots['z_score_contraction'].savefig('z_score_contraction_couplingflow_16params_countylvl.png') From 79f8e51d2fd47dba24a003f3152edfa1bbb678d1 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Thu, 21 Aug 2025 15:13:38 +0200 Subject: [PATCH 28/73] rm hardcoded rki_age_group --- cpp/memilio/io/epi_data.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/memilio/io/epi_data.h b/cpp/memilio/io/epi_data.h index fcb9fe147c..2b62dd59b4 100644 --- a/cpp/memilio/io/epi_data.h +++ b/cpp/memilio/io/epi_data.h @@ -490,7 +490,7 @@ inline IOResult> deserialize_population_data(co * @return list of population data. */ inline IOResult> read_population_data(const std::string& filename, - bool rki_age_group = false) + bool rki_age_group = true) { BOOST_OUTCOME_TRY(auto&& jsvalue, read_json(filename)); return deserialize_population_data(jsvalue, rki_age_group); From 5483e52b7ddaaac066cfc395ef2a1569d7a3936f Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Thu, 21 Aug 2025 17:09:21 +0200 Subject: [PATCH 29/73] [ci skip] working germany examples --- .../simulation/graph_germany_nuts0.py | 2 +- .../simulation/graph_germany_nuts1.py | 99 ++++++++++--------- .../simulation/graph_germany_nuts3.py | 92 ++++++++--------- .../memilio/epidata/getPopulationData.py | 14 ++- .../simulation/bindings/models/osecir.cpp | 13 ++- 5 files changed, 114 insertions(+), 106 deletions(-) diff --git a/pycode/examples/simulation/graph_germany_nuts0.py b/pycode/examples/simulation/graph_germany_nuts0.py index 9783230b28..f5cfb44ca1 100644 --- a/pycode/examples/simulation/graph_germany_nuts0.py +++ b/pycode/examples/simulation/graph_germany_nuts0.py @@ -128,7 +128,7 @@ def get_graph(self, end_date, damping_value): mio.Date(end_date.year, end_date.month, end_date.day), pydata_dir, path_population_data, False, graph, scaling_factor_infected, - scaling_factor_icu, 0.9, 0, False, False) + scaling_factor_icu, 0.9, 0, False) print("Graph created.") diff --git a/pycode/examples/simulation/graph_germany_nuts1.py b/pycode/examples/simulation/graph_germany_nuts1.py index 31a077171b..843dfb2d83 100644 --- a/pycode/examples/simulation/graph_germany_nuts1.py +++ b/pycode/examples/simulation/graph_germany_nuts1.py @@ -130,7 +130,7 @@ def get_graph(self, end_date): mio.Date(end_date.year, end_date.month, end_date.day), pydata_dir, path_population_data, False, graph, scaling_factor_infected, - scaling_factor_icu, 0, 0, False, False) + scaling_factor_icu, 0, 0, False) print("Setting edges...") mio.osecir.set_edges( @@ -196,6 +196,7 @@ def run_germany_nuts1_simulation(damping_values): num_days_sim = 50 results = sim.run(num_days_sim, damping_values) + results[0].export_csv('test.csv') return {f'region{region}': results[region] for region in range(len(results))} @@ -204,14 +205,16 @@ def prior(): return {'damping_values': damping_values} if __name__ == "__main__": + test = prior() + run_germany_nuts1_simulation(test['damping_values']) - import os - os.environ["KERAS_BACKEND"] = "tensorflow" + # import os + # os.environ["KERAS_BACKEND"] = "tensorflow" - import bayesflow as bf + # import bayesflow as bf - simulator = bf.simulators.make_simulator([prior, run_germany_nuts1_simulation]) - # trainings_data = simulator.sample(1000) + # simulator = bf.simulators.make_simulator([prior, run_germany_nuts1_simulation]) + # trainings_data = simulator.sample(1) # for region in range(16): # trainings_data[f'region{region}'] = trainings_data[f'region{region}'][:,:, 8][..., np.newaxis] @@ -219,46 +222,46 @@ def prior(): # with open('validation_data_16param.pickle', 'wb') as f: # pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) - with open('trainings_data1_16param.pickle', 'rb') as f: - trainings_data = pickle.load(f) - trainings_data['damping_values'] = trainings_data['damping_values'][:, :16] - for i in range(9): - with open(f'trainings_data{i+2}_16param.pickle', 'rb') as f: - data = pickle.load(f) - data['damping_values'] = data['damping_values'][:, :16] - trainings_data = {k: np.concatenate([trainings_data[k], data[k]]) for k in trainings_data.keys()} - - with open('validation_data_16param.pickle', 'rb') as f: - validation_data = pickle.load(f) - validation_data['damping_values'] = validation_data['damping_values'][:, :16] - - adapter = ( - bf.Adapter() - .to_array() - .convert_dtype("float64", "float32") - .constrain("damping_values", lower=0.0, upper=1.0) - .rename("damping_values", "inference_variables") - .concatenate([f'region{i}' for i in range(16)], into="summary_variables", axis=-1) - .log("summary_variables", p1=True) - ) + # with open('trainings_data1_16param.pickle', 'rb') as f: + # trainings_data = pickle.load(f) + # trainings_data['damping_values'] = trainings_data['damping_values'][:, :16] + # for i in range(9): + # with open(f'trainings_data{i+2}_16param.pickle', 'rb') as f: + # data = pickle.load(f) + # data['damping_values'] = data['damping_values'][:, :16] + # trainings_data = {k: np.concatenate([trainings_data[k], data[k]]) for k in trainings_data.keys()} + + # with open('validation_data_16param.pickle', 'rb') as f: + # validation_data = pickle.load(f) + # validation_data['damping_values'] = validation_data['damping_values'][:, :16] + + # adapter = ( + # bf.Adapter() + # .to_array() + # .convert_dtype("float64", "float32") + # .constrain("damping_values", lower=0.0, upper=1.0) + # .rename("damping_values", "inference_variables") + # .concatenate([f'region{i}' for i in range(16)], into="summary_variables", axis=-1) + # .log("summary_variables", p1=True) + # ) - summary_network = bf.networks.TimeSeriesNetwork(summary_dim=32) - inference_network = bf.networks.CouplingFlow() - - workflow = bf.BasicWorkflow( - simulator=simulator, - adapter=adapter, - summary_network=summary_network, - inference_network=inference_network, - standardize='all' - ) - - history = workflow.fit_offline(data=trainings_data, epochs=100, batch_size=32, validation_data=validation_data) - - # workflow.approximator.save(filepath=os.path.join(os.path.dirname(__file__), "model_1params.keras")) - - plots = workflow.plot_default_diagnostics(test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True}) - plots['losses'].savefig('losses_couplingflow_16param.png') - plots['recovery'].savefig('recovery_couplingflow_16param.png') - plots['calibration_ecdf'].savefig('calibration_ecdf_couplingflow_16param.png') - plots['z_score_contraction'].savefig('z_score_contraction_couplingflow_16param.png') + # summary_network = bf.networks.TimeSeriesNetwork(summary_dim=32) + # inference_network = bf.networks.CouplingFlow() + + # workflow = bf.BasicWorkflow( + # simulator=simulator, + # adapter=adapter, + # summary_network=summary_network, + # inference_network=inference_network, + # standardize='all' + # ) + + # history = workflow.fit_offline(data=trainings_data, epochs=100, batch_size=32, validation_data=validation_data) + + # # workflow.approximator.save(filepath=os.path.join(os.path.dirname(__file__), "model_1params.keras")) + + # plots = workflow.plot_default_diagnostics(test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True}) + # plots['losses'].savefig('losses_couplingflow_16param.png') + # plots['recovery'].savefig('recovery_couplingflow_16param.png') + # plots['calibration_ecdf'].savefig('calibration_ecdf_couplingflow_16param.png') + # plots['z_score_contraction'].savefig('z_score_contraction_couplingflow_16param.png') diff --git a/pycode/examples/simulation/graph_germany_nuts3.py b/pycode/examples/simulation/graph_germany_nuts3.py index 5c0ddc6554..7010cff1dc 100644 --- a/pycode/examples/simulation/graph_germany_nuts3.py +++ b/pycode/examples/simulation/graph_germany_nuts3.py @@ -124,7 +124,7 @@ def get_graph(self, end_date): pydata_dir = os.path.join(data_dir_Germany, "pydata") path_population_data = os.path.join(pydata_dir, - "county_current_population_aggregated.json") + "county_current_population.json") print("Setting nodes...") mio.osecir.set_nodes( @@ -134,7 +134,7 @@ def get_graph(self, end_date): mio.Date(end_date.year, end_date.month, end_date.day), pydata_dir, path_population_data, True, graph, scaling_factor_infected, - scaling_factor_icu, 0, 0, False, False) + scaling_factor_icu, 0, 0, False) print("Setting edges...") mio.osecir.set_edges( @@ -227,7 +227,7 @@ def load_divi_data(): from tensorflow import keras simulator = bf.simulators.make_simulator([prior, run_germany_nuts3_simulation]) - # trainings_data = simulator.sample(100) + trainings_data = simulator.sample(1) # for key in trainings_data.keys(): # if key != 'damping_values': @@ -246,46 +246,46 @@ def load_divi_data(): # with open('validation_data_counties.pickle', 'rb') as f: # validation_data = pickle.load(f) - adapter = ( - bf.Adapter() - .to_array() - .convert_dtype("float64", "float32") - .constrain("damping_values", lower=0.0, upper=1.0) - .rename("damping_values", "inference_variables") - .concatenate([f'region{i}' for i in range(len(region_ids)) if region_ids[i] not in no_icu_ids], into="summary_variables", axis=-1) - .log("summary_variables", p1=True) - ) - - # print("summary_variables shape:", adapter(trainings_data)["summary_variables"].shape) - - summary_network = bf.networks.TimeSeriesNetwork(summary_dim=32, recurrent_dim=32) - inference_network = bf.networks.CouplingFlow() - - workflow = bf.BasicWorkflow( - simulator=simulator, - adapter=adapter, - summary_network=summary_network, - inference_network=inference_network, - standardize='all' - ) - - # history = workflow.fit_offline(data=trainings_data, epochs=100, batch_size=32, validation_data=validation_data) - - # workflow.approximator.save(filepath=os.path.join(os.path.dirname(__file__), "model_countylvl.keras")) - - # plots = workflow.plot_default_diagnostics(test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True}) - # plots['losses'].savefig('losses_countylvl.png') - # plots['recovery'].savefig('recovery_countylvl.png') - # plots['calibration_ecdf'].savefig('calibration_ecdf_countylvl.png') - # plots['z_score_contraction'].savefig('z_score_contraction_countylvl.png') - - test = load_divi_data() - workflow.approximator = keras.models.load_model(os.path.join(os.path.dirname(__file__), "model_countylvl.keras")) - - samples = workflow.sample(conditions=test, num_samples=10) - # samples = workflow.samples_to_data_frame(samples) - # print(samples.head()) - samples['damping_values'] = np.squeeze(samples['damping_values']) - for i in range(samples['damping_values'].shape[0]): - test = run_germany_nuts3_simulation(samples['damping_values'][i]) - print(test) \ No newline at end of file + # adapter = ( + # bf.Adapter() + # .to_array() + # .convert_dtype("float64", "float32") + # .constrain("damping_values", lower=0.0, upper=1.0) + # .rename("damping_values", "inference_variables") + # .concatenate([f'region{i}' for i in range(len(region_ids)) if region_ids[i] not in no_icu_ids], into="summary_variables", axis=-1) + # .log("summary_variables", p1=True) + # ) + + # # print("summary_variables shape:", adapter(trainings_data)["summary_variables"].shape) + + # summary_network = bf.networks.TimeSeriesNetwork(summary_dim=32, recurrent_dim=32) + # inference_network = bf.networks.CouplingFlow() + + # workflow = bf.BasicWorkflow( + # simulator=simulator, + # adapter=adapter, + # summary_network=summary_network, + # inference_network=inference_network, + # standardize='all' + # ) + + # # history = workflow.fit_offline(data=trainings_data, epochs=100, batch_size=32, validation_data=validation_data) + + # # workflow.approximator.save(filepath=os.path.join(os.path.dirname(__file__), "model_countylvl.keras")) + + # # plots = workflow.plot_default_diagnostics(test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True}) + # # plots['losses'].savefig('losses_countylvl.png') + # # plots['recovery'].savefig('recovery_countylvl.png') + # # plots['calibration_ecdf'].savefig('calibration_ecdf_countylvl.png') + # # plots['z_score_contraction'].savefig('z_score_contraction_countylvl.png') + + # test = load_divi_data() + # workflow.approximator = keras.models.load_model(os.path.join(os.path.dirname(__file__), "model_countylvl.keras")) + + # samples = workflow.sample(conditions=test, num_samples=10) + # # samples = workflow.samples_to_data_frame(samples) + # # print(samples.head()) + # samples['damping_values'] = np.squeeze(samples['damping_values']) + # for i in range(samples['damping_values'].shape[0]): + # test = run_germany_nuts3_simulation(samples['damping_values'][i]) + # print(test) \ No newline at end of file diff --git a/pycode/memilio-epidata/memilio/epidata/getPopulationData.py b/pycode/memilio-epidata/memilio/epidata/getPopulationData.py index b96d5d6800..331860bcca 100644 --- a/pycode/memilio-epidata/memilio/epidata/getPopulationData.py +++ b/pycode/memilio-epidata/memilio/epidata/getPopulationData.py @@ -140,10 +140,10 @@ def export_population_dataframe(df_pop: pd.DataFrame, directory: str, file_forma columns=dd.EngEng["idCounty"]) gd.write_dataframe(df_pop_export, directory, filename, file_format) - gd.write_dataframe(df_pop_export.drop(columns=new_cols[2:]), directory, filename + '_aggregated', file_format) - gd.write_dataframe(aggregate_to_state_level(df_pop_export.drop(columns=new_cols[2:])), directory, filename + '_states', file_format) - df_pop_germany = pd.DataFrame({"ID": [0], "Population": [df_pop_export["Population"].sum()]}) - gd.write_dataframe(df_pop_germany, directory, filename + '_germany', file_format) + gd.write_dataframe(aggregate_to_state_level(df_pop_export), directory, filename + '_states', file_format) + gd.write_dataframe(aggregate_to_country_level(df_pop_export), directory, filename + '_germany', file_format) + # df_pop_germany = pd.DataFrame({"ID": [0], "Population": [df_pop_export["Population"].sum()]}) + # gd.write_dataframe(df_pop_germany, directory, filename + '_germany', file_format) return df_pop_export @@ -456,6 +456,12 @@ def aggregate_to_state_level(df_pop: pd.DataFrame): df_pop['ID_State'] = df_pop.index return df_pop +def aggregate_to_country_level(df_pop: pd.DataFrame): + + df_pop['ID_Country'] = 0 + df_pop = df_pop.drop(columns=['ID_County', 'ID_State']).groupby('ID_Country', as_index=True).sum() + df_pop['ID_Country'] = df_pop.index + return df_pop def main(): """ Main program entry.""" diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp b/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp index ecbb6287e2..c08eb591df 100644 --- a/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp +++ b/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp @@ -260,14 +260,14 @@ PYBIND11_MODULE(_simulation_osecir, m) const std::string& data_dir, const std::string& population_data_path, bool is_node_for_county, mio::Graph, mio::MobilityParameters>& params_graph, const std::vector& scaling_factor_inf, double scaling_factor_icu, double tnt_capacity_factor, - int num_days = 0, bool export_time_series = false, bool rki_age_groups = true) { + int num_days = 0, bool export_time_series = false) { auto result = mio::set_nodes< mio::osecir::TestAndTraceCapacity, mio::osecir::ContactPatterns, mio::osecir::Model, mio::MobilityParameters, mio::osecir::Parameters, decltype(mio::osecir::read_input_data_county>), decltype(mio::get_node_ids)>( params, start_date, end_date, data_dir, population_data_path, is_node_for_county, params_graph, mio::osecir::read_input_data_county>, mio::get_node_ids, scaling_factor_inf, - scaling_factor_icu, tnt_capacity_factor, num_days, export_time_series, rki_age_groups); + scaling_factor_icu, tnt_capacity_factor, num_days, export_time_series); return pymio::check_and_throw(result); }, py::return_value_policy::move); @@ -278,14 +278,14 @@ PYBIND11_MODULE(_simulation_osecir, m) const std::string& data_dir, const std::string& population_data_path, bool is_node_for_county, mio::Graph, mio::MobilityParameters>& params_graph, const std::vector& scaling_factor_inf, double scaling_factor_icu, double tnt_capacity_factor, - int num_days = 0, bool export_time_series = false, bool rki_age_groups = true) { + int num_days = 0, bool export_time_series = false) { auto result = mio::set_nodes< mio::osecir::TestAndTraceCapacity, mio::osecir::ContactPatterns, mio::osecir::Model, mio::MobilityParameters, mio::osecir::Parameters, decltype(mio::osecir::read_input_data_state>), decltype(mio::get_node_ids)>( params, start_date, end_date, data_dir, population_data_path, is_node_for_county, params_graph, mio::osecir::read_input_data_state>, mio::get_node_ids, scaling_factor_inf, - scaling_factor_icu, tnt_capacity_factor, num_days, export_time_series, rki_age_groups); + scaling_factor_icu, tnt_capacity_factor, num_days, export_time_series); return pymio::check_and_throw(result); }, py::return_value_policy::move); @@ -296,7 +296,7 @@ PYBIND11_MODULE(_simulation_osecir, m) const std::string& data_dir, const std::string& population_data_path, bool is_node_for_county, mio::Graph, mio::MobilityParameters>& params_graph, const std::vector& scaling_factor_inf, double scaling_factor_icu, double tnt_capacity_factor, - int num_days = 0, bool export_time_series = false, bool rki_age_groups = true) { + int num_days = 0, bool export_time_series = false) { auto result = mio::set_nodes, mio::osecir::ContactPatterns, mio::osecir::Model, mio::MobilityParameters, mio::osecir::Parameters, @@ -304,8 +304,7 @@ PYBIND11_MODULE(_simulation_osecir, m) decltype(mio::get_country_id)>( params, start_date, end_date, data_dir, population_data_path, is_node_for_county, params_graph, mio::osecir::read_input_data_germany>, mio::get_country_id, - scaling_factor_inf, scaling_factor_icu, tnt_capacity_factor, num_days, export_time_series, - rki_age_groups); + scaling_factor_inf, scaling_factor_icu, tnt_capacity_factor, num_days, export_time_series); return pymio::check_and_throw(result); }, py::return_value_policy::move); From bde63e4768a0287d4695a05cf2834c2be11c5565 Mon Sep 17 00:00:00 2001 From: Henrik Zunker <69154294+HenrZu@users.noreply.github.com> Date: Fri, 22 Aug 2025 09:12:22 +0200 Subject: [PATCH 30/73] Apply suggestions from code review Co-authored-by: Carlotta Gerstein <100771374+charlie0614@users.noreply.github.com> --- cpp/models/ode_secir/parameters_io.h | 2 +- cpp/tests/test_odesecir.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/models/ode_secir/parameters_io.h b/cpp/models/ode_secir/parameters_io.h index 165b4d0c10..d822bcce81 100644 --- a/cpp/models/ode_secir/parameters_io.h +++ b/cpp/models/ode_secir/parameters_io.h @@ -160,7 +160,7 @@ IOResult set_confirmed_cases_data(std::vector>& model, std::vect model[node].populations[{AgeGroup(i), InfectionState::Recovered}] = num_rec[node][i]; } } - else if (num_groups == 1) { + else { const auto sum_vec = [](const std::vector& v) { return std::accumulate(v.begin(), v.end(), 0.0); }; diff --git a/cpp/tests/test_odesecir.cpp b/cpp/tests/test_odesecir.cpp index 88e48a1c24..db64d1ea93 100644 --- a/cpp/tests/test_odesecir.cpp +++ b/cpp/tests/test_odesecir.cpp @@ -1521,7 +1521,7 @@ TEST(TestOdeSecir, read_population_data_failure) TEST(TestOdeSecirIO, read_input_data_county_aggregates_one_group) { - // Set up two models with different age groups. + // Set up two models with different number of age groups. const size_t num_age_groups = 6; std::vector> models6{mio::osecir::Model((int)num_age_groups)}; std::vector> models1{mio::osecir::Model(1)}; From 39691f9540825a209afacf07fe658dd67b746e6c Mon Sep 17 00:00:00 2001 From: HenrZu <69154294+HenrZu@users.noreply.github.com> Date: Fri, 22 Aug 2025 09:36:15 +0200 Subject: [PATCH 31/73] rm code tests, formats param io --- cpp/models/ode_secir/parameters_io.h | 6 +- cpp/tests/test_odesecir.cpp | 94 ++-------------------------- 2 files changed, 7 insertions(+), 93 deletions(-) diff --git a/cpp/models/ode_secir/parameters_io.h b/cpp/models/ode_secir/parameters_io.h index d822bcce81..8471e25a00 100644 --- a/cpp/models/ode_secir/parameters_io.h +++ b/cpp/models/ode_secir/parameters_io.h @@ -78,8 +78,8 @@ IOResult set_confirmed_cases_data(std::vector>& model, std::vect // allow single scalar scaling that is broadcast to all age groups assert(scaling_factor_inf.size() == 1 || scaling_factor_inf.size() == num_age_groups); - // Broadcast scaling factors to match RKI age groups (6) - std::vector scaling_factor_inf_full(num_age_groups, 1.0); + // Set scaling factors to match num age groups + std::vector scaling_factor_inf_full; if (scaling_factor_inf.size() == 1) { scaling_factor_inf_full.assign(num_age_groups, scaling_factor_inf[0]); } @@ -160,7 +160,7 @@ IOResult set_confirmed_cases_data(std::vector>& model, std::vect model[node].populations[{AgeGroup(i), InfectionState::Recovered}] = num_rec[node][i]; } } - else { + else { const auto sum_vec = [](const std::vector& v) { return std::accumulate(v.begin(), v.end(), 0.0); }; diff --git a/cpp/tests/test_odesecir.cpp b/cpp/tests/test_odesecir.cpp index db64d1ea93..99c6dcb99e 100644 --- a/cpp/tests/test_odesecir.cpp +++ b/cpp/tests/test_odesecir.cpp @@ -1607,47 +1607,6 @@ TEST(TestOdeSecirIO, set_population_data_single_age_group) std::vector> models6{mio::osecir::Model((int)num_age_groups)}; std::vector> models1{mio::osecir::Model(1)}; - // Set basic parameters for both models - models6[0].parameters.set(60); - models6[0].parameters.set>(0.2); - models1[0].parameters.set(60); - models1[0].parameters.set>(0.2); - - // Set parameters for 6-age-group model - for (auto i = mio::AgeGroup(0); i < (mio::AgeGroup)num_age_groups; ++i) { - models6[0].parameters.get>()[i] = 3.2; - models6[0].parameters.get>()[i] = 2.0; - models6[0].parameters.get>()[i] = 5.8; - models6[0].parameters.get>()[i] = 9.5; - models6[0].parameters.get>()[i] = 7.1; - models6[0].parameters.get>()[i] = 0.05; - models6[0].parameters.get>()[i] = 0.7; - models6[0].parameters.get>()[i] = 0.09; - models6[0].parameters.get>()[i] = 0.25; - models6[0].parameters.get>()[i] = 0.45; - models6[0].parameters.get>()[i] = 0.2; - models6[0].parameters.get>()[i] = 0.25; - models6[0].parameters.get>()[i] = 0.3; - } - - // Set parameters for 1-age-group model (same values) - models1[0].parameters.get>()[mio::AgeGroup(0)] = 3.2; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 2.0; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 5.8; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 9.5; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 7.1; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.05; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.7; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.09; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.25; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.45; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.2; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.25; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.3; - - models6[0].check_constraints(); - models1[0].check_constraints(); - // Test population data with 6 different values for age groups std::vector> population_data6 = {{10000.0, 20000.0, 30000.0, 25000.0, 15000.0, 8000.0}}; std::vector> population_data1 = {{108000.0}}; // sum of all age groups @@ -1776,60 +1735,15 @@ TEST(TestOdeSecirIO, set_divi_data_single_age_group) std::vector> models_6_groups{mio::osecir::Model(6)}; std::vector> models_1_group{mio::osecir::Model(1)}; - // Set identical parameters for both models - models_6_groups[0].parameters.set(60); - models_1_group[0].parameters.set(60); - models_6_groups[0].parameters.set>(0.2); - models_1_group[0].parameters.set>(0.2); - - // Set parameters for all age groups + // Set relevant parameters for all age groups for (int i = 0; i < 6; i++) { - models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = 3.2; - models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = 2.0; - models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = 5.8; - models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = 9.5; - models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = 7.1; - models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = - 0.05; - models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = - 0.7; - models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = - 0.09; - models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = - 0.25; - models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = - 0.45; models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = 0.2; models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = 0.25; - models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = 0.3; - } - - // Set parameters for 1 age group model (same values) - models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 3.2; - models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 2.0; - models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 5.8; - models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 9.5; - models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 7.1; - models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 0.05; - models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 0.7; - models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 0.09; - models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 0.25; - models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 0.45; - models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 0.2; - models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 0.25; - models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 0.3; - - // Set initial ICU populations to known values - double icu_per_age_group = 100.0; - for (int i = 0; i < 6; i++) { - models_6_groups[0].populations[{mio::AgeGroup(i), mio::osecir::InfectionState::InfectedCritical}] = - icu_per_age_group; } - models_1_group[0].populations[{mio::AgeGroup(0), mio::osecir::InfectionState::InfectedCritical}] = - 6.0 * icu_per_age_group; - models_6_groups[0].check_constraints(); - models_1_group[0].check_constraints(); + // Set relevant parameters for 1 age group model + models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 0.2; + models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 0.25; // Apply DIVI data to both models std::vector regions = {1002}; From afe10d9f29ff5babb9c3050c4a3a6aa8e66cff71 Mon Sep 17 00:00:00 2001 From: HenrZu <69154294+HenrZu@users.noreply.github.com> Date: Fri, 22 Aug 2025 09:59:14 +0200 Subject: [PATCH 32/73] lower tol --- cpp/tests/test_odesecir.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/tests/test_odesecir.cpp b/cpp/tests/test_odesecir.cpp index 99c6dcb99e..9ff2d25a66 100644 --- a/cpp/tests/test_odesecir.cpp +++ b/cpp/tests/test_odesecir.cpp @@ -1585,7 +1585,7 @@ TEST(TestOdeSecirIO, read_input_data_county_aggregates_one_group) // Aggreagate the results from the model with 6 age groups and compare with the model with 1 age group const auto& m6 = models6[0]; const auto& m1 = models1[0]; - const double tol = 1e-13; + const double tol = 1e-10; for (int s = 0; s < (int)mio::osecir::InfectionState::Count; ++s) { double sum6 = 0.0; for (size_t ag = 0; ag < num_age_groups; ++ag) { @@ -1596,7 +1596,7 @@ TEST(TestOdeSecirIO, read_input_data_county_aggregates_one_group) } // Total population - EXPECT_NEAR(m6.populations.get_total(), m1.populations.get_total(), 1e-13); + EXPECT_NEAR(m6.populations.get_total(), m1.populations.get_total(), tol); } TEST(TestOdeSecirIO, set_population_data_single_age_group) @@ -1617,7 +1617,7 @@ TEST(TestOdeSecirIO, set_population_data_single_age_group) EXPECT_THAT(mio::osecir::details::set_population_data(models1, population_data1, regions), IsSuccess()); // Sum all compartments across age groups in 6-group model and compare 1-group model - const double tol = 1e-13; + const double tol = 1e-10; for (int s = 0; s < (int)mio::osecir::InfectionState::Count; ++s) { double sum6 = 0.0; for (size_t ag = 0; ag < num_age_groups; ++ag) { From 43d44a3a2e90f1e43ab169318ae2d43b063d242e Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Mon, 25 Aug 2025 13:18:30 +0200 Subject: [PATCH 33/73] [ci skip] remove comments --- pycode/memilio-epidata/memilio/epidata/getPopulationData.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pycode/memilio-epidata/memilio/epidata/getPopulationData.py b/pycode/memilio-epidata/memilio/epidata/getPopulationData.py index 331860bcca..07b10896c0 100644 --- a/pycode/memilio-epidata/memilio/epidata/getPopulationData.py +++ b/pycode/memilio-epidata/memilio/epidata/getPopulationData.py @@ -142,8 +142,6 @@ def export_population_dataframe(df_pop: pd.DataFrame, directory: str, file_forma gd.write_dataframe(df_pop_export, directory, filename, file_format) gd.write_dataframe(aggregate_to_state_level(df_pop_export), directory, filename + '_states', file_format) gd.write_dataframe(aggregate_to_country_level(df_pop_export), directory, filename + '_germany', file_format) - # df_pop_germany = pd.DataFrame({"ID": [0], "Population": [df_pop_export["Population"].sum()]}) - # gd.write_dataframe(df_pop_germany, directory, filename + '_germany', file_format) return df_pop_export From 02688ac452eca0c5f6668c63be637cb7ec8e6b13 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Tue, 26 Aug 2025 11:28:24 +0200 Subject: [PATCH 34/73] add covid parameters to estimation --- .../simulation/graph_germany_nuts3.py | 251 +++++++++++++----- 1 file changed, 189 insertions(+), 62 deletions(-) diff --git a/pycode/examples/simulation/graph_germany_nuts3.py b/pycode/examples/simulation/graph_germany_nuts3.py index 7010cff1dc..c0d6e4497e 100644 --- a/pycode/examples/simulation/graph_germany_nuts3.py +++ b/pycode/examples/simulation/graph_germany_nuts3.py @@ -35,6 +35,94 @@ no_icu_ids = [7338, 9374, 9473, 9573] region_ids = [region_id for region_id in dd.County.keys() if region_id not in excluded_ids] +inference_params = ['damping_values', 't_E', 't_ISy', 't_ISev', 'transmission_prob'] + +def plot_region_median_mad( + data: np.ndarray, + region: int, + true_data = None, + ax = None, + label = None, + color = "red" +): + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data is not None: + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + n_samples, n_time, n_regions = data.shape + if not (0 <= region < n_regions): + raise IndexError + + x = np.arange(n_time) + vals = data[:, :, region] # (samples, time_points) + med = np.median(vals, axis=0) + mad = np.median(np.abs(vals - med), axis=0) + + if ax is None: + fig, ax = plt.subplots() + + line, = ax.plot(x, med, lw=2, label=label or f"Region {region}", color=color) + band = ax.fill_between(x, med - mad, med + mad, alpha=0.25, color=color) + if true_data is not None: + true_vals = true_data[:, region] # (time_points,) + ax.plot(x, true_vals, lw=2, color="black", label="True data") + + ax.set_xlabel("Time") + ax.set_ylabel("ICU") + ax.set_title(f"Region {region}") + if label is not None: + ax.legend() + return line, band + + +def plot_aggregated_over_regions( + data: np.ndarray, + region_agg = np.sum, + true_data = None, + ax = None, + label = None, + color = 'red' +): + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data is not None: + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + + # Aggregate over regions + agg_over_regions = region_agg(data, axis=-1) # (samples, time_points) + + # Aggregate over samples + agg_median = np.median(agg_over_regions, axis=0) # (time_points, ) + agg_mad = np.median( + np.abs(agg_over_regions - agg_median[None]), + axis=0 + ) + + x = np.arange(agg_median.shape[0]) + if ax is None: + fig, ax = plt.subplots() + + line, = ax.plot(x, agg_median, lw=2, label=label or "Aggregated over regions", color=color) + band = ax.fill_between( + x, + agg_median - agg_mad, + agg_median + agg_mad, + alpha=0.25, + color=color + ) + if true_data is not None: + true_vals = region_agg(true_data, axis=-1) # (time_points,) + ax.plot(x, true_vals, lw=2, color="black", label="True data") + + ax.set_xlabel("Time") + ax.set_ylabel("ICU") + if label is not None: + ax.legend() + + return line, band + class Simulation: """ """ @@ -46,20 +134,20 @@ def __init__(self, data_dir, start_date, results_dir): if not os.path.exists(self.results_dir): os.makedirs(self.results_dir) - def set_covid_parameters(self, model): + def set_covid_parameters(self, model, t_E, t_ISy, t_ISev, t_Cr, transmission_prob): """ :param model: """ - model.parameters.TimeExposed[mio.AgeGroup(0)] = 3.335 - model.parameters.TimeInfectedNoSymptoms[mio.AgeGroup(0)] = 2.58916 - model.parameters.TimeInfectedSymptoms[mio.AgeGroup(0)] = 6.94547 - model.parameters.TimeInfectedSevere[mio.AgeGroup(0)] = 7.28196 - model.parameters.TimeInfectedCritical[mio.AgeGroup(0)] = 13.066 + model.parameters.TimeExposed[mio.AgeGroup(0)] = t_E + model.parameters.TimeInfectedNoSymptoms[mio.AgeGroup(0)] = 5.2 - t_E + model.parameters.TimeInfectedSymptoms[mio.AgeGroup(0)] = t_ISy + model.parameters.TimeInfectedSevere[mio.AgeGroup(0)] = t_ISev + model.parameters.TimeInfectedCritical[mio.AgeGroup(0)] = t_Cr # probabilities - model.parameters.TransmissionProbabilityOnContact[mio.AgeGroup(0)] = 0.07333 + model.parameters.TransmissionProbabilityOnContact[mio.AgeGroup(0)] = transmission_prob model.parameters.RelativeTransmissionNoSymptoms[mio.AgeGroup(0)] = 1 model.parameters.RecoveredPerInfectedNoSymptoms[mio.AgeGroup(0)] = 0.2069 @@ -81,7 +169,7 @@ def set_contact_matrices(self, model): contact_matrices = mio.ContactMatrixGroup(1, self.num_groups) baseline = np.ones((self.num_groups, self.num_groups)) * 7.95 - minimum = np.ones((self.num_groups, self.num_groups)) * 0 + minimum = np.zeros((self.num_groups, self.num_groups)) contact_matrices[0] = mio.ContactMatrix(baseline, minimum) model.parameters.ContactPatterns.cont_freq_mat = contact_matrices @@ -94,13 +182,13 @@ def set_npis(self, params, end_date, damping_value): """ start_damping = datetime.date( - year=2020, month=7, day=8) + year=2020, month=10, day=8) if start_damping < end_date: start_date = (start_damping - self.start_date).days params.ContactPatterns.cont_freq_mat[0].add_damping(mio.Damping(np.r_[damping_value], t=start_date)) - def get_graph(self, end_date): + def get_graph(self, end_date, t_E, t_ISy, t_ISev, t_Cr, transmission_prob): """ :param end_date: @@ -108,13 +196,13 @@ def get_graph(self, end_date): """ print("Initializing model...") model = Model(self.num_groups) - self.set_covid_parameters(model) + self.set_covid_parameters(model, t_E, t_ISy, t_ISev, t_Cr, transmission_prob) self.set_contact_matrices(model) print("Model initialized.") graph = osecir.ModelGraph() - scaling_factor_infected = [1] + scaling_factor_infected = [2.5] scaling_factor_icu = 1.0 tnt_capacity_factor = 7.5 / 100000. @@ -144,7 +232,7 @@ def get_graph(self, end_date): return graph - def run(self, num_days_sim, damping_values, save_graph=True): + def run(self, num_days_sim, damping_values, t_E, t_ISy, t_ISev, t_Cr, transmission_prob, save_graph=True): """ :param num_days_sim: @@ -156,7 +244,7 @@ def run(self, num_days_sim, damping_values, save_graph=True): mio.set_log_level(mio.LogLevel.Warning) end_date = self.start_date + datetime.timedelta(days=num_days_sim) - graph = self.get_graph(end_date) + graph = self.get_graph(end_date, t_E, t_ISy, t_ISev, t_Cr, transmission_prob) if save_graph: path_graph = os.path.join(self.results_dir, "graph") @@ -181,41 +269,52 @@ def run(self, num_days_sim, damping_values, save_graph=True): for node_idx in range(mobility_sim.graph.num_nodes): node = mobility_sim.graph.get_node(node_idx) if node.id in no_icu_ids: - results[f'no_icu_region{node_idx}'] = osecir.interpolate_simulation_result(node.property.result).as_ndarray() + results[f'no_icu_region{node_idx}'] = osecir.interpolate_simulation_result(node.property.result) else: - results[f'region{node_idx}'] = osecir.interpolate_simulation_result(node.property.result).as_ndarray() + results[f'region{node_idx}'] = osecir.interpolate_simulation_result(node.property.result) return results -def run_germany_nuts3_simulation(damping_values): +def run_germany_nuts3_simulation(damping_values, t_E, t_ISy, t_ISev, t_Cr, transmission_prob): mio.set_log_level(mio.LogLevel.Warning) file_path = os.path.dirname(os.path.abspath(__file__)) sim = Simulation( data_dir=os.path.join(file_path, "../../../data"), - start_date=datetime.date(year=2020, month=7, day=1), + start_date=datetime.date(year=2020, month=10, day=1), results_dir=os.path.join(file_path, "../../../results_osecir")) - num_days_sim = 50 + num_days_sim = 60 - results = sim.run(num_days_sim, damping_values) + results = sim.run(num_days_sim, damping_values, t_E, t_ISy, t_ISev, t_Cr, transmission_prob) return results def prior(): damping_values = np.random.uniform(0.0, 1.0, 16) - return {'damping_values': damping_values} + t_E = np.random.uniform(1., 5.2) + t_ISy = np.random.uniform(4., 10.) + t_ISev = np.random.uniform(5., 10.) + t_Cr = np.random.uniform(9., 17.) + transmission_prob = np.random.uniform(0., 0.2) + return {'damping_values': damping_values, + 't_E': t_E, + 't_ISy': t_ISy, + 't_ISev': t_ISev, + 't_Cr': t_Cr, + 'transmission_prob': transmission_prob} def load_divi_data(): file_path = os.path.dirname(os.path.abspath(__file__)) divi_path = os.path.join(file_path, "../../../data/Germany/pydata") - data = pd.read_json(os.path.join(divi_path, "county_divi_ma7.json")) - data = data[data['Date']>= np.datetime64(datetime.date(2020, 7, 1))] - data = data[data['Date'] <= np.datetime64(datetime.date(2020, 7, 1) + datetime.timedelta(days=50))] - data = data.drop(columns=['County', 'ICU_ventilated', 'Date']) - divi_dict = {f"region{i}": data[data['ID_County'] == region_id]['ICU'].to_numpy()[None, :, None] for i, region_id in enumerate(region_ids) if region_id not in no_icu_ids} + data = pd.read_json(os.path.join(divi_path, "county_divi_all_dates.json")) + data = data[data['Date']>= np.datetime64(datetime.date(2020, 10, 1))] + data = data[data['Date'] <= np.datetime64(datetime.date(2020, 10, 1) + datetime.timedelta(days=50))] + data = data.sort_values(by=['ID_County', 'Date']) + divi_data = data.pivot(index='Date', columns='ID_County', values='ICU') + divi_dict = {f"region{i}": divi_data[region_id].to_numpy()[None, :, None] for i, region_id in enumerate(region_ids) if region_id not in no_icu_ids} - return divi_dict + return divi_data.to_numpy(), divi_dict if __name__ == "__main__": @@ -227,39 +326,45 @@ def load_divi_data(): from tensorflow import keras simulator = bf.simulators.make_simulator([prior, run_germany_nuts3_simulation]) - trainings_data = simulator.sample(1) + trainings_data = simulator.sample(1000) - # for key in trainings_data.keys(): - # if key != 'damping_values': - # trainings_data[key] = trainings_data[key][:, :, 8][..., np.newaxis] + for key in trainings_data.keys(): + if key not in inference_params: + trainings_data[key] = trainings_data[key][:, :, 7][..., np.newaxis] - # with open('validation_data_counties.pickle', 'wb') as f: - # pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) + with open('trainings_data10_counties_wcovidparams_oct.pickle', 'wb') as f: + pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) - # with open('trainings_data1_counties.pickle', 'rb') as f: + # with open('trainings_data1_counties_wcovidparams_oct.pickle', 'rb') as f: # trainings_data = pickle.load(f) - # for i in range(9): - # with open(f'trainings_data{i+2}_counties.pickle', 'rb') as f: + # trainings_data = {k: np.round(v) if ('region' in k) else v for k, v in trainings_data.items()} + # for i in range(19): + # with open(f'trainings_data{i+2}_counties_wcovidparams_oct.pickle', 'rb') as f: # data = pickle.load(f) - # trainings_data = {k: np.concatenate([trainings_data[k], data[k]]) for k in trainings_data.keys()} + # trainings_data = {k: np.concatenate([trainings_data[k], np.round(data[k])]) if ('region' in k) else np.concatenate([trainings_data[k], data[k]]) for k in trainings_data.keys()} - # with open('validation_data_counties.pickle', 'rb') as f: + # with open('validation_data_counties_wcovidparams_oct.pickle', 'rb') as f: # validation_data = pickle.load(f) + # divi_dict = {k: np.round(v) if ('region' in k) else v for k, v in validation_data.items()} # adapter = ( # bf.Adapter() # .to_array() # .convert_dtype("float64", "float32") # .constrain("damping_values", lower=0.0, upper=1.0) - # .rename("damping_values", "inference_variables") + # .constrain("t_E", lower=1.0, upper=6.0) + # .constrain("t_ISy", lower=5.0, upper=10.0) + # .constrain("t_ISev", lower=2.0, upper=8.0) + # .concatenate(["damping_values", "t_E", "t_ISy", "t_ISev"], into="inference_variables", axis=-1) # .concatenate([f'region{i}' for i in range(len(region_ids)) if region_ids[i] not in no_icu_ids], into="summary_variables", axis=-1) # .log("summary_variables", p1=True) # ) - # # print("summary_variables shape:", adapter(trainings_data)["summary_variables"].shape) + # print("summary_variables shape:", adapter(trainings_data)["summary_variables"].shape) + # print("inference_variables shape:", adapter(trainings_data)["inference_variables"].shape) - # summary_network = bf.networks.TimeSeriesNetwork(summary_dim=32, recurrent_dim=32) - # inference_network = bf.networks.CouplingFlow() + # summary_network = bf.networks.TimeSeriesNetwork(summary_dim=38) + # inference_network = bf.networks.CouplingFlow(depth=7, transform='spline') # workflow = bf.BasicWorkflow( # simulator=simulator, @@ -269,23 +374,45 @@ def load_divi_data(): # standardize='all' # ) - # # history = workflow.fit_offline(data=trainings_data, epochs=100, batch_size=32, validation_data=validation_data) - - # # workflow.approximator.save(filepath=os.path.join(os.path.dirname(__file__), "model_countylvl.keras")) - - # # plots = workflow.plot_default_diagnostics(test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True}) - # # plots['losses'].savefig('losses_countylvl.png') - # # plots['recovery'].savefig('recovery_countylvl.png') - # # plots['calibration_ecdf'].savefig('calibration_ecdf_countylvl.png') - # # plots['z_score_contraction'].savefig('z_score_contraction_countylvl.png') - - # test = load_divi_data() - # workflow.approximator = keras.models.load_model(os.path.join(os.path.dirname(__file__), "model_countylvl.keras")) - - # samples = workflow.sample(conditions=test, num_samples=10) - # # samples = workflow.samples_to_data_frame(samples) - # # print(samples.head()) - # samples['damping_values'] = np.squeeze(samples['damping_values']) - # for i in range(samples['damping_values'].shape[0]): - # test = run_germany_nuts3_simulation(samples['damping_values'][i]) - # print(test) \ No newline at end of file + # history = workflow.fit_offline(data=trainings_data, epochs=100, batch_size=32, validation_data=validation_data) + + # workflow.approximator.save(filepath=os.path.join(os.path.dirname(__file__), "model_countylvl_wcovidparams_oct.keras")) + + # plots = workflow.plot_default_diagnostics(test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True}) + # plots['losses'].savefig('losses_countylvl_wcovidparams2_oct.png') + # plots['recovery'].savefig('recovery_countylvl_wcovidparams2_oct.png') + # plots['calibration_ecdf'].savefig('calibration_ecdf_countylvl_wcovidparams2_oct.png') + # plots['z_score_contraction'].savefig('z_score_contraction_countylvl_wcovidparams2_oct.png') + + # divi_data, divi_dict = load_divi_data() + # # divi_data = np.concatenate( + # # [validation_data[f'region{i}'] for i in range(len(region_ids)) if region_ids[i] not in no_icu_ids], + # # axis=-1 + # # ) + # workflow.approximator = keras.models.load_model(os.path.join(os.path.dirname(__file__), "model_countylvl_wcovidparams_oct.keras")) + + # samples = workflow.sample(conditions=divi_dict, num_samples=1000) + # samples = np.concatenate([samples[key] for key in inference_params], axis=-1) + # samples = np.squeeze(samples) + # sims = [] + # for i in range(samples.shape[0]): + # result = run_germany_nuts3_simulation(samples[i][:16], *samples[i][16:]) + # for key in result.keys(): + # result[key] = np.array(result[key])[:, 7, None] + # sims.append(np.concatenate([result[key] for key in result.keys() if key.startswith('region')], axis=-1)) + # sims = np.array(sims) + # sims = np.floor(sims) + + # np.random.seed(42) + # fig, ax = plt.subplots(nrows=2, ncols=5, figsize=(12, 5), layout="constrained") + # ax = ax.flatten() + # rand_index = np.random.choice(sims.shape[-1], replace=False, size=len(ax)) + # for i, a in enumerate(ax): + # plot_region_median_mad(sims, region=rand_index[i], true_data=divi_data, label=r"Median $\pm$ Mad", ax=a) + # plt.savefig('random_regions_wcovidparams_oct.png') + # # plt.show() + # #%% + # plot_aggregated_over_regions(sims, true_data=divi_data, label="Region Aggregated Median $\pm$ Mad") + # plt.savefig('region_aggregated_wcovidparams_oct.png') + # # plt.show() + # # %% \ No newline at end of file From 67468af61677eb8ab9da8a28e6a0b50195f57010 Mon Sep 17 00:00:00 2001 From: HenrZu <69154294+HenrZu@users.noreply.github.com> Date: Tue, 26 Aug 2025 14:14:04 +0200 Subject: [PATCH 35/73] only relevant inits --- cpp/tests/test_odesecir.cpp | 84 +++---------------------------------- 1 file changed, 6 insertions(+), 78 deletions(-) diff --git a/cpp/tests/test_odesecir.cpp b/cpp/tests/test_odesecir.cpp index 9ff2d25a66..010fc84202 100644 --- a/cpp/tests/test_odesecir.cpp +++ b/cpp/tests/test_odesecir.cpp @@ -1526,48 +1526,15 @@ TEST(TestOdeSecirIO, read_input_data_county_aggregates_one_group) std::vector> models6{mio::osecir::Model((int)num_age_groups)}; std::vector> models1{mio::osecir::Model(1)}; - // set parameters for both models - models6[0].parameters.set(60); - models6[0].parameters.set>(0.2); - models1[0].parameters.set(60); - models1[0].parameters.set>(0.2); - - // parameters for model with 6 age groups + // Relevant parameters for model with 6 age groups for (auto i = mio::AgeGroup(0); i < (mio::AgeGroup)num_age_groups; ++i) { - models6[0].parameters.get>()[i] = 3.2; - models6[0].parameters.get>()[i] = 2.0; - models6[0].parameters.get>()[i] = 5.8; - models6[0].parameters.get>()[i] = 9.5; - models6[0].parameters.get>()[i] = 7.1; - - models6[0].parameters.get>()[i] = 0.05; - models6[0].parameters.get>()[i] = 0.7; - models6[0].parameters.get>()[i] = 0.09; - models6[0].parameters.get>()[i] = 0.25; - models6[0].parameters.get>()[i] = 0.45; - models6[0].parameters.get>()[i] = 0.2; - models6[0].parameters.get>()[i] = 0.25; - models6[0].parameters.get>()[i] = 0.3; + models6[0].parameters.get>()[i] = 0.2; + models6[0].parameters.get>()[i] = 0.25; } - // parameters for model with 1 age group - models1[0].parameters.get>()[mio::AgeGroup(0)] = 3.2; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 2.0; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 5.8; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 9.5; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 7.1; - - models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.05; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.7; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.09; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.25; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.45; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.2; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.25; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.3; - - models6[0].check_constraints(); - models1[0].check_constraints(); + // Relevant parameters for model with 1 age group + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.2; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.25; const auto pydata_dir_Germany = mio::path_join(TEST_DATA_DIR, "Germany", "pydata"); const std::vector counties{1002}; @@ -1640,45 +1607,6 @@ TEST(TestOdeSecirIO, set_confirmed_cases_data_single_age_group) std::vector> models6{mio::osecir::Model((int)num_age_groups)}; std::vector> models1{mio::osecir::Model(1)}; - // Set identical parameters for both models - models6[0].parameters.set(60); - models6[0].parameters.set>(0.2); - models1[0].parameters.set(60); - models1[0].parameters.set>(0.2); - - for (auto i = mio::AgeGroup(0); i < (mio::AgeGroup)num_age_groups; ++i) { - models6[0].parameters.get>()[i] = 3.2; - models6[0].parameters.get>()[i] = 2.0; - models6[0].parameters.get>()[i] = 5.8; - models6[0].parameters.get>()[i] = 9.5; - models6[0].parameters.get>()[i] = 7.1; - models6[0].parameters.get>()[i] = 0.05; - models6[0].parameters.get>()[i] = 0.7; - models6[0].parameters.get>()[i] = 0.09; - models6[0].parameters.get>()[i] = 0.25; - models6[0].parameters.get>()[i] = 0.45; - models6[0].parameters.get>()[i] = 0.2; - models6[0].parameters.get>()[i] = 0.25; - models6[0].parameters.get>()[i] = 0.3; - } - - models1[0].parameters.get>()[mio::AgeGroup(0)] = 3.2; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 2.0; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 5.8; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 9.5; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 7.1; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.05; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.7; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.09; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.25; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.45; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.2; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.25; - models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.3; - - models6[0].check_constraints(); - models1[0].check_constraints(); - // Create case data for all 6 age groups over multiple days (current day + 6 days back) std::vector case_data; From 1e8007c33ae15d7b08437563d931603fe2c8b419 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Tue, 26 Aug 2025 15:39:17 +0200 Subject: [PATCH 36/73] [ci skip] improve fetch spain data --- .../memilio/epidata/getSimulationDataSpain.py | 62 +++++++++++++++++++ .../epidata/modifyPopulationDataSpain.py | 21 ------- 2 files changed, 62 insertions(+), 21 deletions(-) create mode 100644 pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py delete mode 100644 pycode/memilio-epidata/memilio/epidata/modifyPopulationDataSpain.py diff --git a/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py b/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py new file mode 100644 index 0000000000..ec84d4a0c6 --- /dev/null +++ b/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py @@ -0,0 +1,62 @@ +import pandas as pd +import os +import io +import requests + +def fetch_population_data(): + download_url = 'https://servicios.ine.es/wstempus/js/es/DATOS_TABLA/67988?tip=AM&' + req = requests.get(download_url) + req.encoding = 'ISO-8859-1' + + df = pd.read_json(io.StringIO(req.text)) + df = df[['MetaData', 'Data']] + df['ID_Provincia'] = df['MetaData'].apply(lambda x: x[0]['Id']) + df['Population'] = df['Data'].apply(lambda x: x[0]['Valor']) + return df[['ID_Provincia', 'Population']] + +def remove_islands(df): + df = df[~df['ID_Provincia'].isin([51, 52, 8, 35, 38])] + return df + +def get_population_data(): + df = fetch_population_data() + df = remove_islands(df) + + return df + +def fetch_icu_data(): + download_url = 'https://www.sanidad.gob.es/areas/alertasEmergenciasSanitarias/alertasActuales/nCov/documentos/Datos_Capacidad_Asistencial_Historico_14072023.csv' + req = requests.get(download_url) + req.encoding = 'ISO-8859-1' + + df = pd.read_csv(io.StringIO(req.text), sep=';') + return df + +def preprocess_icu_data(df): + df_without_ventilator = df[df["Unidad"] == "U. Críticas SIN respirador"] + df_with_ventilator = df[df["Unidad"] == "U. Críticas CON respirador"] + + df_icu = df_without_ventilator[['Fecha', 'ID_Provincia', 'Provincia', 'OCUPADAS_COVID19']].rename(columns={'OCUPADAS_COVID19': 'ICU'}) + df_icu_vent = df_with_ventilator[['Fecha', 'ID_Provincia', 'Provincia', 'OCUPADAS_COVID19']].rename(columns={'OCUPADAS_COVID19': 'ICU_ventilated'}) + + df_merged = pd.merge(df_icu, df_icu_vent, on=['Fecha', 'ID_Provincia', 'Provincia'], how='outer') + df_merged['Fecha'] = pd.to_datetime(df_merged['Fecha'], format='%d/%m/%Y') + return df_merged + +def get_icu_data(): + df = fetch_icu_data() + df.rename(columns={'Cod_Provincia': 'ID_Provincia'}, inplace=True) + df = remove_islands(df) + df = preprocess_icu_data(df) + + return df + +if __name__ == "__main__": + + data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../../data/Spain") + + df = get_population_data() + df.to_json(os.path.join(data_dir, 'pydata/provincias_current_population.json'), orient='records') + + df = get_icu_data() + df.to_json(os.path.join(data_dir, 'pydata/provincia_icu.json'), orient='records') \ No newline at end of file diff --git a/pycode/memilio-epidata/memilio/epidata/modifyPopulationDataSpain.py b/pycode/memilio-epidata/memilio/epidata/modifyPopulationDataSpain.py deleted file mode 100644 index acdba425f8..0000000000 --- a/pycode/memilio-epidata/memilio/epidata/modifyPopulationDataSpain.py +++ /dev/null @@ -1,21 +0,0 @@ -import pandas as pd -import os - -def read_population_data(file): - df = pd.read_json(file) - df = df[['MetaData', 'Data']] - df['ID_Provincia'] = df['MetaData'].apply(lambda x: x[0]['Id']) - df['Population'] = df['Data'].apply(lambda x: x[0]['Valor']) - return df[['ID_Provincia', 'Population']] - -def remove_islands(df): - df = df[~df['ID_Provincia'].isin([51, 52, 8, 35, 38])] - return df - -if __name__ == "__main__": - - data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../../data/Spain") - - df = read_population_data(os.path.join(data_dir, 'pydata/67988.json')) - df = remove_islands(df) - df.to_json(os.path.join(data_dir, 'pydata/provincias_current_population.json'), orient='records', force_ascii=False) \ No newline at end of file From 4b5dbdc5bdafe845883b8ba96d9c6f4462a3102e Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Tue, 26 Aug 2025 15:55:36 +0200 Subject: [PATCH 37/73] [ci skip] fix get spanish icu data --- .../memilio/epidata/getSimulationDataSpain.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py b/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py index ec84d4a0c6..41ba384d64 100644 --- a/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py +++ b/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py @@ -33,14 +33,16 @@ def fetch_icu_data(): return df def preprocess_icu_data(df): - df_without_ventilator = df[df["Unidad"] == "U. Críticas SIN respirador"] - df_with_ventilator = df[df["Unidad"] == "U. Críticas CON respirador"] + df_icu = df[df["Unidad"] == "U. Críticas SIN respirador"] + df_icu_vent = df[df["Unidad"] == "U. Críticas CON respirador"] - df_icu = df_without_ventilator[['Fecha', 'ID_Provincia', 'Provincia', 'OCUPADAS_COVID19']].rename(columns={'OCUPADAS_COVID19': 'ICU'}) - df_icu_vent = df_with_ventilator[['Fecha', 'ID_Provincia', 'Provincia', 'OCUPADAS_COVID19']].rename(columns={'OCUPADAS_COVID19': 'ICU_ventilated'}) + df_icu = df_icu[['Fecha', 'ID_Provincia', 'OCUPADAS_COVID19']].rename(columns={'OCUPADAS_COVID19': 'ICU'}) + df_icu_vent = df_icu_vent[['Fecha', 'ID_Provincia', 'OCUPADAS_COVID19']].rename(columns={'OCUPADAS_COVID19': 'ICU_ventilated'}) - df_merged = pd.merge(df_icu, df_icu_vent, on=['Fecha', 'ID_Provincia', 'Provincia'], how='outer') - df_merged['Fecha'] = pd.to_datetime(df_merged['Fecha'], format='%d/%m/%Y') + df_merged = pd.merge(df_icu, df_icu_vent, on=['Fecha', 'ID_Provincia'], how='outer') + df_merged['Fecha'] = pd.to_datetime(df_merged['Fecha'], format='%d/%m/%Y').dt.strftime('%Y-%m-%d') + df_merged.rename(columns={'Fecha': 'Date'}, inplace=True) + print(df_merged) return df_merged def get_icu_data(): From e4a930b2f99d478ebcee063dfe373621e04f23d5 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Wed, 27 Aug 2025 08:59:01 +0200 Subject: [PATCH 38/73] add get case data --- .../memilio/epidata/defaultDict.py | 54 +++++++++++++++++++ .../memilio/epidata/getSimulationDataSpain.py | 30 ++++++++++- 2 files changed, 82 insertions(+), 2 deletions(-) diff --git a/pycode/memilio-epidata/memilio/epidata/defaultDict.py b/pycode/memilio-epidata/memilio/epidata/defaultDict.py index f4013a491b..281f88c130 100644 --- a/pycode/memilio-epidata/memilio/epidata/defaultDict.py +++ b/pycode/memilio-epidata/memilio/epidata/defaultDict.py @@ -758,3 +758,57 @@ def invert_dict(dict_to_invert): 51: 'Ceuta', 52: 'Melilla', 53: 'Ourense'} + + +Provincia_ISO_to_ID = {'A': 4, + 'AB': 3, + 'AL': 5, + 'AV': 6, + 'B': 9, + 'BA': 7, + 'BI': 48, + 'BU': 10, + 'C': 16, + 'CA': 12, + 'CC': 11, + 'CE': 51, + 'CO': 15, + 'CR': 14, + 'CS': 13, + 'CU': 17, + 'GC': 35, + 'GI': 18, + 'GR': 19, + 'GU': 20, + 'H': 22, + 'HU': 23, + 'J': 24, + 'L': 26, + 'LE': 25, + 'LO': 27, + 'LU': 28, + 'M': 29, + 'MA': 30, + 'ML': 52, + 'MU': 31, + 'NA': 32, + 'O': 33, + 'OR': 53, + 'P': 34, + 'PM': 8, + 'PO': 36, + 'SA': 37, + 'S': 39, + 'SE': 41, + 'SG': 40, + 'SO': 42, + 'SS': 21, + 'T': 43, + 'TE': 44, + 'TF': 38, + 'TO': 45, + 'V': 46, + 'VA': 47, + 'VI': 2, + 'ZA': 48, + 'Z': 50} diff --git a/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py b/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py index 41ba384d64..6f3fca3749 100644 --- a/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py +++ b/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py @@ -3,6 +3,8 @@ import io import requests +import defaultDict as dd + def fetch_population_data(): download_url = 'https://servicios.ine.es/wstempus/js/es/DATOS_TABLA/67988?tip=AM&' req = requests.get(download_url) @@ -42,7 +44,7 @@ def preprocess_icu_data(df): df_merged = pd.merge(df_icu, df_icu_vent, on=['Fecha', 'ID_Provincia'], how='outer') df_merged['Fecha'] = pd.to_datetime(df_merged['Fecha'], format='%d/%m/%Y').dt.strftime('%Y-%m-%d') df_merged.rename(columns={'Fecha': 'Date'}, inplace=True) - print(df_merged) + return df_merged def get_icu_data(): @@ -53,6 +55,27 @@ def get_icu_data(): return df +def fetch_case_data(): + download_url = 'https://cnecovid.isciii.es/covid19/resources/casos_diagnostico_provincia.csv' + req = requests.get(download_url) + + df = pd.read_csv(io.StringIO(req.text), sep=',') + + return df + +def preprocess_case_data(df): + df['provincia_iso'] = df['provincia_iso'].map(dd.Provincia_ISO_to_ID) + df = df.rename(columns={'provincia_iso': 'ID_Provincia', 'fecha': 'Date', 'num_casos': 'Confirmed'})[['ID_Provincia', 'Date', 'Confirmed']] + + return df + + +def get_case_data(): + df = fetch_case_data() + df = preprocess_case_data(df) + + return df + if __name__ == "__main__": data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../../data/Spain") @@ -61,4 +84,7 @@ def get_icu_data(): df.to_json(os.path.join(data_dir, 'pydata/provincias_current_population.json'), orient='records') df = get_icu_data() - df.to_json(os.path.join(data_dir, 'pydata/provincia_icu.json'), orient='records') \ No newline at end of file + df.to_json(os.path.join(data_dir, 'pydata/provincia_icu.json'), orient='records') + + df = get_case_data() + df.to_json(os.path.join(data_dir, 'pydata/cases_all_pronvincias.json'), orient='records') \ No newline at end of file From 2f55a4448d7a8fd3e66d66193626052a9af3bd9a Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Wed, 27 Aug 2025 09:01:34 +0200 Subject: [PATCH 39/73] remove comments --- cpp/models/ode_secir/parameters_io.h | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/cpp/models/ode_secir/parameters_io.h b/cpp/models/ode_secir/parameters_io.h index a601f71479..8eaad47685 100644 --- a/cpp/models/ode_secir/parameters_io.h +++ b/cpp/models/ode_secir/parameters_io.h @@ -443,8 +443,7 @@ IOResult read_input_data_state(std::vector& model, Date date, std:: template IOResult read_input_data_county(std::vector& model, Date date, const std::vector& county, const std::vector& scaling_factor_inf, double scaling_factor_icu, - const std::string& pydata_dir, int /*num_days*/ = 0, - bool /*export_time_series*/ = false) + const std::string& pydata_dir, int num_days = 0, bool export_time_series = false) { BOOST_OUTCOME_TRY( details::set_divi_data(model, path_join(pydata_dir, "county_divi_ma7.json"), county, date, scaling_factor_icu)); @@ -453,17 +452,17 @@ IOResult read_input_data_county(std::vector& model, Date date, cons BOOST_OUTCOME_TRY( details::set_population_data(model, path_join(pydata_dir, "county_current_population.json"), county)); - // if (export_time_series) { - // // Use only if extrapolated real data is needed for comparison. EXPENSIVE ! - // // Run time equals run time of the previous functions times the num_days ! - // // (This only represents the vectorization of the previous function over all simulation days...) - // log_warning("Exporting time series of extrapolated real data. This may take some minutes. " - // "For simulation runs over the same time period, deactivate it."); - // // BOOST_OUTCOME_TRY(export_input_data_county_timeseries( - // // model, pydata_dir, county, date, scaling_factor_inf, scaling_factor_icu, num_days, - // // path_join(pydata_dir, "county_divi_ma7.json"), path_join(pydata_dir, "cases_all_county_age_ma7.json"), - // // path_join(pydata_dir, "county_current_population.json"))); - // } + if (export_time_series) { + // Use only if extrapolated real data is needed for comparison. EXPENSIVE ! + // Run time equals run time of the previous functions times the num_days ! + // (This only represents the vectorization of the previous function over all simulation days...) + log_warning("Exporting time series of extrapolated real data. This may take some minutes. " + "For simulation runs over the same time period, deactivate it."); + BOOST_OUTCOME_TRY(export_input_data_county_timeseries( + model, pydata_dir, county, date, scaling_factor_inf, scaling_factor_icu, num_days, + path_join(pydata_dir, "county_divi_ma7.json"), path_join(pydata_dir, "cases_all_county_age_ma7.json"), + path_join(pydata_dir, "county_current_population.json"))); + } return success(); } From fee12c739aecd1ce4900d44b3713310beac30c72 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Wed, 27 Aug 2025 09:02:49 +0200 Subject: [PATCH 40/73] remove writing of graph --- pycode/examples/simulation/graph_germany_nuts3.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pycode/examples/simulation/graph_germany_nuts3.py b/pycode/examples/simulation/graph_germany_nuts3.py index c0d6e4497e..5153d0f7a7 100644 --- a/pycode/examples/simulation/graph_germany_nuts3.py +++ b/pycode/examples/simulation/graph_germany_nuts3.py @@ -246,12 +246,6 @@ def run(self, num_days_sim, damping_values, t_E, t_ISy, t_ISev, t_Cr, transmissi graph = self.get_graph(end_date, t_E, t_ISy, t_ISev, t_Cr, transmission_prob) - if save_graph: - path_graph = os.path.join(self.results_dir, "graph") - if not os.path.exists(path_graph): - os.makedirs(path_graph) - osecir.write_graph(graph, path_graph) - mobility_graph = osecir.MobilityGraph() for node_idx in range(graph.num_nodes): node = graph.get_node(node_idx) From 845a84d20222ca1c05abbb7ffa40c7b2a19a2850 Mon Sep 17 00:00:00 2001 From: HenrZu <69154294+HenrZu@users.noreply.github.com> Date: Wed, 27 Aug 2025 09:16:20 +0200 Subject: [PATCH 41/73] format + create dir --- .../memilio/epidata/getSimulationDataSpain.py | 43 +++++++++++++------ 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py b/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py index 6f3fca3749..d69b6aabba 100644 --- a/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py +++ b/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py @@ -1,14 +1,15 @@ -import pandas as pd +import pandas as pd import os import io import requests import defaultDict as dd + def fetch_population_data(): download_url = 'https://servicios.ine.es/wstempus/js/es/DATOS_TABLA/67988?tip=AM&' req = requests.get(download_url) - req.encoding = 'ISO-8859-1' + req.encoding = 'ISO-8859-1' df = pd.read_json(io.StringIO(req.text)) df = df[['MetaData', 'Data']] @@ -16,37 +17,46 @@ def fetch_population_data(): df['Population'] = df['Data'].apply(lambda x: x[0]['Valor']) return df[['ID_Provincia', 'Population']] + def remove_islands(df): df = df[~df['ID_Provincia'].isin([51, 52, 8, 35, 38])] return df + def get_population_data(): df = fetch_population_data() df = remove_islands(df) return df + def fetch_icu_data(): download_url = 'https://www.sanidad.gob.es/areas/alertasEmergenciasSanitarias/alertasActuales/nCov/documentos/Datos_Capacidad_Asistencial_Historico_14072023.csv' req = requests.get(download_url) - req.encoding = 'ISO-8859-1' + req.encoding = 'ISO-8859-1' df = pd.read_csv(io.StringIO(req.text), sep=';') return df + def preprocess_icu_data(df): df_icu = df[df["Unidad"] == "U. Críticas SIN respirador"] df_icu_vent = df[df["Unidad"] == "U. Críticas CON respirador"] - df_icu = df_icu[['Fecha', 'ID_Provincia', 'OCUPADAS_COVID19']].rename(columns={'OCUPADAS_COVID19': 'ICU'}) - df_icu_vent = df_icu_vent[['Fecha', 'ID_Provincia', 'OCUPADAS_COVID19']].rename(columns={'OCUPADAS_COVID19': 'ICU_ventilated'}) + df_icu = df_icu[['Fecha', 'ID_Provincia', 'OCUPADAS_COVID19']].rename( + columns={'OCUPADAS_COVID19': 'ICU'}) + df_icu_vent = df_icu_vent[['Fecha', 'ID_Provincia', 'OCUPADAS_COVID19']].rename( + columns={'OCUPADAS_COVID19': 'ICU_ventilated'}) - df_merged = pd.merge(df_icu, df_icu_vent, on=['Fecha', 'ID_Provincia'], how='outer') - df_merged['Fecha'] = pd.to_datetime(df_merged['Fecha'], format='%d/%m/%Y').dt.strftime('%Y-%m-%d') + df_merged = pd.merge(df_icu, df_icu_vent, on=[ + 'Fecha', 'ID_Provincia'], how='outer') + df_merged['Fecha'] = pd.to_datetime( + df_merged['Fecha'], format='%d/%m/%Y').dt.strftime('%Y-%m-%d') df_merged.rename(columns={'Fecha': 'Date'}, inplace=True) return df_merged + def get_icu_data(): df = fetch_icu_data() df.rename(columns={'Cod_Provincia': 'ID_Provincia'}, inplace=True) @@ -55,6 +65,7 @@ def get_icu_data(): return df + def fetch_case_data(): download_url = 'https://cnecovid.isciii.es/covid19/resources/casos_diagnostico_provincia.csv' req = requests.get(download_url) @@ -63,9 +74,11 @@ def fetch_case_data(): return df + def preprocess_case_data(df): df['provincia_iso'] = df['provincia_iso'].map(dd.Provincia_ISO_to_ID) - df = df.rename(columns={'provincia_iso': 'ID_Provincia', 'fecha': 'Date', 'num_casos': 'Confirmed'})[['ID_Provincia', 'Date', 'Confirmed']] + df = df.rename(columns={'provincia_iso': 'ID_Provincia', 'fecha': 'Date', + 'num_casos': 'Confirmed'})[['ID_Provincia', 'Date', 'Confirmed']] return df @@ -76,15 +89,21 @@ def get_case_data(): return df + if __name__ == "__main__": - data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../../data/Spain") + data_dir = os.path.join(os.path.dirname( + os.path.abspath(__file__)), "../../../../data/Spain") + pydata_dir = os.path.join(data_dir, 'pydata') + os.makedirs(pydata_dir, exist_ok=True) df = get_population_data() - df.to_json(os.path.join(data_dir, 'pydata/provincias_current_population.json'), orient='records') + df.to_json(os.path.join( + pydata_dir, 'provincias_current_population.json'), orient='records') df = get_icu_data() - df.to_json(os.path.join(data_dir, 'pydata/provincia_icu.json'), orient='records') + df.to_json(os.path.join(pydata_dir, 'provincia_icu.json'), orient='records') df = get_case_data() - df.to_json(os.path.join(data_dir, 'pydata/cases_all_pronvincias.json'), orient='records') \ No newline at end of file + df.to_json(os.path.join( + pydata_dir, 'cases_all_pronvincias.json'), orient='records') From 9281a00adec66f8e293e220eccee5e7a2cbf7162 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Wed, 27 Aug 2025 11:02:10 +0200 Subject: [PATCH 42/73] format --- .../simulation/graph_germany_nuts1.py | 33 +++++++++++------- .../examples/simulation/graph_spain_nuts3.py | 34 +++++++++++-------- 2 files changed, 40 insertions(+), 27 deletions(-) diff --git a/pycode/examples/simulation/graph_germany_nuts1.py b/pycode/examples/simulation/graph_germany_nuts1.py index 843dfb2d83..40389ee00f 100644 --- a/pycode/examples/simulation/graph_germany_nuts1.py +++ b/pycode/examples/simulation/graph_germany_nuts1.py @@ -55,10 +55,12 @@ def set_covid_parameters(self, model): model.parameters.TimeInfectedCritical[mio.AgeGroup(0)] = 13.066 # probabilities - model.parameters.TransmissionProbabilityOnContact[mio.AgeGroup(0)] = 0.07333 + model.parameters.TransmissionProbabilityOnContact[mio.AgeGroup( + 0)] = 0.07333 model.parameters.RelativeTransmissionNoSymptoms[mio.AgeGroup(0)] = 1 - model.parameters.RecoveredPerInfectedNoSymptoms[mio.AgeGroup(0)] = 0.2069 + model.parameters.RecoveredPerInfectedNoSymptoms[mio.AgeGroup( + 0)] = 0.2069 model.parameters.SeverePerInfectedSymptoms[mio.AgeGroup(0)] = 0.07864 model.parameters.CriticalPerSevere[mio.AgeGroup(0)] = 0.17318 model.parameters.DeathsPerCritical[mio.AgeGroup(0)] = 0.21718 @@ -94,7 +96,8 @@ def set_npis(self, params, end_date, damping_value): if start_damping < end_date: start_date = (start_damping - self.start_date).days - params.ContactPatterns.cont_freq_mat[0].add_damping(mio.Damping(np.r_[damping_value], t=start_date)) + params.ContactPatterns.cont_freq_mat[0].add_damping( + mio.Damping(np.r_[damping_value], t=start_date)) def get_graph(self, end_date): """ @@ -119,8 +122,8 @@ def get_graph(self, end_date): data_dir_Germany, "mobility", "commuter_mobility_2022_states.txt") pydata_dir = os.path.join(data_dir_Germany, "pydata") - path_population_data = os.path.join(pydata_dir, - "county_current_population_states.json") + path_population_data = os.path.join( + pydata_dir, "county_current_population_states.json") print("Setting nodes...") mio.osecir.set_nodes_states( @@ -135,7 +138,7 @@ def get_graph(self, end_date): print("Setting edges...") mio.osecir.set_edges( mobility_data_file, graph, 1) - + print("Graph created.") return graph @@ -164,7 +167,8 @@ def run(self, num_days_sim, damping_values, save_graph=True): for node_idx in range(graph.num_nodes): # if node_idx < 5: node = graph.get_node(node_idx) - self.set_npis(node.property.parameters, end_date, damping_values[node_idx]) + self.set_npis(node.property.parameters, + end_date, damping_values[node_idx]) mobility_graph.add_node(node.id, node.property) # else: # node = graph.get_node(node_idx) @@ -181,10 +185,11 @@ def run(self, num_days_sim, damping_values, save_graph=True): results = [] for node_idx in range(graph.num_nodes): results.append(osecir.interpolate_simulation_result( - mobility_sim.graph.get_node(node_idx).property.result)) - + mobility_sim.graph.get_node(node_idx).property.result)) + return results + def run_germany_nuts1_simulation(damping_values): mio.set_log_level(mio.LogLevel.Warning) file_path = os.path.dirname(os.path.abspath(__file__)) @@ -194,21 +199,23 @@ def run_germany_nuts1_simulation(damping_values): start_date=datetime.date(year=2020, month=12, day=12), results_dir=os.path.join(file_path, "../../../results_osecir")) num_days_sim = 50 - + results = sim.run(num_days_sim, damping_values) results[0].export_csv('test.csv') return {f'region{region}': results[region] for region in range(len(results))} + def prior(): damping_values = np.random.uniform(0.0, 1.0, 16) return {'damping_values': damping_values} + if __name__ == "__main__": test = prior() run_germany_nuts1_simulation(test['damping_values']) - # import os + # import os # os.environ["KERAS_BACKEND"] = "tensorflow" # import bayesflow as bf @@ -244,12 +251,12 @@ def prior(): # .concatenate([f'region{i}' for i in range(16)], into="summary_variables", axis=-1) # .log("summary_variables", p1=True) # ) - + # summary_network = bf.networks.TimeSeriesNetwork(summary_dim=32) # inference_network = bf.networks.CouplingFlow() # workflow = bf.BasicWorkflow( - # simulator=simulator, + # simulator=simulator, # adapter=adapter, # summary_network=summary_network, # inference_network=inference_network, diff --git a/pycode/examples/simulation/graph_spain_nuts3.py b/pycode/examples/simulation/graph_spain_nuts3.py index a63e1711e1..081fb1ea22 100644 --- a/pycode/examples/simulation/graph_spain_nuts3.py +++ b/pycode/examples/simulation/graph_spain_nuts3.py @@ -55,10 +55,12 @@ def set_covid_parameters(self, model): model.parameters.TimeInfectedCritical[mio.AgeGroup(0)] = 13.066 # probabilities - model.parameters.TransmissionProbabilityOnContact[mio.AgeGroup(0)] = 0.07333 + model.parameters.TransmissionProbabilityOnContact[mio.AgeGroup( + 0)] = 0.07333 model.parameters.RelativeTransmissionNoSymptoms[mio.AgeGroup(0)] = 1 - model.parameters.RecoveredPerInfectedNoSymptoms[mio.AgeGroup(0)] = 0.2069 + model.parameters.RecoveredPerInfectedNoSymptoms[mio.AgeGroup( + 0)] = 0.2069 model.parameters.SeverePerInfectedSymptoms[mio.AgeGroup(0)] = 0.07864 model.parameters.CriticalPerSevere[mio.AgeGroup(0)] = 0.17318 model.parameters.DeathsPerCritical[mio.AgeGroup(0)] = 0.21718 @@ -80,7 +82,7 @@ def set_contact_matrices(self, model): minimum = np.ones((self.num_groups, self.num_groups)) * 0 contact_matrices[0] = mio.ContactMatrix(baseline, minimum) model.parameters.ContactPatterns.cont_freq_mat = contact_matrices - + def set_npis(self, params, end_date, damping_value): """ @@ -94,7 +96,8 @@ def set_npis(self, params, end_date, damping_value): if start_damping < end_date: start_date = (start_damping - self.start_date).days - params.ContactPatterns.cont_freq_mat[0].add_damping(mio.Damping(np.r_[damping_value], t=start_date)) + params.ContactPatterns.cont_freq_mat[0].add_damping( + mio.Damping(np.r_[damping_value], t=start_date)) def get_graph(self, end_date): """ @@ -119,8 +122,8 @@ def get_graph(self, end_date): data_dir_Spain, "mobility", "commuter_mobility_2022.txt") pydata_dir = os.path.join(data_dir_Spain, "pydata") - path_population_data = os.path.join(pydata_dir, - "provincias_current_population.json") + path_population_data = os.path.join( + pydata_dir, "provincias_current_population.json") print("Setting nodes...") mio.osecir.set_nodes_provincias( @@ -135,7 +138,7 @@ def get_graph(self, end_date): print("Setting edges...") mio.osecir.set_edges( mobility_data_file, graph, 1) - + print("Graph created.") return graph @@ -163,7 +166,8 @@ def run(self, num_days_sim, damping_values, save_graph=True): mobility_graph = osecir.MobilityGraph() for node_idx in range(graph.num_nodes): node = graph.get_node(node_idx) - self.set_npis(node.property.parameters, end_date, damping_values[node_idx]) + self.set_npis(node.property.parameters, + end_date, damping_values[node_idx]) mobility_graph.add_node(node.id, node.property) for edge_idx in range(graph.num_edges): mobility_graph.add_edge( @@ -176,10 +180,11 @@ def run(self, num_days_sim, damping_values, save_graph=True): results = [] for node_idx in range(graph.num_nodes): results.append(osecir.interpolate_simulation_result( - mobility_sim.graph.get_node(node_idx).property.result)) - + mobility_sim.graph.get_node(node_idx).property.result)) + return results + def run_spain_nuts3_simulation(damping_values): mio.set_log_level(mio.LogLevel.Warning) file_path = os.path.dirname(os.path.abspath(__file__)) @@ -189,11 +194,12 @@ def run_spain_nuts3_simulation(damping_values): start_date=datetime.date(year=2020, month=12, day=12), results_dir=os.path.join(file_path, "../../../results_osecir")) num_days_sim = 50 - + results = sim.run(num_days_sim, damping_values) return {f'region{region}': results[region] for region in range(len(results))} + def prior(): damping_values = np.random.uniform(0.0, 1.0, 47) return {'damping_values': damping_values} @@ -202,7 +208,7 @@ def prior(): if __name__ == "__main__": test = prior() run_spain_nuts3_simulation(test['damping_values']) - # import os + # import os # os.environ["KERAS_BACKEND"] = "tensorflow" # import bayesflow as bf @@ -242,11 +248,11 @@ def prior(): # inference_network = bf.networks.DiffusionModel(subnet_kwargs={'widths': {512, 512, 512, 512, 512}}) # workflow = bf.BasicWorkflow( - # simulator=simulator, + # simulator=simulator, # adapter=adapter, # summary_network=summary_network, # inference_network=inference_network, - # standardize='all' + # standardize='all' # ) # history = workflow.fit_offline(data=trainings_data, epochs=1000, batch_size=32, validation_data=validation_data) From d64af349578b7275cbc3e84abccdfafda4bded62 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Wed, 27 Aug 2025 11:04:40 +0200 Subject: [PATCH 43/73] format --- .../simulation/graph_germany_nuts0.py | 35 +++++++++++-------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/pycode/examples/simulation/graph_germany_nuts0.py b/pycode/examples/simulation/graph_germany_nuts0.py index f5cfb44ca1..b9d4843430 100644 --- a/pycode/examples/simulation/graph_germany_nuts0.py +++ b/pycode/examples/simulation/graph_germany_nuts0.py @@ -55,10 +55,12 @@ def set_covid_parameters(self, model): model.parameters.TimeInfectedCritical[mio.AgeGroup(0)] = 13.066 # probabilities - model.parameters.TransmissionProbabilityOnContact[mio.AgeGroup(0)] = 0.07333 + model.parameters.TransmissionProbabilityOnContact[mio.AgeGroup( + 0)] = 0.07333 model.parameters.RelativeTransmissionNoSymptoms[mio.AgeGroup(0)] = 1 - model.parameters.RecoveredPerInfectedNoSymptoms[mio.AgeGroup(0)] = 0.2069 + model.parameters.RecoveredPerInfectedNoSymptoms[mio.AgeGroup( + 0)] = 0.2069 model.parameters.SeverePerInfectedSymptoms[mio.AgeGroup(0)] = 0.07864 model.parameters.CriticalPerSevere[mio.AgeGroup(0)] = 0.17318 model.parameters.DeathsPerCritical[mio.AgeGroup(0)] = 0.21718 @@ -94,7 +96,8 @@ def set_npis(self, params, end_date, damping_value): if start_damping < end_date: start_date = (start_damping - self.start_date).days - params.ContactPatterns.cont_freq_mat[0].add_damping(mio.Damping(np.r_[damping_value], t=start_date)) + params.ContactPatterns.cont_freq_mat[0].add_damping( + mio.Damping(np.r_[damping_value], t=start_date)) def get_graph(self, end_date, damping_value): """ @@ -117,8 +120,8 @@ def get_graph(self, end_date, damping_value): data_dir_Germany = os.path.join(self.data_dir, "Germany") pydata_dir = os.path.join(data_dir_Germany, "pydata") - path_population_data = os.path.join(pydata_dir, - "county_current_population_germany.json") + path_population_data = os.path.join( + pydata_dir, "county_current_population_germany.json") print("Setting nodes...") mio.osecir.set_node_germany( @@ -129,7 +132,7 @@ def get_graph(self, end_date, damping_value): end_date.month, end_date.day), pydata_dir, path_population_data, False, graph, scaling_factor_infected, scaling_factor_icu, 0.9, 0, False) - + print("Graph created.") return graph @@ -157,7 +160,8 @@ def run(self, num_days_sim, damping_value, save_graph=True): mobility_graph = osecir.MobilityGraph() for node_idx in range(graph.num_nodes): - mobility_graph.add_node(graph.get_node(node_idx).id, graph.get_node(node_idx).property) + mobility_graph.add_node(graph.get_node( + node_idx).id, graph.get_node(node_idx).property) for edge_idx in range(graph.num_edges): mobility_graph.add_edge( graph.get_edge(edge_idx).start_node_idx, @@ -170,13 +174,15 @@ def run(self, num_days_sim, damping_value, save_graph=True): results = [] for node_idx in range(graph.num_nodes): results.append(osecir.interpolate_simulation_result( - mobility_sim.graph.get_node(node_idx).property.result)) + mobility_sim.graph.get_node(node_idx).property.result)) osecir.interpolate_simulation_result( - mobility_sim.graph.get_node(0).property.result).export_csv('test.csv') - + mobility_sim.graph.get_node(0).property.result).export_csv( + 'test.csv') + return results + def run_germany_nuts0_simulation(damping_value): mio.set_log_level(mio.LogLevel.Warning) file_path = os.path.dirname(os.path.abspath(__file__)) @@ -186,19 +192,21 @@ def run_germany_nuts0_simulation(damping_value): start_date=datetime.date(year=2020, month=12, day=12), results_dir=os.path.join(file_path, "../../../results_osecir")) num_days_sim = 50 - + results = sim.run(num_days_sim, damping_value) return {"region" + str(region): results[region] for region in range(len(results))} + def prior(): damping_value = np.random.uniform(0.0, 1.0) return {"damping_value": damping_value} + if __name__ == "__main__": run_germany_nuts0_simulation(0.5) - # import os + # import os # os.environ["KERAS_BACKEND"] = "jax" # import bayesflow as bf @@ -215,7 +223,6 @@ def prior(): # # trainings_data = {k:v for k, v in trainings_data.items() if k in ('damping_value', 'region0', 'region1')} # print("Loaded training data:", trainings_data) - # trainings_data = simulator.sample(2) # validation_data = simulator.sample(2) @@ -233,7 +240,7 @@ def prior(): # inference_network = bf.networks.CouplingFlow() # workflow = bf.BasicWorkflow( - # simulator=simulator, + # simulator=simulator, # adapter=adapter, # summary_network=summary_network, # inference_network=inference_network From 6a900bca6d11bdff3280832630b3dbc1a242f1c5 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Wed, 27 Aug 2025 11:17:25 +0200 Subject: [PATCH 44/73] format --- .../memilio/epidata/getPopulationData.py | 33 +++++++++++++------ 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/pycode/memilio-epidata/memilio/epidata/getPopulationData.py b/pycode/memilio-epidata/memilio/epidata/getPopulationData.py index 07b10896c0..07e8ab7415 100644 --- a/pycode/memilio-epidata/memilio/epidata/getPopulationData.py +++ b/pycode/memilio-epidata/memilio/epidata/getPopulationData.py @@ -55,8 +55,9 @@ def read_population_data(ref_year): req = requests.get(download_url) df_pop_raw = pd.read_csv(io.StringIO(req.text), sep=';', header=5) except pd.errors.ParserError: - gd.default_print('Warning', 'Data for year '+str(ref_year) + - ' is not available; downloading newest data instead.') + gd.default_print( + 'Warning', 'Data for year ' + str(ref_year) + + ' is not available; downloading newest data instead.') ref_year = None if ref_year is None: download_url = 'https://www.regionalstatistik.de/genesis/online?operation=download&code=12411-02-03-4&option=csv' @@ -66,7 +67,9 @@ def read_population_data(ref_year): return df_pop_raw, ref_year -def export_population_dataframe(df_pop: pd.DataFrame, directory: str, file_format: str, merge_eisenach: bool, ref_year): +def export_population_dataframe( + df_pop: pd.DataFrame, directory: str, file_format: str, + merge_eisenach: bool, ref_year): """ Writes population dataframe into directory with new column names and age groups :param df_pop: Population data DataFrame to be exported pd.DataFrame @@ -140,8 +143,10 @@ def export_population_dataframe(df_pop: pd.DataFrame, directory: str, file_forma columns=dd.EngEng["idCounty"]) gd.write_dataframe(df_pop_export, directory, filename, file_format) - gd.write_dataframe(aggregate_to_state_level(df_pop_export), directory, filename + '_states', file_format) - gd.write_dataframe(aggregate_to_country_level(df_pop_export), directory, filename + '_germany', file_format) + gd.write_dataframe(aggregate_to_state_level(df_pop_export), + directory, filename + '_states', file_format) + gd.write_dataframe(aggregate_to_country_level(df_pop_export), + directory, filename + '_germany', file_format) return df_pop_export @@ -192,8 +197,8 @@ def assign_population_data(df_pop_raw, counties, age_cols, idCounty_idx): # direct assignment of population data found df_pop.loc[df_pop[dd.EngEng['idCounty']] == df_pop_raw.loc [start_idx, dd.EngEng['idCounty']], - age_cols] = df_pop_raw.loc[start_idx: start_idx + num_age_groups - 1, dd.EngEng - ['number']].values.astype(int) + age_cols] = df_pop_raw.loc[start_idx: start_idx + + num_age_groups - 1, dd.EngEng['number']].values.astype(int) # Berlin and Hamburg elif county_id + '000' in counties[:, 1]: # direct assignment of population data found @@ -264,7 +269,8 @@ def fetch_population_data(read_data: bool = dd.defaultDict['read_data'], if read_data == True: gd.default_print( - 'Warning', 'Read_data is not supportet for getPopulationData.py. Setting read_data = False') + 'Warning', + 'Read_data is not supportet for getPopulationData.py. Setting read_data = False') read_data = False directory = os.path.join(out_folder, 'Germany', 'pydata') @@ -445,22 +451,29 @@ def get_population_data(read_data: bool = dd.defaultDict['read_data'], ) return df_pop_export + def aggregate_to_state_level(df_pop: pd.DataFrame): countyIDtostateID = geoger.get_countyid_to_stateid_map() df_pop['ID_State'] = df_pop[dd.EngEng['idCounty']].map(countyIDtostateID) - df_pop = df_pop.drop(columns='ID_County').groupby('ID_State', as_index=True).sum() + df_pop = df_pop.drop( + columns='ID_County').groupby( + 'ID_State', as_index=True).sum() df_pop['ID_State'] = df_pop.index return df_pop + def aggregate_to_country_level(df_pop: pd.DataFrame): df_pop['ID_Country'] = 0 - df_pop = df_pop.drop(columns=['ID_County', 'ID_State']).groupby('ID_Country', as_index=True).sum() + df_pop = df_pop.drop( + columns=['ID_County', 'ID_State']).groupby( + 'ID_Country', as_index=True).sum() df_pop['ID_Country'] = df_pop.index return df_pop + def main(): """ Main program entry.""" From 4271e61107ad875b3ff216c8df875f9c46309065 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Wed, 27 Aug 2025 11:53:07 +0200 Subject: [PATCH 45/73] fix reading spanish population data --- .../memilio/epidata/getSimulationDataSpain.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py b/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py index d69b6aabba..482ac520d1 100644 --- a/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py +++ b/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py @@ -13,6 +13,10 @@ def fetch_population_data(): df = pd.read_json(io.StringIO(req.text)) df = df[['MetaData', 'Data']] + df = df[df['MetaData'].apply( + lambda x: x[0]['T3_Variable'] == 'Provincias')] + df = df[df['MetaData'].apply( + lambda x: x[1]['Nombre'] == 'Total')] df['ID_Provincia'] = df['MetaData'].apply(lambda x: x[0]['Id']) df['Population'] = df['Data'].apply(lambda x: x[0]['Valor']) return df[['ID_Provincia', 'Population']] @@ -77,8 +81,10 @@ def fetch_case_data(): def preprocess_case_data(df): df['provincia_iso'] = df['provincia_iso'].map(dd.Provincia_ISO_to_ID) - df = df.rename(columns={'provincia_iso': 'ID_Provincia', 'fecha': 'Date', - 'num_casos': 'Confirmed'})[['ID_Provincia', 'Date', 'Confirmed']] + df = df.rename( + columns={'provincia_iso': 'ID_Provincia', 'fecha': 'Date', + 'num_casos': 'Confirmed'})[ + ['ID_Provincia', 'Date', 'Confirmed']] return df @@ -102,7 +108,8 @@ def get_case_data(): pydata_dir, 'provincias_current_population.json'), orient='records') df = get_icu_data() - df.to_json(os.path.join(pydata_dir, 'provincia_icu.json'), orient='records') + df.to_json(os.path.join(pydata_dir, 'provincia_icu.json'), + orient='records') df = get_case_data() df.to_json(os.path.join( From 3f61e48c14979b244863319238c7b64f3ed995bf Mon Sep 17 00:00:00 2001 From: HenrZu <69154294+HenrZu@users.noreply.github.com> Date: Wed, 27 Aug 2025 12:40:10 +0200 Subject: [PATCH 46/73] read spanish pop --- cpp/memilio/io/epi_data.cpp | 2 +- cpp/memilio/io/epi_data.h | 76 ++++++++++++++-------------- cpp/memilio/io/parameters_io.cpp | 48 ++++++++++-------- cpp/memilio/io/parameters_io.h | 18 +++++++ cpp/models/ode_secir/parameters_io.h | 52 +++++++------------ 5 files changed, 101 insertions(+), 95 deletions(-) diff --git a/cpp/memilio/io/epi_data.cpp b/cpp/memilio/io/epi_data.cpp index c91c7f9dd5..e95c6ef809 100644 --- a/cpp/memilio/io/epi_data.cpp +++ b/cpp/memilio/io/epi_data.cpp @@ -30,7 +30,7 @@ std::vector ConfirmedCasesDataEntry::age_group_names = {"A00-A04", std::vector PopulationDataEntry::age_group_names = { "<3 years", "3-5 years", "6-14 years", "15-17 years", "18-24 years", "25-29 years", "30-39 years", "40-49 years", "50-64 years", "65-74 years", ">74 years"}; -// std::vector PopulationDataEntrySpain::age_group_names = {"Population"}; +std::vector PopulationDataEntrySpain::age_group_names = {"Population"}; std::vector VaccinationDataEntry::age_group_names = {"0-4", "5-14", "15-34", "35-59", "60-79", "80-99"}; diff --git a/cpp/memilio/io/epi_data.h b/cpp/memilio/io/epi_data.h index 2b62dd59b4..2c101fcb28 100644 --- a/cpp/memilio/io/epi_data.h +++ b/cpp/memilio/io/epi_data.h @@ -337,34 +337,34 @@ class PopulationDataEntry } }; -// class PopulationDataEntrySpain -// { -// public: -// static std::vector age_group_names; - -// CustomIndexArray population; -// boost::optional provincia_id; - -// template -// static IOResult deserialize(IoContext& io) -// { -// auto obj = io.expect_object("PopulationDataEntrySpain"); -// auto state_id = obj.expect_optional("ID_Provincia", Tag{}); -// std::vector> age_groups; -// age_groups.reserve(age_group_names.size()); -// std::transform(age_group_names.begin(), age_group_names.end(), std::back_inserter(age_groups), -// [&obj](auto&& age_name) { -// return obj.expect_element(age_name, Tag{}); -// }); -// return apply( -// io, -// [](auto&& ag, auto&& pid) { -// return PopulationDataEntrySpain{ -// CustomIndexArray(AgeGroup(ag.size()), ag.begin(), ag.end()), pid}; -// }, -// details::unpack_all(age_groups), provincia_id); -// } -// }; +class PopulationDataEntrySpain +{ +public: + static std::vector age_group_names; + + CustomIndexArray population; + boost::optional provincia_id; + + template + static IOResult deserialize(IoContext& io) + { + auto obj = io.expect_object("PopulationDataEntrySpain"); + auto provincia = obj.expect_optional("ID_Provincia", Tag{}); + std::vector> age_groups; + age_groups.reserve(age_group_names.size()); + std::transform(age_group_names.begin(), age_group_names.end(), std::back_inserter(age_groups), + [&obj](auto&& age_name) { + return obj.expect_element(age_name, Tag{}); + }); + return apply( + io, + [](auto&& ag, auto&& pid) { + return PopulationDataEntrySpain{ + CustomIndexArray(AgeGroup(ag.size()), ag.begin(), ag.end()), pid}; + }, + details::unpack_all(age_groups), provincia); + } +}; namespace details { @@ -477,11 +477,11 @@ inline IOResult> deserialize_population_data(co * @param rki_age_groups Specifies whether population data should be interpolated to rki age groups. * @return list of population data. */ -// inline IOResult> deserialize_population_data_spain(const Json::Value& jsvalue) -// { -// BOOST_OUTCOME_TRY(auto&& population_data, deserialize_json(jsvalue, Tag>{})); -// return success(population_data); -// } +inline IOResult> deserialize_population_data_spain(const Json::Value& jsvalue) +{ + BOOST_OUTCOME_TRY(auto&& population_data, deserialize_json(jsvalue, Tag>{})); + return success(population_data); +} /** * Deserialize population data from a JSON file. @@ -502,11 +502,11 @@ inline IOResult> read_population_data(const std * @param filename JSON file that contains the population data. * @return list of population data. */ -// inline IOResult> read_population_data_spain(const std::string& filename) -// { -// BOOST_OUTCOME_TRY(auto&& jsvalue, read_json(filename)); -// return deserialize_population_data_spain(jsvalue); -// } +inline IOResult> read_population_data_spain(const std::string& filename) +{ + BOOST_OUTCOME_TRY(auto&& jsvalue, read_json(filename)); + return deserialize_population_data_spain(jsvalue); +} /** * @brief Sets the age groups' names for the ConfirmedCasesDataEntry%s. diff --git a/cpp/memilio/io/parameters_io.cpp b/cpp/memilio/io/parameters_io.cpp index 1e90999121..4d75504a82 100644 --- a/cpp/memilio/io/parameters_io.cpp +++ b/cpp/memilio/io/parameters_io.cpp @@ -62,30 +62,27 @@ IOResult>> read_population_data(const std::vecto return success(vnum_population); } -// IOResult>> -// read_population_data_spain(const std::vector& population_data, -// const std::vector& vregion) -// { -// std::vector> vnum_population(vregion.size(), std::vector(1, 0.0)); +IOResult>> +read_population_data_spain(const std::vector& population_data, + const std::vector& vregion) +{ + std::vector> vnum_population(vregion.size(), std::vector(1, 0.0)); -// for (auto&& provincia_entry : population_data) { -// printf("Test"); -// //find region that this county belongs to -// //all counties belong to the country (id = 0) -// auto it = std::find_if(vregion.begin(), vregion.end(), [&provincia_entry](auto r) { -// return r == 0 || (provincia_entry.provincia_id && regions::ProvinciaId(r) == provincia_entry.provincia_id); -// }); -// if (it != vregion.end()) { -// auto region_idx = size_t(it - vregion.begin()); -// auto& num_population = vnum_population[region_idx]; -// for (size_t age = 0; age < num_population.size(); age++) { -// num_population[age] += provincia_entry.population[AgeGroup(age)]; -// } -// } -// } + for (auto&& provincia_entry : population_data) { + auto it = std::find_if(vregion.begin(), vregion.end(), [&provincia_entry](auto r) { + return r == 0 || (provincia_entry.provincia_id && regions::ProvinciaId(r) == provincia_entry.provincia_id); + }); + if (it != vregion.end()) { + auto region_idx = size_t(it - vregion.begin()); + auto& num_population = vnum_population[region_idx]; + for (size_t age = 0; age < num_population.size(); age++) { + num_population[age] += provincia_entry.population[AgeGroup(age)]; + } + } + } -// return success(vnum_population); -// } + return success(vnum_population); +} IOResult>> read_population_data(const std::string& path, const std::vector& vregion) @@ -93,5 +90,12 @@ IOResult>> read_population_data(const std::strin BOOST_OUTCOME_TRY(auto&& population_data, mio::read_population_data(path)); return read_population_data(population_data, vregion); } + +IOResult>> read_population_data_spain(const std::string& path, + const std::vector& vregion) +{ + BOOST_OUTCOME_TRY(auto&& population_data, mio::read_population_data_spain(path)); + return read_population_data_spain(population_data, vregion); +} } // namespace mio #endif //MEMILIO_HAS_JSONCPP diff --git a/cpp/memilio/io/parameters_io.h b/cpp/memilio/io/parameters_io.h index 3ea00fc73b..bbb0769d51 100644 --- a/cpp/memilio/io/parameters_io.h +++ b/cpp/memilio/io/parameters_io.h @@ -136,6 +136,24 @@ IOResult>> read_population_data(const std::vecto IOResult>> read_population_data(const std::string& path, const std::vector& vregion); +/** + * @brief Reads (Spain) population data from a vector of Spain population data entries. + * Uses provincias IDs and a single aggregated age group. + * @param[in] population_data Vector of Spain population data entries. + * @param[in] vregion Vector of keys representing the provincias (or country = 0) of interest. + */ +IOResult>> +read_population_data_spain(const std::vector& population_data, + const std::vector& vregion); + +/** + * @brief Reads (Spain) population data from census data file with provincias. + * @param[in] path Path to the provincias population data file. + * @param[in] vregion Vector of keys representing the provincias (or country = 0) of interest. + */ +IOResult>> read_population_data_spain(const std::string& path, + const std::vector& vregion); + } // namespace mio #endif //MEMILIO_HAS_JSONCPP diff --git a/cpp/models/ode_secir/parameters_io.h b/cpp/models/ode_secir/parameters_io.h index 8eaad47685..8c4b6d05cc 100644 --- a/cpp/models/ode_secir/parameters_io.h +++ b/cpp/models/ode_secir/parameters_io.h @@ -304,14 +304,14 @@ IOResult set_population_data(std::vector>& model, const std::str return success(); } -// template -// IOResult set_population_data_provincias(std::vector>& model, const std::string& path, -// const std::vector& vregion) -// { -// BOOST_OUTCOME_TRY(const auto&& num_population, read_population_data_spain(path, vregion)); -// BOOST_OUTCOME_TRY(set_population_data(model, num_population, vregion)); -// return success(); -// } +template +IOResult set_population_data_provincias(std::vector>& model, const std::string& path, + const std::vector& vregion) +{ + BOOST_OUTCOME_TRY(const auto&& num_population, read_population_data_spain(path, vregion)); + BOOST_OUTCOME_TRY(set_population_data(model, num_population, vregion)); + return success(); +} } //namespace details @@ -477,32 +477,16 @@ IOResult read_input_data_county(std::vector& model, Date date, cons * @param[in] num_days [Default: 0] Number of days to be simulated; required to extrapolate real data. * @param[in] export_time_series [Default: false] If true, reads data for each day of simulation and writes it in the same directory as the input files. */ -// template -// IOResult read_input_data_provincias(std::vector& model, Date date, const std::vector& county, -// const std::vector& scaling_factor_inf, double scaling_factor_icu, -// const std::string& pydata_dir, int num_days = 0, -// bool export_time_series = false) -// { -// // BOOST_OUTCOME_TRY( -// // details::set_divi_data(model, path_join(pydata_dir, "county_divi_ma7.json"), county, date, scaling_factor_icu)); -// // BOOST_OUTCOME_TRY(details::set_confirmed_cases_data(model, path_join(pydata_dir, "cases_all_county_ma7.json"), -// // county, date, scaling_factor_inf)); -// BOOST_OUTCOME_TRY(details::set_population_data_provincias( -// model, path_join(pydata_dir, "provincias_current_population.json"), county)); - -// if (export_time_series) { -// // Use only if extrapolated real data is needed for comparison. EXPENSIVE ! -// // Run time equals run time of the previous functions times the num_days ! -// // (This only represents the vectorization of the previous function over all simulation days...) -// log_warning("Exporting time series of extrapolated real data. This may take some minutes. " -// "For simulation runs over the same time period, deactivate it."); -// BOOST_OUTCOME_TRY(export_input_data_county_timeseries( -// model, pydata_dir, county, date, scaling_factor_inf, scaling_factor_icu, num_days, -// path_join(pydata_dir, "county_divi_ma7.json"), path_join(pydata_dir, "cases_all_county_age_ma7.json"), -// path_join(pydata_dir, "county_current_population.json"))); -// } -// return success(); -// } +template +IOResult read_input_data_provincias(std::vector& model, Date /*date*/, const std::vector& provincias, + const std::vector& /*scaling_factor_inf*/, + double /*scaling_factor_icu*/, const std::string& pydata_dir, + int /*num_days*/ = 0, bool /*export_time_series*/ = false) +{ + BOOST_OUTCOME_TRY(details::set_population_data_provincias( + model, path_join(pydata_dir, "provincias_current_population.json"), provincias)); + return success(); +} /** * @brief reads population data from population files for the specified nodes From ea6f566534f64056f0ed07438898876cac35dd4c Mon Sep 17 00:00:00 2001 From: HenrZu <69154294+HenrZu@users.noreply.github.com> Date: Wed, 27 Aug 2025 12:49:37 +0200 Subject: [PATCH 47/73] fix init pop for provincias --- cpp/memilio/io/parameters_io.cpp | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/cpp/memilio/io/parameters_io.cpp b/cpp/memilio/io/parameters_io.cpp index 4d75504a82..450a38f902 100644 --- a/cpp/memilio/io/parameters_io.cpp +++ b/cpp/memilio/io/parameters_io.cpp @@ -69,14 +69,17 @@ read_population_data_spain(const std::vector& populati std::vector> vnum_population(vregion.size(), std::vector(1, 0.0)); for (auto&& provincia_entry : population_data) { - auto it = std::find_if(vregion.begin(), vregion.end(), [&provincia_entry](auto r) { - return r == 0 || (provincia_entry.provincia_id && regions::ProvinciaId(r) == provincia_entry.provincia_id); - }); - if (it != vregion.end()) { - auto region_idx = size_t(it - vregion.begin()); - auto& num_population = vnum_population[region_idx]; - for (size_t age = 0; age < num_population.size(); age++) { - num_population[age] += provincia_entry.population[AgeGroup(age)]; + if (provincia_entry.provincia_id) { + for (size_t idx = 0; idx < vregion.size(); ++idx) { + if (vregion[idx] == provincia_entry.provincia_id->get()) { + vnum_population[idx][0] += provincia_entry.population[AgeGroup(0)]; + } + } + } + // id 0 means the whole country + for (size_t idx = 0; idx < vregion.size(); ++idx) { + if (vregion[idx] == 0) { + vnum_population[idx][0] += provincia_entry.population[AgeGroup(0)]; } } } From 715c711670ccc812ba039523ccc29d8ce7269e29 Mon Sep 17 00:00:00 2001 From: HenrZu <69154294+HenrZu@users.noreply.github.com> Date: Wed, 27 Aug 2025 14:31:45 +0200 Subject: [PATCH 48/73] spanish divi and case data as input for secir model --- cpp/memilio/io/epi_data.h | 39 ++++++++++++------- cpp/models/ode_secir/parameters_io.h | 13 +++++-- .../memilio/epidata/getSimulationDataSpain.py | 28 ++++++++++--- 3 files changed, 58 insertions(+), 22 deletions(-) diff --git a/cpp/memilio/io/epi_data.h b/cpp/memilio/io/epi_data.h index 2c101fcb28..048eb33c12 100644 --- a/cpp/memilio/io/epi_data.h +++ b/cpp/memilio/io/epi_data.h @@ -148,29 +148,42 @@ class ConfirmedCasesDataEntry { auto obj = io.expect_object("ConfirmedCasesDataEntry"); auto num_confirmed = obj.expect_element("Confirmed", Tag{}); - auto num_recovered = obj.expect_element("Recovered", Tag{}); - auto num_deaths = obj.expect_element("Deaths", Tag{}); + auto num_recovered = obj.expect_optional("Recovered", Tag{}); + auto num_deaths = obj.expect_optional("Deaths", Tag{}); auto date = obj.expect_element("Date", Tag{}); - auto age_group_str = obj.expect_element("Age_RKI", Tag{}); + auto age_group_str = obj.expect_optional("Age_RKI", Tag{}); auto state_id = obj.expect_optional("ID_State", Tag{}); auto county_id = obj.expect_optional("ID_County", Tag{}); auto district_id = obj.expect_optional("ID_District", Tag{}); return apply( io, - [](auto&& nc, auto&& nr, auto&& nd, auto&& d, auto&& a_str, auto&& sid, auto&& cid, + [](auto&& nc, auto&& nr, auto&& nd, auto&& d, auto&& a_str_opt, auto&& sid, auto&& cid, auto&& did) -> IOResult { - auto a = AgeGroup(0); - auto it = std::find(age_group_names.begin(), age_group_names.end(), a_str); - if (it != age_group_names.end()) { - a = AgeGroup(size_t(it - age_group_names.begin())); - } - else if (a_str == "unknown") { - a = AgeGroup(age_group_names.size()); + auto a = AgeGroup(0); // default if Age_RKI missing + + // if no age group is given, use "total" as name + if (!a_str_opt) { + if (age_group_names.size() != 1) { + age_group_names.clear(); + age_group_names.push_back("total"); + } } else { - return failure(StatusCode::InvalidValue, "Invalid confirmed cases data age group."); + const auto& a_str = *a_str_opt; + auto it = std::find(age_group_names.begin(), age_group_names.end(), a_str); + if (it != age_group_names.end()) { + a = AgeGroup(size_t(it - age_group_names.begin())); + } + else if (a_str == "unknown") { + a = AgeGroup(age_group_names.size()); + } + else { + return failure(StatusCode::InvalidValue, "Invalid confirmed cases data age group."); + } } - return success(ConfirmedCasesDataEntry{nc, nr, nd, d, a, sid, cid, did}); + double nrec = nr ? *nr : 0.0; + double ndead = nd ? *nd : 0.0; + return success(ConfirmedCasesDataEntry{nc, nrec, ndead, d, a, sid, cid, did}); }, num_confirmed, num_recovered, num_deaths, date, age_group_str, state_id, county_id, district_id); } diff --git a/cpp/models/ode_secir/parameters_io.h b/cpp/models/ode_secir/parameters_io.h index e5d0b44831..60e7256c10 100644 --- a/cpp/models/ode_secir/parameters_io.h +++ b/cpp/models/ode_secir/parameters_io.h @@ -478,11 +478,16 @@ IOResult read_input_data_county(std::vector& model, Date date, cons * @param[in] export_time_series [Default: false] If true, reads data for each day of simulation and writes it in the same directory as the input files. */ template -IOResult read_input_data_provincias(std::vector& model, Date /*date*/, const std::vector& provincias, - const std::vector& /*scaling_factor_inf*/, - double /*scaling_factor_icu*/, const std::string& pydata_dir, - int /*num_days*/ = 0, bool /*export_time_series*/ = false) +IOResult read_input_data_provincias(std::vector& model, Date date, const std::vector& provincias, + const std::vector& /*scaling_factor_inf*/, double scaling_factor_icu, + const std::string& pydata_dir, int /*num_days*/ = 0, + bool /*export_time_series*/ = false) { + BOOST_OUTCOME_TRY(details::set_divi_data(model, path_join(pydata_dir, "provincia_icu.json"), provincias, date, + scaling_factor_icu)); + + BOOST_OUTCOME_TRY(details::set_confirmed_cases_data(model, path_join(pydata_dir, "cases_all_pronvincias.json"), + provincias, date, std::vector(1, 1.0))); BOOST_OUTCOME_TRY(details::set_population_data_provincias( model, path_join(pydata_dir, "provincias_current_population.json"), provincias)); return success(); diff --git a/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py b/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py index 482ac520d1..feb9507617 100644 --- a/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py +++ b/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py @@ -82,10 +82,14 @@ def fetch_case_data(): def preprocess_case_data(df): df['provincia_iso'] = df['provincia_iso'].map(dd.Provincia_ISO_to_ID) df = df.rename( - columns={'provincia_iso': 'ID_Provincia', 'fecha': 'Date', + columns={'provincia_iso': 'ID_County', 'fecha': 'Date', 'num_casos': 'Confirmed'})[ - ['ID_Provincia', 'Date', 'Confirmed']] - + ['ID_County', 'Date', 'Confirmed']] + # ensure correct types + df['ID_County'] = pd.to_numeric( + df['ID_County'], errors='coerce').astype('Int64') + df = df[df['ID_County'].notna()].copy() + df['ID_County'] = df['ID_County'].astype(int) return df @@ -108,9 +112,23 @@ def get_case_data(): pydata_dir, 'provincias_current_population.json'), orient='records') df = get_icu_data() - df.to_json(os.path.join(pydata_dir, 'provincia_icu.json'), - orient='records') + # rename to match expected column name in parameters_io.h + if 'ID_Provincia' in df.columns: + df.rename(columns={'ID_Provincia': 'ID_County'}, inplace=True) + # and should also be int + if 'ID_County' in df.columns: + df['ID_County'] = pd.to_numeric(df['ID_County'], errors='coerce') + df = df[df['ID_County'].notna()].copy() + df['ID_County'] = df['ID_County'].astype(int) + df.to_json(os.path.join(pydata_dir, 'provincia_icu.json'), orient='records') df = get_case_data() + # same for case data + if 'ID_Provincia' in df.columns: + df.rename(columns={'ID_Provincia': 'ID_County'}, inplace=True) + if 'ID_County' in df.columns: + df['ID_County'] = pd.to_numeric(df['ID_County'], errors='coerce') + df = df[df['ID_County'].notna()].copy() + df['ID_County'] = df['ID_County'].astype(int) df.to_json(os.path.join( pydata_dir, 'cases_all_pronvincias.json'), orient='records') From ab44aebfab10ad88f915c141f3012b8ac161de5b Mon Sep 17 00:00:00 2001 From: HenrZu <69154294+HenrZu@users.noreply.github.com> Date: Thu, 28 Aug 2025 11:15:54 +0200 Subject: [PATCH 49/73] [ci skip] bindings for read_input_data_provincias --- cpp/models/ode_secir/parameters_io.h | 4 ++-- .../memilio/simulation/bindings/models/osecir.cpp | 14 +++++++++++--- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/cpp/models/ode_secir/parameters_io.h b/cpp/models/ode_secir/parameters_io.h index 60e7256c10..6f9a89acda 100644 --- a/cpp/models/ode_secir/parameters_io.h +++ b/cpp/models/ode_secir/parameters_io.h @@ -479,7 +479,7 @@ IOResult read_input_data_county(std::vector& model, Date date, cons */ template IOResult read_input_data_provincias(std::vector& model, Date date, const std::vector& provincias, - const std::vector& /*scaling_factor_inf*/, double scaling_factor_icu, + const std::vector& scaling_factor_inf, double scaling_factor_icu, const std::string& pydata_dir, int /*num_days*/ = 0, bool /*export_time_series*/ = false) { @@ -487,7 +487,7 @@ IOResult read_input_data_provincias(std::vector& model, Date date, scaling_factor_icu)); BOOST_OUTCOME_TRY(details::set_confirmed_cases_data(model, path_join(pydata_dir, "cases_all_pronvincias.json"), - provincias, date, std::vector(1, 1.0))); + provincias, date, scaling_factor_inf)); BOOST_OUTCOME_TRY(details::set_population_data_provincias( model, path_join(pydata_dir, "provincias_current_population.json"), provincias)); return success(); diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp b/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp index c08eb591df..b0d82a94e9 100644 --- a/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp +++ b/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp @@ -203,9 +203,7 @@ PYBIND11_MODULE(_simulation_osecir, m) .def("apply_constraints", &mio::osecir::Parameters::apply_constraints) .def_property( "end_commuter_detection", - [](const mio::osecir::Parameters& self) -> auto { - return self.get_end_commuter_detection(); - }, + [](const mio::osecir::Parameters& self) -> auto { return self.get_end_commuter_detection(); }, [](mio::osecir::Parameters& self, double p) { self.get_end_commuter_detection() = p; }); @@ -348,6 +346,16 @@ PYBIND11_MODULE(_simulation_osecir, m) return pymio::check_and_throw(result); }, py::return_value_policy::move); + m.def( + "read_input_data_provincias", + [](std::vector>& model, mio::Date date, const std::vector& provincias, + const std::vector& scaling_factor_inf, double scaling_factor_icu, const std::string& dir, + int num_days = 0, bool export_time_series = false) { + auto result = mio::osecir::read_input_data_provincias>( + model, date, provincias, scaling_factor_inf, scaling_factor_icu, dir, num_days, export_time_series); + return pymio::check_and_throw(result); + }, + py::return_value_policy::move); #endif // MEMILIO_HAS_JSONCPP m.def("interpolate_simulation_result", From 6487b721106ad70231e1fc13afe7e836e0e04dc4 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Thu, 28 Aug 2025 11:26:27 +0200 Subject: [PATCH 50/73] get mobility data --- .../memilio/epidata/getSimulationDataSpain.py | 70 ++++++++++++++++++- 1 file changed, 68 insertions(+), 2 deletions(-) diff --git a/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py b/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py index feb9507617..744468cfbc 100644 --- a/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py +++ b/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py @@ -3,8 +3,11 @@ import io import requests +from pyspainmobility import Mobility, Zones import defaultDict as dd +from memilio.epidata import getDataIntoPandasDataFrame as gd + def fetch_population_data(): download_url = 'https://servicios.ine.es/wstempus/js/es/DATOS_TABLA/67988?tip=AM&' @@ -22,8 +25,9 @@ def fetch_population_data(): return df[['ID_Provincia', 'Population']] -def remove_islands(df): - df = df[~df['ID_Provincia'].isin([51, 52, 8, 35, 38])] +def remove_islands(df, column_labels=['ID_Provincia']): + for label in column_labels: + df = df[~df[label].isin([51, 52, 8, 35, 38])] return df @@ -100,12 +104,67 @@ def get_case_data(): return df +def download_mobility_data(data_dir, start_date, end_date, level): + mobility_data = Mobility(version=2, zones=level, + start_date=start_date, end_date=end_date, output_directory=data_dir) + mobility_data.get_od_data() + + +def preprocess_mobility_data(df, data_dir): + df.drop(columns=['hour', 'trips_total_length_km'], inplace=True) + if not os.path.exists(os.path.join(data_dir, 'poblacion.csv')): + download_url = 'https://movilidad-opendata.mitma.es/zonificacion/poblacion.csv' + req = requests.get(download_url) + with open(os.path.join(data_dir, 'poblacion.csv'), 'wb') as f: + f.write(req.content) + + zonification_df = pd.read_csv(os.path.join(data_dir, 'poblacion.csv'), sep='|')[ + ['municipio', 'provincia', 'poblacion']] + poblacion = zonification_df.groupby('provincia')[ + 'poblacion'].sum() + zonification_df.drop_duplicates( + subset=['municipio', 'provincia'], inplace=True) + municipio_to_provincia = dict( + zip(zonification_df['municipio'], zonification_df['provincia'])) + df['id_origin'] = df['id_origin'].map(municipio_to_provincia) + df['id_destination'] = df['id_destination'].map(municipio_to_provincia) + + df.query('id_origin != id_destination', inplace=True) + df = df.groupby(['date', 'id_origin', 'id_destination'], + as_index=False)['n_trips'].sum() + + df['n_trips'] = df.apply( + lambda row: row['n_trips'] / poblacion[row['id_origin']], axis=1) + + df = df.groupby(['id_origin', 'id_destination'], + as_index=False)['n_trips'].mean() + + df = remove_islands(df, ['id_origin', 'id_destination']) + + return df + + +def get_mobility_data(data_dir, start_date='2022-08-01', end_date='2022-08-31', level='municipios'): + filename = f'Viajes_{level}_{start_date}_{end_date}_v2.parquet' + if not os.path.exists(os.path.join(data_dir, filename)): + print( + f"File {os.path.exists(os.path.join(data_dir, filename))} does not exist. Downloading mobility data...") + download_mobility_data(data_dir, start_date, end_date, level) + + df = pd.read_parquet(os.path.join(data_dir, filename)) + df = preprocess_mobility_data(df, data_dir) + + return df + + if __name__ == "__main__": data_dir = os.path.join(os.path.dirname( os.path.abspath(__file__)), "../../../../data/Spain") pydata_dir = os.path.join(data_dir, 'pydata') os.makedirs(pydata_dir, exist_ok=True) + mobility_dir = os.path.join(data_dir, 'mobility') + os.makedirs(mobility_dir, exist_ok=True) df = get_population_data() df.to_json(os.path.join( @@ -132,3 +191,10 @@ def get_case_data(): df['ID_County'] = df['ID_County'].astype(int) df.to_json(os.path.join( pydata_dir, 'cases_all_pronvincias.json'), orient='records') + + df = get_mobility_data(mobility_dir) + matrix = df.pivot(index='id_origin', + columns='id_destination', values='n_trips').fillna(0) + + gd.write_dataframe(matrix, data_dir, 'commuter_mobility', 'txt', { + 'sep': ' ', 'index': False, 'header': False}) From fc55a8d4be26dcce117554d3fbf82000b35ca894 Mon Sep 17 00:00:00 2001 From: HenrZu <69154294+HenrZu@users.noreply.github.com> Date: Thu, 28 Aug 2025 12:32:43 +0200 Subject: [PATCH 51/73] [ci skip] accumulated cases --- .../memilio/epidata/getSimulationDataSpain.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py b/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py index 744468cfbc..8ccde09fc5 100644 --- a/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py +++ b/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py @@ -98,9 +98,11 @@ def preprocess_case_data(df): def get_case_data(): - df = fetch_case_data() - df = preprocess_case_data(df) - + df_raw = fetch_case_data() + df = preprocess_case_data(df_raw) + df = df.groupby(['ID_County', 'Date'], as_index=False)['Confirmed'].sum() + df = df.sort_values(['ID_County', 'Date']) + df['Confirmed'] = df.groupby('ID_County')['Confirmed'].cumsum() return df From 3d74140529f5ce5fe897bd8c5812bbc22a82f63a Mon Sep 17 00:00:00 2001 From: HenrZu <69154294+HenrZu@users.noreply.github.com> Date: Thu, 28 Aug 2025 12:51:39 +0200 Subject: [PATCH 52/73] no uncertainty in pop --- cpp/memilio/mobility/graph.h | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/cpp/memilio/mobility/graph.h b/cpp/memilio/mobility/graph.h index 5b211fd4a1..1a616be241 100644 --- a/cpp/memilio/mobility/graph.h +++ b/cpp/memilio/mobility/graph.h @@ -307,15 +307,15 @@ IOResult set_nodes(const Parameters& params, Date start_date, Date end_dat }); //uncertainty in populations - for (auto i = mio::AgeGroup(0); i < params.get_num_groups(); i++) { - for (auto j = Index(0); j < Model::Compartments::Count; ++j) { - auto& compartment_value = nodes[node_idx].populations[{i, j}]; - compartment_value = - UncertainValue(0.5 * (1.1 * double(compartment_value) + 0.9 * double(compartment_value))); - compartment_value.set_distribution(mio::ParameterDistributionUniform(0.9 * double(compartment_value), - 1.1 * double(compartment_value))); - } - } + // for (auto i = mio::AgeGroup(0); i < params.get_num_groups(); i++) { + // for (auto j = Index(0); j < Model::Compartments::Count; ++j) { + // auto& compartment_value = nodes[node_idx].populations[{i, j}]; + // compartment_value = + // UncertainValue(0.5 * (1.1 * double(compartment_value) + 0.9 * double(compartment_value))); + // compartment_value.set_distribution(mio::ParameterDistributionUniform(0.9 * double(compartment_value), + // 1.1 * double(compartment_value))); + // } + // } params_graph.add_node(node_ids[node_idx], nodes[node_idx]); } From a53c310fc8991919c109b9ff03a4a4e81efd0555 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Wed, 1 Oct 2025 10:27:55 +0200 Subject: [PATCH 53/73] add clean version of fitting --- .../simulation/graph_germany_nuts3_clean.py | 745 ++++++++++++++++++ .../memilio/epidata/defaultDict.py | 2 +- 2 files changed, 746 insertions(+), 1 deletion(-) create mode 100644 pycode/examples/simulation/graph_germany_nuts3_clean.py diff --git a/pycode/examples/simulation/graph_germany_nuts3_clean.py b/pycode/examples/simulation/graph_germany_nuts3_clean.py new file mode 100644 index 0000000000..da46f6fe65 --- /dev/null +++ b/pycode/examples/simulation/graph_germany_nuts3_clean.py @@ -0,0 +1,745 @@ +####################################################################### +# Copyright (C) 2020-2025 MEmilio +# +# Authors: Henrik Zunker +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# +import os +os.environ["KERAS_BACKEND"] = "tensorflow" + +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import datetime +import pickle +from scipy.stats import truncnorm + +from matplotlib.patches import Patch + +import bayesflow as bf +import keras + +import memilio.simulation as mio +import memilio.simulation.osecir as osecir +from memilio.simulation.osecir import Model, Simulation, interpolate_simulation_result +from memilio.epidata import defaultDict as dd + + + +excluded_ids = [11001, 11002, 11003, 11004, 11005, 11006, + 11007, 11008, 11009, 11010, 11011, 11012, 16056] +no_icu_ids = [7338, 9374, 9473, 9573] +region_ids = [region_id for region_id in dd.County.keys() + if region_id not in excluded_ids] + +inference_params = ['damping_values', 't_E', 't_ISy', 't_ISev', + 't_Cr', 'mu_CR', 'mu_IH', 'mu_HU', 'mu_UD', 'transmission_prob'] +summary_vars = ['state'] + [f'fed_state{i}' for i in range(16)] + [ + f'region{i}' for i in range(len(region_ids)) if region_ids[i] not in no_icu_ids] + +bounds = { + 't_E': (1.0, 5.2), + 't_ISy': (4.0, 10.0), + 't_ISev': (5.0, 10.0), + 't_Cr': (9.0, 17.0), + 'mu_CR': (0.0, 0.4), + 'mu_IH': (0.0, 0.2), + 'mu_HU': (0.0, 0.4), + 'mu_UD': (0.0, 0.4), + 'transmission_prob': (0.0, 0.2) +} +SPIKE_SCALE = 0.85 +SLAB_SCALE = 0.4 +DATE_TIME = datetime.date(year=2020, month=10, day=1) + +# %% +def plot_region_median_mad( + data: np.ndarray, + region: int, + true_data=None, + ax=None, + label=None, + color="red" +): + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data is not None: + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + n_samples, n_time, n_regions = data.shape + if not (0 <= region < n_regions): + raise IndexError + + x = np.arange(n_time) + vals = data[:, :, region] # (samples, time_points) + med = np.median(vals, axis=0) + mad = np.median(np.abs(vals - med), axis=0) + + if ax is None: + fig, ax = plt.subplots() + + line, = ax.plot( + x, med, lw=2, label=label or f"Region {region}", color=color) + band = ax.fill_between(x, med - mad, med + mad, alpha=0.25, color=color) + if true_data is not None: + true_vals = true_data[:, region] # (time_points,) + ax.plot(x, true_vals, lw=2, color="black", label="True data") + + ax.set_xlabel("Time", fontsize=12) + ax.set_ylabel("ICU", fontsize=12) + ax.set_title(f"Region {region}", fontsize=12) + if label is not None: + ax.legend(fontsize=11, loc="upper right") + return line, band + + +def plot_aggregated_over_regions( + data: np.ndarray, + region_agg=np.sum, + true_data=None, + ax=None, + label=None, + color='red' +): + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data is not None: + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + + # Aggregate over regions + agg_over_regions = region_agg(data, axis=-1) # (samples, time_points) + + # Aggregate over samples + agg_median = np.median(agg_over_regions, axis=0) # (time_points, ) + agg_mad = np.median( + np.abs(agg_over_regions - agg_median[None]), + axis=0 + ) + + x = np.arange(agg_median.shape[0]) + if ax is None: + fig, ax = plt.subplots() + + line, = ax.plot(x, agg_median, lw=2, + label=label or "Aggregated over regions", color=color) + band = ax.fill_between( + x, + agg_median - agg_mad, + agg_median + agg_mad, + alpha=0.25, + color=color + ) + if true_data is not None: + true_vals = region_agg(true_data, axis=-1) # (time_points,) + ax.plot(x, true_vals, lw=2, color="black", label="True data") + + ax.set_xlabel("Time", fontsize=12) + ax.set_ylabel("ICU", fontsize=12) + if label is not None: + ax.legend(fontsize=11) + return line, band + +# %% + + +def calibration_curves_per_region( + data: np.ndarray, + true_data: np.ndarray, + levels=np.linspace(0.01, 0.99, 20), + ax=None, + max_regions=None, + cmap=plt.cm.Blues, + linewidth=1.5, + legend=True, + with_ideal=True, +): + """ + Per-region calibration curves, each region in a different shade of a colormap. + + data: (samples, time, regions) + true_data: (time, regions) + max_regions: limit number of regions shown + cmap: matplotlib colormap for line shades + """ + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + + n_samples, n_time, n_regions = data.shape + if max_regions is None: + R = n_regions + else: + R = min(max_regions, n_regions) + + if ax is None: + fig, ax = plt.subplots(figsize=(5, 4)) + + colors = [cmap(i / (R + 1)) for i in range(1, R + 1)] + + x = np.asarray(levels) + for r, col in zip(range(R), colors): + emp = [] + for nominal in levels: + q_low = (1.0 - nominal) / 2.0 + q_high = 1.0 - q_low + lo = np.quantile(data[:, :, r], q_low, axis=0) + hi = np.quantile(data[:, :, r], q_high, axis=0) + hits = (true_data[:, r] >= lo) & (true_data[:, r] <= hi) + emp.append(hits.mean()) + emp = np.asarray(emp) + # , label=f"Region {r+1}") + ax.plot(x, emp, lw=linewidth, color=col, alpha=0.5) + + if with_ideal: + ideal_line = ax.plot([0, 1], [0, 1], linestyle="--", + lw=1.2, color="black", label="Ideal")[0] + else: + ideal_line = None + + if legend: + # Custom legend: one patch for regions, one line for ideal + region_patch = Patch(color=colors[-1], label="Regions") + ax.legend(handles=[region_patch, ideal_line], + frameon=True, ncol=1, fontsize=12) + + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.set_xlabel("Nominal level", fontsize=12) + ax.set_ylabel("Empirical coverage", fontsize=12) + ax.set_title("Calibration per region", fontsize=12) + return ax + + +def calibration_median_mad_over_regions( + data: np.ndarray, + true_data: np.ndarray, + levels=np.linspace(0.01, 0.99, 20), + ax=None, + color="tab:blue", + alpha=0.25, + linewidth=2.0, + with_ideal=True, +): + """ + Compute per region empirical coverage at each nominal level, + then summarize across regions with median and MAD. + + data: (samples, time, regions) + true_data: (time, regions) + """ + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + + n_samples, n_time, n_regions = data.shape + L = len(levels) + per_region = np.empty((n_regions, L), dtype=float) + + # per level coverage per region + for j, nominal in enumerate(levels): + q_low = (1.0 - nominal) / 2.0 + q_high = 1.0 - q_low + lo = np.quantile(data, q_low, axis=0) # (time, regions) + hi = np.quantile(data, q_high, axis=0) # (time, regions) + hits = (true_data >= lo) & (true_data <= hi) # (time, regions) + # mean over time for each region + per_region[:, j] = hits.mean(axis=0) + + med = np.median(per_region, axis=0) # (levels,) + mad = np.median(np.abs(per_region - med[None, :]), axis=0) + + if ax is None: + fig, ax = plt.subplots(figsize=(5, 4)) + + x = np.asarray(levels) + ax.fill_between(x, med - mad, med + mad, + alpha=alpha, color=color, label=None) + ax.plot(x, med, lw=linewidth, color=color, label="Median across regions") + + if with_ideal: + ax.plot([0, 1], [0, 1], linestyle="--", + lw=1.2, color="black", label="Ideal") + + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.set_xlabel("Nominal level", fontsize=12) + ax.set_ylabel("Empirical coverage", fontsize=12) + ax.set_title("Calibration median and MAD across regions", fontsize=12) + ax.legend(fontsize=12) + return ax, {"levels": x, "median": med, "mad": mad} + + +class Simulation: # todo: correct class? + """ """ + + def __init__(self, data_dir, start_date, results_dir): + self.num_groups = 1 + self.data_dir = data_dir + self.start_date = start_date + self.results_dir = results_dir + if not os.path.exists(self.results_dir): + os.makedirs(self.results_dir) + + def set_covid_parameters(self, model, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): + """ + + :param model: + + """ + model.parameters.TimeExposed[mio.AgeGroup(0)] = t_E + model.parameters.TimeInfectedNoSymptoms[mio.AgeGroup(0)] = 5.2 - t_E # todo: correct? + model.parameters.TimeInfectedSymptoms[mio.AgeGroup(0)] = t_ISy + model.parameters.TimeInfectedSevere[mio.AgeGroup(0)] = t_ISev + model.parameters.TimeInfectedCritical[mio.AgeGroup(0)] = t_Cr + + # probabilities + model.parameters.TransmissionProbabilityOnContact[mio.AgeGroup( + 0)] = transmission_prob + model.parameters.RelativeTransmissionNoSymptoms[mio.AgeGroup(0)] = 1 + + model.parameters.RecoveredPerInfectedNoSymptoms[mio.AgeGroup( + 0)] = mu_CR + model.parameters.SeverePerInfectedSymptoms[mio.AgeGroup(0)] = mu_IH + model.parameters.CriticalPerSevere[mio.AgeGroup(0)] = mu_HU + model.parameters.DeathsPerCritical[mio.AgeGroup(0)] = mu_UD + + # start day is set to the n-th day of the year + model.parameters.StartDay = self.start_date.timetuple().tm_yday + + model.parameters.Seasonality = mio.UncertainValue(0.2) + + def set_contact_matrices(self, model): + """ + + :param model: + + """ + contact_matrices = mio.ContactMatrixGroup(1, self.num_groups) + + baseline = np.ones((self.num_groups, self.num_groups)) * 7.95 + minimum = np.zeros((self.num_groups, self.num_groups)) + contact_matrices[0] = mio.ContactMatrix(baseline, minimum) + model.parameters.ContactPatterns.cont_freq_mat = contact_matrices + + def set_npis(self, params, end_date, damping_value): + """ + + :param params: + :param end_date: + + """ + + start_damping = DATE_TIME + datetime.timedelta(days=7) + + if start_damping < end_date: + start_date = (start_damping - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping( + mio.Damping(np.r_[damping_value], t=start_date)) + + def get_graph(self, end_date, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): + """ + + :param end_date: + + """ + print("Initializing model...") + model = Model(self.num_groups) + self.set_covid_parameters( + model, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) + self.set_contact_matrices(model) + print("Model initialized.") + + graph = osecir.ModelGraph() + + scaling_factor_infected = [2.5] + scaling_factor_icu = 1.0 + + data_dir_Germany = os.path.join(self.data_dir, "Germany") + mobility_data_file = os.path.join( + data_dir_Germany, "mobility", "commuter_mobility_2022.txt") + pydata_dir = os.path.join(data_dir_Germany, "pydata") + + path_population_data = os.path.join(pydata_dir, + "county_current_population.json") + + print("Setting nodes...") + mio.osecir.set_nodes( + model.parameters, + mio.Date(self.start_date.year, + self.start_date.month, self.start_date.day), + mio.Date(end_date.year, + end_date.month, end_date.day), pydata_dir, + path_population_data, True, graph, scaling_factor_infected, + scaling_factor_icu, 0, 0, False) + + print("Setting edges...") + mio.osecir.set_edges(mobility_data_file, graph, 1) + + print("Graph created.") + + return graph + + def run(self, num_days_sim, damping_values, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob, save_graph=True): + """ + + :param num_days_sim: + :param num_runs: (Default value = 10) + :param save_graph: (Default value = True) + :param create_gif: (Default value = True) + + """ + mio.set_log_level(mio.LogLevel.Warning) + end_date = self.start_date + datetime.timedelta(days=num_days_sim) + + graph = self.get_graph(end_date, t_E, t_ISy, t_ISev, + t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) + + mobility_graph = osecir.MobilityGraph() + for node_idx in range(graph.num_nodes): + node = graph.get_node(node_idx) + self.set_npis(node.property.parameters, end_date, + damping_values[node.id // 1000 - 1]) + mobility_graph.add_node(node.id, node.property) + for edge_idx in range(graph.num_edges): + mobility_graph.add_edge( + graph.get_edge(edge_idx).start_node_idx, + graph.get_edge(edge_idx).end_node_idx, + graph.get_edge(edge_idx).property) + mobility_sim = osecir.MobilitySimulation(mobility_graph, t0=0, dt=0.5) + mobility_sim.advance(num_days_sim) + + results = {} + for node_idx in range(mobility_sim.graph.num_nodes): + node = mobility_sim.graph.get_node(node_idx) + if node.id in no_icu_ids: + results[f'no_icu_region{node_idx}'] = osecir.interpolate_simulation_result( + node.property.result) + else: + results[f'region{node_idx}'] = osecir.interpolate_simulation_result( + node.property.result) + + return results + + +def run_germany_nuts3_simulation(damping_values, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): + mio.set_log_level(mio.LogLevel.Warning) + file_path = os.path.dirname(os.path.abspath(__file__)) + + sim = Simulation( + data_dir=os.path.join(file_path, "../../../data"), + start_date=DATE_TIME, + results_dir=os.path.join(file_path, "../../../results_osecir")) + num_days_sim = 60 + + results = sim.run(num_days_sim, damping_values, t_E, t_ISy, + t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) + + return results + +def prior(): + mean = np.random.uniform(0, 1) + scale = 0.1 + a, b = (0 - mean) / scale, (1 - mean) / scale + damping_values = truncnorm.rvs( + a=a, b=b, loc=mean, scale=scale, size=16 + ) + return { + 'damping_values': damping_values, + 't_E': np.random.uniform(*bounds['t_E']), + 't_ISy': np.random.uniform(*bounds['t_ISy']), + 't_ISev': np.random.uniform(*bounds['t_ISev']), + 't_Cr': np.random.uniform(*bounds['t_Cr']), + 'mu_CR': np.random.uniform(*bounds['mu_CR']), + 'mu_IH': np.random.uniform(*bounds['mu_IH']), + 'mu_HU': np.random.uniform(*bounds['mu_HU']), + 'mu_UD': np.random.uniform(*bounds['mu_UD']), + 'transmission_prob': np.random.uniform(*bounds['transmission_prob']) + } + + +def load_divi_data(): + file_path = os.path.dirname(os.path.abspath(__file__)) + divi_path = os.path.join(file_path, "../../../data/Germany/pydata") + + data = pd.read_json(os.path.join(divi_path, "county_divi_ma7.json")) + data = data[data['Date'] >= np.datetime64(DATE_TIME)] + data = data[data['Date'] <= np.datetime64(DATE_TIME + datetime.timedelta(days=60))] + data = data.sort_values(by=['ID_County', 'Date']) + divi_data = data.pivot(index='Date', columns='ID_County', values='ICU') + divi_dict = {} + for i, region_id in enumerate(region_ids): + if region_id not in no_icu_ids: + divi_dict[f"region{i}"] = divi_data[region_id].to_numpy()[None, :, None] + else: + divi_dict[f"no_icu_region{i}"] = np.zeros((1, divi_data.shape[0], 1)) + return divi_dict + + +def extract_observables(simulation_results, observable_index=7): + for key in simulation_results.keys(): + if key not in inference_params: + simulation_results[key] = simulation_results[key][:, :, observable_index][..., np.newaxis] + return simulation_results + + +def create_train_data(filename, number_samples=1000): + + simulator = bf.simulators.make_simulator( + [prior, run_germany_nuts3_simulation] + ) + trainings_data = simulator.sample(number_samples) + trainings_data = extract_observables(trainings_data) + with open(filename, 'wb') as f: + pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) + + +def load_pickle(path): + with open(path, "rb") as f: + return pickle.load(f) + +def is_region_key(k: str) -> bool: + return 'region' in k + +def apply_aug(d: dict, aug) -> dict: + return {k: np.clip(aug(v), 0, None) if is_region_key(k) else v for k, v in d.items()} + +def concat_dicts(base: dict, new: dict) -> dict: + missing = set(base) - set(new) + if missing: + raise KeyError(f"new dict missing keys: {sorted(missing)}") + for k in base: + base[k] = np.concatenate([base[k], new[k]]) + return base + +def region_keys_sorted(d: dict): + def idx(k): + # handles "regionN" and "no_icu_regionN" + return int(k.split("region")[-1]) + return sorted([k for k in d if is_region_key(k)], key=idx) + +def aggregate_states(d: dict) -> None: + n_regions = len(region_ids) + # per state + for state in range(16): + idxs = [ + r for r in range(n_regions) + if region_ids[r] // 1000 == state + 1 + ] + d[f"fed_state{state}"] = np.sum([d[f"region{r}"] if region_ids[r] not in no_icu_ids else d[f"no_icu_region{r}"] for r in idxs], axis=0) + # all allowed regions + d["state"] = np.sum([d[f"fed_state{r}"] for r in range(16)], axis=0) + +def combine_results(dict_list): + combined = {} + for d in dict_list: + combined = concat_dicts(combined, d) if combined else d + return combined + +def get_workflow(): + + simulator = bf.make_simulator( + [prior, run_germany_nuts3_simulation] + ) + adapter = ( + bf.Adapter() + .to_array() + .convert_dtype("float64", "float32") + .constrain("damping_values", lower=0.0, upper=1.0) + .constrain("t_E", lower=bounds["t_E"][0], upper=bounds["t_E"][1]) + .constrain("t_ISy", lower=bounds["t_ISy"][0], upper=bounds["t_ISy"][1]) + .constrain("t_ISev", lower=bounds["t_ISev"][0], upper=bounds["t_ISev"][1]) + .constrain("t_Cr", lower=bounds["t_Cr"][0], upper=bounds["t_Cr"][1]) + .constrain("mu_CR", lower=bounds["mu_CR"][0], upper=bounds["mu_CR"][1]) + .constrain("mu_IH", lower=bounds["mu_IH"][0], upper=bounds["mu_IH"][1]) + .constrain("mu_HU", lower=bounds["mu_HU"][0], upper=bounds["mu_HU"][1]) + .constrain("mu_UD", lower=bounds["mu_UD"][0], upper=bounds["mu_UD"][1]) + .constrain("transmission_prob", lower=bounds["transmission_prob"][0], upper=bounds["transmission_prob"][1]) + .concatenate( + ["damping_values", "t_E", "t_ISy", "t_ISev", "t_Cr", + "mu_CR", "mu_IH", "mu_HU", "mu_UD", "transmission_prob"], + into="inference_variables", + axis=-1 + ) + .concatenate(summary_vars, into="summary_variables", axis=-1) + ) + + summary_network = bf.networks.FusionTransformer( + summary_dim=(len(bounds)+16)*2, dropout=0.1 + ) + inference_network = bf.networks.FlowMatching() + + # aug = bf.augmentations.NNPE(spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False) + workflow = bf.BasicWorkflow( + simulator=simulator, + adapter=adapter, + summary_network=summary_network, + inference_network=inference_network, + standardize='all' + # augmentations={f'region{i}': aug for i in range(len(region_ids))} + ) + + return workflow + + +def run_training(name, num_training_files=20): + train_template = "trainings_data{i}_"+name+".pickle" + val_path = f"validation_data_{name}.pickle" + + aug = bf.augmentations.NNPE( + spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False + ) + + # training data + train_files = [train_template.format(i=i) for i in range(1, 1+num_training_files)] + trainings_data = None + for p in train_files: + d = load_pickle(p) + d = apply_aug(d, aug=aug) # only on region keys + if trainings_data is None: + trainings_data = d + else: + trainings_data = concat_dicts(trainings_data, d) + aggregate_states(trainings_data) + + # validation data + validation_data = apply_aug(load_pickle(val_path), aug=aug) + aggregate_states(validation_data) + + # check data + workflow = get_workflow() + print("summary_variables shape:", workflow.adapter(trainings_data)["summary_variables"].shape) + print("inference_variables shape:", workflow.adapter(trainings_data)["inference_variables"].shape) + + history = workflow.fit_offline( + data=trainings_data, epochs=300, batch_size=64, validation_data=validation_data + ) + + workflow.approximator.save( + filepath=os.path.join(os.path.dirname(__file__), f"model_{name}.keras") + ) + + plots = workflow.plot_default_diagnostics( + test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True} + ) + plots['losses'].savefig(f'losses_{name}.png') + plots['recovery'].savefig(f'recovery_{name}.png') + plots['calibration_ecdf'].savefig(f'calibration_ecdf_{name}.png') + #plots['z_score_contraction'].savefig(f'z_score_contraction_{name}.png') + + +def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmentation=True): + val_path = f"validation_data_{name}.pickle" + synthetic = "_synthetic" if on_synthetic_data else "" + with_aug = "_with_aug" if apply_augmentation else "" + + aug = bf.augmentations.NNPE( + spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False + ) + + if on_synthetic_data: + # validation data + validation_data = load_pickle(val_path) + validation_data = apply_aug(validation_data, aug=aug) + aggregate_states(validation_data) + divi_dict = validation_data + divi_region_keys = region_keys_sorted(divi_dict) + + divi_data = np.concatenate( + [divi_dict[key] for key in divi_region_keys], axis=-1 + )[0] # only one dataset + else: + divi_dict = load_divi_data() + aggregate_states(divi_dict) + divi_region_keys = region_keys_sorted(divi_dict) + divi_data = np.concatenate( + [divi_dict[key] for key in divi_region_keys], axis=-1 + )[0] + + workflow = get_workflow() + workflow.approximator = keras.models.load_model( + filepath=os.path.join(os.path.dirname(__file__), f"model_{name}.keras") + ) + + if False: #os.path.exists(f'sims_{name}{synthetic}{with_aug}.pickle'): + simulations = load_pickle(f'sims_{name}{synthetic}{with_aug}.pickle') + print("loaded simulations from file") + else: + samples = workflow.sample(conditions=divi_dict, num_samples=num_samples) + results = [] + for i in range(num_samples): # we only have one dataset for inference here + result = run_germany_nuts3_simulation( + damping_values=samples['damping_values'][0, i], + t_E=samples['t_E'][0, i], t_ISy=samples['t_ISy'][0, i], + t_ISev=samples['t_ISev'][0, i], t_Cr=samples['t_Cr'][0, i], + mu_CR=samples['mu_CR'][0, i], mu_IH=samples['mu_IH'][0, i], + mu_HU=samples['mu_HU'][0, i], mu_UD=samples['mu_UD'][0, i], + transmission_prob=samples['transmission_prob'][0, i] + ) + for key in result.keys(): + result[key] = np.array(result[key])[None, ...] # add sample axis + results.append(result) + results = combine_results(results) + results = extract_observables(results) + if apply_augmentation: + results = apply_aug(results, aug=aug) + + # get sims in shape (samples, time, regions) + simulations = np.zeros((num_samples, divi_data.shape[0], divi_data.shape[1])) + for i in range(num_samples): + simulations[i] = np.concatenate([results[key][i] for key in divi_region_keys], axis=-1) + + # save sims + with open(f'sims_{name}{synthetic}{with_aug}.pickle', 'wb') as f: + pickle.dump(simulations, f, pickle.HIGHEST_PROTOCOL) + + # plot simulations + fig, ax = plt.subplots(nrows=2, ncols=5, figsize=(12, 5), layout="constrained") + ax = ax.flatten() + rand_index = np.random.choice(simulations.shape[-1], replace=False, size=len(ax)) + for i, a in enumerate(ax): + plot_region_median_mad( + simulations, region=rand_index[i], true_data=divi_data, label=r"Median $\pm$ Mad", ax=a + ) + plt.savefig(f'random_regions_{name}{synthetic}{with_aug}.png') + plt.close() + + plot_aggregated_over_regions(simulations, true_data=divi_data, label="Region Aggregated Median $\pm$ Mad") + plt.savefig(f'region_aggregated_{name}{synthetic}{with_aug}.png') + plt.close() + + + fig, axis = plt.subplots(1, 2, figsize=(10, 4), sharex=True, layout="constrained") + ax = calibration_curves_per_region(simulations, divi_data, ax=axis[0]) + ax, stats = calibration_median_mad_over_regions(simulations, divi_data, ax=axis[1]) + plt.savefig(f'calibration_per_region_{name}{synthetic}{with_aug}.png') + plt.close() + + # plot = bf.diagnostics.pairs_posterior(simulations, priors=validation_data, dataset_id=0) + # plot.savefig(f'pairs_posterior_wcovidparams_oct{synthetic}_ma7_noise.png') + + +if __name__ == "__main__": + name = "counties" + + # create_train_data(filename=f'trainings_data11_{name}.pickle', number_samples=1000) + # run_training(name=name, num_training_files=20) + run_inference(name=name, on_synthetic_data=True) + run_inference(name=name, on_synthetic_data=True, apply_augmentation=False) + run_inference(name=name, on_synthetic_data=False) + run_inference(name=name, on_synthetic_data=False, apply_augmentation=False) \ No newline at end of file diff --git a/pycode/memilio-epidata/memilio/epidata/defaultDict.py b/pycode/memilio-epidata/memilio/epidata/defaultDict.py index 281f88c130..4172220849 100644 --- a/pycode/memilio-epidata/memilio/epidata/defaultDict.py +++ b/pycode/memilio-epidata/memilio/epidata/defaultDict.py @@ -47,7 +47,7 @@ 'start_date': date(2020, 1, 1), 'end_date': date(2021, 1, 1), 'split_berlin': False, - 'impute_dates': True, + 'impute_dates': False, 'moving_average': 0, 'file_format': 'json_timeasstring', 'no_raw': False, From 2bd5070e2ae08c5b07f206928d7f9da71322bc26 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Thu, 2 Oct 2025 11:16:02 +0200 Subject: [PATCH 54/73] run fitting with skipping the first two weeks --- .../simulation/graph_germany_nuts3_clean.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/pycode/examples/simulation/graph_germany_nuts3_clean.py b/pycode/examples/simulation/graph_germany_nuts3_clean.py index da46f6fe65..8b28706e26 100644 --- a/pycode/examples/simulation/graph_germany_nuts3_clean.py +++ b/pycode/examples/simulation/graph_germany_nuts3_clean.py @@ -551,6 +551,9 @@ def combine_results(dict_list): combined = concat_dicts(combined, d) if combined else d return combined +def skip_2weeks(d:dict) -> dict: + return {k: v[:, 14:, :] if is_region_key(k) else v for k, v in d.items()} + def get_workflow(): simulator = bf.make_simulator( @@ -598,8 +601,8 @@ def get_workflow(): def run_training(name, num_training_files=20): - train_template = "trainings_data{i}_"+name+".pickle" - val_path = f"validation_data_{name}.pickle" + train_template = "trainings_data{i}_counties.pickle" + val_path = f"validation_data_counties.pickle" aug = bf.augmentations.NNPE( spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False @@ -611,6 +614,7 @@ def run_training(name, num_training_files=20): for p in train_files: d = load_pickle(p) d = apply_aug(d, aug=aug) # only on region keys + d = skip_2weeks(d) if trainings_data is None: trainings_data = d else: @@ -619,6 +623,7 @@ def run_training(name, num_training_files=20): # validation data validation_data = apply_aug(load_pickle(val_path), aug=aug) + validation_data = skip_2weeks(validation_data) aggregate_states(validation_data) # check data @@ -644,7 +649,7 @@ def run_training(name, num_training_files=20): def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmentation=True): - val_path = f"validation_data_{name}.pickle" + val_path = f"validation_data_counties.pickle" synthetic = "_synthetic" if on_synthetic_data else "" with_aug = "_with_aug" if apply_augmentation else "" @@ -656,7 +661,8 @@ def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmenta # validation data validation_data = load_pickle(val_path) validation_data = apply_aug(validation_data, aug=aug) - aggregate_states(validation_data) + validation_data_skip2w = skip_2weeks(validation_data) + aggregate_states(validation_data_skip2w) divi_dict = validation_data divi_region_keys = region_keys_sorted(divi_dict) @@ -665,7 +671,8 @@ def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmenta )[0] # only one dataset else: divi_dict = load_divi_data() - aggregate_states(divi_dict) + validation_data_skip2w = skip_2weeks(divi_dict) + aggregate_states(validation_data_skip2w) divi_region_keys = region_keys_sorted(divi_dict) divi_data = np.concatenate( [divi_dict[key] for key in divi_region_keys], axis=-1 @@ -680,7 +687,7 @@ def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmenta simulations = load_pickle(f'sims_{name}{synthetic}{with_aug}.pickle') print("loaded simulations from file") else: - samples = workflow.sample(conditions=divi_dict, num_samples=num_samples) + samples = workflow.sample(conditions=validation_data_skip2w, num_samples=num_samples) results = [] for i in range(num_samples): # we only have one dataset for inference here result = run_germany_nuts3_simulation( @@ -735,7 +742,7 @@ def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmenta if __name__ == "__main__": - name = "counties" + name = "skip2w" # create_train_data(filename=f'trainings_data11_{name}.pickle', number_samples=1000) # run_training(name=name, num_training_files=20) From e0b8d77fe6ec0932f87ee0379437cdcabbf22684 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Wed, 8 Oct 2025 07:44:25 +0200 Subject: [PATCH 55/73] save experiments in own folders --- .../simulation/graph_germany_nuts3_clean.py | 34 ++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/pycode/examples/simulation/graph_germany_nuts3_clean.py b/pycode/examples/simulation/graph_germany_nuts3_clean.py index 8b28706e26..b4ccf5f41f 100644 --- a/pycode/examples/simulation/graph_germany_nuts3_clean.py +++ b/pycode/examples/simulation/graph_germany_nuts3_clean.py @@ -601,8 +601,8 @@ def get_workflow(): def run_training(name, num_training_files=20): - train_template = "trainings_data{i}_counties.pickle" - val_path = f"validation_data_counties.pickle" + train_template = name+"/trainings_data{i}_"+name+".pickle" + val_path = f"{name}/validation_data_{name}.pickle" aug = bf.augmentations.NNPE( spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False @@ -636,20 +636,20 @@ def run_training(name, num_training_files=20): ) workflow.approximator.save( - filepath=os.path.join(os.path.dirname(__file__), f"model_{name}.keras") + filepath=os.path.join(f"{name}/model_{name}.keras") ) plots = workflow.plot_default_diagnostics( test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True} ) - plots['losses'].savefig(f'losses_{name}.png') - plots['recovery'].savefig(f'recovery_{name}.png') - plots['calibration_ecdf'].savefig(f'calibration_ecdf_{name}.png') - #plots['z_score_contraction'].savefig(f'z_score_contraction_{name}.png') + plots['losses'].savefig(f'{name}/losses_{name}.png') + plots['recovery'].savefig(f'{name}/recovery_{name}.png') + plots['calibration_ecdf'].savefig(f'{name}/calibration_ecdf_{name}.png') + #plots['z_score_contraction'].savefig(f'{name}/z_score_contraction_{name}.png') def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmentation=True): - val_path = f"validation_data_counties.pickle" + val_path = f"{name}/validation_data_{name}.pickle" synthetic = "_synthetic" if on_synthetic_data else "" with_aug = "_with_aug" if apply_augmentation else "" @@ -680,11 +680,11 @@ def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmenta workflow = get_workflow() workflow.approximator = keras.models.load_model( - filepath=os.path.join(os.path.dirname(__file__), f"model_{name}.keras") + filepath=os.path.join(f"{name}/model_{name}.keras") ) - if False: #os.path.exists(f'sims_{name}{synthetic}{with_aug}.pickle'): - simulations = load_pickle(f'sims_{name}{synthetic}{with_aug}.pickle') + if False: #os.path.exists(f'{name}/sims_{name}{synthetic}{with_aug}.pickle'): + simulations = load_pickle(f'{name}/sims_{name}{synthetic}{with_aug}.pickle') print("loaded simulations from file") else: samples = workflow.sample(conditions=validation_data_skip2w, num_samples=num_samples) @@ -712,7 +712,7 @@ def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmenta simulations[i] = np.concatenate([results[key][i] for key in divi_region_keys], axis=-1) # save sims - with open(f'sims_{name}{synthetic}{with_aug}.pickle', 'wb') as f: + with open(f'{name}/sims_{name}{synthetic}{with_aug}.pickle', 'wb') as f: pickle.dump(simulations, f, pickle.HIGHEST_PROTOCOL) # plot simulations @@ -723,18 +723,18 @@ def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmenta plot_region_median_mad( simulations, region=rand_index[i], true_data=divi_data, label=r"Median $\pm$ Mad", ax=a ) - plt.savefig(f'random_regions_{name}{synthetic}{with_aug}.png') + plt.savefig(f'{name}/random_regions_{name}{synthetic}{with_aug}.png') plt.close() plot_aggregated_over_regions(simulations, true_data=divi_data, label="Region Aggregated Median $\pm$ Mad") - plt.savefig(f'region_aggregated_{name}{synthetic}{with_aug}.png') + plt.savefig(f'{name}/region_aggregated_{name}{synthetic}{with_aug}.png') plt.close() fig, axis = plt.subplots(1, 2, figsize=(10, 4), sharex=True, layout="constrained") ax = calibration_curves_per_region(simulations, divi_data, ax=axis[0]) ax, stats = calibration_median_mad_over_regions(simulations, divi_data, ax=axis[1]) - plt.savefig(f'calibration_per_region_{name}{synthetic}{with_aug}.png') + plt.savefig(f'{name}/calibration_per_region_{name}{synthetic}{with_aug}.png') plt.close() # plot = bf.diagnostics.pairs_posterior(simulations, priors=validation_data, dataset_id=0) @@ -744,7 +744,9 @@ def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmenta if __name__ == "__main__": name = "skip2w" - # create_train_data(filename=f'trainings_data11_{name}.pickle', number_samples=1000) + if not os.path.exists(name): + os.makedirs(name) + # create_train_data(filename=f'{name}/trainings_data1_{name}.pickle', number_samples=10) # run_training(name=name, num_training_files=20) run_inference(name=name, on_synthetic_data=True) run_inference(name=name, on_synthetic_data=True, apply_augmentation=False) From cdbedb2e63a06d984a573882dceafc07dd6df090 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Mon, 13 Oct 2025 10:04:06 +0200 Subject: [PATCH 56/73] fitting on number infecteds --- .../simulation/graph_germany_nuts3_clean.py | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/pycode/examples/simulation/graph_germany_nuts3_clean.py b/pycode/examples/simulation/graph_germany_nuts3_clean.py index b4ccf5f41f..e052ff0960 100644 --- a/pycode/examples/simulation/graph_germany_nuts3_clean.py +++ b/pycode/examples/simulation/graph_germany_nuts3_clean.py @@ -38,7 +38,6 @@ from memilio.epidata import defaultDict as dd - excluded_ids = [11001, 11002, 11003, 11004, 11005, 11006, 11007, 11008, 11009, 11010, 11011, 11012, 16056] no_icu_ids = [7338, 9374, 9473, 9573] @@ -491,6 +490,20 @@ def load_divi_data(): return divi_dict +def load_extrapolated_case_data(): + file_path = os.path.dirname(os.path.abspath(__file__)) + case_path = os.path.join(file_path, "../../../data/Germany/pydata") + + file = h5py.File(os.path.join(case_path, "Results_rki.h5")) + divi_dict = {} + for i, region_id in enumerate(region_ids): + if region_id not in no_icu_ids: + divi_dict[f"region{i}"] = np.array(file[f"{region_id}"]['Total'][:, 4])[None, :, None] + else: + divi_dict[f"no_icu_region{i}"] = np.zeros((1, np.array(file[f"{region_id}"]['Total']).shape[0], 1)) + return divi_dict + + def extract_observables(simulation_results, observable_index=7): for key in simulation_results.keys(): if key not in inference_params: @@ -504,7 +517,7 @@ def create_train_data(filename, number_samples=1000): [prior, run_germany_nuts3_simulation] ) trainings_data = simulator.sample(number_samples) - trainings_data = extract_observables(trainings_data) + trainings_data = extract_observables(trainings_data, observable_index=4) with open(filename, 'wb') as f: pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) @@ -614,7 +627,6 @@ def run_training(name, num_training_files=20): for p in train_files: d = load_pickle(p) d = apply_aug(d, aug=aug) # only on region keys - d = skip_2weeks(d) if trainings_data is None: trainings_data = d else: @@ -661,8 +673,7 @@ def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmenta # validation data validation_data = load_pickle(val_path) validation_data = apply_aug(validation_data, aug=aug) - validation_data_skip2w = skip_2weeks(validation_data) - aggregate_states(validation_data_skip2w) + aggregate_states(validation_data) divi_dict = validation_data divi_region_keys = region_keys_sorted(divi_dict) @@ -671,7 +682,6 @@ def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmenta )[0] # only one dataset else: divi_dict = load_divi_data() - validation_data_skip2w = skip_2weeks(divi_dict) aggregate_states(validation_data_skip2w) divi_region_keys = region_keys_sorted(divi_dict) divi_data = np.concatenate( @@ -687,7 +697,7 @@ def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmenta simulations = load_pickle(f'{name}/sims_{name}{synthetic}{with_aug}.pickle') print("loaded simulations from file") else: - samples = workflow.sample(conditions=validation_data_skip2w, num_samples=num_samples) + samples = workflow.sample(conditions=divi_dict, num_samples=num_samples) results = [] for i in range(num_samples): # we only have one dataset for inference here result = run_germany_nuts3_simulation( @@ -702,7 +712,7 @@ def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmenta result[key] = np.array(result[key])[None, ...] # add sample axis results.append(result) results = combine_results(results) - results = extract_observables(results) + results = extract_observables(results, observable_index=4) if apply_augmentation: results = apply_aug(results, aug=aug) @@ -742,11 +752,11 @@ def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmenta if __name__ == "__main__": - name = "skip2w" + name = "infecteds" if not os.path.exists(name): os.makedirs(name) - # create_train_data(filename=f'{name}/trainings_data1_{name}.pickle', number_samples=10) + # create_train_data(filename=f'{name}/validation_data_{name}.pickle', number_samples=1000) # run_training(name=name, num_training_files=20) run_inference(name=name, on_synthetic_data=True) run_inference(name=name, on_synthetic_data=True, apply_augmentation=False) From 2474af0c4fc829af1d6ab74aaca163d7ea71de57 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Thu, 16 Oct 2025 14:03:29 +0200 Subject: [PATCH 57/73] add two more dates for dampings --- .../simulation/graph_germany_nuts3_clean.py | 128 +++++++++++++++--- 1 file changed, 106 insertions(+), 22 deletions(-) diff --git a/pycode/examples/simulation/graph_germany_nuts3_clean.py b/pycode/examples/simulation/graph_germany_nuts3_clean.py index e052ff0960..7a986267fa 100644 --- a/pycode/examples/simulation/graph_germany_nuts3_clean.py +++ b/pycode/examples/simulation/graph_germany_nuts3_clean.py @@ -91,7 +91,7 @@ def plot_region_median_mad( fig, ax = plt.subplots() line, = ax.plot( - x, med, lw=2, label=label or f"Region {region}", color=color) + x, med, lw=2, label=label or f"{dd.County[region_ids[region]]}", color=color) band = ax.fill_between(x, med - mad, med + mad, alpha=0.25, color=color) if true_data is not None: true_vals = true_data[:, region] # (time_points,) @@ -99,7 +99,7 @@ def plot_region_median_mad( ax.set_xlabel("Time", fontsize=12) ax.set_ylabel("ICU", fontsize=12) - ax.set_title(f"Region {region}", fontsize=12) + ax.set_title(f"{dd.County[region_ids[region]]}", fontsize=12) if label is not None: ax.legend(fontsize=11, loc="upper right") return line, band @@ -284,6 +284,65 @@ def calibration_median_mad_over_regions( return ax, {"levels": x, "median": med, "mad": mad} +def plot_damping_values(name, num_samples=100): + divi_dict = load_divi_data() + aggregate_states(divi_dict) + + workflow = get_workflow() + workflow.approximator = keras.models.load_model( + filepath=os.path.join(f"{name}/model_{name}.keras") + ) + + samples = workflow.sample(conditions=divi_dict, num_samples=num_samples) + print(samples['damping_values']) + print(samples['damping_values'].reshape((num_samples, 16, 3))) + samples['damping_values'] = samples['damping_values'].reshape((num_samples, 16, 3)) + + med = np.median(samples['damping_values'], axis=0) + mad = np.median(np.abs(samples['damping_values'] - med), axis=0) + + # Extend for step plotting + med_extended = np.hstack([med, med[:, -1][:, None]]) + mad_extended = np.hstack([mad, mad[:, -1][:, None]]) + + # Plot damping values per region + fig, axes = plt.subplots(4, 4, figsize=(15, 10), constrained_layout=True) + axes = axes.flatten() + x = np.arange(15, 61, 15) # Time steps from 15 to 60 + + for i, ax in enumerate(axes): + if i < 16: + ax.stairs(med[i], edges=x, lw=2, color='red', baseline=None) + ax.fill_between( + x, med_extended[i] - mad_extended[i], med_extended[i] + mad_extended[i], + alpha=0.25, color='red', step='post' + ) + ax.set_title(f"{dd.State[i+1]}", fontsize=10) + ax.set_xlabel("Time", fontsize=8) + ax.set_ylabel("Damping Value", fontsize=8) + else: + ax.axis('off') # Hide unused subplots + + plt.suptitle("Damping Values per Region", fontsize=14) + plt.savefig(f"{name}/damping_values.png", dpi=300) + + # Combined plot for all regions + fig, ax = plt.subplots(figsize=(10, 6)) + cmap = plt.cm.get_cmap("viridis", 16) # Colormap with 16 distinct colors + + for i in range(16): + ax.stairs( + med[i], edges=x, lw=2, label=f"{dd.State[i+1]}", + color=cmap(i), baseline=None + ) + + ax.set_title("Damping Values per Region (Combined)", fontsize=14) + ax.set_xlabel("Time", fontsize=12) + ax.set_ylabel("Damping Value", fontsize=12) + ax.legend(fontsize=10, loc="upper right", ncol=2) + plt.savefig(f"{name}/damping_values_combined.png", dpi=300) + + class Simulation: # todo: correct class? """ """ @@ -336,20 +395,31 @@ def set_contact_matrices(self, model): contact_matrices[0] = mio.ContactMatrix(baseline, minimum) model.parameters.ContactPatterns.cont_freq_mat = contact_matrices - def set_npis(self, params, end_date, damping_value): + def set_npis(self, params, end_date, damping_values): """ :param params: :param end_date: """ + start_damping_1 = DATE_TIME + datetime.timedelta(days=15) + start_damping_2 = DATE_TIME + datetime.timedelta(days=30) + start_damping_3 = DATE_TIME + datetime.timedelta(days=45) + + if start_damping_1 < end_date: + start_date = (start_damping_1 - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping( + mio.Damping(np.r_[damping_values[0]], t=start_date)) - start_damping = DATE_TIME + datetime.timedelta(days=7) + if start_damping_2 < end_date: + start_date = (start_damping_2 - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping( + mio.Damping(np.r_[damping_values[1]], t=start_date)) - if start_damping < end_date: - start_date = (start_damping - self.start_date).days + if start_damping_3 < end_date: + start_date = (start_damping_3 - self.start_date).days params.ContactPatterns.cont_freq_mat[0].add_damping( - mio.Damping(np.r_[damping_value], t=start_date)) + mio.Damping(np.r_[damping_values[2]], t=start_date)) def get_graph(self, end_date, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): """ @@ -412,8 +482,12 @@ def run(self, num_days_sim, damping_values, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_ mobility_graph = osecir.MobilityGraph() for node_idx in range(graph.num_nodes): node = graph.get_node(node_idx) - self.set_npis(node.property.parameters, end_date, - damping_values[node.id // 1000 - 1]) + + self.set_npis( + node.property.parameters, + end_date, + damping_values[node.id // 1000 - 1] + ) mobility_graph.add_node(node.id, node.property) for edge_idx in range(graph.num_edges): mobility_graph.add_edge( @@ -452,14 +526,16 @@ def run_germany_nuts3_simulation(damping_values, t_E, t_ISy, t_ISev, t_Cr, mu_CR return results def prior(): - mean = np.random.uniform(0, 1) - scale = 0.1 - a, b = (0 - mean) / scale, (1 - mean) / scale - damping_values = truncnorm.rvs( - a=a, b=b, loc=mean, scale=scale, size=16 - ) + damping_values = np.zeros((3, 16)) + for i in range(3): + mean = np.random.uniform(0, 1) + scale = 0.1 + a, b = (0 - mean) / scale, (1 - mean) / scale + damping_values[i] = truncnorm.rvs( + a=a, b=b, loc=mean, scale=scale, size=16 + ) return { - 'damping_values': damping_values, + 'damping_values': np.transpose(damping_values), 't_E': np.random.uniform(*bounds['t_E']), 't_ISy': np.random.uniform(*bounds['t_ISy']), 't_ISev': np.random.uniform(*bounds['t_ISev']), @@ -517,7 +593,7 @@ def create_train_data(filename, number_samples=1000): [prior, run_germany_nuts3_simulation] ) trainings_data = simulator.sample(number_samples) - trainings_data = extract_observables(trainings_data, observable_index=4) + trainings_data = extract_observables(trainings_data) with open(filename, 'wb') as f: pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) @@ -627,6 +703,8 @@ def run_training(name, num_training_files=20): for p in train_files: d = load_pickle(p) d = apply_aug(d, aug=aug) # only on region keys + d = skip_2weeks(d) + d['damping_values'] = d['damping_values'].reshape((d['damping_values'].shape[0], -1)) if trainings_data is None: trainings_data = d else: @@ -637,6 +715,7 @@ def run_training(name, num_training_files=20): validation_data = apply_aug(load_pickle(val_path), aug=aug) validation_data = skip_2weeks(validation_data) aggregate_states(validation_data) + validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) # check data workflow = get_workflow() @@ -673,7 +752,9 @@ def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmenta # validation data validation_data = load_pickle(val_path) validation_data = apply_aug(validation_data, aug=aug) - aggregate_states(validation_data) + validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) + validation_data_skip2w = skip_2weeks(validation_data) + aggregate_states(validation_data_skip2w) divi_dict = validation_data divi_region_keys = region_keys_sorted(divi_dict) @@ -682,6 +763,7 @@ def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmenta )[0] # only one dataset else: divi_dict = load_divi_data() + validation_data_skip2w = skip_2weeks(divi_dict) aggregate_states(validation_data_skip2w) divi_region_keys = region_keys_sorted(divi_dict) divi_data = np.concatenate( @@ -697,7 +779,8 @@ def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmenta simulations = load_pickle(f'{name}/sims_{name}{synthetic}{with_aug}.pickle') print("loaded simulations from file") else: - samples = workflow.sample(conditions=divi_dict, num_samples=num_samples) + samples = workflow.sample(conditions=validation_data_skip2w, num_samples=num_samples) + samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], num_samples, 16, 3)) results = [] for i in range(num_samples): # we only have one dataset for inference here result = run_germany_nuts3_simulation( @@ -712,7 +795,7 @@ def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmenta result[key] = np.array(result[key])[None, ...] # add sample axis results.append(result) results = combine_results(results) - results = extract_observables(results, observable_index=4) + results = extract_observables(results) if apply_augmentation: results = apply_aug(results, aug=aug) @@ -752,7 +835,7 @@ def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmenta if __name__ == "__main__": - name = "infecteds" + name = "3dampings" if not os.path.exists(name): os.makedirs(name) @@ -761,4 +844,5 @@ def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmenta run_inference(name=name, on_synthetic_data=True) run_inference(name=name, on_synthetic_data=True, apply_augmentation=False) run_inference(name=name, on_synthetic_data=False) - run_inference(name=name, on_synthetic_data=False, apply_augmentation=False) \ No newline at end of file + run_inference(name=name, on_synthetic_data=False, apply_augmentation=False) + plot_damping_values(name=name, num_samples=100) \ No newline at end of file From 0650d860dd3d5eebb6219675df0bec8c1c6b5229 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Thu, 23 Oct 2025 16:20:27 +0200 Subject: [PATCH 58/73] add some plots --- .../simulation/graph_germany_nuts3_clean.py | 126 ++++++++++++++---- 1 file changed, 97 insertions(+), 29 deletions(-) diff --git a/pycode/examples/simulation/graph_germany_nuts3_clean.py b/pycode/examples/simulation/graph_germany_nuts3_clean.py index 7a986267fa..aa6654a0f8 100644 --- a/pycode/examples/simulation/graph_germany_nuts3_clean.py +++ b/pycode/examples/simulation/graph_germany_nuts3_clean.py @@ -37,6 +37,8 @@ from memilio.simulation.osecir import Model, Simulation, interpolate_simulation_result from memilio.epidata import defaultDict as dd +import geopandas as gpd + excluded_ids = [11001, 11002, 11003, 11004, 11005, 11006, 11007, 11008, 11009, 11010, 11011, 11012, 16056] @@ -152,8 +154,80 @@ def plot_aggregated_over_regions( ax.legend(fontsize=11) return line, band +def plot_icu_on_germany(simulations, name, synthetic, with_aug): + med = np.median(simulations, axis=0) + + population = pd.read_json('data/Germany/pydata/county_current_population.json') + values = med / population['Population'].to_numpy()[None, :] * 100000 + + + map_data = gpd.read_file(os.path.join(os.getcwd(), 'tools/vg2500_12-31.utm32s.shape/vg2500/VG2500_KRS.shp')) + fedstate_data = gpd.read_file(os.path.join(os.getcwd(), 'tools/vg2500_12-31.utm32s.shape/vg2500/VG2500_LAN.shp')) + + plot_map(values[0], map_data, fedstate_data, "Median ICU", f"{name}/median_icu_germany_initial_{name}{synthetic}{with_aug}") + plot_map(values[-1], map_data, fedstate_data, "Median ICU", f"{name}/median_icu_germany_final_{name}{synthetic}{with_aug}") + +def plot_map(values, map_data, fedstate_data, label, filename): + map_data[label] = map_data['ARS'].map(dict(zip([f"{region_id:05d}" for region_id in region_ids], values))) + + fig, ax = plt.subplots(figsize=(12, 14)) + map_data.plot( + column=f"{label}", + cmap='Reds', + linewidth=0.5, + ax=ax, + edgecolor='0.6', + legend=True, + legend_kwds={'label': f"{label} per County", 'shrink': 0.6}, + ) + fedstate_data.boundary.plot(ax=ax, color='black', linewidth=1) + ax.set_title(f"{label} per County") + ax.axis('off') + plt.savefig(filename, bbox_inches='tight', dpi=300) + + +def plot_aggregated_to_federal_states(data, true_data, name, synthetic, with_aug): + fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(25, 25), layout="constrained") + ax = ax.flatten() + for state in range(16): + idxs = [i for i, region_id in enumerate(region_ids) if region_id // 1000 == state + 1] + state_data = np.sum(data[:, :, idxs], axis=-1) # Aggregate over regions in the state + true_state_data = np.sum(true_data[:, idxs], axis=-1) if true_data is not None else None + plot_aggregated_over_regions( + state_data[:, :, None], # Add a dummy region axis for compatibility + true_data=true_state_data[:, None] if true_state_data is not None else None, + ax=ax[state], + label=f"State {state + 1}", + color=f"C{state % 10}" # Cycle through 10 colors + ) + plt.savefig(f'{name}/federal_states_{name}{synthetic}{with_aug}.png') + plt.close() + + # %% +# plot simulations for all regions in 10x4 blocks +def plot_all_regions(simulations, divi_data, name, synthetic, with_aug): + n_regions = simulations.shape[-1] + n_cols = 4 + n_rows = 10 + n_blocks = (n_regions + n_cols * n_rows - 1) // (n_cols * n_rows) + + for block in range(n_blocks): + start_idx = block * n_cols * n_rows + end_idx = min(start_idx + n_cols * n_rows, n_regions) + fig, ax = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(15, 25), layout="constrained") + ax = ax.flatten() + for i, region_idx in enumerate(range(start_idx, end_idx)): + plot_region_median_mad( + simulations, region=region_idx, true_data=divi_data, label=r"Median $\pm$ Mad", ax=ax[i] + ) + # Hide unused subplots + for i in range(end_idx - start_idx, len(ax)): + ax[i].axis("off") + plt.savefig(f'{name}/regions_block_{block + 1}_{name}{synthetic}{with_aug}.png') + plt.close() + def calibration_curves_per_region( data: np.ndarray, @@ -284,22 +358,10 @@ def calibration_median_mad_over_regions( return ax, {"levels": x, "median": med, "mad": mad} -def plot_damping_values(name, num_samples=100): - divi_dict = load_divi_data() - aggregate_states(divi_dict) - - workflow = get_workflow() - workflow.approximator = keras.models.load_model( - filepath=os.path.join(f"{name}/model_{name}.keras") - ) - - samples = workflow.sample(conditions=divi_dict, num_samples=num_samples) - print(samples['damping_values']) - print(samples['damping_values'].reshape((num_samples, 16, 3))) - samples['damping_values'] = samples['damping_values'].reshape((num_samples, 16, 3)) +def plot_damping_values(damping_values, name, synthetic, with_aug): - med = np.median(samples['damping_values'], axis=0) - mad = np.median(np.abs(samples['damping_values'] - med), axis=0) + med = np.median(damping_values, axis=0) + mad = np.median(np.abs(damping_values - med), axis=0) # Extend for step plotting med_extended = np.hstack([med, med[:, -1][:, None]]) @@ -324,7 +386,7 @@ def plot_damping_values(name, num_samples=100): ax.axis('off') # Hide unused subplots plt.suptitle("Damping Values per Region", fontsize=14) - plt.savefig(f"{name}/damping_values.png", dpi=300) + plt.savefig(f"{name}/damping_values{name}{synthetic}{with_aug}.png", dpi=300) # Combined plot for all regions fig, ax = plt.subplots(figsize=(10, 6)) @@ -340,7 +402,7 @@ def plot_damping_values(name, num_samples=100): ax.set_xlabel("Time", fontsize=12) ax.set_ylabel("Damping Value", fontsize=12) ax.legend(fontsize=10, loc="upper right", ncol=2) - plt.savefig(f"{name}/damping_values_combined.png", dpi=300) + plt.savefig(f"{name}/damping_values_combined{name}{synthetic}{with_aug}.png", dpi=300) class Simulation: # todo: correct class? @@ -808,16 +870,9 @@ def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmenta with open(f'{name}/sims_{name}{synthetic}{with_aug}.pickle', 'wb') as f: pickle.dump(simulations, f, pickle.HIGHEST_PROTOCOL) - # plot simulations - fig, ax = plt.subplots(nrows=2, ncols=5, figsize=(12, 5), layout="constrained") - ax = ax.flatten() - rand_index = np.random.choice(simulations.shape[-1], replace=False, size=len(ax)) - for i, a in enumerate(ax): - plot_region_median_mad( - simulations, region=rand_index[i], true_data=divi_data, label=r"Median $\pm$ Mad", ax=a - ) - plt.savefig(f'{name}/random_regions_{name}{synthetic}{with_aug}.png') - plt.close() + plot_all_regions(simulations, divi_data, name, synthetic, with_aug) + + plot_aggregated_to_federal_states(simulations, divi_data, name, synthetic, with_aug) plot_aggregated_over_regions(simulations, true_data=divi_data, label="Region Aggregated Median $\pm$ Mad") plt.savefig(f'{name}/region_aggregated_{name}{synthetic}{with_aug}.png') @@ -830,6 +885,20 @@ def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmenta plt.savefig(f'{name}/calibration_per_region_{name}{synthetic}{with_aug}.png') plt.close() + + print("Fitted parameters ", name, synthetic, with_aug, ":") + for param in inference_params: + if param == 'damping_values': + plot_damping_values(samples['damping_values'][0].reshape((num_samples, 16, 3)), name, synthetic, with_aug) + else: + med = np.median(samples[param][0], axis=0) + mad = np.median(np.abs(samples[param][0] - med), axis=0) + print(param, ": med: ", med) + print(param, ": mad: ", mad) + + plot_icu_on_germany(simulations, name, synthetic, with_aug) + + # plot = bf.diagnostics.pairs_posterior(simulations, priors=validation_data, dataset_id=0) # plot.savefig(f'pairs_posterior_wcovidparams_oct{synthetic}_ma7_noise.png') @@ -844,5 +913,4 @@ def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmenta run_inference(name=name, on_synthetic_data=True) run_inference(name=name, on_synthetic_data=True, apply_augmentation=False) run_inference(name=name, on_synthetic_data=False) - run_inference(name=name, on_synthetic_data=False, apply_augmentation=False) - plot_damping_values(name=name, num_samples=100) \ No newline at end of file + run_inference(name=name, on_synthetic_data=False, apply_augmentation=False) \ No newline at end of file From b591ec98300a8bbb6f395e1670a5270c479601ac Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Mon, 27 Oct 2025 09:00:09 +0100 Subject: [PATCH 59/73] use the same simulations with and without noise --- .../simulation/graph_germany_nuts3_clean.py | 67 +++++++++++-------- 1 file changed, 39 insertions(+), 28 deletions(-) diff --git a/pycode/examples/simulation/graph_germany_nuts3_clean.py b/pycode/examples/simulation/graph_germany_nuts3_clean.py index aa6654a0f8..f087be5cfb 100644 --- a/pycode/examples/simulation/graph_germany_nuts3_clean.py +++ b/pycode/examples/simulation/graph_germany_nuts3_clean.py @@ -801,18 +801,18 @@ def run_training(name, num_training_files=20): #plots['z_score_contraction'].savefig(f'{name}/z_score_contraction_{name}.png') -def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmentation=True): +def run_inference(name, num_samples=10, on_synthetic_data=False): val_path = f"{name}/validation_data_{name}.pickle" synthetic = "_synthetic" if on_synthetic_data else "" - with_aug = "_with_aug" if apply_augmentation else "" aug = bf.augmentations.NNPE( spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False ) + # validation data + validation_data = load_pickle(val_path) if on_synthetic_data: # validation data - validation_data = load_pickle(val_path) validation_data = apply_aug(validation_data, aug=aug) validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) validation_data_skip2w = skip_2weeks(validation_data) @@ -858,49 +858,62 @@ def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmenta results.append(result) results = combine_results(results) results = extract_observables(results) - if apply_augmentation: - results = apply_aug(results, aug=aug) + results_aug = apply_aug(results, aug=aug) # get sims in shape (samples, time, regions) simulations = np.zeros((num_samples, divi_data.shape[0], divi_data.shape[1])) + simulations_aug = np.zeros((num_samples, divi_data.shape[0], divi_data.shape[1])) for i in range(num_samples): simulations[i] = np.concatenate([results[key][i] for key in divi_region_keys], axis=-1) + simulations_aug[i] = np.concatenate([results_aug[key][i] for key in divi_region_keys], axis=-1) # save sims - with open(f'{name}/sims_{name}{synthetic}{with_aug}.pickle', 'wb') as f: + with open(f'{name}/sims_{name}{synthetic}.pickle', 'wb') as f: pickle.dump(simulations, f, pickle.HIGHEST_PROTOCOL) + with open(f'{name}/sims_{name}{synthetic}_with_aug.pickle', 'wb') as f: + pickle.dump(simulations_aug, f, pickle.HIGHEST_PROTOCOL) + - plot_all_regions(simulations, divi_data, name, synthetic, with_aug) + plot_all_regions(simulations, divi_data, name, synthetic, with_aug="") + plot_all_regions(simulations_aug, divi_data, name, synthetic, with_aug="_with_aug") - plot_aggregated_to_federal_states(simulations, divi_data, name, synthetic, with_aug) + plot_aggregated_to_federal_states(simulations, divi_data, name, synthetic, with_aug="") + plot_aggregated_to_federal_states(simulations_aug, divi_data, name, synthetic, with_aug="_with_aug") - plot_aggregated_over_regions(simulations, true_data=divi_data, label="Region Aggregated Median $\pm$ Mad") - plt.savefig(f'{name}/region_aggregated_{name}{synthetic}{with_aug}.png') + fig, axes = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=True, constrained_layout=True) + # Plot without augmentation + plot_aggregated_over_regions( + simulations, true_data=divi_data, label="Region Aggregated Median $\pm$ Mad (No Aug)", ax=axes[0], color="red" + ) + axes[0].set_title("Without Augmentation") + # Plot with augmentation + plot_aggregated_over_regions( + simulations_aug, true_data=divi_data, label="Region Aggregated Median $\pm$ Mad (With Aug)", ax=axes[1], color="red" + ) + axes[1].set_title("With Augmentation") + plt.savefig(f'{name}/region_aggregated_{name}{synthetic}.png') plt.close() fig, axis = plt.subplots(1, 2, figsize=(10, 4), sharex=True, layout="constrained") ax = calibration_curves_per_region(simulations, divi_data, ax=axis[0]) ax, stats = calibration_median_mad_over_regions(simulations, divi_data, ax=axis[1]) - plt.savefig(f'{name}/calibration_per_region_{name}{synthetic}{with_aug}.png') + plt.savefig(f'{name}/calibration_per_region_{name}{synthetic}.png') + plt.close() + fig, axis = plt.subplots(1, 2, figsize=(10, 4), sharex=True, layout="constrained") + ax = calibration_curves_per_region(simulations_aug, divi_data, ax=axis[0]) + ax, stats = calibration_median_mad_over_regions(simulations_aug, divi_data, ax=axis[1]) + plt.savefig(f'{name}/calibration_per_region_{name}{synthetic}_with_aug.png') plt.close() + plot_icu_on_germany(simulations, name, synthetic, with_aug="") + plot_icu_on_germany(simulations, name, synthetic, with_aug="_with_aug") - print("Fitted parameters ", name, synthetic, with_aug, ":") - for param in inference_params: - if param == 'damping_values': - plot_damping_values(samples['damping_values'][0].reshape((num_samples, 16, 3)), name, synthetic, with_aug) - else: - med = np.median(samples[param][0], axis=0) - mad = np.median(np.abs(samples[param][0] - med), axis=0) - print(param, ": med: ", med) - print(param, ": mad: ", mad) - - plot_icu_on_germany(simulations, name, synthetic, with_aug) - + samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], samples['damping_values'].shape[1], -1)) + validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) - # plot = bf.diagnostics.pairs_posterior(simulations, priors=validation_data, dataset_id=0) - # plot.savefig(f'pairs_posterior_wcovidparams_oct{synthetic}_ma7_noise.png') + plot = bf.diagnostics.pairs_posterior(samples, priors=validation_data, dataset_id=0) + plot.savefig(f'{name}/pairs_posterior_{name}{synthetic}.png') if __name__ == "__main__": @@ -911,6 +924,4 @@ def run_inference(name, num_samples=100, on_synthetic_data=False, apply_augmenta # create_train_data(filename=f'{name}/validation_data_{name}.pickle', number_samples=1000) # run_training(name=name, num_training_files=20) run_inference(name=name, on_synthetic_data=True) - run_inference(name=name, on_synthetic_data=True, apply_augmentation=False) - run_inference(name=name, on_synthetic_data=False) - run_inference(name=name, on_synthetic_data=False, apply_augmentation=False) \ No newline at end of file + run_inference(name=name, on_synthetic_data=False) \ No newline at end of file From 10bf8430e35c68fdc0ac8677e4f9b278716dda8a Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Mon, 27 Oct 2025 10:50:38 +0100 Subject: [PATCH 60/73] switch to conf bands instead of mad --- .../simulation/graph_germany_nuts3_clean.py | 49 +++++++++++-------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/pycode/examples/simulation/graph_germany_nuts3_clean.py b/pycode/examples/simulation/graph_germany_nuts3_clean.py index f087be5cfb..b4af571bb5 100644 --- a/pycode/examples/simulation/graph_germany_nuts3_clean.py +++ b/pycode/examples/simulation/graph_germany_nuts3_clean.py @@ -86,15 +86,25 @@ def plot_region_median_mad( x = np.arange(n_time) vals = data[:, :, region] # (samples, time_points) + + qs_50 = np.quantile(vals, q=[0.25, 0.75], axis=0) + qs_90 = np.quantile(vals, q=[0.05, 0.95], axis=0) + qs_95 = np.quantile(vals, q=[0.025, 0.975], axis=0) + med = np.median(vals, axis=0) - mad = np.median(np.abs(vals - med), axis=0) if ax is None: fig, ax = plt.subplots() - line, = ax.plot( + ax.plot( x, med, lw=2, label=label or f"{dd.County[region_ids[region]]}", color=color) - band = ax.fill_between(x, med - mad, med + mad, alpha=0.25, color=color) + ax.fill_between(x, qs_50[0], qs_50[1], alpha=0.5, + color=color, label="50% CI") + ax.fill_between(x, qs_90[0], qs_90[1], alpha=0.3, + color=color, label="90% CI") + ax.fill_between(x, qs_95[0], qs_95[1], alpha=0.1, + color=color, label="95% CI") + if true_data is not None: true_vals = true_data[:, region] # (time_points,) ax.plot(x, true_vals, lw=2, color="black", label="True data") @@ -104,7 +114,6 @@ def plot_region_median_mad( ax.set_title(f"{dd.County[region_ids[region]]}", fontsize=12) if label is not None: ax.legend(fontsize=11, loc="upper right") - return line, band def plot_aggregated_over_regions( @@ -124,26 +133,25 @@ def plot_aggregated_over_regions( # Aggregate over regions agg_over_regions = region_agg(data, axis=-1) # (samples, time_points) + qs_50 = np.quantile(agg_over_regions, q=[0.25, 0.75], axis=0) + qs_90 = np.quantile(agg_over_regions, q=[0.05, 0.95], axis=0) + qs_95 = np.quantile(agg_over_regions, q=[0.025, 0.975], axis=0) + # Aggregate over samples agg_median = np.median(agg_over_regions, axis=0) # (time_points, ) - agg_mad = np.median( - np.abs(agg_over_regions - agg_median[None]), - axis=0 - ) x = np.arange(agg_median.shape[0]) if ax is None: fig, ax = plt.subplots() - line, = ax.plot(x, agg_median, lw=2, + ax.plot(x, agg_median, lw=2, label=label or "Aggregated over regions", color=color) - band = ax.fill_between( - x, - agg_median - agg_mad, - agg_median + agg_mad, - alpha=0.25, - color=color - ) + ax.fill_between(x, qs_50[0], qs_50[1], alpha=0.5, + color=color, label="50% CI") + ax.fill_between(x, qs_90[0], qs_90[1], alpha=0.3, + color=color, label="90% CI") + ax.fill_between(x, qs_95[0], qs_95[1], alpha=0.1, + color=color, label="95% CI") if true_data is not None: true_vals = region_agg(true_data, axis=-1) # (time_points,) ax.plot(x, true_vals, lw=2, color="black", label="True data") @@ -152,7 +160,6 @@ def plot_aggregated_over_regions( ax.set_ylabel("ICU", fontsize=12) if label is not None: ax.legend(fontsize=11) - return line, band def plot_icu_on_germany(simulations, name, synthetic, with_aug): med = np.median(simulations, axis=0) @@ -220,7 +227,7 @@ def plot_all_regions(simulations, divi_data, name, synthetic, with_aug): ax = ax.flatten() for i, region_idx in enumerate(range(start_idx, end_idx)): plot_region_median_mad( - simulations, region=region_idx, true_data=divi_data, label=r"Median $\pm$ Mad", ax=ax[i] + simulations, region=region_idx, true_data=divi_data, label="Median", ax=ax[i], color="#132a70" ) # Hide unused subplots for i in range(end_idx - start_idx, len(ax)): @@ -801,7 +808,7 @@ def run_training(name, num_training_files=20): #plots['z_score_contraction'].savefig(f'{name}/z_score_contraction_{name}.png') -def run_inference(name, num_samples=10, on_synthetic_data=False): +def run_inference(name, num_samples=100, on_synthetic_data=False): val_path = f"{name}/validation_data_{name}.pickle" synthetic = "_synthetic" if on_synthetic_data else "" @@ -883,12 +890,12 @@ def run_inference(name, num_samples=10, on_synthetic_data=False): fig, axes = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=True, constrained_layout=True) # Plot without augmentation plot_aggregated_over_regions( - simulations, true_data=divi_data, label="Region Aggregated Median $\pm$ Mad (No Aug)", ax=axes[0], color="red" + simulations, true_data=divi_data, label="Region Aggregated Median (No Aug)", ax=axes[0], color="#132a70" ) axes[0].set_title("Without Augmentation") # Plot with augmentation plot_aggregated_over_regions( - simulations_aug, true_data=divi_data, label="Region Aggregated Median $\pm$ Mad (With Aug)", ax=axes[1], color="red" + simulations_aug, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[1], color="#132a70" ) axes[1].set_title("With Augmentation") plt.savefig(f'{name}/region_aggregated_{name}{synthetic}.png') From ecb0360cf22712435765ebc14e02795f879a2e71 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Mon, 27 Oct 2025 15:10:47 +0100 Subject: [PATCH 61/73] add metrics for evaluation --- .../simulation/graph_germany_nuts3_clean.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/pycode/examples/simulation/graph_germany_nuts3_clean.py b/pycode/examples/simulation/graph_germany_nuts3_clean.py index b4af571bb5..95fb58b666 100644 --- a/pycode/examples/simulation/graph_germany_nuts3_clean.py +++ b/pycode/examples/simulation/graph_germany_nuts3_clean.py @@ -901,7 +901,6 @@ def run_inference(name, num_samples=100, on_synthetic_data=False): plt.savefig(f'{name}/region_aggregated_{name}{synthetic}.png') plt.close() - fig, axis = plt.subplots(1, 2, figsize=(10, 4), sharex=True, layout="constrained") ax = calibration_curves_per_region(simulations, divi_data, ax=axis[0]) ax, stats = calibration_median_mad_over_regions(simulations, divi_data, ax=axis[1]) @@ -922,13 +921,23 @@ def run_inference(name, num_samples=100, on_synthetic_data=False): plot = bf.diagnostics.pairs_posterior(samples, priors=validation_data, dataset_id=0) plot.savefig(f'{name}/pairs_posterior_{name}{synthetic}.png') + rmse = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes(simulations, 0,1), divi_data, normalize=False) + rmse_aug = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes(simulations_aug, 0,1), divi_data, normalize=False) + print("Mean RMSE over regions:", rmse["values"].mean()) + print("Mean RMSE over regions (with aug):", rmse_aug["values"].mean()) + + cal_error = bf.diagnostics.metrics.calibration_error(np.swapaxes(simulations, 0,1), divi_data) + cal_error_aug = bf.diagnostics.metrics.calibration_error(np.swapaxes(simulations_aug, 0,1), divi_data) + print("Mean Calibration Error over regions:", cal_error["values"].mean()) + print("Mean Calibration Error over regions (with aug):", cal_error_aug["values"].mean()) + if __name__ == "__main__": name = "3dampings" if not os.path.exists(name): os.makedirs(name) - # create_train_data(filename=f'{name}/validation_data_{name}.pickle', number_samples=1000) + # create_train_data(filename=f'{name}/validation_data{name}.pickle', number_samples=1000) # run_training(name=name, num_training_files=20) run_inference(name=name, on_synthetic_data=True) run_inference(name=name, on_synthetic_data=False) \ No newline at end of file From adff2e1d1b4f3873fc61b5ec9b689f7537d02efc Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Thu, 30 Oct 2025 11:05:21 +0100 Subject: [PATCH 62/73] rearranging some plots and evaluation --- .../simulation/graph_germany_nuts3_clean.py | 116 ++++++++++-------- 1 file changed, 66 insertions(+), 50 deletions(-) diff --git a/pycode/examples/simulation/graph_germany_nuts3_clean.py b/pycode/examples/simulation/graph_germany_nuts3_clean.py index 95fb58b666..b4b670ce93 100644 --- a/pycode/examples/simulation/graph_germany_nuts3_clean.py +++ b/pycode/examples/simulation/graph_germany_nuts3_clean.py @@ -34,7 +34,6 @@ import memilio.simulation as mio import memilio.simulation.osecir as osecir -from memilio.simulation.osecir import Model, Simulation, interpolate_simulation_result from memilio.epidata import defaultDict as dd import geopandas as gpd @@ -62,9 +61,10 @@ 'mu_UD': (0.0, 0.4), 'transmission_prob': (0.0, 0.2) } -SPIKE_SCALE = 0.85 -SLAB_SCALE = 0.4 +SPIKE_SCALE = 0.4 +SLAB_SCALE = 0.2 DATE_TIME = datetime.date(year=2020, month=10, day=1) +NUM_DAMPING_POINTS = 3 # %% def plot_region_median_mad( @@ -73,7 +73,8 @@ def plot_region_median_mad( true_data=None, ax=None, label=None, - color="red" + color="red", + only_50q=False ): if data.ndim != 3: raise ValueError("Array not of shape (samples, time_points, regions)") @@ -100,10 +101,11 @@ def plot_region_median_mad( x, med, lw=2, label=label or f"{dd.County[region_ids[region]]}", color=color) ax.fill_between(x, qs_50[0], qs_50[1], alpha=0.5, color=color, label="50% CI") - ax.fill_between(x, qs_90[0], qs_90[1], alpha=0.3, - color=color, label="90% CI") - ax.fill_between(x, qs_95[0], qs_95[1], alpha=0.1, - color=color, label="95% CI") + if not only_50q: + ax.fill_between(x, qs_90[0], qs_90[1], alpha=0.3, + color=color, label="90% CI") + ax.fill_between(x, qs_95[0], qs_95[1], alpha=0.1, + color=color, label="95% CI") if true_data is not None: true_vals = true_data[:, region] # (time_points,) @@ -122,7 +124,8 @@ def plot_aggregated_over_regions( true_data=None, ax=None, label=None, - color='red' + color='red', + only_50q=False ): if data.ndim != 3: raise ValueError("Array not of shape (samples, time_points, regions)") @@ -148,10 +151,11 @@ def plot_aggregated_over_regions( label=label or "Aggregated over regions", color=color) ax.fill_between(x, qs_50[0], qs_50[1], alpha=0.5, color=color, label="50% CI") - ax.fill_between(x, qs_90[0], qs_90[1], alpha=0.3, - color=color, label="90% CI") - ax.fill_between(x, qs_95[0], qs_95[1], alpha=0.1, - color=color, label="95% CI") + if not only_50q: + ax.fill_between(x, qs_90[0], qs_90[1], alpha=0.3, + color=color, label="90% CI") + ax.fill_between(x, qs_95[0], qs_95[1], alpha=0.1, + color=color, label="95% CI") if true_data is not None: true_vals = region_agg(true_data, axis=-1) # (time_points,) ax.plot(x, true_vals, lw=2, color="black", label="True data") @@ -198,11 +202,9 @@ def plot_aggregated_to_federal_states(data, true_data, name, synthetic, with_aug ax = ax.flatten() for state in range(16): idxs = [i for i, region_id in enumerate(region_ids) if region_id // 1000 == state + 1] - state_data = np.sum(data[:, :, idxs], axis=-1) # Aggregate over regions in the state - true_state_data = np.sum(true_data[:, idxs], axis=-1) if true_data is not None else None plot_aggregated_over_regions( - state_data[:, :, None], # Add a dummy region axis for compatibility - true_data=true_state_data[:, None] if true_state_data is not None else None, + data[:, :, idxs], # Add a dummy region axis for compatibility + true_data=true_data[:, idxs] if true_data is not None else None, ax=ax[state], label=f"State {state + 1}", color=f"C{state % 10}" # Cycle through 10 colors @@ -412,7 +414,7 @@ def plot_damping_values(damping_values, name, synthetic, with_aug): plt.savefig(f"{name}/damping_values_combined{name}{synthetic}{with_aug}.png", dpi=300) -class Simulation: # todo: correct class? +class Simulation: """ """ def __init__(self, data_dir, start_date, results_dir): @@ -595,8 +597,8 @@ def run_germany_nuts3_simulation(damping_values, t_E, t_ISy, t_ISev, t_Cr, mu_CR return results def prior(): - damping_values = np.zeros((3, 16)) - for i in range(3): + damping_values = np.zeros((NUM_DAMPING_POINTS, 16)) + for i in range(NUM_DAMPING_POINTS): mean = np.random.uniform(0, 1) scale = 0.1 a, b = (0 - mean) / scale, (1 - mean) / scale @@ -741,18 +743,18 @@ def get_workflow(): ) summary_network = bf.networks.FusionTransformer( - summary_dim=(len(bounds)+16)*2, dropout=0.1 + summary_dim=(len(bounds)+16*NUM_DAMPING_POINTS)*2, dropout=0.1 ) - inference_network = bf.networks.FlowMatching() + inference_network = bf.networks.FlowMatching(subnet_kwargs={'widths': (512, 512, 512, 512, 512)}) - # aug = bf.augmentations.NNPE(spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False) + aug = bf.augmentations.NNPE(spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False) workflow = bf.BasicWorkflow( simulator=simulator, adapter=adapter, summary_network=summary_network, inference_network=inference_network, - standardize='all' - # augmentations={f'region{i}': aug for i in range(len(region_ids))} + standardize='all', + augmentations={f'region{i}': aug for i in range(len(region_ids)) if region_ids[i] not in no_icu_ids} ) return workflow @@ -762,16 +764,16 @@ def run_training(name, num_training_files=20): train_template = name+"/trainings_data{i}_"+name+".pickle" val_path = f"{name}/validation_data_{name}.pickle" - aug = bf.augmentations.NNPE( - spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False - ) + # aug = bf.augmentations.NNPE( + # spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False + # ) # training data train_files = [train_template.format(i=i) for i in range(1, 1+num_training_files)] trainings_data = None for p in train_files: d = load_pickle(p) - d = apply_aug(d, aug=aug) # only on region keys + # d = apply_aug(d, aug=aug) # only on region keys d = skip_2weeks(d) d['damping_values'] = d['damping_values'].reshape((d['damping_values'].shape[0], -1)) if trainings_data is None: @@ -781,7 +783,7 @@ def run_training(name, num_training_files=20): aggregate_states(trainings_data) # validation data - validation_data = apply_aug(load_pickle(val_path), aug=aug) + validation_data = load_pickle(val_path) validation_data = skip_2weeks(validation_data) aggregate_states(validation_data) validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) @@ -792,7 +794,7 @@ def run_training(name, num_training_files=20): print("inference_variables shape:", workflow.adapter(trainings_data)["inference_variables"].shape) history = workflow.fit_offline( - data=trainings_data, epochs=300, batch_size=64, validation_data=validation_data + data=trainings_data, epochs=500, batch_size=64, validation_data=validation_data ) workflow.approximator.save( @@ -808,7 +810,7 @@ def run_training(name, num_training_files=20): #plots['z_score_contraction'].savefig(f'{name}/z_score_contraction_{name}.png') -def run_inference(name, num_samples=100, on_synthetic_data=False): +def run_inference(name, num_samples=1000, on_synthetic_data=False): val_path = f"{name}/validation_data_{name}.pickle" synthetic = "_synthetic" if on_synthetic_data else "" @@ -844,12 +846,13 @@ def run_inference(name, num_samples=100, on_synthetic_data=False): filepath=os.path.join(f"{name}/model_{name}.keras") ) - if False: #os.path.exists(f'{name}/sims_{name}{synthetic}{with_aug}.pickle'): - simulations = load_pickle(f'{name}/sims_{name}{synthetic}{with_aug}.pickle') + if os.path.exists(f'{name}/sims_{name}{synthetic}_with_aug.pickle') and os.path.exists(f'{name}/sims_{name}{synthetic}.pickle'): + simulations = load_pickle(f'{name}/sims_{name}{synthetic}.pickle') + simulations_aug = load_pickle(f'{name}/sims_{name}{synthetic}_with_aug.pickle') print("loaded simulations from file") else: samples = workflow.sample(conditions=validation_data_skip2w, num_samples=num_samples) - samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], num_samples, 16, 3)) + samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], num_samples, 16, NUM_DAMPING_POINTS)) results = [] for i in range(num_samples): # we only have one dataset for inference here result = run_germany_nuts3_simulation( @@ -880,6 +883,11 @@ def run_inference(name, num_samples=100, on_synthetic_data=False): with open(f'{name}/sims_{name}{synthetic}_with_aug.pickle', 'wb') as f: pickle.dump(simulations_aug, f, pickle.HIGHEST_PROTOCOL) + samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], samples['damping_values'].shape[1], -1)) + validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) + + plot = bf.diagnostics.pairs_posterior(samples, priors=validation_data, dataset_id=0) + plot.savefig(f'{name}/pairs_posterior_{name}{synthetic}.png') plot_all_regions(simulations, divi_data, name, synthetic, with_aug="") plot_all_regions(simulations_aug, divi_data, name, synthetic, with_aug="_with_aug") @@ -887,17 +895,27 @@ def run_inference(name, num_samples=100, on_synthetic_data=False): plot_aggregated_to_federal_states(simulations, divi_data, name, synthetic, with_aug="") plot_aggregated_to_federal_states(simulations_aug, divi_data, name, synthetic, with_aug="_with_aug") - fig, axes = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=True, constrained_layout=True) + fig, axes = plt.subplots(2, 2, figsize=(12, 12), sharex=True, sharey='row', constrained_layout=True) # Plot without augmentation plot_aggregated_over_regions( - simulations, true_data=divi_data, label="Region Aggregated Median (No Aug)", ax=axes[0], color="#132a70" + simulations, true_data=divi_data, label="Region Aggregated Median (No Aug)", ax=axes[0, 0], color="#132a70" ) - axes[0].set_title("Without Augmentation") + axes[0, 0].set_title("Without Augmentation") # Plot with augmentation plot_aggregated_over_regions( - simulations_aug, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[1], color="#132a70" + simulations_aug, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[0, 1], color="#132a70" ) - axes[1].set_title("With Augmentation") + axes[0, 1].set_title("With Augmentation") + # Plot without augmentation (50% quantile only) + plot_aggregated_over_regions( + simulations, true_data=divi_data, label="Region Aggregated Median (No Aug)", ax=axes[1, 0], color="#132a70", only_50q=True + ) + axes[1, 0].set_title("Without Augmentation (50% Quantile)") + # Plot with augmentation (50% quantile only) + plot_aggregated_over_regions( + simulations_aug, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[1, 1], color="#132a70", only_50q=True + ) + axes[1, 1].set_title("With Augmentation (50% Quantile)") plt.savefig(f'{name}/region_aggregated_{name}{synthetic}.png') plt.close() @@ -915,25 +933,23 @@ def run_inference(name, num_samples=100, on_synthetic_data=False): plot_icu_on_germany(simulations, name, synthetic, with_aug="") plot_icu_on_germany(simulations, name, synthetic, with_aug="_with_aug") - samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], samples['damping_values'].shape[1], -1)) - validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) - - plot = bf.diagnostics.pairs_posterior(samples, priors=validation_data, dataset_id=0) - plot.savefig(f'{name}/pairs_posterior_{name}{synthetic}.png') + simulation_agg = np.sum(simulations, axis=-1, keepdims=True) # sum over regions + simulation_aug_agg = np.sum(simulations_aug, axis=-1, keepdims=True) + print(simulation_agg.shape, divi_data.shape) - rmse = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes(simulations, 0,1), divi_data, normalize=False) - rmse_aug = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes(simulations_aug, 0,1), divi_data, normalize=False) + rmse = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes(simulation_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True), normalize=False) + rmse_aug = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes(simulation_aug_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True), normalize=False) print("Mean RMSE over regions:", rmse["values"].mean()) print("Mean RMSE over regions (with aug):", rmse_aug["values"].mean()) - cal_error = bf.diagnostics.metrics.calibration_error(np.swapaxes(simulations, 0,1), divi_data) - cal_error_aug = bf.diagnostics.metrics.calibration_error(np.swapaxes(simulations_aug, 0,1), divi_data) + cal_error = bf.diagnostics.metrics.calibration_error(np.swapaxes(simulation_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True)) + cal_error_aug = bf.diagnostics.metrics.calibration_error(np.swapaxes(simulation_aug_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True)) print("Mean Calibration Error over regions:", cal_error["values"].mean()) print("Mean Calibration Error over regions (with aug):", cal_error_aug["values"].mean()) if __name__ == "__main__": - name = "3dampings" + name = "3dampings_lessnoise" if not os.path.exists(name): os.makedirs(name) From a5243044dfd80a2173012da701c7f00b652020b5 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Thu, 30 Oct 2025 16:25:53 +0100 Subject: [PATCH 63/73] adjust NUTS1 Simulation and plots --- .../simulation/graph_germany_nuts1.py | 836 +++++++++++++++--- 1 file changed, 718 insertions(+), 118 deletions(-) diff --git a/pycode/examples/simulation/graph_germany_nuts1.py b/pycode/examples/simulation/graph_germany_nuts1.py index 40389ee00f..8fc9a2d99a 100644 --- a/pycode/examples/simulation/graph_germany_nuts1.py +++ b/pycode/examples/simulation/graph_germany_nuts1.py @@ -17,18 +17,363 @@ # See the License for the specific language governing permissions and # limitations under the License. ############################################################################# -import numpy as np -import datetime import os -import memilio.simulation as mio -import memilio.simulation.osecir as osecir +os.environ["KERAS_BACKEND"] = "tensorflow" + +import numpy as np +import pandas as pd import matplotlib.pyplot as plt +import datetime +import pickle +from scipy.stats import truncnorm -from enum import Enum -from memilio.simulation.osecir import (Model, Simulation, - interpolate_simulation_result) +from matplotlib.patches import Patch -import pickle +import bayesflow as bf +import keras + +import memilio.simulation as mio +import memilio.simulation.osecir as osecir +from memilio.simulation.osecir import Model, interpolate_simulation_result +from memilio.epidata import defaultDict as dd + +import geopandas as gpd + + +region_ids = [region_id for region_id in dd.State.keys()] +inference_params = ['damping_values', 't_E', 't_ISy', 't_ISev', + 't_Cr', 'mu_CR', 'mu_IH', 'mu_HU', 'mu_UD', 'transmission_prob'] +summary_vars = [f'fed_state{i}' for i in range(16)] + +bounds = { + 't_E': (1.0, 5.2), + 't_ISy': (4.0, 10.0), + 't_ISev': (5.0, 10.0), + 't_Cr': (9.0, 17.0), + 'mu_CR': (0.0, 0.4), + 'mu_IH': (0.0, 0.2), + 'mu_HU': (0.0, 0.4), + 'mu_UD': (0.0, 0.4), + 'transmission_prob': (0.0, 0.2) +} +SPIKE_SCALE = 0.4 +SLAB_SCALE = 0.2 +DATE_TIME = datetime.date(year=2020, month=10, day=1) +NUM_DAMPING_POINTS = 3 + +def plot_region_fit( + data: np.ndarray, + region: int, + true_data=None, + ax=None, + label=None, + color="red", + only_50q=False +): + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data is not None: + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + n_samples, n_time, n_regions = data.shape + if not (0 <= region < n_regions): + raise IndexError + + x = np.arange(n_time) + vals = data[:, :, region] # (samples, time_points) + + qs_50 = np.quantile(vals, q=[0.25, 0.75], axis=0) + qs_90 = np.quantile(vals, q=[0.05, 0.95], axis=0) + qs_95 = np.quantile(vals, q=[0.025, 0.975], axis=0) + + med = np.median(vals, axis=0) + + if ax is None: + fig, ax = plt.subplots() + + ax.plot( + x, med, lw=2, label=label or f"{dd.State[region_ids[region]]}", color=color) + ax.fill_between(x, qs_50[0], qs_50[1], alpha=0.5, + color=color, label="50% CI") + if not only_50q: + ax.fill_between(x, qs_90[0], qs_90[1], alpha=0.3, + color=color, label="90% CI") + ax.fill_between(x, qs_95[0], qs_95[1], alpha=0.1, + color=color, label="95% CI") + + if true_data is not None: + true_vals = true_data[:, region] # (time_points,) + ax.plot(x, true_vals, lw=2, color="black", label="True data") + + ax.set_xlabel("Time", fontsize=12) + ax.set_ylabel("ICU", fontsize=12) + ax.set_title(f"{dd.State[region_ids[region]]}", fontsize=12) + if label is not None: + ax.legend(fontsize=11, loc="upper right") + + +def plot_aggregated_over_regions( + data: np.ndarray, + region_agg=np.sum, + true_data=None, + ax=None, + label=None, + color='red', + only_50q=False +): + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data is not None: + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + + # Aggregate over regions + agg_over_regions = region_agg(data, axis=-1) # (samples, time_points) + + qs_50 = np.quantile(agg_over_regions, q=[0.25, 0.75], axis=0) + qs_90 = np.quantile(agg_over_regions, q=[0.05, 0.95], axis=0) + qs_95 = np.quantile(agg_over_regions, q=[0.025, 0.975], axis=0) + + # Aggregate over samples + agg_median = np.median(agg_over_regions, axis=0) # (time_points, ) + + x = np.arange(agg_median.shape[0]) + if ax is None: + fig, ax = plt.subplots() + + ax.plot(x, agg_median, lw=2, + label=label or "Aggregated over regions", color=color) + ax.fill_between(x, qs_50[0], qs_50[1], alpha=0.5, + color=color, label="50% CI") + if not only_50q: + ax.fill_between(x, qs_90[0], qs_90[1], alpha=0.3, + color=color, label="90% CI") + ax.fill_between(x, qs_95[0], qs_95[1], alpha=0.1, + color=color, label="95% CI") + if true_data is not None: + true_vals = region_agg(true_data, axis=-1) # (time_points,) + ax.plot(x, true_vals, lw=2, color="black", label="True data") + + ax.set_xlabel("Time", fontsize=12) + ax.set_ylabel("ICU", fontsize=12) + if label is not None: + ax.legend(fontsize=11) + +def plot_icu_on_germany(simulations, name, synthetic, with_aug): + med = np.median(simulations, axis=0) + + population = pd.read_json('data/Germany/pydata/county_current_population_states.json') + values = med / population['Population'].to_numpy()[None, :] * 100000 + + map_data = gpd.read_file(os.path.join(os.getcwd(), 'tools/vg2500_12-31.utm32s.shape/vg2500/VG2500_KRS.shp')) + fedstate_data = gpd.read_file(os.path.join(os.getcwd(), 'tools/vg2500_12-31.utm32s.shape/vg2500/VG2500_LAN.shp')) + + plot_map(values[0], map_data, fedstate_data, "Median ICU", f"{name}/median_icu_germany_initial_{name}{synthetic}{with_aug}") + plot_map(values[-1], map_data, fedstate_data, "Median ICU", f"{name}/median_icu_germany_final_{name}{synthetic}{with_aug}") + +def plot_map(values, map_data, fedstate_data, label, filename): + fedstate_data[label] = fedstate_data['ARS'].map(dict(zip([f"{region_id:02d}" for region_id in region_ids], values))) + + fig, ax = plt.subplots(figsize=(12, 14)) + fedstate_data.plot( + column=f"{label}", + cmap='Reds', + linewidth=1., + ax=ax, + edgecolor='0.6', + legend=True, + legend_kwds={'label': f"{label} per Federal State", 'shrink': 0.6}, + ) + map_data.boundary.plot(ax=ax, color='black', linewidth=0.2) + ax.set_title(f"{label} per Federal State") + ax.axis('off') + plt.savefig(filename, bbox_inches='tight', dpi=300) + +def plot_all_regions(simulations, divi_data, name, synthetic, with_aug): + n_regions = simulations.shape[-1] + fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(25, 25), layout="constrained") + ax = ax.flatten() + for i in range(n_regions): + plot_region_fit( + simulations, region=i, true_data=divi_data, label="Median", ax=ax[i], color="#132a70" + ) + plt.savefig(f'{name}/federal_states_{name}{synthetic}{with_aug}.png') + plt.close() + +def calibration_curves_per_region( + data: np.ndarray, + true_data: np.ndarray, + levels=np.linspace(0.01, 0.99, 20), + ax=None, + max_regions=None, + cmap=plt.cm.Blues, + linewidth=1.5, + legend=True, + with_ideal=True, +): + """ + Per-region calibration curves, each region in a different shade of a colormap. + + data: (samples, time, regions) + true_data: (time, regions) + max_regions: limit number of regions shown + cmap: matplotlib colormap for line shades + """ + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + + n_samples, n_time, n_regions = data.shape + if max_regions is None: + R = n_regions + else: + R = min(max_regions, n_regions) + + if ax is None: + fig, ax = plt.subplots(figsize=(5, 4)) + + colors = [cmap(i / (R + 1)) for i in range(1, R + 1)] + + x = np.asarray(levels) + for r, col in zip(range(R), colors): + emp = [] + for nominal in levels: + q_low = (1.0 - nominal) / 2.0 + q_high = 1.0 - q_low + lo = np.quantile(data[:, :, r], q_low, axis=0) + hi = np.quantile(data[:, :, r], q_high, axis=0) + hits = (true_data[:, r] >= lo) & (true_data[:, r] <= hi) + emp.append(hits.mean()) + emp = np.asarray(emp) + # , label=f"Region {r+1}") + ax.plot(x, emp, lw=linewidth, color=col, alpha=0.5) + + if with_ideal: + ideal_line = ax.plot([0, 1], [0, 1], linestyle="--", + lw=1.2, color="black", label="Ideal")[0] + else: + ideal_line = None + + if legend: + # Custom legend: one patch for regions, one line for ideal + region_patch = Patch(color=colors[-1], label="Regions") + ax.legend(handles=[region_patch, ideal_line], + frameon=True, ncol=1, fontsize=12) + + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.set_xlabel("Nominal level", fontsize=12) + ax.set_ylabel("Empirical coverage", fontsize=12) + ax.set_title("Calibration per region", fontsize=12) + return ax + + +def calibration_median_mad_over_regions( + data: np.ndarray, + true_data: np.ndarray, + levels=np.linspace(0.01, 0.99, 20), + ax=None, + color="tab:blue", + alpha=0.25, + linewidth=2.0, + with_ideal=True, +): + """ + Compute per region empirical coverage at each nominal level, + then summarize across regions with median and MAD. + + data: (samples, time, regions) + true_data: (time, regions) + """ + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + + n_samples, n_time, n_regions = data.shape + L = len(levels) + per_region = np.empty((n_regions, L), dtype=float) + + # per level coverage per region + for j, nominal in enumerate(levels): + q_low = (1.0 - nominal) / 2.0 + q_high = 1.0 - q_low + lo = np.quantile(data, q_low, axis=0) # (time, regions) + hi = np.quantile(data, q_high, axis=0) # (time, regions) + hits = (true_data >= lo) & (true_data <= hi) # (time, regions) + # mean over time for each region + per_region[:, j] = hits.mean(axis=0) + + med = np.median(per_region, axis=0) # (levels,) + mad = np.median(np.abs(per_region - med[None, :]), axis=0) + + if ax is None: + fig, ax = plt.subplots(figsize=(5, 4)) + + x = np.asarray(levels) + ax.fill_between(x, med - mad, med + mad, + alpha=alpha, color=color, label=None) + ax.plot(x, med, lw=linewidth, color=color, label="Median across regions") + + if with_ideal: + ax.plot([0, 1], [0, 1], linestyle="--", + lw=1.2, color="black", label="Ideal") + + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.set_xlabel("Nominal level", fontsize=12) + ax.set_ylabel("Empirical coverage", fontsize=12) + ax.set_title("Calibration median and MAD across regions", fontsize=12) + ax.legend(fontsize=12) + return ax, {"levels": x, "median": med, "mad": mad} + + +def plot_damping_values(damping_values, name, synthetic, with_aug): + + med = np.median(damping_values, axis=0) + mad = np.median(np.abs(damping_values - med), axis=0) + + # Extend for step plotting + med_extended = np.hstack([med, med[:, -1][:, None]]) + mad_extended = np.hstack([mad, mad[:, -1][:, None]]) + + # Plot damping values per region + fig, axes = plt.subplots(4, 4, figsize=(15, 10), constrained_layout=True) + axes = axes.flatten() + x = np.arange(15, 61, 15) # Time steps from 15 to 60 + + for i, ax in enumerate(axes): + if i < 16: + ax.stairs(med[i], edges=x, lw=2, color='red', baseline=None) + ax.fill_between( + x, med_extended[i] - mad_extended[i], med_extended[i] + mad_extended[i], + alpha=0.25, color='red', step='post' + ) + ax.set_title(f"{dd.State[i+1]}", fontsize=10) + ax.set_xlabel("Time", fontsize=8) + ax.set_ylabel("Damping Value", fontsize=8) + else: + ax.axis('off') # Hide unused subplots + + plt.suptitle("Damping Values per Region", fontsize=14) + plt.savefig(f"{name}/damping_values{name}{synthetic}{with_aug}.png", dpi=300) + + # Combined plot for all regions + fig, ax = plt.subplots(figsize=(10, 6)) + cmap = plt.cm.get_cmap("viridis", 16) # Colormap with 16 distinct colors + + for i in range(16): + ax.stairs( + med[i], edges=x, lw=2, label=f"{dd.State[i+1]}", + color=cmap(i), baseline=None + ) + + ax.set_title("Damping Values per Region (Combined)", fontsize=14) + ax.set_xlabel("Time", fontsize=12) + ax.set_ylabel("Damping Value", fontsize=12) + ax.legend(fontsize=10, loc="upper right", ncol=2) + plt.savefig(f"{name}/damping_values_combined{name}{synthetic}{with_aug}.png", dpi=300) class Simulation: @@ -42,28 +387,28 @@ def __init__(self, data_dir, start_date, results_dir): if not os.path.exists(self.results_dir): os.makedirs(self.results_dir) - def set_covid_parameters(self, model): + def set_covid_parameters(self, model, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): """ :param model: """ - model.parameters.TimeExposed[mio.AgeGroup(0)] = 3.335 - model.parameters.TimeInfectedNoSymptoms[mio.AgeGroup(0)] = 2.58916 - model.parameters.TimeInfectedSymptoms[mio.AgeGroup(0)] = 6.94547 - model.parameters.TimeInfectedSevere[mio.AgeGroup(0)] = 7.28196 - model.parameters.TimeInfectedCritical[mio.AgeGroup(0)] = 13.066 + model.parameters.TimeExposed[mio.AgeGroup(0)] = t_E + model.parameters.TimeInfectedNoSymptoms[mio.AgeGroup(0)] = 5.2 - t_E # todo: correct? + model.parameters.TimeInfectedSymptoms[mio.AgeGroup(0)] = t_ISy + model.parameters.TimeInfectedSevere[mio.AgeGroup(0)] = t_ISev + model.parameters.TimeInfectedCritical[mio.AgeGroup(0)] = t_Cr # probabilities model.parameters.TransmissionProbabilityOnContact[mio.AgeGroup( - 0)] = 0.07333 + 0)] = transmission_prob model.parameters.RelativeTransmissionNoSymptoms[mio.AgeGroup(0)] = 1 model.parameters.RecoveredPerInfectedNoSymptoms[mio.AgeGroup( - 0)] = 0.2069 - model.parameters.SeverePerInfectedSymptoms[mio.AgeGroup(0)] = 0.07864 - model.parameters.CriticalPerSevere[mio.AgeGroup(0)] = 0.17318 - model.parameters.DeathsPerCritical[mio.AgeGroup(0)] = 0.21718 + 0)] = mu_CR + model.parameters.SeverePerInfectedSymptoms[mio.AgeGroup(0)] = mu_IH + model.parameters.CriticalPerSevere[mio.AgeGroup(0)] = mu_HU + model.parameters.DeathsPerCritical[mio.AgeGroup(0)] = mu_UD # start day is set to the n-th day of the year model.parameters.StartDay = self.start_date.timetuple().tm_yday @@ -79,27 +424,37 @@ def set_contact_matrices(self, model): contact_matrices = mio.ContactMatrixGroup(1, self.num_groups) baseline = np.ones((self.num_groups, self.num_groups)) * 7.95 - minimum = np.ones((self.num_groups, self.num_groups)) * 0 + minimum = np.zeros((self.num_groups, self.num_groups)) contact_matrices[0] = mio.ContactMatrix(baseline, minimum) model.parameters.ContactPatterns.cont_freq_mat = contact_matrices - def set_npis(self, params, end_date, damping_value): + def set_npis(self, params, end_date, damping_values): """ :param params: :param end_date: """ + start_damping_1 = DATE_TIME + datetime.timedelta(days=15) + start_damping_2 = DATE_TIME + datetime.timedelta(days=30) + start_damping_3 = DATE_TIME + datetime.timedelta(days=45) - start_damping = datetime.date( - 2020, 12, 18) + if start_damping_1 < end_date: + start_date = (start_damping_1 - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping( + mio.Damping(np.r_[damping_values[0]], t=start_date)) + + if start_damping_2 < end_date: + start_date = (start_damping_2 - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping( + mio.Damping(np.r_[damping_values[1]], t=start_date)) - if start_damping < end_date: - start_date = (start_damping - self.start_date).days + if start_damping_3 < end_date: + start_date = (start_damping_3 - self.start_date).days params.ContactPatterns.cont_freq_mat[0].add_damping( - mio.Damping(np.r_[damping_value], t=start_date)) + mio.Damping(np.r_[damping_values[2]], t=start_date)) - def get_graph(self, end_date): + def get_graph(self, end_date, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): """ :param end_date: @@ -107,7 +462,8 @@ def get_graph(self, end_date): """ print("Initializing model...") model = Model(self.num_groups) - self.set_covid_parameters(model) + self.set_covid_parameters( + model, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) self.set_contact_matrices(model) print("Model initialized.") @@ -115,7 +471,6 @@ def get_graph(self, end_date): scaling_factor_infected = [2.5] scaling_factor_icu = 1.0 - tnt_capacity_factor = 7.5 / 100000. data_dir_Germany = os.path.join(self.data_dir, "Germany") mobility_data_file = os.path.join( @@ -136,14 +491,13 @@ def get_graph(self, end_date): scaling_factor_icu, 0, 0, False) print("Setting edges...") - mio.osecir.set_edges( - mobility_data_file, graph, 1) + mio.osecir.set_edges(mobility_data_file, graph, 1) print("Graph created.") return graph - def run(self, num_days_sim, damping_values, save_graph=True): + def run(self, num_days_sim, damping_values, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob, save_graph=True): """ :param num_days_sim: @@ -155,25 +509,18 @@ def run(self, num_days_sim, damping_values, save_graph=True): mio.set_log_level(mio.LogLevel.Warning) end_date = self.start_date + datetime.timedelta(days=num_days_sim) - graph = self.get_graph(end_date) - - if save_graph: - path_graph = os.path.join(self.results_dir, "graph") - if not os.path.exists(path_graph): - os.makedirs(path_graph) - osecir.write_graph(graph, path_graph) + graph = self.get_graph(end_date, t_E, t_ISy, t_ISev, + t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) mobility_graph = osecir.MobilityGraph() for node_idx in range(graph.num_nodes): - # if node_idx < 5: node = graph.get_node(node_idx) - self.set_npis(node.property.parameters, - end_date, damping_values[node_idx]) + + self.set_npis( + node.property.parameters, + end_date, + damping_values[node_idx]) mobility_graph.add_node(node.id, node.property) - # else: - # node = graph.get_node(node_idx) - # self.set_npis(node.property.parameters, end_date, 0.5) - # mobility_graph.add_node(node.id, node.property) for edge_idx in range(graph.num_edges): mobility_graph.add_edge( graph.get_edge(edge_idx).start_node_idx, @@ -182,93 +529,346 @@ def run(self, num_days_sim, damping_values, save_graph=True): mobility_sim = osecir.MobilitySimulation(mobility_graph, t0=0, dt=0.5) mobility_sim.advance(num_days_sim) - results = [] - for node_idx in range(graph.num_nodes): - results.append(osecir.interpolate_simulation_result( - mobility_sim.graph.get_node(node_idx).property.result)) + results = {} + for node_idx in range(mobility_sim.graph.num_nodes): + results[f'fed_state{node_idx}'] = osecir.interpolate_simulation_result( + mobility_sim.graph.get_node(node_idx).property.result) return results -def run_germany_nuts1_simulation(damping_values): +def run_germany_nuts1_simulation(damping_values, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): mio.set_log_level(mio.LogLevel.Warning) file_path = os.path.dirname(os.path.abspath(__file__)) sim = Simulation( data_dir=os.path.join(file_path, "../../../data"), - start_date=datetime.date(year=2020, month=12, day=12), + start_date=DATE_TIME, results_dir=os.path.join(file_path, "../../../results_osecir")) - num_days_sim = 50 + num_days_sim = 60 - results = sim.run(num_days_sim, damping_values) - results[0].export_csv('test.csv') - - return {f'region{region}': results[region] for region in range(len(results))} + results = sim.run(num_days_sim, damping_values, t_E, t_ISy, + t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob, save_graph=False) + results['fed_state0'].export_csv('test.csv') + return results def prior(): - damping_values = np.random.uniform(0.0, 1.0, 16) - return {'damping_values': damping_values} - - -if __name__ == "__main__": - test = prior() - run_germany_nuts1_simulation(test['damping_values']) - - # import os - # os.environ["KERAS_BACKEND"] = "tensorflow" - - # import bayesflow as bf - - # simulator = bf.simulators.make_simulator([prior, run_germany_nuts1_simulation]) - # trainings_data = simulator.sample(1) - - # for region in range(16): - # trainings_data[f'region{region}'] = trainings_data[f'region{region}'][:,:, 8][..., np.newaxis] - - # with open('validation_data_16param.pickle', 'wb') as f: - # pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) - - # with open('trainings_data1_16param.pickle', 'rb') as f: - # trainings_data = pickle.load(f) - # trainings_data['damping_values'] = trainings_data['damping_values'][:, :16] - # for i in range(9): - # with open(f'trainings_data{i+2}_16param.pickle', 'rb') as f: - # data = pickle.load(f) - # data['damping_values'] = data['damping_values'][:, :16] - # trainings_data = {k: np.concatenate([trainings_data[k], data[k]]) for k in trainings_data.keys()} - - # with open('validation_data_16param.pickle', 'rb') as f: - # validation_data = pickle.load(f) - # validation_data['damping_values'] = validation_data['damping_values'][:, :16] - - # adapter = ( - # bf.Adapter() - # .to_array() - # .convert_dtype("float64", "float32") - # .constrain("damping_values", lower=0.0, upper=1.0) - # .rename("damping_values", "inference_variables") - # .concatenate([f'region{i}' for i in range(16)], into="summary_variables", axis=-1) - # .log("summary_variables", p1=True) - # ) - - # summary_network = bf.networks.TimeSeriesNetwork(summary_dim=32) - # inference_network = bf.networks.CouplingFlow() - - # workflow = bf.BasicWorkflow( - # simulator=simulator, - # adapter=adapter, - # summary_network=summary_network, - # inference_network=inference_network, - # standardize='all' + damping_values = np.zeros((NUM_DAMPING_POINTS, 16)) + for i in range(NUM_DAMPING_POINTS): + mean = np.random.uniform(0, 1) + scale = 0.1 + a, b = (0 - mean) / scale, (1 - mean) / scale + damping_values[i] = truncnorm.rvs( + a=a, b=b, loc=mean, scale=scale, size=16 + ) + return { + 'damping_values': np.transpose(damping_values), + 't_E': np.random.uniform(*bounds['t_E']), + 't_ISy': np.random.uniform(*bounds['t_ISy']), + 't_ISev': np.random.uniform(*bounds['t_ISev']), + 't_Cr': np.random.uniform(*bounds['t_Cr']), + 'mu_CR': np.random.uniform(*bounds['mu_CR']), + 'mu_IH': np.random.uniform(*bounds['mu_IH']), + 'mu_HU': np.random.uniform(*bounds['mu_HU']), + 'mu_UD': np.random.uniform(*bounds['mu_UD']), + 'transmission_prob': np.random.uniform(*bounds['transmission_prob']) + } + + +def load_divi_data(): + file_path = os.path.dirname(os.path.abspath(__file__)) + divi_path = os.path.join(file_path, "../../../data/Germany/pydata") + + data = pd.read_json(os.path.join(divi_path, "state_divi_ma7.json")) + data = data[data['Date'] >= np.datetime64(DATE_TIME)] + data = data[data['Date'] <= np.datetime64(DATE_TIME + datetime.timedelta(days=60))] + data = data.sort_values(by=['ID_State', 'Date']) + divi_data = data.pivot(index='Date', columns='ID_State', values='ICU') + divi_dict = {} + for i in range(len(region_ids)): + divi_dict[f"fed_state{i}"] = divi_data[region_ids[i]].to_numpy()[None, :, None] + return divi_dict + + + +def extract_observables(simulation_results, observable_index=7): + for key in simulation_results.keys(): + if key not in inference_params: + simulation_results[key] = simulation_results[key][:, :, observable_index][..., np.newaxis] + return simulation_results + + +def create_train_data(filename, number_samples=1000): + + simulator = bf.simulators.make_simulator( + [prior, run_germany_nuts1_simulation] + ) + trainings_data = simulator.sample(number_samples) + trainings_data = extract_observables(trainings_data) + with open(filename, 'wb') as f: + pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) + + +def load_pickle(path): + with open(path, "rb") as f: + return pickle.load(f) + +def is_state_key(k: str) -> bool: + return 'fed_state' in k + +def apply_aug(d: dict, aug) -> dict: + return {k: np.clip(aug(v), 0, None) if is_state_key(k) else v for k, v in d.items()} + +def concat_dicts(base: dict, new: dict) -> dict: + missing = set(base) - set(new) + if missing: + raise KeyError(f"new dict missing keys: {sorted(missing)}") + for k in base: + base[k] = np.concatenate([base[k], new[k]]) + return base + + +def combine_results(dict_list): + combined = {} + for d in dict_list: + combined = concat_dicts(combined, d) if combined else d + return combined + +def skip_2weeks(d:dict) -> dict: + return {k: v[:, 14:, :] if is_state_key(k) else v for k, v in d.items()} + +def get_workflow(): + + simulator = bf.make_simulator( + [prior, run_germany_nuts1_simulation] + ) + adapter = ( + bf.Adapter() + .to_array() + .convert_dtype("float64", "float32") + .constrain("damping_values", lower=0.0, upper=1.0) + .constrain("t_E", lower=bounds["t_E"][0], upper=bounds["t_E"][1]) + .constrain("t_ISy", lower=bounds["t_ISy"][0], upper=bounds["t_ISy"][1]) + .constrain("t_ISev", lower=bounds["t_ISev"][0], upper=bounds["t_ISev"][1]) + .constrain("t_Cr", lower=bounds["t_Cr"][0], upper=bounds["t_Cr"][1]) + .constrain("mu_CR", lower=bounds["mu_CR"][0], upper=bounds["mu_CR"][1]) + .constrain("mu_IH", lower=bounds["mu_IH"][0], upper=bounds["mu_IH"][1]) + .constrain("mu_HU", lower=bounds["mu_HU"][0], upper=bounds["mu_HU"][1]) + .constrain("mu_UD", lower=bounds["mu_UD"][0], upper=bounds["mu_UD"][1]) + .constrain("transmission_prob", lower=bounds["transmission_prob"][0], upper=bounds["transmission_prob"][1]) + .concatenate( + ["damping_values", "t_E", "t_ISy", "t_ISev", "t_Cr", + "mu_CR", "mu_IH", "mu_HU", "mu_UD", "transmission_prob"], + into="inference_variables", + axis=-1 + ) + .concatenate(summary_vars, into="summary_variables", axis=-1) + ) + + summary_network = bf.networks.FusionTransformer( + summary_dim=(len(bounds)+16*NUM_DAMPING_POINTS)*2, dropout=0.1 + ) + inference_network = bf.networks.FlowMatching(subnet_kwargs={'widths': (512, 512, 512, 512, 512)}) + + aug = bf.augmentations.NNPE(spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False) + workflow = bf.BasicWorkflow( + simulator=simulator, + adapter=adapter, + summary_network=summary_network, + inference_network=inference_network, + standardize='all', + augmentations={f'fed_state{i}': aug for i in range(len(region_ids))} + ) + + return workflow + + +def run_training(name, num_training_files=20): + train_template = name+"/trainings_data{i}_"+name+".pickle" + val_path = f"{name}/validation_data_{name}.pickle" + + # aug = bf.augmentations.NNPE( + # spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False # ) - # history = workflow.fit_offline(data=trainings_data, epochs=100, batch_size=32, validation_data=validation_data) + # training data + train_files = [train_template.format(i=i) for i in range(1, 1+num_training_files)] + trainings_data = None + for p in train_files: + d = load_pickle(p) + # d = apply_aug(d, aug=aug) # only on region keys + d = skip_2weeks(d) + d['damping_values'] = d['damping_values'].reshape((d['damping_values'].shape[0], -1)) + if trainings_data is None: + trainings_data = d + else: + trainings_data = concat_dicts(trainings_data, d) + + # validation data + validation_data = load_pickle(val_path)#apply_aug(load_pickle(val_path), aug=aug) + validation_data = skip_2weeks(validation_data) + validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) + + # check data + workflow = get_workflow() + print("summary_variables shape:", workflow.adapter(trainings_data)["summary_variables"].shape) + print("inference_variables shape:", workflow.adapter(trainings_data)["inference_variables"].shape) + + history = workflow.fit_offline( + data=trainings_data, epochs=50, batch_size=64, validation_data=validation_data + ) + + workflow.approximator.save( + filepath=os.path.join(f"{name}/model_{name}.keras") + ) + + plots = workflow.plot_default_diagnostics( + test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True} + ) + plots['losses'].savefig(f'{name}/losses_{name}.png') + plots['recovery'].savefig(f'{name}/recovery_{name}.png') + plots['calibration_ecdf'].savefig(f'{name}/calibration_ecdf_{name}.png') + #plots['z_score_contraction'].savefig(f'{name}/z_score_contraction_{name}.png') + + +def run_inference(name, num_samples=10, on_synthetic_data=False): + val_path = f"{name}/validation_data_{name}.pickle" + synthetic = "_synthetic" if on_synthetic_data else "" + + aug = bf.augmentations.NNPE( + spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False + ) + + # validation data + validation_data = load_pickle(val_path) + if on_synthetic_data: + # validation data + validation_data = apply_aug(validation_data, aug=aug) + validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) + validation_data_skip2w = skip_2weeks(validation_data) + divi_dict = validation_data + print(np.sum([divi_dict[f'fed_state{i}'] for i in range(len(region_ids))])) + + divi_data = np.concatenate( + [divi_dict[f'fed_state{i}'] for i in range(len(region_ids))], axis=-1 + )[0] # only one dataset + else: + divi_dict = load_divi_data() + validation_data_skip2w = skip_2weeks(divi_dict) + divi_data = np.concatenate( + [divi_dict[f'fed_state{i}'] for i in range(len(region_ids))], axis=-1 + )[0] + + workflow = get_workflow() + workflow.approximator = keras.models.load_model( + filepath=os.path.join(f"{name}/model_{name}.keras") + ) + + if os.path.exists(f'{name}/sims_{name}{synthetic}_with_aug.pickle') and os.path.exists(f'{name}/sims_{name}{synthetic}.pickle'): + simulations = load_pickle(f'{name}/sims_{name}{synthetic}.pickle') + simulations_aug = load_pickle(f'{name}/sims_{name}{synthetic}_with_aug.pickle') + print("loaded simulations from file") + else: + samples = workflow.sample(conditions=validation_data_skip2w, num_samples=num_samples) + samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], num_samples, 16, NUM_DAMPING_POINTS)) + results = [] + for i in range(num_samples): # we only have one dataset for inference here + result = run_germany_nuts1_simulation( + damping_values=samples['damping_values'][0, i], + t_E=samples['t_E'][0, i], t_ISy=samples['t_ISy'][0, i], + t_ISev=samples['t_ISev'][0, i], t_Cr=samples['t_Cr'][0, i], + mu_CR=samples['mu_CR'][0, i], mu_IH=samples['mu_IH'][0, i], + mu_HU=samples['mu_HU'][0, i], mu_UD=samples['mu_UD'][0, i], + transmission_prob=samples['transmission_prob'][0, i] + ) + for key in result.keys(): + result[key] = np.array(result[key])[None, ...] # add sample axis + results.append(result) + results = combine_results(results) + results = extract_observables(results) + results_aug = apply_aug(results, aug=aug) + + # get sims in shape (samples, time, regions) + simulations = np.zeros((num_samples, divi_data.shape[0], divi_data.shape[1])) + simulations_aug = np.zeros((num_samples, divi_data.shape[0], divi_data.shape[1])) + for i in range(num_samples): + simulations[i] = np.concatenate([results[key][i] for key in results.keys()], axis=-1) + simulations_aug[i] = np.concatenate([results_aug[key][i] for key in results.keys()], axis=-1) + + # save sims + with open(f'{name}/sims_{name}{synthetic}.pickle', 'wb') as f: + pickle.dump(simulations, f, pickle.HIGHEST_PROTOCOL) + with open(f'{name}/sims_{name}{synthetic}_with_aug.pickle', 'wb') as f: + pickle.dump(simulations_aug, f, pickle.HIGHEST_PROTOCOL) + + samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], samples['damping_values'].shape[1], -1)) + validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) + + plot = bf.diagnostics.pairs_posterior(samples, priors=validation_data, dataset_id=0) + plot.savefig(f'{name}/pairs_posterior_{name}{synthetic}.png') + + plot_all_regions(simulations, divi_data, name, synthetic, with_aug="") + plot_all_regions(simulations_aug, divi_data, name, synthetic, with_aug="_with_aug") + + fig, axes = plt.subplots(2, 2, figsize=(12, 12), sharex=True, sharey='row', constrained_layout=True) + # Plot without augmentation + plot_aggregated_over_regions( + simulations, true_data=divi_data, label="Region Aggregated Median (No Aug)", ax=axes[0, 0], color="#132a70" + ) + axes[0, 0].set_title("Without Augmentation") + # Plot with augmentation + plot_aggregated_over_regions( + simulations_aug, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[0, 1], color="#132a70" + ) + axes[0, 1].set_title("With Augmentation") + # Plot without augmentation (50% quantile only) + plot_aggregated_over_regions( + simulations, true_data=divi_data, label="Region Aggregated Median (No Aug)", ax=axes[1, 0], color="#132a70", only_50q=True + ) + axes[1, 0].set_title("Without Augmentation (50% Quantile)") + # Plot with augmentation (50% quantile only) + plot_aggregated_over_regions( + simulations_aug, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[1, 1], color="#132a70", only_50q=True + ) + axes[1, 1].set_title("With Augmentation (50% Quantile)") + plt.savefig(f'{name}/region_aggregated_{name}{synthetic}.png') + plt.close() + + fig, axis = plt.subplots(1, 2, figsize=(10, 4), sharex=True, layout="constrained") + ax = calibration_curves_per_region(simulations, divi_data, ax=axis[0]) + ax, stats = calibration_median_mad_over_regions(simulations, divi_data, ax=axis[1]) + plt.savefig(f'{name}/calibration_per_region_{name}{synthetic}.png') + plt.close() + fig, axis = plt.subplots(1, 2, figsize=(10, 4), sharex=True, layout="constrained") + ax = calibration_curves_per_region(simulations_aug, divi_data, ax=axis[0]) + ax, stats = calibration_median_mad_over_regions(simulations_aug, divi_data, ax=axis[1]) + plt.savefig(f'{name}/calibration_per_region_{name}{synthetic}_with_aug.png') + plt.close() + + plot_icu_on_germany(simulations, name, synthetic, with_aug="") + plot_icu_on_germany(simulations_aug, name, synthetic, with_aug="_with_aug") + + simulation_agg = np.sum(simulations, axis=-1, keepdims=True) # sum over regions + simulation_aug_agg = np.sum(simulations_aug, axis=-1, keepdims=True) + + rmse = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes(simulation_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True), normalize=False) + rmse_aug = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes(simulation_aug_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True), normalize=False) + print("Mean RMSE over regions:", rmse["values"].mean()) + print("Mean RMSE over regions (with aug):", rmse_aug["values"].mean()) + + cal_error = bf.diagnostics.metrics.calibration_error(np.swapaxes(simulation_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True)) + cal_error_aug = bf.diagnostics.metrics.calibration_error(np.swapaxes(simulation_aug_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True)) + print("Mean Calibration Error over regions:", cal_error["values"].mean()) + print("Mean Calibration Error over regions (with aug):", cal_error_aug["values"].mean()) - # # workflow.approximator.save(filepath=os.path.join(os.path.dirname(__file__), "model_1params.keras")) - # plots = workflow.plot_default_diagnostics(test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True}) - # plots['losses'].savefig('losses_couplingflow_16param.png') - # plots['recovery'].savefig('recovery_couplingflow_16param.png') - # plots['calibration_ecdf'].savefig('calibration_ecdf_couplingflow_16param.png') - # plots['z_score_contraction'].savefig('z_score_contraction_couplingflow_16param.png') +if __name__ == "__main__": + name = "nuts1" + + if not os.path.exists(name): + os.makedirs(name) + create_train_data(filename=f'{name}/trainings_data1_{name}.pickle', number_samples=1000) + # run_training(name=name, num_training_files=20) + # run_inference(name=name, on_synthetic_data=True) + # run_inference(name=name, on_synthetic_data=False) \ No newline at end of file From 0972ae1181b664862be2c3d806b5e9cf1dfaa5c4 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Sun, 9 Nov 2025 17:45:43 +0100 Subject: [PATCH 64/73] final setting for fitting on county level --- .../simulation/graph_germany_nuts3_clean.py | 141 ++++++++++++++---- 1 file changed, 108 insertions(+), 33 deletions(-) diff --git a/pycode/examples/simulation/graph_germany_nuts3_clean.py b/pycode/examples/simulation/graph_germany_nuts3_clean.py index b4b670ce93..1e970127d6 100644 --- a/pycode/examples/simulation/graph_germany_nuts3_clean.py +++ b/pycode/examples/simulation/graph_germany_nuts3_clean.py @@ -34,6 +34,7 @@ import memilio.simulation as mio import memilio.simulation.osecir as osecir +from memilio.simulation.osecir import Model, interpolate_simulation_result from memilio.epidata import defaultDict as dd import geopandas as gpd @@ -66,15 +67,14 @@ DATE_TIME = datetime.date(year=2020, month=10, day=1) NUM_DAMPING_POINTS = 3 -# %% -def plot_region_median_mad( +def plot_region_fit( data: np.ndarray, region: int, true_data=None, ax=None, label=None, color="red", - only_50q=False + only_80q=False ): if data.ndim != 3: raise ValueError("Array not of shape (samples, time_points, regions)") @@ -88,7 +88,7 @@ def plot_region_median_mad( x = np.arange(n_time) vals = data[:, :, region] # (samples, time_points) - qs_50 = np.quantile(vals, q=[0.25, 0.75], axis=0) + qs_80 = np.quantile(vals, q=[0.1, 0.9], axis=0) qs_90 = np.quantile(vals, q=[0.05, 0.95], axis=0) qs_95 = np.quantile(vals, q=[0.025, 0.975], axis=0) @@ -99,9 +99,9 @@ def plot_region_median_mad( ax.plot( x, med, lw=2, label=label or f"{dd.County[region_ids[region]]}", color=color) - ax.fill_between(x, qs_50[0], qs_50[1], alpha=0.5, - color=color, label="50% CI") - if not only_50q: + ax.fill_between(x, qs_80[0], qs_80[1], alpha=0.5, + color=color, label="80% CI") + if not only_80q: ax.fill_between(x, qs_90[0], qs_90[1], alpha=0.3, color=color, label="90% CI") ax.fill_between(x, qs_95[0], qs_95[1], alpha=0.1, @@ -125,7 +125,7 @@ def plot_aggregated_over_regions( ax=None, label=None, color='red', - only_50q=False + only_80q=False ): if data.ndim != 3: raise ValueError("Array not of shape (samples, time_points, regions)") @@ -136,7 +136,7 @@ def plot_aggregated_over_regions( # Aggregate over regions agg_over_regions = region_agg(data, axis=-1) # (samples, time_points) - qs_50 = np.quantile(agg_over_regions, q=[0.25, 0.75], axis=0) + qs_80 = np.quantile(agg_over_regions, q=[0.1, 0.9], axis=0) qs_90 = np.quantile(agg_over_regions, q=[0.05, 0.95], axis=0) qs_95 = np.quantile(agg_over_regions, q=[0.025, 0.975], axis=0) @@ -149,9 +149,9 @@ def plot_aggregated_over_regions( ax.plot(x, agg_median, lw=2, label=label or "Aggregated over regions", color=color) - ax.fill_between(x, qs_50[0], qs_50[1], alpha=0.5, - color=color, label="50% CI") - if not only_50q: + ax.fill_between(x, qs_80[0], qs_80[1], alpha=0.5, + color=color, label="80% CI") + if not only_80q: ax.fill_between(x, qs_90[0], qs_90[1], alpha=0.3, color=color, label="90% CI") ax.fill_between(x, qs_95[0], qs_95[1], alpha=0.1, @@ -228,7 +228,7 @@ def plot_all_regions(simulations, divi_data, name, synthetic, with_aug): fig, ax = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(15, 25), layout="constrained") ax = ax.flatten() for i, region_idx in enumerate(range(start_idx, end_idx)): - plot_region_median_mad( + plot_region_fit( simulations, region=region_idx, true_data=divi_data, label="Median", ax=ax[i], color="#132a70" ) # Hide unused subplots @@ -413,8 +413,80 @@ def plot_damping_values(damping_values, name, synthetic, with_aug): ax.legend(fontsize=10, loc="upper right", ncol=2) plt.savefig(f"{name}/damping_values_combined{name}{synthetic}{with_aug}.png", dpi=300) +def run_prior_predictive_check(name): + validation_data = load_pickle(f'{name}/validation_data_{name}.pickle') # synthetic data + + divi_region_keys = region_keys_sorted(validation_data) + num_samples = validation_data['region0'].shape[0] + time_points = validation_data['region0'].shape[1] + + true_data = load_divi_data() + true_data = np.concatenate( + [true_data[key] for key in divi_region_keys], axis=-1 + )[0] + # get sims in shape (samples, time, regions) + simulations = np.zeros((num_samples, time_points, len(divi_region_keys))) + for i in range(num_samples): + simulations[i] = np.concatenate([validation_data[key][i] for key in divi_region_keys], axis=-1) + + fig, ax = plt.subplots(figsize=(8, 6)) + plot_aggregated_over_regions( + simulations, true_data=true_data, label="Region Aggregated Median (No Aug)", ax=ax, color="#132a70" + ) + ax.set_title("Prior predictive check - Aggregated over regions") + plt.savefig(f'{name}/prior_predictive_check_{name}.png') + +def compare_median_sim_to_mean_param_sim(name): + num_samples = 10 + + divi_dict = load_divi_data() + validation_data_skip2w = skip_2weeks(divi_dict) + aggregate_states(validation_data_skip2w) + divi_region_keys = region_keys_sorted(divi_dict) + divi_data = np.concatenate( + [divi_dict[key] for key in divi_region_keys], axis=-1 + )[0] + + simulations_aug = load_pickle(f'{name}/sims_{name}_with_aug.pickle') + + workflow = get_workflow() + workflow.approximator = keras.models.load_model( + filepath=os.path.join(f"{name}/model_{name}.keras") + ) + + samples = workflow.sample(conditions=validation_data_skip2w, num_samples=num_samples) + for key in inference_params: + samples[key] = np.median(samples[key], axis=1) + samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], 16, NUM_DAMPING_POINTS)) + result = run_germany_nuts3_simulation( + damping_values=samples['damping_values'][0], + t_E=samples['t_E'][0], t_ISy=samples['t_ISy'][0], + t_ISev=samples['t_ISev'][0], t_Cr=samples['t_Cr'][0], + mu_CR=samples['mu_CR'][0], mu_IH=samples['mu_IH'][0], + mu_HU=samples['mu_HU'][0], mu_UD=samples['mu_UD'][0], + transmission_prob=samples['transmission_prob'][0] + ) + for key in result.keys(): + result[key] = np.array(result[key])[None, ...] # add sample axis + + result = extract_observables(result) + divi_region_keys = region_keys_sorted(result) + result = np.concatenate( + [result[key] if key in result else np.zeros_like(result[divi_region_keys[0]]) for key in divi_region_keys], + axis=-1 + )[0] -class Simulation: + + fig, ax = plt.subplots(figsize=(8, 6)) + plot_aggregated_over_regions( + simulations_aug, true_data=result, label="Compare median of sim to sim of median params", ax=ax, color="#132a70" + ) + ax.set_title("Comparison of Median Simulation to Simulation of Median Parameters") + plt.savefig(f'{name}/compare_median_sim13_to_mean_param_sim_{name}.png') + plt.close() + + +class Simulation: """ """ def __init__(self, data_dir, start_date, results_dir): @@ -747,14 +819,15 @@ def get_workflow(): ) inference_network = bf.networks.FlowMatching(subnet_kwargs={'widths': (512, 512, 512, 512, 512)}) - aug = bf.augmentations.NNPE(spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False) + # aug = bf.augmentations.NNPE(spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False) workflow = bf.BasicWorkflow( simulator=simulator, adapter=adapter, summary_network=summary_network, inference_network=inference_network, - standardize='all', - augmentations={f'region{i}': aug for i in range(len(region_ids)) if region_ids[i] not in no_icu_ids} + standardize='all' + # augmentations={f'region{i}': aug for i in range(len(region_ids)) if region_ids[i] not in no_icu_ids} + # aggregation of the states would need to be recomputed every time a different noise realization is applied ) return workflow @@ -764,16 +837,16 @@ def run_training(name, num_training_files=20): train_template = name+"/trainings_data{i}_"+name+".pickle" val_path = f"{name}/validation_data_{name}.pickle" - # aug = bf.augmentations.NNPE( - # spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False - # ) + aug = bf.augmentations.NNPE( + spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False + ) # training data train_files = [train_template.format(i=i) for i in range(1, 1+num_training_files)] trainings_data = None for p in train_files: d = load_pickle(p) - # d = apply_aug(d, aug=aug) # only on region keys + d = apply_aug(d, aug=aug) # only on region keys d = skip_2weeks(d) d['damping_values'] = d['damping_values'].reshape((d['damping_values'].shape[0], -1)) if trainings_data is None: @@ -783,7 +856,7 @@ def run_training(name, num_training_files=20): aggregate_states(trainings_data) # validation data - validation_data = load_pickle(val_path) + validation_data = apply_aug(load_pickle(val_path), aug=aug) validation_data = skip_2weeks(validation_data) aggregate_states(validation_data) validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) @@ -819,7 +892,7 @@ def run_inference(name, num_samples=1000, on_synthetic_data=False): ) # validation data - validation_data = load_pickle(val_path) + validation_data = load_pickle(val_path) # synthetic data if on_synthetic_data: # validation data validation_data = apply_aug(validation_data, aug=aug) @@ -906,16 +979,16 @@ def run_inference(name, num_samples=1000, on_synthetic_data=False): simulations_aug, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[0, 1], color="#132a70" ) axes[0, 1].set_title("With Augmentation") - # Plot without augmentation (50% quantile only) + # Plot without augmentation (80% quantile only) plot_aggregated_over_regions( - simulations, true_data=divi_data, label="Region Aggregated Median (No Aug)", ax=axes[1, 0], color="#132a70", only_50q=True + simulations, true_data=divi_data, label="Region Aggregated Median (No Aug)", ax=axes[1, 0], color="#132a70", only_80q=True ) - axes[1, 0].set_title("Without Augmentation (50% Quantile)") - # Plot with augmentation (50% quantile only) + axes[1, 0].set_title("Without Augmentation (80% Quantile)") + # Plot with augmentation (80% quantile only) plot_aggregated_over_regions( - simulations_aug, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[1, 1], color="#132a70", only_50q=True + simulations_aug, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[1, 1], color="#132a70", only_80q=True ) - axes[1, 1].set_title("With Augmentation (50% Quantile)") + axes[1, 1].set_title("With Augmentation (80% Quantile)") plt.savefig(f'{name}/region_aggregated_{name}{synthetic}.png') plt.close() @@ -949,11 +1022,13 @@ def run_inference(name, num_samples=1000, on_synthetic_data=False): if __name__ == "__main__": - name = "3dampings_lessnoise" + name = "3dampings_lessnoise_newnetwork" if not os.path.exists(name): os.makedirs(name) - # create_train_data(filename=f'{name}/validation_data{name}.pickle', number_samples=1000) + create_train_data(filename=f'{name}/validation_data_{name}.pickle', number_samples=1000) # run_training(name=name, num_training_files=20) - run_inference(name=name, on_synthetic_data=True) - run_inference(name=name, on_synthetic_data=False) \ No newline at end of file + # run_inference(name=name, on_synthetic_data=True) + # run_inference(name=name, on_synthetic_data=False) + # run_prior_predictive_check(name=name) + # compare_median_sim_to_mean_param_sim(name=name) \ No newline at end of file From 642a07139e14df05d585315008248778b876a4ee Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Tue, 11 Nov 2025 10:32:08 +0100 Subject: [PATCH 65/73] adjust germany and spain files to new structure --- cpp/memilio/io/epi_data.cpp | 22 + cpp/memilio/io/epi_data.h | 2 + .../simulation/graph_germany_nuts0.py | 624 ++++++++++--- .../simulation/graph_germany_nuts1.py | 65 +- .../simulation/graph_germany_nuts3_clean.py | 12 +- .../examples/simulation/graph_spain_nuts3.py | 855 +++++++++++++++--- .../memilio/epidata/defaultDict.py | 88 +- .../memilio/epidata/getSimulationDataSpain.py | 6 +- 8 files changed, 1410 insertions(+), 264 deletions(-) diff --git a/cpp/memilio/io/epi_data.cpp b/cpp/memilio/io/epi_data.cpp index e95c6ef809..2741d1f4ff 100644 --- a/cpp/memilio/io/epi_data.cpp +++ b/cpp/memilio/io/epi_data.cpp @@ -72,6 +72,28 @@ IOResult> get_country_id(const std::string& /*path*/, bool /*is std::vector id = {0}; return success(id); } + +IOResult> get_provincia_ids(const std::string& path, bool /*is_node_for_county*/, + bool /*rki_age_groups*/) +{ +mio::log_info("Reading node data from {}.\n", path); +BOOST_OUTCOME_TRY(auto&& population_data, read_population_data_spain(path)); +std::vector id; +id.reserve(population_data.size()); +for (auto&& entry : population_data) { +if (entry.provincia_id) { +id.push_back(entry.provincia_id->get()); +} +else { +return failure(StatusCode::InvalidValue, "Population data file is missing provincia ids."); +} +} + +//remove duplicate node ids +id.erase(std::unique(id.begin(), id.end()), id.end()); +mio::log_info("Reading node data completed."); +return success(id); +} } // namespace mio #endif //MEMILIO_HAS_JSONCPP diff --git a/cpp/memilio/io/epi_data.h b/cpp/memilio/io/epi_data.h index 048eb33c12..442b54e4cc 100644 --- a/cpp/memilio/io/epi_data.h +++ b/cpp/memilio/io/epi_data.h @@ -549,6 +549,8 @@ IOResult set_vaccination_data_age_group_names(std::vector nam IOResult> get_node_ids(const std::string& path, bool is_node_for_county, bool rki_age_groups = true); IOResult> get_country_id(const std::string& /*path*/, bool /*is_node_for_county*/, bool /*rki_age_groups*/ = true); +IOResult> get_provincia_ids(const std::string& path, bool /*is_node_for_county*/, + bool /*rki_age_groups*/ = true); /** * Represents an entry in a vaccination data file. diff --git a/pycode/examples/simulation/graph_germany_nuts0.py b/pycode/examples/simulation/graph_germany_nuts0.py index b9d4843430..4850a03902 100644 --- a/pycode/examples/simulation/graph_germany_nuts0.py +++ b/pycode/examples/simulation/graph_germany_nuts0.py @@ -17,18 +17,182 @@ # See the License for the specific language governing permissions and # limitations under the License. ############################################################################# -import numpy as np -import datetime import os -import memilio.simulation as mio -import memilio.simulation.osecir as osecir +os.environ["KERAS_BACKEND"] = "tensorflow" + +import numpy as np +import pandas as pd import matplotlib.pyplot as plt +import datetime +import pickle +from scipy.stats import truncnorm -from enum import Enum -from memilio.simulation.osecir import (Model, Simulation, - interpolate_simulation_result) +from matplotlib.patches import Patch -import pickle +import bayesflow as bf +import keras + +import memilio.simulation as mio +import memilio.simulation.osecir as osecir +from memilio.simulation.osecir import Model, interpolate_simulation_result +from memilio.epidata import defaultDict as dd + +import geopandas as gpd + + +region_ids = [0] +inference_params = ['damping_values', 't_E', 't_ISy', 't_ISev', + 't_Cr', 'mu_CR', 'mu_IH', 'mu_HU', 'mu_UD', 'transmission_prob'] +summary_vars = ['state0'] + +bounds = { + 't_E': (1.0, 5.2), + 't_ISy': (4.0, 10.0), + 't_ISev': (5.0, 10.0), + 't_Cr': (9.0, 17.0), + 'mu_CR': (0.0, 0.4), + 'mu_IH': (0.0, 0.2), + 'mu_HU': (0.0, 0.4), + 'mu_UD': (0.0, 0.4), + 'transmission_prob': (0.0, 0.2) +} +SPIKE_SCALE = 0.4 +SLAB_SCALE = 0.2 +DATE_TIME = datetime.date(year=2020, month=10, day=1) +NUM_DAMPING_POINTS = 3 + +def plot_region_fit( + data: np.ndarray, + region: int, + true_data=None, + ax=None, + label=None, + color="red", + only_80q=False +): + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data is not None: + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + n_samples, n_time, n_regions = data.shape + if not (0 <= region < n_regions): + raise IndexError + + x = np.arange(n_time) + vals = data[:, :, region] # (samples, time_points) + + qs_80 = np.quantile(vals, q=[0.1, 0.9], axis=0) + qs_90 = np.quantile(vals, q=[0.05, 0.95], axis=0) + qs_95 = np.quantile(vals, q=[0.025, 0.975], axis=0) + + med = np.median(vals, axis=0) + + if ax is None: + fig, ax = plt.subplots() + + ax.plot( + x, med, lw=2, label=label or f"Germany", color=color) + ax.fill_between(x, qs_80[0], qs_80[1], alpha=0.5, + color=color, label="80% CI") + if not only_80q: + ax.fill_between(x, qs_90[0], qs_90[1], alpha=0.3, + color=color, label="90% CI") + ax.fill_between(x, qs_95[0], qs_95[1], alpha=0.1, + color=color, label="95% CI") + + if true_data is not None: + true_vals = true_data[:, region] # (time_points,) + ax.plot(x, true_vals, lw=2, color="black", label="True data") + + ax.set_xlabel("Time", fontsize=12) + ax.set_ylabel("ICU", fontsize=12) + ax.set_title(f"Germany", fontsize=12) + if label is not None: + ax.legend(fontsize=11, loc="upper right") + + +def plot_icu_on_germany(simulations, name, synthetic, with_aug): + med = np.median(simulations, axis=0) + + print(simulations.shape) + population = pd.read_json('data/Germany/pydata/county_current_population_germany.json') + # print(population['Population'].shape) + values = med / population['Population'][None, :] * 100000 + print(values.shape) + + + map_data = gpd.read_file(os.path.join(os.getcwd(), 'tools/vg2500_12-31.utm32s.shape/vg2500/VG2500_KRS.shp')) + fedstate_data = gpd.read_file(os.path.join(os.getcwd(), 'tools/vg2500_12-31.utm32s.shape/vg2500/VG2500_LAN.shp')) + state_data = gpd.read_file(os.path.join(os.getcwd(), 'tools/vg2500_12-31.utm32s.shape/vg2500/VG2500_STA.shp')) + + plot_map(simulations[0], state_data, map_data, fedstate_data, "Median ICU", f"{name}/median_icu_germany_initial_{name}{synthetic}{with_aug}") + plot_map(simulations[-1], state_data, map_data, fedstate_data, "Median ICU", f"{name}/median_icu_germany_final_{name}{synthetic}{with_aug}") + +def plot_map(values, state_data, map_data, fedstate_data, label, filename): + state_data[label] = state_data['ARS'].map(dict(zip([f"{region_id:02d}" for region_id in region_ids], values))) + + fig, ax = plt.subplots(figsize=(12, 14)) + state_data.plot( + column=f"{label}", + cmap='Reds', + linewidth=1., + ax=ax, + edgecolor='0.6', + legend=True, + legend_kwds={'label': f"{label} per State", 'shrink': 0.6}, + ) + map_data.boundary.plot(ax=ax, color='black', linewidth=0.2) + fedstate_data.boundary.plot(ax=ax, color='black', linewidth=0.4) + ax.set_title(f"{label} per Federal State") + ax.axis('off') + plt.savefig(filename, bbox_inches='tight', dpi=300) + +def plot_damping_values(damping_values, name, synthetic, with_aug): + + med = np.median(damping_values, axis=0) + mad = np.median(np.abs(damping_values - med), axis=0) + + # Extend for step plotting + med_extended = np.hstack([med, med[:, -1][:, None]]) + mad_extended = np.hstack([mad, mad[:, -1][:, None]]) + + # Plot damping values per region + fig, axes = plt.subplots(4, 4, figsize=(15, 10), constrained_layout=True) + axes = axes.flatten() + x = np.arange(15, 61, 15) # Time steps from 15 to 60 + + for i, ax in enumerate(axes): + if i < 16: + ax.stairs(med[i], edges=x, lw=2, color='red', baseline=None) + ax.fill_between( + x, med_extended[i] - mad_extended[i], med_extended[i] + mad_extended[i], + alpha=0.25, color='red', step='post' + ) + ax.set_title(f"{dd.State[i+1]}", fontsize=10) + ax.set_xlabel("Time", fontsize=8) + ax.set_ylabel("Damping Value", fontsize=8) + else: + ax.axis('off') # Hide unused subplots + + plt.suptitle("Damping Values per Region", fontsize=14) + plt.savefig(f"{name}/damping_values{name}{synthetic}{with_aug}.png", dpi=300) + + # Combined plot for all regions + fig, ax = plt.subplots(figsize=(10, 6)) + cmap = plt.cm.get_cmap("viridis", 16) # Colormap with 16 distinct colors + + for i in range(16): + ax.stairs( + med[i], edges=x, lw=2, label=f"{dd.State[i+1]}", + color=cmap(i), baseline=None + ) + + ax.set_title("Damping Values per Region (Combined)", fontsize=14) + ax.set_xlabel("Time", fontsize=12) + ax.set_ylabel("Damping Value", fontsize=12) + ax.legend(fontsize=10, loc="upper right", ncol=2) + plt.savefig(f"{name}/damping_values_combined{name}{synthetic}{with_aug}.png", dpi=300) class Simulation: @@ -42,28 +206,28 @@ def __init__(self, data_dir, start_date, results_dir): if not os.path.exists(self.results_dir): os.makedirs(self.results_dir) - def set_covid_parameters(self, model): + def set_covid_parameters(self, model, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): """ :param model: """ - model.parameters.TimeExposed[mio.AgeGroup(0)] = 3.335 - model.parameters.TimeInfectedNoSymptoms[mio.AgeGroup(0)] = 2.58916 - model.parameters.TimeInfectedSymptoms[mio.AgeGroup(0)] = 6.94547 - model.parameters.TimeInfectedSevere[mio.AgeGroup(0)] = 7.28196 - model.parameters.TimeInfectedCritical[mio.AgeGroup(0)] = 13.066 + model.parameters.TimeExposed[mio.AgeGroup(0)] = t_E + model.parameters.TimeInfectedNoSymptoms[mio.AgeGroup(0)] = 5.2 - t_E # todo: correct? + model.parameters.TimeInfectedSymptoms[mio.AgeGroup(0)] = t_ISy + model.parameters.TimeInfectedSevere[mio.AgeGroup(0)] = t_ISev + model.parameters.TimeInfectedCritical[mio.AgeGroup(0)] = t_Cr # probabilities model.parameters.TransmissionProbabilityOnContact[mio.AgeGroup( - 0)] = 0.07333 + 0)] = transmission_prob model.parameters.RelativeTransmissionNoSymptoms[mio.AgeGroup(0)] = 1 model.parameters.RecoveredPerInfectedNoSymptoms[mio.AgeGroup( - 0)] = 0.2069 - model.parameters.SeverePerInfectedSymptoms[mio.AgeGroup(0)] = 0.07864 - model.parameters.CriticalPerSevere[mio.AgeGroup(0)] = 0.17318 - model.parameters.DeathsPerCritical[mio.AgeGroup(0)] = 0.21718 + 0)] = mu_CR + model.parameters.SeverePerInfectedSymptoms[mio.AgeGroup(0)] = mu_IH + model.parameters.CriticalPerSevere[mio.AgeGroup(0)] = mu_HU + model.parameters.DeathsPerCritical[mio.AgeGroup(0)] = mu_UD # start day is set to the n-th day of the year model.parameters.StartDay = self.start_date.timetuple().tm_yday @@ -79,27 +243,37 @@ def set_contact_matrices(self, model): contact_matrices = mio.ContactMatrixGroup(1, self.num_groups) baseline = np.ones((self.num_groups, self.num_groups)) * 7.95 - minimum = np.ones((self.num_groups, self.num_groups)) * 0 + minimum = np.zeros((self.num_groups, self.num_groups)) contact_matrices[0] = mio.ContactMatrix(baseline, minimum) model.parameters.ContactPatterns.cont_freq_mat = contact_matrices - def set_npis(self, params, end_date, damping_value): + def set_npis(self, params, end_date, damping_values): """ :param params: :param end_date: """ + start_damping_1 = DATE_TIME + datetime.timedelta(days=15) + start_damping_2 = DATE_TIME + datetime.timedelta(days=30) + start_damping_3 = DATE_TIME + datetime.timedelta(days=45) - start_damping = datetime.date( - 2020, 12, 18) + if start_damping_1 < end_date: + start_date = (start_damping_1 - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping( + mio.Damping(np.r_[damping_values[0]], t=start_date)) - if start_damping < end_date: - start_date = (start_damping - self.start_date).days + if start_damping_2 < end_date: + start_date = (start_damping_2 - self.start_date).days params.ContactPatterns.cont_freq_mat[0].add_damping( - mio.Damping(np.r_[damping_value], t=start_date)) + mio.Damping(np.r_[damping_values[1]], t=start_date)) - def get_graph(self, end_date, damping_value): + if start_damping_3 < end_date: + start_date = (start_damping_3 - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping( + mio.Damping(np.r_[damping_values[2]], t=start_date)) + + def get_graph(self, end_date, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): """ :param end_date: @@ -107,9 +281,9 @@ def get_graph(self, end_date, damping_value): """ print("Initializing model...") model = Model(self.num_groups) - self.set_covid_parameters(model) + self.set_covid_parameters( + model, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) self.set_contact_matrices(model) - self.set_npis(model.parameters, end_date, damping_value) print("Model initialized.") graph = osecir.ModelGraph() @@ -131,13 +305,13 @@ def get_graph(self, end_date, damping_value): mio.Date(end_date.year, end_date.month, end_date.day), pydata_dir, path_population_data, False, graph, scaling_factor_infected, - scaling_factor_icu, 0.9, 0, False) + scaling_factor_icu, 0., 0, False) print("Graph created.") return graph - def run(self, num_days_sim, damping_value, save_graph=True): + def run(self, num_days_sim, damping_values, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob, save_graph=True): """ :param num_days_sim: @@ -148,103 +322,349 @@ def run(self, num_days_sim, damping_value, save_graph=True): """ mio.set_log_level(mio.LogLevel.Warning) end_date = self.start_date + datetime.timedelta(days=num_days_sim) - num_runs = 10 - - graph = self.get_graph(end_date, damping_value) - if save_graph: - path_graph = os.path.join(self.results_dir, "graph") - if not os.path.exists(path_graph): - os.makedirs(path_graph) - osecir.write_graph(graph, path_graph) + graph = self.get_graph(end_date, t_E, t_ISy, t_ISev, + t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) mobility_graph = osecir.MobilityGraph() for node_idx in range(graph.num_nodes): - mobility_graph.add_node(graph.get_node( - node_idx).id, graph.get_node(node_idx).property) + node = graph.get_node(node_idx) + + self.set_npis( + node.property.parameters, + end_date, + damping_values[node_idx]) + mobility_graph.add_node(node.id, node.property) for edge_idx in range(graph.num_edges): mobility_graph.add_edge( graph.get_edge(edge_idx).start_node_idx, graph.get_edge(edge_idx).end_node_idx, graph.get_edge(edge_idx).property) - mobility_sim = osecir.MobilitySimulation(mobility_graph, t0=0, dt=0.5) mobility_sim.advance(num_days_sim) - results = [] - for node_idx in range(graph.num_nodes): - results.append(osecir.interpolate_simulation_result( - mobility_sim.graph.get_node(node_idx).property.result)) - - osecir.interpolate_simulation_result( - mobility_sim.graph.get_node(0).property.result).export_csv( - 'test.csv') + results = {} + for node_idx in range(mobility_sim.graph.num_nodes): + results[f'state{node_idx}'] = osecir.interpolate_simulation_result( + mobility_sim.graph.get_node(node_idx).property.result) return results -def run_germany_nuts0_simulation(damping_value): +def run_germany_nuts0_simulation(damping_values, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): mio.set_log_level(mio.LogLevel.Warning) file_path = os.path.dirname(os.path.abspath(__file__)) sim = Simulation( data_dir=os.path.join(file_path, "../../../data"), - start_date=datetime.date(year=2020, month=12, day=12), + start_date=DATE_TIME, results_dir=os.path.join(file_path, "../../../results_osecir")) - num_days_sim = 50 + num_days_sim = 60 - results = sim.run(num_days_sim, damping_value) - - return {"region" + str(region): results[region] for region in range(len(results))} + results = sim.run(num_days_sim, damping_values, t_E, t_ISy, + t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) + return results def prior(): - damping_value = np.random.uniform(0.0, 1.0) - return {"damping_value": damping_value} + damping_values = np.zeros((NUM_DAMPING_POINTS, 1)) + for i in range(NUM_DAMPING_POINTS): + mean = np.random.uniform(0, 1) + scale = 0.1 + a, b = (0 - mean) / scale, (1 - mean) / scale + damping_values[i] = truncnorm.rvs( + a=a, b=b, loc=mean, scale=scale, size=1 + ) + return { + 'damping_values': np.transpose(damping_values), + 't_E': np.random.uniform(*bounds['t_E']), + 't_ISy': np.random.uniform(*bounds['t_ISy']), + 't_ISev': np.random.uniform(*bounds['t_ISev']), + 't_Cr': np.random.uniform(*bounds['t_Cr']), + 'mu_CR': np.random.uniform(*bounds['mu_CR']), + 'mu_IH': np.random.uniform(*bounds['mu_IH']), + 'mu_HU': np.random.uniform(*bounds['mu_HU']), + 'mu_UD': np.random.uniform(*bounds['mu_UD']), + 'transmission_prob': np.random.uniform(*bounds['transmission_prob']) + } + + +def load_divi_data(): + file_path = os.path.dirname(os.path.abspath(__file__)) + divi_path = os.path.join(file_path, "../../../data/Germany/pydata") + + data = pd.read_json(os.path.join(divi_path, "germany_divi_ma7.json")) + data = data[data['Date'] >= np.datetime64(DATE_TIME)] + data = data[data['Date'] <= np.datetime64(DATE_TIME + datetime.timedelta(days=60))] + data = data.sort_values(by=['Date']) + divi_dict = {} + divi_dict[f"state0"] = data['ICU'].to_numpy()[None, :, None] + return divi_dict + + + +def extract_observables(simulation_results, observable_index=7): + for key in simulation_results.keys(): + if key not in inference_params: + simulation_results[key] = simulation_results[key][:, :, observable_index][..., np.newaxis] + return simulation_results + + +def create_train_data(filename, number_samples=1000): + + simulator = bf.simulators.make_simulator( + [prior, run_germany_nuts0_simulation] + ) + trainings_data = simulator.sample(number_samples) + trainings_data = extract_observables(trainings_data) + with open(filename, 'wb') as f: + pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) + + +def load_pickle(path): + with open(path, "rb") as f: + return pickle.load(f) + +def is_state_key(k: str) -> bool: + return 'state' in k + +def apply_aug(d: dict, aug) -> dict: + return {k: np.clip(aug(v), 0, None) if is_state_key(k) else v for k, v in d.items()} + +def concat_dicts(base: dict, new: dict) -> dict: + missing = set(base) - set(new) + if missing: + raise KeyError(f"new dict missing keys: {sorted(missing)}") + for k in base: + base[k] = np.concatenate([base[k], new[k]]) + return base + + +def combine_results(dict_list): + combined = {} + for d in dict_list: + combined = concat_dicts(combined, d) if combined else d + return combined + +def skip_2weeks(d:dict) -> dict: + return {k: v[:, 14:, :] if is_state_key(k) else v for k, v in d.items()} + +def get_workflow(): + + simulator = bf.make_simulator( + [prior, run_germany_nuts0_simulation] + ) + adapter = ( + bf.Adapter() + .to_array() + .convert_dtype("float64", "float32") + .constrain("damping_values", lower=0.0, upper=1.0) + .constrain("t_E", lower=bounds["t_E"][0], upper=bounds["t_E"][1]) + .constrain("t_ISy", lower=bounds["t_ISy"][0], upper=bounds["t_ISy"][1]) + .constrain("t_ISev", lower=bounds["t_ISev"][0], upper=bounds["t_ISev"][1]) + .constrain("t_Cr", lower=bounds["t_Cr"][0], upper=bounds["t_Cr"][1]) + .constrain("mu_CR", lower=bounds["mu_CR"][0], upper=bounds["mu_CR"][1]) + .constrain("mu_IH", lower=bounds["mu_IH"][0], upper=bounds["mu_IH"][1]) + .constrain("mu_HU", lower=bounds["mu_HU"][0], upper=bounds["mu_HU"][1]) + .constrain("mu_UD", lower=bounds["mu_UD"][0], upper=bounds["mu_UD"][1]) + .constrain("transmission_prob", lower=bounds["transmission_prob"][0], upper=bounds["transmission_prob"][1]) + .concatenate( + ["damping_values", "t_E", "t_ISy", "t_ISev", "t_Cr", + "mu_CR", "mu_IH", "mu_HU", "mu_UD", "transmission_prob"], + into="inference_variables", + axis=-1 + ) + .concatenate(summary_vars, into="summary_variables", axis=-1) + ) + + summary_network = bf.networks.FusionTransformer( + summary_dim=(len(bounds)+16*NUM_DAMPING_POINTS)*2, dropout=0.1 + ) + inference_network = bf.networks.FlowMatching(subnet_kwargs={'widths': (512, 512, 512, 512, 512)}) + + # aug = bf.augmentations.NNPE(spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False) + workflow = bf.BasicWorkflow( + simulator=simulator, + adapter=adapter, + summary_network=summary_network, + inference_network=inference_network, + standardize='all' + # augmentations={f'fed_state{i}': aug for i in range(len(region_ids))} + ) + + return workflow + + +def run_training(name, num_training_files=20): + train_template = name+"/trainings_data{i}_"+name+".pickle" + val_path = f"{name}/validation_data_{name}.pickle" + + aug = bf.augmentations.NNPE( + spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False + ) + + # training data + train_files = [train_template.format(i=i) for i in range(1, 1+num_training_files)] + trainings_data = None + for p in train_files: + d = load_pickle(p) + d = apply_aug(d, aug=aug) # only on region keys + d = skip_2weeks(d) + d['damping_values'] = d['damping_values'].reshape((d['damping_values'].shape[0], -1)) + if trainings_data is None: + trainings_data = d + else: + trainings_data = concat_dicts(trainings_data, d) + + # validation data + validation_data = apply_aug(load_pickle(val_path), aug=aug) + validation_data = skip_2weeks(validation_data) + validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) + + # check data + workflow = get_workflow() + print("summary_variables shape:", workflow.adapter(trainings_data)["summary_variables"].shape) + print("inference_variables shape:", workflow.adapter(trainings_data)["inference_variables"].shape) + + history = workflow.fit_offline( + data=trainings_data, epochs=500, batch_size=64, validation_data=validation_data + ) + + workflow.approximator.save( + filepath=os.path.join(f"{name}/model_{name}.keras") + ) + + plots = workflow.plot_default_diagnostics( + test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True} + ) + plots['losses'].savefig(f'{name}/losses_{name}.png') + plots['recovery'].savefig(f'{name}/recovery_{name}.png') + plots['calibration_ecdf'].savefig(f'{name}/calibration_ecdf_{name}.png') + #plots['z_score_contraction'].savefig(f'{name}/z_score_contraction_{name}.png') + + +def run_inference(name, num_samples=1000, on_synthetic_data=False): + val_path = f"{name}/validation_data_{name}.pickle" + synthetic = "_synthetic" if on_synthetic_data else "" + + aug = bf.augmentations.NNPE( + spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False + ) + + # validation data + validation_data = load_pickle(val_path) # synthetic data + if on_synthetic_data: + # validation data + validation_data = apply_aug(validation_data, aug=aug) + validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) + validation_data_skip2w = skip_2weeks(validation_data) + divi_dict = validation_data + + divi_data = np.concatenate( + [divi_dict[f'state{i}'] for i in range(len(region_ids))], axis=-1 + )[0] # only one dataset + else: + divi_dict = load_divi_data() + validation_data_skip2w = skip_2weeks(divi_dict) + divi_data = np.concatenate( + [divi_dict[f'state{i}'] for i in range(len(region_ids))], axis=-1 + )[0] + + workflow = get_workflow() + workflow.approximator = keras.models.load_model( + filepath=os.path.join(f"{name}/model_{name}.keras") + ) + + if os.path.exists(f'{name}/sims_{name}{synthetic}_with_aug.pickle') and os.path.exists(f'{name}/sims_{name}{synthetic}.pickle'): + simulations = load_pickle(f'{name}/sims_{name}{synthetic}.pickle') + simulations_aug = load_pickle(f'{name}/sims_{name}{synthetic}_with_aug.pickle') + print("loaded simulations from file") + else: + samples = workflow.sample(conditions=validation_data_skip2w, num_samples=num_samples) + samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], num_samples, 1, NUM_DAMPING_POINTS)) + results = [] + for i in range(num_samples): # we only have one dataset for inference here + result = run_germany_nuts0_simulation( + damping_values=samples['damping_values'][0, i], + t_E=samples['t_E'][0, i], t_ISy=samples['t_ISy'][0, i], + t_ISev=samples['t_ISev'][0, i], t_Cr=samples['t_Cr'][0, i], + mu_CR=samples['mu_CR'][0, i], mu_IH=samples['mu_IH'][0, i], + mu_HU=samples['mu_HU'][0, i], mu_UD=samples['mu_UD'][0, i], + transmission_prob=samples['transmission_prob'][0, i] + ) + for key in result.keys(): + result[key] = np.array(result[key])[None, ...] # add sample axis + results.append(result) + results = combine_results(results) + results = extract_observables(results) + results_aug = apply_aug(results, aug=aug) + + # get sims in shape (samples, time, regions) + simulations = np.zeros((num_samples, divi_data.shape[0], divi_data.shape[1])) + simulations_aug = np.zeros((num_samples, divi_data.shape[0], divi_data.shape[1])) + for i in range(num_samples): + simulations[i] = np.concatenate([results[key][i] for key in results.keys()], axis=-1) + simulations_aug[i] = np.concatenate([results_aug[key][i] for key in results.keys()], axis=-1) + + # save sims + with open(f'{name}/sims_{name}{synthetic}.pickle', 'wb') as f: + pickle.dump(simulations, f, pickle.HIGHEST_PROTOCOL) + with open(f'{name}/sims_{name}{synthetic}_with_aug.pickle', 'wb') as f: + pickle.dump(simulations_aug, f, pickle.HIGHEST_PROTOCOL) + + samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], samples['damping_values'].shape[1], -1)) + validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) + + plot = bf.diagnostics.pairs_posterior(samples, priors=validation_data, dataset_id=0) + plot.savefig(f'{name}/pairs_posterior_{name}{synthetic}.png') + + fig, axes = plt.subplots(2, 2, figsize=(12, 12), sharex=True, sharey='row', constrained_layout=True) + # Plot without augmentation + plot_region_fit( + simulations, region=0, true_data=divi_data, label="Region Aggregated Median (No Aug)", ax=axes[0, 0], color="#132a70" + ) + axes[0, 0].set_title("Without Augmentation") + # Plot with augmentation + plot_region_fit( + simulations_aug, region=0, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[0, 1], color="#132a70" + ) + axes[0, 1].set_title("With Augmentation") + # Plot without augmentation (80% quantile only) + plot_region_fit( + simulations, region=0, true_data=divi_data, label="Region Aggregated Median (No Aug)", ax=axes[1, 0], color="#132a70", only_80q=True + ) + axes[1, 0].set_title("Without Augmentation (80% Quantile)") + # Plot with augmentation (80% quantile only) + plot_region_fit( + simulations_aug, region=0, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[1, 1], color="#132a70", only_80q=True + ) + axes[1, 1].set_title("With Augmentation (80% Quantile)") + plt.savefig(f'{name}/region_aggregated_{name}{synthetic}.png') + plt.close() + + plot_icu_on_germany(simulations, name, synthetic, with_aug="") + plot_icu_on_germany(simulations_aug, name, synthetic, with_aug="_with_aug") + + simulation_agg = np.sum(simulations, axis=-1, keepdims=True) # sum over regions + simulation_aug_agg = np.sum(simulations_aug, axis=-1, keepdims=True) + + rmse = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes(simulation_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True), normalize=False) + rmse_aug = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes(simulation_aug_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True), normalize=False) + print("Mean RMSE over regions:", rmse["values"].mean()) + print("Mean RMSE over regions (with aug):", rmse_aug["values"].mean()) + + cal_error = bf.diagnostics.metrics.calibration_error(np.swapaxes(simulation_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True)) + cal_error_aug = bf.diagnostics.metrics.calibration_error(np.swapaxes(simulation_aug_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True)) + print("Mean Calibration Error over regions:", cal_error["values"].mean()) + print("Mean Calibration Error over regions (with aug):", cal_error_aug["values"].mean()) if __name__ == "__main__": - - run_germany_nuts0_simulation(0.5) - # import os - # os.environ["KERAS_BACKEND"] = "jax" - - # import bayesflow as bf - - # simulator = bf.simulators.make_simulator([prior, run_germany_nuts3_simulation]) - # # trainings_data = simulator.sample(5) - - # with open('trainings_data.pickle', 'wb') as f: - # pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) - - # with open('trainings_data.pickle', 'rb') as f: - # trainings_data = pickle.load(f) - - # # trainings_data = {k:v for k, v in trainings_data.items() if k in ('damping_value', 'region0', 'region1')} - # print("Loaded training data:", trainings_data) - - # trainings_data = simulator.sample(2) - # validation_data = simulator.sample(2) - - # adapter = ( - # bf.Adapter() - # .to_array() - # .convert_dtype("float64", "float32") - # .constrain("damping_value", lower=0.0, upper=1.0) - # .concatenate(["region"+str(region) for region in range(len(trainings_data)-1)], into="summary_variables") - # .rename("damping_value", "inference_variables") - # #.standardize("summary_variables") - # ) - - # summary_network = bf.networks.TimeSeriesNetwork(summary_dim=4) - # inference_network = bf.networks.CouplingFlow() - - # workflow = bf.BasicWorkflow( - # simulator=simulator, - # adapter=adapter, - # summary_network=summary_network, - # inference_network=inference_network - # ) - - # history = workflow.fit_offline(data=trainings_data, epochs=2, batch_size=2, validation_data=trainings_data) - # f = bf.diagnostics.plots.loss(history) + name = "nuts0" + + if not os.path.exists(name): + os.makedirs(name) + # create_train_data(filename=f'{name}/trainings_data10_{name}.pickle', number_samples=2000) + # run_training(name=name, num_training_files=10) + run_inference(name=name, on_synthetic_data=False) + run_inference(name=name, on_synthetic_data=True) \ No newline at end of file diff --git a/pycode/examples/simulation/graph_germany_nuts1.py b/pycode/examples/simulation/graph_germany_nuts1.py index 8fc9a2d99a..041c0f726a 100644 --- a/pycode/examples/simulation/graph_germany_nuts1.py +++ b/pycode/examples/simulation/graph_germany_nuts1.py @@ -68,7 +68,7 @@ def plot_region_fit( ax=None, label=None, color="red", - only_50q=False + only_80q=False ): if data.ndim != 3: raise ValueError("Array not of shape (samples, time_points, regions)") @@ -82,7 +82,7 @@ def plot_region_fit( x = np.arange(n_time) vals = data[:, :, region] # (samples, time_points) - qs_50 = np.quantile(vals, q=[0.25, 0.75], axis=0) + qs_80 = np.quantile(vals, q=[0.1, 0.9], axis=0) qs_90 = np.quantile(vals, q=[0.05, 0.95], axis=0) qs_95 = np.quantile(vals, q=[0.025, 0.975], axis=0) @@ -93,9 +93,9 @@ def plot_region_fit( ax.plot( x, med, lw=2, label=label or f"{dd.State[region_ids[region]]}", color=color) - ax.fill_between(x, qs_50[0], qs_50[1], alpha=0.5, - color=color, label="50% CI") - if not only_50q: + ax.fill_between(x, qs_80[0], qs_80[1], alpha=0.5, + color=color, label="80% CI") + if not only_80q: ax.fill_between(x, qs_90[0], qs_90[1], alpha=0.3, color=color, label="90% CI") ax.fill_between(x, qs_95[0], qs_95[1], alpha=0.1, @@ -119,7 +119,7 @@ def plot_aggregated_over_regions( ax=None, label=None, color='red', - only_50q=False + only_80q=False ): if data.ndim != 3: raise ValueError("Array not of shape (samples, time_points, regions)") @@ -130,7 +130,7 @@ def plot_aggregated_over_regions( # Aggregate over regions agg_over_regions = region_agg(data, axis=-1) # (samples, time_points) - qs_50 = np.quantile(agg_over_regions, q=[0.25, 0.75], axis=0) + qs_80 = np.quantile(agg_over_regions, q=[0.1, 0.9], axis=0) qs_90 = np.quantile(agg_over_regions, q=[0.05, 0.95], axis=0) qs_95 = np.quantile(agg_over_regions, q=[0.025, 0.975], axis=0) @@ -143,9 +143,9 @@ def plot_aggregated_over_regions( ax.plot(x, agg_median, lw=2, label=label or "Aggregated over regions", color=color) - ax.fill_between(x, qs_50[0], qs_50[1], alpha=0.5, - color=color, label="50% CI") - if not only_50q: + ax.fill_between(x, qs_80[0], qs_80[1], alpha=0.5, + color=color, label="80% CI") + if not only_80q: ax.fill_between(x, qs_90[0], qs_90[1], alpha=0.3, color=color, label="90% CI") ax.fill_between(x, qs_95[0], qs_95[1], alpha=0.1, @@ -165,6 +165,7 @@ def plot_icu_on_germany(simulations, name, synthetic, with_aug): population = pd.read_json('data/Germany/pydata/county_current_population_states.json') values = med / population['Population'].to_numpy()[None, :] * 100000 + map_data = gpd.read_file(os.path.join(os.getcwd(), 'tools/vg2500_12-31.utm32s.shape/vg2500/VG2500_KRS.shp')) fedstate_data = gpd.read_file(os.path.join(os.getcwd(), 'tools/vg2500_12-31.utm32s.shape/vg2500/VG2500_LAN.shp')) @@ -548,8 +549,7 @@ def run_germany_nuts1_simulation(damping_values, t_E, t_ISy, t_ISev, t_Cr, mu_CR num_days_sim = 60 results = sim.run(num_days_sim, damping_values, t_E, t_ISy, - t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob, save_graph=False) - results['fed_state0'].export_csv('test.csv') + t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) return results @@ -671,14 +671,14 @@ def get_workflow(): ) inference_network = bf.networks.FlowMatching(subnet_kwargs={'widths': (512, 512, 512, 512, 512)}) - aug = bf.augmentations.NNPE(spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False) + # aug = bf.augmentations.NNPE(spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False) workflow = bf.BasicWorkflow( simulator=simulator, adapter=adapter, summary_network=summary_network, inference_network=inference_network, - standardize='all', - augmentations={f'fed_state{i}': aug for i in range(len(region_ids))} + standardize='all' + # augmentations={f'fed_state{i}': aug for i in range(len(region_ids))} ) return workflow @@ -688,16 +688,16 @@ def run_training(name, num_training_files=20): train_template = name+"/trainings_data{i}_"+name+".pickle" val_path = f"{name}/validation_data_{name}.pickle" - # aug = bf.augmentations.NNPE( - # spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False - # ) + aug = bf.augmentations.NNPE( + spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False + ) # training data train_files = [train_template.format(i=i) for i in range(1, 1+num_training_files)] trainings_data = None for p in train_files: d = load_pickle(p) - # d = apply_aug(d, aug=aug) # only on region keys + d = apply_aug(d, aug=aug) # only on region keys d = skip_2weeks(d) d['damping_values'] = d['damping_values'].reshape((d['damping_values'].shape[0], -1)) if trainings_data is None: @@ -706,7 +706,7 @@ def run_training(name, num_training_files=20): trainings_data = concat_dicts(trainings_data, d) # validation data - validation_data = load_pickle(val_path)#apply_aug(load_pickle(val_path), aug=aug) + validation_data = apply_aug(load_pickle(val_path), aug=aug) validation_data = skip_2weeks(validation_data) validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) @@ -716,7 +716,7 @@ def run_training(name, num_training_files=20): print("inference_variables shape:", workflow.adapter(trainings_data)["inference_variables"].shape) history = workflow.fit_offline( - data=trainings_data, epochs=50, batch_size=64, validation_data=validation_data + data=trainings_data, epochs=500, batch_size=64, validation_data=validation_data ) workflow.approximator.save( @@ -732,7 +732,7 @@ def run_training(name, num_training_files=20): #plots['z_score_contraction'].savefig(f'{name}/z_score_contraction_{name}.png') -def run_inference(name, num_samples=10, on_synthetic_data=False): +def run_inference(name, num_samples=1000, on_synthetic_data=False): val_path = f"{name}/validation_data_{name}.pickle" synthetic = "_synthetic" if on_synthetic_data else "" @@ -741,14 +741,13 @@ def run_inference(name, num_samples=10, on_synthetic_data=False): ) # validation data - validation_data = load_pickle(val_path) + validation_data = load_pickle(val_path) # synthetic data if on_synthetic_data: # validation data validation_data = apply_aug(validation_data, aug=aug) validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) validation_data_skip2w = skip_2weeks(validation_data) divi_dict = validation_data - print(np.sum([divi_dict[f'fed_state{i}'] for i in range(len(region_ids))])) divi_data = np.concatenate( [divi_dict[f'fed_state{i}'] for i in range(len(region_ids))], axis=-1 @@ -822,16 +821,16 @@ def run_inference(name, num_samples=10, on_synthetic_data=False): simulations_aug, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[0, 1], color="#132a70" ) axes[0, 1].set_title("With Augmentation") - # Plot without augmentation (50% quantile only) + # Plot without augmentation (80% quantile only) plot_aggregated_over_regions( - simulations, true_data=divi_data, label="Region Aggregated Median (No Aug)", ax=axes[1, 0], color="#132a70", only_50q=True + simulations, true_data=divi_data, label="Region Aggregated Median (No Aug)", ax=axes[1, 0], color="#132a70", only_80q=True ) - axes[1, 0].set_title("Without Augmentation (50% Quantile)") - # Plot with augmentation (50% quantile only) + axes[1, 0].set_title("Without Augmentation (80% Quantile)") + # Plot with augmentation (80% quantile only) plot_aggregated_over_regions( - simulations_aug, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[1, 1], color="#132a70", only_50q=True + simulations_aug, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[1, 1], color="#132a70", only_80q=True ) - axes[1, 1].set_title("With Augmentation (50% Quantile)") + axes[1, 1].set_title("With Augmentation (80% Quantile)") plt.savefig(f'{name}/region_aggregated_{name}{synthetic}.png') plt.close() @@ -868,7 +867,7 @@ def run_inference(name, num_samples=10, on_synthetic_data=False): if not os.path.exists(name): os.makedirs(name) - create_train_data(filename=f'{name}/trainings_data1_{name}.pickle', number_samples=1000) + # create_train_data(filename=f'{name}/validation_data_{name}.pickle', number_samples=100) # run_training(name=name, num_training_files=20) - # run_inference(name=name, on_synthetic_data=True) - # run_inference(name=name, on_synthetic_data=False) \ No newline at end of file + run_inference(name=name, on_synthetic_data=False) + run_inference(name=name, on_synthetic_data=True) \ No newline at end of file diff --git a/pycode/examples/simulation/graph_germany_nuts3_clean.py b/pycode/examples/simulation/graph_germany_nuts3_clean.py index 1e970127d6..ef9ee88026 100644 --- a/pycode/examples/simulation/graph_germany_nuts3_clean.py +++ b/pycode/examples/simulation/graph_germany_nuts3_clean.py @@ -1,7 +1,7 @@ -####################################################################### +############################################################################# # Copyright (C) 2020-2025 MEmilio # -# Authors: Henrik Zunker +# Authors: Carlotta Gerstein # # Contact: Martin J. Kuehn # @@ -212,9 +212,6 @@ def plot_aggregated_to_federal_states(data, true_data, name, synthetic, with_aug plt.savefig(f'{name}/federal_states_{name}{synthetic}{with_aug}.png') plt.close() - -# %% - # plot simulations for all regions in 10x4 blocks def plot_all_regions(simulations, divi_data, name, synthetic, with_aug): n_regions = simulations.shape[-1] @@ -1004,11 +1001,10 @@ def run_inference(name, num_samples=1000, on_synthetic_data=False): plt.close() plot_icu_on_germany(simulations, name, synthetic, with_aug="") - plot_icu_on_germany(simulations, name, synthetic, with_aug="_with_aug") + plot_icu_on_germany(simulations_aug, name, synthetic, with_aug="_with_aug") simulation_agg = np.sum(simulations, axis=-1, keepdims=True) # sum over regions simulation_aug_agg = np.sum(simulations_aug, axis=-1, keepdims=True) - print(simulation_agg.shape, divi_data.shape) rmse = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes(simulation_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True), normalize=False) rmse_aug = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes(simulation_aug_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True), normalize=False) @@ -1026,7 +1022,7 @@ def run_inference(name, num_samples=1000, on_synthetic_data=False): if not os.path.exists(name): os.makedirs(name) - create_train_data(filename=f'{name}/validation_data_{name}.pickle', number_samples=1000) + create_train_data(filename=f'{name}/trainings_data1_{name}.pickle', number_samples=1000) # run_training(name=name, num_training_files=20) # run_inference(name=name, on_synthetic_data=True) # run_inference(name=name, on_synthetic_data=False) diff --git a/pycode/examples/simulation/graph_spain_nuts3.py b/pycode/examples/simulation/graph_spain_nuts3.py index 081fb1ea22..c08a2c9b16 100644 --- a/pycode/examples/simulation/graph_spain_nuts3.py +++ b/pycode/examples/simulation/graph_spain_nuts3.py @@ -1,7 +1,7 @@ ############################################################################# # Copyright (C) 2020-2025 MEmilio # -# Authors: Henrik Zunker +# Authors: Carlotta Gerstein # # Contact: Martin J. Kuehn # @@ -17,19 +17,363 @@ # See the License for the specific language governing permissions and # limitations under the License. ############################################################################# -import numpy as np -import datetime import os -import memilio.simulation as mio -import memilio.simulation.osecir as osecir +os.environ["KERAS_BACKEND"] = "tensorflow" + +import numpy as np +import pandas as pd import matplotlib.pyplot as plt +import datetime +import pickle +from scipy.stats import truncnorm -from enum import Enum -from memilio.simulation.osecir import (Model, Simulation, - interpolate_simulation_result) +from matplotlib.patches import Patch -import pickle +import bayesflow as bf +import keras +import memilio.simulation as mio +import memilio.simulation.osecir as osecir +from memilio.simulation.osecir import Model, interpolate_simulation_result +from memilio.epidata import defaultDict as dd + +import geopandas as gpd + +excluded_ids = [8, 35, 38, 51, 52] +region_ids = [region_id for region_id in dd.Provincias.keys() + if region_id not in excluded_ids] +inference_params = ['damping_values', 't_E', 't_ISy', 't_ISev', + 't_Cr', 'mu_CR', 'mu_IH', 'mu_HU', 'mu_UD', 'transmission_prob'] +summary_vars = ['state'] + [f'comunidad{i}' for i in range(len(dd.Comunidades))] + [ + f'region{i}' for i in range(len(dd.Provincias))] + +bounds = { + 't_E': (1.0, 5.2), + 't_ISy': (4.0, 10.0), + 't_ISev': (5.0, 10.0), + 't_Cr': (9.0, 17.0), + 'mu_CR': (0.0, 0.4), + 'mu_IH': (0.0, 0.2), + 'mu_HU': (0.0, 0.4), + 'mu_UD': (0.0, 0.4), + 'transmission_prob': (0.0, 0.2) +} +SPIKE_SCALE = 0.4 +SLAB_SCALE = 0.2 +DATE_TIME = datetime.date(year=2020, month=10, day=1) +NUM_DAMPING_POINTS = 3 + +def plot_region_fit( + data: np.ndarray, + region: int, + true_data=None, + ax=None, + label=None, + color="red", + only_80q=False +): + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data is not None: + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + n_samples, n_time, n_regions = data.shape + if not (0 <= region < n_regions): + raise IndexError + + x = np.arange(n_time) + vals = data[:, :, region] # (samples, time_points) + + qs_80 = np.quantile(vals, q=[0.1, 0.9], axis=0) + qs_90 = np.quantile(vals, q=[0.05, 0.95], axis=0) + qs_95 = np.quantile(vals, q=[0.025, 0.975], axis=0) + + med = np.median(vals, axis=0) + + if ax is None: + fig, ax = plt.subplots() + + ax.plot( + x, med, lw=2, label=label or f"{dd.County[region_ids[region]]}", color=color) + ax.fill_between(x, qs_80[0], qs_80[1], alpha=0.5, + color=color, label="80% CI") + if not only_80q: + ax.fill_between(x, qs_90[0], qs_90[1], alpha=0.3, + color=color, label="90% CI") + ax.fill_between(x, qs_95[0], qs_95[1], alpha=0.1, + color=color, label="95% CI") + + if true_data is not None: + true_vals = true_data[:, region] # (time_points,) + ax.plot(x, true_vals, lw=2, color="black", label="True data") + + ax.set_xlabel("Time", fontsize=12) + ax.set_ylabel("ICU", fontsize=12) + ax.set_title(f"{dd.County[region_ids[region]]}", fontsize=12) + if label is not None: + ax.legend(fontsize=11, loc="upper right") + + +def plot_aggregated_over_regions( + data: np.ndarray, + region_agg=np.sum, + true_data=None, + ax=None, + label=None, + color='red', + only_80q=False +): + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data is not None: + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + + # Aggregate over regions + agg_over_regions = region_agg(data, axis=-1) # (samples, time_points) + + qs_80 = np.quantile(agg_over_regions, q=[0.1, 0.9], axis=0) + qs_90 = np.quantile(agg_over_regions, q=[0.05, 0.95], axis=0) + qs_95 = np.quantile(agg_over_regions, q=[0.025, 0.975], axis=0) + + # Aggregate over samples + agg_median = np.median(agg_over_regions, axis=0) # (time_points, ) + + x = np.arange(agg_median.shape[0]) + if ax is None: + fig, ax = plt.subplots() + + ax.plot(x, agg_median, lw=2, + label=label or "Aggregated over regions", color=color) + ax.fill_between(x, qs_80[0], qs_80[1], alpha=0.5, + color=color, label="80% CI") + if not only_80q: + ax.fill_between(x, qs_90[0], qs_90[1], alpha=0.3, + color=color, label="90% CI") + ax.fill_between(x, qs_95[0], qs_95[1], alpha=0.1, + color=color, label="95% CI") + if true_data is not None: + true_vals = region_agg(true_data, axis=-1) # (time_points,) + ax.plot(x, true_vals, lw=2, color="black", label="True data") + + ax.set_xlabel("Time", fontsize=12) + ax.set_ylabel("ICU", fontsize=12) + if label is not None: + ax.legend(fontsize=11) + + +def plot_aggregated_to_comunidades(data, true_data, name, synthetic, with_aug): + fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(25, 25), layout="constrained") + ax = ax.flatten() + for state in range(19): + idxs = [i for i, region_id in enumerate(region_ids) if dd.Provincia_to_Comunidad[region_id] == state + 1] + plot_aggregated_over_regions( + data[:, :, idxs], # Add a dummy region axis for compatibility + true_data=true_data[:, idxs] if true_data is not None else None, + ax=ax[state], + label=f"State {state + 1}", + color=f"C{state % 10}" # Cycle through 10 colors + ) + plt.savefig(f'{name}/comunidades_{name}{synthetic}{with_aug}.png') + plt.close() + +# plot simulations for all regions in 10x4 blocks +def plot_all_regions(simulations, divi_data, name, synthetic, with_aug): + n_regions = simulations.shape[-1] + n_cols = 4 + n_rows = 10 + n_blocks = (n_regions + n_cols * n_rows - 1) // (n_cols * n_rows) + + for block in range(n_blocks): + start_idx = block * n_cols * n_rows + end_idx = min(start_idx + n_cols * n_rows, n_regions) + fig, ax = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(15, 25), layout="constrained") + ax = ax.flatten() + for i, region_idx in enumerate(range(start_idx, end_idx)): + plot_region_fit( + simulations, region=region_idx, true_data=divi_data, label="Median", ax=ax[i], color="#132a70" + ) + # Hide unused subplots + for i in range(end_idx - start_idx, len(ax)): + ax[i].axis("off") + plt.savefig(f'{name}/regions_block_{block + 1}_{name}{synthetic}{with_aug}.png') + plt.close() + + +def calibration_curves_per_region( + data: np.ndarray, + true_data: np.ndarray, + levels=np.linspace(0.01, 0.99, 20), + ax=None, + max_regions=None, + cmap=plt.cm.Blues, + linewidth=1.5, + legend=True, + with_ideal=True, +): + """ + Per-region calibration curves, each region in a different shade of a colormap. + + data: (samples, time, regions) + true_data: (time, regions) + max_regions: limit number of regions shown + cmap: matplotlib colormap for line shades + """ + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + + n_samples, n_time, n_regions = data.shape + if max_regions is None: + R = n_regions + else: + R = min(max_regions, n_regions) + + if ax is None: + fig, ax = plt.subplots(figsize=(5, 4)) + + colors = [cmap(i / (R + 1)) for i in range(1, R + 1)] + + x = np.asarray(levels) + for r, col in zip(range(R), colors): + emp = [] + for nominal in levels: + q_low = (1.0 - nominal) / 2.0 + q_high = 1.0 - q_low + lo = np.quantile(data[:, :, r], q_low, axis=0) + hi = np.quantile(data[:, :, r], q_high, axis=0) + hits = (true_data[:, r] >= lo) & (true_data[:, r] <= hi) + emp.append(hits.mean()) + emp = np.asarray(emp) + # , label=f"Region {r+1}") + ax.plot(x, emp, lw=linewidth, color=col, alpha=0.5) + + if with_ideal: + ideal_line = ax.plot([0, 1], [0, 1], linestyle="--", + lw=1.2, color="black", label="Ideal")[0] + else: + ideal_line = None + + if legend: + # Custom legend: one patch for regions, one line for ideal + region_patch = Patch(color=colors[-1], label="Regions") + ax.legend(handles=[region_patch, ideal_line], + frameon=True, ncol=1, fontsize=12) + + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.set_xlabel("Nominal level", fontsize=12) + ax.set_ylabel("Empirical coverage", fontsize=12) + ax.set_title("Calibration per region", fontsize=12) + return ax + + +def calibration_median_mad_over_regions( + data: np.ndarray, + true_data: np.ndarray, + levels=np.linspace(0.01, 0.99, 20), + ax=None, + color="tab:blue", + alpha=0.25, + linewidth=2.0, + with_ideal=True, +): + """ + Compute per region empirical coverage at each nominal level, + then summarize across regions with median and MAD. + + data: (samples, time, regions) + true_data: (time, regions) + """ + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + + n_samples, n_time, n_regions = data.shape + L = len(levels) + per_region = np.empty((n_regions, L), dtype=float) + + # per level coverage per region + for j, nominal in enumerate(levels): + q_low = (1.0 - nominal) / 2.0 + q_high = 1.0 - q_low + lo = np.quantile(data, q_low, axis=0) # (time, regions) + hi = np.quantile(data, q_high, axis=0) # (time, regions) + hits = (true_data >= lo) & (true_data <= hi) # (time, regions) + # mean over time for each region + per_region[:, j] = hits.mean(axis=0) + + med = np.median(per_region, axis=0) # (levels,) + mad = np.median(np.abs(per_region - med[None, :]), axis=0) + + if ax is None: + fig, ax = plt.subplots(figsize=(5, 4)) + + x = np.asarray(levels) + ax.fill_between(x, med - mad, med + mad, + alpha=alpha, color=color, label=None) + ax.plot(x, med, lw=linewidth, color=color, label="Median across regions") + + if with_ideal: + ax.plot([0, 1], [0, 1], linestyle="--", + lw=1.2, color="black", label="Ideal") + + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.set_xlabel("Nominal level", fontsize=12) + ax.set_ylabel("Empirical coverage", fontsize=12) + ax.set_title("Calibration median and MAD across regions", fontsize=12) + ax.legend(fontsize=12) + return ax, {"levels": x, "median": med, "mad": mad} + + +def plot_damping_values(damping_values, name, synthetic, with_aug): + + med = np.median(damping_values, axis=0) + mad = np.median(np.abs(damping_values - med), axis=0) + + # Extend for step plotting + med_extended = np.hstack([med, med[:, -1][:, None]]) + mad_extended = np.hstack([mad, mad[:, -1][:, None]]) + + # Plot damping values per region + fig, axes = plt.subplots(4, 4, figsize=(15, 10), constrained_layout=True) + axes = axes.flatten() + x = np.arange(15, 61, 15) # Time steps from 15 to 60 + + for i, ax in enumerate(axes): + if i < 19: + ax.stairs(med[i], edges=x, lw=2, color='red', baseline=None) + ax.fill_between( + x, med_extended[i] - mad_extended[i], med_extended[i] + mad_extended[i], + alpha=0.25, color='red', step='post' + ) + ax.set_title(f"{dd.Comunidades[i+2]}", fontsize=10) + ax.set_xlabel("Time", fontsize=8) + ax.set_ylabel("Damping Value", fontsize=8) + else: + ax.axis('off') # Hide unused subplots + + plt.suptitle("Damping Values per Region", fontsize=14) + plt.savefig(f"{name}/damping_values{name}{synthetic}{with_aug}.png", dpi=300) + + # Combined plot for all regions + fig, ax = plt.subplots(figsize=(10, 6)) + cmap = plt.cm.get_cmap("viridis", 16) # Colormap with 16 distinct colors + + for i in range(19): + ax.stairs( + med[i], edges=x, lw=2, label=f"{dd.State[i+1]}", + color=cmap(i), baseline=None + ) + + ax.set_title("Damping Values per Region (Combined)", fontsize=14) + ax.set_xlabel("Time", fontsize=12) + ax.set_ylabel("Damping Value", fontsize=12) + ax.legend(fontsize=10, loc="upper right", ncol=2) + plt.savefig(f"{name}/damping_values_combined{name}{synthetic}{with_aug}.png", dpi=300) class Simulation: """ """ @@ -42,28 +386,28 @@ def __init__(self, data_dir, start_date, results_dir): if not os.path.exists(self.results_dir): os.makedirs(self.results_dir) - def set_covid_parameters(self, model): + def set_covid_parameters(self, model, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): """ :param model: """ - model.parameters.TimeExposed[mio.AgeGroup(0)] = 3.335 - model.parameters.TimeInfectedNoSymptoms[mio.AgeGroup(0)] = 2.58916 - model.parameters.TimeInfectedSymptoms[mio.AgeGroup(0)] = 6.94547 - model.parameters.TimeInfectedSevere[mio.AgeGroup(0)] = 7.28196 - model.parameters.TimeInfectedCritical[mio.AgeGroup(0)] = 13.066 + model.parameters.TimeExposed[mio.AgeGroup(0)] = t_E + model.parameters.TimeInfectedNoSymptoms[mio.AgeGroup(0)] = 5.2 - t_E # todo: correct? + model.parameters.TimeInfectedSymptoms[mio.AgeGroup(0)] = t_ISy + model.parameters.TimeInfectedSevere[mio.AgeGroup(0)] = t_ISev + model.parameters.TimeInfectedCritical[mio.AgeGroup(0)] = t_Cr # probabilities model.parameters.TransmissionProbabilityOnContact[mio.AgeGroup( - 0)] = 0.07333 + 0)] = transmission_prob model.parameters.RelativeTransmissionNoSymptoms[mio.AgeGroup(0)] = 1 model.parameters.RecoveredPerInfectedNoSymptoms[mio.AgeGroup( - 0)] = 0.2069 - model.parameters.SeverePerInfectedSymptoms[mio.AgeGroup(0)] = 0.07864 - model.parameters.CriticalPerSevere[mio.AgeGroup(0)] = 0.17318 - model.parameters.DeathsPerCritical[mio.AgeGroup(0)] = 0.21718 + 0)] = mu_CR + model.parameters.SeverePerInfectedSymptoms[mio.AgeGroup(0)] = mu_IH + model.parameters.CriticalPerSevere[mio.AgeGroup(0)] = mu_HU + model.parameters.DeathsPerCritical[mio.AgeGroup(0)] = mu_UD # start day is set to the n-th day of the year model.parameters.StartDay = self.start_date.timetuple().tm_yday @@ -79,27 +423,37 @@ def set_contact_matrices(self, model): contact_matrices = mio.ContactMatrixGroup(1, self.num_groups) baseline = np.ones((self.num_groups, self.num_groups)) * 12.32 - minimum = np.ones((self.num_groups, self.num_groups)) * 0 + minimum = np.zeros((self.num_groups, self.num_groups)) contact_matrices[0] = mio.ContactMatrix(baseline, minimum) model.parameters.ContactPatterns.cont_freq_mat = contact_matrices - def set_npis(self, params, end_date, damping_value): + def set_npis(self, params, end_date, damping_values): """ :param params: :param end_date: """ + start_damping_1 = DATE_TIME + datetime.timedelta(days=15) + start_damping_2 = DATE_TIME + datetime.timedelta(days=30) + start_damping_3 = DATE_TIME + datetime.timedelta(days=45) - start_damping = datetime.date( - 2020, 12, 18) + if start_damping_1 < end_date: + start_date = (start_damping_1 - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping( + mio.Damping(np.r_[damping_values[0]], t=start_date)) - if start_damping < end_date: - start_date = (start_damping - self.start_date).days + if start_damping_2 < end_date: + start_date = (start_damping_2 - self.start_date).days params.ContactPatterns.cont_freq_mat[0].add_damping( - mio.Damping(np.r_[damping_value], t=start_date)) + mio.Damping(np.r_[damping_values[1]], t=start_date)) - def get_graph(self, end_date): + if start_damping_3 < end_date: + start_date = (start_damping_3 - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping( + mio.Damping(np.r_[damping_values[2]], t=start_date)) + + def get_graph(self, end_date, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): """ :param end_date: @@ -107,7 +461,8 @@ def get_graph(self, end_date): """ print("Initializing model...") model = Model(self.num_groups) - self.set_covid_parameters(model) + self.set_covid_parameters( + model, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) self.set_contact_matrices(model) print("Model initialized.") @@ -115,11 +470,10 @@ def get_graph(self, end_date): scaling_factor_infected = [2.5] scaling_factor_icu = 1.0 - tnt_capacity_factor = 7.5 / 100000. data_dir_Spain = os.path.join(self.data_dir, "Spain") mobility_data_file = os.path.join( - data_dir_Spain, "mobility", "commuter_mobility_2022.txt") + data_dir_Spain, "mobility", "commuter_mobility.txt") pydata_dir = os.path.join(data_dir_Spain, "pydata") path_population_data = os.path.join( @@ -133,17 +487,16 @@ def get_graph(self, end_date): mio.Date(end_date.year, end_date.month, end_date.day), pydata_dir, path_population_data, True, graph, scaling_factor_infected, - scaling_factor_icu, 0, 0, False, False) + scaling_factor_icu, 0, 0, False) print("Setting edges...") - mio.osecir.set_edges( - mobility_data_file, graph, 1) + mio.osecir.set_edges(mobility_data_file, graph, 1) print("Graph created.") return graph - def run(self, num_days_sim, damping_values, save_graph=True): + def run(self, num_days_sim, damping_values, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob, save_graph=True): """ :param num_days_sim: @@ -155,19 +508,14 @@ def run(self, num_days_sim, damping_values, save_graph=True): mio.set_log_level(mio.LogLevel.Warning) end_date = self.start_date + datetime.timedelta(days=num_days_sim) - graph = self.get_graph(end_date) - - if save_graph: - path_graph = os.path.join(self.results_dir, "graph") - if not os.path.exists(path_graph): - os.makedirs(path_graph) - osecir.write_graph(graph, path_graph) + graph = self.get_graph(end_date, t_E, t_ISy, t_ISev, + t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) mobility_graph = osecir.MobilityGraph() for node_idx in range(graph.num_nodes): node = graph.get_node(node_idx) self.set_npis(node.property.parameters, - end_date, damping_values[node_idx]) + end_date, damping_values[dd.Provincia_to_Comunidad[node.id]-1]) mobility_graph.add_node(node.id, node.property) for edge_idx in range(graph.num_edges): mobility_graph.add_edge( @@ -177,90 +525,371 @@ def run(self, num_days_sim, damping_values, save_graph=True): mobility_sim = osecir.MobilitySimulation(mobility_graph, t0=0, dt=0.5) mobility_sim.advance(num_days_sim) - results = [] - for node_idx in range(graph.num_nodes): - results.append(osecir.interpolate_simulation_result( - mobility_sim.graph.get_node(node_idx).property.result)) + results = {} + for node_idx in range(mobility_sim.graph.num_nodes): + results[f'region{node_idx}'] = osecir.interpolate_simulation_result( + mobility_sim.graph.get_node(node_idx).property.result) return results -def run_spain_nuts3_simulation(damping_values): +def run_spain_nuts3_simulation(damping_values, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): mio.set_log_level(mio.LogLevel.Warning) file_path = os.path.dirname(os.path.abspath(__file__)) sim = Simulation( data_dir=os.path.join(file_path, "../../../data"), - start_date=datetime.date(year=2020, month=12, day=12), + start_date=DATE_TIME, results_dir=os.path.join(file_path, "../../../results_osecir")) - num_days_sim = 50 - - results = sim.run(num_days_sim, damping_values) + num_days_sim = 60 - return {f'region{region}': results[region] for region in range(len(results))} + results = sim.run(num_days_sim, damping_values, t_E, t_ISy, + t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) + return results def prior(): - damping_values = np.random.uniform(0.0, 1.0, 47) - return {'damping_values': damping_values} + damping_values = np.zeros((NUM_DAMPING_POINTS, 19)) + for i in range(NUM_DAMPING_POINTS): + mean = np.random.uniform(0, 1) + scale = 0.1 + a, b = (0 - mean) / scale, (1 - mean) / scale + damping_values[i] = truncnorm.rvs( + a=a, b=b, loc=mean, scale=scale, size=19 + ) + return { + 'damping_values': np.transpose(damping_values), + 't_E': np.random.uniform(*bounds['t_E']), + 't_ISy': np.random.uniform(*bounds['t_ISy']), + 't_ISev': np.random.uniform(*bounds['t_ISev']), + 't_Cr': np.random.uniform(*bounds['t_Cr']), + 'mu_CR': np.random.uniform(*bounds['mu_CR']), + 'mu_IH': np.random.uniform(*bounds['mu_IH']), + 'mu_HU': np.random.uniform(*bounds['mu_HU']), + 'mu_UD': np.random.uniform(*bounds['mu_UD']), + 'transmission_prob': np.random.uniform(*bounds['transmission_prob']) + } + + +def load_divi_data(): + file_path = os.path.dirname(os.path.abspath(__file__)) + divi_path = os.path.join(file_path, "../../../data/Spain/pydata") + + data = pd.read_json(os.path.join(divi_path, "provincia_icu.json")) + # data = data[data['Date'] >= np.datetime64(DATE_TIME)] + # data = data[data['Date'] <= np.datetime64(DATE_TIME + datetime.timedelta(days=60))] + # data = data.sort_values(by=['ID_County', 'Date']) + # divi_data = data.pivot(index='Date', columns='ID_County', values='ICU') + # divi_dict = {} + # for i, region_id in enumerate(region_ids): + # if region_id not in no_icu_ids: + # divi_dict[f"region{i}"] = divi_data[region_id].to_numpy()[None, :, None] + # else: + # divi_dict[f"no_icu_region{i}"] = np.zeros((1, divi_data.shape[0], 1)) + # return divi_dict +def extract_observables(simulation_results, observable_index=7): + for key in simulation_results.keys(): + if key not in inference_params: + simulation_results[key] = simulation_results[key][:, :, observable_index][..., np.newaxis] + return simulation_results + + +def create_train_data(filename, number_samples=1000): + + simulator = bf.simulators.make_simulator( + [prior, run_spain_nuts3_simulation] + ) + trainings_data = simulator.sample(number_samples) + trainings_data = extract_observables(trainings_data) + with open(filename, 'wb') as f: + pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) + + +def load_pickle(path): + with open(path, "rb") as f: + return pickle.load(f) + +def is_region_key(k: str) -> bool: + return 'region' in k + +def apply_aug(d: dict, aug) -> dict: + return {k: np.clip(aug(v), 0, None) if is_region_key(k) else v for k, v in d.items()} + +def concat_dicts(base: dict, new: dict) -> dict: + missing = set(base) - set(new) + if missing: + raise KeyError(f"new dict missing keys: {sorted(missing)}") + for k in base: + base[k] = np.concatenate([base[k], new[k]]) + return base + +def aggregate_states(d: dict) -> None: + n_regions = len(region_ids) + # per state + for r in range(n_regions): + print("indx: ", r) + print("region_id: ", region_ids[r]) + print("comunidad: ", dd.Provincia_to_Comunidad[region_ids[r]]) + for state in range(19): + idxs = [ + r for r in range(n_regions) + if dd.Provincia_to_Comunidad[region_ids[r]] == state + 1 + ] + d[f"comunidad{state}"] = np.sum([d[f"region{r}"] for r in idxs], axis=0) + # print(d[f"comunidad{state}"].shape) + # all allowed regions + d["state"] = np.sum([d[f"comunidad{r}"] for r in range(19)], axis=0) + +def combine_results(dict_list): + combined = {} + for d in dict_list: + combined = concat_dicts(combined, d) if combined else d + return combined + +def skip_2weeks(d:dict) -> dict: + return {k: v[:, 14:, :] if is_region_key(k) else v for k, v in d.items()} + +def get_workflow(): + + simulator = bf.make_simulator( + [prior, run_germany_nuts3_simulation] + ) + adapter = ( + bf.Adapter() + .to_array() + .convert_dtype("float64", "float32") + .constrain("damping_values", lower=0.0, upper=1.0) + .constrain("t_E", lower=bounds["t_E"][0], upper=bounds["t_E"][1]) + .constrain("t_ISy", lower=bounds["t_ISy"][0], upper=bounds["t_ISy"][1]) + .constrain("t_ISev", lower=bounds["t_ISev"][0], upper=bounds["t_ISev"][1]) + .constrain("t_Cr", lower=bounds["t_Cr"][0], upper=bounds["t_Cr"][1]) + .constrain("mu_CR", lower=bounds["mu_CR"][0], upper=bounds["mu_CR"][1]) + .constrain("mu_IH", lower=bounds["mu_IH"][0], upper=bounds["mu_IH"][1]) + .constrain("mu_HU", lower=bounds["mu_HU"][0], upper=bounds["mu_HU"][1]) + .constrain("mu_UD", lower=bounds["mu_UD"][0], upper=bounds["mu_UD"][1]) + .constrain("transmission_prob", lower=bounds["transmission_prob"][0], upper=bounds["transmission_prob"][1]) + .concatenate( + ["damping_values", "t_E", "t_ISy", "t_ISev", "t_Cr", + "mu_CR", "mu_IH", "mu_HU", "mu_UD", "transmission_prob"], + into="inference_variables", + axis=-1 + ) + .concatenate(summary_vars, into="summary_variables", axis=-1) + ) + + summary_network = bf.networks.FusionTransformer( + summary_dim=(len(bounds)+16*NUM_DAMPING_POINTS)*2, dropout=0.1 + ) + inference_network = bf.networks.FlowMatching(subnet_kwargs={'widths': (512, 512, 512, 512, 512)}) + + # aug = bf.augmentations.NNPE(spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False) + workflow = bf.BasicWorkflow( + simulator=simulator, + adapter=adapter, + summary_network=summary_network, + inference_network=inference_network, + standardize='all' + # augmentations={f'region{i}': aug for i in range(len(region_ids)) if region_ids[i] not in no_icu_ids} + # aggregation of the states would need to be recomputed every time a different noise realization is applied + ) + + return workflow + + +def run_training(name, num_training_files=20): + train_template = name+"/trainings_data{i}_"+name+".pickle" + val_path = f"{name}/validation_data_{name}.pickle" + + aug = bf.augmentations.NNPE( + spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False + ) + + # training data + train_files = [train_template.format(i=i) for i in range(1, 1+num_training_files)] + trainings_data = None + for p in train_files: + d = load_pickle(p) + d = apply_aug(d, aug=aug) # only on region keys + d = skip_2weeks(d) + d['damping_values'] = d['damping_values'].reshape((d['damping_values'].shape[0], -1)) + if trainings_data is None: + trainings_data = d + else: + trainings_data = concat_dicts(trainings_data, d) + print(trainings_data['region0'].shape) + aggregate_states(trainings_data) + + # validation data + validation_data = apply_aug(load_pickle(val_path), aug=aug) + validation_data = skip_2weeks(validation_data) + aggregate_states(validation_data) + validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) + + # check data + workflow = get_workflow() + print("summary_variables shape:", workflow.adapter(trainings_data)["summary_variables"].shape) + print("inference_variables shape:", workflow.adapter(trainings_data)["inference_variables"].shape) + + history = workflow.fit_offline( + data=trainings_data, epochs=500, batch_size=64, validation_data=validation_data + ) + + workflow.approximator.save( + filepath=os.path.join(f"{name}/model_{name}.keras") + ) + + plots = workflow.plot_default_diagnostics( + test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True} + ) + plots['losses'].savefig(f'{name}/losses_{name}.png') + plots['recovery'].savefig(f'{name}/recovery_{name}.png') + plots['calibration_ecdf'].savefig(f'{name}/calibration_ecdf_{name}.png') + #plots['z_score_contraction'].savefig(f'{name}/z_score_contraction_{name}.png') + + +def run_inference(name, num_samples=1000, on_synthetic_data=False): + val_path = f"{name}/validation_data_{name}.pickle" + synthetic = "_synthetic" if on_synthetic_data else "" + + aug = bf.augmentations.NNPE( + spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False + ) + + # validation data + validation_data = load_pickle(val_path) # synthetic data + if on_synthetic_data: + # validation data + validation_data = apply_aug(validation_data, aug=aug) + validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) + validation_data_skip2w = skip_2weeks(validation_data) + aggregate_states(validation_data_skip2w) + divi_dict = validation_data + divi_region_keys = region_keys_sorted(divi_dict) + + divi_data = np.concatenate( + [divi_dict[key] for key in divi_region_keys], axis=-1 + )[0] # only one dataset + else: + divi_dict = load_divi_data() + validation_data_skip2w = skip_2weeks(divi_dict) + aggregate_states(validation_data_skip2w) + divi_region_keys = region_keys_sorted(divi_dict) + divi_data = np.concatenate( + [divi_dict[key] for key in divi_region_keys], axis=-1 + )[0] + + workflow = get_workflow() + workflow.approximator = keras.models.load_model( + filepath=os.path.join(f"{name}/model_{name}.keras") + ) + + if os.path.exists(f'{name}/sims_{name}{synthetic}_with_aug.pickle') and os.path.exists(f'{name}/sims_{name}{synthetic}.pickle'): + simulations = load_pickle(f'{name}/sims_{name}{synthetic}.pickle') + simulations_aug = load_pickle(f'{name}/sims_{name}{synthetic}_with_aug.pickle') + print("loaded simulations from file") + else: + samples = workflow.sample(conditions=validation_data_skip2w, num_samples=num_samples) + samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], num_samples, 19, NUM_DAMPING_POINTS)) + results = [] + for i in range(num_samples): # we only have one dataset for inference here + result = run_germany_nuts3_simulation( + damping_values=samples['damping_values'][0, i], + t_E=samples['t_E'][0, i], t_ISy=samples['t_ISy'][0, i], + t_ISev=samples['t_ISev'][0, i], t_Cr=samples['t_Cr'][0, i], + mu_CR=samples['mu_CR'][0, i], mu_IH=samples['mu_IH'][0, i], + mu_HU=samples['mu_HU'][0, i], mu_UD=samples['mu_UD'][0, i], + transmission_prob=samples['transmission_prob'][0, i] + ) + for key in result.keys(): + result[key] = np.array(result[key])[None, ...] # add sample axis + results.append(result) + results = combine_results(results) + results = extract_observables(results) + results_aug = apply_aug(results, aug=aug) + + # get sims in shape (samples, time, regions) + simulations = np.zeros((num_samples, divi_data.shape[0], divi_data.shape[1])) + simulations_aug = np.zeros((num_samples, divi_data.shape[0], divi_data.shape[1])) + for i in range(num_samples): + simulations[i] = np.concatenate([results[key][i] for key in divi_region_keys], axis=-1) + simulations_aug[i] = np.concatenate([results_aug[key][i] for key in divi_region_keys], axis=-1) + + # save sims + with open(f'{name}/sims_{name}{synthetic}.pickle', 'wb') as f: + pickle.dump(simulations, f, pickle.HIGHEST_PROTOCOL) + with open(f'{name}/sims_{name}{synthetic}_with_aug.pickle', 'wb') as f: + pickle.dump(simulations_aug, f, pickle.HIGHEST_PROTOCOL) + + samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], samples['damping_values'].shape[1], -1)) + validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) + + plot = bf.diagnostics.pairs_posterior(samples, priors=validation_data, dataset_id=0) + plot.savefig(f'{name}/pairs_posterior_{name}{synthetic}.png') + + plot_all_regions(simulations, divi_data, name, synthetic, with_aug="") + plot_all_regions(simulations_aug, divi_data, name, synthetic, with_aug="_with_aug") + + plot_aggregated_to_comunidades(simulations, divi_data, name, synthetic, with_aug="") + plot_aggregated_to_comunidades(simulations_aug, divi_data, name, synthetic, with_aug="_with_aug") + + fig, axes = plt.subplots(2, 2, figsize=(12, 12), sharex=True, sharey='row', constrained_layout=True) + # Plot without augmentation + plot_aggregated_over_regions( + simulations, true_data=divi_data, label="Region Aggregated Median (No Aug)", ax=axes[0, 0], color="#132a70" + ) + axes[0, 0].set_title("Without Augmentation") + # Plot with augmentation + plot_aggregated_over_regions( + simulations_aug, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[0, 1], color="#132a70" + ) + axes[0, 1].set_title("With Augmentation") + # Plot without augmentation (80% quantile only) + plot_aggregated_over_regions( + simulations, true_data=divi_data, label="Region Aggregated Median (No Aug)", ax=axes[1, 0], color="#132a70", only_80q=True + ) + axes[1, 0].set_title("Without Augmentation (80% Quantile)") + # Plot with augmentation (80% quantile only) + plot_aggregated_over_regions( + simulations_aug, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[1, 1], color="#132a70", only_80q=True + ) + axes[1, 1].set_title("With Augmentation (80% Quantile)") + plt.savefig(f'{name}/region_aggregated_{name}{synthetic}.png') + plt.close() + + fig, axis = plt.subplots(1, 2, figsize=(10, 4), sharex=True, layout="constrained") + ax = calibration_curves_per_region(simulations, divi_data, ax=axis[0]) + ax, stats = calibration_median_mad_over_regions(simulations, divi_data, ax=axis[1]) + plt.savefig(f'{name}/calibration_per_region_{name}{synthetic}.png') + plt.close() + fig, axis = plt.subplots(1, 2, figsize=(10, 4), sharex=True, layout="constrained") + ax = calibration_curves_per_region(simulations_aug, divi_data, ax=axis[0]) + ax, stats = calibration_median_mad_over_regions(simulations_aug, divi_data, ax=axis[1]) + plt.savefig(f'{name}/calibration_per_region_{name}{synthetic}_with_aug.png') + plt.close() + + # plot_icu_on_spain(simulations, name, synthetic, with_aug="") + # plot_icu_on_spain(simulations_aug, name, synthetic, with_aug="_with_aug") + + simulation_agg = np.sum(simulations, axis=-1, keepdims=True) # sum over regions + simulation_aug_agg = np.sum(simulations_aug, axis=-1, keepdims=True) + + rmse = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes(simulation_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True), normalize=False) + rmse_aug = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes(simulation_aug_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True), normalize=False) + print("Mean RMSE over regions:", rmse["values"].mean()) + print("Mean RMSE over regions (with aug):", rmse_aug["values"].mean()) + + cal_error = bf.diagnostics.metrics.calibration_error(np.swapaxes(simulation_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True)) + cal_error_aug = bf.diagnostics.metrics.calibration_error(np.swapaxes(simulation_aug_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True)) + print("Mean Calibration Error over regions:", cal_error["values"].mean()) + print("Mean Calibration Error over regions (with aug):", cal_error_aug["values"].mean()) if __name__ == "__main__": - test = prior() - run_spain_nuts3_simulation(test['damping_values']) - # import os - # os.environ["KERAS_BACKEND"] = "tensorflow" - - # import bayesflow as bf - - # simulator = bf.simulators.make_simulator([prior, run_germany_nuts3_simulation]) - # # trainings_data = simulator.sample(1000) - - # # for region in range(400): - # # trainings_data[f'region{region}'] = trainings_data[f'region{region}'][:,:, 8][..., np.newaxis] - - # # with open('validation_data_400params.pickle', 'wb') as f: - # # pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) - - # with open('trainings_data1_400params.pickle', 'rb') as f: - # trainings_data = pickle.load(f) - # for i in range(9): - # with open(f'trainings_data{i+2}_400params.pickle', 'rb') as f: - # data = pickle.load(f) - # trainings_data = {k: np.concatenate([trainings_data[k], data[k]]) for k in trainings_data.keys()} - - # with open('validation_data_400params.pickle', 'rb') as f: - # validation_data = pickle.load(f) - - # adapter = ( - # bf.Adapter() - # .to_array() - # .convert_dtype("float64", "float32") - # .constrain("damping_values", lower=0.0, upper=1.0) - # .rename("damping_values", "inference_variables") - # .concatenate([f'region{i}' for i in range(400)], into="summary_variables", axis=-1) - # .log("summary_variables", p1=True) - # ) - - # print("summary_variables shape:", adapter(trainings_data)["summary_variables"].shape) - - # summary_network = bf.networks.TimeSeriesNetwork(summary_dim=700, recurrent_dim=256) - # inference_network = bf.networks.DiffusionModel(subnet_kwargs={'widths': {512, 512, 512, 512, 512}}) - - # workflow = bf.BasicWorkflow( - # simulator=simulator, - # adapter=adapter, - # summary_network=summary_network, - # inference_network=inference_network, - # standardize='all' - # ) - - # history = workflow.fit_offline(data=trainings_data, epochs=1000, batch_size=32, validation_data=validation_data) - - # # workflow.approximator.save(filepath=os.path.join(os.path.dirname(__file__), "model_10params.keras")) - - # plots = workflow.plot_default_diagnostics(test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True}) - # plots['losses'].savefig('losses_diffusionmodel_400params.png') - # plots['recovery'].savefig('recovery_diffusionmodel_400params.png') - # plots['calibration_ecdf'].savefig('calibration_ecdf_diffusionmodel_400params.png') - # plots['z_score_contraction'].savefig('z_score_contraction_diffusionmodel_400params.png') + name = "spain_nuts3" + + if not os.path.exists(name): + os.makedirs(name) + # create_train_data(filename=f'{name}/trainings_data1_{name}.pickle', number_samples=2000) + run_training(name=name, num_training_files=10) + # run_inference(name=name, on_synthetic_data=True) + # run_inference(name=name, on_synthetic_data=False) \ No newline at end of file diff --git a/pycode/memilio-epidata/memilio/epidata/defaultDict.py b/pycode/memilio-epidata/memilio/epidata/defaultDict.py index 4172220849..69277a5f53 100644 --- a/pycode/memilio-epidata/memilio/epidata/defaultDict.py +++ b/pycode/memilio-epidata/memilio/epidata/defaultDict.py @@ -722,15 +722,15 @@ def invert_dict(dict_to_invert): 15: 'Córdoba', 16: 'Coruña, A', 17: 'Cuenca', - 18: 'Girona', + 18: 'Gerona', 19: 'Granada', 20: 'Guadalajara', - 21: 'Gipuzkoa', + 21: 'Guipúzcoa', 22: 'Huelva', 23: 'Huesca', 24: 'Jaén', 25: 'León', - 26: 'Lleida', + 26: 'Lérida', 27: 'Rioja, La', 28: 'Lugo', 29: 'Madrid', @@ -752,13 +752,89 @@ def invert_dict(dict_to_invert): 45: 'Toledo', 46: 'Valencia/València', 47: 'Valladolid', - 48: 'Bizkaia', + 48: 'Vizcaya', 49: 'Zamora', 50: 'Zaragoza', 51: 'Ceuta', 52: 'Melilla', - 53: 'Ourense'} + 53: 'Orense'} +Comunidades = { + 1: 'Andalucía', + 2: 'Aragón', + 3: 'Principado de Asturias', + 4: 'Islas Baleares', + 5: 'Canarias', + 6: 'Cantabria', + 7: 'Castilla y León', + 8: 'Castilla-La Mancha', + 9: 'Cataluña', + 10: 'Comunidad Valenciana', + 11: 'Extremadura', + 12: 'Galicia', + 13: 'Comunidad de Madrid', + 14: 'Región de Murcia', + 15: 'Comunidad Foral de Navarra', + 16: 'País Vasco', + 17: 'La Rioja', + 18: 'Ceuta', + 19: 'Melilla' +} + +Provincia_to_Comunidad = { + 2: 16, + 3: 8, + 4: 10, + 5: 1, + 6: 7, + 7: 11, + 8: 4, + 9: 9, + 10: 7, + 11: 11, + 12: 1, + 13: 10, + 14: 8, + 15: 1, + 16: 12, + 17: 8, + 18: 9, + 19: 1, + 20: 8, + 21: 16, + 22: 1, + 23: 2, + 24: 1, + 25: 7, + 26: 9, + 27: 17, + 28: 12, + 29: 13, + 30: 1, + 31: 14, + 32: 15, + 33: 3, + 34: 7, + 35: 5, + 36: 12, + 37: 7, + 38: 5, + 39: 6, + 40: 7, + 41: 1, + 42: 7, + 43: 9, + 44: 2, + 45: 8, + 46: 10, + 47: 7, + 48: 16, + 49: 7, + 50: 2, + 51: 18, + 52: 19, + 53: 12 +} Provincia_ISO_to_ID = {'A': 4, 'AB': 3, @@ -810,5 +886,5 @@ def invert_dict(dict_to_invert): 'V': 46, 'VA': 47, 'VI': 2, - 'ZA': 48, + 'ZA': 49, 'Z': 50} diff --git a/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py b/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py index 8ccde09fc5..aa6ffda746 100644 --- a/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py +++ b/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py @@ -39,6 +39,7 @@ def get_population_data(): def fetch_icu_data(): + # https://www.sanidad.gob.es/areas/alertasEmergenciasSanitarias/alertasActuales/nCov/capacidadAsistencial.htm?utm_source=chatgpt.com download_url = 'https://www.sanidad.gob.es/areas/alertasEmergenciasSanitarias/alertasActuales/nCov/documentos/Datos_Capacidad_Asistencial_Historico_14072023.csv' req = requests.get(download_url) req.encoding = 'ISO-8859-1' @@ -75,10 +76,11 @@ def get_icu_data(): def fetch_case_data(): + # https://datos.gob.es/es/catalogo/e05070101-evolucion-de-enfermedad-por-el-coronavirus-covid-19 download_url = 'https://cnecovid.isciii.es/covid19/resources/casos_diagnostico_provincia.csv' req = requests.get(download_url) - df = pd.read_csv(io.StringIO(req.text), sep=',') + df = pd.read_csv(io.StringIO(req.text), sep=',', keep_default_na=False, na_values=[]) return df @@ -198,5 +200,5 @@ def get_mobility_data(data_dir, start_date='2022-08-01', end_date='2022-08-31', matrix = df.pivot(index='id_origin', columns='id_destination', values='n_trips').fillna(0) - gd.write_dataframe(matrix, data_dir, 'commuter_mobility', 'txt', { + gd.write_dataframe(matrix, mobility_dir, 'commuter_mobility', 'txt', { 'sep': ' ', 'index': False, 'header': False}) From 96cecda54331da09f9f3276149deda18e402d54b Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Tue, 11 Nov 2025 10:33:58 +0100 Subject: [PATCH 66/73] python bindings for spain io --- .../simulation/bindings/models/osecir.cpp | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp b/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp index b0d82a94e9..b7363ced4f 100644 --- a/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp +++ b/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp @@ -306,6 +306,25 @@ PYBIND11_MODULE(_simulation_osecir, m) return pymio::check_and_throw(result); }, py::return_value_policy::move); + + m.def( + "set_nodes_provincias", + [](const mio::osecir::Parameters& params, mio::Date start_date, mio::Date end_date, + const std::string& data_dir, const std::string& population_data_path, bool is_node_for_county, + mio::Graph, mio::MobilityParameters>& params_graph, + const std::vector& scaling_factor_inf, double scaling_factor_icu, double tnt_capacity_factor, + int num_days = 0, bool export_time_series = false) { + auto result = mio::set_nodes, + mio::osecir::ContactPatterns, mio::osecir::Model, + mio::MobilityParameters, mio::osecir::Parameters, + decltype(mio::osecir::read_input_data_provincias>), + decltype(mio::get_provincia_ids)>( + params, start_date, end_date, data_dir, population_data_path, is_node_for_county, params_graph, + mio::osecir::read_input_data_provincias>, mio::get_provincia_ids, + scaling_factor_inf, scaling_factor_icu, tnt_capacity_factor, num_days, export_time_series); + return pymio::check_and_throw(result); + }, + py::return_value_policy::move); pymio::iterable_enum(m, "ContactLocation") .value("Home", ContactLocation::Home) @@ -336,6 +355,7 @@ PYBIND11_MODULE(_simulation_osecir, m) #ifdef MEMILIO_HAS_JSONCPP pymio::bind_write_graph>(m); + pymio::bind_read_graph>(m); m.def( "read_input_data_county", [](std::vector>& model, mio::Date date, const std::vector& county, From 96a6edd7ab6d6140b84b374ee30b0ca7b823062c Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Wed, 12 Nov 2025 11:25:17 +0100 Subject: [PATCH 67/73] remove island from spain simulation --- .../examples/simulation/graph_spain_nuts3.py | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/pycode/examples/simulation/graph_spain_nuts3.py b/pycode/examples/simulation/graph_spain_nuts3.py index c08a2c9b16..f92d21b342 100644 --- a/pycode/examples/simulation/graph_spain_nuts3.py +++ b/pycode/examples/simulation/graph_spain_nuts3.py @@ -42,10 +42,12 @@ excluded_ids = [8, 35, 38, 51, 52] region_ids = [region_id for region_id in dd.Provincias.keys() if region_id not in excluded_ids] +excluded_comunidades = [4, 5, 18, 19] +comunidades = [comunidad for comunidad in dd.Comunidades.keys() if comunidad not in excluded_comunidades] inference_params = ['damping_values', 't_E', 't_ISy', 't_ISev', 't_Cr', 'mu_CR', 'mu_IH', 'mu_HU', 'mu_UD', 'transmission_prob'] -summary_vars = ['state'] + [f'comunidad{i}' for i in range(len(dd.Comunidades))] + [ - f'region{i}' for i in range(len(dd.Provincias))] +summary_vars = ['state'] + [f'comunidad{i}' for i in range(len(comunidades))] + [ + f'region{i}' for i in range(len(region_ids))] bounds = { 't_E': (1.0, 5.2), @@ -626,19 +628,14 @@ def concat_dicts(base: dict, new: dict) -> dict: def aggregate_states(d: dict) -> None: n_regions = len(region_ids) # per state - for r in range(n_regions): - print("indx: ", r) - print("region_id: ", region_ids[r]) - print("comunidad: ", dd.Provincia_to_Comunidad[region_ids[r]]) - for state in range(19): + for i, state in enumerate(comunidades): idxs = [ r for r in range(n_regions) - if dd.Provincia_to_Comunidad[region_ids[r]] == state + 1 + if dd.Provincia_to_Comunidad[region_ids[r]] == state ] - d[f"comunidad{state}"] = np.sum([d[f"region{r}"] for r in idxs], axis=0) - # print(d[f"comunidad{state}"].shape) + d[f"comunidad{i}"] = np.sum([d[f"region{r}"] for r in idxs], axis=0) # all allowed regions - d["state"] = np.sum([d[f"comunidad{r}"] for r in range(19)], axis=0) + d["state"] = np.sum([d[f"comunidad{r}"] for r in range(len(comunidades))], axis=0) def combine_results(dict_list): combined = {} @@ -652,7 +649,7 @@ def skip_2weeks(d:dict) -> dict: def get_workflow(): simulator = bf.make_simulator( - [prior, run_germany_nuts3_simulation] + [prior, run_spain_nuts3_simulation] ) adapter = ( bf.Adapter() @@ -716,7 +713,6 @@ def run_training(name, num_training_files=20): trainings_data = d else: trainings_data = concat_dicts(trainings_data, d) - print(trainings_data['region0'].shape) aggregate_states(trainings_data) # validation data @@ -889,7 +885,7 @@ def run_inference(name, num_samples=1000, on_synthetic_data=False): if not os.path.exists(name): os.makedirs(name) - # create_train_data(filename=f'{name}/trainings_data1_{name}.pickle', number_samples=2000) + # create_train_data(filename=f'{name}/validation_data_{name}.pickle', number_samples=100) run_training(name=name, num_training_files=10) # run_inference(name=name, on_synthetic_data=True) # run_inference(name=name, on_synthetic_data=False) \ No newline at end of file From 801d870de8d622454c0b2a8c7521c3aa8ed0b7dc Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Wed, 12 Nov 2025 13:52:23 +0100 Subject: [PATCH 68/73] fix inference and formatting --- .../examples/simulation/graph_spain_nuts3.py | 270 ++++++++++-------- 1 file changed, 147 insertions(+), 123 deletions(-) diff --git a/pycode/examples/simulation/graph_spain_nuts3.py b/pycode/examples/simulation/graph_spain_nuts3.py index f92d21b342..62e5c3ac00 100644 --- a/pycode/examples/simulation/graph_spain_nuts3.py +++ b/pycode/examples/simulation/graph_spain_nuts3.py @@ -17,33 +17,30 @@ # See the License for the specific language governing permissions and # limitations under the License. ############################################################################# +import geopandas as gpd +from memilio.epidata import defaultDict as dd +from memilio.simulation.osecir import Model, interpolate_simulation_result +import memilio.simulation.osecir as osecir +import memilio.simulation as mio +import keras +import bayesflow as bf +from matplotlib.patches import Patch +from scipy.stats import truncnorm +import pickle +import datetime +import matplotlib.pyplot as plt +import pandas as pd +import numpy as np import os os.environ["KERAS_BACKEND"] = "tensorflow" -import numpy as np -import pandas as pd -import matplotlib.pyplot as plt -import datetime -import pickle -from scipy.stats import truncnorm - -from matplotlib.patches import Patch - -import bayesflow as bf -import keras - -import memilio.simulation as mio -import memilio.simulation.osecir as osecir -from memilio.simulation.osecir import Model, interpolate_simulation_result -from memilio.epidata import defaultDict as dd - -import geopandas as gpd excluded_ids = [8, 35, 38, 51, 52] region_ids = [region_id for region_id in dd.Provincias.keys() if region_id not in excluded_ids] excluded_comunidades = [4, 5, 18, 19] -comunidades = [comunidad for comunidad in dd.Comunidades.keys() if comunidad not in excluded_comunidades] +comunidades = [comunidad for comunidad in dd.Comunidades.keys( +) if comunidad not in excluded_comunidades] inference_params = ['damping_values', 't_E', 't_ISy', 't_ISev', 't_Cr', 'mu_CR', 'mu_IH', 'mu_HU', 'mu_UD', 'transmission_prob'] summary_vars = ['state'] + [f'comunidad{i}' for i in range(len(comunidades))] + [ @@ -65,6 +62,7 @@ DATE_TIME = datetime.date(year=2020, month=10, day=1) NUM_DAMPING_POINTS = 3 + def plot_region_fit( data: np.ndarray, region: int, @@ -111,7 +109,7 @@ def plot_region_fit( ax.set_xlabel("Time", fontsize=12) ax.set_ylabel("ICU", fontsize=12) - ax.set_title(f"{dd.County[region_ids[region]]}", fontsize=12) + ax.set_title(f"{dd.Provincias[region_ids[region]]}", fontsize=12) if label is not None: ax.legend(fontsize=11, loc="upper right") @@ -146,7 +144,7 @@ def plot_aggregated_over_regions( fig, ax = plt.subplots() ax.plot(x, agg_median, lw=2, - label=label or "Aggregated over regions", color=color) + label=label or "Aggregated over regions", color=color) ax.fill_between(x, qs_80[0], qs_80[1], alpha=0.5, color=color, label="80% CI") if not only_80q: @@ -165,31 +163,36 @@ def plot_aggregated_over_regions( def plot_aggregated_to_comunidades(data, true_data, name, synthetic, with_aug): - fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(25, 25), layout="constrained") + fig, ax = plt.subplots(nrows=4, ncols=4, figsize=( + 25, 25), layout="constrained") ax = ax.flatten() - for state in range(19): - idxs = [i for i, region_id in enumerate(region_ids) if dd.Provincia_to_Comunidad[region_id] == state + 1] + for state in range(16): + idxs = [i for i, region_id in enumerate( + region_ids) if dd.Provincia_to_Comunidad[region_id] == state + 1] plot_aggregated_over_regions( - data[:, :, idxs], # Add a dummy region axis for compatibility - true_data=true_data[:, idxs] if true_data is not None else None, - ax=ax[state], - label=f"State {state + 1}", - color=f"C{state % 10}" # Cycle through 10 colors + data[:, :, idxs], # Add a dummy region axis for compatibility + true_data=true_data[:, idxs] if true_data is not None else None, + ax=ax[state], + label=f"State {state + 1}", + color=f"C{state % 10}" # Cycle through 10 colors ) plt.savefig(f'{name}/comunidades_{name}{synthetic}{with_aug}.png') plt.close() -# plot simulations for all regions in 10x4 blocks +# plot simulations for all regions in 6x4 blocks + + def plot_all_regions(simulations, divi_data, name, synthetic, with_aug): n_regions = simulations.shape[-1] n_cols = 4 - n_rows = 10 + n_rows = 6 n_blocks = (n_regions + n_cols * n_rows - 1) // (n_cols * n_rows) for block in range(n_blocks): start_idx = block * n_cols * n_rows end_idx = min(start_idx + n_cols * n_rows, n_regions) - fig, ax = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(15, 25), layout="constrained") + fig, ax = plt.subplots(nrows=n_rows, ncols=n_cols, + figsize=(20, 25), layout="constrained") ax = ax.flatten() for i, region_idx in enumerate(range(start_idx, end_idx)): plot_region_fit( @@ -198,7 +201,8 @@ def plot_all_regions(simulations, divi_data, name, synthetic, with_aug): # Hide unused subplots for i in range(end_idx - start_idx, len(ax)): ax[i].axis("off") - plt.savefig(f'{name}/regions_block_{block + 1}_{name}{synthetic}{with_aug}.png') + plt.savefig( + f'{name}/regions_block_{block + 1}_{name}{synthetic}{with_aug}.png') plt.close() @@ -349,7 +353,8 @@ def plot_damping_values(damping_values, name, synthetic, with_aug): if i < 19: ax.stairs(med[i], edges=x, lw=2, color='red', baseline=None) ax.fill_between( - x, med_extended[i] - mad_extended[i], med_extended[i] + mad_extended[i], + x, med_extended[i] - + mad_extended[i], med_extended[i] + mad_extended[i], alpha=0.25, color='red', step='post' ) ax.set_title(f"{dd.Comunidades[i+2]}", fontsize=10) @@ -359,7 +364,8 @@ def plot_damping_values(damping_values, name, synthetic, with_aug): ax.axis('off') # Hide unused subplots plt.suptitle("Damping Values per Region", fontsize=14) - plt.savefig(f"{name}/damping_values{name}{synthetic}{with_aug}.png", dpi=300) + plt.savefig( + f"{name}/damping_values{name}{synthetic}{with_aug}.png", dpi=300) # Combined plot for all regions fig, ax = plt.subplots(figsize=(10, 6)) @@ -375,7 +381,9 @@ def plot_damping_values(damping_values, name, synthetic, with_aug): ax.set_xlabel("Time", fontsize=12) ax.set_ylabel("Damping Value", fontsize=12) ax.legend(fontsize=10, loc="upper right", ncol=2) - plt.savefig(f"{name}/damping_values_combined{name}{synthetic}{with_aug}.png", dpi=300) + plt.savefig( + f"{name}/damping_values_combined{name}{synthetic}{with_aug}.png", dpi=300) + class Simulation: """ """ @@ -389,13 +397,9 @@ def __init__(self, data_dir, start_date, results_dir): os.makedirs(self.results_dir) def set_covid_parameters(self, model, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): - """ - - :param model: - - """ model.parameters.TimeExposed[mio.AgeGroup(0)] = t_E - model.parameters.TimeInfectedNoSymptoms[mio.AgeGroup(0)] = 5.2 - t_E # todo: correct? + model.parameters.TimeInfectedNoSymptoms[mio.AgeGroup( + 0)] = 5.2 - t_E # todo: correct? model.parameters.TimeInfectedSymptoms[mio.AgeGroup(0)] = t_ISy model.parameters.TimeInfectedSevere[mio.AgeGroup(0)] = t_ISev model.parameters.TimeInfectedCritical[mio.AgeGroup(0)] = t_Cr @@ -417,11 +421,6 @@ def set_covid_parameters(self, model, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu model.parameters.Seasonality = mio.UncertainValue(0.2) def set_contact_matrices(self, model): - """ - - :param model: - - """ contact_matrices = mio.ContactMatrixGroup(1, self.num_groups) baseline = np.ones((self.num_groups, self.num_groups)) * 12.32 @@ -430,12 +429,6 @@ def set_contact_matrices(self, model): model.parameters.ContactPatterns.cont_freq_mat = contact_matrices def set_npis(self, params, end_date, damping_values): - """ - - :param params: - :param end_date: - - """ start_damping_1 = DATE_TIME + datetime.timedelta(days=15) start_damping_2 = DATE_TIME + datetime.timedelta(days=30) start_damping_3 = DATE_TIME + datetime.timedelta(days=45) @@ -456,11 +449,6 @@ def set_npis(self, params, end_date, damping_values): mio.Damping(np.r_[damping_values[2]], t=start_date)) def get_graph(self, end_date, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): - """ - - :param end_date: - - """ print("Initializing model...") model = Model(self.num_groups) self.set_covid_parameters( @@ -499,14 +487,6 @@ def get_graph(self, end_date, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_ return graph def run(self, num_days_sim, damping_values, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob, save_graph=True): - """ - - :param num_days_sim: - :param num_runs: (Default value = 10) - :param save_graph: (Default value = True) - :param create_gif: (Default value = True) - - """ mio.set_log_level(mio.LogLevel.Warning) end_date = self.start_date + datetime.timedelta(days=num_days_sim) @@ -550,6 +530,7 @@ def run_spain_nuts3_simulation(damping_values, t_E, t_ISy, t_ISev, t_Cr, mu_CR, return results + def prior(): damping_values = np.zeros((NUM_DAMPING_POINTS, 19)) for i in range(NUM_DAMPING_POINTS): @@ -578,21 +559,23 @@ def load_divi_data(): divi_path = os.path.join(file_path, "../../../data/Spain/pydata") data = pd.read_json(os.path.join(divi_path, "provincia_icu.json")) - # data = data[data['Date'] >= np.datetime64(DATE_TIME)] - # data = data[data['Date'] <= np.datetime64(DATE_TIME + datetime.timedelta(days=60))] - # data = data.sort_values(by=['ID_County', 'Date']) - # divi_data = data.pivot(index='Date', columns='ID_County', values='ICU') - # divi_dict = {} - # for i, region_id in enumerate(region_ids): - # if region_id not in no_icu_ids: - # divi_dict[f"region{i}"] = divi_data[region_id].to_numpy()[None, :, None] - # else: - # divi_dict[f"no_icu_region{i}"] = np.zeros((1, divi_data.shape[0], 1)) - # return divi_dict + data = data[data['Date'] >= np.datetime64(DATE_TIME)] + data = data[data['Date'] <= np.datetime64( + DATE_TIME + datetime.timedelta(days=60))] + data = data.sort_values(by=['ID_County', 'Date']) + divi_data = data.pivot(index='Date', columns='ID_County', values='ICU') + divi_dict = {} + for i, region_id in enumerate(region_ids): + divi_dict[f"region{i}"] = divi_data[region_id].to_numpy()[ + None, :, None] + return divi_dict + + def extract_observables(simulation_results, observable_index=7): for key in simulation_results.keys(): if key not in inference_params: - simulation_results[key] = simulation_results[key][:, :, observable_index][..., np.newaxis] + simulation_results[key] = simulation_results[key][:, + :, observable_index][..., np.newaxis] return simulation_results @@ -611,12 +594,15 @@ def load_pickle(path): with open(path, "rb") as f: return pickle.load(f) + def is_region_key(k: str) -> bool: return 'region' in k + def apply_aug(d: dict, aug) -> dict: return {k: np.clip(aug(v), 0, None) if is_region_key(k) else v for k, v in d.items()} + def concat_dicts(base: dict, new: dict) -> dict: missing = set(base) - set(new) if missing: @@ -625,6 +611,7 @@ def concat_dicts(base: dict, new: dict) -> dict: base[k] = np.concatenate([base[k], new[k]]) return base + def aggregate_states(d: dict) -> None: n_regions = len(region_ids) # per state @@ -635,7 +622,9 @@ def aggregate_states(d: dict) -> None: ] d[f"comunidad{i}"] = np.sum([d[f"region{r}"] for r in idxs], axis=0) # all allowed regions - d["state"] = np.sum([d[f"comunidad{r}"] for r in range(len(comunidades))], axis=0) + d["state"] = np.sum([d[f"comunidad{r}"] + for r in range(len(comunidades))], axis=0) + def combine_results(dict_list): combined = {} @@ -643,9 +632,11 @@ def combine_results(dict_list): combined = concat_dicts(combined, d) if combined else d return combined -def skip_2weeks(d:dict) -> dict: + +def skip_2weeks(d: dict) -> dict: return {k: v[:, 14:, :] if is_region_key(k) else v for k, v in d.items()} + def get_workflow(): simulator = bf.make_simulator( @@ -677,7 +668,8 @@ def get_workflow(): summary_network = bf.networks.FusionTransformer( summary_dim=(len(bounds)+16*NUM_DAMPING_POINTS)*2, dropout=0.1 ) - inference_network = bf.networks.FlowMatching(subnet_kwargs={'widths': (512, 512, 512, 512, 512)}) + inference_network = bf.networks.FlowMatching( + subnet_kwargs={'widths': (512, 512, 512, 512, 512)}) # aug = bf.augmentations.NNPE(spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False) workflow = bf.BasicWorkflow( @@ -702,13 +694,15 @@ def run_training(name, num_training_files=20): ) # training data - train_files = [train_template.format(i=i) for i in range(1, 1+num_training_files)] + train_files = [train_template.format(i=i) + for i in range(1, 1+num_training_files)] trainings_data = None for p in train_files: d = load_pickle(p) d = apply_aug(d, aug=aug) # only on region keys d = skip_2weeks(d) - d['damping_values'] = d['damping_values'].reshape((d['damping_values'].shape[0], -1)) + d['damping_values'] = d['damping_values'].reshape( + (d['damping_values'].shape[0], -1)) if trainings_data is None: trainings_data = d else: @@ -719,12 +713,15 @@ def run_training(name, num_training_files=20): validation_data = apply_aug(load_pickle(val_path), aug=aug) validation_data = skip_2weeks(validation_data) aggregate_states(validation_data) - validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) + validation_data['damping_values'] = validation_data['damping_values'].reshape( + (validation_data['damping_values'].shape[0], -1)) # check data workflow = get_workflow() - print("summary_variables shape:", workflow.adapter(trainings_data)["summary_variables"].shape) - print("inference_variables shape:", workflow.adapter(trainings_data)["inference_variables"].shape) + print("summary_variables shape:", workflow.adapter( + trainings_data)["summary_variables"].shape) + print("inference_variables shape:", workflow.adapter( + trainings_data)["inference_variables"].shape) history = workflow.fit_offline( data=trainings_data, epochs=500, batch_size=64, validation_data=validation_data @@ -735,12 +732,13 @@ def run_training(name, num_training_files=20): ) plots = workflow.plot_default_diagnostics( - test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True} + test_data=validation_data, calibration_ecdf_kwargs={ + 'difference': True, 'stacked': True} ) plots['losses'].savefig(f'{name}/losses_{name}.png') plots['recovery'].savefig(f'{name}/recovery_{name}.png') plots['calibration_ecdf'].savefig(f'{name}/calibration_ecdf_{name}.png') - #plots['z_score_contraction'].savefig(f'{name}/z_score_contraction_{name}.png') + # plots['z_score_contraction'].savefig(f'{name}/z_score_contraction_{name}.png') def run_inference(name, num_samples=1000, on_synthetic_data=False): @@ -756,22 +754,21 @@ def run_inference(name, num_samples=1000, on_synthetic_data=False): if on_synthetic_data: # validation data validation_data = apply_aug(validation_data, aug=aug) - validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) + validation_data['damping_values'] = validation_data['damping_values'].reshape( + (validation_data['damping_values'].shape[0], -1)) validation_data_skip2w = skip_2weeks(validation_data) aggregate_states(validation_data_skip2w) divi_dict = validation_data - divi_region_keys = region_keys_sorted(divi_dict) divi_data = np.concatenate( - [divi_dict[key] for key in divi_region_keys], axis=-1 + [divi_dict[f'region{i}'] for i in range(len(region_ids))], axis=-1 )[0] # only one dataset else: divi_dict = load_divi_data() validation_data_skip2w = skip_2weeks(divi_dict) aggregate_states(validation_data_skip2w) - divi_region_keys = region_keys_sorted(divi_dict) divi_data = np.concatenate( - [divi_dict[key] for key in divi_region_keys], axis=-1 + [divi_dict[f'region{i}'] for i in range(len(region_ids))], axis=-1 )[0] workflow = get_workflow() @@ -781,14 +778,17 @@ def run_inference(name, num_samples=1000, on_synthetic_data=False): if os.path.exists(f'{name}/sims_{name}{synthetic}_with_aug.pickle') and os.path.exists(f'{name}/sims_{name}{synthetic}.pickle'): simulations = load_pickle(f'{name}/sims_{name}{synthetic}.pickle') - simulations_aug = load_pickle(f'{name}/sims_{name}{synthetic}_with_aug.pickle') + simulations_aug = load_pickle( + f'{name}/sims_{name}{synthetic}_with_aug.pickle') print("loaded simulations from file") else: - samples = workflow.sample(conditions=validation_data_skip2w, num_samples=num_samples) - samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], num_samples, 19, NUM_DAMPING_POINTS)) + samples = workflow.sample( + conditions=validation_data_skip2w, num_samples=num_samples) + samples['damping_values'] = samples['damping_values'].reshape( + (samples['damping_values'].shape[0], num_samples, 19, NUM_DAMPING_POINTS)) results = [] for i in range(num_samples): # we only have one dataset for inference here - result = run_germany_nuts3_simulation( + result = run_spain_nuts3_simulation( damping_values=samples['damping_values'][0, i], t_E=samples['t_E'][0, i], t_ISy=samples['t_ISy'][0, i], t_ISev=samples['t_ISev'][0, i], t_Cr=samples['t_Cr'][0, i], @@ -797,18 +797,23 @@ def run_inference(name, num_samples=1000, on_synthetic_data=False): transmission_prob=samples['transmission_prob'][0, i] ) for key in result.keys(): - result[key] = np.array(result[key])[None, ...] # add sample axis + result[key] = np.array(result[key])[ + None, ...] # add sample axis results.append(result) results = combine_results(results) results = extract_observables(results) results_aug = apply_aug(results, aug=aug) # get sims in shape (samples, time, regions) - simulations = np.zeros((num_samples, divi_data.shape[0], divi_data.shape[1])) - simulations_aug = np.zeros((num_samples, divi_data.shape[0], divi_data.shape[1])) + simulations = np.zeros( + (num_samples, divi_data.shape[0], divi_data.shape[1])) + simulations_aug = np.zeros( + (num_samples, divi_data.shape[0], divi_data.shape[1])) for i in range(num_samples): - simulations[i] = np.concatenate([results[key][i] for key in divi_region_keys], axis=-1) - simulations_aug[i] = np.concatenate([results_aug[key][i] for key in divi_region_keys], axis=-1) + simulations[i] = np.concatenate( + [results[f'region{region}'][i] for region in range(len(region_ids))], axis=-1) + simulations_aug[i] = np.concatenate( + [results_aug[f'region{region}'][i] for region in range(len(region_ids))], axis=-1) # save sims with open(f'{name}/sims_{name}{synthetic}.pickle', 'wb') as f: @@ -816,19 +821,26 @@ def run_inference(name, num_samples=1000, on_synthetic_data=False): with open(f'{name}/sims_{name}{synthetic}_with_aug.pickle', 'wb') as f: pickle.dump(simulations_aug, f, pickle.HIGHEST_PROTOCOL) - samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], samples['damping_values'].shape[1], -1)) - validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) + samples['damping_values'] = samples['damping_values'].reshape( + (samples['damping_values'].shape[0], samples['damping_values'].shape[1], -1)) + validation_data['damping_values'] = validation_data['damping_values'].reshape( + (validation_data['damping_values'].shape[0], -1)) - plot = bf.diagnostics.pairs_posterior(samples, priors=validation_data, dataset_id=0) + plot = bf.diagnostics.pairs_posterior( + samples, priors=validation_data, dataset_id=0) plot.savefig(f'{name}/pairs_posterior_{name}{synthetic}.png') plot_all_regions(simulations, divi_data, name, synthetic, with_aug="") - plot_all_regions(simulations_aug, divi_data, name, synthetic, with_aug="_with_aug") + plot_all_regions(simulations_aug, divi_data, name, + synthetic, with_aug="_with_aug") - plot_aggregated_to_comunidades(simulations, divi_data, name, synthetic, with_aug="") - plot_aggregated_to_comunidades(simulations_aug, divi_data, name, synthetic, with_aug="_with_aug") + plot_aggregated_to_comunidades( + simulations, divi_data, name, synthetic, with_aug="") + plot_aggregated_to_comunidades( + simulations_aug, divi_data, name, synthetic, with_aug="_with_aug") - fig, axes = plt.subplots(2, 2, figsize=(12, 12), sharex=True, sharey='row', constrained_layout=True) + fig, axes = plt.subplots(2, 2, figsize=( + 12, 12), sharex=True, sharey='row', constrained_layout=True) # Plot without augmentation plot_aggregated_over_regions( simulations, true_data=divi_data, label="Region Aggregated Median (No Aug)", ax=axes[0, 0], color="#132a70" @@ -852,32 +864,43 @@ def run_inference(name, num_samples=1000, on_synthetic_data=False): plt.savefig(f'{name}/region_aggregated_{name}{synthetic}.png') plt.close() - fig, axis = plt.subplots(1, 2, figsize=(10, 4), sharex=True, layout="constrained") + fig, axis = plt.subplots(1, 2, figsize=( + 10, 4), sharex=True, layout="constrained") ax = calibration_curves_per_region(simulations, divi_data, ax=axis[0]) - ax, stats = calibration_median_mad_over_regions(simulations, divi_data, ax=axis[1]) + ax, stats = calibration_median_mad_over_regions( + simulations, divi_data, ax=axis[1]) plt.savefig(f'{name}/calibration_per_region_{name}{synthetic}.png') plt.close() - fig, axis = plt.subplots(1, 2, figsize=(10, 4), sharex=True, layout="constrained") + fig, axis = plt.subplots(1, 2, figsize=( + 10, 4), sharex=True, layout="constrained") ax = calibration_curves_per_region(simulations_aug, divi_data, ax=axis[0]) - ax, stats = calibration_median_mad_over_regions(simulations_aug, divi_data, ax=axis[1]) - plt.savefig(f'{name}/calibration_per_region_{name}{synthetic}_with_aug.png') + ax, stats = calibration_median_mad_over_regions( + simulations_aug, divi_data, ax=axis[1]) + plt.savefig( + f'{name}/calibration_per_region_{name}{synthetic}_with_aug.png') plt.close() # plot_icu_on_spain(simulations, name, synthetic, with_aug="") # plot_icu_on_spain(simulations_aug, name, synthetic, with_aug="_with_aug") - simulation_agg = np.sum(simulations, axis=-1, keepdims=True) # sum over regions + simulation_agg = np.sum(simulations, axis=-1, + keepdims=True) # sum over regions simulation_aug_agg = np.sum(simulations_aug, axis=-1, keepdims=True) - rmse = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes(simulation_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True), normalize=False) - rmse_aug = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes(simulation_aug_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True), normalize=False) + rmse = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes( + simulation_agg, 0, 1), np.sum(divi_data, axis=-1, keepdims=True), normalize=False) + rmse_aug = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes( + simulation_aug_agg, 0, 1), np.sum(divi_data, axis=-1, keepdims=True), normalize=False) print("Mean RMSE over regions:", rmse["values"].mean()) print("Mean RMSE over regions (with aug):", rmse_aug["values"].mean()) - cal_error = bf.diagnostics.metrics.calibration_error(np.swapaxes(simulation_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True)) - cal_error_aug = bf.diagnostics.metrics.calibration_error(np.swapaxes(simulation_aug_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True)) + cal_error = bf.diagnostics.metrics.calibration_error(np.swapaxes( + simulation_agg, 0, 1), np.sum(divi_data, axis=-1, keepdims=True)) + cal_error_aug = bf.diagnostics.metrics.calibration_error(np.swapaxes( + simulation_aug_agg, 0, 1), np.sum(divi_data, axis=-1, keepdims=True)) print("Mean Calibration Error over regions:", cal_error["values"].mean()) - print("Mean Calibration Error over regions (with aug):", cal_error_aug["values"].mean()) + print("Mean Calibration Error over regions (with aug):", + cal_error_aug["values"].mean()) if __name__ == "__main__": @@ -885,7 +908,8 @@ def run_inference(name, num_samples=1000, on_synthetic_data=False): if not os.path.exists(name): os.makedirs(name) - # create_train_data(filename=f'{name}/validation_data_{name}.pickle', number_samples=100) - run_training(name=name, num_training_files=10) + # create_train_data( + # filename=f'{name}/trainings_data1_{name}.pickle', number_samples=1000) + # run_training(name=name, num_training_files=1) # run_inference(name=name, on_synthetic_data=True) - # run_inference(name=name, on_synthetic_data=False) \ No newline at end of file + run_inference(name=name, on_synthetic_data=False) From 2f51fc23289bf5e5ce7d164f36531cf171f906ba Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Thu, 13 Nov 2025 10:30:01 +0100 Subject: [PATCH 69/73] correct Provincia IDs in ICU data --- .../memilio/epidata/getSimulationDataSpain.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py b/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py index aa6ffda746..ab93f11743 100644 --- a/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py +++ b/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py @@ -39,7 +39,7 @@ def get_population_data(): def fetch_icu_data(): - # https://www.sanidad.gob.es/areas/alertasEmergenciasSanitarias/alertasActuales/nCov/capacidadAsistencial.htm?utm_source=chatgpt.com + # https://www.sanidad.gob.es/areas/alertasEmergenciasSanitarias/alertasActuales/nCov/capacidadAsistencial.htm download_url = 'https://www.sanidad.gob.es/areas/alertasEmergenciasSanitarias/alertasActuales/nCov/documentos/Datos_Capacidad_Asistencial_Historico_14072023.csv' req = requests.get(download_url) req.encoding = 'ISO-8859-1' @@ -69,6 +69,10 @@ def preprocess_icu_data(df): def get_icu_data(): df = fetch_icu_data() df.rename(columns={'Cod_Provincia': 'ID_Provincia'}, inplace=True) + # ensure numeric, drop non-numeric and increment all IDs by 1 + df['ID_Provincia'] = pd.to_numeric(df['ID_Provincia'], errors='coerce') + df = df[df['ID_Provincia'].notna()].copy() + df['ID_Provincia'] = df['ID_Provincia'].astype(int) + 1 df = remove_islands(df) df = preprocess_icu_data(df) @@ -80,7 +84,8 @@ def fetch_case_data(): download_url = 'https://cnecovid.isciii.es/covid19/resources/casos_diagnostico_provincia.csv' req = requests.get(download_url) - df = pd.read_csv(io.StringIO(req.text), sep=',', keep_default_na=False, na_values=[]) + df = pd.read_csv(io.StringIO(req.text), sep=',', + keep_default_na=False, na_values=[]) return df From e50f41854ac30bf0f071ee9bcf7c81ef2d10f0db Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Fri, 14 Nov 2025 09:39:52 +0100 Subject: [PATCH 70/73] adjust some plots for spain --- .../examples/simulation/graph_spain_nuts3.py | 228 ++++++++++----- .../memilio/epidata/defaultDict.py | 268 ++++++------------ .../memilio/epidata/getSimulationDataSpain.py | 18 +- 3 files changed, 263 insertions(+), 251 deletions(-) diff --git a/pycode/examples/simulation/graph_spain_nuts3.py b/pycode/examples/simulation/graph_spain_nuts3.py index 62e5c3ac00..4eea30d073 100644 --- a/pycode/examples/simulation/graph_spain_nuts3.py +++ b/pycode/examples/simulation/graph_spain_nuts3.py @@ -35,10 +35,10 @@ os.environ["KERAS_BACKEND"] = "tensorflow" -excluded_ids = [8, 35, 38, 51, 52] +excluded_ids = [530, 630, 640, 701, 702] region_ids = [region_id for region_id in dd.Provincias.keys() if region_id not in excluded_ids] -excluded_comunidades = [4, 5, 18, 19] +excluded_comunidades = [53, 63, 64, 70] comunidades = [comunidad for comunidad in dd.Comunidades.keys( ) if comunidad not in excluded_comunidades] inference_params = ['damping_values', 't_E', 't_ISy', 't_ISev', @@ -46,6 +46,42 @@ summary_vars = ['state'] + [f'comunidad{i}' for i in range(len(comunidades))] + [ f'region{i}' for i in range(len(region_ids))] + +def set_fontsize(base_fontsize=17): + fontsize = base_fontsize + plt.rcParams.update({ + 'font.size': fontsize, + 'axes.titlesize': fontsize * 1, + 'axes.labelsize': fontsize, + 'xtick.labelsize': fontsize * 0.8, + 'ytick.labelsize': fontsize * 0.8, + 'legend.fontsize': fontsize * 0.8, + 'font.family': "Arial" + }) + + +plt.style.use('default') + +dpi = 300 + +colors = {"Blue": "#155489", + "Medium blue": "#64A7DD", + "Light blue": "#B4DCF6", + "Lilac blue": "#AECCFF", + "Turquoise": "#76DCEC", + "Light green": "#B6E6B1", + "Medium green": "#54B48C", + "Green": "#5D8A2B", + "Teal": "#20A398", + "Yellow": "#FBD263", + "Orange": "#E89A63", + "Rose": "#CF7768", + "Red": "#A34427", + "Purple": "#741194", + "Grey": "#C0BFBF", + "Dark grey": "#616060", + "Light grey": "#F1F1F1"} + bounds = { 't_E': (1.0, 5.2), 't_ISy': (4.0, 10.0), @@ -94,7 +130,7 @@ def plot_region_fit( fig, ax = plt.subplots() ax.plot( - x, med, lw=2, label=label or f"{dd.County[region_ids[region]]}", color=color) + x, med, lw=2, label=label or f"{dd.Provincias[region_ids[region]]}", color=color) ax.fill_between(x, qs_80[0], qs_80[1], alpha=0.5, color=color, label="80% CI") if not only_80q: @@ -107,11 +143,11 @@ def plot_region_fit( true_vals = true_data[:, region] # (time_points,) ax.plot(x, true_vals, lw=2, color="black", label="True data") - ax.set_xlabel("Time", fontsize=12) - ax.set_ylabel("ICU", fontsize=12) - ax.set_title(f"{dd.Provincias[region_ids[region]]}", fontsize=12) + ax.set_xlabel("Time") + ax.set_ylabel("ICU") + ax.set_title(f"{dd.Provincias[region_ids[region]]}") if label is not None: - ax.legend(fontsize=11, loc="upper right") + ax.legend(loc="upper right") def plot_aggregated_over_regions( @@ -156,27 +192,27 @@ def plot_aggregated_over_regions( true_vals = region_agg(true_data, axis=-1) # (time_points,) ax.plot(x, true_vals, lw=2, color="black", label="True data") - ax.set_xlabel("Time", fontsize=12) - ax.set_ylabel("ICU", fontsize=12) + ax.set_xlabel("Time") + ax.set_ylabel("ICU") if label is not None: - ax.legend(fontsize=11) + ax.legend() def plot_aggregated_to_comunidades(data, true_data, name, synthetic, with_aug): fig, ax = plt.subplots(nrows=4, ncols=4, figsize=( 25, 25), layout="constrained") ax = ax.flatten() - for state in range(16): + for state_idx, state in enumerate(comunidades): idxs = [i for i, region_id in enumerate( - region_ids) if dd.Provincia_to_Comunidad[region_id] == state + 1] + region_ids) if region_id // 10 == state] plot_aggregated_over_regions( data[:, :, idxs], # Add a dummy region axis for compatibility true_data=true_data[:, idxs] if true_data is not None else None, - ax=ax[state], - label=f"State {state + 1}", - color=f"C{state % 10}" # Cycle through 10 colors + ax=ax[state_idx], + label=f"{dd.Comunidades[state]}", + color=colors["Red"] ) - plt.savefig(f'{name}/comunidades_{name}{synthetic}{with_aug}.png') + plt.savefig(f'{name}/comunidades_{name}{synthetic}{with_aug}.png', dpi=dpi) plt.close() # plot simulations for all regions in 6x4 blocks @@ -196,16 +232,59 @@ def plot_all_regions(simulations, divi_data, name, synthetic, with_aug): ax = ax.flatten() for i, region_idx in enumerate(range(start_idx, end_idx)): plot_region_fit( - simulations, region=region_idx, true_data=divi_data, label="Median", ax=ax[i], color="#132a70" + simulations, region=region_idx, true_data=divi_data, label="Median", ax=ax[i], color=colors["Red"] ) # Hide unused subplots for i in range(end_idx - start_idx, len(ax)): ax[i].axis("off") plt.savefig( - f'{name}/regions_block_{block + 1}_{name}{synthetic}{with_aug}.png') + f'{name}/regions_block_{block + 1}_{name}{synthetic}{with_aug}.png', dpi=dpi) plt.close() +def plot_icu_on_spain(simulations, name, synthetic, with_aug): + med = np.median(simulations, axis=0) + + population = pd.read_json( + 'data/Spain/pydata/provincias_current_population.json') + values = med / population['Population'].to_numpy()[None, :] * 100000 + + map_data = gpd.read_file(os.path.join(os.getcwd( + ), 'tools/lineas_limite/SHP_ETRS89/recintos_provinciales_inspire_peninbal_etrs89/recintos_provinciales_inspire_peninbal_etrs89.shp')) + map_data['ID_Provincia'] = map_data['NAMEUNIT'].map( + dd.invert_dict(dd.Provincias)) + map_data.dropna(inplace=True, subset=['ID_Provincia']) + map_data["ID_Provincia"] = map_data["ID_Provincia"].astype(int) + map_data = map_data[~map_data["ID_Provincia"].isin(excluded_ids)] + fedstate_data = gpd.read_file(os.path.join( + os.getcwd(), 'tools/lineas_limite/SHP_ETRS89/recintos_autonomicas_inspire_peninbal_etrs89/recintos_autonomicas_inspire_peninbal_etrs89.shp')) + + plot_map(values[0], map_data, fedstate_data, "Median ICU", + f"{name}/median_icu_spain_initial_{name}{synthetic}{with_aug}") + plot_map(values[-1], map_data, fedstate_data, "Median ICU", + f"{name}/median_icu_spain_final_{name}{synthetic}{with_aug}") + + +def plot_map(values, map_data, fedstate_data, label, filename): + map_data[label] = map_data['ID_Provincia'].map( + dict(zip(region_ids, values))) + + fig, ax = plt.subplots(figsize=(12, 14)) + map_data.plot( + column=f"{label}", + cmap='Reds', + linewidth=0.5, + ax=ax, + edgecolor='0.6', + legend=True, + legend_kwds={'label': f"{label} per County", 'shrink': 0.6}, + ) + fedstate_data.boundary.plot(ax=ax, color='black', linewidth=1) + ax.set_title(f"{label} per County") + ax.axis('off') + plt.savefig(filename, bbox_inches='tight', dpi=dpi) + + def calibration_curves_per_region( data: np.ndarray, true_data: np.ndarray, @@ -265,13 +344,13 @@ def calibration_curves_per_region( # Custom legend: one patch for regions, one line for ideal region_patch = Patch(color=colors[-1], label="Regions") ax.legend(handles=[region_patch, ideal_line], - frameon=True, ncol=1, fontsize=12) + frameon=True, ncol=1) ax.set_xlim(0, 1) ax.set_ylim(0, 1) - ax.set_xlabel("Nominal level", fontsize=12) - ax.set_ylabel("Empirical coverage", fontsize=12) - ax.set_title("Calibration per region", fontsize=12) + ax.set_xlabel("Nominal level") + ax.set_ylabel("Empirical coverage") + ax.set_title("Calibration per region") return ax @@ -328,14 +407,14 @@ def calibration_median_mad_over_regions( ax.set_xlim(0, 1) ax.set_ylim(0, 1) - ax.set_xlabel("Nominal level", fontsize=12) - ax.set_ylabel("Empirical coverage", fontsize=12) - ax.set_title("Calibration median and MAD across regions", fontsize=12) - ax.legend(fontsize=12) + ax.set_xlabel("Nominal level") + ax.set_ylabel("Empirical coverage") + ax.set_title("Calibration median and MAD across regions") + ax.legend() return ax, {"levels": x, "median": med, "mad": mad} -def plot_damping_values(damping_values, name, synthetic, with_aug): +def plot_damping_values(damping_values, name, synthetic): med = np.median(damping_values, axis=0) mad = np.median(np.abs(damping_values - med), axis=0) @@ -350,39 +429,39 @@ def plot_damping_values(damping_values, name, synthetic, with_aug): x = np.arange(15, 61, 15) # Time steps from 15 to 60 for i, ax in enumerate(axes): - if i < 19: + if i < len(comunidades): ax.stairs(med[i], edges=x, lw=2, color='red', baseline=None) ax.fill_between( x, med_extended[i] - mad_extended[i], med_extended[i] + mad_extended[i], alpha=0.25, color='red', step='post' ) - ax.set_title(f"{dd.Comunidades[i+2]}", fontsize=10) - ax.set_xlabel("Time", fontsize=8) - ax.set_ylabel("Damping Value", fontsize=8) + ax.set_title(f"{dd.Comunidades[comunidades[i]]}") + ax.set_xlabel("Time") + ax.set_ylabel("Damping Value") else: ax.axis('off') # Hide unused subplots - plt.suptitle("Damping Values per Region", fontsize=14) + plt.suptitle("Damping Values per Region") plt.savefig( - f"{name}/damping_values{name}{synthetic}{with_aug}.png", dpi=300) + f"{name}/damping_values{name}{synthetic}.png", dpi=dpi) # Combined plot for all regions fig, ax = plt.subplots(figsize=(10, 6)) cmap = plt.cm.get_cmap("viridis", 16) # Colormap with 16 distinct colors - for i in range(19): + for i, comunidad in enumerate(comunidades): ax.stairs( - med[i], edges=x, lw=2, label=f"{dd.State[i+1]}", + med[i], edges=x, lw=2, label=f"{dd.Comunidades[comunidad]}", color=cmap(i), baseline=None ) - ax.set_title("Damping Values per Region (Combined)", fontsize=14) - ax.set_xlabel("Time", fontsize=12) - ax.set_ylabel("Damping Value", fontsize=12) - ax.legend(fontsize=10, loc="upper right", ncol=2) + ax.set_title("Damping Values per Region (Combined)") + ax.set_xlabel("Time") + ax.set_ylabel("Damping Value") + ax.legend(loc="upper right", ncol=2) plt.savefig( - f"{name}/damping_values_combined{name}{synthetic}{with_aug}.png", dpi=300) + f"{name}/damping_values_combined{name}{synthetic}.png", dpi=dpi) class Simulation: @@ -496,8 +575,10 @@ def run(self, num_days_sim, damping_values, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_ mobility_graph = osecir.MobilityGraph() for node_idx in range(graph.num_nodes): node = graph.get_node(node_idx) - self.set_npis(node.property.parameters, - end_date, damping_values[dd.Provincia_to_Comunidad[node.id]-1]) + # determine comunidad index for this node's provincia and use its damping values + comunidad_idx = comunidades.index(node.id // 10) + self.set_npis(node.property.parameters, end_date, + damping_values[comunidad_idx]) mobility_graph.add_node(node.id, node.property) for edge_idx in range(graph.num_edges): mobility_graph.add_edge( @@ -532,13 +613,13 @@ def run_spain_nuts3_simulation(damping_values, t_E, t_ISy, t_ISev, t_Cr, mu_CR, def prior(): - damping_values = np.zeros((NUM_DAMPING_POINTS, 19)) + damping_values = np.zeros((NUM_DAMPING_POINTS, len(comunidades))) for i in range(NUM_DAMPING_POINTS): mean = np.random.uniform(0, 1) scale = 0.1 a, b = (0 - mean) / scale, (1 - mean) / scale damping_values[i] = truncnorm.rvs( - a=a, b=b, loc=mean, scale=scale, size=19 + a=a, b=b, loc=mean, scale=scale, size=len(comunidades) ) return { 'damping_values': np.transpose(damping_values), @@ -618,12 +699,12 @@ def aggregate_states(d: dict) -> None: for i, state in enumerate(comunidades): idxs = [ r for r in range(n_regions) - if dd.Provincia_to_Comunidad[region_ids[r]] == state + if region_ids[r] // 10 == state ] d[f"comunidad{i}"] = np.sum([d[f"region{r}"] for r in idxs], axis=0) # all allowed regions d["state"] = np.sum([d[f"comunidad{r}"] - for r in range(len(comunidades))], axis=0) + for r in range(len(comunidades))], axis=0) def combine_results(dict_list): @@ -719,12 +800,12 @@ def run_training(name, num_training_files=20): # check data workflow = get_workflow() print("summary_variables shape:", workflow.adapter( - trainings_data)["summary_variables"].shape) + validation_data)["summary_variables"].shape) print("inference_variables shape:", workflow.adapter( - trainings_data)["inference_variables"].shape) + validation_data)["inference_variables"].shape) history = workflow.fit_offline( - data=trainings_data, epochs=500, batch_size=64, validation_data=validation_data + data=trainings_data, epochs=5, batch_size=64, validation_data=validation_data ) workflow.approximator.save( @@ -735,13 +816,15 @@ def run_training(name, num_training_files=20): test_data=validation_data, calibration_ecdf_kwargs={ 'difference': True, 'stacked': True} ) - plots['losses'].savefig(f'{name}/losses_{name}.png') - plots['recovery'].savefig(f'{name}/recovery_{name}.png') - plots['calibration_ecdf'].savefig(f'{name}/calibration_ecdf_{name}.png') - # plots['z_score_contraction'].savefig(f'{name}/z_score_contraction_{name}.png') + plots['losses'].savefig(f'{name}/losses_{name}.png', dpi=dpi) + plots['recovery'].savefig(f'{name}/recovery_{name}.png', dpi=dpi) + plots['calibration_ecdf'].savefig( + f'{name}/calibration_ecdf_{name}.png', dpi=dpi) + plots['z_score_contraction'].savefig( + f'{name}/z_score_contraction_{name}.png', dpi=dpi) -def run_inference(name, num_samples=1000, on_synthetic_data=False): +def run_inference(name, num_samples=10, on_synthetic_data=False): val_path = f"{name}/validation_data_{name}.pickle" synthetic = "_synthetic" if on_synthetic_data else "" @@ -776,16 +859,19 @@ def run_inference(name, num_samples=1000, on_synthetic_data=False): filepath=os.path.join(f"{name}/model_{name}.keras") ) - if os.path.exists(f'{name}/sims_{name}{synthetic}_with_aug.pickle') and os.path.exists(f'{name}/sims_{name}{synthetic}.pickle'): + if os.path.exists(f'{name}/sims_{name}{synthetic}_with_aug.pickle') and os.path.exists(f'{name}/sims_{name}{synthetic}.pickle') and os.path.exists(f'{name}/sims_{name}{synthetic}.pickle'): simulations = load_pickle(f'{name}/sims_{name}{synthetic}.pickle') simulations_aug = load_pickle( f'{name}/sims_{name}{synthetic}_with_aug.pickle') + samples = load_pickle(f'{name}/samples_{name}{synthetic}.pickle') print("loaded simulations from file") else: samples = workflow.sample( conditions=validation_data_skip2w, num_samples=num_samples) + with open(f'{name}/samples_{name}{synthetic}.pickle', 'wb') as f: + pickle.dump(samples, f, pickle.HIGHEST_PROTOCOL) samples['damping_values'] = samples['damping_values'].reshape( - (samples['damping_values'].shape[0], num_samples, 19, NUM_DAMPING_POINTS)) + (samples['damping_values'].shape[0], num_samples, len(comunidades), NUM_DAMPING_POINTS)) results = [] for i in range(num_samples): # we only have one dataset for inference here result = run_spain_nuts3_simulation( @@ -821,14 +907,17 @@ def run_inference(name, num_samples=1000, on_synthetic_data=False): with open(f'{name}/sims_{name}{synthetic}_with_aug.pickle', 'wb') as f: pickle.dump(simulations_aug, f, pickle.HIGHEST_PROTOCOL) - samples['damping_values'] = samples['damping_values'].reshape( - (samples['damping_values'].shape[0], samples['damping_values'].shape[1], -1)) - validation_data['damping_values'] = validation_data['damping_values'].reshape( - (validation_data['damping_values'].shape[0], -1)) + plot_damping_values(samples['damping_values'][0], + name=name, synthetic=synthetic) + + samples['damping_values'] = samples['damping_values'].reshape( + (samples['damping_values'].shape[0], samples['damping_values'].shape[1], -1)) + validation_data['damping_values'] = validation_data['damping_values'].reshape( + (validation_data['damping_values'].shape[0], -1)) - plot = bf.diagnostics.pairs_posterior( - samples, priors=validation_data, dataset_id=0) - plot.savefig(f'{name}/pairs_posterior_{name}{synthetic}.png') + plot = bf.diagnostics.pairs_posterior( + samples, priors=validation_data, dataset_id=0) + plot.savefig(f'{name}/pairs_posterior_{name}{synthetic}.png', dpi=dpi) plot_all_regions(simulations, divi_data, name, synthetic, with_aug="") plot_all_regions(simulations_aug, divi_data, name, @@ -861,7 +950,7 @@ def run_inference(name, num_samples=1000, on_synthetic_data=False): simulations_aug, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[1, 1], color="#132a70", only_80q=True ) axes[1, 1].set_title("With Augmentation (80% Quantile)") - plt.savefig(f'{name}/region_aggregated_{name}{synthetic}.png') + plt.savefig(f'{name}/region_aggregated_{name}{synthetic}.png', dpi=dpi) plt.close() fig, axis = plt.subplots(1, 2, figsize=( @@ -869,7 +958,8 @@ def run_inference(name, num_samples=1000, on_synthetic_data=False): ax = calibration_curves_per_region(simulations, divi_data, ax=axis[0]) ax, stats = calibration_median_mad_over_regions( simulations, divi_data, ax=axis[1]) - plt.savefig(f'{name}/calibration_per_region_{name}{synthetic}.png') + plt.savefig( + f'{name}/calibration_per_region_{name}{synthetic}.png', dpi=dpi) plt.close() fig, axis = plt.subplots(1, 2, figsize=( 10, 4), sharex=True, layout="constrained") @@ -877,11 +967,11 @@ def run_inference(name, num_samples=1000, on_synthetic_data=False): ax, stats = calibration_median_mad_over_regions( simulations_aug, divi_data, ax=axis[1]) plt.savefig( - f'{name}/calibration_per_region_{name}{synthetic}_with_aug.png') + f'{name}/calibration_per_region_{name}{synthetic}_with_aug.png', dpi=dpi) plt.close() - # plot_icu_on_spain(simulations, name, synthetic, with_aug="") - # plot_icu_on_spain(simulations_aug, name, synthetic, with_aug="_with_aug") + plot_icu_on_spain(simulations, name, synthetic, with_aug="") + plot_icu_on_spain(simulations_aug, name, synthetic, with_aug="_with_aug") simulation_agg = np.sum(simulations, axis=-1, keepdims=True) # sum over regions @@ -909,7 +999,7 @@ def run_inference(name, num_samples=1000, on_synthetic_data=False): if not os.path.exists(name): os.makedirs(name) # create_train_data( - # filename=f'{name}/trainings_data1_{name}.pickle', number_samples=1000) + # filename=f'{name}/validation_data_{name}.pickle', number_samples=10) # run_training(name=name, num_training_files=1) # run_inference(name=name, on_synthetic_data=True) run_inference(name=name, on_synthetic_data=False) diff --git a/pycode/memilio-epidata/memilio/epidata/defaultDict.py b/pycode/memilio-epidata/memilio/epidata/defaultDict.py index 69277a5f53..8c57564268 100644 --- a/pycode/memilio-epidata/memilio/epidata/defaultDict.py +++ b/pycode/memilio-epidata/memilio/epidata/defaultDict.py @@ -706,185 +706,99 @@ def invert_dict(dict_to_invert): return {val: key for key, val in dict_to_invert.items()} -Provincias = {2: 'Araba/Álava', - 3: 'Albacete', - 4: 'Alicante/Alacant', - 5: 'Almería', - 6: 'Ávila', - 7: 'Badajoz', - 8: 'Balears, Illes', - 9: 'Barcelona', - 10: 'Burgos', - 11: 'Cáceres', - 12: 'Cádiz', - 13: 'Castellón/Castelló', - 14: 'Ciudad Real', - 15: 'Córdoba', - 16: 'Coruña, A', - 17: 'Cuenca', - 18: 'Gerona', - 19: 'Granada', - 20: 'Guadalajara', - 21: 'Guipúzcoa', - 22: 'Huelva', - 23: 'Huesca', - 24: 'Jaén', - 25: 'León', - 26: 'Lérida', - 27: 'Rioja, La', - 28: 'Lugo', - 29: 'Madrid', - 30: 'Málaga', - 31: 'Murcia', - 32: 'Navarra', - 33: 'Asturias', - 34: 'Palencia', - 35: 'Palmas, Las', - 36: 'Pontevedra', - 37: 'Salamanca', - 38: 'Santa Cruz de Tenerife', - 39: 'Cantabria', - 40: 'Segovia', - 41: 'Sevilla', - 42: 'Soria', - 43: 'Tarragona', - 44: 'Teruel', - 45: 'Toledo', - 46: 'Valencia/València', - 47: 'Valladolid', - 48: 'Vizcaya', - 49: 'Zamora', - 50: 'Zaragoza', - 51: 'Ceuta', - 52: 'Melilla', - 53: 'Orense'} +Provincias = { + 111: 'A Coruña', + 112: 'Lugo', + 113: 'Ourense', + 114: 'Pontevedra', + 120: 'Asturias', + 130: 'Cantabria', + 211: 'Araba/Álava', + 212: 'Gipuzkoa', + 213: 'Bizkaia', + 220: 'Navarra', + 230: 'La Rioja', + 241: 'Huesca', + 242: 'Teruel', + 243: 'Zaragoza', + 300: 'Madrid', + 411: 'Ávila', + 412: 'Burgos', + 413: 'León', + 414: 'Palencia', + 415: 'Salamanca', + 416: 'Segovia', + 417: 'Soria', + 418: 'Valladolid', + 419: 'Zamora', + 421: 'Albacete', + 422: 'Ciudad Real', + 423: 'Cuenca', + 424: 'Guadalajara', + 425: 'Toledo', + 431: 'Badajoz', + 432: 'Cáceres', + 511: 'Barcelona', + 512: 'Girona', + 513: 'Lleida', + 514: 'Tarragona', + 521: 'Alacant/Alicante', + 522: 'Castelló/Castellón', + 523: 'València/Valencia', + 530: 'Illes Balears', + 611: 'Almería', + 612: 'Cádiz', + 613: 'Córdoba', + 614: 'Granada', + 615: 'Huelva', + 616: 'Jaén', + 617: 'Málaga', + 618: 'Sevilla', + 620: 'Murcia', + 630: 'Ceuta', + 640: 'Melilla', + 701: 'Las Palmas', + 702: 'Santa Cruz de Tenerife' +} + Comunidades = { - 1: 'Andalucía', - 2: 'Aragón', - 3: 'Principado de Asturias', - 4: 'Islas Baleares', - 5: 'Canarias', - 6: 'Cantabria', - 7: 'Castilla y León', - 8: 'Castilla-La Mancha', - 9: 'Cataluña', - 10: 'Comunidad Valenciana', - 11: 'Extremadura', - 12: 'Galicia', - 13: 'Comunidad de Madrid', - 14: 'Región de Murcia', - 15: 'Comunidad Foral de Navarra', - 16: 'País Vasco', - 17: 'La Rioja', - 18: 'Ceuta', - 19: 'Melilla' + 11: 'Galicia', + 12: 'Principado de Asturias', + 13: 'Cantabria', + 21: 'País Vasco', + 22: 'Comunidad Foral de Navarra', + 23: 'La Rioja', + 24: 'Aragón', + 30: 'Comunidad de Madrid', + 41: 'Castilla y León', + 42: 'Castilla-La Mancha', + 43: 'Extremadura', + 51: 'Cataluña', + 52: 'Comunidad Valenciana', + 53: 'Islas Baleares', + 61: 'Andalucía', + 62: 'Región de Murcia', + 63: 'Ceuta', + 64: 'Melilla', + 70: 'Canarias' } -Provincia_to_Comunidad = { - 2: 16, - 3: 8, - 4: 10, - 5: 1, - 6: 7, - 7: 11, - 8: 4, - 9: 9, - 10: 7, - 11: 11, - 12: 1, - 13: 10, - 14: 8, - 15: 1, - 16: 12, - 17: 8, - 18: 9, - 19: 1, - 20: 8, - 21: 16, - 22: 1, - 23: 2, - 24: 1, - 25: 7, - 26: 9, - 27: 17, - 28: 12, - 29: 13, - 30: 1, - 31: 14, - 32: 15, - 33: 3, - 34: 7, - 35: 5, - 36: 12, - 37: 7, - 38: 5, - 39: 6, - 40: 7, - 41: 1, - 42: 7, - 43: 9, - 44: 2, - 45: 8, - 46: 10, - 47: 7, - 48: 16, - 49: 7, - 50: 2, - 51: 18, - 52: 19, - 53: 12 -} +provincia_id_map = {1: 211, 2: 421, 3: 521, 4: 611, 5: 411, 6: 431, 7: 530, 8: 511, 9: 412, 10: 432, 11: 612, 12: 522, + 13: 422, 14: 613, 15: 111, 16: 423, 17: 512, 18: 614, 19: 424, 20: 212, 21: 615, 22: 241, 23: 616, + 24: 413, 25: 513, 26: 230, 27: 112, 28: 300, 29: 617, 30: 620, 31: 220, 32: 113, 33: 120, 34: 414, + 35: 701, 36: 114, 37: 415, 38: 702, 39: 130, 40: 416, 41: 618, 42: 417, 43: 514, 44: 242, 45: 425, + 46: 523, 47: 418, 48: 213, 49: 419, 50: 243, 51: 630, 52: 640} + +provincia_id_map_census = {2: 211, 3: 421, 4: 521, 5: 611, 6: 411, 7: 431, 8: 530, 9: 511, 10: 412, 11: 432, 12: 612, 13: 522, + 14: 422, 15: 613, 16: 111, 17: 423, 18: 512, 19: 614, 20: 424, 21: 212, 22: 615, 23: 241, 24: 616, + 25: 413, 26: 513, 27: 230, 28: 112, 29: 300, 30: 617, 31: 620, 32: 220, 33: 120, 34: 414, 35: 701, + 36: 114, 37: 415, 38: 702, 39: 130, 40: 416, 41: 618, 42: 417, 43: 514, 44: 242, 45: 425, 46: 523, + 47: 418, 48: 213, 49: 419, 50: 243, 51: 630, 52: 640, 53: 113} -Provincia_ISO_to_ID = {'A': 4, - 'AB': 3, - 'AL': 5, - 'AV': 6, - 'B': 9, - 'BA': 7, - 'BI': 48, - 'BU': 10, - 'C': 16, - 'CA': 12, - 'CC': 11, - 'CE': 51, - 'CO': 15, - 'CR': 14, - 'CS': 13, - 'CU': 17, - 'GC': 35, - 'GI': 18, - 'GR': 19, - 'GU': 20, - 'H': 22, - 'HU': 23, - 'J': 24, - 'L': 26, - 'LE': 25, - 'LO': 27, - 'LU': 28, - 'M': 29, - 'MA': 30, - 'ML': 52, - 'MU': 31, - 'NA': 32, - 'O': 33, - 'OR': 53, - 'P': 34, - 'PM': 8, - 'PO': 36, - 'SA': 37, - 'S': 39, - 'SE': 41, - 'SG': 40, - 'SO': 42, - 'SS': 21, - 'T': 43, - 'TE': 44, - 'TF': 38, - 'TO': 45, - 'V': 46, - 'VA': 47, - 'VI': 2, - 'ZA': 49, - 'Z': 50} +Provincia_ISO_to_ID = {'A': 521, 'AB': 421, 'AL': 611, 'AV': 411, 'B': 511, 'BA': 431, 'BI': 213, 'BU': 412, 'C': 111, 'CA': 612, + 'CC': 432, 'CE': 630, 'CO': 613, 'CR': 422, 'CS': 522, 'CU': 423, 'GC': 701, 'GI': 512, 'GR': 614, 'GU': 424, + 'H': 615, 'HU': 241, 'J': 616, 'L': 513, 'LE': 413, 'LO': 230, 'LU': 112, 'M': 300, 'MA': 617, 'ML': 640, + 'MU': 620, 'NA': 220, 'O': 120, 'OR': 113, 'P': 414, 'PM': 530, 'PO': 114, 'SA': 415, 'S': 130, 'SE': 618, + 'SG': 416, 'SO': 417, 'SS': 212, 'T': 514, 'TE': 242, 'TF': 702, 'TO': 425, 'V': 523, 'VA': 418, 'VI': 211, + 'ZA': 419, 'Z': 243} diff --git a/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py b/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py index ab93f11743..2e6dbb895c 100644 --- a/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py +++ b/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py @@ -20,19 +20,21 @@ def fetch_population_data(): lambda x: x[0]['T3_Variable'] == 'Provincias')] df = df[df['MetaData'].apply( lambda x: x[1]['Nombre'] == 'Total')] - df['ID_Provincia'] = df['MetaData'].apply(lambda x: x[0]['Id']) + df['ID_Provincia'] = df['MetaData'].apply( + lambda x: dd.provincia_id_map_census[x[0]['Id']]) df['Population'] = df['Data'].apply(lambda x: x[0]['Valor']) return df[['ID_Provincia', 'Population']] def remove_islands(df, column_labels=['ID_Provincia']): for label in column_labels: - df = df[~df[label].isin([51, 52, 8, 35, 38])] + df = df[~df[label].isin([530, 630, 640, 701, 702])] return df def get_population_data(): df = fetch_population_data() + df = df.sort_values(by=['ID_Provincia']) df = remove_islands(df) return df @@ -73,6 +75,7 @@ def get_icu_data(): df['ID_Provincia'] = pd.to_numeric(df['ID_Provincia'], errors='coerce') df = df[df['ID_Provincia'].notna()].copy() df['ID_Provincia'] = df['ID_Provincia'].astype(int) + 1 + df['ID_Provincia'] = df['ID_Provincia'].map(dd.provincia_id_map) df = remove_islands(df) df = preprocess_icu_data(df) @@ -129,14 +132,19 @@ def preprocess_mobility_data(df, data_dir): zonification_df = pd.read_csv(os.path.join(data_dir, 'poblacion.csv'), sep='|')[ ['municipio', 'provincia', 'poblacion']] + zonification_df['provincia'] = zonification_df['provincia'].map( + dd.provincia_id_map) + zonification_df.dropna(inplace=True, subset=['provincia']) poblacion = zonification_df.groupby('provincia')[ 'poblacion'].sum() zonification_df.drop_duplicates( subset=['municipio', 'provincia'], inplace=True) municipio_to_provincia = dict( zip(zonification_df['municipio'], zonification_df['provincia'])) - df['id_origin'] = df['id_origin'].map(municipio_to_provincia) - df['id_destination'] = df['id_destination'].map(municipio_to_provincia) + df['id_origin'] = df['id_origin'].map( + municipio_to_provincia) + df['id_destination'] = df['id_destination'].map( + municipio_to_provincia) df.query('id_origin != id_destination', inplace=True) df = df.groupby(['date', 'id_origin', 'id_destination'], @@ -157,7 +165,7 @@ def get_mobility_data(data_dir, start_date='2022-08-01', end_date='2022-08-31', filename = f'Viajes_{level}_{start_date}_{end_date}_v2.parquet' if not os.path.exists(os.path.join(data_dir, filename)): print( - f"File {os.path.exists(os.path.join(data_dir, filename))} does not exist. Downloading mobility data...") + f"File {os.path.join(data_dir, filename)} does not exist. Downloading mobility data...") download_mobility_data(data_dir, start_date, end_date, level) df = pd.read_parquet(os.path.join(data_dir, filename)) From da549c3068f394e36d8b64fa134e3dfc9c9fc751 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Fri, 14 Nov 2025 10:14:27 +0100 Subject: [PATCH 71/73] adjust to paper colors and fontsizes --- .../simulation/graph_germany_nuts3_clean.py | 193 +++++++++++------- 1 file changed, 122 insertions(+), 71 deletions(-) diff --git a/pycode/examples/simulation/graph_germany_nuts3_clean.py b/pycode/examples/simulation/graph_germany_nuts3_clean.py index ef9ee88026..92d7c1fe77 100644 --- a/pycode/examples/simulation/graph_germany_nuts3_clean.py +++ b/pycode/examples/simulation/graph_germany_nuts3_clean.py @@ -46,13 +46,14 @@ region_ids = [region_id for region_id in dd.County.keys() if region_id not in excluded_ids] -inference_params = ['damping_values', 't_E', 't_ISy', 't_ISev', +inference_params = ['damping_values', 't_E', 't_C', 't_ISy', 't_ISev', 't_Cr', 'mu_CR', 'mu_IH', 'mu_HU', 'mu_UD', 'transmission_prob'] summary_vars = ['state'] + [f'fed_state{i}' for i in range(16)] + [ f'region{i}' for i in range(len(region_ids)) if region_ids[i] not in no_icu_ids] bounds = { 't_E': (1.0, 5.2), + 't_C': (0.1, 4.2), 't_ISy': (4.0, 10.0), 't_ISev': (5.0, 10.0), 't_Cr': (9.0, 17.0), @@ -67,6 +68,42 @@ DATE_TIME = datetime.date(year=2020, month=10, day=1) NUM_DAMPING_POINTS = 3 + +def set_fontsize(base_fontsize=17): + fontsize = base_fontsize + plt.rcParams.update({ + 'font.size': fontsize, + 'axes.titlesize': fontsize * 1, + 'axes.labelsize': fontsize, + 'xtick.labelsize': fontsize * 0.8, + 'ytick.labelsize': fontsize * 0.8, + 'legend.fontsize': fontsize * 0.8, + 'font.family': "Arial" + }) + + +plt.style.use('default') + +dpi = 300 + +colors = {"Blue": "#155489", + "Medium blue": "#64A7DD", + "Light blue": "#B4DCF6", + "Lilac blue": "#AECCFF", + "Turquoise": "#76DCEC", + "Light green": "#B6E6B1", + "Medium green": "#54B48C", + "Green": "#5D8A2B", + "Teal": "#20A398", + "Yellow": "#FBD263", + "Orange": "#E89A63", + "Rose": "#CF7768", + "Red": "#A34427", + "Purple": "#741194", + "Grey": "#C0BFBF", + "Dark grey": "#616060", + "Light grey": "#F1F1F1"} + def plot_region_fit( data: np.ndarray, region: int, @@ -111,11 +148,11 @@ def plot_region_fit( true_vals = true_data[:, region] # (time_points,) ax.plot(x, true_vals, lw=2, color="black", label="True data") - ax.set_xlabel("Time", fontsize=12) - ax.set_ylabel("ICU", fontsize=12) - ax.set_title(f"{dd.County[region_ids[region]]}", fontsize=12) + ax.set_xlabel("Time") + ax.set_ylabel("ICU") + ax.set_title(f"{dd.County[region_ids[region]]}") if label is not None: - ax.legend(fontsize=11, loc="upper right") + ax.legend(loc="upper right") def plot_aggregated_over_regions( @@ -148,7 +185,7 @@ def plot_aggregated_over_regions( fig, ax = plt.subplots() ax.plot(x, agg_median, lw=2, - label=label or "Aggregated over regions", color=color) + label=label or "Aggregated over regions", color=color) ax.fill_between(x, qs_80[0], qs_80[1], alpha=0.5, color=color, label="80% CI") if not only_80q: @@ -160,10 +197,10 @@ def plot_aggregated_over_regions( true_vals = region_agg(true_data, axis=-1) # (time_points,) ax.plot(x, true_vals, lw=2, color="black", label="True data") - ax.set_xlabel("Time", fontsize=12) - ax.set_ylabel("ICU", fontsize=12) + ax.set_xlabel("Time") + ax.set_ylabel("ICU") if label is not None: - ax.legend(fontsize=11) + ax.legend() def plot_icu_on_germany(simulations, name, synthetic, with_aug): med = np.median(simulations, axis=0) @@ -194,7 +231,7 @@ def plot_map(values, map_data, fedstate_data, label, filename): fedstate_data.boundary.plot(ax=ax, color='black', linewidth=1) ax.set_title(f"{label} per County") ax.axis('off') - plt.savefig(filename, bbox_inches='tight', dpi=300) + plt.savefig(filename, bbox_inches='tight', dpi=dpi) def plot_aggregated_to_federal_states(data, true_data, name, synthetic, with_aug): @@ -203,13 +240,13 @@ def plot_aggregated_to_federal_states(data, true_data, name, synthetic, with_aug for state in range(16): idxs = [i for i, region_id in enumerate(region_ids) if region_id // 1000 == state + 1] plot_aggregated_over_regions( - data[:, :, idxs], # Add a dummy region axis for compatibility - true_data=true_data[:, idxs] if true_data is not None else None, - ax=ax[state], - label=f"State {state + 1}", - color=f"C{state % 10}" # Cycle through 10 colors + data[:, :, idxs], # Add a dummy region axis for compatibility + true_data=true_data[:, idxs] if true_data is not None else None, + ax=ax[state], + label=f"State {state + 1}", + color=colors["Red"] ) - plt.savefig(f'{name}/federal_states_{name}{synthetic}{with_aug}.png') + plt.savefig(f'{name}/federal_states_{name}{synthetic}{with_aug}.png', dpi=dpi) plt.close() # plot simulations for all regions in 10x4 blocks @@ -226,12 +263,12 @@ def plot_all_regions(simulations, divi_data, name, synthetic, with_aug): ax = ax.flatten() for i, region_idx in enumerate(range(start_idx, end_idx)): plot_region_fit( - simulations, region=region_idx, true_data=divi_data, label="Median", ax=ax[i], color="#132a70" + simulations, region=region_idx, true_data=divi_data, label="Median", ax=ax[i], color=colors["Red"] ) # Hide unused subplots for i in range(end_idx - start_idx, len(ax)): ax[i].axis("off") - plt.savefig(f'{name}/regions_block_{block + 1}_{name}{synthetic}{with_aug}.png') + plt.savefig(f'{name}/regions_block_{block + 1}_{name}{synthetic}{with_aug}.png', dpi=dpi) plt.close() @@ -294,13 +331,13 @@ def calibration_curves_per_region( # Custom legend: one patch for regions, one line for ideal region_patch = Patch(color=colors[-1], label="Regions") ax.legend(handles=[region_patch, ideal_line], - frameon=True, ncol=1, fontsize=12) + frameon=True, ncol=1) ax.set_xlim(0, 1) ax.set_ylim(0, 1) - ax.set_xlabel("Nominal level", fontsize=12) - ax.set_ylabel("Empirical coverage", fontsize=12) - ax.set_title("Calibration per region", fontsize=12) + ax.set_xlabel("Nominal level") + ax.set_ylabel("Empirical coverage") + ax.set_title("Calibration per region") return ax @@ -357,10 +394,10 @@ def calibration_median_mad_over_regions( ax.set_xlim(0, 1) ax.set_ylim(0, 1) - ax.set_xlabel("Nominal level", fontsize=12) - ax.set_ylabel("Empirical coverage", fontsize=12) - ax.set_title("Calibration median and MAD across regions", fontsize=12) - ax.legend(fontsize=12) + ax.set_xlabel("Nominal level") + ax.set_ylabel("Empirical coverage") + ax.set_title("Calibration median and MAD across regions") + ax.legend() return ax, {"levels": x, "median": med, "mad": mad} @@ -385,14 +422,14 @@ def plot_damping_values(damping_values, name, synthetic, with_aug): x, med_extended[i] - mad_extended[i], med_extended[i] + mad_extended[i], alpha=0.25, color='red', step='post' ) - ax.set_title(f"{dd.State[i+1]}", fontsize=10) - ax.set_xlabel("Time", fontsize=8) - ax.set_ylabel("Damping Value", fontsize=8) + ax.set_title(f"{dd.State[i+1]}") + ax.set_xlabel("Time") + ax.set_ylabel("Damping Value") else: ax.axis('off') # Hide unused subplots - plt.suptitle("Damping Values per Region", fontsize=14) - plt.savefig(f"{name}/damping_values{name}{synthetic}{with_aug}.png", dpi=300) + plt.suptitle("Damping Values per Region") + plt.savefig(f"{name}/damping_values{name}{synthetic}{with_aug}.png", dpi=dpi) # Combined plot for all regions fig, ax = plt.subplots(figsize=(10, 6)) @@ -404,11 +441,11 @@ def plot_damping_values(damping_values, name, synthetic, with_aug): color=cmap(i), baseline=None ) - ax.set_title("Damping Values per Region (Combined)", fontsize=14) - ax.set_xlabel("Time", fontsize=12) - ax.set_ylabel("Damping Value", fontsize=12) - ax.legend(fontsize=10, loc="upper right", ncol=2) - plt.savefig(f"{name}/damping_values_combined{name}{synthetic}{with_aug}.png", dpi=300) + ax.set_title("Damping Values per Region (Combined)") + ax.set_xlabel("Time") + ax.set_ylabel("Damping Value") + ax.legend(loc="upper right", ncol=2) + plt.savefig(f"{name}/damping_values_combined{name}{synthetic}{with_aug}.png", dpi=dpi) def run_prior_predictive_check(name): validation_data = load_pickle(f'{name}/validation_data_{name}.pickle') # synthetic data @@ -428,10 +465,10 @@ def run_prior_predictive_check(name): fig, ax = plt.subplots(figsize=(8, 6)) plot_aggregated_over_regions( - simulations, true_data=true_data, label="Region Aggregated Median (No Aug)", ax=ax, color="#132a70" + simulations, true_data=true_data, label="Region Aggregated Median (No Aug)", ax=ax, color=colors["Red"] ) ax.set_title("Prior predictive check - Aggregated over regions") - plt.savefig(f'{name}/prior_predictive_check_{name}.png') + plt.savefig(f'{name}/prior_predictive_check_{name}.png', dpi=dpi) def compare_median_sim_to_mean_param_sim(name): num_samples = 10 @@ -476,10 +513,10 @@ def compare_median_sim_to_mean_param_sim(name): fig, ax = plt.subplots(figsize=(8, 6)) plot_aggregated_over_regions( - simulations_aug, true_data=result, label="Compare median of sim to sim of median params", ax=ax, color="#132a70" + simulations_aug, true_data=result, label="Compare median of sim to sim of median params", ax=ax, color=colors["Red"] ) ax.set_title("Comparison of Median Simulation to Simulation of Median Parameters") - plt.savefig(f'{name}/compare_median_sim13_to_mean_param_sim_{name}.png') + plt.savefig(f'{name}/compare_median_sim13_to_mean_param_sim_{name}.png', dpi=dpi) plt.close() @@ -494,14 +531,14 @@ def __init__(self, data_dir, start_date, results_dir): if not os.path.exists(self.results_dir): os.makedirs(self.results_dir) - def set_covid_parameters(self, model, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): + def set_covid_parameters(self, model, t_E, t_C, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): """ :param model: """ model.parameters.TimeExposed[mio.AgeGroup(0)] = t_E - model.parameters.TimeInfectedNoSymptoms[mio.AgeGroup(0)] = 5.2 - t_E # todo: correct? + model.parameters.TimeInfectedNoSymptoms[mio.AgeGroup(0)] = t_C model.parameters.TimeInfectedSymptoms[mio.AgeGroup(0)] = t_ISy model.parameters.TimeInfectedSevere[mio.AgeGroup(0)] = t_ISev model.parameters.TimeInfectedCritical[mio.AgeGroup(0)] = t_Cr @@ -561,7 +598,7 @@ def set_npis(self, params, end_date, damping_values): params.ContactPatterns.cont_freq_mat[0].add_damping( mio.Damping(np.r_[damping_values[2]], t=start_date)) - def get_graph(self, end_date, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): + def get_graph(self, end_date, t_E, t_C, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): """ :param end_date: @@ -570,7 +607,7 @@ def get_graph(self, end_date, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_ print("Initializing model...") model = Model(self.num_groups) self.set_covid_parameters( - model, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) + model, t_E, t_C, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) self.set_contact_matrices(model) print("Model initialized.") @@ -604,7 +641,7 @@ def get_graph(self, end_date, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_ return graph - def run(self, num_days_sim, damping_values, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob, save_graph=True): + def run(self, num_days_sim, damping_values, t_E, t_C, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob, save_graph=True): """ :param num_days_sim: @@ -616,7 +653,7 @@ def run(self, num_days_sim, damping_values, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_ mio.set_log_level(mio.LogLevel.Warning) end_date = self.start_date + datetime.timedelta(days=num_days_sim) - graph = self.get_graph(end_date, t_E, t_ISy, t_ISev, + graph = self.get_graph(end_date, t_E, t_C, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) mobility_graph = osecir.MobilityGraph() @@ -650,7 +687,7 @@ def run(self, num_days_sim, damping_values, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_ return results -def run_germany_nuts3_simulation(damping_values, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): +def run_germany_nuts3_simulation(damping_values, t_E, t_C, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): mio.set_log_level(mio.LogLevel.Warning) file_path = os.path.dirname(os.path.abspath(__file__)) @@ -660,7 +697,7 @@ def run_germany_nuts3_simulation(damping_values, t_E, t_ISy, t_ISev, t_Cr, mu_CR results_dir=os.path.join(file_path, "../../../results_osecir")) num_days_sim = 60 - results = sim.run(num_days_sim, damping_values, t_E, t_ISy, + results = sim.run(num_days_sim, damping_values, t_E, t_C, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) return results @@ -677,6 +714,7 @@ def prior(): return { 'damping_values': np.transpose(damping_values), 't_E': np.random.uniform(*bounds['t_E']), + 't_C': np.random.uniform(*bounds['t_C']), 't_ISy': np.random.uniform(*bounds['t_ISy']), 't_ISev': np.random.uniform(*bounds['t_ISev']), 't_Cr': np.random.uniform(*bounds['t_Cr']), @@ -794,6 +832,7 @@ def get_workflow(): .convert_dtype("float64", "float32") .constrain("damping_values", lower=0.0, upper=1.0) .constrain("t_E", lower=bounds["t_E"][0], upper=bounds["t_E"][1]) + .constrain("t_C", lower=bounds["t_C"][0], upper=bounds["t_C"][1]) .constrain("t_ISy", lower=bounds["t_ISy"][0], upper=bounds["t_ISy"][1]) .constrain("t_ISev", lower=bounds["t_ISev"][0], upper=bounds["t_ISev"][1]) .constrain("t_Cr", lower=bounds["t_Cr"][0], upper=bounds["t_Cr"][1]) @@ -803,7 +842,7 @@ def get_workflow(): .constrain("mu_UD", lower=bounds["mu_UD"][0], upper=bounds["mu_UD"][1]) .constrain("transmission_prob", lower=bounds["transmission_prob"][0], upper=bounds["transmission_prob"][1]) .concatenate( - ["damping_values", "t_E", "t_ISy", "t_ISev", "t_Cr", + ["damping_values", "t_E", "t_C", "t_ISy", "t_ISev", "t_Cr", "mu_CR", "mu_IH", "mu_HU", "mu_UD", "transmission_prob"], into="inference_variables", axis=-1 @@ -874,10 +913,10 @@ def run_training(name, num_training_files=20): plots = workflow.plot_default_diagnostics( test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True} ) - plots['losses'].savefig(f'{name}/losses_{name}.png') - plots['recovery'].savefig(f'{name}/recovery_{name}.png') - plots['calibration_ecdf'].savefig(f'{name}/calibration_ecdf_{name}.png') - #plots['z_score_contraction'].savefig(f'{name}/z_score_contraction_{name}.png') + plots['losses'].savefig(f'{name}/losses_{name}.png', dpi=dpi) + plots['recovery'].savefig(f'{name}/recovery_{name}.png', dpi=dpi) + plots['calibration_ecdf'].savefig(f'{name}/calibration_ecdf_{name}.png', dpi=dpi) + #plots['z_score_contraction'].savefig(f'{name}/z_score_contraction_{name}.png', dpi=dpi) def run_inference(name, num_samples=1000, on_synthetic_data=False): @@ -916,18 +955,22 @@ def run_inference(name, num_samples=1000, on_synthetic_data=False): filepath=os.path.join(f"{name}/model_{name}.keras") ) - if os.path.exists(f'{name}/sims_{name}{synthetic}_with_aug.pickle') and os.path.exists(f'{name}/sims_{name}{synthetic}.pickle'): + if os.path.exists(f'{name}/sims_{name}{synthetic}_with_aug.pickle') and os.path.exists(f'{name}/sims_{name}{synthetic}.pickle') and os.path.exists(f'{name}/sims_{name}{synthetic}.pickle'): simulations = load_pickle(f'{name}/sims_{name}{synthetic}.pickle') - simulations_aug = load_pickle(f'{name}/sims_{name}{synthetic}_with_aug.pickle') + simulations_aug = load_pickle( + f'{name}/sims_{name}{synthetic}_with_aug.pickle') + samples = load_pickle(f'{name}/samples_{name}{synthetic}.pickle') print("loaded simulations from file") else: samples = workflow.sample(conditions=validation_data_skip2w, num_samples=num_samples) + with open(f'{name}/samples_{name}{synthetic}.pickle', 'wb') as f: + pickle.dump(samples, f, pickle.HIGHEST_PROTOCOL) samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], num_samples, 16, NUM_DAMPING_POINTS)) results = [] for i in range(num_samples): # we only have one dataset for inference here result = run_germany_nuts3_simulation( damping_values=samples['damping_values'][0, i], - t_E=samples['t_E'][0, i], t_ISy=samples['t_ISy'][0, i], + t_E=samples['t_E'][0, i], t_C=samples['t_C'][0, i], t_ISy=samples['t_ISy'][0, i], t_ISev=samples['t_ISev'][0, i], t_Cr=samples['t_Cr'][0, i], mu_CR=samples['mu_CR'][0, i], mu_IH=samples['mu_IH'][0, i], mu_HU=samples['mu_HU'][0, i], mu_UD=samples['mu_UD'][0, i], @@ -953,11 +996,17 @@ def run_inference(name, num_samples=1000, on_synthetic_data=False): with open(f'{name}/sims_{name}{synthetic}_with_aug.pickle', 'wb') as f: pickle.dump(simulations_aug, f, pickle.HIGHEST_PROTOCOL) - samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], samples['damping_values'].shape[1], -1)) - validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) + plot_damping_values(samples['damping_values'][0], + name=name, synthetic=synthetic) - plot = bf.diagnostics.pairs_posterior(samples, priors=validation_data, dataset_id=0) - plot.savefig(f'{name}/pairs_posterior_{name}{synthetic}.png') + samples['damping_values'] = samples['damping_values'].reshape( + (samples['damping_values'].shape[0], samples['damping_values'].shape[1], -1)) + validation_data['damping_values'] = validation_data['damping_values'].reshape( + (validation_data['damping_values'].shape[0], -1)) + + plot = bf.diagnostics.pairs_posterior( + samples, priors=validation_data, dataset_id=0) + plot.savefig(f'{name}/pairs_posterior_{name}{synthetic}.png', dpi=dpi) plot_all_regions(simulations, divi_data, name, synthetic, with_aug="") plot_all_regions(simulations_aug, divi_data, name, synthetic, with_aug="_with_aug") @@ -968,36 +1017,36 @@ def run_inference(name, num_samples=1000, on_synthetic_data=False): fig, axes = plt.subplots(2, 2, figsize=(12, 12), sharex=True, sharey='row', constrained_layout=True) # Plot without augmentation plot_aggregated_over_regions( - simulations, true_data=divi_data, label="Region Aggregated Median (No Aug)", ax=axes[0, 0], color="#132a70" + simulations, true_data=divi_data, label="Region Aggregated Median (No Aug)", ax=axes[0, 0], color=colors["Red"] ) axes[0, 0].set_title("Without Augmentation") # Plot with augmentation plot_aggregated_over_regions( - simulations_aug, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[0, 1], color="#132a70" + simulations_aug, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[0, 1], color=colors["Red"] ) axes[0, 1].set_title("With Augmentation") # Plot without augmentation (80% quantile only) plot_aggregated_over_regions( - simulations, true_data=divi_data, label="Region Aggregated Median (No Aug)", ax=axes[1, 0], color="#132a70", only_80q=True + simulations, true_data=divi_data, label="Region Aggregated Median (No Aug)", ax=axes[1, 0], color=colors["Red"], only_80q=True ) axes[1, 0].set_title("Without Augmentation (80% Quantile)") # Plot with augmentation (80% quantile only) plot_aggregated_over_regions( - simulations_aug, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[1, 1], color="#132a70", only_80q=True + simulations_aug, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[1, 1], color=colors["Red"], only_80q=True ) axes[1, 1].set_title("With Augmentation (80% Quantile)") - plt.savefig(f'{name}/region_aggregated_{name}{synthetic}.png') + plt.savefig(f'{name}/region_aggregated_{name}{synthetic}.png', dpi=dpi) plt.close() fig, axis = plt.subplots(1, 2, figsize=(10, 4), sharex=True, layout="constrained") ax = calibration_curves_per_region(simulations, divi_data, ax=axis[0]) ax, stats = calibration_median_mad_over_regions(simulations, divi_data, ax=axis[1]) - plt.savefig(f'{name}/calibration_per_region_{name}{synthetic}.png') + plt.savefig(f'{name}/calibration_per_region_{name}{synthetic}.png', dpi=dpi) plt.close() fig, axis = plt.subplots(1, 2, figsize=(10, 4), sharex=True, layout="constrained") ax = calibration_curves_per_region(simulations_aug, divi_data, ax=axis[0]) ax, stats = calibration_median_mad_over_regions(simulations_aug, divi_data, ax=axis[1]) - plt.savefig(f'{name}/calibration_per_region_{name}{synthetic}_with_aug.png') + plt.savefig(f'{name}/calibration_per_region_{name}{synthetic}_with_aug.png', dpi=dpi) plt.close() plot_icu_on_germany(simulations, name, synthetic, with_aug="") @@ -1020,11 +1069,13 @@ def run_inference(name, num_samples=1000, on_synthetic_data=False): if __name__ == "__main__": name = "3dampings_lessnoise_newnetwork" + set_fontsize() + if not os.path.exists(name): os.makedirs(name) - create_train_data(filename=f'{name}/trainings_data1_{name}.pickle', number_samples=1000) - # run_training(name=name, num_training_files=20) - # run_inference(name=name, on_synthetic_data=True) - # run_inference(name=name, on_synthetic_data=False) + # create_train_data(filename=f'{name}/validation_data_{name}.pickle', number_samples=100) + # run_training(name=name, num_training_files=10) + run_inference(name=name, on_synthetic_data=True) + run_inference(name=name, on_synthetic_data=False) # run_prior_predictive_check(name=name) # compare_median_sim_to_mean_param_sim(name=name) \ No newline at end of file From a8fbad9a772aa024211709f87f2f9c172c347fc6 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Fri, 14 Nov 2025 14:27:47 +0100 Subject: [PATCH 72/73] small adjustments to plots --- .../simulation/graph_germany_nuts3_clean.py | 47 +++++++------------ 1 file changed, 18 insertions(+), 29 deletions(-) diff --git a/pycode/examples/simulation/graph_germany_nuts3_clean.py b/pycode/examples/simulation/graph_germany_nuts3_clean.py index 92d7c1fe77..908813cccf 100644 --- a/pycode/examples/simulation/graph_germany_nuts3_clean.py +++ b/pycode/examples/simulation/graph_germany_nuts3_clean.py @@ -199,8 +199,6 @@ def plot_aggregated_over_regions( ax.set_xlabel("Time") ax.set_ylabel("ICU") - if label is not None: - ax.legend() def plot_icu_on_germany(simulations, name, synthetic, with_aug): med = np.median(simulations, axis=0) @@ -401,7 +399,7 @@ def calibration_median_mad_over_regions( return ax, {"levels": x, "median": med, "mad": mad} -def plot_damping_values(damping_values, name, synthetic, with_aug): +def plot_damping_values(damping_values, name, synthetic): med = np.median(damping_values, axis=0) mad = np.median(np.abs(damping_values - med), axis=0) @@ -429,7 +427,7 @@ def plot_damping_values(damping_values, name, synthetic, with_aug): ax.axis('off') # Hide unused subplots plt.suptitle("Damping Values per Region") - plt.savefig(f"{name}/damping_values{name}{synthetic}{with_aug}.png", dpi=dpi) + plt.savefig(f"{name}/damping_values{name}{synthetic}.png", dpi=dpi) # Combined plot for all regions fig, ax = plt.subplots(figsize=(10, 6)) @@ -445,7 +443,7 @@ def plot_damping_values(damping_values, name, synthetic, with_aug): ax.set_xlabel("Time") ax.set_ylabel("Damping Value") ax.legend(loc="upper right", ncol=2) - plt.savefig(f"{name}/damping_values_combined{name}{synthetic}{with_aug}.png", dpi=dpi) + plt.savefig(f"{name}/damping_values_combined{name}{synthetic}.png", dpi=dpi) def run_prior_predictive_check(name): validation_data = load_pickle(f'{name}/validation_data_{name}.pickle') # synthetic data @@ -955,7 +953,7 @@ def run_inference(name, num_samples=1000, on_synthetic_data=False): filepath=os.path.join(f"{name}/model_{name}.keras") ) - if os.path.exists(f'{name}/sims_{name}{synthetic}_with_aug.pickle') and os.path.exists(f'{name}/sims_{name}{synthetic}.pickle') and os.path.exists(f'{name}/sims_{name}{synthetic}.pickle'): + if os.path.exists(f'{name}/sims_{name}{synthetic}_with_aug.pickle') and os.path.exists(f'{name}/sims_{name}{synthetic}.pickle') and os.path.exists(f'{name}/samples_{name}{synthetic}.pickle'): simulations = load_pickle(f'{name}/sims_{name}{synthetic}.pickle') simulations_aug = load_pickle( f'{name}/sims_{name}{synthetic}_with_aug.pickle') @@ -996,17 +994,18 @@ def run_inference(name, num_samples=1000, on_synthetic_data=False): with open(f'{name}/sims_{name}{synthetic}_with_aug.pickle', 'wb') as f: pickle.dump(simulations_aug, f, pickle.HIGHEST_PROTOCOL) - plot_damping_values(samples['damping_values'][0], - name=name, synthetic=synthetic) - - samples['damping_values'] = samples['damping_values'].reshape( + samples['damping_values'] = samples['damping_values'].reshape( (samples['damping_values'].shape[0], samples['damping_values'].shape[1], -1)) - validation_data['damping_values'] = validation_data['damping_values'].reshape( + validation_data['damping_values'] = validation_data['damping_values'].reshape( (validation_data['damping_values'].shape[0], -1)) - plot = bf.diagnostics.pairs_posterior( - samples, priors=validation_data, dataset_id=0) - plot.savefig(f'{name}/pairs_posterior_{name}{synthetic}.png', dpi=dpi) + plot = bf.diagnostics.pairs_posterior( + samples, priors=validation_data, dataset_id=0) + plot.savefig(f'{name}/pairs_posterior_{name}{synthetic}.png', dpi=dpi) + + samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], num_samples, 16, NUM_DAMPING_POINTS)) + plot_damping_values(samples['damping_values'][0], + name=name, synthetic=synthetic) plot_all_regions(simulations, divi_data, name, synthetic, with_aug="") plot_all_regions(simulations_aug, divi_data, name, synthetic, with_aug="_with_aug") @@ -1014,27 +1013,17 @@ def run_inference(name, num_samples=1000, on_synthetic_data=False): plot_aggregated_to_federal_states(simulations, divi_data, name, synthetic, with_aug="") plot_aggregated_to_federal_states(simulations_aug, divi_data, name, synthetic, with_aug="_with_aug") - fig, axes = plt.subplots(2, 2, figsize=(12, 12), sharex=True, sharey='row', constrained_layout=True) - # Plot without augmentation - plot_aggregated_over_regions( - simulations, true_data=divi_data, label="Region Aggregated Median (No Aug)", ax=axes[0, 0], color=colors["Red"] - ) - axes[0, 0].set_title("Without Augmentation") + fig, axes = plt.subplots(1, 2, figsize=(12, 6), constrained_layout=True) # Plot with augmentation plot_aggregated_over_regions( - simulations_aug, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[0, 1], color=colors["Red"] - ) - axes[0, 1].set_title("With Augmentation") - # Plot without augmentation (80% quantile only) - plot_aggregated_over_regions( - simulations, true_data=divi_data, label="Region Aggregated Median (No Aug)", ax=axes[1, 0], color=colors["Red"], only_80q=True + simulations_aug, true_data=divi_data, label="Region Aggregated Median", ax=axes[0], color=colors["Red"] ) - axes[1, 0].set_title("Without Augmentation (80% Quantile)") # Plot with augmentation (80% quantile only) plot_aggregated_over_regions( - simulations_aug, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[1, 1], color=colors["Red"], only_80q=True + simulations_aug, true_data=divi_data, label="Region Aggregated Median", ax=axes[1], color=colors["Red"], only_80q=True ) - axes[1, 1].set_title("With Augmentation (80% Quantile)") + lines, labels = axes[0].get_legend_handles_labels() + fig.legend(lines, labels) plt.savefig(f'{name}/region_aggregated_{name}{synthetic}.png', dpi=dpi) plt.close() From f2edabfc78964e3d04e4378023253664c26b0027 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein Date: Fri, 14 Nov 2025 14:28:30 +0100 Subject: [PATCH 73/73] renaming --- .../simulation/graph_germany_nuts3.py | 1032 +++++++++++++--- .../simulation/graph_germany_nuts3_clean.py | 1070 ----------------- 2 files changed, 845 insertions(+), 1257 deletions(-) delete mode 100644 pycode/examples/simulation/graph_germany_nuts3_clean.py diff --git a/pycode/examples/simulation/graph_germany_nuts3.py b/pycode/examples/simulation/graph_germany_nuts3.py index 5153d0f7a7..908813cccf 100644 --- a/pycode/examples/simulation/graph_germany_nuts3.py +++ b/pycode/examples/simulation/graph_germany_nuts3.py @@ -1,7 +1,7 @@ ############################################################################# # Copyright (C) 2020-2025 MEmilio # -# Authors: Henrik Zunker +# Authors: Carlotta Gerstein # # Contact: Martin J. Kuehn # @@ -17,33 +17,101 @@ # See the License for the specific language governing permissions and # limitations under the License. ############################################################################# +import os +os.environ["KERAS_BACKEND"] = "tensorflow" + import numpy as np +import pandas as pd +import matplotlib.pyplot as plt import datetime -import os +import pickle +from scipy.stats import truncnorm + +from matplotlib.patches import Patch + +import bayesflow as bf +import keras + import memilio.simulation as mio import memilio.simulation.osecir as osecir -import matplotlib.pyplot as plt - -from enum import Enum -from memilio.simulation.osecir import (Model, Simulation, - interpolate_simulation_result) -import pandas as pd +from memilio.simulation.osecir import Model, interpolate_simulation_result from memilio.epidata import defaultDict as dd -import pickle -excluded_ids = [11001, 11002, 11003, 11004, 11005, 11006, 11007, 11008, 11009, 11010, 11011, 11012, 16056] -no_icu_ids = [7338, 9374, 9473, 9573] -region_ids = [region_id for region_id in dd.County.keys() if region_id not in excluded_ids] +import geopandas as gpd -inference_params = ['damping_values', 't_E', 't_ISy', 't_ISev', 'transmission_prob'] -def plot_region_median_mad( +excluded_ids = [11001, 11002, 11003, 11004, 11005, 11006, + 11007, 11008, 11009, 11010, 11011, 11012, 16056] +no_icu_ids = [7338, 9374, 9473, 9573] +region_ids = [region_id for region_id in dd.County.keys() + if region_id not in excluded_ids] + +inference_params = ['damping_values', 't_E', 't_C', 't_ISy', 't_ISev', + 't_Cr', 'mu_CR', 'mu_IH', 'mu_HU', 'mu_UD', 'transmission_prob'] +summary_vars = ['state'] + [f'fed_state{i}' for i in range(16)] + [ + f'region{i}' for i in range(len(region_ids)) if region_ids[i] not in no_icu_ids] + +bounds = { + 't_E': (1.0, 5.2), + 't_C': (0.1, 4.2), + 't_ISy': (4.0, 10.0), + 't_ISev': (5.0, 10.0), + 't_Cr': (9.0, 17.0), + 'mu_CR': (0.0, 0.4), + 'mu_IH': (0.0, 0.2), + 'mu_HU': (0.0, 0.4), + 'mu_UD': (0.0, 0.4), + 'transmission_prob': (0.0, 0.2) +} +SPIKE_SCALE = 0.4 +SLAB_SCALE = 0.2 +DATE_TIME = datetime.date(year=2020, month=10, day=1) +NUM_DAMPING_POINTS = 3 + + +def set_fontsize(base_fontsize=17): + fontsize = base_fontsize + plt.rcParams.update({ + 'font.size': fontsize, + 'axes.titlesize': fontsize * 1, + 'axes.labelsize': fontsize, + 'xtick.labelsize': fontsize * 0.8, + 'ytick.labelsize': fontsize * 0.8, + 'legend.fontsize': fontsize * 0.8, + 'font.family': "Arial" + }) + + +plt.style.use('default') + +dpi = 300 + +colors = {"Blue": "#155489", + "Medium blue": "#64A7DD", + "Light blue": "#B4DCF6", + "Lilac blue": "#AECCFF", + "Turquoise": "#76DCEC", + "Light green": "#B6E6B1", + "Medium green": "#54B48C", + "Green": "#5D8A2B", + "Teal": "#20A398", + "Yellow": "#FBD263", + "Orange": "#E89A63", + "Rose": "#CF7768", + "Red": "#A34427", + "Purple": "#741194", + "Grey": "#C0BFBF", + "Dark grey": "#616060", + "Light grey": "#F1F1F1"} + +def plot_region_fit( data: np.ndarray, region: int, - true_data = None, - ax = None, - label = None, - color = "red" + true_data=None, + ax=None, + label=None, + color="red", + only_80q=False ): if data.ndim != 3: raise ValueError("Array not of shape (samples, time_points, regions)") @@ -56,33 +124,45 @@ def plot_region_median_mad( x = np.arange(n_time) vals = data[:, :, region] # (samples, time_points) + + qs_80 = np.quantile(vals, q=[0.1, 0.9], axis=0) + qs_90 = np.quantile(vals, q=[0.05, 0.95], axis=0) + qs_95 = np.quantile(vals, q=[0.025, 0.975], axis=0) + med = np.median(vals, axis=0) - mad = np.median(np.abs(vals - med), axis=0) if ax is None: fig, ax = plt.subplots() - line, = ax.plot(x, med, lw=2, label=label or f"Region {region}", color=color) - band = ax.fill_between(x, med - mad, med + mad, alpha=0.25, color=color) + ax.plot( + x, med, lw=2, label=label or f"{dd.County[region_ids[region]]}", color=color) + ax.fill_between(x, qs_80[0], qs_80[1], alpha=0.5, + color=color, label="80% CI") + if not only_80q: + ax.fill_between(x, qs_90[0], qs_90[1], alpha=0.3, + color=color, label="90% CI") + ax.fill_between(x, qs_95[0], qs_95[1], alpha=0.1, + color=color, label="95% CI") + if true_data is not None: true_vals = true_data[:, region] # (time_points,) ax.plot(x, true_vals, lw=2, color="black", label="True data") ax.set_xlabel("Time") ax.set_ylabel("ICU") - ax.set_title(f"Region {region}") + ax.set_title(f"{dd.County[region_ids[region]]}") if label is not None: - ax.legend() - return line, band + ax.legend(loc="upper right") def plot_aggregated_over_regions( data: np.ndarray, - region_agg = np.sum, - true_data = None, - ax = None, - label = None, - color = 'red' + region_agg=np.sum, + true_data=None, + ax=None, + label=None, + color='red', + only_80q=False ): if data.ndim != 3: raise ValueError("Array not of shape (samples, time_points, regions)") @@ -93,35 +173,350 @@ def plot_aggregated_over_regions( # Aggregate over regions agg_over_regions = region_agg(data, axis=-1) # (samples, time_points) + qs_80 = np.quantile(agg_over_regions, q=[0.1, 0.9], axis=0) + qs_90 = np.quantile(agg_over_regions, q=[0.05, 0.95], axis=0) + qs_95 = np.quantile(agg_over_regions, q=[0.025, 0.975], axis=0) + # Aggregate over samples agg_median = np.median(agg_over_regions, axis=0) # (time_points, ) - agg_mad = np.median( - np.abs(agg_over_regions - agg_median[None]), - axis=0 - ) x = np.arange(agg_median.shape[0]) if ax is None: fig, ax = plt.subplots() - line, = ax.plot(x, agg_median, lw=2, label=label or "Aggregated over regions", color=color) - band = ax.fill_between( - x, - agg_median - agg_mad, - agg_median + agg_mad, - alpha=0.25, - color=color - ) + ax.plot(x, agg_median, lw=2, + label=label or "Aggregated over regions", color=color) + ax.fill_between(x, qs_80[0], qs_80[1], alpha=0.5, + color=color, label="80% CI") + if not only_80q: + ax.fill_between(x, qs_90[0], qs_90[1], alpha=0.3, + color=color, label="90% CI") + ax.fill_between(x, qs_95[0], qs_95[1], alpha=0.1, + color=color, label="95% CI") if true_data is not None: true_vals = region_agg(true_data, axis=-1) # (time_points,) ax.plot(x, true_vals, lw=2, color="black", label="True data") ax.set_xlabel("Time") ax.set_ylabel("ICU") - if label is not None: - ax.legend() - return line, band +def plot_icu_on_germany(simulations, name, synthetic, with_aug): + med = np.median(simulations, axis=0) + + population = pd.read_json('data/Germany/pydata/county_current_population.json') + values = med / population['Population'].to_numpy()[None, :] * 100000 + + + map_data = gpd.read_file(os.path.join(os.getcwd(), 'tools/vg2500_12-31.utm32s.shape/vg2500/VG2500_KRS.shp')) + fedstate_data = gpd.read_file(os.path.join(os.getcwd(), 'tools/vg2500_12-31.utm32s.shape/vg2500/VG2500_LAN.shp')) + + plot_map(values[0], map_data, fedstate_data, "Median ICU", f"{name}/median_icu_germany_initial_{name}{synthetic}{with_aug}") + plot_map(values[-1], map_data, fedstate_data, "Median ICU", f"{name}/median_icu_germany_final_{name}{synthetic}{with_aug}") + +def plot_map(values, map_data, fedstate_data, label, filename): + map_data[label] = map_data['ARS'].map(dict(zip([f"{region_id:05d}" for region_id in region_ids], values))) + + fig, ax = plt.subplots(figsize=(12, 14)) + map_data.plot( + column=f"{label}", + cmap='Reds', + linewidth=0.5, + ax=ax, + edgecolor='0.6', + legend=True, + legend_kwds={'label': f"{label} per County", 'shrink': 0.6}, + ) + fedstate_data.boundary.plot(ax=ax, color='black', linewidth=1) + ax.set_title(f"{label} per County") + ax.axis('off') + plt.savefig(filename, bbox_inches='tight', dpi=dpi) + + +def plot_aggregated_to_federal_states(data, true_data, name, synthetic, with_aug): + fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(25, 25), layout="constrained") + ax = ax.flatten() + for state in range(16): + idxs = [i for i, region_id in enumerate(region_ids) if region_id // 1000 == state + 1] + plot_aggregated_over_regions( + data[:, :, idxs], # Add a dummy region axis for compatibility + true_data=true_data[:, idxs] if true_data is not None else None, + ax=ax[state], + label=f"State {state + 1}", + color=colors["Red"] + ) + plt.savefig(f'{name}/federal_states_{name}{synthetic}{with_aug}.png', dpi=dpi) + plt.close() + +# plot simulations for all regions in 10x4 blocks +def plot_all_regions(simulations, divi_data, name, synthetic, with_aug): + n_regions = simulations.shape[-1] + n_cols = 4 + n_rows = 10 + n_blocks = (n_regions + n_cols * n_rows - 1) // (n_cols * n_rows) + + for block in range(n_blocks): + start_idx = block * n_cols * n_rows + end_idx = min(start_idx + n_cols * n_rows, n_regions) + fig, ax = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(15, 25), layout="constrained") + ax = ax.flatten() + for i, region_idx in enumerate(range(start_idx, end_idx)): + plot_region_fit( + simulations, region=region_idx, true_data=divi_data, label="Median", ax=ax[i], color=colors["Red"] + ) + # Hide unused subplots + for i in range(end_idx - start_idx, len(ax)): + ax[i].axis("off") + plt.savefig(f'{name}/regions_block_{block + 1}_{name}{synthetic}{with_aug}.png', dpi=dpi) + plt.close() + + +def calibration_curves_per_region( + data: np.ndarray, + true_data: np.ndarray, + levels=np.linspace(0.01, 0.99, 20), + ax=None, + max_regions=None, + cmap=plt.cm.Blues, + linewidth=1.5, + legend=True, + with_ideal=True, +): + """ + Per-region calibration curves, each region in a different shade of a colormap. + + data: (samples, time, regions) + true_data: (time, regions) + max_regions: limit number of regions shown + cmap: matplotlib colormap for line shades + """ + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + + n_samples, n_time, n_regions = data.shape + if max_regions is None: + R = n_regions + else: + R = min(max_regions, n_regions) + + if ax is None: + fig, ax = plt.subplots(figsize=(5, 4)) + + colors = [cmap(i / (R + 1)) for i in range(1, R + 1)] + + x = np.asarray(levels) + for r, col in zip(range(R), colors): + emp = [] + for nominal in levels: + q_low = (1.0 - nominal) / 2.0 + q_high = 1.0 - q_low + lo = np.quantile(data[:, :, r], q_low, axis=0) + hi = np.quantile(data[:, :, r], q_high, axis=0) + hits = (true_data[:, r] >= lo) & (true_data[:, r] <= hi) + emp.append(hits.mean()) + emp = np.asarray(emp) + # , label=f"Region {r+1}") + ax.plot(x, emp, lw=linewidth, color=col, alpha=0.5) + + if with_ideal: + ideal_line = ax.plot([0, 1], [0, 1], linestyle="--", + lw=1.2, color="black", label="Ideal")[0] + else: + ideal_line = None + + if legend: + # Custom legend: one patch for regions, one line for ideal + region_patch = Patch(color=colors[-1], label="Regions") + ax.legend(handles=[region_patch, ideal_line], + frameon=True, ncol=1) + + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.set_xlabel("Nominal level") + ax.set_ylabel("Empirical coverage") + ax.set_title("Calibration per region") + return ax + + +def calibration_median_mad_over_regions( + data: np.ndarray, + true_data: np.ndarray, + levels=np.linspace(0.01, 0.99, 20), + ax=None, + color="tab:blue", + alpha=0.25, + linewidth=2.0, + with_ideal=True, +): + """ + Compute per region empirical coverage at each nominal level, + then summarize across regions with median and MAD. + + data: (samples, time, regions) + true_data: (time, regions) + """ + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + + n_samples, n_time, n_regions = data.shape + L = len(levels) + per_region = np.empty((n_regions, L), dtype=float) + + # per level coverage per region + for j, nominal in enumerate(levels): + q_low = (1.0 - nominal) / 2.0 + q_high = 1.0 - q_low + lo = np.quantile(data, q_low, axis=0) # (time, regions) + hi = np.quantile(data, q_high, axis=0) # (time, regions) + hits = (true_data >= lo) & (true_data <= hi) # (time, regions) + # mean over time for each region + per_region[:, j] = hits.mean(axis=0) + + med = np.median(per_region, axis=0) # (levels,) + mad = np.median(np.abs(per_region - med[None, :]), axis=0) + + if ax is None: + fig, ax = plt.subplots(figsize=(5, 4)) + + x = np.asarray(levels) + ax.fill_between(x, med - mad, med + mad, + alpha=alpha, color=color, label=None) + ax.plot(x, med, lw=linewidth, color=color, label="Median across regions") + + if with_ideal: + ax.plot([0, 1], [0, 1], linestyle="--", + lw=1.2, color="black", label="Ideal") + + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.set_xlabel("Nominal level") + ax.set_ylabel("Empirical coverage") + ax.set_title("Calibration median and MAD across regions") + ax.legend() + return ax, {"levels": x, "median": med, "mad": mad} + + +def plot_damping_values(damping_values, name, synthetic): + + med = np.median(damping_values, axis=0) + mad = np.median(np.abs(damping_values - med), axis=0) + + # Extend for step plotting + med_extended = np.hstack([med, med[:, -1][:, None]]) + mad_extended = np.hstack([mad, mad[:, -1][:, None]]) + + # Plot damping values per region + fig, axes = plt.subplots(4, 4, figsize=(15, 10), constrained_layout=True) + axes = axes.flatten() + x = np.arange(15, 61, 15) # Time steps from 15 to 60 + + for i, ax in enumerate(axes): + if i < 16: + ax.stairs(med[i], edges=x, lw=2, color='red', baseline=None) + ax.fill_between( + x, med_extended[i] - mad_extended[i], med_extended[i] + mad_extended[i], + alpha=0.25, color='red', step='post' + ) + ax.set_title(f"{dd.State[i+1]}") + ax.set_xlabel("Time") + ax.set_ylabel("Damping Value") + else: + ax.axis('off') # Hide unused subplots + + plt.suptitle("Damping Values per Region") + plt.savefig(f"{name}/damping_values{name}{synthetic}.png", dpi=dpi) + + # Combined plot for all regions + fig, ax = plt.subplots(figsize=(10, 6)) + cmap = plt.cm.get_cmap("viridis", 16) # Colormap with 16 distinct colors + + for i in range(16): + ax.stairs( + med[i], edges=x, lw=2, label=f"{dd.State[i+1]}", + color=cmap(i), baseline=None + ) + + ax.set_title("Damping Values per Region (Combined)") + ax.set_xlabel("Time") + ax.set_ylabel("Damping Value") + ax.legend(loc="upper right", ncol=2) + plt.savefig(f"{name}/damping_values_combined{name}{synthetic}.png", dpi=dpi) + +def run_prior_predictive_check(name): + validation_data = load_pickle(f'{name}/validation_data_{name}.pickle') # synthetic data + + divi_region_keys = region_keys_sorted(validation_data) + num_samples = validation_data['region0'].shape[0] + time_points = validation_data['region0'].shape[1] + + true_data = load_divi_data() + true_data = np.concatenate( + [true_data[key] for key in divi_region_keys], axis=-1 + )[0] + # get sims in shape (samples, time, regions) + simulations = np.zeros((num_samples, time_points, len(divi_region_keys))) + for i in range(num_samples): + simulations[i] = np.concatenate([validation_data[key][i] for key in divi_region_keys], axis=-1) + + fig, ax = plt.subplots(figsize=(8, 6)) + plot_aggregated_over_regions( + simulations, true_data=true_data, label="Region Aggregated Median (No Aug)", ax=ax, color=colors["Red"] + ) + ax.set_title("Prior predictive check - Aggregated over regions") + plt.savefig(f'{name}/prior_predictive_check_{name}.png', dpi=dpi) + +def compare_median_sim_to_mean_param_sim(name): + num_samples = 10 + + divi_dict = load_divi_data() + validation_data_skip2w = skip_2weeks(divi_dict) + aggregate_states(validation_data_skip2w) + divi_region_keys = region_keys_sorted(divi_dict) + divi_data = np.concatenate( + [divi_dict[key] for key in divi_region_keys], axis=-1 + )[0] + + simulations_aug = load_pickle(f'{name}/sims_{name}_with_aug.pickle') + + workflow = get_workflow() + workflow.approximator = keras.models.load_model( + filepath=os.path.join(f"{name}/model_{name}.keras") + ) + + samples = workflow.sample(conditions=validation_data_skip2w, num_samples=num_samples) + for key in inference_params: + samples[key] = np.median(samples[key], axis=1) + samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], 16, NUM_DAMPING_POINTS)) + result = run_germany_nuts3_simulation( + damping_values=samples['damping_values'][0], + t_E=samples['t_E'][0], t_ISy=samples['t_ISy'][0], + t_ISev=samples['t_ISev'][0], t_Cr=samples['t_Cr'][0], + mu_CR=samples['mu_CR'][0], mu_IH=samples['mu_IH'][0], + mu_HU=samples['mu_HU'][0], mu_UD=samples['mu_UD'][0], + transmission_prob=samples['transmission_prob'][0] + ) + for key in result.keys(): + result[key] = np.array(result[key])[None, ...] # add sample axis + + result = extract_observables(result) + divi_region_keys = region_keys_sorted(result) + result = np.concatenate( + [result[key] if key in result else np.zeros_like(result[divi_region_keys[0]]) for key in divi_region_keys], + axis=-1 + )[0] + + + fig, ax = plt.subplots(figsize=(8, 6)) + plot_aggregated_over_regions( + simulations_aug, true_data=result, label="Compare median of sim to sim of median params", ax=ax, color=colors["Red"] + ) + ax.set_title("Comparison of Median Simulation to Simulation of Median Parameters") + plt.savefig(f'{name}/compare_median_sim13_to_mean_param_sim_{name}.png', dpi=dpi) + plt.close() + class Simulation: """ """ @@ -134,26 +529,28 @@ def __init__(self, data_dir, start_date, results_dir): if not os.path.exists(self.results_dir): os.makedirs(self.results_dir) - def set_covid_parameters(self, model, t_E, t_ISy, t_ISev, t_Cr, transmission_prob): + def set_covid_parameters(self, model, t_E, t_C, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): """ :param model: """ model.parameters.TimeExposed[mio.AgeGroup(0)] = t_E - model.parameters.TimeInfectedNoSymptoms[mio.AgeGroup(0)] = 5.2 - t_E + model.parameters.TimeInfectedNoSymptoms[mio.AgeGroup(0)] = t_C model.parameters.TimeInfectedSymptoms[mio.AgeGroup(0)] = t_ISy model.parameters.TimeInfectedSevere[mio.AgeGroup(0)] = t_ISev model.parameters.TimeInfectedCritical[mio.AgeGroup(0)] = t_Cr # probabilities - model.parameters.TransmissionProbabilityOnContact[mio.AgeGroup(0)] = transmission_prob + model.parameters.TransmissionProbabilityOnContact[mio.AgeGroup( + 0)] = transmission_prob model.parameters.RelativeTransmissionNoSymptoms[mio.AgeGroup(0)] = 1 - model.parameters.RecoveredPerInfectedNoSymptoms[mio.AgeGroup(0)] = 0.2069 - model.parameters.SeverePerInfectedSymptoms[mio.AgeGroup(0)] = 0.07864 - model.parameters.CriticalPerSevere[mio.AgeGroup(0)] = 0.17318 - model.parameters.DeathsPerCritical[mio.AgeGroup(0)] = 0.21718 + model.parameters.RecoveredPerInfectedNoSymptoms[mio.AgeGroup( + 0)] = mu_CR + model.parameters.SeverePerInfectedSymptoms[mio.AgeGroup(0)] = mu_IH + model.parameters.CriticalPerSevere[mio.AgeGroup(0)] = mu_HU + model.parameters.DeathsPerCritical[mio.AgeGroup(0)] = mu_UD # start day is set to the n-th day of the year model.parameters.StartDay = self.start_date.timetuple().tm_yday @@ -172,23 +569,34 @@ def set_contact_matrices(self, model): minimum = np.zeros((self.num_groups, self.num_groups)) contact_matrices[0] = mio.ContactMatrix(baseline, minimum) model.parameters.ContactPatterns.cont_freq_mat = contact_matrices - - def set_npis(self, params, end_date, damping_value): + + def set_npis(self, params, end_date, damping_values): """ :param params: :param end_date: """ - - start_damping = datetime.date( - year=2020, month=10, day=8) - - if start_damping < end_date: - start_date = (start_damping - self.start_date).days - params.ContactPatterns.cont_freq_mat[0].add_damping(mio.Damping(np.r_[damping_value], t=start_date)) - - def get_graph(self, end_date, t_E, t_ISy, t_ISev, t_Cr, transmission_prob): + start_damping_1 = DATE_TIME + datetime.timedelta(days=15) + start_damping_2 = DATE_TIME + datetime.timedelta(days=30) + start_damping_3 = DATE_TIME + datetime.timedelta(days=45) + + if start_damping_1 < end_date: + start_date = (start_damping_1 - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping( + mio.Damping(np.r_[damping_values[0]], t=start_date)) + + if start_damping_2 < end_date: + start_date = (start_damping_2 - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping( + mio.Damping(np.r_[damping_values[1]], t=start_date)) + + if start_damping_3 < end_date: + start_date = (start_damping_3 - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping( + mio.Damping(np.r_[damping_values[2]], t=start_date)) + + def get_graph(self, end_date, t_E, t_C, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): """ :param end_date: @@ -196,7 +604,8 @@ def get_graph(self, end_date, t_E, t_ISy, t_ISev, t_Cr, transmission_prob): """ print("Initializing model...") model = Model(self.num_groups) - self.set_covid_parameters(model, t_E, t_ISy, t_ISev, t_Cr, transmission_prob) + self.set_covid_parameters( + model, t_E, t_C, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) self.set_contact_matrices(model) print("Model initialized.") @@ -204,7 +613,6 @@ def get_graph(self, end_date, t_E, t_ISy, t_ISev, t_Cr, transmission_prob): scaling_factor_infected = [2.5] scaling_factor_icu = 1.0 - tnt_capacity_factor = 7.5 / 100000. data_dir_Germany = os.path.join(self.data_dir, "Germany") mobility_data_file = os.path.join( @@ -225,14 +633,13 @@ def get_graph(self, end_date, t_E, t_ISy, t_ISev, t_Cr, transmission_prob): scaling_factor_icu, 0, 0, False) print("Setting edges...") - mio.osecir.set_edges( - mobility_data_file, graph, 1) - + mio.osecir.set_edges(mobility_data_file, graph, 1) + print("Graph created.") return graph - def run(self, num_days_sim, damping_values, t_E, t_ISy, t_ISev, t_Cr, transmission_prob, save_graph=True): + def run(self, num_days_sim, damping_values, t_E, t_C, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob, save_graph=True): """ :param num_days_sim: @@ -244,12 +651,18 @@ def run(self, num_days_sim, damping_values, t_E, t_ISy, t_ISev, t_Cr, transmissi mio.set_log_level(mio.LogLevel.Warning) end_date = self.start_date + datetime.timedelta(days=num_days_sim) - graph = self.get_graph(end_date, t_E, t_ISy, t_ISev, t_Cr, transmission_prob) + graph = self.get_graph(end_date, t_E, t_C, t_ISy, t_ISev, + t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) mobility_graph = osecir.MobilityGraph() for node_idx in range(graph.num_nodes): node = graph.get_node(node_idx) - self.set_npis(node.property.parameters, end_date, damping_values[node.id // 1000 -1]) + + self.set_npis( + node.property.parameters, + end_date, + damping_values[node.id // 1000 - 1] + ) mobility_graph.add_node(node.id, node.property) for edge_idx in range(graph.num_edges): mobility_graph.add_edge( @@ -263,150 +676,395 @@ def run(self, num_days_sim, damping_values, t_E, t_ISy, t_ISev, t_Cr, transmissi for node_idx in range(mobility_sim.graph.num_nodes): node = mobility_sim.graph.get_node(node_idx) if node.id in no_icu_ids: - results[f'no_icu_region{node_idx}'] = osecir.interpolate_simulation_result(node.property.result) + results[f'no_icu_region{node_idx}'] = osecir.interpolate_simulation_result( + node.property.result) else: - results[f'region{node_idx}'] = osecir.interpolate_simulation_result(node.property.result) - + results[f'region{node_idx}'] = osecir.interpolate_simulation_result( + node.property.result) + return results -def run_germany_nuts3_simulation(damping_values, t_E, t_ISy, t_ISev, t_Cr, transmission_prob): + +def run_germany_nuts3_simulation(damping_values, t_E, t_C, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): mio.set_log_level(mio.LogLevel.Warning) file_path = os.path.dirname(os.path.abspath(__file__)) sim = Simulation( data_dir=os.path.join(file_path, "../../../data"), - start_date=datetime.date(year=2020, month=10, day=1), + start_date=DATE_TIME, results_dir=os.path.join(file_path, "../../../results_osecir")) num_days_sim = 60 - - results = sim.run(num_days_sim, damping_values, t_E, t_ISy, t_ISev, t_Cr, transmission_prob) + + results = sim.run(num_days_sim, damping_values, t_E, t_C, t_ISy, + t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) return results def prior(): - damping_values = np.random.uniform(0.0, 1.0, 16) - t_E = np.random.uniform(1., 5.2) - t_ISy = np.random.uniform(4., 10.) - t_ISev = np.random.uniform(5., 10.) - t_Cr = np.random.uniform(9., 17.) - transmission_prob = np.random.uniform(0., 0.2) - return {'damping_values': damping_values, - 't_E': t_E, - 't_ISy': t_ISy, - 't_ISev': t_ISev, - 't_Cr': t_Cr, - 'transmission_prob': transmission_prob} + damping_values = np.zeros((NUM_DAMPING_POINTS, 16)) + for i in range(NUM_DAMPING_POINTS): + mean = np.random.uniform(0, 1) + scale = 0.1 + a, b = (0 - mean) / scale, (1 - mean) / scale + damping_values[i] = truncnorm.rvs( + a=a, b=b, loc=mean, scale=scale, size=16 + ) + return { + 'damping_values': np.transpose(damping_values), + 't_E': np.random.uniform(*bounds['t_E']), + 't_C': np.random.uniform(*bounds['t_C']), + 't_ISy': np.random.uniform(*bounds['t_ISy']), + 't_ISev': np.random.uniform(*bounds['t_ISev']), + 't_Cr': np.random.uniform(*bounds['t_Cr']), + 'mu_CR': np.random.uniform(*bounds['mu_CR']), + 'mu_IH': np.random.uniform(*bounds['mu_IH']), + 'mu_HU': np.random.uniform(*bounds['mu_HU']), + 'mu_UD': np.random.uniform(*bounds['mu_UD']), + 'transmission_prob': np.random.uniform(*bounds['transmission_prob']) + } + def load_divi_data(): file_path = os.path.dirname(os.path.abspath(__file__)) divi_path = os.path.join(file_path, "../../../data/Germany/pydata") - data = pd.read_json(os.path.join(divi_path, "county_divi_all_dates.json")) - data = data[data['Date']>= np.datetime64(datetime.date(2020, 10, 1))] - data = data[data['Date'] <= np.datetime64(datetime.date(2020, 10, 1) + datetime.timedelta(days=50))] + data = pd.read_json(os.path.join(divi_path, "county_divi_ma7.json")) + data = data[data['Date'] >= np.datetime64(DATE_TIME)] + data = data[data['Date'] <= np.datetime64(DATE_TIME + datetime.timedelta(days=60))] data = data.sort_values(by=['ID_County', 'Date']) divi_data = data.pivot(index='Date', columns='ID_County', values='ICU') - divi_dict = {f"region{i}": divi_data[region_id].to_numpy()[None, :, None] for i, region_id in enumerate(region_ids) if region_id not in no_icu_ids} + divi_dict = {} + for i, region_id in enumerate(region_ids): + if region_id not in no_icu_ids: + divi_dict[f"region{i}"] = divi_data[region_id].to_numpy()[None, :, None] + else: + divi_dict[f"no_icu_region{i}"] = np.zeros((1, divi_data.shape[0], 1)) + return divi_dict - return divi_data.to_numpy(), divi_dict +def load_extrapolated_case_data(): + file_path = os.path.dirname(os.path.abspath(__file__)) + case_path = os.path.join(file_path, "../../../data/Germany/pydata") -if __name__ == "__main__": - - import os - os.environ["KERAS_BACKEND"] = "tensorflow" - - import bayesflow as bf - from tensorflow import keras + file = h5py.File(os.path.join(case_path, "Results_rki.h5")) + divi_dict = {} + for i, region_id in enumerate(region_ids): + if region_id not in no_icu_ids: + divi_dict[f"region{i}"] = np.array(file[f"{region_id}"]['Total'][:, 4])[None, :, None] + else: + divi_dict[f"no_icu_region{i}"] = np.zeros((1, np.array(file[f"{region_id}"]['Total']).shape[0], 1)) + return divi_dict - simulator = bf.simulators.make_simulator([prior, run_germany_nuts3_simulation]) - trainings_data = simulator.sample(1000) - for key in trainings_data.keys(): +def extract_observables(simulation_results, observable_index=7): + for key in simulation_results.keys(): if key not in inference_params: - trainings_data[key] = trainings_data[key][:, :, 7][..., np.newaxis] + simulation_results[key] = simulation_results[key][:, :, observable_index][..., np.newaxis] + return simulation_results - with open('trainings_data10_counties_wcovidparams_oct.pickle', 'wb') as f: + +def create_train_data(filename, number_samples=1000): + + simulator = bf.simulators.make_simulator( + [prior, run_germany_nuts3_simulation] + ) + trainings_data = simulator.sample(number_samples) + trainings_data = extract_observables(trainings_data) + with open(filename, 'wb') as f: pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) - # with open('trainings_data1_counties_wcovidparams_oct.pickle', 'rb') as f: - # trainings_data = pickle.load(f) - # trainings_data = {k: np.round(v) if ('region' in k) else v for k, v in trainings_data.items()} - # for i in range(19): - # with open(f'trainings_data{i+2}_counties_wcovidparams_oct.pickle', 'rb') as f: - # data = pickle.load(f) - # trainings_data = {k: np.concatenate([trainings_data[k], np.round(data[k])]) if ('region' in k) else np.concatenate([trainings_data[k], data[k]]) for k in trainings_data.keys()} - - # with open('validation_data_counties_wcovidparams_oct.pickle', 'rb') as f: - # validation_data = pickle.load(f) - # divi_dict = {k: np.round(v) if ('region' in k) else v for k, v in validation_data.items()} - - # adapter = ( - # bf.Adapter() - # .to_array() - # .convert_dtype("float64", "float32") - # .constrain("damping_values", lower=0.0, upper=1.0) - # .constrain("t_E", lower=1.0, upper=6.0) - # .constrain("t_ISy", lower=5.0, upper=10.0) - # .constrain("t_ISev", lower=2.0, upper=8.0) - # .concatenate(["damping_values", "t_E", "t_ISy", "t_ISev"], into="inference_variables", axis=-1) - # .concatenate([f'region{i}' for i in range(len(region_ids)) if region_ids[i] not in no_icu_ids], into="summary_variables", axis=-1) - # .log("summary_variables", p1=True) - # ) - - # print("summary_variables shape:", adapter(trainings_data)["summary_variables"].shape) - # print("inference_variables shape:", adapter(trainings_data)["inference_variables"].shape) - - # summary_network = bf.networks.TimeSeriesNetwork(summary_dim=38) - # inference_network = bf.networks.CouplingFlow(depth=7, transform='spline') - - # workflow = bf.BasicWorkflow( - # simulator=simulator, - # adapter=adapter, - # summary_network=summary_network, - # inference_network=inference_network, - # standardize='all' - # ) - - # history = workflow.fit_offline(data=trainings_data, epochs=100, batch_size=32, validation_data=validation_data) - - # workflow.approximator.save(filepath=os.path.join(os.path.dirname(__file__), "model_countylvl_wcovidparams_oct.keras")) - - # plots = workflow.plot_default_diagnostics(test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True}) - # plots['losses'].savefig('losses_countylvl_wcovidparams2_oct.png') - # plots['recovery'].savefig('recovery_countylvl_wcovidparams2_oct.png') - # plots['calibration_ecdf'].savefig('calibration_ecdf_countylvl_wcovidparams2_oct.png') - # plots['z_score_contraction'].savefig('z_score_contraction_countylvl_wcovidparams2_oct.png') - - # divi_data, divi_dict = load_divi_data() - # # divi_data = np.concatenate( - # # [validation_data[f'region{i}'] for i in range(len(region_ids)) if region_ids[i] not in no_icu_ids], - # # axis=-1 - # # ) - # workflow.approximator = keras.models.load_model(os.path.join(os.path.dirname(__file__), "model_countylvl_wcovidparams_oct.keras")) - - # samples = workflow.sample(conditions=divi_dict, num_samples=1000) - # samples = np.concatenate([samples[key] for key in inference_params], axis=-1) - # samples = np.squeeze(samples) - # sims = [] - # for i in range(samples.shape[0]): - # result = run_germany_nuts3_simulation(samples[i][:16], *samples[i][16:]) - # for key in result.keys(): - # result[key] = np.array(result[key])[:, 7, None] - # sims.append(np.concatenate([result[key] for key in result.keys() if key.startswith('region')], axis=-1)) - # sims = np.array(sims) - # sims = np.floor(sims) - - # np.random.seed(42) - # fig, ax = plt.subplots(nrows=2, ncols=5, figsize=(12, 5), layout="constrained") - # ax = ax.flatten() - # rand_index = np.random.choice(sims.shape[-1], replace=False, size=len(ax)) - # for i, a in enumerate(ax): - # plot_region_median_mad(sims, region=rand_index[i], true_data=divi_data, label=r"Median $\pm$ Mad", ax=a) - # plt.savefig('random_regions_wcovidparams_oct.png') - # # plt.show() - # #%% - # plot_aggregated_over_regions(sims, true_data=divi_data, label="Region Aggregated Median $\pm$ Mad") - # plt.savefig('region_aggregated_wcovidparams_oct.png') - # # plt.show() - # # %% \ No newline at end of file + +def load_pickle(path): + with open(path, "rb") as f: + return pickle.load(f) + +def is_region_key(k: str) -> bool: + return 'region' in k + +def apply_aug(d: dict, aug) -> dict: + return {k: np.clip(aug(v), 0, None) if is_region_key(k) else v for k, v in d.items()} + +def concat_dicts(base: dict, new: dict) -> dict: + missing = set(base) - set(new) + if missing: + raise KeyError(f"new dict missing keys: {sorted(missing)}") + for k in base: + base[k] = np.concatenate([base[k], new[k]]) + return base + +def region_keys_sorted(d: dict): + def idx(k): + # handles "regionN" and "no_icu_regionN" + return int(k.split("region")[-1]) + return sorted([k for k in d if is_region_key(k)], key=idx) + +def aggregate_states(d: dict) -> None: + n_regions = len(region_ids) + # per state + for state in range(16): + idxs = [ + r for r in range(n_regions) + if region_ids[r] // 1000 == state + 1 + ] + d[f"fed_state{state}"] = np.sum([d[f"region{r}"] if region_ids[r] not in no_icu_ids else d[f"no_icu_region{r}"] for r in idxs], axis=0) + # all allowed regions + d["state"] = np.sum([d[f"fed_state{r}"] for r in range(16)], axis=0) + +def combine_results(dict_list): + combined = {} + for d in dict_list: + combined = concat_dicts(combined, d) if combined else d + return combined + +def skip_2weeks(d:dict) -> dict: + return {k: v[:, 14:, :] if is_region_key(k) else v for k, v in d.items()} + +def get_workflow(): + + simulator = bf.make_simulator( + [prior, run_germany_nuts3_simulation] + ) + adapter = ( + bf.Adapter() + .to_array() + .convert_dtype("float64", "float32") + .constrain("damping_values", lower=0.0, upper=1.0) + .constrain("t_E", lower=bounds["t_E"][0], upper=bounds["t_E"][1]) + .constrain("t_C", lower=bounds["t_C"][0], upper=bounds["t_C"][1]) + .constrain("t_ISy", lower=bounds["t_ISy"][0], upper=bounds["t_ISy"][1]) + .constrain("t_ISev", lower=bounds["t_ISev"][0], upper=bounds["t_ISev"][1]) + .constrain("t_Cr", lower=bounds["t_Cr"][0], upper=bounds["t_Cr"][1]) + .constrain("mu_CR", lower=bounds["mu_CR"][0], upper=bounds["mu_CR"][1]) + .constrain("mu_IH", lower=bounds["mu_IH"][0], upper=bounds["mu_IH"][1]) + .constrain("mu_HU", lower=bounds["mu_HU"][0], upper=bounds["mu_HU"][1]) + .constrain("mu_UD", lower=bounds["mu_UD"][0], upper=bounds["mu_UD"][1]) + .constrain("transmission_prob", lower=bounds["transmission_prob"][0], upper=bounds["transmission_prob"][1]) + .concatenate( + ["damping_values", "t_E", "t_C", "t_ISy", "t_ISev", "t_Cr", + "mu_CR", "mu_IH", "mu_HU", "mu_UD", "transmission_prob"], + into="inference_variables", + axis=-1 + ) + .concatenate(summary_vars, into="summary_variables", axis=-1) + ) + + summary_network = bf.networks.FusionTransformer( + summary_dim=(len(bounds)+16*NUM_DAMPING_POINTS)*2, dropout=0.1 + ) + inference_network = bf.networks.FlowMatching(subnet_kwargs={'widths': (512, 512, 512, 512, 512)}) + + # aug = bf.augmentations.NNPE(spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False) + workflow = bf.BasicWorkflow( + simulator=simulator, + adapter=adapter, + summary_network=summary_network, + inference_network=inference_network, + standardize='all' + # augmentations={f'region{i}': aug for i in range(len(region_ids)) if region_ids[i] not in no_icu_ids} + # aggregation of the states would need to be recomputed every time a different noise realization is applied + ) + + return workflow + + +def run_training(name, num_training_files=20): + train_template = name+"/trainings_data{i}_"+name+".pickle" + val_path = f"{name}/validation_data_{name}.pickle" + + aug = bf.augmentations.NNPE( + spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False + ) + + # training data + train_files = [train_template.format(i=i) for i in range(1, 1+num_training_files)] + trainings_data = None + for p in train_files: + d = load_pickle(p) + d = apply_aug(d, aug=aug) # only on region keys + d = skip_2weeks(d) + d['damping_values'] = d['damping_values'].reshape((d['damping_values'].shape[0], -1)) + if trainings_data is None: + trainings_data = d + else: + trainings_data = concat_dicts(trainings_data, d) + aggregate_states(trainings_data) + + # validation data + validation_data = apply_aug(load_pickle(val_path), aug=aug) + validation_data = skip_2weeks(validation_data) + aggregate_states(validation_data) + validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) + + # check data + workflow = get_workflow() + print("summary_variables shape:", workflow.adapter(trainings_data)["summary_variables"].shape) + print("inference_variables shape:", workflow.adapter(trainings_data)["inference_variables"].shape) + + history = workflow.fit_offline( + data=trainings_data, epochs=500, batch_size=64, validation_data=validation_data + ) + + workflow.approximator.save( + filepath=os.path.join(f"{name}/model_{name}.keras") + ) + + plots = workflow.plot_default_diagnostics( + test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True} + ) + plots['losses'].savefig(f'{name}/losses_{name}.png', dpi=dpi) + plots['recovery'].savefig(f'{name}/recovery_{name}.png', dpi=dpi) + plots['calibration_ecdf'].savefig(f'{name}/calibration_ecdf_{name}.png', dpi=dpi) + #plots['z_score_contraction'].savefig(f'{name}/z_score_contraction_{name}.png', dpi=dpi) + + +def run_inference(name, num_samples=1000, on_synthetic_data=False): + val_path = f"{name}/validation_data_{name}.pickle" + synthetic = "_synthetic" if on_synthetic_data else "" + + aug = bf.augmentations.NNPE( + spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False + ) + + # validation data + validation_data = load_pickle(val_path) # synthetic data + if on_synthetic_data: + # validation data + validation_data = apply_aug(validation_data, aug=aug) + validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) + validation_data_skip2w = skip_2weeks(validation_data) + aggregate_states(validation_data_skip2w) + divi_dict = validation_data + divi_region_keys = region_keys_sorted(divi_dict) + + divi_data = np.concatenate( + [divi_dict[key] for key in divi_region_keys], axis=-1 + )[0] # only one dataset + else: + divi_dict = load_divi_data() + validation_data_skip2w = skip_2weeks(divi_dict) + aggregate_states(validation_data_skip2w) + divi_region_keys = region_keys_sorted(divi_dict) + divi_data = np.concatenate( + [divi_dict[key] for key in divi_region_keys], axis=-1 + )[0] + + workflow = get_workflow() + workflow.approximator = keras.models.load_model( + filepath=os.path.join(f"{name}/model_{name}.keras") + ) + + if os.path.exists(f'{name}/sims_{name}{synthetic}_with_aug.pickle') and os.path.exists(f'{name}/sims_{name}{synthetic}.pickle') and os.path.exists(f'{name}/samples_{name}{synthetic}.pickle'): + simulations = load_pickle(f'{name}/sims_{name}{synthetic}.pickle') + simulations_aug = load_pickle( + f'{name}/sims_{name}{synthetic}_with_aug.pickle') + samples = load_pickle(f'{name}/samples_{name}{synthetic}.pickle') + print("loaded simulations from file") + else: + samples = workflow.sample(conditions=validation_data_skip2w, num_samples=num_samples) + with open(f'{name}/samples_{name}{synthetic}.pickle', 'wb') as f: + pickle.dump(samples, f, pickle.HIGHEST_PROTOCOL) + samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], num_samples, 16, NUM_DAMPING_POINTS)) + results = [] + for i in range(num_samples): # we only have one dataset for inference here + result = run_germany_nuts3_simulation( + damping_values=samples['damping_values'][0, i], + t_E=samples['t_E'][0, i], t_C=samples['t_C'][0, i], t_ISy=samples['t_ISy'][0, i], + t_ISev=samples['t_ISev'][0, i], t_Cr=samples['t_Cr'][0, i], + mu_CR=samples['mu_CR'][0, i], mu_IH=samples['mu_IH'][0, i], + mu_HU=samples['mu_HU'][0, i], mu_UD=samples['mu_UD'][0, i], + transmission_prob=samples['transmission_prob'][0, i] + ) + for key in result.keys(): + result[key] = np.array(result[key])[None, ...] # add sample axis + results.append(result) + results = combine_results(results) + results = extract_observables(results) + results_aug = apply_aug(results, aug=aug) + + # get sims in shape (samples, time, regions) + simulations = np.zeros((num_samples, divi_data.shape[0], divi_data.shape[1])) + simulations_aug = np.zeros((num_samples, divi_data.shape[0], divi_data.shape[1])) + for i in range(num_samples): + simulations[i] = np.concatenate([results[key][i] for key in divi_region_keys], axis=-1) + simulations_aug[i] = np.concatenate([results_aug[key][i] for key in divi_region_keys], axis=-1) + + # save sims + with open(f'{name}/sims_{name}{synthetic}.pickle', 'wb') as f: + pickle.dump(simulations, f, pickle.HIGHEST_PROTOCOL) + with open(f'{name}/sims_{name}{synthetic}_with_aug.pickle', 'wb') as f: + pickle.dump(simulations_aug, f, pickle.HIGHEST_PROTOCOL) + + samples['damping_values'] = samples['damping_values'].reshape( + (samples['damping_values'].shape[0], samples['damping_values'].shape[1], -1)) + validation_data['damping_values'] = validation_data['damping_values'].reshape( + (validation_data['damping_values'].shape[0], -1)) + + plot = bf.diagnostics.pairs_posterior( + samples, priors=validation_data, dataset_id=0) + plot.savefig(f'{name}/pairs_posterior_{name}{synthetic}.png', dpi=dpi) + + samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], num_samples, 16, NUM_DAMPING_POINTS)) + plot_damping_values(samples['damping_values'][0], + name=name, synthetic=synthetic) + + plot_all_regions(simulations, divi_data, name, synthetic, with_aug="") + plot_all_regions(simulations_aug, divi_data, name, synthetic, with_aug="_with_aug") + + plot_aggregated_to_federal_states(simulations, divi_data, name, synthetic, with_aug="") + plot_aggregated_to_federal_states(simulations_aug, divi_data, name, synthetic, with_aug="_with_aug") + + fig, axes = plt.subplots(1, 2, figsize=(12, 6), constrained_layout=True) + # Plot with augmentation + plot_aggregated_over_regions( + simulations_aug, true_data=divi_data, label="Region Aggregated Median", ax=axes[0], color=colors["Red"] + ) + # Plot with augmentation (80% quantile only) + plot_aggregated_over_regions( + simulations_aug, true_data=divi_data, label="Region Aggregated Median", ax=axes[1], color=colors["Red"], only_80q=True + ) + lines, labels = axes[0].get_legend_handles_labels() + fig.legend(lines, labels) + plt.savefig(f'{name}/region_aggregated_{name}{synthetic}.png', dpi=dpi) + plt.close() + + fig, axis = plt.subplots(1, 2, figsize=(10, 4), sharex=True, layout="constrained") + ax = calibration_curves_per_region(simulations, divi_data, ax=axis[0]) + ax, stats = calibration_median_mad_over_regions(simulations, divi_data, ax=axis[1]) + plt.savefig(f'{name}/calibration_per_region_{name}{synthetic}.png', dpi=dpi) + plt.close() + fig, axis = plt.subplots(1, 2, figsize=(10, 4), sharex=True, layout="constrained") + ax = calibration_curves_per_region(simulations_aug, divi_data, ax=axis[0]) + ax, stats = calibration_median_mad_over_regions(simulations_aug, divi_data, ax=axis[1]) + plt.savefig(f'{name}/calibration_per_region_{name}{synthetic}_with_aug.png', dpi=dpi) + plt.close() + + plot_icu_on_germany(simulations, name, synthetic, with_aug="") + plot_icu_on_germany(simulations_aug, name, synthetic, with_aug="_with_aug") + + simulation_agg = np.sum(simulations, axis=-1, keepdims=True) # sum over regions + simulation_aug_agg = np.sum(simulations_aug, axis=-1, keepdims=True) + + rmse = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes(simulation_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True), normalize=False) + rmse_aug = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes(simulation_aug_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True), normalize=False) + print("Mean RMSE over regions:", rmse["values"].mean()) + print("Mean RMSE over regions (with aug):", rmse_aug["values"].mean()) + + cal_error = bf.diagnostics.metrics.calibration_error(np.swapaxes(simulation_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True)) + cal_error_aug = bf.diagnostics.metrics.calibration_error(np.swapaxes(simulation_aug_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True)) + print("Mean Calibration Error over regions:", cal_error["values"].mean()) + print("Mean Calibration Error over regions (with aug):", cal_error_aug["values"].mean()) + + +if __name__ == "__main__": + name = "3dampings_lessnoise_newnetwork" + + set_fontsize() + + if not os.path.exists(name): + os.makedirs(name) + # create_train_data(filename=f'{name}/validation_data_{name}.pickle', number_samples=100) + # run_training(name=name, num_training_files=10) + run_inference(name=name, on_synthetic_data=True) + run_inference(name=name, on_synthetic_data=False) + # run_prior_predictive_check(name=name) + # compare_median_sim_to_mean_param_sim(name=name) \ No newline at end of file diff --git a/pycode/examples/simulation/graph_germany_nuts3_clean.py b/pycode/examples/simulation/graph_germany_nuts3_clean.py deleted file mode 100644 index 908813cccf..0000000000 --- a/pycode/examples/simulation/graph_germany_nuts3_clean.py +++ /dev/null @@ -1,1070 +0,0 @@ -############################################################################# -# Copyright (C) 2020-2025 MEmilio -# -# Authors: Carlotta Gerstein -# -# Contact: Martin J. Kuehn -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -############################################################################# -import os -os.environ["KERAS_BACKEND"] = "tensorflow" - -import numpy as np -import pandas as pd -import matplotlib.pyplot as plt -import datetime -import pickle -from scipy.stats import truncnorm - -from matplotlib.patches import Patch - -import bayesflow as bf -import keras - -import memilio.simulation as mio -import memilio.simulation.osecir as osecir -from memilio.simulation.osecir import Model, interpolate_simulation_result -from memilio.epidata import defaultDict as dd - -import geopandas as gpd - - -excluded_ids = [11001, 11002, 11003, 11004, 11005, 11006, - 11007, 11008, 11009, 11010, 11011, 11012, 16056] -no_icu_ids = [7338, 9374, 9473, 9573] -region_ids = [region_id for region_id in dd.County.keys() - if region_id not in excluded_ids] - -inference_params = ['damping_values', 't_E', 't_C', 't_ISy', 't_ISev', - 't_Cr', 'mu_CR', 'mu_IH', 'mu_HU', 'mu_UD', 'transmission_prob'] -summary_vars = ['state'] + [f'fed_state{i}' for i in range(16)] + [ - f'region{i}' for i in range(len(region_ids)) if region_ids[i] not in no_icu_ids] - -bounds = { - 't_E': (1.0, 5.2), - 't_C': (0.1, 4.2), - 't_ISy': (4.0, 10.0), - 't_ISev': (5.0, 10.0), - 't_Cr': (9.0, 17.0), - 'mu_CR': (0.0, 0.4), - 'mu_IH': (0.0, 0.2), - 'mu_HU': (0.0, 0.4), - 'mu_UD': (0.0, 0.4), - 'transmission_prob': (0.0, 0.2) -} -SPIKE_SCALE = 0.4 -SLAB_SCALE = 0.2 -DATE_TIME = datetime.date(year=2020, month=10, day=1) -NUM_DAMPING_POINTS = 3 - - -def set_fontsize(base_fontsize=17): - fontsize = base_fontsize - plt.rcParams.update({ - 'font.size': fontsize, - 'axes.titlesize': fontsize * 1, - 'axes.labelsize': fontsize, - 'xtick.labelsize': fontsize * 0.8, - 'ytick.labelsize': fontsize * 0.8, - 'legend.fontsize': fontsize * 0.8, - 'font.family': "Arial" - }) - - -plt.style.use('default') - -dpi = 300 - -colors = {"Blue": "#155489", - "Medium blue": "#64A7DD", - "Light blue": "#B4DCF6", - "Lilac blue": "#AECCFF", - "Turquoise": "#76DCEC", - "Light green": "#B6E6B1", - "Medium green": "#54B48C", - "Green": "#5D8A2B", - "Teal": "#20A398", - "Yellow": "#FBD263", - "Orange": "#E89A63", - "Rose": "#CF7768", - "Red": "#A34427", - "Purple": "#741194", - "Grey": "#C0BFBF", - "Dark grey": "#616060", - "Light grey": "#F1F1F1"} - -def plot_region_fit( - data: np.ndarray, - region: int, - true_data=None, - ax=None, - label=None, - color="red", - only_80q=False -): - if data.ndim != 3: - raise ValueError("Array not of shape (samples, time_points, regions)") - if true_data is not None: - if true_data.shape != data.shape[1:]: - raise ValueError("True data shape does not match data shape") - n_samples, n_time, n_regions = data.shape - if not (0 <= region < n_regions): - raise IndexError - - x = np.arange(n_time) - vals = data[:, :, region] # (samples, time_points) - - qs_80 = np.quantile(vals, q=[0.1, 0.9], axis=0) - qs_90 = np.quantile(vals, q=[0.05, 0.95], axis=0) - qs_95 = np.quantile(vals, q=[0.025, 0.975], axis=0) - - med = np.median(vals, axis=0) - - if ax is None: - fig, ax = plt.subplots() - - ax.plot( - x, med, lw=2, label=label or f"{dd.County[region_ids[region]]}", color=color) - ax.fill_between(x, qs_80[0], qs_80[1], alpha=0.5, - color=color, label="80% CI") - if not only_80q: - ax.fill_between(x, qs_90[0], qs_90[1], alpha=0.3, - color=color, label="90% CI") - ax.fill_between(x, qs_95[0], qs_95[1], alpha=0.1, - color=color, label="95% CI") - - if true_data is not None: - true_vals = true_data[:, region] # (time_points,) - ax.plot(x, true_vals, lw=2, color="black", label="True data") - - ax.set_xlabel("Time") - ax.set_ylabel("ICU") - ax.set_title(f"{dd.County[region_ids[region]]}") - if label is not None: - ax.legend(loc="upper right") - - -def plot_aggregated_over_regions( - data: np.ndarray, - region_agg=np.sum, - true_data=None, - ax=None, - label=None, - color='red', - only_80q=False -): - if data.ndim != 3: - raise ValueError("Array not of shape (samples, time_points, regions)") - if true_data is not None: - if true_data.shape != data.shape[1:]: - raise ValueError("True data shape does not match data shape") - - # Aggregate over regions - agg_over_regions = region_agg(data, axis=-1) # (samples, time_points) - - qs_80 = np.quantile(agg_over_regions, q=[0.1, 0.9], axis=0) - qs_90 = np.quantile(agg_over_regions, q=[0.05, 0.95], axis=0) - qs_95 = np.quantile(agg_over_regions, q=[0.025, 0.975], axis=0) - - # Aggregate over samples - agg_median = np.median(agg_over_regions, axis=0) # (time_points, ) - - x = np.arange(agg_median.shape[0]) - if ax is None: - fig, ax = plt.subplots() - - ax.plot(x, agg_median, lw=2, - label=label or "Aggregated over regions", color=color) - ax.fill_between(x, qs_80[0], qs_80[1], alpha=0.5, - color=color, label="80% CI") - if not only_80q: - ax.fill_between(x, qs_90[0], qs_90[1], alpha=0.3, - color=color, label="90% CI") - ax.fill_between(x, qs_95[0], qs_95[1], alpha=0.1, - color=color, label="95% CI") - if true_data is not None: - true_vals = region_agg(true_data, axis=-1) # (time_points,) - ax.plot(x, true_vals, lw=2, color="black", label="True data") - - ax.set_xlabel("Time") - ax.set_ylabel("ICU") - -def plot_icu_on_germany(simulations, name, synthetic, with_aug): - med = np.median(simulations, axis=0) - - population = pd.read_json('data/Germany/pydata/county_current_population.json') - values = med / population['Population'].to_numpy()[None, :] * 100000 - - - map_data = gpd.read_file(os.path.join(os.getcwd(), 'tools/vg2500_12-31.utm32s.shape/vg2500/VG2500_KRS.shp')) - fedstate_data = gpd.read_file(os.path.join(os.getcwd(), 'tools/vg2500_12-31.utm32s.shape/vg2500/VG2500_LAN.shp')) - - plot_map(values[0], map_data, fedstate_data, "Median ICU", f"{name}/median_icu_germany_initial_{name}{synthetic}{with_aug}") - plot_map(values[-1], map_data, fedstate_data, "Median ICU", f"{name}/median_icu_germany_final_{name}{synthetic}{with_aug}") - -def plot_map(values, map_data, fedstate_data, label, filename): - map_data[label] = map_data['ARS'].map(dict(zip([f"{region_id:05d}" for region_id in region_ids], values))) - - fig, ax = plt.subplots(figsize=(12, 14)) - map_data.plot( - column=f"{label}", - cmap='Reds', - linewidth=0.5, - ax=ax, - edgecolor='0.6', - legend=True, - legend_kwds={'label': f"{label} per County", 'shrink': 0.6}, - ) - fedstate_data.boundary.plot(ax=ax, color='black', linewidth=1) - ax.set_title(f"{label} per County") - ax.axis('off') - plt.savefig(filename, bbox_inches='tight', dpi=dpi) - - -def plot_aggregated_to_federal_states(data, true_data, name, synthetic, with_aug): - fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(25, 25), layout="constrained") - ax = ax.flatten() - for state in range(16): - idxs = [i for i, region_id in enumerate(region_ids) if region_id // 1000 == state + 1] - plot_aggregated_over_regions( - data[:, :, idxs], # Add a dummy region axis for compatibility - true_data=true_data[:, idxs] if true_data is not None else None, - ax=ax[state], - label=f"State {state + 1}", - color=colors["Red"] - ) - plt.savefig(f'{name}/federal_states_{name}{synthetic}{with_aug}.png', dpi=dpi) - plt.close() - -# plot simulations for all regions in 10x4 blocks -def plot_all_regions(simulations, divi_data, name, synthetic, with_aug): - n_regions = simulations.shape[-1] - n_cols = 4 - n_rows = 10 - n_blocks = (n_regions + n_cols * n_rows - 1) // (n_cols * n_rows) - - for block in range(n_blocks): - start_idx = block * n_cols * n_rows - end_idx = min(start_idx + n_cols * n_rows, n_regions) - fig, ax = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(15, 25), layout="constrained") - ax = ax.flatten() - for i, region_idx in enumerate(range(start_idx, end_idx)): - plot_region_fit( - simulations, region=region_idx, true_data=divi_data, label="Median", ax=ax[i], color=colors["Red"] - ) - # Hide unused subplots - for i in range(end_idx - start_idx, len(ax)): - ax[i].axis("off") - plt.savefig(f'{name}/regions_block_{block + 1}_{name}{synthetic}{with_aug}.png', dpi=dpi) - plt.close() - - -def calibration_curves_per_region( - data: np.ndarray, - true_data: np.ndarray, - levels=np.linspace(0.01, 0.99, 20), - ax=None, - max_regions=None, - cmap=plt.cm.Blues, - linewidth=1.5, - legend=True, - with_ideal=True, -): - """ - Per-region calibration curves, each region in a different shade of a colormap. - - data: (samples, time, regions) - true_data: (time, regions) - max_regions: limit number of regions shown - cmap: matplotlib colormap for line shades - """ - if data.ndim != 3: - raise ValueError("Array not of shape (samples, time_points, regions)") - if true_data.shape != data.shape[1:]: - raise ValueError("True data shape does not match data shape") - - n_samples, n_time, n_regions = data.shape - if max_regions is None: - R = n_regions - else: - R = min(max_regions, n_regions) - - if ax is None: - fig, ax = plt.subplots(figsize=(5, 4)) - - colors = [cmap(i / (R + 1)) for i in range(1, R + 1)] - - x = np.asarray(levels) - for r, col in zip(range(R), colors): - emp = [] - for nominal in levels: - q_low = (1.0 - nominal) / 2.0 - q_high = 1.0 - q_low - lo = np.quantile(data[:, :, r], q_low, axis=0) - hi = np.quantile(data[:, :, r], q_high, axis=0) - hits = (true_data[:, r] >= lo) & (true_data[:, r] <= hi) - emp.append(hits.mean()) - emp = np.asarray(emp) - # , label=f"Region {r+1}") - ax.plot(x, emp, lw=linewidth, color=col, alpha=0.5) - - if with_ideal: - ideal_line = ax.plot([0, 1], [0, 1], linestyle="--", - lw=1.2, color="black", label="Ideal")[0] - else: - ideal_line = None - - if legend: - # Custom legend: one patch for regions, one line for ideal - region_patch = Patch(color=colors[-1], label="Regions") - ax.legend(handles=[region_patch, ideal_line], - frameon=True, ncol=1) - - ax.set_xlim(0, 1) - ax.set_ylim(0, 1) - ax.set_xlabel("Nominal level") - ax.set_ylabel("Empirical coverage") - ax.set_title("Calibration per region") - return ax - - -def calibration_median_mad_over_regions( - data: np.ndarray, - true_data: np.ndarray, - levels=np.linspace(0.01, 0.99, 20), - ax=None, - color="tab:blue", - alpha=0.25, - linewidth=2.0, - with_ideal=True, -): - """ - Compute per region empirical coverage at each nominal level, - then summarize across regions with median and MAD. - - data: (samples, time, regions) - true_data: (time, regions) - """ - if data.ndim != 3: - raise ValueError("Array not of shape (samples, time_points, regions)") - if true_data.shape != data.shape[1:]: - raise ValueError("True data shape does not match data shape") - - n_samples, n_time, n_regions = data.shape - L = len(levels) - per_region = np.empty((n_regions, L), dtype=float) - - # per level coverage per region - for j, nominal in enumerate(levels): - q_low = (1.0 - nominal) / 2.0 - q_high = 1.0 - q_low - lo = np.quantile(data, q_low, axis=0) # (time, regions) - hi = np.quantile(data, q_high, axis=0) # (time, regions) - hits = (true_data >= lo) & (true_data <= hi) # (time, regions) - # mean over time for each region - per_region[:, j] = hits.mean(axis=0) - - med = np.median(per_region, axis=0) # (levels,) - mad = np.median(np.abs(per_region - med[None, :]), axis=0) - - if ax is None: - fig, ax = plt.subplots(figsize=(5, 4)) - - x = np.asarray(levels) - ax.fill_between(x, med - mad, med + mad, - alpha=alpha, color=color, label=None) - ax.plot(x, med, lw=linewidth, color=color, label="Median across regions") - - if with_ideal: - ax.plot([0, 1], [0, 1], linestyle="--", - lw=1.2, color="black", label="Ideal") - - ax.set_xlim(0, 1) - ax.set_ylim(0, 1) - ax.set_xlabel("Nominal level") - ax.set_ylabel("Empirical coverage") - ax.set_title("Calibration median and MAD across regions") - ax.legend() - return ax, {"levels": x, "median": med, "mad": mad} - - -def plot_damping_values(damping_values, name, synthetic): - - med = np.median(damping_values, axis=0) - mad = np.median(np.abs(damping_values - med), axis=0) - - # Extend for step plotting - med_extended = np.hstack([med, med[:, -1][:, None]]) - mad_extended = np.hstack([mad, mad[:, -1][:, None]]) - - # Plot damping values per region - fig, axes = plt.subplots(4, 4, figsize=(15, 10), constrained_layout=True) - axes = axes.flatten() - x = np.arange(15, 61, 15) # Time steps from 15 to 60 - - for i, ax in enumerate(axes): - if i < 16: - ax.stairs(med[i], edges=x, lw=2, color='red', baseline=None) - ax.fill_between( - x, med_extended[i] - mad_extended[i], med_extended[i] + mad_extended[i], - alpha=0.25, color='red', step='post' - ) - ax.set_title(f"{dd.State[i+1]}") - ax.set_xlabel("Time") - ax.set_ylabel("Damping Value") - else: - ax.axis('off') # Hide unused subplots - - plt.suptitle("Damping Values per Region") - plt.savefig(f"{name}/damping_values{name}{synthetic}.png", dpi=dpi) - - # Combined plot for all regions - fig, ax = plt.subplots(figsize=(10, 6)) - cmap = plt.cm.get_cmap("viridis", 16) # Colormap with 16 distinct colors - - for i in range(16): - ax.stairs( - med[i], edges=x, lw=2, label=f"{dd.State[i+1]}", - color=cmap(i), baseline=None - ) - - ax.set_title("Damping Values per Region (Combined)") - ax.set_xlabel("Time") - ax.set_ylabel("Damping Value") - ax.legend(loc="upper right", ncol=2) - plt.savefig(f"{name}/damping_values_combined{name}{synthetic}.png", dpi=dpi) - -def run_prior_predictive_check(name): - validation_data = load_pickle(f'{name}/validation_data_{name}.pickle') # synthetic data - - divi_region_keys = region_keys_sorted(validation_data) - num_samples = validation_data['region0'].shape[0] - time_points = validation_data['region0'].shape[1] - - true_data = load_divi_data() - true_data = np.concatenate( - [true_data[key] for key in divi_region_keys], axis=-1 - )[0] - # get sims in shape (samples, time, regions) - simulations = np.zeros((num_samples, time_points, len(divi_region_keys))) - for i in range(num_samples): - simulations[i] = np.concatenate([validation_data[key][i] for key in divi_region_keys], axis=-1) - - fig, ax = plt.subplots(figsize=(8, 6)) - plot_aggregated_over_regions( - simulations, true_data=true_data, label="Region Aggregated Median (No Aug)", ax=ax, color=colors["Red"] - ) - ax.set_title("Prior predictive check - Aggregated over regions") - plt.savefig(f'{name}/prior_predictive_check_{name}.png', dpi=dpi) - -def compare_median_sim_to_mean_param_sim(name): - num_samples = 10 - - divi_dict = load_divi_data() - validation_data_skip2w = skip_2weeks(divi_dict) - aggregate_states(validation_data_skip2w) - divi_region_keys = region_keys_sorted(divi_dict) - divi_data = np.concatenate( - [divi_dict[key] for key in divi_region_keys], axis=-1 - )[0] - - simulations_aug = load_pickle(f'{name}/sims_{name}_with_aug.pickle') - - workflow = get_workflow() - workflow.approximator = keras.models.load_model( - filepath=os.path.join(f"{name}/model_{name}.keras") - ) - - samples = workflow.sample(conditions=validation_data_skip2w, num_samples=num_samples) - for key in inference_params: - samples[key] = np.median(samples[key], axis=1) - samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], 16, NUM_DAMPING_POINTS)) - result = run_germany_nuts3_simulation( - damping_values=samples['damping_values'][0], - t_E=samples['t_E'][0], t_ISy=samples['t_ISy'][0], - t_ISev=samples['t_ISev'][0], t_Cr=samples['t_Cr'][0], - mu_CR=samples['mu_CR'][0], mu_IH=samples['mu_IH'][0], - mu_HU=samples['mu_HU'][0], mu_UD=samples['mu_UD'][0], - transmission_prob=samples['transmission_prob'][0] - ) - for key in result.keys(): - result[key] = np.array(result[key])[None, ...] # add sample axis - - result = extract_observables(result) - divi_region_keys = region_keys_sorted(result) - result = np.concatenate( - [result[key] if key in result else np.zeros_like(result[divi_region_keys[0]]) for key in divi_region_keys], - axis=-1 - )[0] - - - fig, ax = plt.subplots(figsize=(8, 6)) - plot_aggregated_over_regions( - simulations_aug, true_data=result, label="Compare median of sim to sim of median params", ax=ax, color=colors["Red"] - ) - ax.set_title("Comparison of Median Simulation to Simulation of Median Parameters") - plt.savefig(f'{name}/compare_median_sim13_to_mean_param_sim_{name}.png', dpi=dpi) - plt.close() - - -class Simulation: - """ """ - - def __init__(self, data_dir, start_date, results_dir): - self.num_groups = 1 - self.data_dir = data_dir - self.start_date = start_date - self.results_dir = results_dir - if not os.path.exists(self.results_dir): - os.makedirs(self.results_dir) - - def set_covid_parameters(self, model, t_E, t_C, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): - """ - - :param model: - - """ - model.parameters.TimeExposed[mio.AgeGroup(0)] = t_E - model.parameters.TimeInfectedNoSymptoms[mio.AgeGroup(0)] = t_C - model.parameters.TimeInfectedSymptoms[mio.AgeGroup(0)] = t_ISy - model.parameters.TimeInfectedSevere[mio.AgeGroup(0)] = t_ISev - model.parameters.TimeInfectedCritical[mio.AgeGroup(0)] = t_Cr - - # probabilities - model.parameters.TransmissionProbabilityOnContact[mio.AgeGroup( - 0)] = transmission_prob - model.parameters.RelativeTransmissionNoSymptoms[mio.AgeGroup(0)] = 1 - - model.parameters.RecoveredPerInfectedNoSymptoms[mio.AgeGroup( - 0)] = mu_CR - model.parameters.SeverePerInfectedSymptoms[mio.AgeGroup(0)] = mu_IH - model.parameters.CriticalPerSevere[mio.AgeGroup(0)] = mu_HU - model.parameters.DeathsPerCritical[mio.AgeGroup(0)] = mu_UD - - # start day is set to the n-th day of the year - model.parameters.StartDay = self.start_date.timetuple().tm_yday - - model.parameters.Seasonality = mio.UncertainValue(0.2) - - def set_contact_matrices(self, model): - """ - - :param model: - - """ - contact_matrices = mio.ContactMatrixGroup(1, self.num_groups) - - baseline = np.ones((self.num_groups, self.num_groups)) * 7.95 - minimum = np.zeros((self.num_groups, self.num_groups)) - contact_matrices[0] = mio.ContactMatrix(baseline, minimum) - model.parameters.ContactPatterns.cont_freq_mat = contact_matrices - - def set_npis(self, params, end_date, damping_values): - """ - - :param params: - :param end_date: - - """ - start_damping_1 = DATE_TIME + datetime.timedelta(days=15) - start_damping_2 = DATE_TIME + datetime.timedelta(days=30) - start_damping_3 = DATE_TIME + datetime.timedelta(days=45) - - if start_damping_1 < end_date: - start_date = (start_damping_1 - self.start_date).days - params.ContactPatterns.cont_freq_mat[0].add_damping( - mio.Damping(np.r_[damping_values[0]], t=start_date)) - - if start_damping_2 < end_date: - start_date = (start_damping_2 - self.start_date).days - params.ContactPatterns.cont_freq_mat[0].add_damping( - mio.Damping(np.r_[damping_values[1]], t=start_date)) - - if start_damping_3 < end_date: - start_date = (start_damping_3 - self.start_date).days - params.ContactPatterns.cont_freq_mat[0].add_damping( - mio.Damping(np.r_[damping_values[2]], t=start_date)) - - def get_graph(self, end_date, t_E, t_C, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): - """ - - :param end_date: - - """ - print("Initializing model...") - model = Model(self.num_groups) - self.set_covid_parameters( - model, t_E, t_C, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) - self.set_contact_matrices(model) - print("Model initialized.") - - graph = osecir.ModelGraph() - - scaling_factor_infected = [2.5] - scaling_factor_icu = 1.0 - - data_dir_Germany = os.path.join(self.data_dir, "Germany") - mobility_data_file = os.path.join( - data_dir_Germany, "mobility", "commuter_mobility_2022.txt") - pydata_dir = os.path.join(data_dir_Germany, "pydata") - - path_population_data = os.path.join(pydata_dir, - "county_current_population.json") - - print("Setting nodes...") - mio.osecir.set_nodes( - model.parameters, - mio.Date(self.start_date.year, - self.start_date.month, self.start_date.day), - mio.Date(end_date.year, - end_date.month, end_date.day), pydata_dir, - path_population_data, True, graph, scaling_factor_infected, - scaling_factor_icu, 0, 0, False) - - print("Setting edges...") - mio.osecir.set_edges(mobility_data_file, graph, 1) - - print("Graph created.") - - return graph - - def run(self, num_days_sim, damping_values, t_E, t_C, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob, save_graph=True): - """ - - :param num_days_sim: - :param num_runs: (Default value = 10) - :param save_graph: (Default value = True) - :param create_gif: (Default value = True) - - """ - mio.set_log_level(mio.LogLevel.Warning) - end_date = self.start_date + datetime.timedelta(days=num_days_sim) - - graph = self.get_graph(end_date, t_E, t_C, t_ISy, t_ISev, - t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) - - mobility_graph = osecir.MobilityGraph() - for node_idx in range(graph.num_nodes): - node = graph.get_node(node_idx) - - self.set_npis( - node.property.parameters, - end_date, - damping_values[node.id // 1000 - 1] - ) - mobility_graph.add_node(node.id, node.property) - for edge_idx in range(graph.num_edges): - mobility_graph.add_edge( - graph.get_edge(edge_idx).start_node_idx, - graph.get_edge(edge_idx).end_node_idx, - graph.get_edge(edge_idx).property) - mobility_sim = osecir.MobilitySimulation(mobility_graph, t0=0, dt=0.5) - mobility_sim.advance(num_days_sim) - - results = {} - for node_idx in range(mobility_sim.graph.num_nodes): - node = mobility_sim.graph.get_node(node_idx) - if node.id in no_icu_ids: - results[f'no_icu_region{node_idx}'] = osecir.interpolate_simulation_result( - node.property.result) - else: - results[f'region{node_idx}'] = osecir.interpolate_simulation_result( - node.property.result) - - return results - - -def run_germany_nuts3_simulation(damping_values, t_E, t_C, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): - mio.set_log_level(mio.LogLevel.Warning) - file_path = os.path.dirname(os.path.abspath(__file__)) - - sim = Simulation( - data_dir=os.path.join(file_path, "../../../data"), - start_date=DATE_TIME, - results_dir=os.path.join(file_path, "../../../results_osecir")) - num_days_sim = 60 - - results = sim.run(num_days_sim, damping_values, t_E, t_C, t_ISy, - t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) - - return results - -def prior(): - damping_values = np.zeros((NUM_DAMPING_POINTS, 16)) - for i in range(NUM_DAMPING_POINTS): - mean = np.random.uniform(0, 1) - scale = 0.1 - a, b = (0 - mean) / scale, (1 - mean) / scale - damping_values[i] = truncnorm.rvs( - a=a, b=b, loc=mean, scale=scale, size=16 - ) - return { - 'damping_values': np.transpose(damping_values), - 't_E': np.random.uniform(*bounds['t_E']), - 't_C': np.random.uniform(*bounds['t_C']), - 't_ISy': np.random.uniform(*bounds['t_ISy']), - 't_ISev': np.random.uniform(*bounds['t_ISev']), - 't_Cr': np.random.uniform(*bounds['t_Cr']), - 'mu_CR': np.random.uniform(*bounds['mu_CR']), - 'mu_IH': np.random.uniform(*bounds['mu_IH']), - 'mu_HU': np.random.uniform(*bounds['mu_HU']), - 'mu_UD': np.random.uniform(*bounds['mu_UD']), - 'transmission_prob': np.random.uniform(*bounds['transmission_prob']) - } - - -def load_divi_data(): - file_path = os.path.dirname(os.path.abspath(__file__)) - divi_path = os.path.join(file_path, "../../../data/Germany/pydata") - - data = pd.read_json(os.path.join(divi_path, "county_divi_ma7.json")) - data = data[data['Date'] >= np.datetime64(DATE_TIME)] - data = data[data['Date'] <= np.datetime64(DATE_TIME + datetime.timedelta(days=60))] - data = data.sort_values(by=['ID_County', 'Date']) - divi_data = data.pivot(index='Date', columns='ID_County', values='ICU') - divi_dict = {} - for i, region_id in enumerate(region_ids): - if region_id not in no_icu_ids: - divi_dict[f"region{i}"] = divi_data[region_id].to_numpy()[None, :, None] - else: - divi_dict[f"no_icu_region{i}"] = np.zeros((1, divi_data.shape[0], 1)) - return divi_dict - - -def load_extrapolated_case_data(): - file_path = os.path.dirname(os.path.abspath(__file__)) - case_path = os.path.join(file_path, "../../../data/Germany/pydata") - - file = h5py.File(os.path.join(case_path, "Results_rki.h5")) - divi_dict = {} - for i, region_id in enumerate(region_ids): - if region_id not in no_icu_ids: - divi_dict[f"region{i}"] = np.array(file[f"{region_id}"]['Total'][:, 4])[None, :, None] - else: - divi_dict[f"no_icu_region{i}"] = np.zeros((1, np.array(file[f"{region_id}"]['Total']).shape[0], 1)) - return divi_dict - - -def extract_observables(simulation_results, observable_index=7): - for key in simulation_results.keys(): - if key not in inference_params: - simulation_results[key] = simulation_results[key][:, :, observable_index][..., np.newaxis] - return simulation_results - - -def create_train_data(filename, number_samples=1000): - - simulator = bf.simulators.make_simulator( - [prior, run_germany_nuts3_simulation] - ) - trainings_data = simulator.sample(number_samples) - trainings_data = extract_observables(trainings_data) - with open(filename, 'wb') as f: - pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) - - -def load_pickle(path): - with open(path, "rb") as f: - return pickle.load(f) - -def is_region_key(k: str) -> bool: - return 'region' in k - -def apply_aug(d: dict, aug) -> dict: - return {k: np.clip(aug(v), 0, None) if is_region_key(k) else v for k, v in d.items()} - -def concat_dicts(base: dict, new: dict) -> dict: - missing = set(base) - set(new) - if missing: - raise KeyError(f"new dict missing keys: {sorted(missing)}") - for k in base: - base[k] = np.concatenate([base[k], new[k]]) - return base - -def region_keys_sorted(d: dict): - def idx(k): - # handles "regionN" and "no_icu_regionN" - return int(k.split("region")[-1]) - return sorted([k for k in d if is_region_key(k)], key=idx) - -def aggregate_states(d: dict) -> None: - n_regions = len(region_ids) - # per state - for state in range(16): - idxs = [ - r for r in range(n_regions) - if region_ids[r] // 1000 == state + 1 - ] - d[f"fed_state{state}"] = np.sum([d[f"region{r}"] if region_ids[r] not in no_icu_ids else d[f"no_icu_region{r}"] for r in idxs], axis=0) - # all allowed regions - d["state"] = np.sum([d[f"fed_state{r}"] for r in range(16)], axis=0) - -def combine_results(dict_list): - combined = {} - for d in dict_list: - combined = concat_dicts(combined, d) if combined else d - return combined - -def skip_2weeks(d:dict) -> dict: - return {k: v[:, 14:, :] if is_region_key(k) else v for k, v in d.items()} - -def get_workflow(): - - simulator = bf.make_simulator( - [prior, run_germany_nuts3_simulation] - ) - adapter = ( - bf.Adapter() - .to_array() - .convert_dtype("float64", "float32") - .constrain("damping_values", lower=0.0, upper=1.0) - .constrain("t_E", lower=bounds["t_E"][0], upper=bounds["t_E"][1]) - .constrain("t_C", lower=bounds["t_C"][0], upper=bounds["t_C"][1]) - .constrain("t_ISy", lower=bounds["t_ISy"][0], upper=bounds["t_ISy"][1]) - .constrain("t_ISev", lower=bounds["t_ISev"][0], upper=bounds["t_ISev"][1]) - .constrain("t_Cr", lower=bounds["t_Cr"][0], upper=bounds["t_Cr"][1]) - .constrain("mu_CR", lower=bounds["mu_CR"][0], upper=bounds["mu_CR"][1]) - .constrain("mu_IH", lower=bounds["mu_IH"][0], upper=bounds["mu_IH"][1]) - .constrain("mu_HU", lower=bounds["mu_HU"][0], upper=bounds["mu_HU"][1]) - .constrain("mu_UD", lower=bounds["mu_UD"][0], upper=bounds["mu_UD"][1]) - .constrain("transmission_prob", lower=bounds["transmission_prob"][0], upper=bounds["transmission_prob"][1]) - .concatenate( - ["damping_values", "t_E", "t_C", "t_ISy", "t_ISev", "t_Cr", - "mu_CR", "mu_IH", "mu_HU", "mu_UD", "transmission_prob"], - into="inference_variables", - axis=-1 - ) - .concatenate(summary_vars, into="summary_variables", axis=-1) - ) - - summary_network = bf.networks.FusionTransformer( - summary_dim=(len(bounds)+16*NUM_DAMPING_POINTS)*2, dropout=0.1 - ) - inference_network = bf.networks.FlowMatching(subnet_kwargs={'widths': (512, 512, 512, 512, 512)}) - - # aug = bf.augmentations.NNPE(spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False) - workflow = bf.BasicWorkflow( - simulator=simulator, - adapter=adapter, - summary_network=summary_network, - inference_network=inference_network, - standardize='all' - # augmentations={f'region{i}': aug for i in range(len(region_ids)) if region_ids[i] not in no_icu_ids} - # aggregation of the states would need to be recomputed every time a different noise realization is applied - ) - - return workflow - - -def run_training(name, num_training_files=20): - train_template = name+"/trainings_data{i}_"+name+".pickle" - val_path = f"{name}/validation_data_{name}.pickle" - - aug = bf.augmentations.NNPE( - spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False - ) - - # training data - train_files = [train_template.format(i=i) for i in range(1, 1+num_training_files)] - trainings_data = None - for p in train_files: - d = load_pickle(p) - d = apply_aug(d, aug=aug) # only on region keys - d = skip_2weeks(d) - d['damping_values'] = d['damping_values'].reshape((d['damping_values'].shape[0], -1)) - if trainings_data is None: - trainings_data = d - else: - trainings_data = concat_dicts(trainings_data, d) - aggregate_states(trainings_data) - - # validation data - validation_data = apply_aug(load_pickle(val_path), aug=aug) - validation_data = skip_2weeks(validation_data) - aggregate_states(validation_data) - validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) - - # check data - workflow = get_workflow() - print("summary_variables shape:", workflow.adapter(trainings_data)["summary_variables"].shape) - print("inference_variables shape:", workflow.adapter(trainings_data)["inference_variables"].shape) - - history = workflow.fit_offline( - data=trainings_data, epochs=500, batch_size=64, validation_data=validation_data - ) - - workflow.approximator.save( - filepath=os.path.join(f"{name}/model_{name}.keras") - ) - - plots = workflow.plot_default_diagnostics( - test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True} - ) - plots['losses'].savefig(f'{name}/losses_{name}.png', dpi=dpi) - plots['recovery'].savefig(f'{name}/recovery_{name}.png', dpi=dpi) - plots['calibration_ecdf'].savefig(f'{name}/calibration_ecdf_{name}.png', dpi=dpi) - #plots['z_score_contraction'].savefig(f'{name}/z_score_contraction_{name}.png', dpi=dpi) - - -def run_inference(name, num_samples=1000, on_synthetic_data=False): - val_path = f"{name}/validation_data_{name}.pickle" - synthetic = "_synthetic" if on_synthetic_data else "" - - aug = bf.augmentations.NNPE( - spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False - ) - - # validation data - validation_data = load_pickle(val_path) # synthetic data - if on_synthetic_data: - # validation data - validation_data = apply_aug(validation_data, aug=aug) - validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) - validation_data_skip2w = skip_2weeks(validation_data) - aggregate_states(validation_data_skip2w) - divi_dict = validation_data - divi_region_keys = region_keys_sorted(divi_dict) - - divi_data = np.concatenate( - [divi_dict[key] for key in divi_region_keys], axis=-1 - )[0] # only one dataset - else: - divi_dict = load_divi_data() - validation_data_skip2w = skip_2weeks(divi_dict) - aggregate_states(validation_data_skip2w) - divi_region_keys = region_keys_sorted(divi_dict) - divi_data = np.concatenate( - [divi_dict[key] for key in divi_region_keys], axis=-1 - )[0] - - workflow = get_workflow() - workflow.approximator = keras.models.load_model( - filepath=os.path.join(f"{name}/model_{name}.keras") - ) - - if os.path.exists(f'{name}/sims_{name}{synthetic}_with_aug.pickle') and os.path.exists(f'{name}/sims_{name}{synthetic}.pickle') and os.path.exists(f'{name}/samples_{name}{synthetic}.pickle'): - simulations = load_pickle(f'{name}/sims_{name}{synthetic}.pickle') - simulations_aug = load_pickle( - f'{name}/sims_{name}{synthetic}_with_aug.pickle') - samples = load_pickle(f'{name}/samples_{name}{synthetic}.pickle') - print("loaded simulations from file") - else: - samples = workflow.sample(conditions=validation_data_skip2w, num_samples=num_samples) - with open(f'{name}/samples_{name}{synthetic}.pickle', 'wb') as f: - pickle.dump(samples, f, pickle.HIGHEST_PROTOCOL) - samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], num_samples, 16, NUM_DAMPING_POINTS)) - results = [] - for i in range(num_samples): # we only have one dataset for inference here - result = run_germany_nuts3_simulation( - damping_values=samples['damping_values'][0, i], - t_E=samples['t_E'][0, i], t_C=samples['t_C'][0, i], t_ISy=samples['t_ISy'][0, i], - t_ISev=samples['t_ISev'][0, i], t_Cr=samples['t_Cr'][0, i], - mu_CR=samples['mu_CR'][0, i], mu_IH=samples['mu_IH'][0, i], - mu_HU=samples['mu_HU'][0, i], mu_UD=samples['mu_UD'][0, i], - transmission_prob=samples['transmission_prob'][0, i] - ) - for key in result.keys(): - result[key] = np.array(result[key])[None, ...] # add sample axis - results.append(result) - results = combine_results(results) - results = extract_observables(results) - results_aug = apply_aug(results, aug=aug) - - # get sims in shape (samples, time, regions) - simulations = np.zeros((num_samples, divi_data.shape[0], divi_data.shape[1])) - simulations_aug = np.zeros((num_samples, divi_data.shape[0], divi_data.shape[1])) - for i in range(num_samples): - simulations[i] = np.concatenate([results[key][i] for key in divi_region_keys], axis=-1) - simulations_aug[i] = np.concatenate([results_aug[key][i] for key in divi_region_keys], axis=-1) - - # save sims - with open(f'{name}/sims_{name}{synthetic}.pickle', 'wb') as f: - pickle.dump(simulations, f, pickle.HIGHEST_PROTOCOL) - with open(f'{name}/sims_{name}{synthetic}_with_aug.pickle', 'wb') as f: - pickle.dump(simulations_aug, f, pickle.HIGHEST_PROTOCOL) - - samples['damping_values'] = samples['damping_values'].reshape( - (samples['damping_values'].shape[0], samples['damping_values'].shape[1], -1)) - validation_data['damping_values'] = validation_data['damping_values'].reshape( - (validation_data['damping_values'].shape[0], -1)) - - plot = bf.diagnostics.pairs_posterior( - samples, priors=validation_data, dataset_id=0) - plot.savefig(f'{name}/pairs_posterior_{name}{synthetic}.png', dpi=dpi) - - samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], num_samples, 16, NUM_DAMPING_POINTS)) - plot_damping_values(samples['damping_values'][0], - name=name, synthetic=synthetic) - - plot_all_regions(simulations, divi_data, name, synthetic, with_aug="") - plot_all_regions(simulations_aug, divi_data, name, synthetic, with_aug="_with_aug") - - plot_aggregated_to_federal_states(simulations, divi_data, name, synthetic, with_aug="") - plot_aggregated_to_federal_states(simulations_aug, divi_data, name, synthetic, with_aug="_with_aug") - - fig, axes = plt.subplots(1, 2, figsize=(12, 6), constrained_layout=True) - # Plot with augmentation - plot_aggregated_over_regions( - simulations_aug, true_data=divi_data, label="Region Aggregated Median", ax=axes[0], color=colors["Red"] - ) - # Plot with augmentation (80% quantile only) - plot_aggregated_over_regions( - simulations_aug, true_data=divi_data, label="Region Aggregated Median", ax=axes[1], color=colors["Red"], only_80q=True - ) - lines, labels = axes[0].get_legend_handles_labels() - fig.legend(lines, labels) - plt.savefig(f'{name}/region_aggregated_{name}{synthetic}.png', dpi=dpi) - plt.close() - - fig, axis = plt.subplots(1, 2, figsize=(10, 4), sharex=True, layout="constrained") - ax = calibration_curves_per_region(simulations, divi_data, ax=axis[0]) - ax, stats = calibration_median_mad_over_regions(simulations, divi_data, ax=axis[1]) - plt.savefig(f'{name}/calibration_per_region_{name}{synthetic}.png', dpi=dpi) - plt.close() - fig, axis = plt.subplots(1, 2, figsize=(10, 4), sharex=True, layout="constrained") - ax = calibration_curves_per_region(simulations_aug, divi_data, ax=axis[0]) - ax, stats = calibration_median_mad_over_regions(simulations_aug, divi_data, ax=axis[1]) - plt.savefig(f'{name}/calibration_per_region_{name}{synthetic}_with_aug.png', dpi=dpi) - plt.close() - - plot_icu_on_germany(simulations, name, synthetic, with_aug="") - plot_icu_on_germany(simulations_aug, name, synthetic, with_aug="_with_aug") - - simulation_agg = np.sum(simulations, axis=-1, keepdims=True) # sum over regions - simulation_aug_agg = np.sum(simulations_aug, axis=-1, keepdims=True) - - rmse = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes(simulation_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True), normalize=False) - rmse_aug = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes(simulation_aug_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True), normalize=False) - print("Mean RMSE over regions:", rmse["values"].mean()) - print("Mean RMSE over regions (with aug):", rmse_aug["values"].mean()) - - cal_error = bf.diagnostics.metrics.calibration_error(np.swapaxes(simulation_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True)) - cal_error_aug = bf.diagnostics.metrics.calibration_error(np.swapaxes(simulation_aug_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True)) - print("Mean Calibration Error over regions:", cal_error["values"].mean()) - print("Mean Calibration Error over regions (with aug):", cal_error_aug["values"].mean()) - - -if __name__ == "__main__": - name = "3dampings_lessnoise_newnetwork" - - set_fontsize() - - if not os.path.exists(name): - os.makedirs(name) - # create_train_data(filename=f'{name}/validation_data_{name}.pickle', number_samples=100) - # run_training(name=name, num_training_files=10) - run_inference(name=name, on_synthetic_data=True) - run_inference(name=name, on_synthetic_data=False) - # run_prior_predictive_check(name=name) - # compare_median_sim_to_mean_param_sim(name=name) \ No newline at end of file