Skip to content

Commit 453c425

Browse files
committed
use opencv nms
1 parent bc92369 commit 453c425

File tree

1 file changed

+17
-146
lines changed

1 file changed

+17
-146
lines changed

MTM/NMS.py

Lines changed: 17 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -10,74 +10,11 @@
1010
1111
@author: Laurent Thomas
1212
"""
13-
from __future__ import division, print_function # for compatibility with Py2
14-
import pandas as pd
1513

16-
def Point_in_Rectangle(Point, Rectangle):
17-
'''Return True if a point (x,y) is contained in a Rectangle(x, y, width, height)'''
18-
# unpack variables
19-
Px, Py = Point
20-
Rx, Ry, w, h = Rectangle
14+
import cv2
2115

22-
return (Rx <= Px) and (Px <= Rx + w -1) and (Ry <= Py) and (Py <= Ry + h -1) # simply test if x_Point is in the range of x for the rectangle
2316

24-
25-
def computeIoU(BBox1,BBox2):
26-
'''
27-
Compute the IoU (Intersection over Union) between 2 rectangular bounding boxes defined by the top left (Xtop,Ytop) and bottom right (Xbot, Ybot) pixel coordinates
28-
Code adapted from https://www.pyimagesearch.com/2016/11/07/intersection-over-union-iou-for-object-detection/
29-
'''
30-
#print('BBox1 : ', BBox1)
31-
#print('BBox2 : ', BBox2)
32-
33-
# Unpack input (python3 - tuple are no more supported as input in function definition - PEP3113 - Tuple can be used in as argument in a call but the function will not unpack it automatically)
34-
Xleft1, Ytop1, Width1, Height1 = BBox1
35-
Xleft2, Ytop2, Width2, Height2 = BBox2
36-
37-
# Compute bottom coordinates
38-
Xright1 = Xleft1 + Width1 -1 # we remove -1 from the width since we start with 1 pixel already (the top one)
39-
Ybot1 = Ytop1 + Height1 -1 # idem for the height
40-
41-
Xright2 = Xleft2 + Width2 -1
42-
Ybot2 = Ytop2 + Height2 -1
43-
44-
# determine the (x, y)-coordinates of the top left and bottom right points of the intersection rectangle
45-
Xleft = max(Xleft1, Xleft2)
46-
Ytop = max(Ytop1, Ytop2)
47-
Xright = min(Xright1, Xright2)
48-
Ybot = min(Ybot1, Ybot2)
49-
50-
# Compute boolean for inclusion
51-
BBox1_in_BBox2 = Point_in_Rectangle((Xleft1, Ytop1), BBox2) and Point_in_Rectangle((Xleft1, Ybot1), BBox2) and Point_in_Rectangle((Xright1, Ytop1), BBox2) and Point_in_Rectangle((Xright1, Ybot1), BBox2)
52-
BBox2_in_BBox1 = Point_in_Rectangle((Xleft2, Ytop2), BBox1) and Point_in_Rectangle((Xleft2, Ybot2), BBox1) and Point_in_Rectangle((Xright2, Ytop2), BBox1) and Point_in_Rectangle((Xright2, Ybot2), BBox1)
53-
54-
# Check that for the intersection box, Xtop,Ytop is indeed on the top left of Xbot,Ybot
55-
if BBox1_in_BBox2 or BBox2_in_BBox1:
56-
#print('One BBox is included within the other')
57-
IoU = 1
58-
59-
elif Xright<Xleft or Ybot<Ytop : # it means that there is no intersection (bbox is inverted)
60-
#print('No overlap')
61-
IoU = 0
62-
63-
else:
64-
# Compute area of the intersecting box
65-
Inter = (Xright - Xleft + 1) * (Ybot - Ytop + 1) # +1 since we are dealing with pixels. See a 1D example with 3 pixels for instance
66-
#print('Intersection area : ', Inter)
67-
68-
# Compute area of the union as Sum of the 2 BBox area - Intersection
69-
Union = Width1 * Height1 + Width2 * Height2 - Inter
70-
#print('Union : ', Union)
71-
72-
# Compute Intersection over union
73-
IoU = Inter/Union
74-
75-
#print('IoU : ',IoU)
76-
return IoU
77-
78-
79-
80-
def NMS(tableHit, scoreThreshold=None, sortAscending=False, N_object=float("inf"), maxOverlap=0.5):
17+
def NMS(tableHit, scoreThreshold=0, sortAscending=False, N_object=-1, maxOverlap=0.5):
8118
'''
8219
Perform Non-Maxima supression : it compares the hits after maxima/minima detection, and removes the ones that are too close (too large overlap)
8320
This function works both with an optionnal threshold on the score, and number of detected bbox
@@ -95,105 +32,39 @@ def NMS(tableHit, scoreThreshold=None, sortAscending=False, N_object=float("inf"
9532
- tableHit : (Panda DataFrame) Each row is a hit, with columns "TemplateName"(String),"BBox"(x,y,width,height),"Score"(float)
9633
9734
- scoreThreshold : Float (or None), used to remove hit with too low prediction score.
98-
If sortDescending=True (ie we use a correlation measure so we want to keep large scores) the scores above that threshold are kept
99-
While if we use sortDescending=False (we use a difference measure ie we want to keep low score), the scores below that threshold are kept
35+
If sortAscending=False (ie we use a correlation measure so we want to keep large scores) the scores above that threshold are kept
36+
If True (we use a difference measure ie we want to keep low score), the scores below that threshold are kept
10037
101-
- N_object : number of best hit to return (by increasing score). Min=1, eventhough it does not really make sense to do NMS with only 1 hit
38+
- N_object : maximum number of hit to return. Default=-1, ie return all hit passing NMS
10239
- maxOverlap : float between 0 and 1, the maximal overlap authorised between 2 bounding boxes, above this value, the bounding box of lower score is deleted
10340
- sortAscending : use True when low score means better prediction (Difference-based score), True otherwise (Correlation score)
10441
10542
OUTPUT
10643
Panda DataFrame with best detection after NMS, it contains max N detection (but potentially less)
10744
'''
45+
listBoxes = tableHit["BBox"].to_list()
46+
listScores = tableHit["Score"].to_list()
10847

109-
# Apply threshold on prediction score
110-
if scoreThreshold==None :
111-
threshTable = tableHit.copy() # copy to avoid modifying the input list in place
112-
113-
elif not sortAscending : # We keep rows above the threshold
114-
threshTable = tableHit[ tableHit['Score']>=scoreThreshold ]
115-
116-
elif sortAscending : # We keep hit below the threshold
117-
threshTable = tableHit[ tableHit['Score']<=scoreThreshold ]
118-
119-
# Sort score to have best predictions first (ie lower score if difference-based, higher score if correlation-based)
120-
# important as we loop testing the best boxes against the other boxes)
121-
threshTable.sort_values("Score", ascending=sortAscending, inplace=True) # Warning here is fine
122-
123-
124-
# Split the inital pool into Final Hit that are kept and restTable that can be tested
125-
# Initialisation : 1st keep is kept for sure, restTable is the rest of the list
126-
#print("\nInitialise final hit list with first best hit")
127-
outTable = threshTable.iloc[[0]].to_dict('records') # double square bracket to recover a DataFrame
128-
restTable = threshTable.iloc[1:].to_dict('records')
129-
130-
131-
# Loop to compute overlap
132-
while len(outTable)<N_object and restTable: # second condition is restTable is not empty
133-
134-
# Report state of the loop
135-
#print("\n\n\nNext while iteration")
136-
137-
#print("-> Final hit list")
138-
#for hit in outTable: print(hit)
139-
140-
#print("\n-> Remaining hit list")
141-
#for hit in restTable: print(hit)
48+
if sortAscending:
49+
listScores = [1-score for score in listScores] # NMS expect high-score for good predictions
50+
scoreThreshold = 1-scoreThreshold
14251

143-
# pick the next best peak in the rest of peak
144-
testHit_dico = restTable[0] # dico
145-
test_bbox = testHit_dico['BBox']
146-
#print("\nTest BBox:{} for overlap against higher score bboxes".format(test_bbox))
147-
148-
# Loop over hit in outTable to compute successively overlap with testHit
149-
for hit_dico in outTable:
150-
151-
# Recover Bbox from hit
152-
bbox2 = hit_dico['BBox']
153-
154-
# Compute the Intersection over Union between test_peak and current peak
155-
IoU = computeIoU(test_bbox, bbox2)
156-
157-
# Initialise the boolean value to true before test of overlap
158-
ToAppend = True
159-
160-
if IoU>maxOverlap:
161-
ToAppend = False
162-
#print("IoU above threshold\n")
163-
break # no need to test overlap with the other peaks
164-
165-
else:
166-
#print("IoU below threshold\n")
167-
# no overlap for this particular (test_peak,peak) pair, keep looping to test the other (test_peak,peak)
168-
continue
169-
170-
171-
# After testing against all peaks (for loop is over), append or not the peak to final
172-
if ToAppend:
173-
# Move the test_hit from restTable to outTable
174-
#print("Append {} to list of final hits, remove it from Remaining hit list".format(test_hit))
175-
outTable.append(testHit_dico)
176-
restTable.remove(testHit_dico)
177-
178-
else:
179-
# only remove the test_peak from restTable
180-
#print("Remove {} from Remaining hit list".format(test_hit))
181-
restTable.remove(testHit_dico)
182-
52+
indexes = cv2.dnn.NMSBoxes(listBoxes, listScores, scoreThreshold, maxOverlap, top_k=N_object)
53+
54+
indexes = [index[0] for index in indexes]
55+
outTable = tableHit[tableHit.index.isin(indexes)]
18356

184-
# Once function execution is done, return list of hit without overlap
185-
#print("\nCollected N expected hit, or no hit left to test")
186-
#print("NMS over\n")
187-
return pd.DataFrame(outTable)
57+
return outTable
18858

18959

19060
if __name__ == "__main__":
61+
import pandas as pd
19162
ListHit =[
19263
{'TemplateName':1,'BBox':(780, 350, 700, 480), 'Score':0.8},
19364
{'TemplateName':1,'BBox':(806, 416, 716, 442), 'Score':0.6},
19465
{'TemplateName':1,'BBox':(1074, 530, 680, 390), 'Score':0.4}
19566
]
19667

197-
FinalHits = NMS( pd.DataFrame(ListHit), scoreThreshold=0.7, sortAscending=False, maxOverlap=0.5 )
68+
FinalHits = NMS( pd.DataFrame(ListHit), scoreThreshold=0.61, sortAscending=True, maxOverlap=0.8, N_object=2 )
19869

19970
print(FinalHits)

0 commit comments

Comments
 (0)