Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions examples/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from progpy.analysis import show_heatmap
from progpy.datasets import nasa_cmapss

(training, testing, eol) = nasa_cmapss.load_data(1)

show_heatmap(training)

# Notice that some values have no color- this is because they are constant. Let's drop these
for feature in ['setting3', 'sensor1', 'sensor5', 'sensor10', 'sensor16', 'sensor18', 'sensor19']:
training.drop(feature, axis=1)
show_heatmap(training)

# Here you can see high correlations between sensor 14 and 9
4 changes: 4 additions & 0 deletions src/progpy/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright © 2021 United States Government as represented by the Administrator of the
# National Aeronautics and Space Administration. All Rights Reserved.

from progpy.analysis.heatmap import show_heatmap
23 changes: 23 additions & 0 deletions src/progpy/analysis/heatmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright © 2021 United States Government as represented by the Administrator of the
# National Aeronautics and Space Administration. All Rights Reserved.

import matplotlib.pyplot as plt

def show_heatmap(data):
"""
Generate a heatmap showing correlation between parameters.

Code from: https://github.com/keras-team/keras-io/blob/13d513d7375656a14698ba4827ebbb4177efcf43/examples/timeseries/timeseries_weather_forecasting.py#L152

Args:
data (np.ndarray): Array of data where each column is a variable.
"""
plt.matshow(data.corr())
plt.xticks(range(data.shape[1]), data.columns, fontsize=14, rotation=90)
plt.gca().xaxis.tick_bottom()
plt.yticks(range(data.shape[1]), data.columns, fontsize=14)

cb = plt.colorbar()
cb.ax.tick_params(labelsize=14)
plt.title("Feature Correlation Heatmap", fontsize=14)
plt.show()
Loading