Skip to content

Commit 7b7eaac

Browse files
committed
Added a couple of tests for ProductDomain
1 parent 4a1c493 commit 7b7eaac

File tree

1 file changed

+96
-1
lines changed

1 file changed

+96
-1
lines changed

src/qinfer/tests/test_domains.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@
4040
)
4141
import abc
4242
from qinfer import (
43-
Domain, RealDomain, IntegerDomain, MultinomialDomain
43+
Domain, ProductDomain,
44+
RealDomain, IntegerDomain, MultinomialDomain
4445
)
4546

4647
import unittest
@@ -175,3 +176,97 @@ def test_array_conversion(self):
175176
assert_equal(self.domain.from_regular_array(arr2), arr1)
176177
assert_equal(self.domain.to_regular_array(self.domain.from_regular_array(arr2)), arr2)
177178
assert_equal(self.domain.from_regular_array(self.domain.to_regular_array(arr1)), arr1)
179+
180+
181+
class TestIntegerIntegerProductDomain(ConcreteDomainTest, DerandomizedTestCase):
182+
"""
183+
Tests ProductDomain([IntegerDomain, IntegerDomain])
184+
"""
185+
186+
def instantiate_domain(self):
187+
return ProductDomain(
188+
IntegerDomain(min=0,max=5),
189+
IntegerDomain(min=-2, max=2)
190+
)
191+
def instantiate_good_values(self):
192+
return [
193+
np.array([(0,0)], dtype=[('','<i8'),('','<i8')]),
194+
np.array([(5,2),(0,-2)], dtype=[('','<i8'),('','<i8')])
195+
]
196+
def instantiate_bad_values(self):
197+
return [
198+
np.array([(0.5,0)], dtype=[('','f8'),('','<i8')]),
199+
np.array([(6,2),(0,-2)], dtype=[('','<i8'),('','<i8')])
200+
]
201+
202+
def test_array_conversion(self):
203+
arr1 = np.array([(5,2),(0,-1)], dtype=[('','<i8'),('','<i8')])
204+
arr2 = [np.array([5,0]), np.array([2,-1])]
205+
206+
assert_equal(self.domain.to_regular_arrays(arr1), arr2)
207+
assert_equal(self.domain.from_regular_arrays(arr2), arr1)
208+
assert_equal(self.domain.to_regular_arrays(self.domain.from_regular_arrays(arr2)), arr2)
209+
assert_equal(self.domain.from_regular_arrays(self.domain.to_regular_arrays(arr1)), arr1)
210+
211+
#override
212+
def test_is_finite(self):
213+
assert(self.domain.is_finite)
214+
assert(self.domain.is_discrete)
215+
216+
class TestIntegerIntegerProductDomain2(ConcreteDomainTest, DerandomizedTestCase):
217+
"""
218+
Tests ProductDomain([IntegerDomain, IntegerDomain])
219+
"""
220+
221+
def instantiate_domain(self):
222+
return ProductDomain(
223+
IntegerDomain(min=0,max=5),
224+
IntegerDomain(min=-2, max=np.inf),
225+
IntegerDomain(min=-np.inf, max=np.inf)
226+
)
227+
def instantiate_good_values(self):
228+
return [
229+
np.array([(0,0,0)], dtype=[('','<i8'),('','<i8'),('','<i8')]),
230+
np.array([(5,2,0),(0,-2,10)], dtype=[('','<i8'),('','<i8'),('','<i8')])
231+
]
232+
def instantiate_bad_values(self):
233+
return [
234+
np.array([(0.5,0,10)], dtype=[('','f8'),('','<i8'),('','<i8')]),
235+
np.array([(6,2,10),(0,-2,10)], dtype=[('','<i8'),('','<i8'),('','<i8')])
236+
]
237+
238+
#override
239+
def test_is_finite(self):
240+
assert(not self.domain.is_finite)
241+
assert(self.domain.is_discrete)
242+
243+
class TestIntegerMultinomialProductDomain(ConcreteDomainTest, DerandomizedTestCase):
244+
"""
245+
Tests ProductDomain([IntegerDomain, IntegerDomain])
246+
"""
247+
248+
def instantiate_domain(self):
249+
return ProductDomain(
250+
IntegerDomain(min=0,max=5),
251+
MultinomialDomain(5, n_elements=3)
252+
)
253+
def instantiate_good_values(self):
254+
return [
255+
np.array([(0,[1,2,2])], dtype=[('','<i8'),('','<i8', 3)]),
256+
np.array([(5,[5,0,0]),(0,[1,0,4])], dtype=[('','<i8'),('','<i8', 3)])
257+
]
258+
def instantiate_bad_values(self):
259+
return [
260+
np.array([(-10,[1,2,2])], dtype=[('','<i8'),('k','<i8', 3)]),
261+
np.array([(5,[-1,6,0]),(0,[1,0,4])], dtype=[('','<i8'),('k','<i8', 3)])
262+
]
263+
264+
def test_array_conversion(self):
265+
arr1 = np.array([(5,[1,2,2]),(0,[5,0,0])], dtype=[('','<i8'),('k','<i8', 3)])
266+
arr2 = [np.array([5,0]), np.array([([1,2,2],), ([5,0,0],)], dtype=[('k','<i8',3)])]
267+
268+
assert_equal(self.domain.to_regular_arrays(arr1), arr2)
269+
assert_equal(self.domain.from_regular_arrays(arr2), arr1)
270+
assert_equal(self.domain.to_regular_arrays(self.domain.from_regular_arrays(arr2)), arr2)
271+
assert_equal(self.domain.from_regular_arrays(self.domain.to_regular_arrays(arr1)), arr1)
272+

0 commit comments

Comments
 (0)