Skip to content

Commit e1581d0

Browse files
Revert "tf-unet-fix" and fix. (#237)
1 parent 54911b2 commit e1581d0

File tree

18 files changed

+47
-32
lines changed

18 files changed

+47
-32
lines changed

utils/cv/brats.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,12 @@ def submit_predictions(self, prediction):
236236
)
237237
self.__current_img_id += 1
238238

239-
def summarize_accuracy(self):
239+
def _summarize_accuracy(self):
240240
from utils.cv.nnUNet.nnunet.evaluation.region_based_evaluation import evaluate_regions, get_brats_regions
241241
evaluate_regions(
242-
self.__processed_predictions_dir_path, Path(self.__preprocessed_dir_path, "labelsTr"), get_brats_regions()
242+
str(self.__processed_predictions_dir_path),
243+
str(Path(self.__preprocessed_dir_path, "labelsTr")),
244+
get_brats_regions()
243245
)
244246
with open(Path(self.__processed_predictions_dir_path, "summary.csv")) as f:
245247
for line in f:

utils/cv/coco.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,13 @@ def submit_mask_prediction(self, id_in_batch, bbox, score, category, mask):
171171
"segmentation": mask
172172
})
173173

174-
def summarize_accuracy(self):
174+
def _summarize_accuracy(self):
175175
"""
176176
A function summarizing the accuracy achieved on the images obtained with get_input_array() calls on which
177177
predictions done where supplied with submit_bbox_prediction() function.
178178
"""
179179
if self.do_skip():
180-
return {}
180+
return
181181

182182
if self._task == "bbox":
183183
predictions = np.array(self._predictions)

utils/cv/imagenet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,13 @@ def submit_predictions(self, id_in_batch: int, top_1_index: int, top_5_indices:
152152
self.__top_1_count += int(ground_truth == top_1_index)
153153
self.__top_5_count += int(ground_truth in top_5_indices)
154154

155-
def summarize_accuracy(self):
155+
def _summarize_accuracy(self):
156156
"""
157157
A function summarizing the accuracy achieved on the images obtained with get_input_array() calls on which
158158
predictions done where supplied with submit_predictions() function.
159159
"""
160160
if self.do_skip():
161-
return {}
161+
return
162162

163163
top_1_accuracy = self.__top_1_count / self.__current_img
164164
# print("\n Top-1 accuracy = {:.3f}".format(top_1_accuracy))

utils/cv/kits.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,13 +232,13 @@ def submit_predictions(self, prediction):
232232
self.__calc_dice_score(full_prediction, ground_truth)
233233
self.__current_img_id += 1
234234

235-
def summarize_accuracy(self):
235+
def _summarize_accuracy(self):
236236
if self.__current_img_id < 1:
237237
utils.print_warning_message(
238238
"Not a single image has been completed - cannot calculate accuracy. Note that images of KiTS dataset "
239239
"are processed in slices due to their size. That implies that complete processing of one image can "
240240
"involve many passes through the network.")
241-
return {"mean_kidney_acc": None, "mean_tumor_acc": None, "mean_composite_acc": None}
241+
return
242242

243243
mean_kidney = self.__kidney_score / self.__current_img_id
244244
# print("\n Mean kidney segmentation accuracy = {:.3f}".format(mean_kidney))

utils/helpers.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,32 @@ def do_skip(self) -> bool:
4848
def reset(self) -> bool:
4949
raise NotImplementedError
5050

51-
def summarize_accuracy(self) -> dict:
51+
def _summarize_accuracy(self) -> dict:
5252
raise NotImplementedError
5353

54+
# don't override this method, override _summarize_accuracy instead
55+
def summarize_accuracy(self) -> dict:
56+
accuracy_results = self._summarize_accuracy()
57+
if accuracy_results is None:
58+
accuracy_results = {}
59+
assert type(accuracy_results) is dict
60+
for k, v in accuracy_results.items():
61+
assert isinstance(k, str)
62+
try:
63+
float(v)
64+
except Exception as e:
65+
raise e
66+
return accuracy_results
67+
5468
def print_accuracy_metrics(self) -> dict:
5569
accuracy_results = self.summarize_accuracy()
56-
assert type(accuracy_results) is dict
57-
if len(accuracy_results) == 0 or None in accuracy_results.values():
58-
print_warning_message("Accuracy metrics are unavailable.")
70+
if len(accuracy_results) == 0:
71+
print_warning_message("No accuracy metrics to print.")
5972
else:
6073
max_len = 14
6174
indent = 2 * " "
6275
print(f"\n{indent}ACCURACY")
6376
for metric in accuracy_results.keys():
64-
print(f"{3 * indent}{metric}{(max_len - len(metric)) * ' '}{3 * indent}" +
77+
print(f"{3 * indent}{metric[:max_len]}{(max_len - len(metric)) * ' '}{3 * indent}" +
6578
"= {:>7.3f}".format(float(accuracy_results[metric])))
6679
return accuracy_results

utils/nlp/alpaca_instruct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def metric_max_over_ground_truth(metric_fn, pred, gt):
102102
self._f1 += metric_max_over_ground_truth(f1_score, answer, ground_truth)
103103
self._count += 1
104104

105-
def summarize_accuracy(self):
105+
def _summarize_accuracy(self):
106106
exact_match = self._exact_match / self._count
107107
f1 = self._f1 / self._count
108108
return {"exact_match": exact_match, "f1": f1}

utils/nlp/cnn_dailymail.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def get_bigrams(text):
212212
self.__rouge2_count += rouge2_score(normalize(summary), normalize(ground_truth))
213213
self.__unsubmitted_count -= 1
214214

215-
def summarize_accuracy(self):
215+
def _summarize_accuracy(self):
216216
"""
217217
A function summarizing the accuracy achieved on the texts obtained with get_*_array() calls on which
218218
predicted summaries were supplied with submit_predictions() function.

utils/nlp/conll2003.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truth):
183183
self.__f1_count += metric_max_over_ground_truths(f1_score, prediction, ground_truth)
184184
self.__unsubmitted_count -= 1
185185

186-
def summarize_accuracy(self):
186+
def _summarize_accuracy(self):
187187
"""
188188
A function summarizing the accuracy achieved on the sequences obtained with get_*_array() calls on which
189189
predictions were supplied with submit_predictions() function.

utils/nlp/lambada.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
156156
self.__f1_count += metric_max_over_ground_truths(f1_score, answer, ground_truths)
157157
self.__unsubmitted_count -= 1
158158

159-
def summarize_accuracy(self):
159+
def _summarize_accuracy(self):
160160
"""
161161
A function summarizing the accuracy achieved on the questions obtained with get_*_array() calls on which
162162
predicted answers were supplied with submit_predictions() function.

utils/nlp/mrpc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def submit_predictions(self, prediction, label):
8383
if label == prediction:
8484
self.__correct += 1
8585

86-
def summarize_accuracy(self):
86+
def _summarize_accuracy(self):
8787
"""
8888
A function summarizing the obtained accuracy for the model
8989

0 commit comments

Comments
 (0)