@@ -9,7 +9,7 @@ public class DATA_evaluation {
99 private String [] testDataResults ;
1010 private int [][] confustionMatrix ;
1111
12- private boolean createConfusionMatrix ;
12+ private boolean createConfusionMatrix = false ;
1313
1414 protected DATA_evaluation (String [] testDataResults , int columnCount , String [][][] predictedTestData , int [][] sortedProbability , int numberOfClasses ) {
1515 this .testDataResults = testDataResults ;
@@ -18,6 +18,7 @@ protected DATA_evaluation(String[] testDataResults, int columnCount, String[][][
1818 this .sortedProbability = sortedProbability ;
1919 this .numberOfClasses = numberOfClasses ;
2020 this .confustionMatrix = new int [this .numberOfClasses ][this .numberOfClasses ];
21+ this .createConfusionMatrix = false ;
2122 }
2223
2324 private void confusionMatrix () {
@@ -60,15 +61,40 @@ private void confusionMatrix() {
6061
6162 }
6263
63- protected int [][] getConfusionMatrix () {
64- if (!createConfusionMatrix ) {
64+ protected int [][] getConfusionMatrixSimple () {
65+ if (!this . createConfusionMatrix ) {
6566 confusionMatrix ();
6667 }
6768
68- return this .confustionMatrix ;
69+ int [][] confusionMatrixSimple = new int [this .numberOfClasses ][2 ];
70+ for (int i = 0 ; i < this .numberOfClasses ; i ++) {
71+ for (int j = 0 ; j < 2 ; j ++) {
72+ confusionMatrixSimple [i ][j ] = 0 ;
73+ }
74+ }
75+
76+ for (int i = 0 ; i < this .columnCount ; i ++) {
77+ if (this .testDataResults [i ].equals (this .predictedTestData [i ][this .sortedProbability [i ][0 ]][0 ])) {
78+ confusionMatrixSimple [this .sortedProbability [i ][0 ]][0 ]++;
79+ }
80+ else {
81+ confusionMatrixSimple [this .sortedProbability [i ][0 ]][1 ]++;
82+ }
83+ }
84+
85+
86+
87+ for (int i = 0 ; i < 3 ; i ++) {
88+ for (int j = 0 ; j < 2 ; j ++) {
89+ System .out .print (confusionMatrixSimple [i ][j ] + " " );
90+ }
91+ System .out .println ();
92+ }
93+
94+ return confusionMatrixSimple ;
6995 }
7096 protected float [][] getConfusionMatrixNormalized () {
71- if (!createConfusionMatrix ) {
97+ if (!this . createConfusionMatrix ) {
7298 confusionMatrix ();
7399 }
74100 float [][] confusionMatrixNormalized = new float [this .numberOfClasses ][this .numberOfClasses ];
@@ -89,14 +115,14 @@ protected float[][] getConfusionMatrixNormalized() {
89115 }
90116
91117 }
92-
118+ /*
93119 for (int i = 0; i < this.numberOfClasses; i++) {
94120 for (int j = 0; j < this.numberOfClasses; j++) {
95121 System.out.print(confusionMatrixNormalized[i][j] + " ");
96122 }
97123 System.out.println();
98124 }
99-
125+ */
100126
101127 return confusionMatrixNormalized ;
102128 }
0 commit comments