-
Notifications
You must be signed in to change notification settings - Fork 19
1057 data generation for GNNs #1090
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
AgathaSchmidt
wants to merge
34
commits into
main
Choose a base branch
from
1057-GNN-datageneration
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 6 commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
e50eff8
data gerneration graphODE without dampings
AgathaSchmidt 21e6df5
GNN data generator with dampings - graphODE
AgathaSchmidt 95a9d15
add damping information to output
AgathaSchmidt 300b956
adjust to code guidelines
AgathaSchmidt 75ce80d
delete file
AgathaSchmidt 10b6e92
adjust to code guidelines
AgathaSchmidt 9115678
create a make_graph function
AgathaSchmidt c9b26cf
create make graph function
AgathaSchmidt d97a1f7
fix population number at 400
AgathaSchmidt e20fe2d
fix number of population to 400 and rename damping information
AgathaSchmidt d9110e5
add "return data" for case without saving
AgathaSchmidt 8823aa1
tests for simulation run and datageneration for models with and witho…
AgathaSchmidt 74da6d8
Apply suggestions from code review
AgathaSchmidt 47f99eb
apply suggestions from code review
AgathaSchmidt 58fc8e8
a python file with functions frequently used for our GNNs
AgathaSchmidt ffa076e
add get population and get minimum matrix to utils
AgathaSchmidt 0cbda2b
remove get population function from this file and import it from utils
AgathaSchmidt e8fb580
remove functons and import them from utlis
AgathaSchmidt de35ca6
add comments as proposed by review
AgathaSchmidt 3d5a644
pre commit
AgathaSchmidt 2dfbe4f
mock transform mobility function
AgathaSchmidt 86a2e89
adjust utils: add surrogate utils and add scling to GNN utis
AgathaSchmidt 77af0ab
add test for saving mechanism
AgathaSchmidt ce4a494
import functions from new utils file
AgathaSchmidt 63980d9
adjust import
AgathaSchmidt cabf5a1
adjust imports
AgathaSchmidt 06fb8dd
put function which read files outside if the run_simulation function
AgathaSchmidt f2345c7
add directory as parameter
AgathaSchmidt d26528e
set edges only one time
AgathaSchmidt 86c76d5
Merge branch 'main' into 1057-GNN-datageneration
HenrZu e39fd71
new structure for no damp
HenrZu ff6c571
timing graph sim (Delete before Merge!)
HenrZu ac705a4
with_dampings
HenrZu 10d5a98
[ci skip] damping correctly setted and reseted after each run
HenrZu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
328 changes: 328 additions & 0 deletions
328
pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/data_generation_nodamp.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,328 @@ | ||
| import copy | ||
| import os | ||
| import pickle | ||
| import random | ||
| import json | ||
| from datetime import date | ||
| import numpy as np | ||
|
|
||
| from progress.bar import Bar | ||
| from sklearn.preprocessing import FunctionTransformer | ||
|
|
||
| from memilio.simulation import (AgeGroup, LogLevel, set_log_level) | ||
| from memilio.simulation.osecir import (Index_InfectionState, interpolate_simulation_result, ParameterStudy, | ||
| InfectionState, Model, ModelGraph, | ||
| interpolate_simulation_result, set_edges) | ||
| from memilio.epidata import geoModificationGermany as geoger | ||
| from memilio.epidata import transformMobilityData as tmd | ||
| from memilio.epidata import getDataIntoPandasDataFrame as gd | ||
|
|
||
|
|
||
| def run_secir_groups_simulation(days, populations): | ||
| """! Uses an ODE SECIR model allowing for asymptomatic infection with 6 | ||
| different age groups. The model is not stratified by region. | ||
| Virus-specific parameters are fixed and initial number of persons | ||
| in the particular infection states are chosen randomly from defined ranges. | ||
| @param Days Describes how many days we simulate within a single run. | ||
| @param damping_day The day when damping is applied. | ||
| @param populations List containing the population in each age group. | ||
| @return List containing the populations in each compartment used to initialize | ||
| the run. | ||
| """ | ||
|
|
||
| set_log_level(LogLevel.Off) | ||
|
|
||
| start_day = 1 | ||
| start_month = 1 | ||
| start_year = 2019 | ||
| dt = 0.1 | ||
|
|
||
| # get county ids | ||
| countykey_list = geoger.get_county_ids(merge_eisenach=True, zfill=True) | ||
|
|
||
| # Define age Groups | ||
| groups = ['0-4', '5-14', '15-34', '35-59', '60-79', '80+'] | ||
| num_groups = len(groups) | ||
| num_regions = len(populations) | ||
| models = [] | ||
|
|
||
| # Initialize Parameters | ||
| for region in range(num_regions): | ||
| model = Model(num_groups) | ||
|
|
||
| # Set parameters | ||
| for i in range(num_groups): | ||
| # Compartment transition duration | ||
| model.parameters.TimeExposed[AgeGroup(i)] = 3.2 | ||
| model.parameters.TimeInfectedNoSymptoms[AgeGroup(i)] = 2. | ||
| model.parameters.TimeInfectedSymptoms[AgeGroup(i)] = 6. | ||
| model.parameters.TimeInfectedSevere[AgeGroup(i)] = 12. | ||
| model.parameters.TimeInfectedCritical[AgeGroup(i)] = 8. | ||
|
|
||
| # Initial number of people in each compartment with random numbers | ||
AgathaSchmidt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| model.populations[AgeGroup(i), Index_InfectionState( | ||
| InfectionState.Exposed)] = random.uniform( | ||
| 0.00025, 0.005) * populations[region][i] | ||
| model.populations[AgeGroup(i), Index_InfectionState( | ||
| InfectionState.InfectedNoSymptoms)] = random.uniform( | ||
| 0.0001, 0.0035) * populations[region][i] | ||
| model.populations[AgeGroup(i), Index_InfectionState( | ||
| InfectionState.InfectedNoSymptomsConfirmed)] = 0 | ||
| model.populations[AgeGroup(i), Index_InfectionState( | ||
| InfectionState.InfectedSymptoms)] = random.uniform( | ||
| 0.00007, 0.001) * populations[region][i] | ||
| model.populations[AgeGroup(i), Index_InfectionState( | ||
| InfectionState.InfectedSymptomsConfirmed)] = 0 | ||
| model.populations[AgeGroup(i), Index_InfectionState( | ||
| InfectionState.InfectedSevere)] = random.uniform( | ||
| 0.00003, 0.0006) * populations[region][i] | ||
| model.populations[AgeGroup(i), Index_InfectionState( | ||
| InfectionState.InfectedCritical)] = random.uniform( | ||
| 0.00001, 0.0002) * populations[region][i] | ||
| model.populations[AgeGroup(i), Index_InfectionState( | ||
| InfectionState.Recovered)] = random.uniform( | ||
| 0.002, 0.08) * populations[region][i] | ||
| model.populations[AgeGroup(i), | ||
| Index_InfectionState(InfectionState.Dead)] = random.uniform( | ||
| 0, 0.0003) * populations[region][i] | ||
| model.populations.set_difference_from_group_total_AgeGroup( | ||
| (AgeGroup(i), Index_InfectionState(InfectionState.Susceptible)), | ||
| populations[region][i]) | ||
|
|
||
| # Compartment transition propabilities | ||
| model.parameters.RelativeTransmissionNoSymptoms[AgeGroup(i)] = 0.5 | ||
| model.parameters.TransmissionProbabilityOnContact[AgeGroup( | ||
| i)] = 0.1 | ||
| model.parameters.RecoveredPerInfectedNoSymptoms[AgeGroup(i)] = 0.09 | ||
| model.parameters.RiskOfInfectionFromSymptomatic[AgeGroup(i)] = 0.25 | ||
| model.parameters.SeverePerInfectedSymptoms[AgeGroup(i)] = 0.2 | ||
| model.parameters.CriticalPerSevere[AgeGroup(i)] = 0.25 | ||
| model.parameters.DeathsPerCritical[AgeGroup(i)] = 0.3 | ||
| # twice the value of RiskOfInfectionFromSymptomatic | ||
AgathaSchmidt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| model.parameters.MaxRiskOfInfectionFromSymptomatic[AgeGroup( | ||
| i)] = 0.5 | ||
|
|
||
| # StartDay is the n-th day of the year | ||
| model.parameters.StartDay = ( | ||
| date(start_year, start_month, start_day) - date(start_year, 1, 1)).days | ||
|
|
||
| # Load baseline and minimum contact matrix and assign them to the model | ||
| baseline = getBaselineMatrix() | ||
| #minimum = getMinimumMatrix() | ||
AgathaSchmidt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| model.parameters.ContactPatterns.cont_freq_mat[0].baseline = baseline | ||
| model.parameters.ContactPatterns.cont_freq_mat[0].minimum = np.ones( | ||
| (num_groups, num_groups)) * 0 | ||
|
|
||
| # Apply mathematical constraints to parameters | ||
| model.apply_constraints() | ||
| models.append(model) | ||
|
|
||
| graph = ModelGraph() | ||
| for i in range(num_regions): | ||
| graph.add_node(int(countykey_list[i]), models[i]) | ||
|
|
||
| # get mobility data directory | ||
| arg_dict = gd.cli("commuter_official") | ||
|
|
||
| directory = arg_dict['out_folder'].split('/pydata')[0] | ||
| directory = os.path.join(directory, 'mobility/') | ||
|
|
||
| # Merge Eisenach and Wartbugkreis in Input Data | ||
| tmd.updateMobility2022(directory, mobility_file='twitter_scaled_1252') | ||
| tmd.updateMobility2022(directory, mobility_file='commuter_migration_scaled') | ||
AgathaSchmidt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| num_locations = 4 | ||
|
|
||
| set_edges(os.path.abspath(os.path.join(directory, os.pardir)), | ||
| graph, num_locations) | ||
|
|
||
| study = ParameterStudy(graph, 0, days, dt=dt, num_runs=1) | ||
| study.run() | ||
|
|
||
| graph_run = study.run()[0] | ||
| results = interpolate_simulation_result(graph_run) | ||
|
|
||
| for result_indx in range(len(results)): | ||
| results[result_indx] = remove_confirmed_compartments( | ||
| np.transpose(results[result_indx].as_ndarray()[1:, :]), num_groups) | ||
|
|
||
| # Omit first column, as the time points are not of interest here. | ||
| dataset_entry = copy.deepcopy(results) | ||
|
|
||
| return dataset_entry | ||
|
|
||
| def remove_confirmed_compartments(dataset_entries, num_groups): | ||
| """! The compartments which contain confirmed cases are not needed and are | ||
| therefore omitted by summarizing the confirmed compartment with the | ||
| original compartment. | ||
| @param dataset_entries Array that contains the compartmental data with | ||
| confirmed compartments. | ||
| @param num_groups Number of age groups. | ||
| @return Array that contains the compartmental data without confirmed compartments. | ||
| """ | ||
|
|
||
| new_dataset_entries = [] | ||
| for i in dataset_entries : | ||
| dataset_entries_reshaped = i.reshape( | ||
| [num_groups, int(np.asarray(dataset_entries).shape[1]/num_groups)] | ||
| ) | ||
| sum_inf_no_symp = np.sum(dataset_entries_reshaped [:, [2, 3]], axis=1) | ||
| sum_inf_symp = np.sum(dataset_entries_reshaped [:, [4, 5]], axis=1) | ||
| dataset_entries_reshaped[:, 2] = sum_inf_no_symp | ||
| dataset_entries_reshaped[:, 4] = sum_inf_symp | ||
| new_dataset_entries.append( | ||
| np.delete(dataset_entries_reshaped , [3, 5], axis=1).flatten() | ||
| ) | ||
| return new_dataset_entries | ||
|
|
||
|
|
||
| def get_population(path="data/pydata/Germany/county_population.json"): | ||
AgathaSchmidt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """! loads population data | ||
| @param path Path to population file. | ||
| @return List with all 400 populations and 6 age groups. | ||
| """ | ||
|
|
||
| with open(path) as f: | ||
| data = json.load(f) | ||
| population = [] | ||
| for data_entry in data: | ||
| population_county = [] | ||
| population_county.append( | ||
| data_entry['<3 years'] + data_entry['3-5 years'] / 2) | ||
| population_county.append(data_entry['6-14 years']) | ||
| population_county.append( | ||
| data_entry['15-17 years'] + data_entry['18-24 years'] + | ||
| data_entry['25-29 years'] + data_entry['30-39 years'] / 2) | ||
| population_county.append( | ||
| data_entry['30-39 years'] / 2 + data_entry['40-49 years'] + | ||
| data_entry['50-64 years'] * 2 / 3) | ||
| population_county.append( | ||
| data_entry['65-74 years'] + data_entry['>74 years'] * 0.2 + | ||
| data_entry['50-64 years'] * 1 / 3) | ||
| population_county.append( | ||
| data_entry['>74 years'] * 0.8) | ||
|
|
||
| population.append(population_county) | ||
| return population | ||
|
|
||
| def getBaselineMatrix(): | ||
| """! loads the baselinematrix | ||
| """ | ||
|
|
||
| baseline_contact_matrix0 = os.path.join( | ||
| "./data/contacts/baseline_home.txt") | ||
| baseline_contact_matrix1 = os.path.join( | ||
| "./data/contacts/baseline_school_pf_eig.txt") | ||
| baseline_contact_matrix2 = os.path.join( | ||
| "./data/contacts/baseline_work.txt") | ||
| baseline_contact_matrix3 = os.path.join( | ||
| "./data/contacts/baseline_other.txt") | ||
|
|
||
| baseline = np.loadtxt(baseline_contact_matrix0) \ | ||
| + np.loadtxt(baseline_contact_matrix1) + \ | ||
| np.loadtxt(baseline_contact_matrix2) + \ | ||
| np.loadtxt(baseline_contact_matrix3) | ||
|
|
||
| return baseline | ||
|
|
||
| def generate_data( | ||
| num_runs, path, input_width, days, save_data=True): | ||
| """! Generate dataset by calling run_secir_simulation (num_runs)-often | ||
| @param num_runs Number of times, the function run_secir_simulation is called. | ||
| @param path Path, where the datasets are stored. | ||
| @param input_width number of time steps used for model input. | ||
| @param label_width number of time steps (days) used as model output/label. | ||
| @param save_data Option to deactivate the save of the dataset. Per default true. | ||
| """ | ||
|
|
||
| population = get_population() | ||
| days_sum = days + input_width - 1 | ||
|
|
||
| data = {"inputs": [], | ||
| "labels": [], | ||
| } | ||
|
|
||
| # show progess in terminal for longer runs | ||
| # Due to the random structure, theres currently no need to shuffle the data | ||
| bar = Bar('Number of Runs done', max=num_runs) | ||
|
|
||
| for _ in range(num_runs): | ||
|
|
||
| data_run = run_secir_groups_simulation( | ||
| days_sum, population) | ||
|
|
||
| inputs = np.asarray(data_run).transpose(1, 2, 0)[: input_width] | ||
| data["inputs"].append(inputs) | ||
|
|
||
| data["labels"].append(np.asarray(data_run).transpose(1, 2, 0)[input_width:]) | ||
|
|
||
| bar.next() | ||
|
|
||
| bar.finish() | ||
|
|
||
| if save_data: | ||
| num_groups = int(np.asarray(data['inputs']).shape[2] / 8) | ||
|
||
| transformer = FunctionTransformer(np.log1p, validate=True) | ||
AgathaSchmidt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Scale inputs | ||
| inputs = np.asarray( | ||
| data['inputs']).transpose(2, 0, 1, 3).reshape(num_groups * 8, -1) | ||
| scaled_inputs = transformer.transform(inputs) | ||
| original_shape_input = np.asarray(data['inputs']).shape | ||
|
|
||
| # Step 1: Reverse the reshape | ||
| reshaped_back = scaled_inputs.reshape(original_shape_input[2], | ||
| original_shape_input[0], | ||
| original_shape_input[1], | ||
| original_shape_input[3]) | ||
|
|
||
| # Step 2: Reverse the transpose | ||
| original_inputs = reshaped_back.transpose(1, 2, 0, 3) | ||
| scaled_inputs = original_inputs.transpose(0, 3, 1, 2) | ||
|
|
||
|
|
||
| # Scale labels | ||
| labels = np.asarray( | ||
| data['labels']).transpose(2, 0, 1, 3).reshape(num_groups * 8, -1) | ||
| scaled_labels = transformer.transform(labels) | ||
| original_shape_labels = np.asarray(data['labels']).shape | ||
|
|
||
| # Step 1: Reverse the reshape | ||
AgathaSchmidt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| reshaped_back = scaled_labels.reshape(original_shape_labels[2], | ||
| original_shape_labels[0], | ||
| original_shape_labels[1], | ||
| original_shape_labels[3]) | ||
|
|
||
| # Step 2: Reverse the transpose | ||
| original_labels = reshaped_back.transpose(1, 2, 0, 3) | ||
| scaled_labels = original_labels.transpose(0, 3, 1, 2) | ||
|
|
||
| all_data = {"inputs": scaled_inputs, | ||
| "labels": scaled_labels, | ||
| } | ||
|
|
||
| # check if data directory exists. If necessary create it. | ||
| if not os.path.isdir(path): | ||
| os.mkdir(path) | ||
|
|
||
| # save dict to json file | ||
| with open(os.path.join(path, 'data_secir_age_groups.pickle'), 'wb') as f: | ||
| pickle.dump(all_data, f) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
| path = os.path.dirname(os.path.realpath(__file__)) | ||
| path_data = os.path.join( | ||
| os.path.dirname( | ||
| os.path.realpath(os.path.dirname(os.path.realpath(path)))), | ||
| 'data_GNN_nodamp') | ||
|
|
||
| input_width = 5 | ||
| days = 30 | ||
| num_runs = 1000 | ||
| number_of_populations = 400 | ||
| generate_data(num_runs, path_data, input_width, | ||
| days, number_of_populations) | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.