33import logging
44import os
55import scipy
6- import sys
76import torch
87import numpy as np
98import pandas as pd
@@ -27,7 +26,11 @@ def get_confident_map(gt, pred, CL_type = 'both'):
2726
2827 :return: A tensor representing the noisiness of each pixel.
2928 """
30- import cleanlab
29+ try :
30+ import cleanlab
31+ assert (cleanlab .__version__ == '1.0.1' )
32+ except :
33+ raise ValueError ("Error: cleanlab 1.0.1 required. Please install it by `pip install cleanlab==1.0.1`" )
3134 prob = scipy .special .softmax (pred , axis = 1 )
3235 if CL_type in ['both' , 'Qij' ]:
3336 noise = cleanlab .pruning .get_noise_indices (gt , prob , prune_method = 'both' , n_jobs = 1 )
@@ -146,15 +149,7 @@ def test_time_dropout(m):
146149 dst_path = os .path .join (save_dir , filename )
147150 conf_map .save (dst_path )
148151
149- def get_confidence_map ():
150- """
151- The main function to get the confidence map during inference.
152- """
153- if (len (sys .argv ) < 2 ):
154- print ('Number of arguments should be 3. e.g.' )
155- print (' python nll_clslsr.py config.cfg' )
156- exit ()
157- cfg_file = str (sys .argv [1 ])
152+ def get_confidence_map (cfg_file ):
158153 config = parse_config (cfg_file )
159154 config = synchronize_config (config )
160155
@@ -173,7 +168,7 @@ def get_confidence_map():
173168 one_transform = transform_dict [name ](transform_param )
174169 transform_list .append (one_transform )
175170 data_transform = transforms .Compose (transform_list )
176- print ( 'transform list' , transform_list )
171+
177172 csv_file = config ['dataset' ]['train_csv' ]
178173 modal_num = config ['dataset' ].get ('modal_num' , 1 )
179174 dataset = NiftyDataset (root_dir = config ['dataset' ]['root_dir' ],
@@ -201,7 +196,4 @@ def get_confidence_map():
201196 "label" : df_train ["label" ]}
202197 train_cl_csv = csv_file .replace (".csv" , "_clslsr.csv" )
203198 df_cl = pd .DataFrame .from_dict (train_cl_dict )
204- df_cl .to_csv (train_cl_csv , index = False )
205-
206- if __name__ == "__main__" :
207- get_confidence_map ()
199+ df_cl .to_csv (train_cl_csv , index = False )
0 commit comments