Skip to content

Commit 9b0e75b

Browse files
committed
add mask parameter to computeScoreMap
1 parent 74d5e0c commit 9b0e75b

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

MTM/__init__.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import cv2
22
import numpy as np
33
import pandas as pd
4+
import warnings
45
from skimage.feature import peak_local_max
56
from scipy.signal import find_peaks
67
from .version import __version__
@@ -46,7 +47,7 @@ def _findLocalMin_(corrMap, score_threshold=0.4):
4647
return _findLocalMax_(-corrMap, -score_threshold)
4748

4849

49-
def computeScoreMap(template, image, method=cv2.TM_CCOEFF_NORMED):
50+
def computeScoreMap(template, image, method=cv2.TM_CCOEFF_NORMED, mask=None):
5051
'''
5152
Compute score map provided numpy array for template and image.
5253
Automatically converts images if necessary
@@ -60,8 +61,20 @@ def computeScoreMap(template, image, method=cv2.TM_CCOEFF_NORMED):
6061
template = np.float32(template)
6162
image = np.float32(image)
6263

64+
if mask:
65+
66+
if method not in (0,3):
67+
mask = None
68+
warnings.warn("Template matching method not compatible with use of mask (only TM_SQDIFF or TM_CCORR_NORMED).\n-> Ignoring mask.")
69+
70+
else: # correct method
71+
# Check that mask has the same dimensions and type than template
72+
sameDimension = mask.shape == template.shape
73+
sameType = mask.dtype == template.dtype
74+
if not sameDimension and sameType: mask = None
75+
6376
# Compute correlation map
64-
return cv2.matchTemplate(template, image, method)
77+
return cv2.matchTemplate(template, image, method, mask=mask)
6578

6679

6780
def findMatches(listTemplates, image, method=cv2.TM_CCOEFF_NORMED, N_object=float("inf"), score_threshold=0.5, searchBox=None):
@@ -97,7 +110,7 @@ def findMatches(listTemplates, image, method=cv2.TM_CCOEFF_NORMED, N_object=floa
97110
xOffset=yOffset=0
98111

99112
listHit = []
100-
for templateName, template in listTemplates:
113+
for templateName, template in listTemplates: # put the mask as 3rd tuple member ? but then unwrap in the loop rather
101114

102115
#print('\nSearch with template : ',templateName)
103116

0 commit comments

Comments
 (0)