66
77from labelbox .data .metrics .iou import data_row_miou
88from labelbox .data .serialization import NDJsonConverter , LBV1Converter
9- from labelbox .data .annotation_types import Label , RasterData
9+ from labelbox .data .annotation_types import Label , RasterData , Mask
1010
1111
12- def check_iou (pair ):
12+ def check_iou (pair , mask = False ):
1313 default = Label (data = RasterData (uid = "ckppihxc10005aeyjen11h7jh" ))
14- assert math .isclose (
15- data_row_miou (
16- next (LBV1Converter .deserialize ([pair .labels ])),
17- next (NDJsonConverter .deserialize (pair .predictions ), default )),
18- pair .expected )
14+ prediction = next (NDJsonConverter .deserialize (pair .predictions ), default )
15+ label = next (LBV1Converter .deserialize ([pair .labels ]))
16+ if mask :
17+ for annotation in [* prediction .annotations , * label .annotations ]:
18+ if isinstance (annotation .value , Mask ):
19+ annotation .value .mask .arr = np .frombuffer (
20+ base64 .b64decode (annotation .value .mask .url .encode ('utf-8' )),
21+ dtype = np .uint8 ).reshape ((32 , 32 , 3 ))
22+ assert math .isclose (data_row_miou (label , prediction ), pair .expected )
1923
2024
2125def strings_to_fixtures (strings ):
@@ -25,11 +29,7 @@ def strings_to_fixtures(strings):
2529def test_overlapping (polygon_pair , box_pair , mask_pair ):
2630 check_iou (polygon_pair )
2731 check_iou (box_pair )
28- #with patch('labelbox.data.metrics.iou.url_to_numpy',
29- # side_effect=lambda x: np.frombuffer(
30- # base64.b64decode(x.encode('utf-8')), dtype=np.uint8).reshape(
31- # (32, 32, 3))):
32- # #check_iou(mask_pair)
32+ check_iou (mask_pair , True )
3333
3434
3535@parametrize ("pair" ,
0 commit comments