Skip to content

Commit d4b8c64

Browse files
Fix error in mislabelling predictions.
1 parent 093706f commit d4b8c64

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

classifier/predictor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@ def init_from_config_path(cls, config_path):
2727

2828
def predict_from_array(self, arr) -> Dict[str, float]:
2929
"""Returns a prediction value the sample belongs to each class."""
30-
pred = self.model.predict(arr[np.newaxis, ...]).ravel().tolist()
31-
pred = [round(x, 3) for x in pred] # values between 0-1
30+
# in this model, 'Normal Images' is the positive class (labeled by 1)
31+
pred_arr = self.model.predict(arr[np.newaxis, ...]).ravel().tolist()
32+
# so we convert the probability to predict for 'Fire_Images'
33+
pred = [1 - probability for probability in pred_arr]
3234
return {class_label: prob for class_label, prob in zip(self.targets, pred)}
3335

3436
def predict_from_file(self, file_object):

0 commit comments

Comments
 (0)