Skip to content

Commit 9efa683

Browse files
author
donglaiw
committed
test singly: enable output filenames to be the same as input names
1 parent 9ef9231 commit 9efa683

File tree

3 files changed

+20
-12
lines changed

3 files changed

+20
-12
lines changed

connectomics/config/defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@
484484
# Do inference one-by-on (load a volume when needed).
485485
_C.INFERENCE.DO_SINGLY = False
486486
_C.INFERENCE.DO_SINGLY_START_INDEX = 0
487+
_C.INFERENCE.DO_SINGLY_STEP = 1
487488

488489
_C.INFERENCE.PAD_SIZE = None
489490
_C.INFERENCE.UNPAD = True

connectomics/data/dataset/build.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,12 @@ def _get_file_list(name: Union[str, List[str]],
6767
return name
6868

6969
suffix = name.split('.')[-1]
70-
if suffix == 'txt': # a text file saving the absolute path
70+
if suffix == 'txt': # a text file saving the absolute or relative path
7171
with open(name) as file:
72-
filelist = [line.rstrip('\n') for line in file]
72+
if prefix is None:
73+
filelist = [line.rstrip('\n') for line in file]
74+
else:
75+
filelist = [os.path.join(prefix, line.rstrip('\n')) for line in file]
7376
return filelist
7477

7578
suffix = name.split('/')[-1]

connectomics/engine/trainer.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -274,22 +274,26 @@ def test(self):
274274

275275
def test_singly(self):
276276
dir_name = _get_file_list(self.cfg.DATASET.INPUT_PATH)
277-
img_name = _get_file_list(self.cfg.DATASET.IMAGE_NAME, prefix=dir_name[0])
278277
assert len(dir_name) == 1 # avoid ambiguity when DO_SINGLY is True
278+
img_name = _get_file_list(self.cfg.DATASET.IMAGE_NAME, prefix=dir_name[0])
279+
num_file = len(img_name)
280+
281+
if os.path.isfile(self.cfg.INFERENCE.OUTPUT_NAME):
282+
output_name = _get_file_list(self.cfg.DATASET.OUTPUT_NAME, prefix=self.output_dir)
283+
else:
284+
# same filename but different location
285+
if self.output_dir != dir_name[0]:
286+
output_name = _get_file_list(self.cfg.DATASET.IMAGE_NAME, prefix=self.output_dir)
287+
else:
288+
output_name = [x+'_result.h5' for x in img_name]
279289

280-
# save input image names for further reference
290+
# save input image names for future reference
281291
fw = open(os.path.join(self.output_dir, "images.txt"), "w")
282292
fw.write('\n'.join(img_name))
283293
fw.close()
284294

285-
num_file = len(img_name)
286-
start_idx = self.cfg.INFERENCE.DO_SINGLY_START_INDEX
287-
digits = int(math.log10(num_file))+1
288-
for i in range(start_idx, num_file):
289-
self.test_filename = self.cfg.INFERENCE.OUTPUT_NAME + \
290-
'_' + str(i).zfill(digits) + '.h5'
291-
self.test_filename = self.augmentor.update_name(
292-
self.test_filename)
295+
for i in range(self.cfg.INFERENCE.DO_SINGLY_START_INDEX, num_file, self.cfg.INFERENCE.DO_SINGLY_STEP):
296+
self.test_filename = output_name[i]
293297
if not os.path.exists(self.test_filename):
294298
dataset = get_dataset(
295299
self.cfg, self.augmentor, self.mode, self.rank,

0 commit comments

Comments
 (0)