11"""Function to create the Multi-Comparison Matrix (MCM) results visualisation."""
22
3- __maintainer__ = ["TonyBagnall" ]
3+ __maintainer__ = ["hadifawaz1999" , " TonyBagnall" ]
44
55__all__ = ["create_multi_comparison_matrix" ]
66
1515
1616
1717def create_multi_comparison_matrix (
18- df_results ,
18+ results ,
1919 save_path = "./mcm" ,
2020 formats = None ,
21- used_statistic = "Accuracy " ,
21+ statistic_name = "Score " ,
2222 plot_1v1_comparisons = False ,
2323 higher_stat_better = True ,
2424 include_pvalue = True ,
2525 pvalue_test = "wilcoxon" ,
2626 pvalue_test_params = None ,
27- pvalue_correction = None ,
27+ pvalue_correction = "Holm" ,
2828 pvalue_threshold = 0.05 ,
29- use_mean = "mean-difference" ,
3029 order_stats = "average-statistic" ,
3130 order_stats_increasing = False ,
32- dataset_column = None ,
3331 precision = 4 ,
34- load_analysis = False ,
3532 row_comparates = None ,
3633 col_comparates = None ,
3734 excluded_row_comparates = None ,
@@ -43,20 +40,25 @@ def create_multi_comparison_matrix(
4340 colorbar_value = None ,
4441 win_tie_loss_labels = None ,
4542 include_legend = True ,
46- show_symetry = True ,
43+ show_symmetry = True ,
4744):
4845 """Generate the Multi-Comparison Matrix (MCM) [1]_.
4946
5047 MCM summarises a set of results for multiple estimators evaluated on multiple
5148 datasets. The MCM is a heatmap that shows absolute performance and tests for
5249 significant difference. It is configurable inmany ways.
5350
51+ Note: this implementation uses different pvalue parameters from the original
52+ by default. To use the original parameters, set ``pvalue_test_params`` to
53+ ``{"zero_method": "pratt", "alternative": "two-sided"}`` and
54+ ``pvalue_correction`` to ``None``.
55+
5456 Parameters
5557 ----------
56- df_results: str or pd.DataFrame
57- A csv file containing results in `n_problems,n_estimators` format. The first
58- row should contain the names of the estimators and the first column can
59- contain the names of the problems if `dataset_column` is true .
58+ results: pd.DataFrame
59+ A dataframe of scores. Columns are the names of the estimators and rows are the
60+ different problems. The estimator names present in the columns will be used as
61+ the comparate names in the MCM .
6062 save_path: str, default = './mcm'
6163 The output directory for the results. If you want to save the results with a
6264 different filename, you must include the filename in the path.
@@ -65,33 +67,28 @@ def create_multi_comparison_matrix(
6567 File formats to save in the save_path.
6668 - If None, no files are saved.
6769 - Valid formats are 'pdf', 'png', 'json', 'csv', 'tex'.
68- used_statistic: str, default = 'Score'
69- Name of the metric being assesses (e.g. accuracy, error, mse).
70- save_as_json: bool, default = True
71- Whether or not to save the python analysis dict into a json file format.
70+ statistic_name: str, default = 'Score'
71+ Name of the metric being assessesed (e.g. accuracy, error, mse).
72+ By default just generically labelles as 'Score'.
7273 plot_1v1_comparisons: bool, default = True
73- Whether or not to plot the 1v1 scatter results.
74+ Whether to plot the 1v1 scatter results.
7475 higher_stat_better: bool, default = True
7576 The order on considering a win or a loss for a given statistics.
7677 include_pvalue bool, default = True
77- Condition whether or not include a pvalue stats.
78+ Condition whether include a pvalue stats.
7879 pvalue_test: str, default = 'wilcoxon'
7980 The statistical test to produce the pvalue stats. Currently only wilcoxon is
8081 supported.
8182 pvalue_test_params: dict, default = None,
8283 The default parameter set for the pvalue_test used. If pvalue_test is set
8384 to Wilcoxon, one should check the scipy.stats.wilcoxon parameters,
8485 in the case Wilcoxon is set and this parameter is None, then the default setup
85- is {"zero_method": "pratt ", "alternative": "greater"}.
86+ is {"zero_method": "wilcox ", "alternative": "greater"}.
8687 pvalue_correction: str, default = None
8788 Correction to use for the pvalue significant test, None or "Holm".
8889 pvalue_threshold: float, default = 0.05
8990 Threshold for considering a comparison is significant or not. If pvalue <
9091 pvalue_threshhold -> comparison is significant.
91- use_mean: str, default = 'mean-difference'
92- The mean used to compare two estimators. The only option available
93- is 'mean-difference' which is the difference between arithmetic mean
94- over all datasets.
9592 order_stats: str, default = 'average-statistic'
9693 The way to order the used_statistic, default setup orders by average
9794 statistic over all datasets.
@@ -108,17 +105,13 @@ def create_multi_comparison_matrix(
108105 order_stats_increasing: bool, default = False
109106 If True, the order_stats will be ordered in increasing order, otherwise they are
110107 ordered in decreasing order.
111- dataset_column: str, default = 'dataset_name'
112- The name of the datasets column in the csv file.
113108 precision: int, default = 4
114109 The number of floating numbers after decimal point.
115- load_analysis: bool, default = False
116- If True attempts to load the analysis json file.
117110 row_comparates: list of str, default = None
118- A list of included row comparates, if None, all of the comparates in the study
111+ A list of included row comparates, if None, all the comparates in the study
119112 are placed in the rows.
120113 col_comparates: list of str, default = None
121- A list of included col comparates, if None, all of the comparates in the
114+ A list of included col comparates, if None, all the comparates in the
122115 study are placed in the cols.
123116 excluded_row_comparates: list of str, default = None
124117 A list of excluded row comparates. If None, all comparates are included.
@@ -145,9 +138,9 @@ def create_multi_comparison_matrix(
145138 The tuple must contain exactly three strings, representing win, tie, and
146139 loss outcomes for the row comparate (r) against the column comparate (c).
147140 include_legend: bool, default = True
148- Whether or not to show the legend on the MCM.
149- show_symetry : bool, default = True
150- Whether or not to show the symmetrical part of the heatmap.
141+ Whether to show the legend on the MCM.
142+ show_symmetry : bool, default = True
143+ Whether to show the symmetrical part of the heatmap.
151144
152145 Returns
153146 -------
@@ -158,7 +151,7 @@ def create_multi_comparison_matrix(
158151 -------
159152 >>> from aeon.visualisation import create_multi_comparison_matrix # doctest: +SKIP
160153 >>> create_multi_comparison_matrix(
161- ... df_results ="results.csv",
154+ ... results ="results.csv",
162155 ... save_path="reports/mymcm",
163156 ... formats=("png", "json")
164157 ... ) # doctest: +SKIP
@@ -173,12 +166,6 @@ def create_multi_comparison_matrix(
173166 Evaluations That Is Stable Under Manipulation Of The Comparate Set
174167 arXiv preprint arXiv:2305.11921, 2023.
175168 """
176- if isinstance (df_results , str ):
177- try :
178- df_results = pd .read_csv (df_results )
179- except Exception as e :
180- raise ValueError (f"No dataframe or valid path is given: Exception { e } " )
181-
182169 formats = _normalize_formats (formats )
183170
184171 if win_tie_loss_labels is None :
@@ -192,23 +179,20 @@ def create_multi_comparison_matrix(
192179 win_label , tie_label , loss_label = win_tie_loss_labels
193180
194181 analysis = _get_analysis (
195- df_results ,
182+ results ,
196183 save_path = save_path ,
197184 formats = formats ,
198- used_statistic = used_statistic ,
185+ used_statistic = statistic_name ,
199186 plot_1v1_comparisons = plot_1v1_comparisons ,
200187 higher_stat_better = higher_stat_better ,
201188 include_pvalue = include_pvalue ,
202189 pvalue_test = pvalue_test ,
203190 pvalue_test_params = pvalue_test_params ,
204191 pvalue_correction = pvalue_correction ,
205192 pvalue_threshhold = pvalue_threshold ,
206- use_mean = use_mean ,
207193 order_stats = order_stats ,
208194 order_stats_increasing = order_stats_increasing ,
209- dataset_column = dataset_column ,
210195 precision = precision ,
211- load_analysis = load_analysis ,
212196 )
213197
214198 # start drawing heatmap
@@ -228,15 +212,15 @@ def create_multi_comparison_matrix(
228212 colorbar_value = colorbar_value ,
229213 win_tie_loss_labels = win_tie_loss_labels ,
230214 include_legend = include_legend ,
231- show_symetry = show_symetry ,
215+ show_symmetry = show_symmetry ,
232216 )
233217 return temp
234218
235219
236220def _get_analysis (
237221 df_results ,
238222 save_path = "./" ,
239- formats = ( "json" ) ,
223+ formats = "json" ,
240224 used_statistic = "Score" ,
241225 plot_1v1_comparisons = False ,
242226 higher_stat_better = True ,
@@ -245,12 +229,9 @@ def _get_analysis(
245229 pvalue_test_params = None ,
246230 pvalue_correction = None ,
247231 pvalue_threshhold = 0.05 ,
248- use_mean = "mean-difference" ,
249232 order_stats = "average-statistic" ,
250233 order_stats_increasing = False ,
251- dataset_column = None ,
252234 precision = 4 ,
253- load_analysis = False ,
254235):
255236 _check_soft_dependencies ("matplotlib" )
256237 import matplotlib as mpl
@@ -344,19 +325,7 @@ def _plot_1v1(
344325 plt .clf ()
345326 plt .close ()
346327
347- save_file = f"{ save_path } _analysis.json"
348-
349- if load_analysis and os .path .exists (save_file ):
350- with open (save_file ) as json_file :
351- analysis = json .load (json_file )
352-
353- analysis .setdefault ("order_stats_increasing" , order_stats_increasing )
354-
355- return analysis
356-
357328 analysis = {
358- "dataset-column" : dataset_column ,
359- "use-mean" : use_mean ,
360329 "order-stats" : order_stats ,
361330 "order_stats_increasing" : order_stats_increasing ,
362331 "used-statistics" : used_statistic ,
@@ -397,7 +366,6 @@ def _plot_1v1(
397366 pvalue_test = pvalue_test ,
398367 pvalue_test_params = pvalue_test_params ,
399368 pvalue_threshhold = pvalue_threshhold ,
400- use_mean = use_mean ,
401369 )
402370
403371 analysis [pairwise_key ] = pairwise_content
@@ -431,6 +399,7 @@ def _plot_1v1(
431399
432400 _re_order_comparates (df_results = df_results , analysis = analysis )
433401
402+ save_file = f"{ save_path } _analysis.json"
434403 if "json" in formats :
435404 with open (save_file , "w" ) as fjson :
436405 json .dump (analysis , fjson , cls = _NpEncoder )
@@ -454,7 +423,7 @@ def _draw(
454423 colorbar_value = None ,
455424 win_tie_loss_labels = None ,
456425 higher_stat_better = True ,
457- show_symetry = True ,
426+ show_symmetry = True ,
458427 include_legend = True ,
459428):
460429 _check_soft_dependencies ("matplotlib" )
@@ -676,7 +645,7 @@ def _draw(
676645
677646 latex_row = []
678647
679- if can_be_symmetrical and (not show_symetry ):
648+ if can_be_symmetrical and (not show_symmetry ):
680649 start_j = i
681650
682651 for j in range (start_j , n_cols ):
@@ -918,31 +887,11 @@ def _decode_results_data_frame(df, analysis):
918887
919888 """
920889 df_columns = list (df .columns ) # extract columns from data frame
921-
922- # check if dataset column name is correct
923-
924- if analysis ["dataset-column" ] is not None :
925- if analysis ["dataset-column" ] not in df_columns :
926- raise KeyError ("The column " + analysis ["dataset-column" ] + " is missing." )
927-
928- # get number of examples (datasets)
929- # n_datasets = len(np.unique(np.asarray(df[analysis['dataset-column']])))
930- n_datasets = len (df .index )
931-
932- analysis ["n-datasets" ] = n_datasets # add number of examples to dictionary
933-
934- if analysis ["dataset-column" ] is not None :
935- analysis ["dataset-names" ] = list (
936- df [analysis ["dataset-column" ]]
937- ) # add example names to dict
938- df_columns .remove (
939- analysis ["dataset-column" ]
940- ) # drop the dataset column name from columns list
941- # and keep comparate names
942-
943890 comparate_names = df_columns .copy ()
944891 n_comparates = len (comparate_names )
945892
893+ # add number of examples to dictionary
894+ analysis ["n-datasets" ] = len (df .index )
946895 # add the information about comparates to dict
947896 analysis ["comparate-names" ] = comparate_names
948897 analysis ["n-comparates" ] = n_comparates
@@ -956,7 +905,6 @@ def _get_pairwise_content(
956905 pvalue_test = "wilcoxon" ,
957906 pvalue_test_params = None ,
958907 pvalue_threshhold = 0.05 ,
959- use_mean = "mean-difference" ,
960908):
961909 content = {}
962910
@@ -992,9 +940,8 @@ def _get_pairwise_content(
992940
993941 else :
994942 raise ValueError (f"{ pvalue_test } test is not supported yet" )
995- if use_mean == "mean-difference" :
996- content ["mean" ] = np .mean (x ) - np .mean (y )
997943
944+ content ["mean" ] = np .mean (x ) - np .mean (y )
998945 return content
999946
1000947
@@ -1070,13 +1017,7 @@ def _re_order_comparates(df_results, analysis):
10701017 stats .append (analysis ["average-statistic" ][analysis ["comparate-names" ][i ]])
10711018
10721019 elif analysis ["order-stats" ] == "average-rank" :
1073- if analysis ["dataset-column" ] is not None :
1074- np_results = np .asarray (
1075- df_results .drop ([analysis ["dataset-column" ]], axis = 1 )
1076- )
1077- else :
1078- np_results = np .asarray (df_results )
1079-
1020+ np_results = np .asarray (df_results )
10801021 df = pd .DataFrame (columns = ["comparate-name" , "values" ])
10811022
10821023 for i , comparate_name in enumerate (analysis ["comparate-names" ]):
@@ -1169,7 +1110,7 @@ def _get_cell_legend(
11691110 tie_label = "r=c" ,
11701111 loss_label = "r<c" ,
11711112):
1172- cell_legend = _capitalize_label (analysis [ "use- mean" ] )
1113+ cell_legend = _capitalize_label (" mean-difference" )
11731114 longest_string = len (cell_legend )
11741115
11751116 win_tie_loss_string = f"{ win_label } / { tie_label } / { loss_label } "
@@ -1190,7 +1131,6 @@ def _get_cell_legend(
11901131def _capitalize_label (s ):
11911132 if len (s .split ("-" )) == 1 :
11921133 return s .capitalize ()
1193-
11941134 else :
11951135 return "-" .join (ss .capitalize () for ss in s .split ("-" ))
11961136
0 commit comments