|
267 | 267 | }, |
268 | 268 | }, |
269 | 269 | SchemaOptional("seed"): int, |
| 270 | + SchemaOptional("features_to_explain"): [Or(int, str)], |
270 | 271 | }, |
271 | 272 | SchemaOptional("pre_training_bias"): {"methods": Or(str, [str])}, |
272 | 273 | SchemaOptional("post_training_bias"): {"methods": Or(str, [str])}, |
@@ -1308,6 +1309,7 @@ def __init__( |
1308 | 1309 | num_clusters: Optional[int] = None, |
1309 | 1310 | text_config: Optional[TextConfig] = None, |
1310 | 1311 | image_config: Optional[ImageConfig] = None, |
| 1312 | + features_to_explain: Optional[List[Union[str, int]]] = None, |
1311 | 1313 | ): |
1312 | 1314 | """Initializes config for SHAP analysis. |
1313 | 1315 |
|
@@ -1343,6 +1345,14 @@ def __init__( |
1343 | 1345 | text features. Default is None. |
1344 | 1346 | image_config (:class:`~sagemaker.clarify.ImageConfig`): Config for handling image |
1345 | 1347 | features. Default is None. |
| 1348 | + features_to_explain: A list of names or indices of dataset features to compute SHAP |
| 1349 | + values for. If not provided, SHAP values are computed for all features by default. |
| 1350 | + Currently only supported for tabular datasets. |
| 1351 | +
|
| 1352 | + Raises: |
| 1353 | + ValueError: when ``agg_method`` is invalid, ``baseline`` and ``num_clusters`` are provided |
| 1354 | + together, or ``features_to_explain`` is specified when ``text_config`` or |
| 1355 | + ``image_config`` is provided |
1346 | 1356 | """ # noqa E501 # pylint: disable=c0301 |
1347 | 1357 | if agg_method is not None and agg_method not in [ |
1348 | 1358 | "mean_abs", |
@@ -1376,6 +1386,13 @@ def __init__( |
1376 | 1386 | ) |
1377 | 1387 | if image_config: |
1378 | 1388 | _set(image_config.get_image_config(), "image_config", self.shap_config) |
| 1389 | + if features_to_explain is not None and ( |
| 1390 | + text_config is not None or image_config is not None |
| 1391 | + ): |
| 1392 | + raise ValueError( |
| 1393 | + "`features_to_explain` is not supported for datasets containing text features or images." |
| 1394 | + ) |
| 1395 | + _set(features_to_explain, "features_to_explain", self.shap_config) |
1379 | 1396 |
|
1380 | 1397 | def get_explainability_config(self): |
1381 | 1398 | """Returns a shap config dictionary.""" |
|
0 commit comments