Skip to content

Commit 9ab3fca

Browse files
committed
Update nll_clslsr.py
update nll_clsslsr
1 parent fee0a47 commit 9ab3fca

File tree

1 file changed

+8
-16
lines changed

1 file changed

+8
-16
lines changed

pymic/net_run_nll/nll_clslsr.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import logging
44
import os
55
import scipy
6-
import sys
76
import torch
87
import numpy as np
98
import pandas as pd
@@ -27,7 +26,11 @@ def get_confident_map(gt, pred, CL_type = 'both'):
2726
2827
:return: A tensor representing the noisiness of each pixel.
2928
"""
30-
import cleanlab
29+
try:
30+
import cleanlab
31+
assert(cleanlab.__version__ == '1.0.1')
32+
except:
33+
raise ValueError("Error: cleanlab 1.0.1 required. Please install it by `pip install cleanlab==1.0.1`")
3134
prob = scipy.special.softmax(pred, axis = 1)
3235
if CL_type in ['both', 'Qij']:
3336
noise = cleanlab.pruning.get_noise_indices(gt, prob, prune_method='both', n_jobs=1)
@@ -146,15 +149,7 @@ def test_time_dropout(m):
146149
dst_path = os.path.join(save_dir, filename)
147150
conf_map.save(dst_path)
148151

149-
def get_confidence_map():
150-
"""
151-
The main function to get the confidence map during inference.
152-
"""
153-
if(len(sys.argv) < 2):
154-
print('Number of arguments should be 3. e.g.')
155-
print(' python nll_clslsr.py config.cfg')
156-
exit()
157-
cfg_file = str(sys.argv[1])
152+
def get_confidence_map(cfg_file):
158153
config = parse_config(cfg_file)
159154
config = synchronize_config(config)
160155

@@ -173,7 +168,7 @@ def get_confidence_map():
173168
one_transform = transform_dict[name](transform_param)
174169
transform_list.append(one_transform)
175170
data_transform = transforms.Compose(transform_list)
176-
print('transform list', transform_list)
171+
177172
csv_file = config['dataset']['train_csv']
178173
modal_num = config['dataset'].get('modal_num', 1)
179174
dataset = NiftyDataset(root_dir = config['dataset']['root_dir'],
@@ -201,7 +196,4 @@ def get_confidence_map():
201196
"label": df_train["label"]}
202197
train_cl_csv = csv_file.replace(".csv", "_clslsr.csv")
203198
df_cl = pd.DataFrame.from_dict(train_cl_dict)
204-
df_cl.to_csv(train_cl_csv, index = False)
205-
206-
if __name__ == "__main__":
207-
get_confidence_map()
199+
df_cl.to_csv(train_cl_csv, index = False)

0 commit comments

Comments
 (0)