Skip to content

Commit 013b7ab

Browse files
committed
Add more SHAP explanation plots and embedding anomalies
1 parent 61f7a84 commit 013b7ab

File tree

1 file changed

+360
-1
lines changed

1 file changed

+360
-1
lines changed

domains/anomaly-detection/explore/AnomalyDetectionIsolationForestExploration.ipynb

Lines changed: 360 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,6 @@
931931
"outputs": [],
932932
"source": [
933933
"java_package_shap_values = explain_anomalies_with_shap(\n",
934-
" # random_forest_model=java_package_proxy_random_forest,\n",
935934
" random_forest_model=java_package_anomaly_detection_results.random_forest_classifier,\n",
936935
" prepared_features=java_package_anomaly_detection_features_prepared\n",
937936
")\n",
@@ -944,6 +943,366 @@
944943
")"
945944
]
946945
},
946+
{
947+
"cell_type": "code",
948+
"execution_count": null,
949+
"id": "50ce9cbb",
950+
"metadata": {},
951+
"outputs": [],
952+
"source": [
953+
"# TODO delete next section if not used anymore"
954+
]
955+
},
956+
{
957+
"cell_type": "markdown",
958+
"id": "34a35377",
959+
"metadata": {},
960+
"source": [
961+
"\n",
962+
"\n",
963+
"### 🔍 1. **Summary Plot (Beeswarm)**\n",
964+
"\n",
965+
"```python\n",
966+
"shap.summary_plot(shap_values, X, feature_names=feature_names)\n",
967+
"```\n",
968+
"\n",
969+
"* **Best for:** Global understanding of which features drive anomalies.\n",
970+
"* **Adds:** Direction of impact (color shows feature value).\n",
971+
"* **Why:** Useful when you want to see how values push predictions toward normal or anomalous.\n",
972+
"\n",
973+
"---\n",
974+
"\n",
975+
"### 🧠 2. **Force Plot**\n",
976+
"\n",
977+
"```python\n",
978+
"shap.initjs()\n",
979+
"shap.force_plot(\n",
980+
" explainer.expected_value[1], # For class \"anomaly\"\n",
981+
" shap_values[1][i], # For specific instance i\n",
982+
" X.iloc[i], # Same instance input\n",
983+
" feature_names=feature_names\n",
984+
")\n",
985+
"```\n",
986+
"\n",
987+
"* **Best for:** Explaining *why a specific data point* is anomalous.\n",
988+
"* **Adds:** Visual breakdown of how each feature contributes to the score.\n",
989+
"* **Why:** Highly interpretable for debugging single nodes.\n",
990+
"\n",
991+
"---\n",
992+
"\n",
993+
"### 📈 3. **Dependence Plot**\n",
994+
"\n",
995+
"```python\n",
996+
"shap.dependence_plot(\"PageRank\", shap_values[1], X, feature_names=feature_names)\n",
997+
"```\n",
998+
"\n",
999+
"* **Best for:** Understanding how *one feature* affects anomaly scores.\n",
1000+
"* **Adds:** Color can show interaction with another feature.\n",
1001+
"* **Why:** Helps discover *nonlinear effects or interaction terms*.\n",
1002+
"\n",
1003+
"---\n",
1004+
"\n",
1005+
"### 🔗 4. **Interaction Value Plots**\n",
1006+
"\n",
1007+
"If your model was trained with `TreeExplainer(model, feature_perturbation=\"tree_path_dependent\")`, you can use:\n",
1008+
"\n",
1009+
"```python\n",
1010+
"shap_interaction_values = explainer.shap_interaction_values(X)\n",
1011+
"shap.summary_plot(shap_interaction_values[1], X)\n",
1012+
"```\n",
1013+
"\n",
1014+
"* **Best for:** Revealing how features *interact* in creating anomalies.\n",
1015+
"* **Adds:** Pairs of features contributing together.\n",
1016+
"* **Why:** Especially interesting with graph metrics + embedding components.\n",
1017+
"\n",
1018+
"---\n",
1019+
"\n",
1020+
"### 🧭 5. **Decision Plot**\n",
1021+
"\n",
1022+
"```python\n",
1023+
"shap.decision_plot(\n",
1024+
" explainer.expected_value[1],\n",
1025+
" shap_values[1][sample_indices],\n",
1026+
" X.iloc[sample_indices],\n",
1027+
" feature_names=feature_names\n",
1028+
")\n",
1029+
"```\n",
1030+
"\n",
1031+
"* **Best for:** Tracing how a model arrives at a decision.\n",
1032+
"* **Adds:** Shows cumulative impact of features.\n",
1033+
"* **Why:** Good for **comparing multiple instances** and identifying tipping-point features.\n",
1034+
"\n",
1035+
"---\n",
1036+
"\n",
1037+
"### 🧊 6. **Waterfall Plot**\n",
1038+
"\n",
1039+
"```python\n",
1040+
"shap.plots.waterfall(shap.Explanation(\n",
1041+
" values=shap_values[1][i],\n",
1042+
" base_values=explainer.expected_value[1],\n",
1043+
" data=X.iloc[i],\n",
1044+
" feature_names=feature_names\n",
1045+
"))\n",
1046+
"```\n",
1047+
"\n",
1048+
"* **Best for:** Clear breakdown of prediction into additive components.\n",
1049+
"* **Why:** Cleaner than force plot; great in reports or UI.\n",
1050+
"\n",
1051+
"---\n",
1052+
"\n",
1053+
"### ✅ Recommendations for Your Use Case (Code Graph Anomaly Detection):\n",
1054+
"\n",
1055+
"| Goal | Recommended Plot |\n",
1056+
"| -------------------------------- | ------------------------------ |\n",
1057+
"| Global feature influence | Summary Plot (bar or beeswarm) |\n",
1058+
"| Understand single anomaly | Force Plot / Waterfall |\n",
1059+
"| Explore how a feature influences | Dependence Plot |\n",
1060+
"| Discover interactions | Interaction Plot |\n",
1061+
"| Debug how decision was made | Decision Plot |\n",
1062+
"\n",
1063+
"Let me know what type of insight you're most interested in (e.g., per node, across the graph, per anomaly cluster), and I can recommend specific plot setups or generate templates for you.\n"
1064+
]
1065+
},
1066+
{
1067+
"cell_type": "code",
1068+
"execution_count": null,
1069+
"id": "7025dd65",
1070+
"metadata": {},
1071+
"outputs": [],
1072+
"source": [
1073+
"def plot_shap_explained_beeswarm(\n",
1074+
" shap_values: numpy_typing.NDArray,\n",
1075+
" prepared_features: numpy_typing.NDArray,\n",
1076+
" feature_names: list[str],\n",
1077+
" title_prefix: str = \"\",\n",
1078+
") -> None:\n",
1079+
" \"\"\"\n",
1080+
" Explain anomalies using SHAP values and plot the global feature importance as a \"beeswarm\".\n",
1081+
" This function uses the SHAP library to visualize the impact of features on the model's predictions\n",
1082+
" for anomalies detected by the Isolation Forest model via the Random Forest proxy model.\n",
1083+
" \"\"\"\n",
1084+
"\n",
1085+
" shap.summary_plot(\n",
1086+
" shap_values[:, :, 1], # Class 1 = anomaly\n",
1087+
" prepared_features[:],\n",
1088+
" feature_names=feature_names,\n",
1089+
" plot_type=\"dot\",\n",
1090+
" max_display=20,\n",
1091+
" plot_size=(12, 6), # (width, height) in inches\n",
1092+
" show=False\n",
1093+
" )\n",
1094+
" plot.title(f\"How {title_prefix} features contribute to the anomaly score (beeswarm plot)\", fontsize=12)\n",
1095+
" plot.show()"
1096+
]
1097+
},
1098+
{
1099+
"cell_type": "code",
1100+
"execution_count": null,
1101+
"id": "ec6676c7",
1102+
"metadata": {},
1103+
"outputs": [],
1104+
"source": [
1105+
"plot_shap_explained_beeswarm(\n",
1106+
" shap_values=java_package_shap_values,\n",
1107+
" prepared_features=java_package_anomaly_detection_features_prepared,\n",
1108+
" feature_names=java_package_anomaly_detection_feature_names,\n",
1109+
" title_prefix=\"Java Package\"\n",
1110+
")"
1111+
]
1112+
},
1113+
{
1114+
"cell_type": "code",
1115+
"execution_count": null,
1116+
"id": "9b5a523d",
1117+
"metadata": {},
1118+
"outputs": [],
1119+
"source": [
1120+
"def plot_shap_explained_local_feature_importance(\n",
1121+
" index_to_explain,\n",
1122+
" random_forest_model: RandomForestClassifier,\n",
1123+
" prepared_features: np.ndarray,\n",
1124+
" feature_names: list[str],\n",
1125+
" title_prefix: str = \"\",\n",
1126+
" rounding_precision: int = 3,\n",
1127+
"):\n",
1128+
" # TODO Take explainer as input parameter\n",
1129+
" explainer = shap.TreeExplainer(random_forest_model)\n",
1130+
" shap_values = explainer.shap_values(prepared_features)\n",
1131+
"\n",
1132+
" # print(f\"Input data with prepared features: shape={prepared_features.shape}\")\n",
1133+
" # print(f\"Explainable AI SHAP values: shape={np.shape(shap_values)}\")\n",
1134+
" # print(f\"Explainable AI SHAP expected_value: shape={np.shape(explainer.expected_value)}\")\n",
1135+
" # print(f\"Explainable AI SHAP expected_value: type={type(explainer.expected_value)}\")\n",
1136+
" # print(f\"Explaining instance at index {index_to_explain} with anomaly label: {original_features.iloc[index_to_explain][anomaly_label_column]}\")\n",
1137+
"\n",
1138+
" shap_values_rounded = np.round(shap_values[:,:, 1][index_to_explain], rounding_precision)\n",
1139+
" prepared_features_rounded = prepared_features[:][index_to_explain].round(rounding_precision)\n",
1140+
" base_value_rounded = np.round(typing.cast(np.ndarray,explainer.expected_value)[1], rounding_precision)\n",
1141+
"\n",
1142+
" shap.force_plot(\n",
1143+
" base_value_rounded, # For class \"anomaly\"\n",
1144+
" # typing.cast(np.ndarray,explainer.expected_value)[1], # For class \"anomaly\"\n",
1145+
" # shap_values[:,:, 1][index_to_explain],\n",
1146+
" shap_values_rounded,\n",
1147+
" prepared_features_rounded,\n",
1148+
" # prepared_features[:][index_to_explain],\n",
1149+
" feature_names=feature_names,\n",
1150+
" matplotlib=True,\n",
1151+
" show=False,\n",
1152+
" contribution_threshold=0.07\n",
1153+
" )\n",
1154+
" plot.title(f\"{title_prefix} anomaly feature {feature_names[index_to_explain]} explained\", fontsize=14, loc='left')\n",
1155+
" plot.show()"
1156+
]
1157+
},
1158+
{
1159+
"cell_type": "code",
1160+
"execution_count": null,
1161+
"id": "77b0852c",
1162+
"metadata": {},
1163+
"outputs": [],
1164+
"source": [
1165+
"plot_shap_explained_local_feature_importance(\n",
1166+
" index_to_explain=4,\n",
1167+
" random_forest_model=java_package_anomaly_detection_results.random_forest_classifier,\n",
1168+
" prepared_features=java_package_anomaly_detection_features_prepared,\n",
1169+
" feature_names=java_package_anomaly_detection_feature_names,\n",
1170+
" title_prefix=\"Java Package\",\n",
1171+
")"
1172+
]
1173+
},
1174+
{
1175+
"cell_type": "raw",
1176+
"id": "2df453b4",
1177+
"metadata": {
1178+
"vscode": {
1179+
"languageId": "raw"
1180+
}
1181+
},
1182+
"source": [
1183+
"# TODO delete if not needed anymore\n",
1184+
"def plot_shap_explained_feature_dependency(\n",
1185+
" index_to_explain: int,\n",
1186+
" random_forest_model: RandomForestClassifier,\n",
1187+
" prepared_features: np.ndarray,\n",
1188+
" feature_names: list[str],\n",
1189+
" title_prefix: str = \"\",\n",
1190+
"):\n",
1191+
" explainer = shap.TreeExplainer(random_forest_model)\n",
1192+
" shap_values = explainer.shap_values(prepared_features)\n",
1193+
"\n",
1194+
" shap.dependence_plot(\n",
1195+
" ind=index_to_explain, # Feature name or index\n",
1196+
" shap_values=shap_values[:, :, 1],\n",
1197+
" features=prepared_features[:],\n",
1198+
" feature_names=feature_names,\n",
1199+
" interaction_index=None, # Set to a feature name/index to see interactions\n",
1200+
" show=False,\n",
1201+
" )\n",
1202+
" plot.title(f\"{title_prefix} Feature contribution to anomaly score\")\n",
1203+
" plot.show()\n",
1204+
"\n",
1205+
"plot_shap_explained_feature_dependency(\n",
1206+
" index_to_explain=2,\n",
1207+
" random_forest_model=java_package_anomaly_detection_results.random_forest_classifier,\n",
1208+
" prepared_features=java_package_anomaly_detection_features_prepared,\n",
1209+
" feature_names=java_package_anomaly_detection_feature_names,\n",
1210+
" title_prefix=\"Java Package\"\n",
1211+
")"
1212+
]
1213+
},
1214+
{
1215+
"cell_type": "code",
1216+
"execution_count": null,
1217+
"id": "fb7e14f9",
1218+
"metadata": {},
1219+
"outputs": [],
1220+
"source": [
1221+
"def plot_shap_explained_top_10_feature_dependence(\n",
1222+
" random_forest_model: RandomForestClassifier,\n",
1223+
" prepared_features: np.ndarray,\n",
1224+
" feature_names: list[str],\n",
1225+
" title_prefix: str = \"\",\n",
1226+
"):\n",
1227+
" explainer = shap.TreeExplainer(random_forest_model)\n",
1228+
" shap_values = explainer.shap_values(prepared_features)\n",
1229+
"\n",
1230+
" mean_abs_shap = np.abs(shap_values[:, :, 1]).mean(axis=0)\n",
1231+
" top_features = np.argsort(mean_abs_shap)[-10:][::-1] # top 10 indices\n",
1232+
" top_feature_names = [feature_names[i] for i in top_features] # Get names of top features\n",
1233+
" \n",
1234+
" figure, axes = plot.subplots(5, 2, figsize=(15, 20)) # 5 rows x 2 columns\n",
1235+
" figure.suptitle(f\"{title_prefix} Anomalies: Top 10 feature dependence plots\", fontsize=16)\n",
1236+
" axes = axes.flatten() # Flatten for easy iteration\n",
1237+
"\n",
1238+
" for index, feature in enumerate(top_feature_names):\n",
1239+
" shap.dependence_plot(\n",
1240+
" ind=feature, # Feature name or index\n",
1241+
" shap_values=shap_values[:, :, 1],\n",
1242+
" features=prepared_features[:],\n",
1243+
" feature_names=feature_names,\n",
1244+
" interaction_index=None, # Set to a feature name/index to see interactions\n",
1245+
" show=False,\n",
1246+
" ax=axes[index]\n",
1247+
" )\n",
1248+
"\n",
1249+
" plot.tight_layout(rect=(0.0, 0.02, 1.0, 0.98))\n",
1250+
" plot.show()\n",
1251+
"\n",
1252+
"plot_shap_explained_top_10_feature_dependence(\n",
1253+
" random_forest_model=java_package_anomaly_detection_results.random_forest_classifier,\n",
1254+
" prepared_features=java_package_anomaly_detection_features_prepared,\n",
1255+
" feature_names=java_package_anomaly_detection_feature_names,\n",
1256+
" title_prefix=\"Java Package\"\n",
1257+
")"
1258+
]
1259+
},
1260+
{
1261+
"cell_type": "raw",
1262+
"id": "1ced99f1",
1263+
"metadata": {
1264+
"vscode": {
1265+
"languageId": "raw"
1266+
}
1267+
},
1268+
"source": [
1269+
"# TODO delete if not needed anymore\n",
1270+
"def plot_shap_explained_heatmap(\n",
1271+
" random_forest_model: RandomForestClassifier,\n",
1272+
" prepared_features: np.ndarray,\n",
1273+
" original_features: pd.DataFrame, \n",
1274+
" feature_names: list[str],\n",
1275+
" title_prefix: str = \"\",\n",
1276+
" anomaly_label_column: str = \"anomalyLabel\"\n",
1277+
"):\n",
1278+
" explainer = shap.TreeExplainer(random_forest_model)\n",
1279+
" shap_values = explainer.shap_values(prepared_features)\n",
1280+
"\n",
1281+
" # Create SHAP Explanation object\n",
1282+
" shap_explanation = shap.Explanation(\n",
1283+
" values=shap_values[:, :, 1],\n",
1284+
" base_values=typing.cast(np.ndarray, explainer.expected_value)[1], # For class \"anomaly\"\n",
1285+
" data=prepared_features[:],\n",
1286+
" feature_names=feature_names\n",
1287+
" )\n",
1288+
"\n",
1289+
" shap.heatmap_plot(\n",
1290+
" shap_explanation, \n",
1291+
" instance_order=\"leaves\", # Optional: use clustering to sort rows\n",
1292+
" show=False,\n",
1293+
" )\n",
1294+
" plot.title(f\"{title_prefix} Anomaly feature heatmap\")\n",
1295+
" plot.show()\n",
1296+
"\n",
1297+
"plot_shap_explained_heatmap(\n",
1298+
" random_forest_model=java_package_anomaly_detection_results.random_forest_classifier,\n",
1299+
" prepared_features=java_package_anomaly_detection_features_prepared,\n",
1300+
" original_features=java_package_anomaly_detection_features,\n",
1301+
" feature_names=java_package_anomaly_detection_feature_names,\n",
1302+
" title_prefix=\"Java Package\"\n",
1303+
")"
1304+
]
1305+
},
9471306
{
9481307
"cell_type": "markdown",
9491308
"id": "27b33560",

0 commit comments

Comments
 (0)