diff --git a/cpp/memilio/geography/regions.h b/cpp/memilio/geography/regions.h index c8dbeb5978..a359d5bce9 100644 --- a/cpp/memilio/geography/regions.h +++ b/cpp/memilio/geography/regions.h @@ -77,6 +77,8 @@ DECL_TYPESAFE(int, CountyId); DECL_TYPESAFE(int, DistrictId); +DECL_TYPESAFE(int, ProvinciaId); + /** * get the id of the state that the specified county is in. * @param[in, out] county a county id. diff --git a/cpp/memilio/io/epi_data.cpp b/cpp/memilio/io/epi_data.cpp index 24e86a00e0..2741d1f4ff 100644 --- a/cpp/memilio/io/epi_data.cpp +++ b/cpp/memilio/io/epi_data.cpp @@ -27,10 +27,10 @@ namespace mio std::vector ConfirmedCasesDataEntry::age_group_names = {"A00-A04", "A05-A14", "A15-A34", "A35-A59", "A60-A79", "A80+"}; - -std::vector PopulationDataEntry::age_group_names = { +std::vector PopulationDataEntry::age_group_names = { "<3 years", "3-5 years", "6-14 years", "15-17 years", "18-24 years", "25-29 years", "30-39 years", "40-49 years", "50-64 years", "65-74 years", ">74 years"}; +std::vector PopulationDataEntrySpain::age_group_names = {"Population"}; std::vector VaccinationDataEntry::age_group_names = {"0-4", "5-14", "15-34", "35-59", "60-79", "80-99"}; @@ -49,11 +49,14 @@ IOResult> get_node_ids(const std::string& path, bool is_node_fo } } else { - if (entry.district_id) { + if (entry.state_id) { + id.push_back(entry.state_id->get()); + } + else if (entry.district_id) { id.push_back(entry.district_id->get()); } else { - return failure(StatusCode::InvalidValue, "Population data file is missing district ids."); + return failure(StatusCode::InvalidValue, "Population data file is missing district and state ids."); } } } @@ -62,6 +65,35 @@ IOResult> get_node_ids(const std::string& path, bool is_node_fo id.erase(std::unique(id.begin(), id.end()), id.end()); return success(id); } + +IOResult> get_country_id(const std::string& /*path*/, bool /*is_node_for_county*/, + bool /*rki_age_groups*/) +{ + std::vector id = {0}; + return success(id); +} + +IOResult> get_provincia_ids(const std::string& path, bool /*is_node_for_county*/, + bool /*rki_age_groups*/) +{ +mio::log_info("Reading node data from {}.\n", path); +BOOST_OUTCOME_TRY(auto&& population_data, read_population_data_spain(path)); +std::vector id; +id.reserve(population_data.size()); +for (auto&& entry : population_data) { +if (entry.provincia_id) { +id.push_back(entry.provincia_id->get()); +} +else { +return failure(StatusCode::InvalidValue, "Population data file is missing provincia ids."); +} +} + +//remove duplicate node ids +id.erase(std::unique(id.begin(), id.end()), id.end()); +mio::log_info("Reading node data completed."); +return success(id); +} } // namespace mio #endif //MEMILIO_HAS_JSONCPP diff --git a/cpp/memilio/io/epi_data.h b/cpp/memilio/io/epi_data.h index 3e4572b870..442b54e4cc 100644 --- a/cpp/memilio/io/epi_data.h +++ b/cpp/memilio/io/epi_data.h @@ -77,6 +77,9 @@ class ConfirmedCasesNoAgeEntry double num_recovered; double num_deaths; Date date; + boost::optional state_id; + boost::optional county_id; + boost::optional district_id; template static IOResult deserialize(IOContext& io) @@ -86,12 +89,15 @@ class ConfirmedCasesNoAgeEntry auto num_recovered = obj.expect_element("Recovered", Tag{}); auto num_deaths = obj.expect_element("Deaths", Tag{}); auto date = obj.expect_element("Date", Tag{}); + auto state_id = obj.expect_optional("ID_State", Tag{}); + auto county_id = obj.expect_optional("ID_County", Tag{}); + auto district_id = obj.expect_optional("ID_District", Tag{}); return apply( io, - [](auto&& nc, auto&& nr, auto&& nd, auto&& d) { - return ConfirmedCasesNoAgeEntry{nc, nr, nd, d}; + [](auto&& nc, auto&& nr, auto&& nd, auto&& d, auto&& sid, auto&& cid, auto&& did) { + return ConfirmedCasesNoAgeEntry{nc, nr, nd, d, sid, cid, did}; }, - num_confirmed, num_recovered, num_deaths, date); + num_confirmed, num_recovered, num_deaths, date, state_id, county_id, district_id); } }; @@ -142,29 +148,42 @@ class ConfirmedCasesDataEntry { auto obj = io.expect_object("ConfirmedCasesDataEntry"); auto num_confirmed = obj.expect_element("Confirmed", Tag{}); - auto num_recovered = obj.expect_element("Recovered", Tag{}); - auto num_deaths = obj.expect_element("Deaths", Tag{}); + auto num_recovered = obj.expect_optional("Recovered", Tag{}); + auto num_deaths = obj.expect_optional("Deaths", Tag{}); auto date = obj.expect_element("Date", Tag{}); - auto age_group_str = obj.expect_element("Age_RKI", Tag{}); + auto age_group_str = obj.expect_optional("Age_RKI", Tag{}); auto state_id = obj.expect_optional("ID_State", Tag{}); auto county_id = obj.expect_optional("ID_County", Tag{}); auto district_id = obj.expect_optional("ID_District", Tag{}); return apply( io, - [](auto&& nc, auto&& nr, auto&& nd, auto&& d, auto&& a_str, auto&& sid, auto&& cid, + [](auto&& nc, auto&& nr, auto&& nd, auto&& d, auto&& a_str_opt, auto&& sid, auto&& cid, auto&& did) -> IOResult { - auto a = AgeGroup(0); - auto it = std::find(age_group_names.begin(), age_group_names.end(), a_str); - if (it != age_group_names.end()) { - a = AgeGroup(size_t(it - age_group_names.begin())); - } - else if (a_str == "unknown") { - a = AgeGroup(age_group_names.size()); + auto a = AgeGroup(0); // default if Age_RKI missing + + // if no age group is given, use "total" as name + if (!a_str_opt) { + if (age_group_names.size() != 1) { + age_group_names.clear(); + age_group_names.push_back("total"); + } } else { - return failure(StatusCode::InvalidValue, "Invalid confirmed cases data age group."); + const auto& a_str = *a_str_opt; + auto it = std::find(age_group_names.begin(), age_group_names.end(), a_str); + if (it != age_group_names.end()) { + a = AgeGroup(size_t(it - age_group_names.begin())); + } + else if (a_str == "unknown") { + a = AgeGroup(age_group_names.size()); + } + else { + return failure(StatusCode::InvalidValue, "Invalid confirmed cases data age group."); + } } - return success(ConfirmedCasesDataEntry{nc, nr, nd, d, a, sid, cid, did}); + double nrec = nr ? *nr : 0.0; + double ndead = nd ? *nd : 0.0; + return success(ConfirmedCasesDataEntry{nc, nrec, ndead, d, a, sid, cid, did}); }, num_confirmed, num_recovered, num_deaths, date, age_group_str, state_id, county_id, district_id); } @@ -331,6 +350,35 @@ class PopulationDataEntry } }; +class PopulationDataEntrySpain +{ +public: + static std::vector age_group_names; + + CustomIndexArray population; + boost::optional provincia_id; + + template + static IOResult deserialize(IoContext& io) + { + auto obj = io.expect_object("PopulationDataEntrySpain"); + auto provincia = obj.expect_optional("ID_Provincia", Tag{}); + std::vector> age_groups; + age_groups.reserve(age_group_names.size()); + std::transform(age_group_names.begin(), age_group_names.end(), std::back_inserter(age_groups), + [&obj](auto&& age_name) { + return obj.expect_element(age_name, Tag{}); + }); + return apply( + io, + [](auto&& ag, auto&& pid) { + return PopulationDataEntrySpain{ + CustomIndexArray(AgeGroup(ag.size()), ag.begin(), ag.end()), pid}; + }, + details::unpack_all(age_groups), provincia); + } +}; + namespace details { inline void get_rki_age_interpolation_coefficients(const std::vector& age_ranges, @@ -435,6 +483,19 @@ inline IOResult> deserialize_population_data(co } } +/** + * Deserialize population data from a JSON value. + * Age groups are interpolated to RKI age groups. + * @param jsvalue JSON value that contains the population data. + * @param rki_age_groups Specifies whether population data should be interpolated to rki age groups. + * @return list of population data. + */ +inline IOResult> deserialize_population_data_spain(const Json::Value& jsvalue) +{ + BOOST_OUTCOME_TRY(auto&& population_data, deserialize_json(jsvalue, Tag>{})); + return success(population_data); +} + /** * Deserialize population data from a JSON file. * Age groups are interpolated to RKI age groups. @@ -448,6 +509,18 @@ inline IOResult> read_population_data(const std return deserialize_population_data(jsvalue, rki_age_group); } +/** + * Deserialize population data from a JSON file. + * Age groups are interpolated to RKI age groups. + * @param filename JSON file that contains the population data. + * @return list of population data. + */ +inline IOResult> read_population_data_spain(const std::string& filename) +{ + BOOST_OUTCOME_TRY(auto&& jsvalue, read_json(filename)); + return deserialize_population_data_spain(jsvalue); +} + /** * @brief Sets the age groups' names for the ConfirmedCasesDataEntry%s. * @param[in] names age group names @@ -474,6 +547,10 @@ IOResult set_vaccination_data_age_group_names(std::vector nam * @return list of node ids. */ IOResult> get_node_ids(const std::string& path, bool is_node_for_county, bool rki_age_groups = true); +IOResult> get_country_id(const std::string& /*path*/, bool /*is_node_for_county*/, + bool /*rki_age_groups*/ = true); +IOResult> get_provincia_ids(const std::string& path, bool /*is_node_for_county*/, + bool /*rki_age_groups*/ = true); /** * Represents an entry in a vaccination data file. diff --git a/cpp/memilio/io/parameters_io.cpp b/cpp/memilio/io/parameters_io.cpp index f9f9f8b3c0..450a38f902 100644 --- a/cpp/memilio/io/parameters_io.cpp +++ b/cpp/memilio/io/parameters_io.cpp @@ -62,11 +62,43 @@ IOResult>> read_population_data(const std::vecto return success(vnum_population); } +IOResult>> +read_population_data_spain(const std::vector& population_data, + const std::vector& vregion) +{ + std::vector> vnum_population(vregion.size(), std::vector(1, 0.0)); + + for (auto&& provincia_entry : population_data) { + if (provincia_entry.provincia_id) { + for (size_t idx = 0; idx < vregion.size(); ++idx) { + if (vregion[idx] == provincia_entry.provincia_id->get()) { + vnum_population[idx][0] += provincia_entry.population[AgeGroup(0)]; + } + } + } + // id 0 means the whole country + for (size_t idx = 0; idx < vregion.size(); ++idx) { + if (vregion[idx] == 0) { + vnum_population[idx][0] += provincia_entry.population[AgeGroup(0)]; + } + } + } + + return success(vnum_population); +} + IOResult>> read_population_data(const std::string& path, const std::vector& vregion) { BOOST_OUTCOME_TRY(auto&& population_data, mio::read_population_data(path)); return read_population_data(population_data, vregion); } + +IOResult>> read_population_data_spain(const std::string& path, + const std::vector& vregion) +{ + BOOST_OUTCOME_TRY(auto&& population_data, mio::read_population_data_spain(path)); + return read_population_data_spain(population_data, vregion); +} } // namespace mio #endif //MEMILIO_HAS_JSONCPP diff --git a/cpp/memilio/io/parameters_io.h b/cpp/memilio/io/parameters_io.h index 3ea00fc73b..bbb0769d51 100644 --- a/cpp/memilio/io/parameters_io.h +++ b/cpp/memilio/io/parameters_io.h @@ -136,6 +136,24 @@ IOResult>> read_population_data(const std::vecto IOResult>> read_population_data(const std::string& path, const std::vector& vregion); +/** + * @brief Reads (Spain) population data from a vector of Spain population data entries. + * Uses provincias IDs and a single aggregated age group. + * @param[in] population_data Vector of Spain population data entries. + * @param[in] vregion Vector of keys representing the provincias (or country = 0) of interest. + */ +IOResult>> +read_population_data_spain(const std::vector& population_data, + const std::vector& vregion); + +/** + * @brief Reads (Spain) population data from census data file with provincias. + * @param[in] path Path to the provincias population data file. + * @param[in] vregion Vector of keys representing the provincias (or country = 0) of interest. + */ +IOResult>> read_population_data_spain(const std::string& path, + const std::vector& vregion); + } // namespace mio #endif //MEMILIO_HAS_JSONCPP diff --git a/cpp/memilio/mobility/graph.h b/cpp/memilio/mobility/graph.h index eb8bd35990..1a616be241 100644 --- a/cpp/memilio/mobility/graph.h +++ b/cpp/memilio/mobility/graph.h @@ -307,15 +307,15 @@ IOResult set_nodes(const Parameters& params, Date start_date, Date end_dat }); //uncertainty in populations - for (auto i = mio::AgeGroup(0); i < params.get_num_groups(); i++) { - for (auto j = Index(0); j < Model::Compartments::Count; ++j) { - auto& compartment_value = nodes[node_idx].populations[{i, j}]; - compartment_value = - UncertainValue(0.5 * (1.1 * double(compartment_value) + 0.9 * double(compartment_value))); - compartment_value.set_distribution(mio::ParameterDistributionUniform(0.9 * double(compartment_value), - 1.1 * double(compartment_value))); - } - } + // for (auto i = mio::AgeGroup(0); i < params.get_num_groups(); i++) { + // for (auto j = Index(0); j < Model::Compartments::Count; ++j) { + // auto& compartment_value = nodes[node_idx].populations[{i, j}]; + // compartment_value = + // UncertainValue(0.5 * (1.1 * double(compartment_value) + 0.9 * double(compartment_value))); + // compartment_value.set_distribution(mio::ParameterDistributionUniform(0.9 * double(compartment_value), + // 1.1 * double(compartment_value))); + // } + // } params_graph.add_node(node_ids[node_idx], nodes[node_idx]); } @@ -368,7 +368,7 @@ IOResult set_edges(const fs::path& mobility_data_file, Graph(); ++age) { for (auto compartment : mobile_compartments) { auto coeff_index = populations.get_flat_index({age, compartment}); - mobility_coeffs[size_t(ContactLocation::Work)].get_baseline()[coeff_index] = + mobility_coeffs[size_t(ContactLocation::Home)].get_baseline()[coeff_index] = commuter_coeff_ij * commuting_weights[size_t(age)]; } } diff --git a/cpp/models/ode_secir/parameters_io.h b/cpp/models/ode_secir/parameters_io.h index ab334d922d..6f9a89acda 100644 --- a/cpp/models/ode_secir/parameters_io.h +++ b/cpp/models/ode_secir/parameters_io.h @@ -75,7 +75,17 @@ IOResult set_confirmed_cases_data(std::vector>& model, std::vect const std::vector& scaling_factor_inf) { const size_t num_age_groups = ConfirmedCasesDataEntry::age_group_names.size(); - assert(scaling_factor_inf.size() == num_age_groups); + // allow single scalar scaling that is broadcast to all age groups + assert(scaling_factor_inf.size() == 1 || scaling_factor_inf.size() == num_age_groups); + + // Set scaling factors to match num age groups + std::vector scaling_factor_inf_full; + if (scaling_factor_inf.size() == 1) { + scaling_factor_inf_full.assign(num_age_groups, scaling_factor_inf[0]); + } + else { + scaling_factor_inf_full = scaling_factor_inf; + } std::vector> t_InfectedNoSymptoms{model.size()}; std::vector> t_Exposed{model.size()}; @@ -89,7 +99,12 @@ IOResult set_confirmed_cases_data(std::vector>& model, std::vect std::vector> mu_U_D{model.size()}; for (size_t node = 0; node < model.size(); ++node) { - for (size_t group = 0; group < num_age_groups; group++) { + const size_t model_groups = (size_t)model[node].parameters.get_num_groups(); + assert(model_groups == 1 || model_groups == num_age_groups); + for (size_t ag = 0; ag < num_age_groups; ag++) { + // If the model has fewer groups than casedata entries available, + // reuse group 0 parameters for all RKI age groups + const size_t group = (model_groups == num_age_groups) ? ag : 0; t_Exposed[node].push_back( static_cast(std::round(model[node].parameters.template get>()[(AgeGroup)group]))); @@ -121,26 +136,49 @@ IOResult set_confirmed_cases_data(std::vector>& model, std::vect BOOST_OUTCOME_TRY(read_confirmed_cases_data(case_data, region, date, num_Exposed, num_InfectedNoSymptoms, num_InfectedSymptoms, num_InfectedSevere, num_icu, num_death, num_rec, t_Exposed, t_InfectedNoSymptoms, t_InfectedSymptoms, t_InfectedSevere, - t_InfectedCritical, mu_C_R, mu_I_H, mu_H_U, scaling_factor_inf)); + t_InfectedCritical, mu_C_R, mu_I_H, mu_H_U, scaling_factor_inf_full)); for (size_t node = 0; node < model.size(); node++) { if (std::accumulate(num_InfectedSymptoms[node].begin(), num_InfectedSymptoms[node].end(), 0.0) > 0) { size_t num_groups = (size_t)model[node].parameters.get_num_groups(); - for (size_t i = 0; i < num_groups; i++) { - model[node].populations[{AgeGroup(i), InfectionState::Exposed}] = num_Exposed[node][i]; - model[node].populations[{AgeGroup(i), InfectionState::InfectedNoSymptoms}] = - num_InfectedNoSymptoms[node][i]; - model[node].populations[{AgeGroup(i), InfectionState::InfectedNoSymptomsConfirmed}] = 0; - model[node].populations[{AgeGroup(i), InfectionState::InfectedSymptoms}] = - num_InfectedSymptoms[node][i]; - model[node].populations[{AgeGroup(i), InfectionState::InfectedSymptomsConfirmed}] = 0; - model[node].populations[{AgeGroup(i), InfectionState::InfectedSevere}] = num_InfectedSevere[node][i]; - // Only set the number of ICU patients here, if the date is not available in the data. + if (num_groups == num_age_groups) { + for (size_t i = 0; i < num_groups; i++) { + model[node].populations[{AgeGroup(i), InfectionState::Exposed}] = num_Exposed[node][i]; + model[node].populations[{AgeGroup(i), InfectionState::InfectedNoSymptoms}] = + num_InfectedNoSymptoms[node][i]; + model[node].populations[{AgeGroup(i), InfectionState::InfectedNoSymptomsConfirmed}] = 0; + model[node].populations[{AgeGroup(i), InfectionState::InfectedSymptoms}] = + num_InfectedSymptoms[node][i]; + model[node].populations[{AgeGroup(i), InfectionState::InfectedSymptomsConfirmed}] = 0; + model[node].populations[{AgeGroup(i), InfectionState::InfectedSevere}] = + num_InfectedSevere[node][i]; + // Only set the number of ICU patients here, if the date is not available in the data. + if (!is_divi_data_available(date)) { + model[node].populations[{AgeGroup(i), InfectionState::InfectedCritical}] = num_icu[node][i]; + } + model[node].populations[{AgeGroup(i), InfectionState::Dead}] = num_death[node][i]; + model[node].populations[{AgeGroup(i), InfectionState::Recovered}] = num_rec[node][i]; + } + } + else { + const auto sum_vec = [](const std::vector& v) { + return std::accumulate(v.begin(), v.end(), 0.0); + }; + const size_t i0 = 0; + model[node].populations[{AgeGroup(i0), InfectionState::Exposed}] = sum_vec(num_Exposed[node]); + model[node].populations[{AgeGroup(i0), InfectionState::InfectedNoSymptoms}] = + sum_vec(num_InfectedNoSymptoms[node]); + model[node].populations[{AgeGroup(i0), InfectionState::InfectedNoSymptomsConfirmed}] = 0; + model[node].populations[{AgeGroup(i0), InfectionState::InfectedSymptoms}] = + sum_vec(num_InfectedSymptoms[node]); + model[node].populations[{AgeGroup(i0), InfectionState::InfectedSymptomsConfirmed}] = 0; + model[node].populations[{AgeGroup(i0), InfectionState::InfectedSevere}] = + sum_vec(num_InfectedSevere[node]); if (!is_divi_data_available(date)) { - model[node].populations[{AgeGroup(i), InfectionState::InfectedCritical}] = num_icu[node][i]; + model[node].populations[{AgeGroup(i0), InfectionState::InfectedCritical}] = sum_vec(num_icu[node]); } - model[node].populations[{AgeGroup(i), InfectionState::Dead}] = num_death[node][i]; - model[node].populations[{AgeGroup(i), InfectionState::Recovered}] = num_rec[node][i]; + model[node].populations[{AgeGroup(i0), InfectionState::Dead}] = sum_vec(num_death[node]); + model[node].populations[{AgeGroup(i0), InfectionState::Recovered}] = sum_vec(num_rec[node]); } } else { @@ -231,10 +269,20 @@ IOResult set_population_data(std::vector>& model, assert(num_population.size() == vregion.size()); assert(model.size() == vregion.size()); for (size_t region = 0; region < vregion.size(); region++) { - auto num_groups = model[region].parameters.get_num_groups(); - for (auto i = AgeGroup(0); i < num_groups; i++) { + const auto model_groups = (size_t)model[region].parameters.get_num_groups(); + const auto data_groups = num_population[region].size(); + assert(data_groups == model_groups || (model_groups == 1 && data_groups >= 1)); + + if (data_groups == model_groups) { + for (auto i = AgeGroup(0); i < model[region].parameters.get_num_groups(); i++) { + model[region].populations.template set_difference_from_group_total( + {i, InfectionState::Susceptible}, num_population[region][(size_t)i]); + } + } + else if (model_groups == 1 && data_groups >= 1) { + const double total = std::accumulate(num_population[region].begin(), num_population[region].end(), 0.0); model[region].populations.template set_difference_from_group_total( - {i, InfectionState::Susceptible}, num_population[region][size_t(i)]); + {AgeGroup(0), InfectionState::Susceptible}, total); } } return success(); @@ -256,6 +304,15 @@ IOResult set_population_data(std::vector>& model, const std::str return success(); } +template +IOResult set_population_data_provincias(std::vector>& model, const std::string& path, + const std::vector& vregion) +{ + BOOST_OUTCOME_TRY(const auto&& num_population, read_population_data_spain(path, vregion)); + BOOST_OUTCOME_TRY(set_population_data(model, num_population, vregion)); + return success(); +} + } //namespace details #ifdef MEMILIO_HAS_HDF5 @@ -283,8 +340,8 @@ IOResult export_input_data_county_timeseries( const std::string& divi_data_path, const std::string& confirmed_cases_path, const std::string& population_data_path) { const auto num_age_groups = (size_t)models[0].parameters.get_num_groups(); - assert(scaling_factor_inf.size() == num_age_groups); - assert(num_age_groups == ConfirmedCasesDataEntry::age_group_names.size()); + // allow scalar scaling factor as convenience for 1-group models + assert(scaling_factor_inf.size() == 1 || scaling_factor_inf.size() == num_age_groups); assert(models.size() == region.size()); std::vector> extrapolated_data( region.size(), TimeSeries::zero(num_days + 1, (size_t)InfectionState::Count * num_age_groups)); @@ -333,9 +390,10 @@ IOResult export_input_data_county_timeseries( * @param[in] pydata_dir Directory of files. */ template -IOResult read_input_data_germany(std::vector& model, Date date, +IOResult read_input_data_germany(std::vector& model, Date date, std::vector& /*state*/, const std::vector& scaling_factor_inf, double scaling_factor_icu, - const std::string& pydata_dir) + const std::string& pydata_dir, int /*num_days*/ = 0, + bool /*export_time_series*/ = false) { BOOST_OUTCOME_TRY( details::set_divi_data(model, path_join(pydata_dir, "germany_divi.json"), {0}, date, scaling_factor_icu)); @@ -358,7 +416,8 @@ IOResult read_input_data_germany(std::vector& model, Date date, template IOResult read_input_data_state(std::vector& model, Date date, std::vector& state, const std::vector& scaling_factor_inf, double scaling_factor_icu, - const std::string& pydata_dir) + const std::string& pydata_dir, int /*num_days*/ = 0, + bool /*export_time_series*/ = false) { BOOST_OUTCOME_TRY( @@ -407,6 +466,33 @@ IOResult read_input_data_county(std::vector& model, Date date, cons return success(); } +/** + * @brief Reads population data from population files for the specefied county. + * @param[in, out] model Vector of model in which the data is set. + * @param[in] date Date for which the data should be read. + * @param[in] county Vector of region keys of counties of interest. + * @param[in] scaling_factor_inf Factors by which to scale the confirmed cases of rki data. + * @param[in] scaling_factor_icu Factor by which to scale the icu cases of divi data. + * @param[in] pydata_dir Directory of files. + * @param[in] num_days [Default: 0] Number of days to be simulated; required to extrapolate real data. + * @param[in] export_time_series [Default: false] If true, reads data for each day of simulation and writes it in the same directory as the input files. + */ +template +IOResult read_input_data_provincias(std::vector& model, Date date, const std::vector& provincias, + const std::vector& scaling_factor_inf, double scaling_factor_icu, + const std::string& pydata_dir, int /*num_days*/ = 0, + bool /*export_time_series*/ = false) +{ + BOOST_OUTCOME_TRY(details::set_divi_data(model, path_join(pydata_dir, "provincia_icu.json"), provincias, date, + scaling_factor_icu)); + + BOOST_OUTCOME_TRY(details::set_confirmed_cases_data(model, path_join(pydata_dir, "cases_all_pronvincias.json"), + provincias, date, scaling_factor_inf)); + BOOST_OUTCOME_TRY(details::set_population_data_provincias( + model, path_join(pydata_dir, "provincias_current_population.json"), provincias)); + return success(); +} + /** * @brief reads population data from population files for the specified nodes * @param[in, out] model vector of model in which the data is set diff --git a/cpp/tests/test_odesecir.cpp b/cpp/tests/test_odesecir.cpp index cb4ad98a2f..b59da559b1 100755 --- a/cpp/tests/test_odesecir.cpp +++ b/cpp/tests/test_odesecir.cpp @@ -29,6 +29,7 @@ #include "memilio/io/parameters_io.h" #include "memilio/data/analyze_result.h" #include "memilio/math/adapt_rk.h" +#include "memilio/geography/regions.h" #include @@ -1516,5 +1517,184 @@ TEST(TestOdeSecir, read_population_data_failure) EXPECT_EQ(result.error().message(), "File with county population expected."); } +TEST(TestOdeSecirIO, read_input_data_county_aggregates_one_group) +{ + // Set up two models with different number of age groups. + const size_t num_age_groups = 6; + std::vector> models6{mio::osecir::Model((int)num_age_groups)}; + std::vector> models1{mio::osecir::Model(1)}; + + // Relevant parameters for model with 6 age groups + for (auto i = mio::AgeGroup(0); i < (mio::AgeGroup)num_age_groups; ++i) { + models6[0].parameters.get>()[i] = 0.2; + models6[0].parameters.get>()[i] = 0.25; + } + + // Relevant parameters for model with 1 age group + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.2; + models1[0].parameters.get>()[mio::AgeGroup(0)] = 0.25; + + const auto pydata_dir_Germany = mio::path_join(TEST_DATA_DIR, "Germany", "pydata"); + const std::vector counties{1002}; + const auto date = mio::Date(2020, 12, 1); + + std::vector scale6(num_age_groups, 1.0); + std::vector scale1{1.0}; + + // Initialize both models + ASSERT_THAT(mio::osecir::read_input_data_county(models6, date, counties, scale6, 1.0, pydata_dir_Germany), + IsSuccess()); + ASSERT_THAT(mio::osecir::read_input_data_county(models1, date, counties, scale1, 1.0, pydata_dir_Germany), + IsSuccess()); + + // Aggreagate the results from the model with 6 age groups and compare with the model with 1 age group + const auto& m6 = models6[0]; + const auto& m1 = models1[0]; + const double tol = 1e-10; + for (int s = 0; s < (int)mio::osecir::InfectionState::Count; ++s) { + double sum6 = 0.0; + for (size_t ag = 0; ag < num_age_groups; ++ag) { + sum6 += m6.populations[{mio::AgeGroup(ag), (mio::osecir::InfectionState)s}].value(); + } + const double v1 = m1.populations[{mio::AgeGroup(0), (mio::osecir::InfectionState)s}].value(); + EXPECT_NEAR(sum6, v1, tol); + } + + // Total population + EXPECT_NEAR(m6.populations.get_total(), m1.populations.get_total(), tol); +} + +TEST(TestOdeSecirIO, set_population_data_single_age_group) +{ + const size_t num_age_groups = 6; + + // Create two models: one with 6 age groups, one with 1 age group + std::vector> models6{mio::osecir::Model((int)num_age_groups)}; + std::vector> models1{mio::osecir::Model(1)}; + + // Test population data with 6 different values for age groups + std::vector> population_data6 = {{10000.0, 20000.0, 30000.0, 25000.0, 15000.0, 8000.0}}; + std::vector> population_data1 = {{108000.0}}; // sum of all age groups + std::vector regions = {1002}; + + // Set population data for both models + EXPECT_THAT(mio::osecir::details::set_population_data(models6, population_data6, regions), IsSuccess()); + EXPECT_THAT(mio::osecir::details::set_population_data(models1, population_data1, regions), IsSuccess()); + + // Sum all compartments across age groups in 6-group model and compare 1-group model + const double tol = 1e-10; + for (int s = 0; s < (int)mio::osecir::InfectionState::Count; ++s) { + double sum6 = 0.0; + for (size_t ag = 0; ag < num_age_groups; ++ag) { + sum6 += models6[0].populations[{mio::AgeGroup(ag), (mio::osecir::InfectionState)s}].value(); + } + double val1 = models1[0].populations[{mio::AgeGroup(0), (mio::osecir::InfectionState)s}].value(); + + EXPECT_NEAR(sum6, val1, tol); + } + + // Total population should also match + EXPECT_NEAR(models6[0].populations.get_total(), models1[0].populations.get_total(), tol); +} + +TEST(TestOdeSecirIO, set_confirmed_cases_data_single_age_group) +{ + const size_t num_age_groups = 6; + + // Create two models: one with 6 age groups, one with 1 age group + std::vector> models6{mio::osecir::Model((int)num_age_groups)}; + std::vector> models1{mio::osecir::Model(1)}; + + // Create case data for all 6 age groups over multiple days (current day + 6 days back) + std::vector case_data; + + for (int day_offset = -6; day_offset <= 0; ++day_offset) { + mio::Date current_date = mio::offset_date_by_days(mio::Date(2020, 12, 1), day_offset); + + for (int age_group = 0; age_group < 6; ++age_group) { + double base_confirmed = 80.0 + age_group * 8.0 + (day_offset + 6) * 5.0; + double base_recovered = 40.0 + age_group * 4.0 + (day_offset + 6) * 3.0; + double base_deaths = 3.0 + age_group * 0.5 + (day_offset + 6) * 0.5; + + mio::ConfirmedCasesDataEntry entry{base_confirmed, + base_recovered, + base_deaths, + current_date, + mio::AgeGroup(age_group), + {}, + mio::regions::CountyId(1002), + {}}; + case_data.push_back(entry); + } + } + + std::vector regions = {1002}; + std::vector scaling_factors = {1.0}; + + // Set confirmed cases data for both models + EXPECT_THAT(mio::osecir::details::set_confirmed_cases_data(models6, case_data, regions, mio::Date(2020, 12, 1), + scaling_factors), + IsSuccess()); + EXPECT_THAT(mio::osecir::details::set_confirmed_cases_data(models1, case_data, regions, mio::Date(2020, 12, 1), + scaling_factors), + IsSuccess()); + + // Sum all compartments across age groups in 6-group model should be equal to 1-group model + for (int s = 0; s < (int)mio::osecir::InfectionState::Count; ++s) { + double sum6 = 0.0; + for (size_t ag = 0; ag < num_age_groups; ++ag) { + sum6 += models6[0].populations[{mio::AgeGroup(ag), (mio::osecir::InfectionState)s}].value(); + } + + double val1 = models1[0].populations[{mio::AgeGroup(0), (mio::osecir::InfectionState)s}].value(); + + EXPECT_NEAR(sum6, val1, 1e-10); + } + + // Total population + EXPECT_NEAR(models6[0].populations.get_total(), models1[0].populations.get_total(), 1e-10); +} + +TEST(TestOdeSecirIO, set_divi_data_single_age_group) +{ + // Create models with 6 age groups and 1 age group + std::vector> models_6_groups{mio::osecir::Model(6)}; + std::vector> models_1_group{mio::osecir::Model(1)}; + + // Set relevant parameters for all age groups + for (int i = 0; i < 6; i++) { + models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = 0.2; + models_6_groups[0].parameters.get>()[mio::AgeGroup(i)] = 0.25; + } + + // Set relevant parameters for 1 age group model + models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 0.2; + models_1_group[0].parameters.get>()[mio::AgeGroup(0)] = 0.25; + + // Apply DIVI data to both models + std::vector regions = {1002}; + double scaling_factor_icu = 1.0; + mio::Date date(2020, 12, 1); + std::string divi_data_path = mio::path_join(TEST_DATA_DIR, "Germany", "pydata", "county_divi_ma7.json"); + auto result_6_groups = + mio::osecir::details::set_divi_data(models_6_groups, divi_data_path, regions, date, scaling_factor_icu); + auto result_1_group = + mio::osecir::details::set_divi_data(models_1_group, divi_data_path, regions, date, scaling_factor_icu); + + EXPECT_THAT(result_6_groups, IsSuccess()); + EXPECT_THAT(result_1_group, IsSuccess()); + + // Calculate totals after applying DIVI data + double total_icu_6_groups_after = 0.0; + for (int i = 0; i < 6; i++) { + total_icu_6_groups_after += + models_6_groups[0].populations[{mio::AgeGroup(i), mio::osecir::InfectionState::InfectedCritical}].value(); + } + double icu_1_group_after = + models_1_group[0].populations[{mio::AgeGroup(0), mio::osecir::InfectionState::InfectedCritical}].value(); + + EXPECT_NEAR(total_icu_6_groups_after, icu_1_group_after, 1e-10); +} + #endif #endif diff --git a/pycode/examples/simulation/graph_germany_nuts0.py b/pycode/examples/simulation/graph_germany_nuts0.py new file mode 100644 index 0000000000..4850a03902 --- /dev/null +++ b/pycode/examples/simulation/graph_germany_nuts0.py @@ -0,0 +1,670 @@ +############################################################################# +# Copyright (C) 2020-2025 MEmilio +# +# Authors: Henrik Zunker +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# +import os +os.environ["KERAS_BACKEND"] = "tensorflow" + +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import datetime +import pickle +from scipy.stats import truncnorm + +from matplotlib.patches import Patch + +import bayesflow as bf +import keras + +import memilio.simulation as mio +import memilio.simulation.osecir as osecir +from memilio.simulation.osecir import Model, interpolate_simulation_result +from memilio.epidata import defaultDict as dd + +import geopandas as gpd + + +region_ids = [0] +inference_params = ['damping_values', 't_E', 't_ISy', 't_ISev', + 't_Cr', 'mu_CR', 'mu_IH', 'mu_HU', 'mu_UD', 'transmission_prob'] +summary_vars = ['state0'] + +bounds = { + 't_E': (1.0, 5.2), + 't_ISy': (4.0, 10.0), + 't_ISev': (5.0, 10.0), + 't_Cr': (9.0, 17.0), + 'mu_CR': (0.0, 0.4), + 'mu_IH': (0.0, 0.2), + 'mu_HU': (0.0, 0.4), + 'mu_UD': (0.0, 0.4), + 'transmission_prob': (0.0, 0.2) +} +SPIKE_SCALE = 0.4 +SLAB_SCALE = 0.2 +DATE_TIME = datetime.date(year=2020, month=10, day=1) +NUM_DAMPING_POINTS = 3 + +def plot_region_fit( + data: np.ndarray, + region: int, + true_data=None, + ax=None, + label=None, + color="red", + only_80q=False +): + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data is not None: + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + n_samples, n_time, n_regions = data.shape + if not (0 <= region < n_regions): + raise IndexError + + x = np.arange(n_time) + vals = data[:, :, region] # (samples, time_points) + + qs_80 = np.quantile(vals, q=[0.1, 0.9], axis=0) + qs_90 = np.quantile(vals, q=[0.05, 0.95], axis=0) + qs_95 = np.quantile(vals, q=[0.025, 0.975], axis=0) + + med = np.median(vals, axis=0) + + if ax is None: + fig, ax = plt.subplots() + + ax.plot( + x, med, lw=2, label=label or f"Germany", color=color) + ax.fill_between(x, qs_80[0], qs_80[1], alpha=0.5, + color=color, label="80% CI") + if not only_80q: + ax.fill_between(x, qs_90[0], qs_90[1], alpha=0.3, + color=color, label="90% CI") + ax.fill_between(x, qs_95[0], qs_95[1], alpha=0.1, + color=color, label="95% CI") + + if true_data is not None: + true_vals = true_data[:, region] # (time_points,) + ax.plot(x, true_vals, lw=2, color="black", label="True data") + + ax.set_xlabel("Time", fontsize=12) + ax.set_ylabel("ICU", fontsize=12) + ax.set_title(f"Germany", fontsize=12) + if label is not None: + ax.legend(fontsize=11, loc="upper right") + + +def plot_icu_on_germany(simulations, name, synthetic, with_aug): + med = np.median(simulations, axis=0) + + print(simulations.shape) + population = pd.read_json('data/Germany/pydata/county_current_population_germany.json') + # print(population['Population'].shape) + values = med / population['Population'][None, :] * 100000 + print(values.shape) + + + map_data = gpd.read_file(os.path.join(os.getcwd(), 'tools/vg2500_12-31.utm32s.shape/vg2500/VG2500_KRS.shp')) + fedstate_data = gpd.read_file(os.path.join(os.getcwd(), 'tools/vg2500_12-31.utm32s.shape/vg2500/VG2500_LAN.shp')) + state_data = gpd.read_file(os.path.join(os.getcwd(), 'tools/vg2500_12-31.utm32s.shape/vg2500/VG2500_STA.shp')) + + plot_map(simulations[0], state_data, map_data, fedstate_data, "Median ICU", f"{name}/median_icu_germany_initial_{name}{synthetic}{with_aug}") + plot_map(simulations[-1], state_data, map_data, fedstate_data, "Median ICU", f"{name}/median_icu_germany_final_{name}{synthetic}{with_aug}") + +def plot_map(values, state_data, map_data, fedstate_data, label, filename): + state_data[label] = state_data['ARS'].map(dict(zip([f"{region_id:02d}" for region_id in region_ids], values))) + + fig, ax = plt.subplots(figsize=(12, 14)) + state_data.plot( + column=f"{label}", + cmap='Reds', + linewidth=1., + ax=ax, + edgecolor='0.6', + legend=True, + legend_kwds={'label': f"{label} per State", 'shrink': 0.6}, + ) + map_data.boundary.plot(ax=ax, color='black', linewidth=0.2) + fedstate_data.boundary.plot(ax=ax, color='black', linewidth=0.4) + ax.set_title(f"{label} per Federal State") + ax.axis('off') + plt.savefig(filename, bbox_inches='tight', dpi=300) + +def plot_damping_values(damping_values, name, synthetic, with_aug): + + med = np.median(damping_values, axis=0) + mad = np.median(np.abs(damping_values - med), axis=0) + + # Extend for step plotting + med_extended = np.hstack([med, med[:, -1][:, None]]) + mad_extended = np.hstack([mad, mad[:, -1][:, None]]) + + # Plot damping values per region + fig, axes = plt.subplots(4, 4, figsize=(15, 10), constrained_layout=True) + axes = axes.flatten() + x = np.arange(15, 61, 15) # Time steps from 15 to 60 + + for i, ax in enumerate(axes): + if i < 16: + ax.stairs(med[i], edges=x, lw=2, color='red', baseline=None) + ax.fill_between( + x, med_extended[i] - mad_extended[i], med_extended[i] + mad_extended[i], + alpha=0.25, color='red', step='post' + ) + ax.set_title(f"{dd.State[i+1]}", fontsize=10) + ax.set_xlabel("Time", fontsize=8) + ax.set_ylabel("Damping Value", fontsize=8) + else: + ax.axis('off') # Hide unused subplots + + plt.suptitle("Damping Values per Region", fontsize=14) + plt.savefig(f"{name}/damping_values{name}{synthetic}{with_aug}.png", dpi=300) + + # Combined plot for all regions + fig, ax = plt.subplots(figsize=(10, 6)) + cmap = plt.cm.get_cmap("viridis", 16) # Colormap with 16 distinct colors + + for i in range(16): + ax.stairs( + med[i], edges=x, lw=2, label=f"{dd.State[i+1]}", + color=cmap(i), baseline=None + ) + + ax.set_title("Damping Values per Region (Combined)", fontsize=14) + ax.set_xlabel("Time", fontsize=12) + ax.set_ylabel("Damping Value", fontsize=12) + ax.legend(fontsize=10, loc="upper right", ncol=2) + plt.savefig(f"{name}/damping_values_combined{name}{synthetic}{with_aug}.png", dpi=300) + + +class Simulation: + """ """ + + def __init__(self, data_dir, start_date, results_dir): + self.num_groups = 1 + self.data_dir = data_dir + self.start_date = start_date + self.results_dir = results_dir + if not os.path.exists(self.results_dir): + os.makedirs(self.results_dir) + + def set_covid_parameters(self, model, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): + """ + + :param model: + + """ + model.parameters.TimeExposed[mio.AgeGroup(0)] = t_E + model.parameters.TimeInfectedNoSymptoms[mio.AgeGroup(0)] = 5.2 - t_E # todo: correct? + model.parameters.TimeInfectedSymptoms[mio.AgeGroup(0)] = t_ISy + model.parameters.TimeInfectedSevere[mio.AgeGroup(0)] = t_ISev + model.parameters.TimeInfectedCritical[mio.AgeGroup(0)] = t_Cr + + # probabilities + model.parameters.TransmissionProbabilityOnContact[mio.AgeGroup( + 0)] = transmission_prob + model.parameters.RelativeTransmissionNoSymptoms[mio.AgeGroup(0)] = 1 + + model.parameters.RecoveredPerInfectedNoSymptoms[mio.AgeGroup( + 0)] = mu_CR + model.parameters.SeverePerInfectedSymptoms[mio.AgeGroup(0)] = mu_IH + model.parameters.CriticalPerSevere[mio.AgeGroup(0)] = mu_HU + model.parameters.DeathsPerCritical[mio.AgeGroup(0)] = mu_UD + + # start day is set to the n-th day of the year + model.parameters.StartDay = self.start_date.timetuple().tm_yday + + model.parameters.Seasonality = mio.UncertainValue(0.2) + + def set_contact_matrices(self, model): + """ + + :param model: + + """ + contact_matrices = mio.ContactMatrixGroup(1, self.num_groups) + + baseline = np.ones((self.num_groups, self.num_groups)) * 7.95 + minimum = np.zeros((self.num_groups, self.num_groups)) + contact_matrices[0] = mio.ContactMatrix(baseline, minimum) + model.parameters.ContactPatterns.cont_freq_mat = contact_matrices + + def set_npis(self, params, end_date, damping_values): + """ + + :param params: + :param end_date: + + """ + start_damping_1 = DATE_TIME + datetime.timedelta(days=15) + start_damping_2 = DATE_TIME + datetime.timedelta(days=30) + start_damping_3 = DATE_TIME + datetime.timedelta(days=45) + + if start_damping_1 < end_date: + start_date = (start_damping_1 - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping( + mio.Damping(np.r_[damping_values[0]], t=start_date)) + + if start_damping_2 < end_date: + start_date = (start_damping_2 - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping( + mio.Damping(np.r_[damping_values[1]], t=start_date)) + + if start_damping_3 < end_date: + start_date = (start_damping_3 - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping( + mio.Damping(np.r_[damping_values[2]], t=start_date)) + + def get_graph(self, end_date, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): + """ + + :param end_date: + + """ + print("Initializing model...") + model = Model(self.num_groups) + self.set_covid_parameters( + model, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) + self.set_contact_matrices(model) + print("Model initialized.") + + graph = osecir.ModelGraph() + + scaling_factor_infected = [2.5] + scaling_factor_icu = 1.0 + + data_dir_Germany = os.path.join(self.data_dir, "Germany") + pydata_dir = os.path.join(data_dir_Germany, "pydata") + + path_population_data = os.path.join( + pydata_dir, "county_current_population_germany.json") + + print("Setting nodes...") + mio.osecir.set_node_germany( + model.parameters, + mio.Date(self.start_date.year, + self.start_date.month, self.start_date.day), + mio.Date(end_date.year, + end_date.month, end_date.day), pydata_dir, + path_population_data, False, graph, scaling_factor_infected, + scaling_factor_icu, 0., 0, False) + + print("Graph created.") + + return graph + + def run(self, num_days_sim, damping_values, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob, save_graph=True): + """ + + :param num_days_sim: + :param num_runs: (Default value = 10) + :param save_graph: (Default value = True) + :param create_gif: (Default value = True) + + """ + mio.set_log_level(mio.LogLevel.Warning) + end_date = self.start_date + datetime.timedelta(days=num_days_sim) + + graph = self.get_graph(end_date, t_E, t_ISy, t_ISev, + t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) + + mobility_graph = osecir.MobilityGraph() + for node_idx in range(graph.num_nodes): + node = graph.get_node(node_idx) + + self.set_npis( + node.property.parameters, + end_date, + damping_values[node_idx]) + mobility_graph.add_node(node.id, node.property) + for edge_idx in range(graph.num_edges): + mobility_graph.add_edge( + graph.get_edge(edge_idx).start_node_idx, + graph.get_edge(edge_idx).end_node_idx, + graph.get_edge(edge_idx).property) + mobility_sim = osecir.MobilitySimulation(mobility_graph, t0=0, dt=0.5) + mobility_sim.advance(num_days_sim) + + results = {} + for node_idx in range(mobility_sim.graph.num_nodes): + results[f'state{node_idx}'] = osecir.interpolate_simulation_result( + mobility_sim.graph.get_node(node_idx).property.result) + + return results + + +def run_germany_nuts0_simulation(damping_values, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): + mio.set_log_level(mio.LogLevel.Warning) + file_path = os.path.dirname(os.path.abspath(__file__)) + + sim = Simulation( + data_dir=os.path.join(file_path, "../../../data"), + start_date=DATE_TIME, + results_dir=os.path.join(file_path, "../../../results_osecir")) + num_days_sim = 60 + + results = sim.run(num_days_sim, damping_values, t_E, t_ISy, + t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) + + return results + +def prior(): + damping_values = np.zeros((NUM_DAMPING_POINTS, 1)) + for i in range(NUM_DAMPING_POINTS): + mean = np.random.uniform(0, 1) + scale = 0.1 + a, b = (0 - mean) / scale, (1 - mean) / scale + damping_values[i] = truncnorm.rvs( + a=a, b=b, loc=mean, scale=scale, size=1 + ) + return { + 'damping_values': np.transpose(damping_values), + 't_E': np.random.uniform(*bounds['t_E']), + 't_ISy': np.random.uniform(*bounds['t_ISy']), + 't_ISev': np.random.uniform(*bounds['t_ISev']), + 't_Cr': np.random.uniform(*bounds['t_Cr']), + 'mu_CR': np.random.uniform(*bounds['mu_CR']), + 'mu_IH': np.random.uniform(*bounds['mu_IH']), + 'mu_HU': np.random.uniform(*bounds['mu_HU']), + 'mu_UD': np.random.uniform(*bounds['mu_UD']), + 'transmission_prob': np.random.uniform(*bounds['transmission_prob']) + } + + +def load_divi_data(): + file_path = os.path.dirname(os.path.abspath(__file__)) + divi_path = os.path.join(file_path, "../../../data/Germany/pydata") + + data = pd.read_json(os.path.join(divi_path, "germany_divi_ma7.json")) + data = data[data['Date'] >= np.datetime64(DATE_TIME)] + data = data[data['Date'] <= np.datetime64(DATE_TIME + datetime.timedelta(days=60))] + data = data.sort_values(by=['Date']) + divi_dict = {} + divi_dict[f"state0"] = data['ICU'].to_numpy()[None, :, None] + return divi_dict + + + +def extract_observables(simulation_results, observable_index=7): + for key in simulation_results.keys(): + if key not in inference_params: + simulation_results[key] = simulation_results[key][:, :, observable_index][..., np.newaxis] + return simulation_results + + +def create_train_data(filename, number_samples=1000): + + simulator = bf.simulators.make_simulator( + [prior, run_germany_nuts0_simulation] + ) + trainings_data = simulator.sample(number_samples) + trainings_data = extract_observables(trainings_data) + with open(filename, 'wb') as f: + pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) + + +def load_pickle(path): + with open(path, "rb") as f: + return pickle.load(f) + +def is_state_key(k: str) -> bool: + return 'state' in k + +def apply_aug(d: dict, aug) -> dict: + return {k: np.clip(aug(v), 0, None) if is_state_key(k) else v for k, v in d.items()} + +def concat_dicts(base: dict, new: dict) -> dict: + missing = set(base) - set(new) + if missing: + raise KeyError(f"new dict missing keys: {sorted(missing)}") + for k in base: + base[k] = np.concatenate([base[k], new[k]]) + return base + + +def combine_results(dict_list): + combined = {} + for d in dict_list: + combined = concat_dicts(combined, d) if combined else d + return combined + +def skip_2weeks(d:dict) -> dict: + return {k: v[:, 14:, :] if is_state_key(k) else v for k, v in d.items()} + +def get_workflow(): + + simulator = bf.make_simulator( + [prior, run_germany_nuts0_simulation] + ) + adapter = ( + bf.Adapter() + .to_array() + .convert_dtype("float64", "float32") + .constrain("damping_values", lower=0.0, upper=1.0) + .constrain("t_E", lower=bounds["t_E"][0], upper=bounds["t_E"][1]) + .constrain("t_ISy", lower=bounds["t_ISy"][0], upper=bounds["t_ISy"][1]) + .constrain("t_ISev", lower=bounds["t_ISev"][0], upper=bounds["t_ISev"][1]) + .constrain("t_Cr", lower=bounds["t_Cr"][0], upper=bounds["t_Cr"][1]) + .constrain("mu_CR", lower=bounds["mu_CR"][0], upper=bounds["mu_CR"][1]) + .constrain("mu_IH", lower=bounds["mu_IH"][0], upper=bounds["mu_IH"][1]) + .constrain("mu_HU", lower=bounds["mu_HU"][0], upper=bounds["mu_HU"][1]) + .constrain("mu_UD", lower=bounds["mu_UD"][0], upper=bounds["mu_UD"][1]) + .constrain("transmission_prob", lower=bounds["transmission_prob"][0], upper=bounds["transmission_prob"][1]) + .concatenate( + ["damping_values", "t_E", "t_ISy", "t_ISev", "t_Cr", + "mu_CR", "mu_IH", "mu_HU", "mu_UD", "transmission_prob"], + into="inference_variables", + axis=-1 + ) + .concatenate(summary_vars, into="summary_variables", axis=-1) + ) + + summary_network = bf.networks.FusionTransformer( + summary_dim=(len(bounds)+16*NUM_DAMPING_POINTS)*2, dropout=0.1 + ) + inference_network = bf.networks.FlowMatching(subnet_kwargs={'widths': (512, 512, 512, 512, 512)}) + + # aug = bf.augmentations.NNPE(spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False) + workflow = bf.BasicWorkflow( + simulator=simulator, + adapter=adapter, + summary_network=summary_network, + inference_network=inference_network, + standardize='all' + # augmentations={f'fed_state{i}': aug for i in range(len(region_ids))} + ) + + return workflow + + +def run_training(name, num_training_files=20): + train_template = name+"/trainings_data{i}_"+name+".pickle" + val_path = f"{name}/validation_data_{name}.pickle" + + aug = bf.augmentations.NNPE( + spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False + ) + + # training data + train_files = [train_template.format(i=i) for i in range(1, 1+num_training_files)] + trainings_data = None + for p in train_files: + d = load_pickle(p) + d = apply_aug(d, aug=aug) # only on region keys + d = skip_2weeks(d) + d['damping_values'] = d['damping_values'].reshape((d['damping_values'].shape[0], -1)) + if trainings_data is None: + trainings_data = d + else: + trainings_data = concat_dicts(trainings_data, d) + + # validation data + validation_data = apply_aug(load_pickle(val_path), aug=aug) + validation_data = skip_2weeks(validation_data) + validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) + + # check data + workflow = get_workflow() + print("summary_variables shape:", workflow.adapter(trainings_data)["summary_variables"].shape) + print("inference_variables shape:", workflow.adapter(trainings_data)["inference_variables"].shape) + + history = workflow.fit_offline( + data=trainings_data, epochs=500, batch_size=64, validation_data=validation_data + ) + + workflow.approximator.save( + filepath=os.path.join(f"{name}/model_{name}.keras") + ) + + plots = workflow.plot_default_diagnostics( + test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True} + ) + plots['losses'].savefig(f'{name}/losses_{name}.png') + plots['recovery'].savefig(f'{name}/recovery_{name}.png') + plots['calibration_ecdf'].savefig(f'{name}/calibration_ecdf_{name}.png') + #plots['z_score_contraction'].savefig(f'{name}/z_score_contraction_{name}.png') + + +def run_inference(name, num_samples=1000, on_synthetic_data=False): + val_path = f"{name}/validation_data_{name}.pickle" + synthetic = "_synthetic" if on_synthetic_data else "" + + aug = bf.augmentations.NNPE( + spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False + ) + + # validation data + validation_data = load_pickle(val_path) # synthetic data + if on_synthetic_data: + # validation data + validation_data = apply_aug(validation_data, aug=aug) + validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) + validation_data_skip2w = skip_2weeks(validation_data) + divi_dict = validation_data + + divi_data = np.concatenate( + [divi_dict[f'state{i}'] for i in range(len(region_ids))], axis=-1 + )[0] # only one dataset + else: + divi_dict = load_divi_data() + validation_data_skip2w = skip_2weeks(divi_dict) + divi_data = np.concatenate( + [divi_dict[f'state{i}'] for i in range(len(region_ids))], axis=-1 + )[0] + + workflow = get_workflow() + workflow.approximator = keras.models.load_model( + filepath=os.path.join(f"{name}/model_{name}.keras") + ) + + if os.path.exists(f'{name}/sims_{name}{synthetic}_with_aug.pickle') and os.path.exists(f'{name}/sims_{name}{synthetic}.pickle'): + simulations = load_pickle(f'{name}/sims_{name}{synthetic}.pickle') + simulations_aug = load_pickle(f'{name}/sims_{name}{synthetic}_with_aug.pickle') + print("loaded simulations from file") + else: + samples = workflow.sample(conditions=validation_data_skip2w, num_samples=num_samples) + samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], num_samples, 1, NUM_DAMPING_POINTS)) + results = [] + for i in range(num_samples): # we only have one dataset for inference here + result = run_germany_nuts0_simulation( + damping_values=samples['damping_values'][0, i], + t_E=samples['t_E'][0, i], t_ISy=samples['t_ISy'][0, i], + t_ISev=samples['t_ISev'][0, i], t_Cr=samples['t_Cr'][0, i], + mu_CR=samples['mu_CR'][0, i], mu_IH=samples['mu_IH'][0, i], + mu_HU=samples['mu_HU'][0, i], mu_UD=samples['mu_UD'][0, i], + transmission_prob=samples['transmission_prob'][0, i] + ) + for key in result.keys(): + result[key] = np.array(result[key])[None, ...] # add sample axis + results.append(result) + results = combine_results(results) + results = extract_observables(results) + results_aug = apply_aug(results, aug=aug) + + # get sims in shape (samples, time, regions) + simulations = np.zeros((num_samples, divi_data.shape[0], divi_data.shape[1])) + simulations_aug = np.zeros((num_samples, divi_data.shape[0], divi_data.shape[1])) + for i in range(num_samples): + simulations[i] = np.concatenate([results[key][i] for key in results.keys()], axis=-1) + simulations_aug[i] = np.concatenate([results_aug[key][i] for key in results.keys()], axis=-1) + + # save sims + with open(f'{name}/sims_{name}{synthetic}.pickle', 'wb') as f: + pickle.dump(simulations, f, pickle.HIGHEST_PROTOCOL) + with open(f'{name}/sims_{name}{synthetic}_with_aug.pickle', 'wb') as f: + pickle.dump(simulations_aug, f, pickle.HIGHEST_PROTOCOL) + + samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], samples['damping_values'].shape[1], -1)) + validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) + + plot = bf.diagnostics.pairs_posterior(samples, priors=validation_data, dataset_id=0) + plot.savefig(f'{name}/pairs_posterior_{name}{synthetic}.png') + + fig, axes = plt.subplots(2, 2, figsize=(12, 12), sharex=True, sharey='row', constrained_layout=True) + # Plot without augmentation + plot_region_fit( + simulations, region=0, true_data=divi_data, label="Region Aggregated Median (No Aug)", ax=axes[0, 0], color="#132a70" + ) + axes[0, 0].set_title("Without Augmentation") + # Plot with augmentation + plot_region_fit( + simulations_aug, region=0, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[0, 1], color="#132a70" + ) + axes[0, 1].set_title("With Augmentation") + # Plot without augmentation (80% quantile only) + plot_region_fit( + simulations, region=0, true_data=divi_data, label="Region Aggregated Median (No Aug)", ax=axes[1, 0], color="#132a70", only_80q=True + ) + axes[1, 0].set_title("Without Augmentation (80% Quantile)") + # Plot with augmentation (80% quantile only) + plot_region_fit( + simulations_aug, region=0, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[1, 1], color="#132a70", only_80q=True + ) + axes[1, 1].set_title("With Augmentation (80% Quantile)") + plt.savefig(f'{name}/region_aggregated_{name}{synthetic}.png') + plt.close() + + plot_icu_on_germany(simulations, name, synthetic, with_aug="") + plot_icu_on_germany(simulations_aug, name, synthetic, with_aug="_with_aug") + + simulation_agg = np.sum(simulations, axis=-1, keepdims=True) # sum over regions + simulation_aug_agg = np.sum(simulations_aug, axis=-1, keepdims=True) + + rmse = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes(simulation_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True), normalize=False) + rmse_aug = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes(simulation_aug_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True), normalize=False) + print("Mean RMSE over regions:", rmse["values"].mean()) + print("Mean RMSE over regions (with aug):", rmse_aug["values"].mean()) + + cal_error = bf.diagnostics.metrics.calibration_error(np.swapaxes(simulation_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True)) + cal_error_aug = bf.diagnostics.metrics.calibration_error(np.swapaxes(simulation_aug_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True)) + print("Mean Calibration Error over regions:", cal_error["values"].mean()) + print("Mean Calibration Error over regions (with aug):", cal_error_aug["values"].mean()) + + +if __name__ == "__main__": + name = "nuts0" + + if not os.path.exists(name): + os.makedirs(name) + # create_train_data(filename=f'{name}/trainings_data10_{name}.pickle', number_samples=2000) + # run_training(name=name, num_training_files=10) + run_inference(name=name, on_synthetic_data=False) + run_inference(name=name, on_synthetic_data=True) \ No newline at end of file diff --git a/pycode/examples/simulation/graph_germany_nuts1.py b/pycode/examples/simulation/graph_germany_nuts1.py new file mode 100644 index 0000000000..041c0f726a --- /dev/null +++ b/pycode/examples/simulation/graph_germany_nuts1.py @@ -0,0 +1,873 @@ +############################################################################# +# Copyright (C) 2020-2025 MEmilio +# +# Authors: Henrik Zunker +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# +import os +os.environ["KERAS_BACKEND"] = "tensorflow" + +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import datetime +import pickle +from scipy.stats import truncnorm + +from matplotlib.patches import Patch + +import bayesflow as bf +import keras + +import memilio.simulation as mio +import memilio.simulation.osecir as osecir +from memilio.simulation.osecir import Model, interpolate_simulation_result +from memilio.epidata import defaultDict as dd + +import geopandas as gpd + + +region_ids = [region_id for region_id in dd.State.keys()] +inference_params = ['damping_values', 't_E', 't_ISy', 't_ISev', + 't_Cr', 'mu_CR', 'mu_IH', 'mu_HU', 'mu_UD', 'transmission_prob'] +summary_vars = [f'fed_state{i}' for i in range(16)] + +bounds = { + 't_E': (1.0, 5.2), + 't_ISy': (4.0, 10.0), + 't_ISev': (5.0, 10.0), + 't_Cr': (9.0, 17.0), + 'mu_CR': (0.0, 0.4), + 'mu_IH': (0.0, 0.2), + 'mu_HU': (0.0, 0.4), + 'mu_UD': (0.0, 0.4), + 'transmission_prob': (0.0, 0.2) +} +SPIKE_SCALE = 0.4 +SLAB_SCALE = 0.2 +DATE_TIME = datetime.date(year=2020, month=10, day=1) +NUM_DAMPING_POINTS = 3 + +def plot_region_fit( + data: np.ndarray, + region: int, + true_data=None, + ax=None, + label=None, + color="red", + only_80q=False +): + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data is not None: + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + n_samples, n_time, n_regions = data.shape + if not (0 <= region < n_regions): + raise IndexError + + x = np.arange(n_time) + vals = data[:, :, region] # (samples, time_points) + + qs_80 = np.quantile(vals, q=[0.1, 0.9], axis=0) + qs_90 = np.quantile(vals, q=[0.05, 0.95], axis=0) + qs_95 = np.quantile(vals, q=[0.025, 0.975], axis=0) + + med = np.median(vals, axis=0) + + if ax is None: + fig, ax = plt.subplots() + + ax.plot( + x, med, lw=2, label=label or f"{dd.State[region_ids[region]]}", color=color) + ax.fill_between(x, qs_80[0], qs_80[1], alpha=0.5, + color=color, label="80% CI") + if not only_80q: + ax.fill_between(x, qs_90[0], qs_90[1], alpha=0.3, + color=color, label="90% CI") + ax.fill_between(x, qs_95[0], qs_95[1], alpha=0.1, + color=color, label="95% CI") + + if true_data is not None: + true_vals = true_data[:, region] # (time_points,) + ax.plot(x, true_vals, lw=2, color="black", label="True data") + + ax.set_xlabel("Time", fontsize=12) + ax.set_ylabel("ICU", fontsize=12) + ax.set_title(f"{dd.State[region_ids[region]]}", fontsize=12) + if label is not None: + ax.legend(fontsize=11, loc="upper right") + + +def plot_aggregated_over_regions( + data: np.ndarray, + region_agg=np.sum, + true_data=None, + ax=None, + label=None, + color='red', + only_80q=False +): + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data is not None: + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + + # Aggregate over regions + agg_over_regions = region_agg(data, axis=-1) # (samples, time_points) + + qs_80 = np.quantile(agg_over_regions, q=[0.1, 0.9], axis=0) + qs_90 = np.quantile(agg_over_regions, q=[0.05, 0.95], axis=0) + qs_95 = np.quantile(agg_over_regions, q=[0.025, 0.975], axis=0) + + # Aggregate over samples + agg_median = np.median(agg_over_regions, axis=0) # (time_points, ) + + x = np.arange(agg_median.shape[0]) + if ax is None: + fig, ax = plt.subplots() + + ax.plot(x, agg_median, lw=2, + label=label or "Aggregated over regions", color=color) + ax.fill_between(x, qs_80[0], qs_80[1], alpha=0.5, + color=color, label="80% CI") + if not only_80q: + ax.fill_between(x, qs_90[0], qs_90[1], alpha=0.3, + color=color, label="90% CI") + ax.fill_between(x, qs_95[0], qs_95[1], alpha=0.1, + color=color, label="95% CI") + if true_data is not None: + true_vals = region_agg(true_data, axis=-1) # (time_points,) + ax.plot(x, true_vals, lw=2, color="black", label="True data") + + ax.set_xlabel("Time", fontsize=12) + ax.set_ylabel("ICU", fontsize=12) + if label is not None: + ax.legend(fontsize=11) + +def plot_icu_on_germany(simulations, name, synthetic, with_aug): + med = np.median(simulations, axis=0) + + population = pd.read_json('data/Germany/pydata/county_current_population_states.json') + values = med / population['Population'].to_numpy()[None, :] * 100000 + + + map_data = gpd.read_file(os.path.join(os.getcwd(), 'tools/vg2500_12-31.utm32s.shape/vg2500/VG2500_KRS.shp')) + fedstate_data = gpd.read_file(os.path.join(os.getcwd(), 'tools/vg2500_12-31.utm32s.shape/vg2500/VG2500_LAN.shp')) + + plot_map(values[0], map_data, fedstate_data, "Median ICU", f"{name}/median_icu_germany_initial_{name}{synthetic}{with_aug}") + plot_map(values[-1], map_data, fedstate_data, "Median ICU", f"{name}/median_icu_germany_final_{name}{synthetic}{with_aug}") + +def plot_map(values, map_data, fedstate_data, label, filename): + fedstate_data[label] = fedstate_data['ARS'].map(dict(zip([f"{region_id:02d}" for region_id in region_ids], values))) + + fig, ax = plt.subplots(figsize=(12, 14)) + fedstate_data.plot( + column=f"{label}", + cmap='Reds', + linewidth=1., + ax=ax, + edgecolor='0.6', + legend=True, + legend_kwds={'label': f"{label} per Federal State", 'shrink': 0.6}, + ) + map_data.boundary.plot(ax=ax, color='black', linewidth=0.2) + ax.set_title(f"{label} per Federal State") + ax.axis('off') + plt.savefig(filename, bbox_inches='tight', dpi=300) + +def plot_all_regions(simulations, divi_data, name, synthetic, with_aug): + n_regions = simulations.shape[-1] + fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(25, 25), layout="constrained") + ax = ax.flatten() + for i in range(n_regions): + plot_region_fit( + simulations, region=i, true_data=divi_data, label="Median", ax=ax[i], color="#132a70" + ) + plt.savefig(f'{name}/federal_states_{name}{synthetic}{with_aug}.png') + plt.close() + +def calibration_curves_per_region( + data: np.ndarray, + true_data: np.ndarray, + levels=np.linspace(0.01, 0.99, 20), + ax=None, + max_regions=None, + cmap=plt.cm.Blues, + linewidth=1.5, + legend=True, + with_ideal=True, +): + """ + Per-region calibration curves, each region in a different shade of a colormap. + + data: (samples, time, regions) + true_data: (time, regions) + max_regions: limit number of regions shown + cmap: matplotlib colormap for line shades + """ + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + + n_samples, n_time, n_regions = data.shape + if max_regions is None: + R = n_regions + else: + R = min(max_regions, n_regions) + + if ax is None: + fig, ax = plt.subplots(figsize=(5, 4)) + + colors = [cmap(i / (R + 1)) for i in range(1, R + 1)] + + x = np.asarray(levels) + for r, col in zip(range(R), colors): + emp = [] + for nominal in levels: + q_low = (1.0 - nominal) / 2.0 + q_high = 1.0 - q_low + lo = np.quantile(data[:, :, r], q_low, axis=0) + hi = np.quantile(data[:, :, r], q_high, axis=0) + hits = (true_data[:, r] >= lo) & (true_data[:, r] <= hi) + emp.append(hits.mean()) + emp = np.asarray(emp) + # , label=f"Region {r+1}") + ax.plot(x, emp, lw=linewidth, color=col, alpha=0.5) + + if with_ideal: + ideal_line = ax.plot([0, 1], [0, 1], linestyle="--", + lw=1.2, color="black", label="Ideal")[0] + else: + ideal_line = None + + if legend: + # Custom legend: one patch for regions, one line for ideal + region_patch = Patch(color=colors[-1], label="Regions") + ax.legend(handles=[region_patch, ideal_line], + frameon=True, ncol=1, fontsize=12) + + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.set_xlabel("Nominal level", fontsize=12) + ax.set_ylabel("Empirical coverage", fontsize=12) + ax.set_title("Calibration per region", fontsize=12) + return ax + + +def calibration_median_mad_over_regions( + data: np.ndarray, + true_data: np.ndarray, + levels=np.linspace(0.01, 0.99, 20), + ax=None, + color="tab:blue", + alpha=0.25, + linewidth=2.0, + with_ideal=True, +): + """ + Compute per region empirical coverage at each nominal level, + then summarize across regions with median and MAD. + + data: (samples, time, regions) + true_data: (time, regions) + """ + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + + n_samples, n_time, n_regions = data.shape + L = len(levels) + per_region = np.empty((n_regions, L), dtype=float) + + # per level coverage per region + for j, nominal in enumerate(levels): + q_low = (1.0 - nominal) / 2.0 + q_high = 1.0 - q_low + lo = np.quantile(data, q_low, axis=0) # (time, regions) + hi = np.quantile(data, q_high, axis=0) # (time, regions) + hits = (true_data >= lo) & (true_data <= hi) # (time, regions) + # mean over time for each region + per_region[:, j] = hits.mean(axis=0) + + med = np.median(per_region, axis=0) # (levels,) + mad = np.median(np.abs(per_region - med[None, :]), axis=0) + + if ax is None: + fig, ax = plt.subplots(figsize=(5, 4)) + + x = np.asarray(levels) + ax.fill_between(x, med - mad, med + mad, + alpha=alpha, color=color, label=None) + ax.plot(x, med, lw=linewidth, color=color, label="Median across regions") + + if with_ideal: + ax.plot([0, 1], [0, 1], linestyle="--", + lw=1.2, color="black", label="Ideal") + + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.set_xlabel("Nominal level", fontsize=12) + ax.set_ylabel("Empirical coverage", fontsize=12) + ax.set_title("Calibration median and MAD across regions", fontsize=12) + ax.legend(fontsize=12) + return ax, {"levels": x, "median": med, "mad": mad} + + +def plot_damping_values(damping_values, name, synthetic, with_aug): + + med = np.median(damping_values, axis=0) + mad = np.median(np.abs(damping_values - med), axis=0) + + # Extend for step plotting + med_extended = np.hstack([med, med[:, -1][:, None]]) + mad_extended = np.hstack([mad, mad[:, -1][:, None]]) + + # Plot damping values per region + fig, axes = plt.subplots(4, 4, figsize=(15, 10), constrained_layout=True) + axes = axes.flatten() + x = np.arange(15, 61, 15) # Time steps from 15 to 60 + + for i, ax in enumerate(axes): + if i < 16: + ax.stairs(med[i], edges=x, lw=2, color='red', baseline=None) + ax.fill_between( + x, med_extended[i] - mad_extended[i], med_extended[i] + mad_extended[i], + alpha=0.25, color='red', step='post' + ) + ax.set_title(f"{dd.State[i+1]}", fontsize=10) + ax.set_xlabel("Time", fontsize=8) + ax.set_ylabel("Damping Value", fontsize=8) + else: + ax.axis('off') # Hide unused subplots + + plt.suptitle("Damping Values per Region", fontsize=14) + plt.savefig(f"{name}/damping_values{name}{synthetic}{with_aug}.png", dpi=300) + + # Combined plot for all regions + fig, ax = plt.subplots(figsize=(10, 6)) + cmap = plt.cm.get_cmap("viridis", 16) # Colormap with 16 distinct colors + + for i in range(16): + ax.stairs( + med[i], edges=x, lw=2, label=f"{dd.State[i+1]}", + color=cmap(i), baseline=None + ) + + ax.set_title("Damping Values per Region (Combined)", fontsize=14) + ax.set_xlabel("Time", fontsize=12) + ax.set_ylabel("Damping Value", fontsize=12) + ax.legend(fontsize=10, loc="upper right", ncol=2) + plt.savefig(f"{name}/damping_values_combined{name}{synthetic}{with_aug}.png", dpi=300) + + +class Simulation: + """ """ + + def __init__(self, data_dir, start_date, results_dir): + self.num_groups = 1 + self.data_dir = data_dir + self.start_date = start_date + self.results_dir = results_dir + if not os.path.exists(self.results_dir): + os.makedirs(self.results_dir) + + def set_covid_parameters(self, model, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): + """ + + :param model: + + """ + model.parameters.TimeExposed[mio.AgeGroup(0)] = t_E + model.parameters.TimeInfectedNoSymptoms[mio.AgeGroup(0)] = 5.2 - t_E # todo: correct? + model.parameters.TimeInfectedSymptoms[mio.AgeGroup(0)] = t_ISy + model.parameters.TimeInfectedSevere[mio.AgeGroup(0)] = t_ISev + model.parameters.TimeInfectedCritical[mio.AgeGroup(0)] = t_Cr + + # probabilities + model.parameters.TransmissionProbabilityOnContact[mio.AgeGroup( + 0)] = transmission_prob + model.parameters.RelativeTransmissionNoSymptoms[mio.AgeGroup(0)] = 1 + + model.parameters.RecoveredPerInfectedNoSymptoms[mio.AgeGroup( + 0)] = mu_CR + model.parameters.SeverePerInfectedSymptoms[mio.AgeGroup(0)] = mu_IH + model.parameters.CriticalPerSevere[mio.AgeGroup(0)] = mu_HU + model.parameters.DeathsPerCritical[mio.AgeGroup(0)] = mu_UD + + # start day is set to the n-th day of the year + model.parameters.StartDay = self.start_date.timetuple().tm_yday + + model.parameters.Seasonality = mio.UncertainValue(0.2) + + def set_contact_matrices(self, model): + """ + + :param model: + + """ + contact_matrices = mio.ContactMatrixGroup(1, self.num_groups) + + baseline = np.ones((self.num_groups, self.num_groups)) * 7.95 + minimum = np.zeros((self.num_groups, self.num_groups)) + contact_matrices[0] = mio.ContactMatrix(baseline, minimum) + model.parameters.ContactPatterns.cont_freq_mat = contact_matrices + + def set_npis(self, params, end_date, damping_values): + """ + + :param params: + :param end_date: + + """ + start_damping_1 = DATE_TIME + datetime.timedelta(days=15) + start_damping_2 = DATE_TIME + datetime.timedelta(days=30) + start_damping_3 = DATE_TIME + datetime.timedelta(days=45) + + if start_damping_1 < end_date: + start_date = (start_damping_1 - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping( + mio.Damping(np.r_[damping_values[0]], t=start_date)) + + if start_damping_2 < end_date: + start_date = (start_damping_2 - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping( + mio.Damping(np.r_[damping_values[1]], t=start_date)) + + if start_damping_3 < end_date: + start_date = (start_damping_3 - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping( + mio.Damping(np.r_[damping_values[2]], t=start_date)) + + def get_graph(self, end_date, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): + """ + + :param end_date: + + """ + print("Initializing model...") + model = Model(self.num_groups) + self.set_covid_parameters( + model, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) + self.set_contact_matrices(model) + print("Model initialized.") + + graph = osecir.ModelGraph() + + scaling_factor_infected = [2.5] + scaling_factor_icu = 1.0 + + data_dir_Germany = os.path.join(self.data_dir, "Germany") + mobility_data_file = os.path.join( + data_dir_Germany, "mobility", "commuter_mobility_2022_states.txt") + pydata_dir = os.path.join(data_dir_Germany, "pydata") + + path_population_data = os.path.join( + pydata_dir, "county_current_population_states.json") + + print("Setting nodes...") + mio.osecir.set_nodes_states( + model.parameters, + mio.Date(self.start_date.year, + self.start_date.month, self.start_date.day), + mio.Date(end_date.year, + end_date.month, end_date.day), pydata_dir, + path_population_data, False, graph, scaling_factor_infected, + scaling_factor_icu, 0, 0, False) + + print("Setting edges...") + mio.osecir.set_edges(mobility_data_file, graph, 1) + + print("Graph created.") + + return graph + + def run(self, num_days_sim, damping_values, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob, save_graph=True): + """ + + :param num_days_sim: + :param num_runs: (Default value = 10) + :param save_graph: (Default value = True) + :param create_gif: (Default value = True) + + """ + mio.set_log_level(mio.LogLevel.Warning) + end_date = self.start_date + datetime.timedelta(days=num_days_sim) + + graph = self.get_graph(end_date, t_E, t_ISy, t_ISev, + t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) + + mobility_graph = osecir.MobilityGraph() + for node_idx in range(graph.num_nodes): + node = graph.get_node(node_idx) + + self.set_npis( + node.property.parameters, + end_date, + damping_values[node_idx]) + mobility_graph.add_node(node.id, node.property) + for edge_idx in range(graph.num_edges): + mobility_graph.add_edge( + graph.get_edge(edge_idx).start_node_idx, + graph.get_edge(edge_idx).end_node_idx, + graph.get_edge(edge_idx).property) + mobility_sim = osecir.MobilitySimulation(mobility_graph, t0=0, dt=0.5) + mobility_sim.advance(num_days_sim) + + results = {} + for node_idx in range(mobility_sim.graph.num_nodes): + results[f'fed_state{node_idx}'] = osecir.interpolate_simulation_result( + mobility_sim.graph.get_node(node_idx).property.result) + + return results + + +def run_germany_nuts1_simulation(damping_values, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): + mio.set_log_level(mio.LogLevel.Warning) + file_path = os.path.dirname(os.path.abspath(__file__)) + + sim = Simulation( + data_dir=os.path.join(file_path, "../../../data"), + start_date=DATE_TIME, + results_dir=os.path.join(file_path, "../../../results_osecir")) + num_days_sim = 60 + + results = sim.run(num_days_sim, damping_values, t_E, t_ISy, + t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) + + return results + +def prior(): + damping_values = np.zeros((NUM_DAMPING_POINTS, 16)) + for i in range(NUM_DAMPING_POINTS): + mean = np.random.uniform(0, 1) + scale = 0.1 + a, b = (0 - mean) / scale, (1 - mean) / scale + damping_values[i] = truncnorm.rvs( + a=a, b=b, loc=mean, scale=scale, size=16 + ) + return { + 'damping_values': np.transpose(damping_values), + 't_E': np.random.uniform(*bounds['t_E']), + 't_ISy': np.random.uniform(*bounds['t_ISy']), + 't_ISev': np.random.uniform(*bounds['t_ISev']), + 't_Cr': np.random.uniform(*bounds['t_Cr']), + 'mu_CR': np.random.uniform(*bounds['mu_CR']), + 'mu_IH': np.random.uniform(*bounds['mu_IH']), + 'mu_HU': np.random.uniform(*bounds['mu_HU']), + 'mu_UD': np.random.uniform(*bounds['mu_UD']), + 'transmission_prob': np.random.uniform(*bounds['transmission_prob']) + } + + +def load_divi_data(): + file_path = os.path.dirname(os.path.abspath(__file__)) + divi_path = os.path.join(file_path, "../../../data/Germany/pydata") + + data = pd.read_json(os.path.join(divi_path, "state_divi_ma7.json")) + data = data[data['Date'] >= np.datetime64(DATE_TIME)] + data = data[data['Date'] <= np.datetime64(DATE_TIME + datetime.timedelta(days=60))] + data = data.sort_values(by=['ID_State', 'Date']) + divi_data = data.pivot(index='Date', columns='ID_State', values='ICU') + divi_dict = {} + for i in range(len(region_ids)): + divi_dict[f"fed_state{i}"] = divi_data[region_ids[i]].to_numpy()[None, :, None] + return divi_dict + + + +def extract_observables(simulation_results, observable_index=7): + for key in simulation_results.keys(): + if key not in inference_params: + simulation_results[key] = simulation_results[key][:, :, observable_index][..., np.newaxis] + return simulation_results + + +def create_train_data(filename, number_samples=1000): + + simulator = bf.simulators.make_simulator( + [prior, run_germany_nuts1_simulation] + ) + trainings_data = simulator.sample(number_samples) + trainings_data = extract_observables(trainings_data) + with open(filename, 'wb') as f: + pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) + + +def load_pickle(path): + with open(path, "rb") as f: + return pickle.load(f) + +def is_state_key(k: str) -> bool: + return 'fed_state' in k + +def apply_aug(d: dict, aug) -> dict: + return {k: np.clip(aug(v), 0, None) if is_state_key(k) else v for k, v in d.items()} + +def concat_dicts(base: dict, new: dict) -> dict: + missing = set(base) - set(new) + if missing: + raise KeyError(f"new dict missing keys: {sorted(missing)}") + for k in base: + base[k] = np.concatenate([base[k], new[k]]) + return base + + +def combine_results(dict_list): + combined = {} + for d in dict_list: + combined = concat_dicts(combined, d) if combined else d + return combined + +def skip_2weeks(d:dict) -> dict: + return {k: v[:, 14:, :] if is_state_key(k) else v for k, v in d.items()} + +def get_workflow(): + + simulator = bf.make_simulator( + [prior, run_germany_nuts1_simulation] + ) + adapter = ( + bf.Adapter() + .to_array() + .convert_dtype("float64", "float32") + .constrain("damping_values", lower=0.0, upper=1.0) + .constrain("t_E", lower=bounds["t_E"][0], upper=bounds["t_E"][1]) + .constrain("t_ISy", lower=bounds["t_ISy"][0], upper=bounds["t_ISy"][1]) + .constrain("t_ISev", lower=bounds["t_ISev"][0], upper=bounds["t_ISev"][1]) + .constrain("t_Cr", lower=bounds["t_Cr"][0], upper=bounds["t_Cr"][1]) + .constrain("mu_CR", lower=bounds["mu_CR"][0], upper=bounds["mu_CR"][1]) + .constrain("mu_IH", lower=bounds["mu_IH"][0], upper=bounds["mu_IH"][1]) + .constrain("mu_HU", lower=bounds["mu_HU"][0], upper=bounds["mu_HU"][1]) + .constrain("mu_UD", lower=bounds["mu_UD"][0], upper=bounds["mu_UD"][1]) + .constrain("transmission_prob", lower=bounds["transmission_prob"][0], upper=bounds["transmission_prob"][1]) + .concatenate( + ["damping_values", "t_E", "t_ISy", "t_ISev", "t_Cr", + "mu_CR", "mu_IH", "mu_HU", "mu_UD", "transmission_prob"], + into="inference_variables", + axis=-1 + ) + .concatenate(summary_vars, into="summary_variables", axis=-1) + ) + + summary_network = bf.networks.FusionTransformer( + summary_dim=(len(bounds)+16*NUM_DAMPING_POINTS)*2, dropout=0.1 + ) + inference_network = bf.networks.FlowMatching(subnet_kwargs={'widths': (512, 512, 512, 512, 512)}) + + # aug = bf.augmentations.NNPE(spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False) + workflow = bf.BasicWorkflow( + simulator=simulator, + adapter=adapter, + summary_network=summary_network, + inference_network=inference_network, + standardize='all' + # augmentations={f'fed_state{i}': aug for i in range(len(region_ids))} + ) + + return workflow + + +def run_training(name, num_training_files=20): + train_template = name+"/trainings_data{i}_"+name+".pickle" + val_path = f"{name}/validation_data_{name}.pickle" + + aug = bf.augmentations.NNPE( + spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False + ) + + # training data + train_files = [train_template.format(i=i) for i in range(1, 1+num_training_files)] + trainings_data = None + for p in train_files: + d = load_pickle(p) + d = apply_aug(d, aug=aug) # only on region keys + d = skip_2weeks(d) + d['damping_values'] = d['damping_values'].reshape((d['damping_values'].shape[0], -1)) + if trainings_data is None: + trainings_data = d + else: + trainings_data = concat_dicts(trainings_data, d) + + # validation data + validation_data = apply_aug(load_pickle(val_path), aug=aug) + validation_data = skip_2weeks(validation_data) + validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) + + # check data + workflow = get_workflow() + print("summary_variables shape:", workflow.adapter(trainings_data)["summary_variables"].shape) + print("inference_variables shape:", workflow.adapter(trainings_data)["inference_variables"].shape) + + history = workflow.fit_offline( + data=trainings_data, epochs=500, batch_size=64, validation_data=validation_data + ) + + workflow.approximator.save( + filepath=os.path.join(f"{name}/model_{name}.keras") + ) + + plots = workflow.plot_default_diagnostics( + test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True} + ) + plots['losses'].savefig(f'{name}/losses_{name}.png') + plots['recovery'].savefig(f'{name}/recovery_{name}.png') + plots['calibration_ecdf'].savefig(f'{name}/calibration_ecdf_{name}.png') + #plots['z_score_contraction'].savefig(f'{name}/z_score_contraction_{name}.png') + + +def run_inference(name, num_samples=1000, on_synthetic_data=False): + val_path = f"{name}/validation_data_{name}.pickle" + synthetic = "_synthetic" if on_synthetic_data else "" + + aug = bf.augmentations.NNPE( + spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False + ) + + # validation data + validation_data = load_pickle(val_path) # synthetic data + if on_synthetic_data: + # validation data + validation_data = apply_aug(validation_data, aug=aug) + validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) + validation_data_skip2w = skip_2weeks(validation_data) + divi_dict = validation_data + + divi_data = np.concatenate( + [divi_dict[f'fed_state{i}'] for i in range(len(region_ids))], axis=-1 + )[0] # only one dataset + else: + divi_dict = load_divi_data() + validation_data_skip2w = skip_2weeks(divi_dict) + divi_data = np.concatenate( + [divi_dict[f'fed_state{i}'] for i in range(len(region_ids))], axis=-1 + )[0] + + workflow = get_workflow() + workflow.approximator = keras.models.load_model( + filepath=os.path.join(f"{name}/model_{name}.keras") + ) + + if os.path.exists(f'{name}/sims_{name}{synthetic}_with_aug.pickle') and os.path.exists(f'{name}/sims_{name}{synthetic}.pickle'): + simulations = load_pickle(f'{name}/sims_{name}{synthetic}.pickle') + simulations_aug = load_pickle(f'{name}/sims_{name}{synthetic}_with_aug.pickle') + print("loaded simulations from file") + else: + samples = workflow.sample(conditions=validation_data_skip2w, num_samples=num_samples) + samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], num_samples, 16, NUM_DAMPING_POINTS)) + results = [] + for i in range(num_samples): # we only have one dataset for inference here + result = run_germany_nuts1_simulation( + damping_values=samples['damping_values'][0, i], + t_E=samples['t_E'][0, i], t_ISy=samples['t_ISy'][0, i], + t_ISev=samples['t_ISev'][0, i], t_Cr=samples['t_Cr'][0, i], + mu_CR=samples['mu_CR'][0, i], mu_IH=samples['mu_IH'][0, i], + mu_HU=samples['mu_HU'][0, i], mu_UD=samples['mu_UD'][0, i], + transmission_prob=samples['transmission_prob'][0, i] + ) + for key in result.keys(): + result[key] = np.array(result[key])[None, ...] # add sample axis + results.append(result) + results = combine_results(results) + results = extract_observables(results) + results_aug = apply_aug(results, aug=aug) + + # get sims in shape (samples, time, regions) + simulations = np.zeros((num_samples, divi_data.shape[0], divi_data.shape[1])) + simulations_aug = np.zeros((num_samples, divi_data.shape[0], divi_data.shape[1])) + for i in range(num_samples): + simulations[i] = np.concatenate([results[key][i] for key in results.keys()], axis=-1) + simulations_aug[i] = np.concatenate([results_aug[key][i] for key in results.keys()], axis=-1) + + # save sims + with open(f'{name}/sims_{name}{synthetic}.pickle', 'wb') as f: + pickle.dump(simulations, f, pickle.HIGHEST_PROTOCOL) + with open(f'{name}/sims_{name}{synthetic}_with_aug.pickle', 'wb') as f: + pickle.dump(simulations_aug, f, pickle.HIGHEST_PROTOCOL) + + samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], samples['damping_values'].shape[1], -1)) + validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) + + plot = bf.diagnostics.pairs_posterior(samples, priors=validation_data, dataset_id=0) + plot.savefig(f'{name}/pairs_posterior_{name}{synthetic}.png') + + plot_all_regions(simulations, divi_data, name, synthetic, with_aug="") + plot_all_regions(simulations_aug, divi_data, name, synthetic, with_aug="_with_aug") + + fig, axes = plt.subplots(2, 2, figsize=(12, 12), sharex=True, sharey='row', constrained_layout=True) + # Plot without augmentation + plot_aggregated_over_regions( + simulations, true_data=divi_data, label="Region Aggregated Median (No Aug)", ax=axes[0, 0], color="#132a70" + ) + axes[0, 0].set_title("Without Augmentation") + # Plot with augmentation + plot_aggregated_over_regions( + simulations_aug, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[0, 1], color="#132a70" + ) + axes[0, 1].set_title("With Augmentation") + # Plot without augmentation (80% quantile only) + plot_aggregated_over_regions( + simulations, true_data=divi_data, label="Region Aggregated Median (No Aug)", ax=axes[1, 0], color="#132a70", only_80q=True + ) + axes[1, 0].set_title("Without Augmentation (80% Quantile)") + # Plot with augmentation (80% quantile only) + plot_aggregated_over_regions( + simulations_aug, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[1, 1], color="#132a70", only_80q=True + ) + axes[1, 1].set_title("With Augmentation (80% Quantile)") + plt.savefig(f'{name}/region_aggregated_{name}{synthetic}.png') + plt.close() + + fig, axis = plt.subplots(1, 2, figsize=(10, 4), sharex=True, layout="constrained") + ax = calibration_curves_per_region(simulations, divi_data, ax=axis[0]) + ax, stats = calibration_median_mad_over_regions(simulations, divi_data, ax=axis[1]) + plt.savefig(f'{name}/calibration_per_region_{name}{synthetic}.png') + plt.close() + fig, axis = plt.subplots(1, 2, figsize=(10, 4), sharex=True, layout="constrained") + ax = calibration_curves_per_region(simulations_aug, divi_data, ax=axis[0]) + ax, stats = calibration_median_mad_over_regions(simulations_aug, divi_data, ax=axis[1]) + plt.savefig(f'{name}/calibration_per_region_{name}{synthetic}_with_aug.png') + plt.close() + + plot_icu_on_germany(simulations, name, synthetic, with_aug="") + plot_icu_on_germany(simulations_aug, name, synthetic, with_aug="_with_aug") + + simulation_agg = np.sum(simulations, axis=-1, keepdims=True) # sum over regions + simulation_aug_agg = np.sum(simulations_aug, axis=-1, keepdims=True) + + rmse = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes(simulation_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True), normalize=False) + rmse_aug = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes(simulation_aug_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True), normalize=False) + print("Mean RMSE over regions:", rmse["values"].mean()) + print("Mean RMSE over regions (with aug):", rmse_aug["values"].mean()) + + cal_error = bf.diagnostics.metrics.calibration_error(np.swapaxes(simulation_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True)) + cal_error_aug = bf.diagnostics.metrics.calibration_error(np.swapaxes(simulation_aug_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True)) + print("Mean Calibration Error over regions:", cal_error["values"].mean()) + print("Mean Calibration Error over regions (with aug):", cal_error_aug["values"].mean()) + + +if __name__ == "__main__": + name = "nuts1" + + if not os.path.exists(name): + os.makedirs(name) + # create_train_data(filename=f'{name}/validation_data_{name}.pickle', number_samples=100) + # run_training(name=name, num_training_files=20) + run_inference(name=name, on_synthetic_data=False) + run_inference(name=name, on_synthetic_data=True) \ No newline at end of file diff --git a/pycode/examples/simulation/graph_germany_nuts3.py b/pycode/examples/simulation/graph_germany_nuts3.py new file mode 100644 index 0000000000..908813cccf --- /dev/null +++ b/pycode/examples/simulation/graph_germany_nuts3.py @@ -0,0 +1,1070 @@ +############################################################################# +# Copyright (C) 2020-2025 MEmilio +# +# Authors: Carlotta Gerstein +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# +import os +os.environ["KERAS_BACKEND"] = "tensorflow" + +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import datetime +import pickle +from scipy.stats import truncnorm + +from matplotlib.patches import Patch + +import bayesflow as bf +import keras + +import memilio.simulation as mio +import memilio.simulation.osecir as osecir +from memilio.simulation.osecir import Model, interpolate_simulation_result +from memilio.epidata import defaultDict as dd + +import geopandas as gpd + + +excluded_ids = [11001, 11002, 11003, 11004, 11005, 11006, + 11007, 11008, 11009, 11010, 11011, 11012, 16056] +no_icu_ids = [7338, 9374, 9473, 9573] +region_ids = [region_id for region_id in dd.County.keys() + if region_id not in excluded_ids] + +inference_params = ['damping_values', 't_E', 't_C', 't_ISy', 't_ISev', + 't_Cr', 'mu_CR', 'mu_IH', 'mu_HU', 'mu_UD', 'transmission_prob'] +summary_vars = ['state'] + [f'fed_state{i}' for i in range(16)] + [ + f'region{i}' for i in range(len(region_ids)) if region_ids[i] not in no_icu_ids] + +bounds = { + 't_E': (1.0, 5.2), + 't_C': (0.1, 4.2), + 't_ISy': (4.0, 10.0), + 't_ISev': (5.0, 10.0), + 't_Cr': (9.0, 17.0), + 'mu_CR': (0.0, 0.4), + 'mu_IH': (0.0, 0.2), + 'mu_HU': (0.0, 0.4), + 'mu_UD': (0.0, 0.4), + 'transmission_prob': (0.0, 0.2) +} +SPIKE_SCALE = 0.4 +SLAB_SCALE = 0.2 +DATE_TIME = datetime.date(year=2020, month=10, day=1) +NUM_DAMPING_POINTS = 3 + + +def set_fontsize(base_fontsize=17): + fontsize = base_fontsize + plt.rcParams.update({ + 'font.size': fontsize, + 'axes.titlesize': fontsize * 1, + 'axes.labelsize': fontsize, + 'xtick.labelsize': fontsize * 0.8, + 'ytick.labelsize': fontsize * 0.8, + 'legend.fontsize': fontsize * 0.8, + 'font.family': "Arial" + }) + + +plt.style.use('default') + +dpi = 300 + +colors = {"Blue": "#155489", + "Medium blue": "#64A7DD", + "Light blue": "#B4DCF6", + "Lilac blue": "#AECCFF", + "Turquoise": "#76DCEC", + "Light green": "#B6E6B1", + "Medium green": "#54B48C", + "Green": "#5D8A2B", + "Teal": "#20A398", + "Yellow": "#FBD263", + "Orange": "#E89A63", + "Rose": "#CF7768", + "Red": "#A34427", + "Purple": "#741194", + "Grey": "#C0BFBF", + "Dark grey": "#616060", + "Light grey": "#F1F1F1"} + +def plot_region_fit( + data: np.ndarray, + region: int, + true_data=None, + ax=None, + label=None, + color="red", + only_80q=False +): + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data is not None: + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + n_samples, n_time, n_regions = data.shape + if not (0 <= region < n_regions): + raise IndexError + + x = np.arange(n_time) + vals = data[:, :, region] # (samples, time_points) + + qs_80 = np.quantile(vals, q=[0.1, 0.9], axis=0) + qs_90 = np.quantile(vals, q=[0.05, 0.95], axis=0) + qs_95 = np.quantile(vals, q=[0.025, 0.975], axis=0) + + med = np.median(vals, axis=0) + + if ax is None: + fig, ax = plt.subplots() + + ax.plot( + x, med, lw=2, label=label or f"{dd.County[region_ids[region]]}", color=color) + ax.fill_between(x, qs_80[0], qs_80[1], alpha=0.5, + color=color, label="80% CI") + if not only_80q: + ax.fill_between(x, qs_90[0], qs_90[1], alpha=0.3, + color=color, label="90% CI") + ax.fill_between(x, qs_95[0], qs_95[1], alpha=0.1, + color=color, label="95% CI") + + if true_data is not None: + true_vals = true_data[:, region] # (time_points,) + ax.plot(x, true_vals, lw=2, color="black", label="True data") + + ax.set_xlabel("Time") + ax.set_ylabel("ICU") + ax.set_title(f"{dd.County[region_ids[region]]}") + if label is not None: + ax.legend(loc="upper right") + + +def plot_aggregated_over_regions( + data: np.ndarray, + region_agg=np.sum, + true_data=None, + ax=None, + label=None, + color='red', + only_80q=False +): + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data is not None: + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + + # Aggregate over regions + agg_over_regions = region_agg(data, axis=-1) # (samples, time_points) + + qs_80 = np.quantile(agg_over_regions, q=[0.1, 0.9], axis=0) + qs_90 = np.quantile(agg_over_regions, q=[0.05, 0.95], axis=0) + qs_95 = np.quantile(agg_over_regions, q=[0.025, 0.975], axis=0) + + # Aggregate over samples + agg_median = np.median(agg_over_regions, axis=0) # (time_points, ) + + x = np.arange(agg_median.shape[0]) + if ax is None: + fig, ax = plt.subplots() + + ax.plot(x, agg_median, lw=2, + label=label or "Aggregated over regions", color=color) + ax.fill_between(x, qs_80[0], qs_80[1], alpha=0.5, + color=color, label="80% CI") + if not only_80q: + ax.fill_between(x, qs_90[0], qs_90[1], alpha=0.3, + color=color, label="90% CI") + ax.fill_between(x, qs_95[0], qs_95[1], alpha=0.1, + color=color, label="95% CI") + if true_data is not None: + true_vals = region_agg(true_data, axis=-1) # (time_points,) + ax.plot(x, true_vals, lw=2, color="black", label="True data") + + ax.set_xlabel("Time") + ax.set_ylabel("ICU") + +def plot_icu_on_germany(simulations, name, synthetic, with_aug): + med = np.median(simulations, axis=0) + + population = pd.read_json('data/Germany/pydata/county_current_population.json') + values = med / population['Population'].to_numpy()[None, :] * 100000 + + + map_data = gpd.read_file(os.path.join(os.getcwd(), 'tools/vg2500_12-31.utm32s.shape/vg2500/VG2500_KRS.shp')) + fedstate_data = gpd.read_file(os.path.join(os.getcwd(), 'tools/vg2500_12-31.utm32s.shape/vg2500/VG2500_LAN.shp')) + + plot_map(values[0], map_data, fedstate_data, "Median ICU", f"{name}/median_icu_germany_initial_{name}{synthetic}{with_aug}") + plot_map(values[-1], map_data, fedstate_data, "Median ICU", f"{name}/median_icu_germany_final_{name}{synthetic}{with_aug}") + +def plot_map(values, map_data, fedstate_data, label, filename): + map_data[label] = map_data['ARS'].map(dict(zip([f"{region_id:05d}" for region_id in region_ids], values))) + + fig, ax = plt.subplots(figsize=(12, 14)) + map_data.plot( + column=f"{label}", + cmap='Reds', + linewidth=0.5, + ax=ax, + edgecolor='0.6', + legend=True, + legend_kwds={'label': f"{label} per County", 'shrink': 0.6}, + ) + fedstate_data.boundary.plot(ax=ax, color='black', linewidth=1) + ax.set_title(f"{label} per County") + ax.axis('off') + plt.savefig(filename, bbox_inches='tight', dpi=dpi) + + +def plot_aggregated_to_federal_states(data, true_data, name, synthetic, with_aug): + fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(25, 25), layout="constrained") + ax = ax.flatten() + for state in range(16): + idxs = [i for i, region_id in enumerate(region_ids) if region_id // 1000 == state + 1] + plot_aggregated_over_regions( + data[:, :, idxs], # Add a dummy region axis for compatibility + true_data=true_data[:, idxs] if true_data is not None else None, + ax=ax[state], + label=f"State {state + 1}", + color=colors["Red"] + ) + plt.savefig(f'{name}/federal_states_{name}{synthetic}{with_aug}.png', dpi=dpi) + plt.close() + +# plot simulations for all regions in 10x4 blocks +def plot_all_regions(simulations, divi_data, name, synthetic, with_aug): + n_regions = simulations.shape[-1] + n_cols = 4 + n_rows = 10 + n_blocks = (n_regions + n_cols * n_rows - 1) // (n_cols * n_rows) + + for block in range(n_blocks): + start_idx = block * n_cols * n_rows + end_idx = min(start_idx + n_cols * n_rows, n_regions) + fig, ax = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(15, 25), layout="constrained") + ax = ax.flatten() + for i, region_idx in enumerate(range(start_idx, end_idx)): + plot_region_fit( + simulations, region=region_idx, true_data=divi_data, label="Median", ax=ax[i], color=colors["Red"] + ) + # Hide unused subplots + for i in range(end_idx - start_idx, len(ax)): + ax[i].axis("off") + plt.savefig(f'{name}/regions_block_{block + 1}_{name}{synthetic}{with_aug}.png', dpi=dpi) + plt.close() + + +def calibration_curves_per_region( + data: np.ndarray, + true_data: np.ndarray, + levels=np.linspace(0.01, 0.99, 20), + ax=None, + max_regions=None, + cmap=plt.cm.Blues, + linewidth=1.5, + legend=True, + with_ideal=True, +): + """ + Per-region calibration curves, each region in a different shade of a colormap. + + data: (samples, time, regions) + true_data: (time, regions) + max_regions: limit number of regions shown + cmap: matplotlib colormap for line shades + """ + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + + n_samples, n_time, n_regions = data.shape + if max_regions is None: + R = n_regions + else: + R = min(max_regions, n_regions) + + if ax is None: + fig, ax = plt.subplots(figsize=(5, 4)) + + colors = [cmap(i / (R + 1)) for i in range(1, R + 1)] + + x = np.asarray(levels) + for r, col in zip(range(R), colors): + emp = [] + for nominal in levels: + q_low = (1.0 - nominal) / 2.0 + q_high = 1.0 - q_low + lo = np.quantile(data[:, :, r], q_low, axis=0) + hi = np.quantile(data[:, :, r], q_high, axis=0) + hits = (true_data[:, r] >= lo) & (true_data[:, r] <= hi) + emp.append(hits.mean()) + emp = np.asarray(emp) + # , label=f"Region {r+1}") + ax.plot(x, emp, lw=linewidth, color=col, alpha=0.5) + + if with_ideal: + ideal_line = ax.plot([0, 1], [0, 1], linestyle="--", + lw=1.2, color="black", label="Ideal")[0] + else: + ideal_line = None + + if legend: + # Custom legend: one patch for regions, one line for ideal + region_patch = Patch(color=colors[-1], label="Regions") + ax.legend(handles=[region_patch, ideal_line], + frameon=True, ncol=1) + + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.set_xlabel("Nominal level") + ax.set_ylabel("Empirical coverage") + ax.set_title("Calibration per region") + return ax + + +def calibration_median_mad_over_regions( + data: np.ndarray, + true_data: np.ndarray, + levels=np.linspace(0.01, 0.99, 20), + ax=None, + color="tab:blue", + alpha=0.25, + linewidth=2.0, + with_ideal=True, +): + """ + Compute per region empirical coverage at each nominal level, + then summarize across regions with median and MAD. + + data: (samples, time, regions) + true_data: (time, regions) + """ + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + + n_samples, n_time, n_regions = data.shape + L = len(levels) + per_region = np.empty((n_regions, L), dtype=float) + + # per level coverage per region + for j, nominal in enumerate(levels): + q_low = (1.0 - nominal) / 2.0 + q_high = 1.0 - q_low + lo = np.quantile(data, q_low, axis=0) # (time, regions) + hi = np.quantile(data, q_high, axis=0) # (time, regions) + hits = (true_data >= lo) & (true_data <= hi) # (time, regions) + # mean over time for each region + per_region[:, j] = hits.mean(axis=0) + + med = np.median(per_region, axis=0) # (levels,) + mad = np.median(np.abs(per_region - med[None, :]), axis=0) + + if ax is None: + fig, ax = plt.subplots(figsize=(5, 4)) + + x = np.asarray(levels) + ax.fill_between(x, med - mad, med + mad, + alpha=alpha, color=color, label=None) + ax.plot(x, med, lw=linewidth, color=color, label="Median across regions") + + if with_ideal: + ax.plot([0, 1], [0, 1], linestyle="--", + lw=1.2, color="black", label="Ideal") + + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.set_xlabel("Nominal level") + ax.set_ylabel("Empirical coverage") + ax.set_title("Calibration median and MAD across regions") + ax.legend() + return ax, {"levels": x, "median": med, "mad": mad} + + +def plot_damping_values(damping_values, name, synthetic): + + med = np.median(damping_values, axis=0) + mad = np.median(np.abs(damping_values - med), axis=0) + + # Extend for step plotting + med_extended = np.hstack([med, med[:, -1][:, None]]) + mad_extended = np.hstack([mad, mad[:, -1][:, None]]) + + # Plot damping values per region + fig, axes = plt.subplots(4, 4, figsize=(15, 10), constrained_layout=True) + axes = axes.flatten() + x = np.arange(15, 61, 15) # Time steps from 15 to 60 + + for i, ax in enumerate(axes): + if i < 16: + ax.stairs(med[i], edges=x, lw=2, color='red', baseline=None) + ax.fill_between( + x, med_extended[i] - mad_extended[i], med_extended[i] + mad_extended[i], + alpha=0.25, color='red', step='post' + ) + ax.set_title(f"{dd.State[i+1]}") + ax.set_xlabel("Time") + ax.set_ylabel("Damping Value") + else: + ax.axis('off') # Hide unused subplots + + plt.suptitle("Damping Values per Region") + plt.savefig(f"{name}/damping_values{name}{synthetic}.png", dpi=dpi) + + # Combined plot for all regions + fig, ax = plt.subplots(figsize=(10, 6)) + cmap = plt.cm.get_cmap("viridis", 16) # Colormap with 16 distinct colors + + for i in range(16): + ax.stairs( + med[i], edges=x, lw=2, label=f"{dd.State[i+1]}", + color=cmap(i), baseline=None + ) + + ax.set_title("Damping Values per Region (Combined)") + ax.set_xlabel("Time") + ax.set_ylabel("Damping Value") + ax.legend(loc="upper right", ncol=2) + plt.savefig(f"{name}/damping_values_combined{name}{synthetic}.png", dpi=dpi) + +def run_prior_predictive_check(name): + validation_data = load_pickle(f'{name}/validation_data_{name}.pickle') # synthetic data + + divi_region_keys = region_keys_sorted(validation_data) + num_samples = validation_data['region0'].shape[0] + time_points = validation_data['region0'].shape[1] + + true_data = load_divi_data() + true_data = np.concatenate( + [true_data[key] for key in divi_region_keys], axis=-1 + )[0] + # get sims in shape (samples, time, regions) + simulations = np.zeros((num_samples, time_points, len(divi_region_keys))) + for i in range(num_samples): + simulations[i] = np.concatenate([validation_data[key][i] for key in divi_region_keys], axis=-1) + + fig, ax = plt.subplots(figsize=(8, 6)) + plot_aggregated_over_regions( + simulations, true_data=true_data, label="Region Aggregated Median (No Aug)", ax=ax, color=colors["Red"] + ) + ax.set_title("Prior predictive check - Aggregated over regions") + plt.savefig(f'{name}/prior_predictive_check_{name}.png', dpi=dpi) + +def compare_median_sim_to_mean_param_sim(name): + num_samples = 10 + + divi_dict = load_divi_data() + validation_data_skip2w = skip_2weeks(divi_dict) + aggregate_states(validation_data_skip2w) + divi_region_keys = region_keys_sorted(divi_dict) + divi_data = np.concatenate( + [divi_dict[key] for key in divi_region_keys], axis=-1 + )[0] + + simulations_aug = load_pickle(f'{name}/sims_{name}_with_aug.pickle') + + workflow = get_workflow() + workflow.approximator = keras.models.load_model( + filepath=os.path.join(f"{name}/model_{name}.keras") + ) + + samples = workflow.sample(conditions=validation_data_skip2w, num_samples=num_samples) + for key in inference_params: + samples[key] = np.median(samples[key], axis=1) + samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], 16, NUM_DAMPING_POINTS)) + result = run_germany_nuts3_simulation( + damping_values=samples['damping_values'][0], + t_E=samples['t_E'][0], t_ISy=samples['t_ISy'][0], + t_ISev=samples['t_ISev'][0], t_Cr=samples['t_Cr'][0], + mu_CR=samples['mu_CR'][0], mu_IH=samples['mu_IH'][0], + mu_HU=samples['mu_HU'][0], mu_UD=samples['mu_UD'][0], + transmission_prob=samples['transmission_prob'][0] + ) + for key in result.keys(): + result[key] = np.array(result[key])[None, ...] # add sample axis + + result = extract_observables(result) + divi_region_keys = region_keys_sorted(result) + result = np.concatenate( + [result[key] if key in result else np.zeros_like(result[divi_region_keys[0]]) for key in divi_region_keys], + axis=-1 + )[0] + + + fig, ax = plt.subplots(figsize=(8, 6)) + plot_aggregated_over_regions( + simulations_aug, true_data=result, label="Compare median of sim to sim of median params", ax=ax, color=colors["Red"] + ) + ax.set_title("Comparison of Median Simulation to Simulation of Median Parameters") + plt.savefig(f'{name}/compare_median_sim13_to_mean_param_sim_{name}.png', dpi=dpi) + plt.close() + + +class Simulation: + """ """ + + def __init__(self, data_dir, start_date, results_dir): + self.num_groups = 1 + self.data_dir = data_dir + self.start_date = start_date + self.results_dir = results_dir + if not os.path.exists(self.results_dir): + os.makedirs(self.results_dir) + + def set_covid_parameters(self, model, t_E, t_C, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): + """ + + :param model: + + """ + model.parameters.TimeExposed[mio.AgeGroup(0)] = t_E + model.parameters.TimeInfectedNoSymptoms[mio.AgeGroup(0)] = t_C + model.parameters.TimeInfectedSymptoms[mio.AgeGroup(0)] = t_ISy + model.parameters.TimeInfectedSevere[mio.AgeGroup(0)] = t_ISev + model.parameters.TimeInfectedCritical[mio.AgeGroup(0)] = t_Cr + + # probabilities + model.parameters.TransmissionProbabilityOnContact[mio.AgeGroup( + 0)] = transmission_prob + model.parameters.RelativeTransmissionNoSymptoms[mio.AgeGroup(0)] = 1 + + model.parameters.RecoveredPerInfectedNoSymptoms[mio.AgeGroup( + 0)] = mu_CR + model.parameters.SeverePerInfectedSymptoms[mio.AgeGroup(0)] = mu_IH + model.parameters.CriticalPerSevere[mio.AgeGroup(0)] = mu_HU + model.parameters.DeathsPerCritical[mio.AgeGroup(0)] = mu_UD + + # start day is set to the n-th day of the year + model.parameters.StartDay = self.start_date.timetuple().tm_yday + + model.parameters.Seasonality = mio.UncertainValue(0.2) + + def set_contact_matrices(self, model): + """ + + :param model: + + """ + contact_matrices = mio.ContactMatrixGroup(1, self.num_groups) + + baseline = np.ones((self.num_groups, self.num_groups)) * 7.95 + minimum = np.zeros((self.num_groups, self.num_groups)) + contact_matrices[0] = mio.ContactMatrix(baseline, minimum) + model.parameters.ContactPatterns.cont_freq_mat = contact_matrices + + def set_npis(self, params, end_date, damping_values): + """ + + :param params: + :param end_date: + + """ + start_damping_1 = DATE_TIME + datetime.timedelta(days=15) + start_damping_2 = DATE_TIME + datetime.timedelta(days=30) + start_damping_3 = DATE_TIME + datetime.timedelta(days=45) + + if start_damping_1 < end_date: + start_date = (start_damping_1 - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping( + mio.Damping(np.r_[damping_values[0]], t=start_date)) + + if start_damping_2 < end_date: + start_date = (start_damping_2 - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping( + mio.Damping(np.r_[damping_values[1]], t=start_date)) + + if start_damping_3 < end_date: + start_date = (start_damping_3 - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping( + mio.Damping(np.r_[damping_values[2]], t=start_date)) + + def get_graph(self, end_date, t_E, t_C, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): + """ + + :param end_date: + + """ + print("Initializing model...") + model = Model(self.num_groups) + self.set_covid_parameters( + model, t_E, t_C, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) + self.set_contact_matrices(model) + print("Model initialized.") + + graph = osecir.ModelGraph() + + scaling_factor_infected = [2.5] + scaling_factor_icu = 1.0 + + data_dir_Germany = os.path.join(self.data_dir, "Germany") + mobility_data_file = os.path.join( + data_dir_Germany, "mobility", "commuter_mobility_2022.txt") + pydata_dir = os.path.join(data_dir_Germany, "pydata") + + path_population_data = os.path.join(pydata_dir, + "county_current_population.json") + + print("Setting nodes...") + mio.osecir.set_nodes( + model.parameters, + mio.Date(self.start_date.year, + self.start_date.month, self.start_date.day), + mio.Date(end_date.year, + end_date.month, end_date.day), pydata_dir, + path_population_data, True, graph, scaling_factor_infected, + scaling_factor_icu, 0, 0, False) + + print("Setting edges...") + mio.osecir.set_edges(mobility_data_file, graph, 1) + + print("Graph created.") + + return graph + + def run(self, num_days_sim, damping_values, t_E, t_C, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob, save_graph=True): + """ + + :param num_days_sim: + :param num_runs: (Default value = 10) + :param save_graph: (Default value = True) + :param create_gif: (Default value = True) + + """ + mio.set_log_level(mio.LogLevel.Warning) + end_date = self.start_date + datetime.timedelta(days=num_days_sim) + + graph = self.get_graph(end_date, t_E, t_C, t_ISy, t_ISev, + t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) + + mobility_graph = osecir.MobilityGraph() + for node_idx in range(graph.num_nodes): + node = graph.get_node(node_idx) + + self.set_npis( + node.property.parameters, + end_date, + damping_values[node.id // 1000 - 1] + ) + mobility_graph.add_node(node.id, node.property) + for edge_idx in range(graph.num_edges): + mobility_graph.add_edge( + graph.get_edge(edge_idx).start_node_idx, + graph.get_edge(edge_idx).end_node_idx, + graph.get_edge(edge_idx).property) + mobility_sim = osecir.MobilitySimulation(mobility_graph, t0=0, dt=0.5) + mobility_sim.advance(num_days_sim) + + results = {} + for node_idx in range(mobility_sim.graph.num_nodes): + node = mobility_sim.graph.get_node(node_idx) + if node.id in no_icu_ids: + results[f'no_icu_region{node_idx}'] = osecir.interpolate_simulation_result( + node.property.result) + else: + results[f'region{node_idx}'] = osecir.interpolate_simulation_result( + node.property.result) + + return results + + +def run_germany_nuts3_simulation(damping_values, t_E, t_C, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): + mio.set_log_level(mio.LogLevel.Warning) + file_path = os.path.dirname(os.path.abspath(__file__)) + + sim = Simulation( + data_dir=os.path.join(file_path, "../../../data"), + start_date=DATE_TIME, + results_dir=os.path.join(file_path, "../../../results_osecir")) + num_days_sim = 60 + + results = sim.run(num_days_sim, damping_values, t_E, t_C, t_ISy, + t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) + + return results + +def prior(): + damping_values = np.zeros((NUM_DAMPING_POINTS, 16)) + for i in range(NUM_DAMPING_POINTS): + mean = np.random.uniform(0, 1) + scale = 0.1 + a, b = (0 - mean) / scale, (1 - mean) / scale + damping_values[i] = truncnorm.rvs( + a=a, b=b, loc=mean, scale=scale, size=16 + ) + return { + 'damping_values': np.transpose(damping_values), + 't_E': np.random.uniform(*bounds['t_E']), + 't_C': np.random.uniform(*bounds['t_C']), + 't_ISy': np.random.uniform(*bounds['t_ISy']), + 't_ISev': np.random.uniform(*bounds['t_ISev']), + 't_Cr': np.random.uniform(*bounds['t_Cr']), + 'mu_CR': np.random.uniform(*bounds['mu_CR']), + 'mu_IH': np.random.uniform(*bounds['mu_IH']), + 'mu_HU': np.random.uniform(*bounds['mu_HU']), + 'mu_UD': np.random.uniform(*bounds['mu_UD']), + 'transmission_prob': np.random.uniform(*bounds['transmission_prob']) + } + + +def load_divi_data(): + file_path = os.path.dirname(os.path.abspath(__file__)) + divi_path = os.path.join(file_path, "../../../data/Germany/pydata") + + data = pd.read_json(os.path.join(divi_path, "county_divi_ma7.json")) + data = data[data['Date'] >= np.datetime64(DATE_TIME)] + data = data[data['Date'] <= np.datetime64(DATE_TIME + datetime.timedelta(days=60))] + data = data.sort_values(by=['ID_County', 'Date']) + divi_data = data.pivot(index='Date', columns='ID_County', values='ICU') + divi_dict = {} + for i, region_id in enumerate(region_ids): + if region_id not in no_icu_ids: + divi_dict[f"region{i}"] = divi_data[region_id].to_numpy()[None, :, None] + else: + divi_dict[f"no_icu_region{i}"] = np.zeros((1, divi_data.shape[0], 1)) + return divi_dict + + +def load_extrapolated_case_data(): + file_path = os.path.dirname(os.path.abspath(__file__)) + case_path = os.path.join(file_path, "../../../data/Germany/pydata") + + file = h5py.File(os.path.join(case_path, "Results_rki.h5")) + divi_dict = {} + for i, region_id in enumerate(region_ids): + if region_id not in no_icu_ids: + divi_dict[f"region{i}"] = np.array(file[f"{region_id}"]['Total'][:, 4])[None, :, None] + else: + divi_dict[f"no_icu_region{i}"] = np.zeros((1, np.array(file[f"{region_id}"]['Total']).shape[0], 1)) + return divi_dict + + +def extract_observables(simulation_results, observable_index=7): + for key in simulation_results.keys(): + if key not in inference_params: + simulation_results[key] = simulation_results[key][:, :, observable_index][..., np.newaxis] + return simulation_results + + +def create_train_data(filename, number_samples=1000): + + simulator = bf.simulators.make_simulator( + [prior, run_germany_nuts3_simulation] + ) + trainings_data = simulator.sample(number_samples) + trainings_data = extract_observables(trainings_data) + with open(filename, 'wb') as f: + pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) + + +def load_pickle(path): + with open(path, "rb") as f: + return pickle.load(f) + +def is_region_key(k: str) -> bool: + return 'region' in k + +def apply_aug(d: dict, aug) -> dict: + return {k: np.clip(aug(v), 0, None) if is_region_key(k) else v for k, v in d.items()} + +def concat_dicts(base: dict, new: dict) -> dict: + missing = set(base) - set(new) + if missing: + raise KeyError(f"new dict missing keys: {sorted(missing)}") + for k in base: + base[k] = np.concatenate([base[k], new[k]]) + return base + +def region_keys_sorted(d: dict): + def idx(k): + # handles "regionN" and "no_icu_regionN" + return int(k.split("region")[-1]) + return sorted([k for k in d if is_region_key(k)], key=idx) + +def aggregate_states(d: dict) -> None: + n_regions = len(region_ids) + # per state + for state in range(16): + idxs = [ + r for r in range(n_regions) + if region_ids[r] // 1000 == state + 1 + ] + d[f"fed_state{state}"] = np.sum([d[f"region{r}"] if region_ids[r] not in no_icu_ids else d[f"no_icu_region{r}"] for r in idxs], axis=0) + # all allowed regions + d["state"] = np.sum([d[f"fed_state{r}"] for r in range(16)], axis=0) + +def combine_results(dict_list): + combined = {} + for d in dict_list: + combined = concat_dicts(combined, d) if combined else d + return combined + +def skip_2weeks(d:dict) -> dict: + return {k: v[:, 14:, :] if is_region_key(k) else v for k, v in d.items()} + +def get_workflow(): + + simulator = bf.make_simulator( + [prior, run_germany_nuts3_simulation] + ) + adapter = ( + bf.Adapter() + .to_array() + .convert_dtype("float64", "float32") + .constrain("damping_values", lower=0.0, upper=1.0) + .constrain("t_E", lower=bounds["t_E"][0], upper=bounds["t_E"][1]) + .constrain("t_C", lower=bounds["t_C"][0], upper=bounds["t_C"][1]) + .constrain("t_ISy", lower=bounds["t_ISy"][0], upper=bounds["t_ISy"][1]) + .constrain("t_ISev", lower=bounds["t_ISev"][0], upper=bounds["t_ISev"][1]) + .constrain("t_Cr", lower=bounds["t_Cr"][0], upper=bounds["t_Cr"][1]) + .constrain("mu_CR", lower=bounds["mu_CR"][0], upper=bounds["mu_CR"][1]) + .constrain("mu_IH", lower=bounds["mu_IH"][0], upper=bounds["mu_IH"][1]) + .constrain("mu_HU", lower=bounds["mu_HU"][0], upper=bounds["mu_HU"][1]) + .constrain("mu_UD", lower=bounds["mu_UD"][0], upper=bounds["mu_UD"][1]) + .constrain("transmission_prob", lower=bounds["transmission_prob"][0], upper=bounds["transmission_prob"][1]) + .concatenate( + ["damping_values", "t_E", "t_C", "t_ISy", "t_ISev", "t_Cr", + "mu_CR", "mu_IH", "mu_HU", "mu_UD", "transmission_prob"], + into="inference_variables", + axis=-1 + ) + .concatenate(summary_vars, into="summary_variables", axis=-1) + ) + + summary_network = bf.networks.FusionTransformer( + summary_dim=(len(bounds)+16*NUM_DAMPING_POINTS)*2, dropout=0.1 + ) + inference_network = bf.networks.FlowMatching(subnet_kwargs={'widths': (512, 512, 512, 512, 512)}) + + # aug = bf.augmentations.NNPE(spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False) + workflow = bf.BasicWorkflow( + simulator=simulator, + adapter=adapter, + summary_network=summary_network, + inference_network=inference_network, + standardize='all' + # augmentations={f'region{i}': aug for i in range(len(region_ids)) if region_ids[i] not in no_icu_ids} + # aggregation of the states would need to be recomputed every time a different noise realization is applied + ) + + return workflow + + +def run_training(name, num_training_files=20): + train_template = name+"/trainings_data{i}_"+name+".pickle" + val_path = f"{name}/validation_data_{name}.pickle" + + aug = bf.augmentations.NNPE( + spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False + ) + + # training data + train_files = [train_template.format(i=i) for i in range(1, 1+num_training_files)] + trainings_data = None + for p in train_files: + d = load_pickle(p) + d = apply_aug(d, aug=aug) # only on region keys + d = skip_2weeks(d) + d['damping_values'] = d['damping_values'].reshape((d['damping_values'].shape[0], -1)) + if trainings_data is None: + trainings_data = d + else: + trainings_data = concat_dicts(trainings_data, d) + aggregate_states(trainings_data) + + # validation data + validation_data = apply_aug(load_pickle(val_path), aug=aug) + validation_data = skip_2weeks(validation_data) + aggregate_states(validation_data) + validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) + + # check data + workflow = get_workflow() + print("summary_variables shape:", workflow.adapter(trainings_data)["summary_variables"].shape) + print("inference_variables shape:", workflow.adapter(trainings_data)["inference_variables"].shape) + + history = workflow.fit_offline( + data=trainings_data, epochs=500, batch_size=64, validation_data=validation_data + ) + + workflow.approximator.save( + filepath=os.path.join(f"{name}/model_{name}.keras") + ) + + plots = workflow.plot_default_diagnostics( + test_data=validation_data, calibration_ecdf_kwargs={'difference': True, 'stacked': True} + ) + plots['losses'].savefig(f'{name}/losses_{name}.png', dpi=dpi) + plots['recovery'].savefig(f'{name}/recovery_{name}.png', dpi=dpi) + plots['calibration_ecdf'].savefig(f'{name}/calibration_ecdf_{name}.png', dpi=dpi) + #plots['z_score_contraction'].savefig(f'{name}/z_score_contraction_{name}.png', dpi=dpi) + + +def run_inference(name, num_samples=1000, on_synthetic_data=False): + val_path = f"{name}/validation_data_{name}.pickle" + synthetic = "_synthetic" if on_synthetic_data else "" + + aug = bf.augmentations.NNPE( + spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False + ) + + # validation data + validation_data = load_pickle(val_path) # synthetic data + if on_synthetic_data: + # validation data + validation_data = apply_aug(validation_data, aug=aug) + validation_data['damping_values'] = validation_data['damping_values'].reshape((validation_data['damping_values'].shape[0], -1)) + validation_data_skip2w = skip_2weeks(validation_data) + aggregate_states(validation_data_skip2w) + divi_dict = validation_data + divi_region_keys = region_keys_sorted(divi_dict) + + divi_data = np.concatenate( + [divi_dict[key] for key in divi_region_keys], axis=-1 + )[0] # only one dataset + else: + divi_dict = load_divi_data() + validation_data_skip2w = skip_2weeks(divi_dict) + aggregate_states(validation_data_skip2w) + divi_region_keys = region_keys_sorted(divi_dict) + divi_data = np.concatenate( + [divi_dict[key] for key in divi_region_keys], axis=-1 + )[0] + + workflow = get_workflow() + workflow.approximator = keras.models.load_model( + filepath=os.path.join(f"{name}/model_{name}.keras") + ) + + if os.path.exists(f'{name}/sims_{name}{synthetic}_with_aug.pickle') and os.path.exists(f'{name}/sims_{name}{synthetic}.pickle') and os.path.exists(f'{name}/samples_{name}{synthetic}.pickle'): + simulations = load_pickle(f'{name}/sims_{name}{synthetic}.pickle') + simulations_aug = load_pickle( + f'{name}/sims_{name}{synthetic}_with_aug.pickle') + samples = load_pickle(f'{name}/samples_{name}{synthetic}.pickle') + print("loaded simulations from file") + else: + samples = workflow.sample(conditions=validation_data_skip2w, num_samples=num_samples) + with open(f'{name}/samples_{name}{synthetic}.pickle', 'wb') as f: + pickle.dump(samples, f, pickle.HIGHEST_PROTOCOL) + samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], num_samples, 16, NUM_DAMPING_POINTS)) + results = [] + for i in range(num_samples): # we only have one dataset for inference here + result = run_germany_nuts3_simulation( + damping_values=samples['damping_values'][0, i], + t_E=samples['t_E'][0, i], t_C=samples['t_C'][0, i], t_ISy=samples['t_ISy'][0, i], + t_ISev=samples['t_ISev'][0, i], t_Cr=samples['t_Cr'][0, i], + mu_CR=samples['mu_CR'][0, i], mu_IH=samples['mu_IH'][0, i], + mu_HU=samples['mu_HU'][0, i], mu_UD=samples['mu_UD'][0, i], + transmission_prob=samples['transmission_prob'][0, i] + ) + for key in result.keys(): + result[key] = np.array(result[key])[None, ...] # add sample axis + results.append(result) + results = combine_results(results) + results = extract_observables(results) + results_aug = apply_aug(results, aug=aug) + + # get sims in shape (samples, time, regions) + simulations = np.zeros((num_samples, divi_data.shape[0], divi_data.shape[1])) + simulations_aug = np.zeros((num_samples, divi_data.shape[0], divi_data.shape[1])) + for i in range(num_samples): + simulations[i] = np.concatenate([results[key][i] for key in divi_region_keys], axis=-1) + simulations_aug[i] = np.concatenate([results_aug[key][i] for key in divi_region_keys], axis=-1) + + # save sims + with open(f'{name}/sims_{name}{synthetic}.pickle', 'wb') as f: + pickle.dump(simulations, f, pickle.HIGHEST_PROTOCOL) + with open(f'{name}/sims_{name}{synthetic}_with_aug.pickle', 'wb') as f: + pickle.dump(simulations_aug, f, pickle.HIGHEST_PROTOCOL) + + samples['damping_values'] = samples['damping_values'].reshape( + (samples['damping_values'].shape[0], samples['damping_values'].shape[1], -1)) + validation_data['damping_values'] = validation_data['damping_values'].reshape( + (validation_data['damping_values'].shape[0], -1)) + + plot = bf.diagnostics.pairs_posterior( + samples, priors=validation_data, dataset_id=0) + plot.savefig(f'{name}/pairs_posterior_{name}{synthetic}.png', dpi=dpi) + + samples['damping_values'] = samples['damping_values'].reshape((samples['damping_values'].shape[0], num_samples, 16, NUM_DAMPING_POINTS)) + plot_damping_values(samples['damping_values'][0], + name=name, synthetic=synthetic) + + plot_all_regions(simulations, divi_data, name, synthetic, with_aug="") + plot_all_regions(simulations_aug, divi_data, name, synthetic, with_aug="_with_aug") + + plot_aggregated_to_federal_states(simulations, divi_data, name, synthetic, with_aug="") + plot_aggregated_to_federal_states(simulations_aug, divi_data, name, synthetic, with_aug="_with_aug") + + fig, axes = plt.subplots(1, 2, figsize=(12, 6), constrained_layout=True) + # Plot with augmentation + plot_aggregated_over_regions( + simulations_aug, true_data=divi_data, label="Region Aggregated Median", ax=axes[0], color=colors["Red"] + ) + # Plot with augmentation (80% quantile only) + plot_aggregated_over_regions( + simulations_aug, true_data=divi_data, label="Region Aggregated Median", ax=axes[1], color=colors["Red"], only_80q=True + ) + lines, labels = axes[0].get_legend_handles_labels() + fig.legend(lines, labels) + plt.savefig(f'{name}/region_aggregated_{name}{synthetic}.png', dpi=dpi) + plt.close() + + fig, axis = plt.subplots(1, 2, figsize=(10, 4), sharex=True, layout="constrained") + ax = calibration_curves_per_region(simulations, divi_data, ax=axis[0]) + ax, stats = calibration_median_mad_over_regions(simulations, divi_data, ax=axis[1]) + plt.savefig(f'{name}/calibration_per_region_{name}{synthetic}.png', dpi=dpi) + plt.close() + fig, axis = plt.subplots(1, 2, figsize=(10, 4), sharex=True, layout="constrained") + ax = calibration_curves_per_region(simulations_aug, divi_data, ax=axis[0]) + ax, stats = calibration_median_mad_over_regions(simulations_aug, divi_data, ax=axis[1]) + plt.savefig(f'{name}/calibration_per_region_{name}{synthetic}_with_aug.png', dpi=dpi) + plt.close() + + plot_icu_on_germany(simulations, name, synthetic, with_aug="") + plot_icu_on_germany(simulations_aug, name, synthetic, with_aug="_with_aug") + + simulation_agg = np.sum(simulations, axis=-1, keepdims=True) # sum over regions + simulation_aug_agg = np.sum(simulations_aug, axis=-1, keepdims=True) + + rmse = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes(simulation_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True), normalize=False) + rmse_aug = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes(simulation_aug_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True), normalize=False) + print("Mean RMSE over regions:", rmse["values"].mean()) + print("Mean RMSE over regions (with aug):", rmse_aug["values"].mean()) + + cal_error = bf.diagnostics.metrics.calibration_error(np.swapaxes(simulation_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True)) + cal_error_aug = bf.diagnostics.metrics.calibration_error(np.swapaxes(simulation_aug_agg, 0,1), np.sum(divi_data, axis=-1, keepdims=True)) + print("Mean Calibration Error over regions:", cal_error["values"].mean()) + print("Mean Calibration Error over regions (with aug):", cal_error_aug["values"].mean()) + + +if __name__ == "__main__": + name = "3dampings_lessnoise_newnetwork" + + set_fontsize() + + if not os.path.exists(name): + os.makedirs(name) + # create_train_data(filename=f'{name}/validation_data_{name}.pickle', number_samples=100) + # run_training(name=name, num_training_files=10) + run_inference(name=name, on_synthetic_data=True) + run_inference(name=name, on_synthetic_data=False) + # run_prior_predictive_check(name=name) + # compare_median_sim_to_mean_param_sim(name=name) \ No newline at end of file diff --git a/pycode/examples/simulation/graph_spain_nuts3.py b/pycode/examples/simulation/graph_spain_nuts3.py new file mode 100644 index 0000000000..4eea30d073 --- /dev/null +++ b/pycode/examples/simulation/graph_spain_nuts3.py @@ -0,0 +1,1005 @@ +############################################################################# +# Copyright (C) 2020-2025 MEmilio +# +# Authors: Carlotta Gerstein +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# +import geopandas as gpd +from memilio.epidata import defaultDict as dd +from memilio.simulation.osecir import Model, interpolate_simulation_result +import memilio.simulation.osecir as osecir +import memilio.simulation as mio +import keras +import bayesflow as bf +from matplotlib.patches import Patch +from scipy.stats import truncnorm +import pickle +import datetime +import matplotlib.pyplot as plt +import pandas as pd +import numpy as np +import os +os.environ["KERAS_BACKEND"] = "tensorflow" + + +excluded_ids = [530, 630, 640, 701, 702] +region_ids = [region_id for region_id in dd.Provincias.keys() + if region_id not in excluded_ids] +excluded_comunidades = [53, 63, 64, 70] +comunidades = [comunidad for comunidad in dd.Comunidades.keys( +) if comunidad not in excluded_comunidades] +inference_params = ['damping_values', 't_E', 't_ISy', 't_ISev', + 't_Cr', 'mu_CR', 'mu_IH', 'mu_HU', 'mu_UD', 'transmission_prob'] +summary_vars = ['state'] + [f'comunidad{i}' for i in range(len(comunidades))] + [ + f'region{i}' for i in range(len(region_ids))] + + +def set_fontsize(base_fontsize=17): + fontsize = base_fontsize + plt.rcParams.update({ + 'font.size': fontsize, + 'axes.titlesize': fontsize * 1, + 'axes.labelsize': fontsize, + 'xtick.labelsize': fontsize * 0.8, + 'ytick.labelsize': fontsize * 0.8, + 'legend.fontsize': fontsize * 0.8, + 'font.family': "Arial" + }) + + +plt.style.use('default') + +dpi = 300 + +colors = {"Blue": "#155489", + "Medium blue": "#64A7DD", + "Light blue": "#B4DCF6", + "Lilac blue": "#AECCFF", + "Turquoise": "#76DCEC", + "Light green": "#B6E6B1", + "Medium green": "#54B48C", + "Green": "#5D8A2B", + "Teal": "#20A398", + "Yellow": "#FBD263", + "Orange": "#E89A63", + "Rose": "#CF7768", + "Red": "#A34427", + "Purple": "#741194", + "Grey": "#C0BFBF", + "Dark grey": "#616060", + "Light grey": "#F1F1F1"} + +bounds = { + 't_E': (1.0, 5.2), + 't_ISy': (4.0, 10.0), + 't_ISev': (5.0, 10.0), + 't_Cr': (9.0, 17.0), + 'mu_CR': (0.0, 0.4), + 'mu_IH': (0.0, 0.2), + 'mu_HU': (0.0, 0.4), + 'mu_UD': (0.0, 0.4), + 'transmission_prob': (0.0, 0.2) +} +SPIKE_SCALE = 0.4 +SLAB_SCALE = 0.2 +DATE_TIME = datetime.date(year=2020, month=10, day=1) +NUM_DAMPING_POINTS = 3 + + +def plot_region_fit( + data: np.ndarray, + region: int, + true_data=None, + ax=None, + label=None, + color="red", + only_80q=False +): + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data is not None: + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + n_samples, n_time, n_regions = data.shape + if not (0 <= region < n_regions): + raise IndexError + + x = np.arange(n_time) + vals = data[:, :, region] # (samples, time_points) + + qs_80 = np.quantile(vals, q=[0.1, 0.9], axis=0) + qs_90 = np.quantile(vals, q=[0.05, 0.95], axis=0) + qs_95 = np.quantile(vals, q=[0.025, 0.975], axis=0) + + med = np.median(vals, axis=0) + + if ax is None: + fig, ax = plt.subplots() + + ax.plot( + x, med, lw=2, label=label or f"{dd.Provincias[region_ids[region]]}", color=color) + ax.fill_between(x, qs_80[0], qs_80[1], alpha=0.5, + color=color, label="80% CI") + if not only_80q: + ax.fill_between(x, qs_90[0], qs_90[1], alpha=0.3, + color=color, label="90% CI") + ax.fill_between(x, qs_95[0], qs_95[1], alpha=0.1, + color=color, label="95% CI") + + if true_data is not None: + true_vals = true_data[:, region] # (time_points,) + ax.plot(x, true_vals, lw=2, color="black", label="True data") + + ax.set_xlabel("Time") + ax.set_ylabel("ICU") + ax.set_title(f"{dd.Provincias[region_ids[region]]}") + if label is not None: + ax.legend(loc="upper right") + + +def plot_aggregated_over_regions( + data: np.ndarray, + region_agg=np.sum, + true_data=None, + ax=None, + label=None, + color='red', + only_80q=False +): + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data is not None: + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + + # Aggregate over regions + agg_over_regions = region_agg(data, axis=-1) # (samples, time_points) + + qs_80 = np.quantile(agg_over_regions, q=[0.1, 0.9], axis=0) + qs_90 = np.quantile(agg_over_regions, q=[0.05, 0.95], axis=0) + qs_95 = np.quantile(agg_over_regions, q=[0.025, 0.975], axis=0) + + # Aggregate over samples + agg_median = np.median(agg_over_regions, axis=0) # (time_points, ) + + x = np.arange(agg_median.shape[0]) + if ax is None: + fig, ax = plt.subplots() + + ax.plot(x, agg_median, lw=2, + label=label or "Aggregated over regions", color=color) + ax.fill_between(x, qs_80[0], qs_80[1], alpha=0.5, + color=color, label="80% CI") + if not only_80q: + ax.fill_between(x, qs_90[0], qs_90[1], alpha=0.3, + color=color, label="90% CI") + ax.fill_between(x, qs_95[0], qs_95[1], alpha=0.1, + color=color, label="95% CI") + if true_data is not None: + true_vals = region_agg(true_data, axis=-1) # (time_points,) + ax.plot(x, true_vals, lw=2, color="black", label="True data") + + ax.set_xlabel("Time") + ax.set_ylabel("ICU") + if label is not None: + ax.legend() + + +def plot_aggregated_to_comunidades(data, true_data, name, synthetic, with_aug): + fig, ax = plt.subplots(nrows=4, ncols=4, figsize=( + 25, 25), layout="constrained") + ax = ax.flatten() + for state_idx, state in enumerate(comunidades): + idxs = [i for i, region_id in enumerate( + region_ids) if region_id // 10 == state] + plot_aggregated_over_regions( + data[:, :, idxs], # Add a dummy region axis for compatibility + true_data=true_data[:, idxs] if true_data is not None else None, + ax=ax[state_idx], + label=f"{dd.Comunidades[state]}", + color=colors["Red"] + ) + plt.savefig(f'{name}/comunidades_{name}{synthetic}{with_aug}.png', dpi=dpi) + plt.close() + +# plot simulations for all regions in 6x4 blocks + + +def plot_all_regions(simulations, divi_data, name, synthetic, with_aug): + n_regions = simulations.shape[-1] + n_cols = 4 + n_rows = 6 + n_blocks = (n_regions + n_cols * n_rows - 1) // (n_cols * n_rows) + + for block in range(n_blocks): + start_idx = block * n_cols * n_rows + end_idx = min(start_idx + n_cols * n_rows, n_regions) + fig, ax = plt.subplots(nrows=n_rows, ncols=n_cols, + figsize=(20, 25), layout="constrained") + ax = ax.flatten() + for i, region_idx in enumerate(range(start_idx, end_idx)): + plot_region_fit( + simulations, region=region_idx, true_data=divi_data, label="Median", ax=ax[i], color=colors["Red"] + ) + # Hide unused subplots + for i in range(end_idx - start_idx, len(ax)): + ax[i].axis("off") + plt.savefig( + f'{name}/regions_block_{block + 1}_{name}{synthetic}{with_aug}.png', dpi=dpi) + plt.close() + + +def plot_icu_on_spain(simulations, name, synthetic, with_aug): + med = np.median(simulations, axis=0) + + population = pd.read_json( + 'data/Spain/pydata/provincias_current_population.json') + values = med / population['Population'].to_numpy()[None, :] * 100000 + + map_data = gpd.read_file(os.path.join(os.getcwd( + ), 'tools/lineas_limite/SHP_ETRS89/recintos_provinciales_inspire_peninbal_etrs89/recintos_provinciales_inspire_peninbal_etrs89.shp')) + map_data['ID_Provincia'] = map_data['NAMEUNIT'].map( + dd.invert_dict(dd.Provincias)) + map_data.dropna(inplace=True, subset=['ID_Provincia']) + map_data["ID_Provincia"] = map_data["ID_Provincia"].astype(int) + map_data = map_data[~map_data["ID_Provincia"].isin(excluded_ids)] + fedstate_data = gpd.read_file(os.path.join( + os.getcwd(), 'tools/lineas_limite/SHP_ETRS89/recintos_autonomicas_inspire_peninbal_etrs89/recintos_autonomicas_inspire_peninbal_etrs89.shp')) + + plot_map(values[0], map_data, fedstate_data, "Median ICU", + f"{name}/median_icu_spain_initial_{name}{synthetic}{with_aug}") + plot_map(values[-1], map_data, fedstate_data, "Median ICU", + f"{name}/median_icu_spain_final_{name}{synthetic}{with_aug}") + + +def plot_map(values, map_data, fedstate_data, label, filename): + map_data[label] = map_data['ID_Provincia'].map( + dict(zip(region_ids, values))) + + fig, ax = plt.subplots(figsize=(12, 14)) + map_data.plot( + column=f"{label}", + cmap='Reds', + linewidth=0.5, + ax=ax, + edgecolor='0.6', + legend=True, + legend_kwds={'label': f"{label} per County", 'shrink': 0.6}, + ) + fedstate_data.boundary.plot(ax=ax, color='black', linewidth=1) + ax.set_title(f"{label} per County") + ax.axis('off') + plt.savefig(filename, bbox_inches='tight', dpi=dpi) + + +def calibration_curves_per_region( + data: np.ndarray, + true_data: np.ndarray, + levels=np.linspace(0.01, 0.99, 20), + ax=None, + max_regions=None, + cmap=plt.cm.Blues, + linewidth=1.5, + legend=True, + with_ideal=True, +): + """ + Per-region calibration curves, each region in a different shade of a colormap. + + data: (samples, time, regions) + true_data: (time, regions) + max_regions: limit number of regions shown + cmap: matplotlib colormap for line shades + """ + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + + n_samples, n_time, n_regions = data.shape + if max_regions is None: + R = n_regions + else: + R = min(max_regions, n_regions) + + if ax is None: + fig, ax = plt.subplots(figsize=(5, 4)) + + colors = [cmap(i / (R + 1)) for i in range(1, R + 1)] + + x = np.asarray(levels) + for r, col in zip(range(R), colors): + emp = [] + for nominal in levels: + q_low = (1.0 - nominal) / 2.0 + q_high = 1.0 - q_low + lo = np.quantile(data[:, :, r], q_low, axis=0) + hi = np.quantile(data[:, :, r], q_high, axis=0) + hits = (true_data[:, r] >= lo) & (true_data[:, r] <= hi) + emp.append(hits.mean()) + emp = np.asarray(emp) + # , label=f"Region {r+1}") + ax.plot(x, emp, lw=linewidth, color=col, alpha=0.5) + + if with_ideal: + ideal_line = ax.plot([0, 1], [0, 1], linestyle="--", + lw=1.2, color="black", label="Ideal")[0] + else: + ideal_line = None + + if legend: + # Custom legend: one patch for regions, one line for ideal + region_patch = Patch(color=colors[-1], label="Regions") + ax.legend(handles=[region_patch, ideal_line], + frameon=True, ncol=1) + + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.set_xlabel("Nominal level") + ax.set_ylabel("Empirical coverage") + ax.set_title("Calibration per region") + return ax + + +def calibration_median_mad_over_regions( + data: np.ndarray, + true_data: np.ndarray, + levels=np.linspace(0.01, 0.99, 20), + ax=None, + color="tab:blue", + alpha=0.25, + linewidth=2.0, + with_ideal=True, +): + """ + Compute per region empirical coverage at each nominal level, + then summarize across regions with median and MAD. + + data: (samples, time, regions) + true_data: (time, regions) + """ + if data.ndim != 3: + raise ValueError("Array not of shape (samples, time_points, regions)") + if true_data.shape != data.shape[1:]: + raise ValueError("True data shape does not match data shape") + + n_samples, n_time, n_regions = data.shape + L = len(levels) + per_region = np.empty((n_regions, L), dtype=float) + + # per level coverage per region + for j, nominal in enumerate(levels): + q_low = (1.0 - nominal) / 2.0 + q_high = 1.0 - q_low + lo = np.quantile(data, q_low, axis=0) # (time, regions) + hi = np.quantile(data, q_high, axis=0) # (time, regions) + hits = (true_data >= lo) & (true_data <= hi) # (time, regions) + # mean over time for each region + per_region[:, j] = hits.mean(axis=0) + + med = np.median(per_region, axis=0) # (levels,) + mad = np.median(np.abs(per_region - med[None, :]), axis=0) + + if ax is None: + fig, ax = plt.subplots(figsize=(5, 4)) + + x = np.asarray(levels) + ax.fill_between(x, med - mad, med + mad, + alpha=alpha, color=color, label=None) + ax.plot(x, med, lw=linewidth, color=color, label="Median across regions") + + if with_ideal: + ax.plot([0, 1], [0, 1], linestyle="--", + lw=1.2, color="black", label="Ideal") + + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.set_xlabel("Nominal level") + ax.set_ylabel("Empirical coverage") + ax.set_title("Calibration median and MAD across regions") + ax.legend() + return ax, {"levels": x, "median": med, "mad": mad} + + +def plot_damping_values(damping_values, name, synthetic): + + med = np.median(damping_values, axis=0) + mad = np.median(np.abs(damping_values - med), axis=0) + + # Extend for step plotting + med_extended = np.hstack([med, med[:, -1][:, None]]) + mad_extended = np.hstack([mad, mad[:, -1][:, None]]) + + # Plot damping values per region + fig, axes = plt.subplots(4, 4, figsize=(15, 10), constrained_layout=True) + axes = axes.flatten() + x = np.arange(15, 61, 15) # Time steps from 15 to 60 + + for i, ax in enumerate(axes): + if i < len(comunidades): + ax.stairs(med[i], edges=x, lw=2, color='red', baseline=None) + ax.fill_between( + x, med_extended[i] - + mad_extended[i], med_extended[i] + mad_extended[i], + alpha=0.25, color='red', step='post' + ) + ax.set_title(f"{dd.Comunidades[comunidades[i]]}") + ax.set_xlabel("Time") + ax.set_ylabel("Damping Value") + else: + ax.axis('off') # Hide unused subplots + + plt.suptitle("Damping Values per Region") + plt.savefig( + f"{name}/damping_values{name}{synthetic}.png", dpi=dpi) + + # Combined plot for all regions + fig, ax = plt.subplots(figsize=(10, 6)) + cmap = plt.cm.get_cmap("viridis", 16) # Colormap with 16 distinct colors + + for i, comunidad in enumerate(comunidades): + ax.stairs( + med[i], edges=x, lw=2, label=f"{dd.Comunidades[comunidad]}", + color=cmap(i), baseline=None + ) + + ax.set_title("Damping Values per Region (Combined)") + ax.set_xlabel("Time") + ax.set_ylabel("Damping Value") + ax.legend(loc="upper right", ncol=2) + plt.savefig( + f"{name}/damping_values_combined{name}{synthetic}.png", dpi=dpi) + + +class Simulation: + """ """ + + def __init__(self, data_dir, start_date, results_dir): + self.num_groups = 1 + self.data_dir = data_dir + self.start_date = start_date + self.results_dir = results_dir + if not os.path.exists(self.results_dir): + os.makedirs(self.results_dir) + + def set_covid_parameters(self, model, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): + model.parameters.TimeExposed[mio.AgeGroup(0)] = t_E + model.parameters.TimeInfectedNoSymptoms[mio.AgeGroup( + 0)] = 5.2 - t_E # todo: correct? + model.parameters.TimeInfectedSymptoms[mio.AgeGroup(0)] = t_ISy + model.parameters.TimeInfectedSevere[mio.AgeGroup(0)] = t_ISev + model.parameters.TimeInfectedCritical[mio.AgeGroup(0)] = t_Cr + + # probabilities + model.parameters.TransmissionProbabilityOnContact[mio.AgeGroup( + 0)] = transmission_prob + model.parameters.RelativeTransmissionNoSymptoms[mio.AgeGroup(0)] = 1 + + model.parameters.RecoveredPerInfectedNoSymptoms[mio.AgeGroup( + 0)] = mu_CR + model.parameters.SeverePerInfectedSymptoms[mio.AgeGroup(0)] = mu_IH + model.parameters.CriticalPerSevere[mio.AgeGroup(0)] = mu_HU + model.parameters.DeathsPerCritical[mio.AgeGroup(0)] = mu_UD + + # start day is set to the n-th day of the year + model.parameters.StartDay = self.start_date.timetuple().tm_yday + + model.parameters.Seasonality = mio.UncertainValue(0.2) + + def set_contact_matrices(self, model): + contact_matrices = mio.ContactMatrixGroup(1, self.num_groups) + + baseline = np.ones((self.num_groups, self.num_groups)) * 12.32 + minimum = np.zeros((self.num_groups, self.num_groups)) + contact_matrices[0] = mio.ContactMatrix(baseline, minimum) + model.parameters.ContactPatterns.cont_freq_mat = contact_matrices + + def set_npis(self, params, end_date, damping_values): + start_damping_1 = DATE_TIME + datetime.timedelta(days=15) + start_damping_2 = DATE_TIME + datetime.timedelta(days=30) + start_damping_3 = DATE_TIME + datetime.timedelta(days=45) + + if start_damping_1 < end_date: + start_date = (start_damping_1 - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping( + mio.Damping(np.r_[damping_values[0]], t=start_date)) + + if start_damping_2 < end_date: + start_date = (start_damping_2 - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping( + mio.Damping(np.r_[damping_values[1]], t=start_date)) + + if start_damping_3 < end_date: + start_date = (start_damping_3 - self.start_date).days + params.ContactPatterns.cont_freq_mat[0].add_damping( + mio.Damping(np.r_[damping_values[2]], t=start_date)) + + def get_graph(self, end_date, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): + print("Initializing model...") + model = Model(self.num_groups) + self.set_covid_parameters( + model, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) + self.set_contact_matrices(model) + print("Model initialized.") + + graph = osecir.ModelGraph() + + scaling_factor_infected = [2.5] + scaling_factor_icu = 1.0 + + data_dir_Spain = os.path.join(self.data_dir, "Spain") + mobility_data_file = os.path.join( + data_dir_Spain, "mobility", "commuter_mobility.txt") + pydata_dir = os.path.join(data_dir_Spain, "pydata") + + path_population_data = os.path.join( + pydata_dir, "provincias_current_population.json") + + print("Setting nodes...") + mio.osecir.set_nodes_provincias( + model.parameters, + mio.Date(self.start_date.year, + self.start_date.month, self.start_date.day), + mio.Date(end_date.year, + end_date.month, end_date.day), pydata_dir, + path_population_data, True, graph, scaling_factor_infected, + scaling_factor_icu, 0, 0, False) + + print("Setting edges...") + mio.osecir.set_edges(mobility_data_file, graph, 1) + + print("Graph created.") + + return graph + + def run(self, num_days_sim, damping_values, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob, save_graph=True): + mio.set_log_level(mio.LogLevel.Warning) + end_date = self.start_date + datetime.timedelta(days=num_days_sim) + + graph = self.get_graph(end_date, t_E, t_ISy, t_ISev, + t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) + + mobility_graph = osecir.MobilityGraph() + for node_idx in range(graph.num_nodes): + node = graph.get_node(node_idx) + # determine comunidad index for this node's provincia and use its damping values + comunidad_idx = comunidades.index(node.id // 10) + self.set_npis(node.property.parameters, end_date, + damping_values[comunidad_idx]) + mobility_graph.add_node(node.id, node.property) + for edge_idx in range(graph.num_edges): + mobility_graph.add_edge( + graph.get_edge(edge_idx).start_node_idx, + graph.get_edge(edge_idx).end_node_idx, + graph.get_edge(edge_idx).property) + mobility_sim = osecir.MobilitySimulation(mobility_graph, t0=0, dt=0.5) + mobility_sim.advance(num_days_sim) + + results = {} + for node_idx in range(mobility_sim.graph.num_nodes): + results[f'region{node_idx}'] = osecir.interpolate_simulation_result( + mobility_sim.graph.get_node(node_idx).property.result) + + return results + + +def run_spain_nuts3_simulation(damping_values, t_E, t_ISy, t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob): + mio.set_log_level(mio.LogLevel.Warning) + file_path = os.path.dirname(os.path.abspath(__file__)) + + sim = Simulation( + data_dir=os.path.join(file_path, "../../../data"), + start_date=DATE_TIME, + results_dir=os.path.join(file_path, "../../../results_osecir")) + num_days_sim = 60 + + results = sim.run(num_days_sim, damping_values, t_E, t_ISy, + t_ISev, t_Cr, mu_CR, mu_IH, mu_HU, mu_UD, transmission_prob) + + return results + + +def prior(): + damping_values = np.zeros((NUM_DAMPING_POINTS, len(comunidades))) + for i in range(NUM_DAMPING_POINTS): + mean = np.random.uniform(0, 1) + scale = 0.1 + a, b = (0 - mean) / scale, (1 - mean) / scale + damping_values[i] = truncnorm.rvs( + a=a, b=b, loc=mean, scale=scale, size=len(comunidades) + ) + return { + 'damping_values': np.transpose(damping_values), + 't_E': np.random.uniform(*bounds['t_E']), + 't_ISy': np.random.uniform(*bounds['t_ISy']), + 't_ISev': np.random.uniform(*bounds['t_ISev']), + 't_Cr': np.random.uniform(*bounds['t_Cr']), + 'mu_CR': np.random.uniform(*bounds['mu_CR']), + 'mu_IH': np.random.uniform(*bounds['mu_IH']), + 'mu_HU': np.random.uniform(*bounds['mu_HU']), + 'mu_UD': np.random.uniform(*bounds['mu_UD']), + 'transmission_prob': np.random.uniform(*bounds['transmission_prob']) + } + + +def load_divi_data(): + file_path = os.path.dirname(os.path.abspath(__file__)) + divi_path = os.path.join(file_path, "../../../data/Spain/pydata") + + data = pd.read_json(os.path.join(divi_path, "provincia_icu.json")) + data = data[data['Date'] >= np.datetime64(DATE_TIME)] + data = data[data['Date'] <= np.datetime64( + DATE_TIME + datetime.timedelta(days=60))] + data = data.sort_values(by=['ID_County', 'Date']) + divi_data = data.pivot(index='Date', columns='ID_County', values='ICU') + divi_dict = {} + for i, region_id in enumerate(region_ids): + divi_dict[f"region{i}"] = divi_data[region_id].to_numpy()[ + None, :, None] + return divi_dict + + +def extract_observables(simulation_results, observable_index=7): + for key in simulation_results.keys(): + if key not in inference_params: + simulation_results[key] = simulation_results[key][:, + :, observable_index][..., np.newaxis] + return simulation_results + + +def create_train_data(filename, number_samples=1000): + + simulator = bf.simulators.make_simulator( + [prior, run_spain_nuts3_simulation] + ) + trainings_data = simulator.sample(number_samples) + trainings_data = extract_observables(trainings_data) + with open(filename, 'wb') as f: + pickle.dump(trainings_data, f, pickle.HIGHEST_PROTOCOL) + + +def load_pickle(path): + with open(path, "rb") as f: + return pickle.load(f) + + +def is_region_key(k: str) -> bool: + return 'region' in k + + +def apply_aug(d: dict, aug) -> dict: + return {k: np.clip(aug(v), 0, None) if is_region_key(k) else v for k, v in d.items()} + + +def concat_dicts(base: dict, new: dict) -> dict: + missing = set(base) - set(new) + if missing: + raise KeyError(f"new dict missing keys: {sorted(missing)}") + for k in base: + base[k] = np.concatenate([base[k], new[k]]) + return base + + +def aggregate_states(d: dict) -> None: + n_regions = len(region_ids) + # per state + for i, state in enumerate(comunidades): + idxs = [ + r for r in range(n_regions) + if region_ids[r] // 10 == state + ] + d[f"comunidad{i}"] = np.sum([d[f"region{r}"] for r in idxs], axis=0) + # all allowed regions + d["state"] = np.sum([d[f"comunidad{r}"] + for r in range(len(comunidades))], axis=0) + + +def combine_results(dict_list): + combined = {} + for d in dict_list: + combined = concat_dicts(combined, d) if combined else d + return combined + + +def skip_2weeks(d: dict) -> dict: + return {k: v[:, 14:, :] if is_region_key(k) else v for k, v in d.items()} + + +def get_workflow(): + + simulator = bf.make_simulator( + [prior, run_spain_nuts3_simulation] + ) + adapter = ( + bf.Adapter() + .to_array() + .convert_dtype("float64", "float32") + .constrain("damping_values", lower=0.0, upper=1.0) + .constrain("t_E", lower=bounds["t_E"][0], upper=bounds["t_E"][1]) + .constrain("t_ISy", lower=bounds["t_ISy"][0], upper=bounds["t_ISy"][1]) + .constrain("t_ISev", lower=bounds["t_ISev"][0], upper=bounds["t_ISev"][1]) + .constrain("t_Cr", lower=bounds["t_Cr"][0], upper=bounds["t_Cr"][1]) + .constrain("mu_CR", lower=bounds["mu_CR"][0], upper=bounds["mu_CR"][1]) + .constrain("mu_IH", lower=bounds["mu_IH"][0], upper=bounds["mu_IH"][1]) + .constrain("mu_HU", lower=bounds["mu_HU"][0], upper=bounds["mu_HU"][1]) + .constrain("mu_UD", lower=bounds["mu_UD"][0], upper=bounds["mu_UD"][1]) + .constrain("transmission_prob", lower=bounds["transmission_prob"][0], upper=bounds["transmission_prob"][1]) + .concatenate( + ["damping_values", "t_E", "t_ISy", "t_ISev", "t_Cr", + "mu_CR", "mu_IH", "mu_HU", "mu_UD", "transmission_prob"], + into="inference_variables", + axis=-1 + ) + .concatenate(summary_vars, into="summary_variables", axis=-1) + ) + + summary_network = bf.networks.FusionTransformer( + summary_dim=(len(bounds)+16*NUM_DAMPING_POINTS)*2, dropout=0.1 + ) + inference_network = bf.networks.FlowMatching( + subnet_kwargs={'widths': (512, 512, 512, 512, 512)}) + + # aug = bf.augmentations.NNPE(spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False) + workflow = bf.BasicWorkflow( + simulator=simulator, + adapter=adapter, + summary_network=summary_network, + inference_network=inference_network, + standardize='all' + # augmentations={f'region{i}': aug for i in range(len(region_ids)) if region_ids[i] not in no_icu_ids} + # aggregation of the states would need to be recomputed every time a different noise realization is applied + ) + + return workflow + + +def run_training(name, num_training_files=20): + train_template = name+"/trainings_data{i}_"+name+".pickle" + val_path = f"{name}/validation_data_{name}.pickle" + + aug = bf.augmentations.NNPE( + spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False + ) + + # training data + train_files = [train_template.format(i=i) + for i in range(1, 1+num_training_files)] + trainings_data = None + for p in train_files: + d = load_pickle(p) + d = apply_aug(d, aug=aug) # only on region keys + d = skip_2weeks(d) + d['damping_values'] = d['damping_values'].reshape( + (d['damping_values'].shape[0], -1)) + if trainings_data is None: + trainings_data = d + else: + trainings_data = concat_dicts(trainings_data, d) + aggregate_states(trainings_data) + + # validation data + validation_data = apply_aug(load_pickle(val_path), aug=aug) + validation_data = skip_2weeks(validation_data) + aggregate_states(validation_data) + validation_data['damping_values'] = validation_data['damping_values'].reshape( + (validation_data['damping_values'].shape[0], -1)) + + # check data + workflow = get_workflow() + print("summary_variables shape:", workflow.adapter( + validation_data)["summary_variables"].shape) + print("inference_variables shape:", workflow.adapter( + validation_data)["inference_variables"].shape) + + history = workflow.fit_offline( + data=trainings_data, epochs=5, batch_size=64, validation_data=validation_data + ) + + workflow.approximator.save( + filepath=os.path.join(f"{name}/model_{name}.keras") + ) + + plots = workflow.plot_default_diagnostics( + test_data=validation_data, calibration_ecdf_kwargs={ + 'difference': True, 'stacked': True} + ) + plots['losses'].savefig(f'{name}/losses_{name}.png', dpi=dpi) + plots['recovery'].savefig(f'{name}/recovery_{name}.png', dpi=dpi) + plots['calibration_ecdf'].savefig( + f'{name}/calibration_ecdf_{name}.png', dpi=dpi) + plots['z_score_contraction'].savefig( + f'{name}/z_score_contraction_{name}.png', dpi=dpi) + + +def run_inference(name, num_samples=10, on_synthetic_data=False): + val_path = f"{name}/validation_data_{name}.pickle" + synthetic = "_synthetic" if on_synthetic_data else "" + + aug = bf.augmentations.NNPE( + spike_scale=SPIKE_SCALE, slab_scale=SLAB_SCALE, per_dimension=False + ) + + # validation data + validation_data = load_pickle(val_path) # synthetic data + if on_synthetic_data: + # validation data + validation_data = apply_aug(validation_data, aug=aug) + validation_data['damping_values'] = validation_data['damping_values'].reshape( + (validation_data['damping_values'].shape[0], -1)) + validation_data_skip2w = skip_2weeks(validation_data) + aggregate_states(validation_data_skip2w) + divi_dict = validation_data + + divi_data = np.concatenate( + [divi_dict[f'region{i}'] for i in range(len(region_ids))], axis=-1 + )[0] # only one dataset + else: + divi_dict = load_divi_data() + validation_data_skip2w = skip_2weeks(divi_dict) + aggregate_states(validation_data_skip2w) + divi_data = np.concatenate( + [divi_dict[f'region{i}'] for i in range(len(region_ids))], axis=-1 + )[0] + + workflow = get_workflow() + workflow.approximator = keras.models.load_model( + filepath=os.path.join(f"{name}/model_{name}.keras") + ) + + if os.path.exists(f'{name}/sims_{name}{synthetic}_with_aug.pickle') and os.path.exists(f'{name}/sims_{name}{synthetic}.pickle') and os.path.exists(f'{name}/sims_{name}{synthetic}.pickle'): + simulations = load_pickle(f'{name}/sims_{name}{synthetic}.pickle') + simulations_aug = load_pickle( + f'{name}/sims_{name}{synthetic}_with_aug.pickle') + samples = load_pickle(f'{name}/samples_{name}{synthetic}.pickle') + print("loaded simulations from file") + else: + samples = workflow.sample( + conditions=validation_data_skip2w, num_samples=num_samples) + with open(f'{name}/samples_{name}{synthetic}.pickle', 'wb') as f: + pickle.dump(samples, f, pickle.HIGHEST_PROTOCOL) + samples['damping_values'] = samples['damping_values'].reshape( + (samples['damping_values'].shape[0], num_samples, len(comunidades), NUM_DAMPING_POINTS)) + results = [] + for i in range(num_samples): # we only have one dataset for inference here + result = run_spain_nuts3_simulation( + damping_values=samples['damping_values'][0, i], + t_E=samples['t_E'][0, i], t_ISy=samples['t_ISy'][0, i], + t_ISev=samples['t_ISev'][0, i], t_Cr=samples['t_Cr'][0, i], + mu_CR=samples['mu_CR'][0, i], mu_IH=samples['mu_IH'][0, i], + mu_HU=samples['mu_HU'][0, i], mu_UD=samples['mu_UD'][0, i], + transmission_prob=samples['transmission_prob'][0, i] + ) + for key in result.keys(): + result[key] = np.array(result[key])[ + None, ...] # add sample axis + results.append(result) + results = combine_results(results) + results = extract_observables(results) + results_aug = apply_aug(results, aug=aug) + + # get sims in shape (samples, time, regions) + simulations = np.zeros( + (num_samples, divi_data.shape[0], divi_data.shape[1])) + simulations_aug = np.zeros( + (num_samples, divi_data.shape[0], divi_data.shape[1])) + for i in range(num_samples): + simulations[i] = np.concatenate( + [results[f'region{region}'][i] for region in range(len(region_ids))], axis=-1) + simulations_aug[i] = np.concatenate( + [results_aug[f'region{region}'][i] for region in range(len(region_ids))], axis=-1) + + # save sims + with open(f'{name}/sims_{name}{synthetic}.pickle', 'wb') as f: + pickle.dump(simulations, f, pickle.HIGHEST_PROTOCOL) + with open(f'{name}/sims_{name}{synthetic}_with_aug.pickle', 'wb') as f: + pickle.dump(simulations_aug, f, pickle.HIGHEST_PROTOCOL) + + plot_damping_values(samples['damping_values'][0], + name=name, synthetic=synthetic) + + samples['damping_values'] = samples['damping_values'].reshape( + (samples['damping_values'].shape[0], samples['damping_values'].shape[1], -1)) + validation_data['damping_values'] = validation_data['damping_values'].reshape( + (validation_data['damping_values'].shape[0], -1)) + + plot = bf.diagnostics.pairs_posterior( + samples, priors=validation_data, dataset_id=0) + plot.savefig(f'{name}/pairs_posterior_{name}{synthetic}.png', dpi=dpi) + + plot_all_regions(simulations, divi_data, name, synthetic, with_aug="") + plot_all_regions(simulations_aug, divi_data, name, + synthetic, with_aug="_with_aug") + + plot_aggregated_to_comunidades( + simulations, divi_data, name, synthetic, with_aug="") + plot_aggregated_to_comunidades( + simulations_aug, divi_data, name, synthetic, with_aug="_with_aug") + + fig, axes = plt.subplots(2, 2, figsize=( + 12, 12), sharex=True, sharey='row', constrained_layout=True) + # Plot without augmentation + plot_aggregated_over_regions( + simulations, true_data=divi_data, label="Region Aggregated Median (No Aug)", ax=axes[0, 0], color="#132a70" + ) + axes[0, 0].set_title("Without Augmentation") + # Plot with augmentation + plot_aggregated_over_regions( + simulations_aug, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[0, 1], color="#132a70" + ) + axes[0, 1].set_title("With Augmentation") + # Plot without augmentation (80% quantile only) + plot_aggregated_over_regions( + simulations, true_data=divi_data, label="Region Aggregated Median (No Aug)", ax=axes[1, 0], color="#132a70", only_80q=True + ) + axes[1, 0].set_title("Without Augmentation (80% Quantile)") + # Plot with augmentation (80% quantile only) + plot_aggregated_over_regions( + simulations_aug, true_data=divi_data, label="Region Aggregated Median (With Aug)", ax=axes[1, 1], color="#132a70", only_80q=True + ) + axes[1, 1].set_title("With Augmentation (80% Quantile)") + plt.savefig(f'{name}/region_aggregated_{name}{synthetic}.png', dpi=dpi) + plt.close() + + fig, axis = plt.subplots(1, 2, figsize=( + 10, 4), sharex=True, layout="constrained") + ax = calibration_curves_per_region(simulations, divi_data, ax=axis[0]) + ax, stats = calibration_median_mad_over_regions( + simulations, divi_data, ax=axis[1]) + plt.savefig( + f'{name}/calibration_per_region_{name}{synthetic}.png', dpi=dpi) + plt.close() + fig, axis = plt.subplots(1, 2, figsize=( + 10, 4), sharex=True, layout="constrained") + ax = calibration_curves_per_region(simulations_aug, divi_data, ax=axis[0]) + ax, stats = calibration_median_mad_over_regions( + simulations_aug, divi_data, ax=axis[1]) + plt.savefig( + f'{name}/calibration_per_region_{name}{synthetic}_with_aug.png', dpi=dpi) + plt.close() + + plot_icu_on_spain(simulations, name, synthetic, with_aug="") + plot_icu_on_spain(simulations_aug, name, synthetic, with_aug="_with_aug") + + simulation_agg = np.sum(simulations, axis=-1, + keepdims=True) # sum over regions + simulation_aug_agg = np.sum(simulations_aug, axis=-1, keepdims=True) + + rmse = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes( + simulation_agg, 0, 1), np.sum(divi_data, axis=-1, keepdims=True), normalize=False) + rmse_aug = bf.diagnostics.metrics.root_mean_squared_error(np.swapaxes( + simulation_aug_agg, 0, 1), np.sum(divi_data, axis=-1, keepdims=True), normalize=False) + print("Mean RMSE over regions:", rmse["values"].mean()) + print("Mean RMSE over regions (with aug):", rmse_aug["values"].mean()) + + cal_error = bf.diagnostics.metrics.calibration_error(np.swapaxes( + simulation_agg, 0, 1), np.sum(divi_data, axis=-1, keepdims=True)) + cal_error_aug = bf.diagnostics.metrics.calibration_error(np.swapaxes( + simulation_aug_agg, 0, 1), np.sum(divi_data, axis=-1, keepdims=True)) + print("Mean Calibration Error over regions:", cal_error["values"].mean()) + print("Mean Calibration Error over regions (with aug):", + cal_error_aug["values"].mean()) + + +if __name__ == "__main__": + name = "spain_nuts3" + + if not os.path.exists(name): + os.makedirs(name) + # create_train_data( + # filename=f'{name}/validation_data_{name}.pickle', number_samples=10) + # run_training(name=name, num_training_files=1) + # run_inference(name=name, on_synthetic_data=True) + run_inference(name=name, on_synthetic_data=False) diff --git a/pycode/memilio-epidata/memilio/epidata/defaultDict.py b/pycode/memilio-epidata/memilio/epidata/defaultDict.py index 73389b82d2..8c57564268 100644 --- a/pycode/memilio-epidata/memilio/epidata/defaultDict.py +++ b/pycode/memilio-epidata/memilio/epidata/defaultDict.py @@ -45,7 +45,7 @@ 'out_folder': default_file_path, 'update_data': False, 'start_date': date(2020, 1, 1), - 'end_date': date.today(), + 'end_date': date(2021, 1, 1), 'split_berlin': False, 'impute_dates': False, 'moving_average': 0, @@ -704,3 +704,101 @@ def invert_dict(dict_to_invert): """ return {val: key for key, val in dict_to_invert.items()} + + +Provincias = { + 111: 'A Coruña', + 112: 'Lugo', + 113: 'Ourense', + 114: 'Pontevedra', + 120: 'Asturias', + 130: 'Cantabria', + 211: 'Araba/Álava', + 212: 'Gipuzkoa', + 213: 'Bizkaia', + 220: 'Navarra', + 230: 'La Rioja', + 241: 'Huesca', + 242: 'Teruel', + 243: 'Zaragoza', + 300: 'Madrid', + 411: 'Ávila', + 412: 'Burgos', + 413: 'León', + 414: 'Palencia', + 415: 'Salamanca', + 416: 'Segovia', + 417: 'Soria', + 418: 'Valladolid', + 419: 'Zamora', + 421: 'Albacete', + 422: 'Ciudad Real', + 423: 'Cuenca', + 424: 'Guadalajara', + 425: 'Toledo', + 431: 'Badajoz', + 432: 'Cáceres', + 511: 'Barcelona', + 512: 'Girona', + 513: 'Lleida', + 514: 'Tarragona', + 521: 'Alacant/Alicante', + 522: 'Castelló/Castellón', + 523: 'València/Valencia', + 530: 'Illes Balears', + 611: 'Almería', + 612: 'Cádiz', + 613: 'Córdoba', + 614: 'Granada', + 615: 'Huelva', + 616: 'Jaén', + 617: 'Málaga', + 618: 'Sevilla', + 620: 'Murcia', + 630: 'Ceuta', + 640: 'Melilla', + 701: 'Las Palmas', + 702: 'Santa Cruz de Tenerife' +} + + +Comunidades = { + 11: 'Galicia', + 12: 'Principado de Asturias', + 13: 'Cantabria', + 21: 'País Vasco', + 22: 'Comunidad Foral de Navarra', + 23: 'La Rioja', + 24: 'Aragón', + 30: 'Comunidad de Madrid', + 41: 'Castilla y León', + 42: 'Castilla-La Mancha', + 43: 'Extremadura', + 51: 'Cataluña', + 52: 'Comunidad Valenciana', + 53: 'Islas Baleares', + 61: 'Andalucía', + 62: 'Región de Murcia', + 63: 'Ceuta', + 64: 'Melilla', + 70: 'Canarias' +} + +provincia_id_map = {1: 211, 2: 421, 3: 521, 4: 611, 5: 411, 6: 431, 7: 530, 8: 511, 9: 412, 10: 432, 11: 612, 12: 522, + 13: 422, 14: 613, 15: 111, 16: 423, 17: 512, 18: 614, 19: 424, 20: 212, 21: 615, 22: 241, 23: 616, + 24: 413, 25: 513, 26: 230, 27: 112, 28: 300, 29: 617, 30: 620, 31: 220, 32: 113, 33: 120, 34: 414, + 35: 701, 36: 114, 37: 415, 38: 702, 39: 130, 40: 416, 41: 618, 42: 417, 43: 514, 44: 242, 45: 425, + 46: 523, 47: 418, 48: 213, 49: 419, 50: 243, 51: 630, 52: 640} + +provincia_id_map_census = {2: 211, 3: 421, 4: 521, 5: 611, 6: 411, 7: 431, 8: 530, 9: 511, 10: 412, 11: 432, 12: 612, 13: 522, + 14: 422, 15: 613, 16: 111, 17: 423, 18: 512, 19: 614, 20: 424, 21: 212, 22: 615, 23: 241, 24: 616, + 25: 413, 26: 513, 27: 230, 28: 112, 29: 300, 30: 617, 31: 620, 32: 220, 33: 120, 34: 414, 35: 701, + 36: 114, 37: 415, 38: 702, 39: 130, 40: 416, 41: 618, 42: 417, 43: 514, 44: 242, 45: 425, 46: 523, + 47: 418, 48: 213, 49: 419, 50: 243, 51: 630, 52: 640, 53: 113} + +Provincia_ISO_to_ID = {'A': 521, 'AB': 421, 'AL': 611, 'AV': 411, 'B': 511, 'BA': 431, 'BI': 213, 'BU': 412, 'C': 111, 'CA': 612, + 'CC': 432, 'CE': 630, 'CO': 613, 'CR': 422, 'CS': 522, 'CU': 423, 'GC': 701, 'GI': 512, 'GR': 614, 'GU': 424, + 'H': 615, 'HU': 241, 'J': 616, 'L': 513, 'LE': 413, 'LO': 230, 'LU': 112, 'M': 300, 'MA': 617, 'ML': 640, + 'MU': 620, 'NA': 220, 'O': 120, 'OR': 113, 'P': 414, 'PM': 530, 'PO': 114, 'SA': 415, 'S': 130, 'SE': 618, + 'SG': 416, 'SO': 417, 'SS': 212, 'T': 514, 'TE': 242, 'TF': 702, 'TO': 425, 'V': 523, 'VA': 418, 'VI': 211, + 'ZA': 419, 'Z': 243} diff --git a/pycode/memilio-epidata/memilio/epidata/getPopulationData.py b/pycode/memilio-epidata/memilio/epidata/getPopulationData.py index c7d0b401de..07e8ab7415 100644 --- a/pycode/memilio-epidata/memilio/epidata/getPopulationData.py +++ b/pycode/memilio-epidata/memilio/epidata/getPopulationData.py @@ -55,8 +55,9 @@ def read_population_data(ref_year): req = requests.get(download_url) df_pop_raw = pd.read_csv(io.StringIO(req.text), sep=';', header=5) except pd.errors.ParserError: - gd.default_print('Warning', 'Data for year '+str(ref_year) + - ' is not available; downloading newest data instead.') + gd.default_print( + 'Warning', 'Data for year ' + str(ref_year) + + ' is not available; downloading newest data instead.') ref_year = None if ref_year is None: download_url = 'https://www.regionalstatistik.de/genesis/online?operation=download&code=12411-02-03-4&option=csv' @@ -66,7 +67,9 @@ def read_population_data(ref_year): return df_pop_raw, ref_year -def export_population_dataframe(df_pop: pd.DataFrame, directory: str, file_format: str, merge_eisenach: bool, ref_year): +def export_population_dataframe( + df_pop: pd.DataFrame, directory: str, file_format: str, + merge_eisenach: bool, ref_year): """ Writes population dataframe into directory with new column names and age groups :param df_pop: Population data DataFrame to be exported pd.DataFrame @@ -140,6 +143,10 @@ def export_population_dataframe(df_pop: pd.DataFrame, directory: str, file_forma columns=dd.EngEng["idCounty"]) gd.write_dataframe(df_pop_export, directory, filename, file_format) + gd.write_dataframe(aggregate_to_state_level(df_pop_export), + directory, filename + '_states', file_format) + gd.write_dataframe(aggregate_to_country_level(df_pop_export), + directory, filename + '_germany', file_format) return df_pop_export @@ -190,8 +197,8 @@ def assign_population_data(df_pop_raw, counties, age_cols, idCounty_idx): # direct assignment of population data found df_pop.loc[df_pop[dd.EngEng['idCounty']] == df_pop_raw.loc [start_idx, dd.EngEng['idCounty']], - age_cols] = df_pop_raw.loc[start_idx: start_idx + num_age_groups - 1, dd.EngEng - ['number']].values.astype(int) + age_cols] = df_pop_raw.loc[start_idx: start_idx + + num_age_groups - 1, dd.EngEng['number']].values.astype(int) # Berlin and Hamburg elif county_id + '000' in counties[:, 1]: # direct assignment of population data found @@ -262,7 +269,8 @@ def fetch_population_data(read_data: bool = dd.defaultDict['read_data'], if read_data == True: gd.default_print( - 'Warning', 'Read_data is not supportet for getPopulationData.py. Setting read_data = False') + 'Warning', + 'Read_data is not supportet for getPopulationData.py. Setting read_data = False') read_data = False directory = os.path.join(out_folder, 'Germany', 'pydata') @@ -444,6 +452,28 @@ def get_population_data(read_data: bool = dd.defaultDict['read_data'], return df_pop_export +def aggregate_to_state_level(df_pop: pd.DataFrame): + + countyIDtostateID = geoger.get_countyid_to_stateid_map() + + df_pop['ID_State'] = df_pop[dd.EngEng['idCounty']].map(countyIDtostateID) + df_pop = df_pop.drop( + columns='ID_County').groupby( + 'ID_State', as_index=True).sum() + df_pop['ID_State'] = df_pop.index + return df_pop + + +def aggregate_to_country_level(df_pop: pd.DataFrame): + + df_pop['ID_Country'] = 0 + df_pop = df_pop.drop( + columns=['ID_County', 'ID_State']).groupby( + 'ID_Country', as_index=True).sum() + df_pop['ID_Country'] = df_pop.index + return df_pop + + def main(): """ Main program entry.""" diff --git a/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py b/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py new file mode 100644 index 0000000000..2e6dbb895c --- /dev/null +++ b/pycode/memilio-epidata/memilio/epidata/getSimulationDataSpain.py @@ -0,0 +1,217 @@ +import pandas as pd +import os +import io +import requests + +from pyspainmobility import Mobility, Zones +import defaultDict as dd + +from memilio.epidata import getDataIntoPandasDataFrame as gd + + +def fetch_population_data(): + download_url = 'https://servicios.ine.es/wstempus/js/es/DATOS_TABLA/67988?tip=AM&' + req = requests.get(download_url) + req.encoding = 'ISO-8859-1' + + df = pd.read_json(io.StringIO(req.text)) + df = df[['MetaData', 'Data']] + df = df[df['MetaData'].apply( + lambda x: x[0]['T3_Variable'] == 'Provincias')] + df = df[df['MetaData'].apply( + lambda x: x[1]['Nombre'] == 'Total')] + df['ID_Provincia'] = df['MetaData'].apply( + lambda x: dd.provincia_id_map_census[x[0]['Id']]) + df['Population'] = df['Data'].apply(lambda x: x[0]['Valor']) + return df[['ID_Provincia', 'Population']] + + +def remove_islands(df, column_labels=['ID_Provincia']): + for label in column_labels: + df = df[~df[label].isin([530, 630, 640, 701, 702])] + return df + + +def get_population_data(): + df = fetch_population_data() + df = df.sort_values(by=['ID_Provincia']) + df = remove_islands(df) + + return df + + +def fetch_icu_data(): + # https://www.sanidad.gob.es/areas/alertasEmergenciasSanitarias/alertasActuales/nCov/capacidadAsistencial.htm + download_url = 'https://www.sanidad.gob.es/areas/alertasEmergenciasSanitarias/alertasActuales/nCov/documentos/Datos_Capacidad_Asistencial_Historico_14072023.csv' + req = requests.get(download_url) + req.encoding = 'ISO-8859-1' + + df = pd.read_csv(io.StringIO(req.text), sep=';') + return df + + +def preprocess_icu_data(df): + df_icu = df[df["Unidad"] == "U. Críticas SIN respirador"] + df_icu_vent = df[df["Unidad"] == "U. Críticas CON respirador"] + + df_icu = df_icu[['Fecha', 'ID_Provincia', 'OCUPADAS_COVID19']].rename( + columns={'OCUPADAS_COVID19': 'ICU'}) + df_icu_vent = df_icu_vent[['Fecha', 'ID_Provincia', 'OCUPADAS_COVID19']].rename( + columns={'OCUPADAS_COVID19': 'ICU_ventilated'}) + + df_merged = pd.merge(df_icu, df_icu_vent, on=[ + 'Fecha', 'ID_Provincia'], how='outer') + df_merged['Fecha'] = pd.to_datetime( + df_merged['Fecha'], format='%d/%m/%Y').dt.strftime('%Y-%m-%d') + df_merged.rename(columns={'Fecha': 'Date'}, inplace=True) + + return df_merged + + +def get_icu_data(): + df = fetch_icu_data() + df.rename(columns={'Cod_Provincia': 'ID_Provincia'}, inplace=True) + # ensure numeric, drop non-numeric and increment all IDs by 1 + df['ID_Provincia'] = pd.to_numeric(df['ID_Provincia'], errors='coerce') + df = df[df['ID_Provincia'].notna()].copy() + df['ID_Provincia'] = df['ID_Provincia'].astype(int) + 1 + df['ID_Provincia'] = df['ID_Provincia'].map(dd.provincia_id_map) + df = remove_islands(df) + df = preprocess_icu_data(df) + + return df + + +def fetch_case_data(): + # https://datos.gob.es/es/catalogo/e05070101-evolucion-de-enfermedad-por-el-coronavirus-covid-19 + download_url = 'https://cnecovid.isciii.es/covid19/resources/casos_diagnostico_provincia.csv' + req = requests.get(download_url) + + df = pd.read_csv(io.StringIO(req.text), sep=',', + keep_default_na=False, na_values=[]) + + return df + + +def preprocess_case_data(df): + df['provincia_iso'] = df['provincia_iso'].map(dd.Provincia_ISO_to_ID) + df = df.rename( + columns={'provincia_iso': 'ID_County', 'fecha': 'Date', + 'num_casos': 'Confirmed'})[ + ['ID_County', 'Date', 'Confirmed']] + # ensure correct types + df['ID_County'] = pd.to_numeric( + df['ID_County'], errors='coerce').astype('Int64') + df = df[df['ID_County'].notna()].copy() + df['ID_County'] = df['ID_County'].astype(int) + return df + + +def get_case_data(): + df_raw = fetch_case_data() + df = preprocess_case_data(df_raw) + df = df.groupby(['ID_County', 'Date'], as_index=False)['Confirmed'].sum() + df = df.sort_values(['ID_County', 'Date']) + df['Confirmed'] = df.groupby('ID_County')['Confirmed'].cumsum() + return df + + +def download_mobility_data(data_dir, start_date, end_date, level): + mobility_data = Mobility(version=2, zones=level, + start_date=start_date, end_date=end_date, output_directory=data_dir) + mobility_data.get_od_data() + + +def preprocess_mobility_data(df, data_dir): + df.drop(columns=['hour', 'trips_total_length_km'], inplace=True) + if not os.path.exists(os.path.join(data_dir, 'poblacion.csv')): + download_url = 'https://movilidad-opendata.mitma.es/zonificacion/poblacion.csv' + req = requests.get(download_url) + with open(os.path.join(data_dir, 'poblacion.csv'), 'wb') as f: + f.write(req.content) + + zonification_df = pd.read_csv(os.path.join(data_dir, 'poblacion.csv'), sep='|')[ + ['municipio', 'provincia', 'poblacion']] + zonification_df['provincia'] = zonification_df['provincia'].map( + dd.provincia_id_map) + zonification_df.dropna(inplace=True, subset=['provincia']) + poblacion = zonification_df.groupby('provincia')[ + 'poblacion'].sum() + zonification_df.drop_duplicates( + subset=['municipio', 'provincia'], inplace=True) + municipio_to_provincia = dict( + zip(zonification_df['municipio'], zonification_df['provincia'])) + df['id_origin'] = df['id_origin'].map( + municipio_to_provincia) + df['id_destination'] = df['id_destination'].map( + municipio_to_provincia) + + df.query('id_origin != id_destination', inplace=True) + df = df.groupby(['date', 'id_origin', 'id_destination'], + as_index=False)['n_trips'].sum() + + df['n_trips'] = df.apply( + lambda row: row['n_trips'] / poblacion[row['id_origin']], axis=1) + + df = df.groupby(['id_origin', 'id_destination'], + as_index=False)['n_trips'].mean() + + df = remove_islands(df, ['id_origin', 'id_destination']) + + return df + + +def get_mobility_data(data_dir, start_date='2022-08-01', end_date='2022-08-31', level='municipios'): + filename = f'Viajes_{level}_{start_date}_{end_date}_v2.parquet' + if not os.path.exists(os.path.join(data_dir, filename)): + print( + f"File {os.path.join(data_dir, filename)} does not exist. Downloading mobility data...") + download_mobility_data(data_dir, start_date, end_date, level) + + df = pd.read_parquet(os.path.join(data_dir, filename)) + df = preprocess_mobility_data(df, data_dir) + + return df + + +if __name__ == "__main__": + + data_dir = os.path.join(os.path.dirname( + os.path.abspath(__file__)), "../../../../data/Spain") + pydata_dir = os.path.join(data_dir, 'pydata') + os.makedirs(pydata_dir, exist_ok=True) + mobility_dir = os.path.join(data_dir, 'mobility') + os.makedirs(mobility_dir, exist_ok=True) + + df = get_population_data() + df.to_json(os.path.join( + pydata_dir, 'provincias_current_population.json'), orient='records') + + df = get_icu_data() + # rename to match expected column name in parameters_io.h + if 'ID_Provincia' in df.columns: + df.rename(columns={'ID_Provincia': 'ID_County'}, inplace=True) + # and should also be int + if 'ID_County' in df.columns: + df['ID_County'] = pd.to_numeric(df['ID_County'], errors='coerce') + df = df[df['ID_County'].notna()].copy() + df['ID_County'] = df['ID_County'].astype(int) + df.to_json(os.path.join(pydata_dir, 'provincia_icu.json'), orient='records') + + df = get_case_data() + # same for case data + if 'ID_Provincia' in df.columns: + df.rename(columns={'ID_Provincia': 'ID_County'}, inplace=True) + if 'ID_County' in df.columns: + df['ID_County'] = pd.to_numeric(df['ID_County'], errors='coerce') + df = df[df['ID_County'].notna()].copy() + df['ID_County'] = df['ID_County'].astype(int) + df.to_json(os.path.join( + pydata_dir, 'cases_all_pronvincias.json'), orient='records') + + df = get_mobility_data(mobility_dir) + matrix = df.pivot(index='id_origin', + columns='id_destination', values='n_trips').fillna(0) + + gd.write_dataframe(matrix, mobility_dir, 'commuter_mobility', 'txt', { + 'sep': ' ', 'index': False, 'header': False}) diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp b/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp index 50ea9b65b8..b7363ced4f 100644 --- a/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp +++ b/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp @@ -200,7 +200,13 @@ PYBIND11_MODULE(_simulation_osecir, m) mio::osecir::ParametersBase>(m, "Parameters") .def(py::init()) .def("check_constraints", &mio::osecir::Parameters::check_constraints) - .def("apply_constraints", &mio::osecir::Parameters::apply_constraints); + .def("apply_constraints", &mio::osecir::Parameters::apply_constraints) + .def_property( + "end_commuter_detection", + [](const mio::osecir::Parameters& self) -> auto { return self.get_end_commuter_detection(); }, + [](mio::osecir::Parameters& self, double p) { + self.get_end_commuter_detection() = p; + }); using Populations = mio::Populations; pymio::bind_Population(m, "Populations", mio::Tag::Populations>{}); @@ -264,6 +270,62 @@ PYBIND11_MODULE(_simulation_osecir, m) }, py::return_value_policy::move); + m.def( + "set_nodes_states", + [](const mio::osecir::Parameters& params, mio::Date start_date, mio::Date end_date, + const std::string& data_dir, const std::string& population_data_path, bool is_node_for_county, + mio::Graph, mio::MobilityParameters>& params_graph, + const std::vector& scaling_factor_inf, double scaling_factor_icu, double tnt_capacity_factor, + int num_days = 0, bool export_time_series = false) { + auto result = mio::set_nodes< + mio::osecir::TestAndTraceCapacity, mio::osecir::ContactPatterns, + mio::osecir::Model, mio::MobilityParameters, mio::osecir::Parameters, + decltype(mio::osecir::read_input_data_state>), decltype(mio::get_node_ids)>( + params, start_date, end_date, data_dir, population_data_path, is_node_for_county, params_graph, + mio::osecir::read_input_data_state>, mio::get_node_ids, scaling_factor_inf, + scaling_factor_icu, tnt_capacity_factor, num_days, export_time_series); + return pymio::check_and_throw(result); + }, + py::return_value_policy::move); + + m.def( + "set_node_germany", + [](const mio::osecir::Parameters& params, mio::Date start_date, mio::Date end_date, + const std::string& data_dir, const std::string& population_data_path, bool is_node_for_county, + mio::Graph, mio::MobilityParameters>& params_graph, + const std::vector& scaling_factor_inf, double scaling_factor_icu, double tnt_capacity_factor, + int num_days = 0, bool export_time_series = false) { + auto result = mio::set_nodes, + mio::osecir::ContactPatterns, mio::osecir::Model, + mio::MobilityParameters, mio::osecir::Parameters, + decltype(mio::osecir::read_input_data_germany>), + decltype(mio::get_country_id)>( + params, start_date, end_date, data_dir, population_data_path, is_node_for_county, params_graph, + mio::osecir::read_input_data_germany>, mio::get_country_id, + scaling_factor_inf, scaling_factor_icu, tnt_capacity_factor, num_days, export_time_series); + return pymio::check_and_throw(result); + }, + py::return_value_policy::move); + + m.def( + "set_nodes_provincias", + [](const mio::osecir::Parameters& params, mio::Date start_date, mio::Date end_date, + const std::string& data_dir, const std::string& population_data_path, bool is_node_for_county, + mio::Graph, mio::MobilityParameters>& params_graph, + const std::vector& scaling_factor_inf, double scaling_factor_icu, double tnt_capacity_factor, + int num_days = 0, bool export_time_series = false) { + auto result = mio::set_nodes, + mio::osecir::ContactPatterns, mio::osecir::Model, + mio::MobilityParameters, mio::osecir::Parameters, + decltype(mio::osecir::read_input_data_provincias>), + decltype(mio::get_provincia_ids)>( + params, start_date, end_date, data_dir, population_data_path, is_node_for_county, params_graph, + mio::osecir::read_input_data_provincias>, mio::get_provincia_ids, + scaling_factor_inf, scaling_factor_icu, tnt_capacity_factor, num_days, export_time_series); + return pymio::check_and_throw(result); + }, + py::return_value_policy::move); + pymio::iterable_enum(m, "ContactLocation") .value("Home", ContactLocation::Home) .value("School", ContactLocation::School) @@ -278,12 +340,11 @@ PYBIND11_MODULE(_simulation_osecir, m) auto mobile_comp = {mio::osecir::InfectionState::Susceptible, mio::osecir::InfectionState::Exposed, mio::osecir::InfectionState::InfectedNoSymptoms, mio::osecir::InfectionState::InfectedSymptoms, mio::osecir::InfectionState::Recovered}; - auto weights = std::vector{0., 0., 1.0, 1.0, 0.33, 0., 0.}; + // auto weights = std::vector{1.0}; auto result = mio::set_edges, mio::MobilityParameters, mio::MobilityCoefficientGroup, mio::osecir::InfectionState, - decltype(mio::read_mobility_plain)>(mobility_data_file, params_graph, - mobile_comp, contact_locations_size, - mio::read_mobility_plain, weights); + decltype(mio::read_mobility_plain)>( + mobility_data_file, params_graph, mobile_comp, contact_locations_size, mio::read_mobility_plain, {1.0}); return pymio::check_and_throw(result); }, py::return_value_policy::move); @@ -294,6 +355,7 @@ PYBIND11_MODULE(_simulation_osecir, m) #ifdef MEMILIO_HAS_JSONCPP pymio::bind_write_graph>(m); + pymio::bind_read_graph>(m); m.def( "read_input_data_county", [](std::vector>& model, mio::Date date, const std::vector& county, @@ -304,6 +366,16 @@ PYBIND11_MODULE(_simulation_osecir, m) return pymio::check_and_throw(result); }, py::return_value_policy::move); + m.def( + "read_input_data_provincias", + [](std::vector>& model, mio::Date date, const std::vector& provincias, + const std::vector& scaling_factor_inf, double scaling_factor_icu, const std::string& dir, + int num_days = 0, bool export_time_series = false) { + auto result = mio::osecir::read_input_data_provincias>( + model, date, provincias, scaling_factor_inf, scaling_factor_icu, dir, num_days, export_time_series); + return pymio::check_and_throw(result); + }, + py::return_value_policy::move); #endif // MEMILIO_HAS_JSONCPP m.def("interpolate_simulation_result", diff --git a/shellscripts/fitting_graphmodel.sh b/shellscripts/fitting_graphmodel.sh new file mode 100644 index 0000000000..fa53bc085b --- /dev/null +++ b/shellscripts/fitting_graphmodel.sh @@ -0,0 +1,15 @@ +#!/bin/bash +#SBATCH -N 1 +#SBATCH -n 1 +#SBATCH -c 1 +#SBATCH -t 5-0:00:00 +#SBATCH --output=shellscripts/train_countylvl-%A.out +#SBATCH --error=shellscripts/train_countylvl-%A.err +#SBATCH --job-name=train_countylvl +#SBATCH --partition=gpu +#SBATCH --gpus=1 + +module load PrgEnv/gcc13-openmpi-python +module load cuda/12.9.0-none-none-6wnenm2 +source venv/bin/activate +srun --cpu-bind=core python pycode/examples/simulation/graph_germany_nuts3.py \ No newline at end of file