|
12 | 12 | # Dariusz Brzezinski |
13 | 13 | # License: MIT |
14 | 14 |
|
15 | | -import warnings |
16 | 15 | import functools |
17 | | - |
18 | | -from inspect import getcallargs |
| 16 | +import warnings |
19 | 17 |
|
20 | 18 | import numpy as np |
21 | 19 | import scipy as sp |
@@ -731,56 +729,56 @@ def make_index_balanced_accuracy(alpha=0.1, squared=True): |
731 | 729 | def decorate(scoring_func): |
732 | 730 | @functools.wraps(scoring_func) |
733 | 731 | def compute_score(*args, **kwargs): |
734 | | - # Create the list of tags |
735 | | - tags_scoring_func = getcallargs(scoring_func, *args, **kwargs) |
| 732 | + signature_scoring_func = signature(scoring_func) |
| 733 | + params_scoring_func = set(signature_scoring_func.parameters.keys()) |
| 734 | + |
736 | 735 | # check that the scoring function does not need a score |
737 | 736 | # and only a prediction |
738 | | - if ( |
739 | | - "y_score" in tags_scoring_func |
740 | | - or "y_prob" in tags_scoring_func |
741 | | - or "y2" in tags_scoring_func |
742 | | - ): |
| 737 | + prohibitied_y_pred = set(["y_score", "y_prob", "y2"]) |
| 738 | + if prohibitied_y_pred.intersection(params_scoring_func): |
743 | 739 | raise AttributeError( |
744 | 740 | "The function {} has an unsupported" |
745 | 741 | " attribute. Metric with`y_pred` are the" |
746 | 742 | " only supported metrics is the only" |
747 | | - " supported." |
| 743 | + " supported.".format(scoring_func.__name__) |
748 | 744 | ) |
749 | | - # Compute the score from the scoring function |
750 | | - _score = scoring_func(*args, **kwargs) |
751 | | - # Square if desired |
| 745 | + |
| 746 | + args_scoring_func = signature_scoring_func.bind(*args, **kwargs) |
| 747 | + args_scoring_func.apply_defaults() |
| 748 | + _score = scoring_func( |
| 749 | + *args_scoring_func.args, **args_scoring_func.kwargs |
| 750 | + ) |
752 | 751 | if squared: |
753 | 752 | _score = np.power(_score, 2) |
754 | | - # Get the signature of the sens/spec function |
755 | | - sens_spec_sig = signature(sensitivity_specificity_support) |
756 | | - # We need to extract from kwargs only the one needed by the |
757 | | - # specificity and specificity |
758 | | - params_sens_spec = set(sens_spec_sig._parameters.keys()) |
759 | | - # Make the intersection between the parameters |
760 | | - sel_params = params_sens_spec.intersection(set(tags_scoring_func)) |
761 | | - # Create a sub dictionary |
762 | | - tags_scoring_func = {k: tags_scoring_func[k] for k in sel_params} |
763 | | - # Check if the metric is the geometric mean |
| 753 | + |
| 754 | + signature_sens_spec = signature(sensitivity_specificity_support) |
| 755 | + params_sens_spec = set(signature_sens_spec.parameters.keys()) |
| 756 | + common_params = params_sens_spec.intersection( |
| 757 | + set(args_scoring_func.arguments.keys()) |
| 758 | + ) |
| 759 | + |
| 760 | + args_sens_spec = { |
| 761 | + k: args_scoring_func.arguments[k] for k in common_params |
| 762 | + } |
| 763 | + |
764 | 764 | if scoring_func.__name__ == "geometric_mean_score": |
765 | | - if "average" in tags_scoring_func: |
766 | | - if tags_scoring_func["average"] == "multiclass": |
767 | | - tags_scoring_func["average"] = "macro" |
768 | | - # We do not support multilabel so the only average supported |
769 | | - # is binary |
| 765 | + if "average" in args_sens_spec: |
| 766 | + if args_sens_spec["average"] == "multiclass": |
| 767 | + args_sens_spec["average"] = "macro" |
770 | 768 | elif ( |
771 | 769 | scoring_func.__name__ == "accuracy_score" |
772 | 770 | or scoring_func.__name__ == "jaccard_score" |
773 | 771 | ): |
774 | | - tags_scoring_func["average"] = "binary" |
775 | | - # Create the list of parameters through signature binding |
776 | | - tags_sens_spec = sens_spec_sig.bind(**tags_scoring_func) |
777 | | - # Call the sens/spec function |
778 | | - sen, spe, _ = sensitivity_specificity_support( |
779 | | - *tags_sens_spec.args, **tags_sens_spec.kwargs |
| 772 | + # We do not support multilabel so the only average supported |
| 773 | + # is binary |
| 774 | + args_sens_spec["average"] = "binary" |
| 775 | + |
| 776 | + sensitivity, specificity, _ = sensitivity_specificity_support( |
| 777 | + **args_sens_spec |
780 | 778 | ) |
781 | | - # Compute the dominance |
782 | | - dom = sen - spe |
783 | | - return (1.0 + alpha * dom) * _score |
| 779 | + |
| 780 | + dominance = sensitivity - specificity |
| 781 | + return (1.0 + alpha * dominance) * _score |
784 | 782 |
|
785 | 783 | return compute_score |
786 | 784 |
|
|
0 commit comments