|
931 | 931 | "outputs": [], |
932 | 932 | "source": [ |
933 | 933 | "java_package_shap_values = explain_anomalies_with_shap(\n", |
934 | | - " # random_forest_model=java_package_proxy_random_forest,\n", |
935 | 934 | " random_forest_model=java_package_anomaly_detection_results.random_forest_classifier,\n", |
936 | 935 | " prepared_features=java_package_anomaly_detection_features_prepared\n", |
937 | 936 | ")\n", |
|
944 | 943 | ")" |
945 | 944 | ] |
946 | 945 | }, |
| 946 | + { |
| 947 | + "cell_type": "code", |
| 948 | + "execution_count": null, |
| 949 | + "id": "50ce9cbb", |
| 950 | + "metadata": {}, |
| 951 | + "outputs": [], |
| 952 | + "source": [ |
| 953 | + "# TODO delete next section if not used anymore" |
| 954 | + ] |
| 955 | + }, |
| 956 | + { |
| 957 | + "cell_type": "markdown", |
| 958 | + "id": "34a35377", |
| 959 | + "metadata": {}, |
| 960 | + "source": [ |
| 961 | + "\n", |
| 962 | + "\n", |
| 963 | + "### 🔍 1. **Summary Plot (Beeswarm)**\n", |
| 964 | + "\n", |
| 965 | + "```python\n", |
| 966 | + "shap.summary_plot(shap_values, X, feature_names=feature_names)\n", |
| 967 | + "```\n", |
| 968 | + "\n", |
| 969 | + "* **Best for:** Global understanding of which features drive anomalies.\n", |
| 970 | + "* **Adds:** Direction of impact (color shows feature value).\n", |
| 971 | + "* **Why:** Useful when you want to see how values push predictions toward normal or anomalous.\n", |
| 972 | + "\n", |
| 973 | + "---\n", |
| 974 | + "\n", |
| 975 | + "### 🧠 2. **Force Plot**\n", |
| 976 | + "\n", |
| 977 | + "```python\n", |
| 978 | + "shap.initjs()\n", |
| 979 | + "shap.force_plot(\n", |
| 980 | + " explainer.expected_value[1], # For class \"anomaly\"\n", |
| 981 | + " shap_values[1][i], # For specific instance i\n", |
| 982 | + " X.iloc[i], # Same instance input\n", |
| 983 | + " feature_names=feature_names\n", |
| 984 | + ")\n", |
| 985 | + "```\n", |
| 986 | + "\n", |
| 987 | + "* **Best for:** Explaining *why a specific data point* is anomalous.\n", |
| 988 | + "* **Adds:** Visual breakdown of how each feature contributes to the score.\n", |
| 989 | + "* **Why:** Highly interpretable for debugging single nodes.\n", |
| 990 | + "\n", |
| 991 | + "---\n", |
| 992 | + "\n", |
| 993 | + "### 📈 3. **Dependence Plot**\n", |
| 994 | + "\n", |
| 995 | + "```python\n", |
| 996 | + "shap.dependence_plot(\"PageRank\", shap_values[1], X, feature_names=feature_names)\n", |
| 997 | + "```\n", |
| 998 | + "\n", |
| 999 | + "* **Best for:** Understanding how *one feature* affects anomaly scores.\n", |
| 1000 | + "* **Adds:** Color can show interaction with another feature.\n", |
| 1001 | + "* **Why:** Helps discover *nonlinear effects or interaction terms*.\n", |
| 1002 | + "\n", |
| 1003 | + "---\n", |
| 1004 | + "\n", |
| 1005 | + "### 🔗 4. **Interaction Value Plots**\n", |
| 1006 | + "\n", |
| 1007 | + "If your model was trained with `TreeExplainer(model, feature_perturbation=\"tree_path_dependent\")`, you can use:\n", |
| 1008 | + "\n", |
| 1009 | + "```python\n", |
| 1010 | + "shap_interaction_values = explainer.shap_interaction_values(X)\n", |
| 1011 | + "shap.summary_plot(shap_interaction_values[1], X)\n", |
| 1012 | + "```\n", |
| 1013 | + "\n", |
| 1014 | + "* **Best for:** Revealing how features *interact* in creating anomalies.\n", |
| 1015 | + "* **Adds:** Pairs of features contributing together.\n", |
| 1016 | + "* **Why:** Especially interesting with graph metrics + embedding components.\n", |
| 1017 | + "\n", |
| 1018 | + "---\n", |
| 1019 | + "\n", |
| 1020 | + "### 🧭 5. **Decision Plot**\n", |
| 1021 | + "\n", |
| 1022 | + "```python\n", |
| 1023 | + "shap.decision_plot(\n", |
| 1024 | + " explainer.expected_value[1],\n", |
| 1025 | + " shap_values[1][sample_indices],\n", |
| 1026 | + " X.iloc[sample_indices],\n", |
| 1027 | + " feature_names=feature_names\n", |
| 1028 | + ")\n", |
| 1029 | + "```\n", |
| 1030 | + "\n", |
| 1031 | + "* **Best for:** Tracing how a model arrives at a decision.\n", |
| 1032 | + "* **Adds:** Shows cumulative impact of features.\n", |
| 1033 | + "* **Why:** Good for **comparing multiple instances** and identifying tipping-point features.\n", |
| 1034 | + "\n", |
| 1035 | + "---\n", |
| 1036 | + "\n", |
| 1037 | + "### 🧊 6. **Waterfall Plot**\n", |
| 1038 | + "\n", |
| 1039 | + "```python\n", |
| 1040 | + "shap.plots.waterfall(shap.Explanation(\n", |
| 1041 | + " values=shap_values[1][i],\n", |
| 1042 | + " base_values=explainer.expected_value[1],\n", |
| 1043 | + " data=X.iloc[i],\n", |
| 1044 | + " feature_names=feature_names\n", |
| 1045 | + "))\n", |
| 1046 | + "```\n", |
| 1047 | + "\n", |
| 1048 | + "* **Best for:** Clear breakdown of prediction into additive components.\n", |
| 1049 | + "* **Why:** Cleaner than force plot; great in reports or UI.\n", |
| 1050 | + "\n", |
| 1051 | + "---\n", |
| 1052 | + "\n", |
| 1053 | + "### ✅ Recommendations for Your Use Case (Code Graph Anomaly Detection):\n", |
| 1054 | + "\n", |
| 1055 | + "| Goal | Recommended Plot |\n", |
| 1056 | + "| -------------------------------- | ------------------------------ |\n", |
| 1057 | + "| Global feature influence | Summary Plot (bar or beeswarm) |\n", |
| 1058 | + "| Understand single anomaly | Force Plot / Waterfall |\n", |
| 1059 | + "| Explore how a feature influences | Dependence Plot |\n", |
| 1060 | + "| Discover interactions | Interaction Plot |\n", |
| 1061 | + "| Debug how decision was made | Decision Plot |\n", |
| 1062 | + "\n", |
| 1063 | + "Let me know what type of insight you're most interested in (e.g., per node, across the graph, per anomaly cluster), and I can recommend specific plot setups or generate templates for you.\n" |
| 1064 | + ] |
| 1065 | + }, |
| 1066 | + { |
| 1067 | + "cell_type": "code", |
| 1068 | + "execution_count": null, |
| 1069 | + "id": "7025dd65", |
| 1070 | + "metadata": {}, |
| 1071 | + "outputs": [], |
| 1072 | + "source": [ |
| 1073 | + "def plot_shap_explained_beeswarm(\n", |
| 1074 | + " shap_values: numpy_typing.NDArray,\n", |
| 1075 | + " prepared_features: numpy_typing.NDArray,\n", |
| 1076 | + " feature_names: list[str],\n", |
| 1077 | + " title_prefix: str = \"\",\n", |
| 1078 | + ") -> None:\n", |
| 1079 | + " \"\"\"\n", |
| 1080 | + " Explain anomalies using SHAP values and plot the global feature importance as a \"beeswarm\".\n", |
| 1081 | + " This function uses the SHAP library to visualize the impact of features on the model's predictions\n", |
| 1082 | + " for anomalies detected by the Isolation Forest model via the Random Forest proxy model.\n", |
| 1083 | + " \"\"\"\n", |
| 1084 | + "\n", |
| 1085 | + " shap.summary_plot(\n", |
| 1086 | + " shap_values[:, :, 1], # Class 1 = anomaly\n", |
| 1087 | + " prepared_features[:],\n", |
| 1088 | + " feature_names=feature_names,\n", |
| 1089 | + " plot_type=\"dot\",\n", |
| 1090 | + " max_display=20,\n", |
| 1091 | + " plot_size=(12, 6), # (width, height) in inches\n", |
| 1092 | + " show=False\n", |
| 1093 | + " )\n", |
| 1094 | + " plot.title(f\"How {title_prefix} features contribute to the anomaly score (beeswarm plot)\", fontsize=12)\n", |
| 1095 | + " plot.show()" |
| 1096 | + ] |
| 1097 | + }, |
| 1098 | + { |
| 1099 | + "cell_type": "code", |
| 1100 | + "execution_count": null, |
| 1101 | + "id": "ec6676c7", |
| 1102 | + "metadata": {}, |
| 1103 | + "outputs": [], |
| 1104 | + "source": [ |
| 1105 | + "plot_shap_explained_beeswarm(\n", |
| 1106 | + " shap_values=java_package_shap_values,\n", |
| 1107 | + " prepared_features=java_package_anomaly_detection_features_prepared,\n", |
| 1108 | + " feature_names=java_package_anomaly_detection_feature_names,\n", |
| 1109 | + " title_prefix=\"Java Package\"\n", |
| 1110 | + ")" |
| 1111 | + ] |
| 1112 | + }, |
| 1113 | + { |
| 1114 | + "cell_type": "code", |
| 1115 | + "execution_count": null, |
| 1116 | + "id": "9b5a523d", |
| 1117 | + "metadata": {}, |
| 1118 | + "outputs": [], |
| 1119 | + "source": [ |
| 1120 | + "def plot_shap_explained_local_feature_importance(\n", |
| 1121 | + " index_to_explain,\n", |
| 1122 | + " random_forest_model: RandomForestClassifier,\n", |
| 1123 | + " prepared_features: np.ndarray,\n", |
| 1124 | + " feature_names: list[str],\n", |
| 1125 | + " title_prefix: str = \"\",\n", |
| 1126 | + " rounding_precision: int = 3,\n", |
| 1127 | + "):\n", |
| 1128 | + " # TODO Take explainer as input parameter\n", |
| 1129 | + " explainer = shap.TreeExplainer(random_forest_model)\n", |
| 1130 | + " shap_values = explainer.shap_values(prepared_features)\n", |
| 1131 | + "\n", |
| 1132 | + " # print(f\"Input data with prepared features: shape={prepared_features.shape}\")\n", |
| 1133 | + " # print(f\"Explainable AI SHAP values: shape={np.shape(shap_values)}\")\n", |
| 1134 | + " # print(f\"Explainable AI SHAP expected_value: shape={np.shape(explainer.expected_value)}\")\n", |
| 1135 | + " # print(f\"Explainable AI SHAP expected_value: type={type(explainer.expected_value)}\")\n", |
| 1136 | + " # print(f\"Explaining instance at index {index_to_explain} with anomaly label: {original_features.iloc[index_to_explain][anomaly_label_column]}\")\n", |
| 1137 | + "\n", |
| 1138 | + " shap_values_rounded = np.round(shap_values[:,:, 1][index_to_explain], rounding_precision)\n", |
| 1139 | + " prepared_features_rounded = prepared_features[:][index_to_explain].round(rounding_precision)\n", |
| 1140 | + " base_value_rounded = np.round(typing.cast(np.ndarray,explainer.expected_value)[1], rounding_precision)\n", |
| 1141 | + "\n", |
| 1142 | + " shap.force_plot(\n", |
| 1143 | + " base_value_rounded, # For class \"anomaly\"\n", |
| 1144 | + " # typing.cast(np.ndarray,explainer.expected_value)[1], # For class \"anomaly\"\n", |
| 1145 | + " # shap_values[:,:, 1][index_to_explain],\n", |
| 1146 | + " shap_values_rounded,\n", |
| 1147 | + " prepared_features_rounded,\n", |
| 1148 | + " # prepared_features[:][index_to_explain],\n", |
| 1149 | + " feature_names=feature_names,\n", |
| 1150 | + " matplotlib=True,\n", |
| 1151 | + " show=False,\n", |
| 1152 | + " contribution_threshold=0.07\n", |
| 1153 | + " )\n", |
| 1154 | + " plot.title(f\"{title_prefix} anomaly feature {feature_names[index_to_explain]} explained\", fontsize=14, loc='left')\n", |
| 1155 | + " plot.show()" |
| 1156 | + ] |
| 1157 | + }, |
| 1158 | + { |
| 1159 | + "cell_type": "code", |
| 1160 | + "execution_count": null, |
| 1161 | + "id": "77b0852c", |
| 1162 | + "metadata": {}, |
| 1163 | + "outputs": [], |
| 1164 | + "source": [ |
| 1165 | + "plot_shap_explained_local_feature_importance(\n", |
| 1166 | + " index_to_explain=4,\n", |
| 1167 | + " random_forest_model=java_package_anomaly_detection_results.random_forest_classifier,\n", |
| 1168 | + " prepared_features=java_package_anomaly_detection_features_prepared,\n", |
| 1169 | + " feature_names=java_package_anomaly_detection_feature_names,\n", |
| 1170 | + " title_prefix=\"Java Package\",\n", |
| 1171 | + ")" |
| 1172 | + ] |
| 1173 | + }, |
| 1174 | + { |
| 1175 | + "cell_type": "raw", |
| 1176 | + "id": "2df453b4", |
| 1177 | + "metadata": { |
| 1178 | + "vscode": { |
| 1179 | + "languageId": "raw" |
| 1180 | + } |
| 1181 | + }, |
| 1182 | + "source": [ |
| 1183 | + "# TODO delete if not needed anymore\n", |
| 1184 | + "def plot_shap_explained_feature_dependency(\n", |
| 1185 | + " index_to_explain: int,\n", |
| 1186 | + " random_forest_model: RandomForestClassifier,\n", |
| 1187 | + " prepared_features: np.ndarray,\n", |
| 1188 | + " feature_names: list[str],\n", |
| 1189 | + " title_prefix: str = \"\",\n", |
| 1190 | + "):\n", |
| 1191 | + " explainer = shap.TreeExplainer(random_forest_model)\n", |
| 1192 | + " shap_values = explainer.shap_values(prepared_features)\n", |
| 1193 | + "\n", |
| 1194 | + " shap.dependence_plot(\n", |
| 1195 | + " ind=index_to_explain, # Feature name or index\n", |
| 1196 | + " shap_values=shap_values[:, :, 1],\n", |
| 1197 | + " features=prepared_features[:],\n", |
| 1198 | + " feature_names=feature_names,\n", |
| 1199 | + " interaction_index=None, # Set to a feature name/index to see interactions\n", |
| 1200 | + " show=False,\n", |
| 1201 | + " )\n", |
| 1202 | + " plot.title(f\"{title_prefix} Feature contribution to anomaly score\")\n", |
| 1203 | + " plot.show()\n", |
| 1204 | + "\n", |
| 1205 | + "plot_shap_explained_feature_dependency(\n", |
| 1206 | + " index_to_explain=2,\n", |
| 1207 | + " random_forest_model=java_package_anomaly_detection_results.random_forest_classifier,\n", |
| 1208 | + " prepared_features=java_package_anomaly_detection_features_prepared,\n", |
| 1209 | + " feature_names=java_package_anomaly_detection_feature_names,\n", |
| 1210 | + " title_prefix=\"Java Package\"\n", |
| 1211 | + ")" |
| 1212 | + ] |
| 1213 | + }, |
| 1214 | + { |
| 1215 | + "cell_type": "code", |
| 1216 | + "execution_count": null, |
| 1217 | + "id": "fb7e14f9", |
| 1218 | + "metadata": {}, |
| 1219 | + "outputs": [], |
| 1220 | + "source": [ |
| 1221 | + "def plot_shap_explained_top_10_feature_dependence(\n", |
| 1222 | + " random_forest_model: RandomForestClassifier,\n", |
| 1223 | + " prepared_features: np.ndarray,\n", |
| 1224 | + " feature_names: list[str],\n", |
| 1225 | + " title_prefix: str = \"\",\n", |
| 1226 | + "):\n", |
| 1227 | + " explainer = shap.TreeExplainer(random_forest_model)\n", |
| 1228 | + " shap_values = explainer.shap_values(prepared_features)\n", |
| 1229 | + "\n", |
| 1230 | + " mean_abs_shap = np.abs(shap_values[:, :, 1]).mean(axis=0)\n", |
| 1231 | + " top_features = np.argsort(mean_abs_shap)[-10:][::-1] # top 10 indices\n", |
| 1232 | + " top_feature_names = [feature_names[i] for i in top_features] # Get names of top features\n", |
| 1233 | + " \n", |
| 1234 | + " figure, axes = plot.subplots(5, 2, figsize=(15, 20)) # 5 rows x 2 columns\n", |
| 1235 | + " figure.suptitle(f\"{title_prefix} Anomalies: Top 10 feature dependence plots\", fontsize=16)\n", |
| 1236 | + " axes = axes.flatten() # Flatten for easy iteration\n", |
| 1237 | + "\n", |
| 1238 | + " for index, feature in enumerate(top_feature_names):\n", |
| 1239 | + " shap.dependence_plot(\n", |
| 1240 | + " ind=feature, # Feature name or index\n", |
| 1241 | + " shap_values=shap_values[:, :, 1],\n", |
| 1242 | + " features=prepared_features[:],\n", |
| 1243 | + " feature_names=feature_names,\n", |
| 1244 | + " interaction_index=None, # Set to a feature name/index to see interactions\n", |
| 1245 | + " show=False,\n", |
| 1246 | + " ax=axes[index]\n", |
| 1247 | + " )\n", |
| 1248 | + "\n", |
| 1249 | + " plot.tight_layout(rect=(0.0, 0.02, 1.0, 0.98))\n", |
| 1250 | + " plot.show()\n", |
| 1251 | + "\n", |
| 1252 | + "plot_shap_explained_top_10_feature_dependence(\n", |
| 1253 | + " random_forest_model=java_package_anomaly_detection_results.random_forest_classifier,\n", |
| 1254 | + " prepared_features=java_package_anomaly_detection_features_prepared,\n", |
| 1255 | + " feature_names=java_package_anomaly_detection_feature_names,\n", |
| 1256 | + " title_prefix=\"Java Package\"\n", |
| 1257 | + ")" |
| 1258 | + ] |
| 1259 | + }, |
| 1260 | + { |
| 1261 | + "cell_type": "raw", |
| 1262 | + "id": "1ced99f1", |
| 1263 | + "metadata": { |
| 1264 | + "vscode": { |
| 1265 | + "languageId": "raw" |
| 1266 | + } |
| 1267 | + }, |
| 1268 | + "source": [ |
| 1269 | + "# TODO delete if not needed anymore\n", |
| 1270 | + "def plot_shap_explained_heatmap(\n", |
| 1271 | + " random_forest_model: RandomForestClassifier,\n", |
| 1272 | + " prepared_features: np.ndarray,\n", |
| 1273 | + " original_features: pd.DataFrame, \n", |
| 1274 | + " feature_names: list[str],\n", |
| 1275 | + " title_prefix: str = \"\",\n", |
| 1276 | + " anomaly_label_column: str = \"anomalyLabel\"\n", |
| 1277 | + "):\n", |
| 1278 | + " explainer = shap.TreeExplainer(random_forest_model)\n", |
| 1279 | + " shap_values = explainer.shap_values(prepared_features)\n", |
| 1280 | + "\n", |
| 1281 | + " # Create SHAP Explanation object\n", |
| 1282 | + " shap_explanation = shap.Explanation(\n", |
| 1283 | + " values=shap_values[:, :, 1],\n", |
| 1284 | + " base_values=typing.cast(np.ndarray, explainer.expected_value)[1], # For class \"anomaly\"\n", |
| 1285 | + " data=prepared_features[:],\n", |
| 1286 | + " feature_names=feature_names\n", |
| 1287 | + " )\n", |
| 1288 | + "\n", |
| 1289 | + " shap.heatmap_plot(\n", |
| 1290 | + " shap_explanation, \n", |
| 1291 | + " instance_order=\"leaves\", # Optional: use clustering to sort rows\n", |
| 1292 | + " show=False,\n", |
| 1293 | + " )\n", |
| 1294 | + " plot.title(f\"{title_prefix} Anomaly feature heatmap\")\n", |
| 1295 | + " plot.show()\n", |
| 1296 | + "\n", |
| 1297 | + "plot_shap_explained_heatmap(\n", |
| 1298 | + " random_forest_model=java_package_anomaly_detection_results.random_forest_classifier,\n", |
| 1299 | + " prepared_features=java_package_anomaly_detection_features_prepared,\n", |
| 1300 | + " original_features=java_package_anomaly_detection_features,\n", |
| 1301 | + " feature_names=java_package_anomaly_detection_feature_names,\n", |
| 1302 | + " title_prefix=\"Java Package\"\n", |
| 1303 | + ")" |
| 1304 | + ] |
| 1305 | + }, |
947 | 1306 | { |
948 | 1307 | "cell_type": "markdown", |
949 | 1308 | "id": "27b33560", |
|
0 commit comments