Skip to content

Commit 4d0aee3

Browse files
authored
Merge pull request QInfer#117 from ihincks/feature-product-domain
Feature: ProductDomain
2 parents d93aaf3 + 28ed55f commit 4d0aee3

File tree

4 files changed

+297
-3
lines changed

4 files changed

+297
-3
lines changed

doc/source/apiref/domains.rst

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,16 @@ Class Reference
3232
.. autoclass:: Domain
3333
:members:
3434

35+
36+
:class:`ProductDomain` - Cartesian product of multiple domains
37+
--------------------------------------------------------------
38+
39+
Class Reference
40+
~~~~~~~~~~~~~~~
41+
.. autoclass:: ProductDomain
42+
:members:
43+
44+
3545
:class:`RealDomain` - (A subset of) Real Numbers
3646
------------------------------------------------
3747

@@ -58,4 +68,4 @@ This domain is used by :class:`MultinomialModel`.
5868
Class Reference
5969
~~~~~~~~~~~~~~~
6070
.. autoclass:: MultinomialDomain
61-
:members:
71+
:members:

src/qinfer/domains.py

Lines changed: 148 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,14 @@
3030

3131
from builtins import range
3232
from future.utils import with_metaclass
33+
from functools import reduce
3334

35+
from operator import mul
3436
from scipy.special import binom
3537
from math import factorial
36-
from itertools import combinations_with_replacement
38+
from itertools import combinations_with_replacement, product
3739
import numpy as np
40+
from .utils import join_struct_arrays, separate_struct_array
3841

3942
import abc
4043

@@ -44,6 +47,7 @@
4447

4548
__all__ = [
4649
'Domain',
50+
'ProductDomain',
4751
'RealDomain',
4852
'IntegerDomain',
4953
'MultinomialDomain'
@@ -145,6 +149,149 @@ def in_domain(self, points):
145149
"""
146150
pass
147151

152+
class ProductDomain(Domain):
153+
"""
154+
A domain made from the cartesian product of other domains.
155+
156+
:param Domain domains: ``Domain`` instances as separate arguments,
157+
or as a singe list of ``Domain`` instances.
158+
"""
159+
def __init__(self, *domains):
160+
161+
if len(domains) == 1:
162+
try:
163+
self._domains = list(domains[0])
164+
except:
165+
self._domains = domains
166+
else:
167+
self._domains = domains
168+
169+
self._domains = domains
170+
self._dtypes = [domain.example_point.dtype for domain in self._domains]
171+
self._example_point = join_struct_arrays(
172+
[np.array(domain.example_point) for domain in self._domains]
173+
)
174+
self._dtype = self._example_point.dtype
175+
176+
@property
177+
def is_continuous(self):
178+
"""
179+
Whether or not the domain has an uncountable number of values.
180+
181+
:type: `bool`
182+
"""
183+
return any([domain.is_continuous for domain in self._domains])
184+
185+
@property
186+
def is_finite(self):
187+
"""
188+
Whether or not the domain contains a finite number of points.
189+
190+
:type: `bool`
191+
"""
192+
return all([domain.is_finite for domain in self._domains])
193+
194+
@property
195+
def dtype(self):
196+
"""
197+
The numpy dtype of a single element of the domain.
198+
199+
:type: `np.dtype`
200+
"""
201+
return self._dtype
202+
203+
@property
204+
def n_members(self):
205+
"""
206+
Returns the number of members in the domain if it
207+
`is_finite`, otherwise, returns `np.inf`.
208+
209+
:type: ``int`` or ``np.inf``
210+
"""
211+
if self.is_finite:
212+
return reduce(mul, [domain.n_members for domain in self._domains], 1)
213+
else:
214+
return np.inf
215+
216+
@property
217+
def example_point(self):
218+
"""
219+
Returns any single point guaranteed to be in the domain, but
220+
no other guarantees; useful for testing purposes.
221+
This is given as a size 1 ``np.array`` of type `dtype`.
222+
223+
:type: ``np.ndarray``
224+
"""
225+
return self._example_point
226+
227+
@property
228+
def values(self):
229+
"""
230+
Returns an `np.array` of type `dtype` containing
231+
some values from the domain.
232+
For domains where `is_finite` is ``True``, all elements
233+
of the domain will be yielded exactly once.
234+
235+
:rtype: `np.ndarray`
236+
"""
237+
separate_values = [domain.values for domain in self._domains]
238+
return np.concatenate([
239+
join_struct_arrays(list(map(np.array, value)))
240+
for value in product(*separate_values)
241+
])
242+
243+
## METHODS ##
244+
245+
def _mytype(self, array):
246+
# astype does weird stuff with struct names, and possibly
247+
# depends on numpy version; hopefully
248+
# the following is a bit more predictable since it passes through
249+
# uint8
250+
return separate_struct_array(array, self.dtype)[0]
251+
252+
def to_regular_arrays(self, array):
253+
"""
254+
Expands from an array of type `self.dtype` into a list of
255+
arrays with dtypes corresponding to the factor domains.
256+
257+
:param np.ndarray array: An `np.array` of type `self.dtype`.
258+
259+
:rtype: ``list``
260+
"""
261+
return separate_struct_array(self._mytype(array), self._dtypes)
262+
263+
def from_regular_arrays(self, arrays):
264+
"""
265+
Merges a list of arrays (of the same shape) of dtypes
266+
corresponding to the factor domains into a single array
267+
with the dtype of the ``ProductDomain``.
268+
269+
:param list array: A list with each element of type ``np.ndarray``
270+
271+
:rtype: `np.ndarray`
272+
"""
273+
return self._mytype(join_struct_arrays([
274+
array.astype(dtype)
275+
for dtype, array in zip(self._dtypes, arrays)
276+
]))
277+
278+
279+
def in_domain(self, points):
280+
"""
281+
Returns ``True`` if all of the given points are in the domain,
282+
``False`` otherwise.
283+
284+
:param np.ndarray points: An `np.ndarray` of type `self.dtype`.
285+
286+
:rtype: `bool`
287+
"""
288+
return all([
289+
domain.in_domain(array)
290+
for domain, array in
291+
zip(self._domains, separate_struct_array(points, self._dtypes))
292+
])
293+
294+
148295
## CLASSES ###################################################################
149296

150297
class RealDomain(Domain):

src/qinfer/tests/test_domains.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@
4040
)
4141
import abc
4242
from qinfer import (
43-
Domain, RealDomain, IntegerDomain, MultinomialDomain
43+
Domain, ProductDomain,
44+
RealDomain, IntegerDomain, MultinomialDomain
4445
)
4546

4647
import unittest
@@ -175,3 +176,97 @@ def test_array_conversion(self):
175176
assert_equal(self.domain.from_regular_array(arr2), arr1)
176177
assert_equal(self.domain.to_regular_array(self.domain.from_regular_array(arr2)), arr2)
177178
assert_equal(self.domain.from_regular_array(self.domain.to_regular_array(arr1)), arr1)
179+
180+
181+
class TestIntegerIntegerProductDomain(ConcreteDomainTest, DerandomizedTestCase):
182+
"""
183+
Tests ProductDomain([IntegerDomain, IntegerDomain])
184+
"""
185+
186+
def instantiate_domain(self):
187+
return ProductDomain(
188+
IntegerDomain(min=0,max=5),
189+
IntegerDomain(min=-2, max=2)
190+
)
191+
def instantiate_good_values(self):
192+
return [
193+
np.array([(0,0)], dtype=[('','<i8'),('','<i8')]),
194+
np.array([(5,2),(0,-2)], dtype=[('','<i8'),('','<i8')])
195+
]
196+
def instantiate_bad_values(self):
197+
return [
198+
np.array([(0.5,0)], dtype=[('','f8'),('','<i8')]),
199+
np.array([(6,2),(0,-2)], dtype=[('','<i8'),('','<i8')])
200+
]
201+
202+
def test_array_conversion(self):
203+
arr1 = np.array([(5,2),(0,-1)], dtype=[('','<i8'),('','<i8')])
204+
arr2 = [np.array([5,0]), np.array([2,-1])]
205+
206+
assert_equal(self.domain.to_regular_arrays(arr1), arr2)
207+
assert_equal(self.domain.from_regular_arrays(arr2), arr1)
208+
assert_equal(self.domain.to_regular_arrays(self.domain.from_regular_arrays(arr2)), arr2)
209+
assert_equal(self.domain.from_regular_arrays(self.domain.to_regular_arrays(arr1)), arr1)
210+
211+
#override
212+
def test_is_finite(self):
213+
assert(self.domain.is_finite)
214+
assert(self.domain.is_discrete)
215+
216+
class TestIntegerIntegerProductDomain2(ConcreteDomainTest, DerandomizedTestCase):
217+
"""
218+
Tests ProductDomain([IntegerDomain, IntegerDomain])
219+
"""
220+
221+
def instantiate_domain(self):
222+
return ProductDomain(
223+
IntegerDomain(min=0,max=5),
224+
IntegerDomain(min=-2, max=np.inf),
225+
IntegerDomain(min=-np.inf, max=np.inf)
226+
)
227+
def instantiate_good_values(self):
228+
return [
229+
np.array([(0,0,0)], dtype=[('','<i8'),('','<i8'),('','<i8')]),
230+
np.array([(5,2,0),(0,-2,10)], dtype=[('','<i8'),('','<i8'),('','<i8')])
231+
]
232+
def instantiate_bad_values(self):
233+
return [
234+
np.array([(0.5,0,10)], dtype=[('','f8'),('','<i8'),('','<i8')]),
235+
np.array([(6,2,10),(0,-2,10)], dtype=[('','<i8'),('','<i8'),('','<i8')])
236+
]
237+
238+
#override
239+
def test_is_finite(self):
240+
assert(not self.domain.is_finite)
241+
assert(self.domain.is_discrete)
242+
243+
class TestIntegerMultinomialProductDomain(ConcreteDomainTest, DerandomizedTestCase):
244+
"""
245+
Tests ProductDomain([IntegerDomain, IntegerDomain])
246+
"""
247+
248+
def instantiate_domain(self):
249+
return ProductDomain(
250+
IntegerDomain(min=0,max=5),
251+
MultinomialDomain(5, n_elements=3)
252+
)
253+
def instantiate_good_values(self):
254+
return [
255+
np.array([(0,[1,2,2])], dtype=[('','<i8'),('','<i8', 3)]),
256+
np.array([(5,[5,0,0]),(0,[1,0,4])], dtype=[('','<i8'),('','<i8', 3)])
257+
]
258+
def instantiate_bad_values(self):
259+
return [
260+
np.array([(-10,[1,2,2])], dtype=[('','<i8'),('k','<i8', 3)]),
261+
np.array([(5,[-1,6,0]),(0,[1,0,4])], dtype=[('','<i8'),('k','<i8', 3)])
262+
]
263+
264+
def test_array_conversion(self):
265+
arr1 = np.array([(5,[1,2,2]),(0,[5,0,0])], dtype=[('','<i8'),('k','<i8', 3)])
266+
arr2 = [np.array([5,0]), np.array([([1,2,2],), ([5,0,0],)], dtype=[('k','<i8',3)])]
267+
268+
assert_equal(self.domain.to_regular_arrays(arr1), arr2)
269+
assert_equal(self.domain.from_regular_arrays(arr2), arr1)
270+
assert_equal(self.domain.to_regular_arrays(self.domain.from_regular_arrays(arr2)), arr2)
271+
assert_equal(self.domain.from_regular_arrays(self.domain.to_regular_arrays(arr1)), arr1)
272+

src/qinfer/utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,48 @@ def pretty_time(secs, force_h=False, force_m=False):
462462
def safe_shape(arr, idx=0, default=1):
463463
shape = np.shape(arr)
464464
return shape[idx] if idx < len(shape) else default
465+
466+
def join_struct_arrays(arrays):
467+
"""
468+
Takes a list of possibly structured arrays, concatenates their
469+
dtypes, and returns one big array with that dtype. Does the
470+
inverse of ``separate_struct_array``.
471+
472+
:param list arrays: List of ``np.ndarray``s
473+
"""
474+
# taken from http://stackoverflow.com/questions/5355744/numpy-joining-structured-arrays
475+
sizes = np.array([a.itemsize for a in arrays])
476+
offsets = np.r_[0, sizes.cumsum()]
477+
shape = arrays[0].shape
478+
joint = np.empty(shape + (offsets[-1],), dtype=np.uint8)
479+
for a, size, offset in zip(arrays, sizes, offsets):
480+
joint[...,offset:offset+size] = np.atleast_1d(a).view(np.uint8).reshape(shape + (size,))
481+
dtype = sum((a.dtype.descr for a in arrays), [])
482+
return joint.ravel().view(dtype)
483+
484+
def separate_struct_array(array, dtypes):
485+
"""
486+
Takes an array with a structured dtype, and separates it out into
487+
a list of arrays with dtypes coming from the input ``dtypes``.
488+
Does the inverse of ``join_struct_arrays``.
489+
490+
:param np.ndarray array: Structured array.
491+
:param dtypes: List of ``np.dtype``, or just a ``np.dtype`` and the number of
492+
them is figured out automatically by counting bytes.
493+
"""
494+
try:
495+
offsets = np.cumsum([np.dtype(dtype).itemsize for dtype in dtypes])
496+
except TypeError:
497+
dtype_size = np.dtype(dtypes).itemsize
498+
num_fields = int(array.nbytes / (array.size * dtype_size))
499+
offsets = np.cumsum([dtype_size] * num_fields)
500+
dtypes = [dtypes] * num_fields
501+
offsets = np.concatenate([[0], offsets]).astype(int)
502+
uint_array = array.view(np.uint8).reshape(array.shape + (-1,))
503+
return [
504+
uint_array[..., offsets[idx]:offsets[idx+1]].flatten().view(dtype)
505+
for idx, dtype in enumerate(dtypes)
506+
]
465507

466508
def sqrtm_psd(A, est_error=True, check_finite=True):
467509
"""

0 commit comments

Comments
 (0)