Skip to content
This repository was archived by the owner on Nov 9, 2022. It is now read-only.

Commit a4566a2

Browse files
authored
Merge pull request #4 from dronedeploy/basedir_options
[in progress] finishing up saving to basedir
2 parents 1f22ee8 + eea1c30 commit a4566a2

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

libs/inference.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ def predict(self, imagefile, predsfile, size=1200):
100100
mask = category2mask(prediction)
101101
cv2.imwrite(predsfile, mask)
102102

103-
def run_inference(dataset, model_name='baseline_model'):
103+
def run_inference(dataset, model_name='baseline_model', basedir="predictions"):
104+
if not os.path.isdir(basedir):
105+
os.mkdir(basedir)
104106

105107
size = 1200
106108
modelpath = 'models'
@@ -116,7 +118,7 @@ def run_inference(dataset, model_name='baseline_model'):
116118

117119
imagefile = f'{dataset}/images/{scene}-ortho.tif'
118120
labelfile = f'{dataset}/labels/{scene}-label.png'
119-
predsfile = f"{scene}-prediction.png"
121+
predsfile = f"{basedir}/{scene}-prediction.png"
120122

121123
if not os.path.exists(imagefile):
122124
#print(f"image {imagefile} not found, skipping.")

libs/inference_keras.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ def run_inference_on_file(imagefile, predsfile, model, size=300):
5353
mask = category2mask(prediction)
5454
Image.fromarray(mask).save(predsfile)
5555

56-
def run_inference(dataset, model=None, model_path=None, basedir='.'):
56+
def run_inference(dataset, model=None, model_path=None, basedir='predictions'):
57+
if not os.path.isdir(basedir):
58+
os.mkdir(basedir)
5759
if model is None and model_path is None:
5860
raise Exception("model or model_path required")
5961

0 commit comments

Comments
 (0)