Skip to content

Commit 6d99d7e

Browse files
[ENH] MCM plot function tidy. (#3099)
* mcm tidy * more fixes * more fixes * docs and holm * Update note on pvalue parameters usage Clarified note on pvalue parameters for configuration. * Automatic `pre-commit` fixes --------- Co-authored-by: MatthewMiddlehurst <25731235+MatthewMiddlehurst@users.noreply.github.com>
1 parent 63dda95 commit 6d99d7e

File tree

3 files changed

+59
-104
lines changed

3 files changed

+59
-104
lines changed

aeon/visualisation/results/_mcm.py

Lines changed: 38 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Function to create the Multi-Comparison Matrix (MCM) results visualisation."""
22

3-
__maintainer__ = ["TonyBagnall"]
3+
__maintainer__ = ["hadifawaz1999", "TonyBagnall"]
44

55
__all__ = ["create_multi_comparison_matrix"]
66

@@ -15,23 +15,20 @@
1515

1616

1717
def create_multi_comparison_matrix(
18-
df_results,
18+
results,
1919
save_path="./mcm",
2020
formats=None,
21-
used_statistic="Accuracy",
21+
statistic_name="Score",
2222
plot_1v1_comparisons=False,
2323
higher_stat_better=True,
2424
include_pvalue=True,
2525
pvalue_test="wilcoxon",
2626
pvalue_test_params=None,
27-
pvalue_correction=None,
27+
pvalue_correction="Holm",
2828
pvalue_threshold=0.05,
29-
use_mean="mean-difference",
3029
order_stats="average-statistic",
3130
order_stats_increasing=False,
32-
dataset_column=None,
3331
precision=4,
34-
load_analysis=False,
3532
row_comparates=None,
3633
col_comparates=None,
3734
excluded_row_comparates=None,
@@ -43,20 +40,25 @@ def create_multi_comparison_matrix(
4340
colorbar_value=None,
4441
win_tie_loss_labels=None,
4542
include_legend=True,
46-
show_symetry=True,
43+
show_symmetry=True,
4744
):
4845
"""Generate the Multi-Comparison Matrix (MCM) [1]_.
4946
5047
MCM summarises a set of results for multiple estimators evaluated on multiple
5148
datasets. The MCM is a heatmap that shows absolute performance and tests for
5249
significant difference. It is configurable inmany ways.
5350
51+
Note: this implementation uses different pvalue parameters from the original
52+
by default. To use the original parameters, set ``pvalue_test_params`` to
53+
``{"zero_method": "pratt", "alternative": "two-sided"}`` and
54+
``pvalue_correction`` to ``None``.
55+
5456
Parameters
5557
----------
56-
df_results: str or pd.DataFrame
57-
A csv file containing results in `n_problems,n_estimators` format. The first
58-
row should contain the names of the estimators and the first column can
59-
contain the names of the problems if `dataset_column` is true.
58+
results: pd.DataFrame
59+
A dataframe of scores. Columns are the names of the estimators and rows are the
60+
different problems. The estimator names present in the columns will be used as
61+
the comparate names in the MCM.
6062
save_path: str, default = './mcm'
6163
The output directory for the results. If you want to save the results with a
6264
different filename, you must include the filename in the path.
@@ -65,33 +67,28 @@ def create_multi_comparison_matrix(
6567
File formats to save in the save_path.
6668
- If None, no files are saved.
6769
- Valid formats are 'pdf', 'png', 'json', 'csv', 'tex'.
68-
used_statistic: str, default = 'Score'
69-
Name of the metric being assesses (e.g. accuracy, error, mse).
70-
save_as_json: bool, default = True
71-
Whether or not to save the python analysis dict into a json file format.
70+
statistic_name: str, default = 'Score'
71+
Name of the metric being assessesed (e.g. accuracy, error, mse).
72+
By default just generically labelles as 'Score'.
7273
plot_1v1_comparisons: bool, default = True
73-
Whether or not to plot the 1v1 scatter results.
74+
Whether to plot the 1v1 scatter results.
7475
higher_stat_better: bool, default = True
7576
The order on considering a win or a loss for a given statistics.
7677
include_pvalue bool, default = True
77-
Condition whether or not include a pvalue stats.
78+
Condition whether include a pvalue stats.
7879
pvalue_test: str, default = 'wilcoxon'
7980
The statistical test to produce the pvalue stats. Currently only wilcoxon is
8081
supported.
8182
pvalue_test_params: dict, default = None,
8283
The default parameter set for the pvalue_test used. If pvalue_test is set
8384
to Wilcoxon, one should check the scipy.stats.wilcoxon parameters,
8485
in the case Wilcoxon is set and this parameter is None, then the default setup
85-
is {"zero_method": "pratt", "alternative": "greater"}.
86+
is {"zero_method": "wilcox", "alternative": "greater"}.
8687
pvalue_correction: str, default = None
8788
Correction to use for the pvalue significant test, None or "Holm".
8889
pvalue_threshold: float, default = 0.05
8990
Threshold for considering a comparison is significant or not. If pvalue <
9091
pvalue_threshhold -> comparison is significant.
91-
use_mean: str, default = 'mean-difference'
92-
The mean used to compare two estimators. The only option available
93-
is 'mean-difference' which is the difference between arithmetic mean
94-
over all datasets.
9592
order_stats: str, default = 'average-statistic'
9693
The way to order the used_statistic, default setup orders by average
9794
statistic over all datasets.
@@ -108,17 +105,13 @@ def create_multi_comparison_matrix(
108105
order_stats_increasing: bool, default = False
109106
If True, the order_stats will be ordered in increasing order, otherwise they are
110107
ordered in decreasing order.
111-
dataset_column: str, default = 'dataset_name'
112-
The name of the datasets column in the csv file.
113108
precision: int, default = 4
114109
The number of floating numbers after decimal point.
115-
load_analysis: bool, default = False
116-
If True attempts to load the analysis json file.
117110
row_comparates: list of str, default = None
118-
A list of included row comparates, if None, all of the comparates in the study
111+
A list of included row comparates, if None, all the comparates in the study
119112
are placed in the rows.
120113
col_comparates: list of str, default = None
121-
A list of included col comparates, if None, all of the comparates in the
114+
A list of included col comparates, if None, all the comparates in the
122115
study are placed in the cols.
123116
excluded_row_comparates: list of str, default = None
124117
A list of excluded row comparates. If None, all comparates are included.
@@ -145,9 +138,9 @@ def create_multi_comparison_matrix(
145138
The tuple must contain exactly three strings, representing win, tie, and
146139
loss outcomes for the row comparate (r) against the column comparate (c).
147140
include_legend: bool, default = True
148-
Whether or not to show the legend on the MCM.
149-
show_symetry: bool, default = True
150-
Whether or not to show the symmetrical part of the heatmap.
141+
Whether to show the legend on the MCM.
142+
show_symmetry: bool, default = True
143+
Whether to show the symmetrical part of the heatmap.
151144
152145
Returns
153146
-------
@@ -158,7 +151,7 @@ def create_multi_comparison_matrix(
158151
-------
159152
>>> from aeon.visualisation import create_multi_comparison_matrix # doctest: +SKIP
160153
>>> create_multi_comparison_matrix(
161-
... df_results="results.csv",
154+
... results="results.csv",
162155
... save_path="reports/mymcm",
163156
... formats=("png", "json")
164157
... ) # doctest: +SKIP
@@ -173,12 +166,6 @@ def create_multi_comparison_matrix(
173166
Evaluations That Is Stable Under Manipulation Of The Comparate Set
174167
arXiv preprint arXiv:2305.11921, 2023.
175168
"""
176-
if isinstance(df_results, str):
177-
try:
178-
df_results = pd.read_csv(df_results)
179-
except Exception as e:
180-
raise ValueError(f"No dataframe or valid path is given: Exception {e}")
181-
182169
formats = _normalize_formats(formats)
183170

184171
if win_tie_loss_labels is None:
@@ -192,23 +179,20 @@ def create_multi_comparison_matrix(
192179
win_label, tie_label, loss_label = win_tie_loss_labels
193180

194181
analysis = _get_analysis(
195-
df_results,
182+
results,
196183
save_path=save_path,
197184
formats=formats,
198-
used_statistic=used_statistic,
185+
used_statistic=statistic_name,
199186
plot_1v1_comparisons=plot_1v1_comparisons,
200187
higher_stat_better=higher_stat_better,
201188
include_pvalue=include_pvalue,
202189
pvalue_test=pvalue_test,
203190
pvalue_test_params=pvalue_test_params,
204191
pvalue_correction=pvalue_correction,
205192
pvalue_threshhold=pvalue_threshold,
206-
use_mean=use_mean,
207193
order_stats=order_stats,
208194
order_stats_increasing=order_stats_increasing,
209-
dataset_column=dataset_column,
210195
precision=precision,
211-
load_analysis=load_analysis,
212196
)
213197

214198
# start drawing heatmap
@@ -228,15 +212,15 @@ def create_multi_comparison_matrix(
228212
colorbar_value=colorbar_value,
229213
win_tie_loss_labels=win_tie_loss_labels,
230214
include_legend=include_legend,
231-
show_symetry=show_symetry,
215+
show_symmetry=show_symmetry,
232216
)
233217
return temp
234218

235219

236220
def _get_analysis(
237221
df_results,
238222
save_path="./",
239-
formats=("json"),
223+
formats="json",
240224
used_statistic="Score",
241225
plot_1v1_comparisons=False,
242226
higher_stat_better=True,
@@ -245,12 +229,9 @@ def _get_analysis(
245229
pvalue_test_params=None,
246230
pvalue_correction=None,
247231
pvalue_threshhold=0.05,
248-
use_mean="mean-difference",
249232
order_stats="average-statistic",
250233
order_stats_increasing=False,
251-
dataset_column=None,
252234
precision=4,
253-
load_analysis=False,
254235
):
255236
_check_soft_dependencies("matplotlib")
256237
import matplotlib as mpl
@@ -344,19 +325,7 @@ def _plot_1v1(
344325
plt.clf()
345326
plt.close()
346327

347-
save_file = f"{save_path}_analysis.json"
348-
349-
if load_analysis and os.path.exists(save_file):
350-
with open(save_file) as json_file:
351-
analysis = json.load(json_file)
352-
353-
analysis.setdefault("order_stats_increasing", order_stats_increasing)
354-
355-
return analysis
356-
357328
analysis = {
358-
"dataset-column": dataset_column,
359-
"use-mean": use_mean,
360329
"order-stats": order_stats,
361330
"order_stats_increasing": order_stats_increasing,
362331
"used-statistics": used_statistic,
@@ -397,7 +366,6 @@ def _plot_1v1(
397366
pvalue_test=pvalue_test,
398367
pvalue_test_params=pvalue_test_params,
399368
pvalue_threshhold=pvalue_threshhold,
400-
use_mean=use_mean,
401369
)
402370

403371
analysis[pairwise_key] = pairwise_content
@@ -431,6 +399,7 @@ def _plot_1v1(
431399

432400
_re_order_comparates(df_results=df_results, analysis=analysis)
433401

402+
save_file = f"{save_path}_analysis.json"
434403
if "json" in formats:
435404
with open(save_file, "w") as fjson:
436405
json.dump(analysis, fjson, cls=_NpEncoder)
@@ -454,7 +423,7 @@ def _draw(
454423
colorbar_value=None,
455424
win_tie_loss_labels=None,
456425
higher_stat_better=True,
457-
show_symetry=True,
426+
show_symmetry=True,
458427
include_legend=True,
459428
):
460429
_check_soft_dependencies("matplotlib")
@@ -676,7 +645,7 @@ def _draw(
676645

677646
latex_row = []
678647

679-
if can_be_symmetrical and (not show_symetry):
648+
if can_be_symmetrical and (not show_symmetry):
680649
start_j = i
681650

682651
for j in range(start_j, n_cols):
@@ -918,31 +887,11 @@ def _decode_results_data_frame(df, analysis):
918887
919888
"""
920889
df_columns = list(df.columns) # extract columns from data frame
921-
922-
# check if dataset column name is correct
923-
924-
if analysis["dataset-column"] is not None:
925-
if analysis["dataset-column"] not in df_columns:
926-
raise KeyError("The column " + analysis["dataset-column"] + " is missing.")
927-
928-
# get number of examples (datasets)
929-
# n_datasets = len(np.unique(np.asarray(df[analysis['dataset-column']])))
930-
n_datasets = len(df.index)
931-
932-
analysis["n-datasets"] = n_datasets # add number of examples to dictionary
933-
934-
if analysis["dataset-column"] is not None:
935-
analysis["dataset-names"] = list(
936-
df[analysis["dataset-column"]]
937-
) # add example names to dict
938-
df_columns.remove(
939-
analysis["dataset-column"]
940-
) # drop the dataset column name from columns list
941-
# and keep comparate names
942-
943890
comparate_names = df_columns.copy()
944891
n_comparates = len(comparate_names)
945892

893+
# add number of examples to dictionary
894+
analysis["n-datasets"] = len(df.index)
946895
# add the information about comparates to dict
947896
analysis["comparate-names"] = comparate_names
948897
analysis["n-comparates"] = n_comparates
@@ -956,7 +905,6 @@ def _get_pairwise_content(
956905
pvalue_test="wilcoxon",
957906
pvalue_test_params=None,
958907
pvalue_threshhold=0.05,
959-
use_mean="mean-difference",
960908
):
961909
content = {}
962910

@@ -992,9 +940,8 @@ def _get_pairwise_content(
992940

993941
else:
994942
raise ValueError(f"{pvalue_test} test is not supported yet")
995-
if use_mean == "mean-difference":
996-
content["mean"] = np.mean(x) - np.mean(y)
997943

944+
content["mean"] = np.mean(x) - np.mean(y)
998945
return content
999946

1000947

@@ -1070,13 +1017,7 @@ def _re_order_comparates(df_results, analysis):
10701017
stats.append(analysis["average-statistic"][analysis["comparate-names"][i]])
10711018

10721019
elif analysis["order-stats"] == "average-rank":
1073-
if analysis["dataset-column"] is not None:
1074-
np_results = np.asarray(
1075-
df_results.drop([analysis["dataset-column"]], axis=1)
1076-
)
1077-
else:
1078-
np_results = np.asarray(df_results)
1079-
1020+
np_results = np.asarray(df_results)
10801021
df = pd.DataFrame(columns=["comparate-name", "values"])
10811022

10821023
for i, comparate_name in enumerate(analysis["comparate-names"]):
@@ -1169,7 +1110,7 @@ def _get_cell_legend(
11691110
tie_label="r=c",
11701111
loss_label="r<c",
11711112
):
1172-
cell_legend = _capitalize_label(analysis["use-mean"])
1113+
cell_legend = _capitalize_label("mean-difference")
11731114
longest_string = len(cell_legend)
11741115

11751116
win_tie_loss_string = f"{win_label} / {tie_label} / {loss_label}"
@@ -1190,7 +1131,6 @@ def _get_cell_legend(
11901131
def _capitalize_label(s):
11911132
if len(s.split("-")) == 1:
11921133
return s.capitalize()
1193-
11941134
else:
11951135
return "-".join(ss.capitalize() for ss in s.split("-"))
11961136

aeon/visualisation/results/tests/test_mcm.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,25 @@ def test_mcm():
2323
np.random.rand(10, 3), # 10 rows, 3 columns of random numbers
2424
columns=["Classifier1", "Classifier2", "Classifier3"],
2525
)
26-
fig = create_multi_comparison_matrix(df, formats=())
26+
fig = create_multi_comparison_matrix(df)
27+
assert isinstance(fig, plt.Figure)
28+
29+
30+
@pytest.mark.skipif(
31+
not _check_soft_dependencies("matplotlib", severity="none"),
32+
reason="skip test if required soft dependency not available",
33+
)
34+
def test_mcm_original_pvalue():
35+
"""Test the multi-comparison-matrix visualisation."""
36+
import matplotlib.pyplot as plt
37+
38+
df = pd.DataFrame(
39+
np.random.rand(10, 3), # 10 rows, 3 columns of random numbers
40+
columns=["Classifier1", "Classifier2", "Classifier3"],
41+
)
42+
fig = create_multi_comparison_matrix(
43+
df, pvalue_test_params={"zero_method": "pratt", "alternative": "two-sided"}
44+
)
2745
assert isinstance(fig, plt.Figure)
2846

2947

@@ -48,6 +66,3 @@ def test_mcm_file_save():
4866
pvalue_correction="Holm",
4967
)
5068
assert isinstance(fig, plt.Figure)
51-
52-
53-
# assert os.path.isfile(os.path.join(tmp, "test1.pdf"))

0 commit comments

Comments
 (0)