diff --git a/examples/analysis.py b/examples/analysis.py new file mode 100644 index 00000000..e47e11f2 --- /dev/null +++ b/examples/analysis.py @@ -0,0 +1,13 @@ +from progpy.analysis import show_heatmap +from progpy.datasets import nasa_cmapss + +(training, testing, eol) = nasa_cmapss.load_data(1) + +show_heatmap(training) + +# Notice that some values have no color- this is because they are constant. Let's drop these +for feature in ['setting3', 'sensor1', 'sensor5', 'sensor10', 'sensor16', 'sensor18', 'sensor19']: + training.drop(feature, axis=1) +show_heatmap(training) + +# Here you can see high correlations between sensor 14 and 9 diff --git a/src/progpy/analysis/__init__.py b/src/progpy/analysis/__init__.py new file mode 100644 index 00000000..494a52ce --- /dev/null +++ b/src/progpy/analysis/__init__.py @@ -0,0 +1,4 @@ +# Copyright © 2021 United States Government as represented by the Administrator of the +# National Aeronautics and Space Administration. All Rights Reserved. + +from progpy.analysis.heatmap import show_heatmap \ No newline at end of file diff --git a/src/progpy/analysis/heatmap.py b/src/progpy/analysis/heatmap.py new file mode 100644 index 00000000..9bbd4684 --- /dev/null +++ b/src/progpy/analysis/heatmap.py @@ -0,0 +1,23 @@ +# Copyright © 2021 United States Government as represented by the Administrator of the +# National Aeronautics and Space Administration. All Rights Reserved. + +import matplotlib.pyplot as plt + +def show_heatmap(data): + """ + Generate a heatmap showing correlation between parameters. + + Code from: https://github.com/keras-team/keras-io/blob/13d513d7375656a14698ba4827ebbb4177efcf43/examples/timeseries/timeseries_weather_forecasting.py#L152 + + Args: + data (np.ndarray): Array of data where each column is a variable. + """ + plt.matshow(data.corr()) + plt.xticks(range(data.shape[1]), data.columns, fontsize=14, rotation=90) + plt.gca().xaxis.tick_bottom() + plt.yticks(range(data.shape[1]), data.columns, fontsize=14) + + cb = plt.colorbar() + cb.ax.tick_params(labelsize=14) + plt.title("Feature Correlation Heatmap", fontsize=14) + plt.show() \ No newline at end of file