1111
1212class NameSpace (SimpleNamespace ):
1313
14- def __init__ (self , predictions , ground_truths , expected ):
15- super (NameSpace , self ).__init__ (predictions = predictions ,
16- ground_truths = ground_truths ,
17- expected = expected )
14+ def __init__ (self ,
15+ predictions ,
16+ ground_truths ,
17+ expected ,
18+ expected_without_subclasses = None ):
19+ super (NameSpace , self ).__init__ (
20+ predictions = predictions ,
21+ ground_truths = ground_truths ,
22+ expected = expected ,
23+ expected_without_subclasses = expected_without_subclasses or expected )
1824
1925
2026def get_radio (name , answer_name ):
@@ -109,7 +115,8 @@ def get_object_pairs(tool_fn, **kwargs):
109115 ** kwargs ,
110116 subclasses = [get_radio ("is_animal" , answer_name = "yes" )])
111117 ],
112- expected = {'cat' : [1 , 0 , 0 , 0 ]}),
118+ expected = {'cat' : [1 , 0 , 0 , 0 ]},
119+ expected_without_subclasses = {'cat' : [1 , 0 , 0 , 0 ]}),
113120 NameSpace (predictions = [
114121 tool_fn ("cat" ,
115122 ** kwargs ,
@@ -121,7 +128,8 @@ def get_object_pairs(tool_fn, **kwargs):
121128 ** kwargs ,
122129 subclasses = [get_radio ("is_animal" , answer_name = "no" )])
123130 ],
124- expected = {'cat' : [0 , 1 , 0 , 1 ]}),
131+ expected = {'cat' : [0 , 1 , 0 , 1 ]},
132+ expected_without_subclasses = {'cat' : [1 , 0 , 0 , 0 ]}),
125133 NameSpace (predictions = [
126134 tool_fn ("cat" ,
127135 ** kwargs ,
@@ -136,7 +144,8 @@ def get_object_pairs(tool_fn, **kwargs):
136144 ** kwargs ,
137145 subclasses = [get_radio ("is_animal" , answer_name = "no" )])
138146 ],
139- expected = {'cat' : [1 , 1 , 0 , 0 ]}),
147+ expected = {'cat' : [1 , 1 , 0 , 0 ]},
148+ expected_without_subclasses = {'cat' : [1 , 1 , 0 , 0 ]}),
140149 NameSpace (predictions = [
141150 tool_fn ("cat" ,
142151 ** kwargs ,
@@ -154,6 +163,10 @@ def get_object_pairs(tool_fn, **kwargs):
154163 expected = {
155164 'cat' : [0 , 1 , 0 , 1 ],
156165 'dog' : [0 , 1 , 0 , 0 ]
166+ },
167+ expected_without_subclasses = {
168+ 'cat' : [1 , 0 , 0 , 0 ],
169+ 'dog' : [0 , 1 , 0 , 0 ]
157170 }),
158171 NameSpace (
159172 predictions = [tool_fn ("cat" , ** kwargs ),
@@ -171,7 +184,10 @@ def get_object_pairs(tool_fn, **kwargs):
171184 ground_truths = [tool_fn ("cat" , ** kwargs ),
172185 tool_fn ("cat" , ** kwargs )],
173186 expected = {'cat' : [1 , 0 , 0 , 1 ]}),
174- NameSpace (predictions = [], ground_truths = [], expected = []),
187+ NameSpace (predictions = [],
188+ ground_truths = [],
189+ expected = [],
190+ expected_without_subclasses = []),
175191 NameSpace (predictions = [],
176192 ground_truths = [tool_fn ("cat" , ** kwargs )],
177193 expected = {'cat' : [0 , 0 , 0 , 1 ]}),
@@ -183,7 +199,7 @@ def get_object_pairs(tool_fn, **kwargs):
183199 expected = {
184200 'cat' : [0 , 1 , 0 , 0 ],
185201 'dog' : [0 , 0 , 0 , 1 ]
186- }),
202+ })
187203 ]
188204
189205
0 commit comments