3434from operator import mul
3535from scipy .special import binom
3636from math import factorial
37- from itertools import combinations_with_replacement
37+ from itertools import combinations_with_replacement , product
3838import numpy as np
39- from numpy . lib . recfunctions import merge_arrays
39+ from . utils import join_struct_arrays , separate_struct_array
4040
4141import abc
4242
@@ -152,14 +152,24 @@ class ProductDomain(Domain):
152152 """
153153 A domain made from the cartesian product of other domains.
154154
155- :param list domains: A list of domains.
155+ :param Domain domains: ``Domain`` objects as separate arguments,
156+ or as a singe list of ``Domain``s.
156157 """
157- def __init__ (self , domains ):
158- super (ProductDomain , self ).__init__ ()
158+ def __init__ (self , * domains ):
159+
160+ if len (domains ) == 1 :
161+ try :
162+ self ._domains = list (domains [0 ])
163+ except :
164+ self ._domains = domains
165+ else :
166+ self ._domains = domains
167+
159168 self ._domains = domains
160- self ._example_point = merge_arrays (
161- [domain .example_point for domain in self ._domains ],
162- flatten = True , usemask = False )
169+ self ._dtypes = [domain .example_point .dtype for domain in self ._domains ]
170+ self ._example_point = join_struct_arrays (
171+ [np .array (domain .example_point ) for domain in self ._domains ]
172+ )
163173 self ._dtype = self ._example_point .dtype
164174
165175 @property
@@ -224,12 +234,49 @@ def values(self):
224234 :rtype: `np.ndarray`
225235 """
226236 if self .is_finite :
227- raise NotImplemented ()
237+ separate_values = [domain .values for domain in self ._domains ]
238+ return np .concatenate ([
239+ join_struct_arrays (map (np .array , value ))
240+ for value in product (* separate_values )
241+ ])
228242 else :
229243 return self .example_point
230244
245+ ## METHODS ##
246+
247+ def _mytype (self , array ):
248+ # astype does weird stuff with struct names, and possibly
249+ # depends on numpy version; hopefully
250+ # the following is a bit more predictable since it passes through
251+ # uint8
252+ return separate_struct_array (array , self .dtype )[0 ]
253+
254+ def to_regular_arrays (self , array ):
255+ """
256+ Expands from an array of type `self.dtype` into a list of
257+ arrays with dtypes corresponding to the factor domains.
231258
232- ## ABSTRACT METHODS ##
259+ :param np.ndarray array: An `np.array` of type `self.dtype`.
260+
261+ :rtype: ``list``
262+ """
263+ return separate_struct_array (self ._mytype (array ), self ._dtypes )
264+
265+ def from_regular_arrays (self , arrays ):
266+ """
267+ Merges a list of arrays (of the same shape) of dtypes
268+ corresponding to the factor domains into a single array
269+ with the dtype of the ``ProductDomain``.
270+
271+ :param list array: A list of ``np.ndarray``s
272+
273+ :rtype: `np.ndarray`
274+ """
275+ return self ._mytype (join_struct_arrays ([
276+ array .astype (dtype )
277+ for dtype , array in zip (self ._dtypes , arrays )
278+ ]))
279+
233280
234281 def in_domain (self , points ):
235282 """
@@ -240,7 +287,11 @@ def in_domain(self, points):
240287
241288 :rtype: `bool`
242289 """
243- raise NotImplemented ()
290+ return all ([
291+ domain .in_domain (array )
292+ for domain , array in
293+ zip (self ._domains , separate_struct_array (points , self ._dtypes ))
294+ ])
244295
245296
246297## CLASSES ###################################################################
0 commit comments