Skip to content

Commit 4a1c493

Browse files
committed
finished implementing ProductDomain
1 parent 425ef8f commit 4a1c493

File tree

2 files changed

+104
-11
lines changed

2 files changed

+104
-11
lines changed

src/qinfer/domains.py

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434
from operator import mul
3535
from scipy.special import binom
3636
from math import factorial
37-
from itertools import combinations_with_replacement
37+
from itertools import combinations_with_replacement, product
3838
import numpy as np
39-
from numpy.lib.recfunctions import merge_arrays
39+
from .utils import join_struct_arrays, separate_struct_array
4040

4141
import abc
4242

@@ -152,14 +152,24 @@ class ProductDomain(Domain):
152152
"""
153153
A domain made from the cartesian product of other domains.
154154
155-
:param list domains: A list of domains.
155+
:param Domain domains: ``Domain`` objects as separate arguments,
156+
or as a singe list of ``Domain``s.
156157
"""
157-
def __init__(self, domains):
158-
super(ProductDomain, self).__init__()
158+
def __init__(self, *domains):
159+
160+
if len(domains) == 1:
161+
try:
162+
self._domains = list(domains[0])
163+
except:
164+
self._domains = domains
165+
else:
166+
self._domains = domains
167+
159168
self._domains = domains
160-
self._example_point = merge_arrays(
161-
[domain.example_point for domain in self._domains],
162-
flatten = True, usemask = False)
169+
self._dtypes = [domain.example_point.dtype for domain in self._domains]
170+
self._example_point = join_struct_arrays(
171+
[np.array(domain.example_point) for domain in self._domains]
172+
)
163173
self._dtype = self._example_point.dtype
164174

165175
@property
@@ -224,12 +234,49 @@ def values(self):
224234
:rtype: `np.ndarray`
225235
"""
226236
if self.is_finite:
227-
raise NotImplemented()
237+
separate_values = [domain.values for domain in self._domains]
238+
return np.concatenate([
239+
join_struct_arrays(map(np.array, value))
240+
for value in product(*separate_values)
241+
])
228242
else:
229243
return self.example_point
230244

245+
## METHODS ##
246+
247+
def _mytype(self, array):
248+
# astype does weird stuff with struct names, and possibly
249+
# depends on numpy version; hopefully
250+
# the following is a bit more predictable since it passes through
251+
# uint8
252+
return separate_struct_array(array, self.dtype)[0]
253+
254+
def to_regular_arrays(self, array):
255+
"""
256+
Expands from an array of type `self.dtype` into a list of
257+
arrays with dtypes corresponding to the factor domains.
231258
232-
## ABSTRACT METHODS ##
259+
:param np.ndarray array: An `np.array` of type `self.dtype`.
260+
261+
:rtype: ``list``
262+
"""
263+
return separate_struct_array(self._mytype(array), self._dtypes)
264+
265+
def from_regular_arrays(self, arrays):
266+
"""
267+
Merges a list of arrays (of the same shape) of dtypes
268+
corresponding to the factor domains into a single array
269+
with the dtype of the ``ProductDomain``.
270+
271+
:param list array: A list of ``np.ndarray``s
272+
273+
:rtype: `np.ndarray`
274+
"""
275+
return self._mytype(join_struct_arrays([
276+
array.astype(dtype)
277+
for dtype, array in zip(self._dtypes, arrays)
278+
]))
279+
233280

234281
def in_domain(self, points):
235282
"""
@@ -240,7 +287,11 @@ def in_domain(self, points):
240287
241288
:rtype: `bool`
242289
"""
243-
raise NotImplemented()
290+
return all([
291+
domain.in_domain(array)
292+
for domain, array in
293+
zip(self._domains, separate_struct_array(points, self._dtypes))
294+
])
244295

245296

246297
## CLASSES ###################################################################

src/qinfer/utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,48 @@ def pretty_time(secs, force_h=False, force_m=False):
454454
def safe_shape(arr, idx=0, default=1):
455455
shape = np.shape(arr)
456456
return shape[idx] if idx < len(shape) else default
457+
458+
def join_struct_arrays(arrays):
459+
"""
460+
Takes a list of possibly structured arrays, concatenates their
461+
dtypes, and returns one big array with that dtype. Does the
462+
inverse of ``separate_struct_array``.
463+
464+
:param list arrays: List of ``np.ndarray``s
465+
"""
466+
# taken from http://stackoverflow.com/questions/5355744/numpy-joining-structured-arrays
467+
sizes = np.array([a.itemsize for a in arrays])
468+
offsets = np.r_[0, sizes.cumsum()]
469+
shape = arrays[0].shape
470+
joint = np.empty(shape + (offsets[-1],), dtype=np.uint8)
471+
for a, size, offset in zip(arrays, sizes, offsets):
472+
joint[...,offset:offset+size] = np.atleast_1d(a).view(np.uint8).reshape(shape + (size,))
473+
dtype = sum((a.dtype.descr for a in arrays), [])
474+
return joint.ravel().view(dtype)
475+
476+
def separate_struct_array(array, dtypes):
477+
"""
478+
Takes an array with a structured dtype, and separates it out into
479+
a list of arrays with dtypes coming from the input ``dtypes``.
480+
Does the inverse of ``join_struct_arrays``.
481+
482+
:param np.ndarray array: Structured array.
483+
:param dtypes: List of ``np.dtype``, or just a ``np.dtype`` and the number of
484+
them is figured out automatically by counting bytes.
485+
"""
486+
try:
487+
offsets = np.cumsum([np.dtype(dtype).itemsize for dtype in dtypes])
488+
except TypeError:
489+
dtype_size = np.dtype(dtypes).itemsize
490+
num_fields = int(array.nbytes / (array.size * dtype_size))
491+
offsets = np.cumsum([dtype_size] * num_fields)
492+
dtypes = [dtypes] * num_fields
493+
offsets = np.concatenate([[0], offsets]).astype(int)
494+
uint_array = array.view(np.uint8).reshape(array.shape + (-1,))
495+
return [
496+
uint_array[..., offsets[idx]:offsets[idx+1]].flatten().view(dtype)
497+
for idx, dtype in enumerate(dtypes)
498+
]
457499

458500

459501
#==============================================================================

0 commit comments

Comments
 (0)