Skip to content

Commit 12b2612

Browse files
authored
Merge pull request #115 from ihincks/feature-domain-tests
Tests for Domains
2 parents aa6fb0b + cba6647 commit 12b2612

File tree

4 files changed

+407
-104
lines changed

4 files changed

+407
-104
lines changed

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ description-file = README.rst
44
[bdist_wheel]
55
universal = 1
66

7-
[pytest]
7+
[tool:pytest]
88
python_files=tests/*.py
99
python_classes=Test
1010
python_functions=*_test

src/qinfer/domains.py

Lines changed: 73 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -90,18 +90,18 @@ def dtype(self):
9090
@abc.abstractproperty
9191
def n_members(self):
9292
"""
93-
Returns the number of members in the domain if it
93+
Returns the number of members in the domain if it
9494
`is_finite`, otherwise, returns `np.inf`.
9595
96-
:type: ``int`` or ``np.inf``
96+
:type: ``int`` or ``np.inf``
9797
"""
9898
pass
9999

100100
@abc.abstractproperty
101101
def example_point(self):
102102
"""
103-
Returns any single point guaranteed to be in the domain, but
104-
no other guarantees; useful for testing purposes.
103+
Returns any single point guaranteed to be in the domain, but
104+
no other guarantees; useful for testing purposes.
105105
This is given as a size 1 ``np.array`` of type `dtype`.
106106
107107
:type: ``np.ndarray``
@@ -111,9 +111,9 @@ def example_point(self):
111111
@abc.abstractproperty
112112
def values(self):
113113
"""
114-
Returns an `np.array` of type `dtype` containing
114+
Returns an `np.array` of type `dtype` containing
115115
some values from the domain.
116-
For domains where `is_finite` is ``True``, all elements
116+
For domains where `is_finite` is ``True``, all elements
117117
of the domain will be yielded exactly once.
118118
119119
:rtype: `np.ndarray`
@@ -136,7 +136,7 @@ def is_discrete(self):
136136
@abc.abstractmethod
137137
def in_domain(self, points):
138138
"""
139-
Returns ``True`` if all of the given points are in the domain,
139+
Returns ``True`` if all of the given points are in the domain,
140140
``False`` otherwise.
141141
142142
:param np.ndarray points: An `np.ndarray` of type `self.dtype`.
@@ -149,13 +149,13 @@ def in_domain(self, points):
149149

150150
class RealDomain(Domain):
151151
"""
152-
A domain specifying a contiguous (and possibly open ended) subset
152+
A domain specifying a contiguous (and possibly open ended) subset
153153
of the real numbers.
154154
155-
:param float min: A number specifying the lowest possible value of the
155+
:param float min: A number specifying the lowest possible value of the
156+
domain.
157+
:param float max: A number specifying the largest possible value of the
156158
domain.
157-
:param float max: A number specifying the largest possible value of the
158-
domain.
159159
"""
160160

161161
def __init__(self, min=-np.inf, max=np.inf):
@@ -167,20 +167,20 @@ def __init__(self, min=-np.inf, max=np.inf):
167167
@property
168168
def min(self):
169169
"""
170-
Returns the minimum value of the domain.
170+
Returns the minimum value of the domain.
171171
172172
:rtype: `float`
173173
"""
174174
return self._min
175175
@property
176176
def max(self):
177177
"""
178-
Returns the maximum value of the domain.
178+
Returns the maximum value of the domain.
179179
180180
:rtype: `float`
181181
"""
182182
return self._max
183-
183+
184184
@property
185185
def is_continuous(self):
186186
"""
@@ -211,7 +211,7 @@ def dtype(self):
211211
@property
212212
def n_members(self):
213213
"""
214-
Returns the number of members in the domain if it
214+
Returns the number of members in the domain if it
215215
`is_finite`, otherwise, returns `None`.
216216
217217
:type: ``np.inf``
@@ -221,8 +221,8 @@ def n_members(self):
221221
@property
222222
def example_point(self):
223223
"""
224-
Returns any single point guaranteed to be in the domain, but
225-
no other guarantees; useful for testing purposes.
224+
Returns any single point guaranteed to be in the domain, but
225+
no other guarantees; useful for testing purposes.
226226
This is given as a size 1 ``np.array`` of type ``dtype``.
227227
228228
:type: ``np.ndarray``
@@ -237,9 +237,9 @@ def example_point(self):
237237
@property
238238
def values(self):
239239
"""
240-
Returns an `np.array` of type `self.dtype` containing
240+
Returns an `np.array` of type `self.dtype` containing
241241
some values from the domain.
242-
For domains where ``is_finite`` is ``True``, all elements
242+
For domains where ``is_finite`` is ``True``, all elements
243243
of the domain will be yielded exactly once.
244244
245245
:rtype: `np.ndarray`
@@ -250,27 +250,32 @@ def values(self):
250250

251251
def in_domain(self, points):
252252
"""
253-
Returns ``True`` if all of the given points are in the domain,
253+
Returns ``True`` if all of the given points are in the domain,
254254
``False`` otherwise.
255255
256256
:param np.ndarray points: An `np.ndarray` of type `self.dtype`.
257257
258258
:rtype: `bool`
259259
"""
260-
return np.all(points >= self._min) and np.all(points <= self._max)
260+
if np.all(np.isreal(points)):
261+
are_greater = np.all(np.greater_equal(points, self._min))
262+
are_smaller = np.all(np.less_equal(points, self._max))
263+
return are_greater and are_smaller
264+
else:
265+
return False
261266

262267
class IntegerDomain(Domain):
263268
"""
264-
A domain specifying a contiguous (and possibly open ended) subset
265-
of the integers.
269+
A domain specifying a contiguous (and possibly open ended) subset
270+
of the integers.
266271
267-
Internally minimum and maximum are represented as
272+
Internally minimum and maximum are represented as
268273
floats in order to handle the case of infinite maximum, and minimums. The
269-
integer conversion function will be applied to the min and max values.
274+
integer conversion function will be applied to the min and max values.
270275
271-
:param int min: A number specifying the lowest possible value of the
272-
domain.
273-
:param int max: A number specifying the largest possible value of the
276+
:param int min: A number specifying the lowest possible value of the
277+
domain.
278+
:param int max: A number specifying the largest possible value of the
274279
domain.
275280
276281
Note: Yes, it is slightly unpythonic to specify `max` instead of `max`+1.
@@ -285,20 +290,20 @@ def __init__(self, min=0, max=np.inf):
285290
@property
286291
def min(self):
287292
"""
288-
Returns the minimum value of the domain.
293+
Returns the minimum value of the domain.
289294
290295
:rtype: `float` or `np.inf`
291296
"""
292-
return int(self._min) if not np.isinf(self._min) else self._min
297+
return int(self._min) if not np.isinf(self._min) else self._min
293298
@property
294299
def max(self):
295300
"""
296-
Returns the maximum value of the domain.
301+
Returns the maximum value of the domain.
297302
298303
:rtype: `float` or `np.inf`
299304
"""
300305
return int(self._max) if not np.isinf(self._max) else self._max
301-
306+
302307

303308
@property
304309
def is_continuous(self):
@@ -330,7 +335,7 @@ def dtype(self):
330335
@property
331336
def n_members(self):
332337
"""
333-
Returns the number of members in the domain if it
338+
Returns the number of members in the domain if it
334339
`is_finite`, otherwise, returns `np.inf`.
335340
336341
:type: ``int`` or ``np.inf``
@@ -343,8 +348,8 @@ def n_members(self):
343348
@property
344349
def example_point(self):
345350
"""
346-
Returns any single point guaranteed to be in the domain, but
347-
no other guarantees; useful for testing purposes.
351+
Returns any single point guaranteed to be in the domain, but
352+
no other guarantees; useful for testing purposes.
348353
This is given as a size 1 ``np.array`` of type ``dtype``.
349354
350355
:type: ``np.ndarray``
@@ -359,9 +364,9 @@ def example_point(self):
359364
@property
360365
def values(self):
361366
"""
362-
Returns an `np.array` of type `self.dtype` containing
367+
Returns an `np.array` of type `self.dtype` containing
363368
some values from the domain.
364-
For domains where ``is_finite`` is ``True``, all elements
369+
For domains where ``is_finite`` is ``True``, all elements
365370
of the domain will be yielded exactly once.
366371
367372
:rtype: `np.ndarray`
@@ -375,26 +380,31 @@ def values(self):
375380

376381
def in_domain(self, points):
377382
"""
378-
Returns ``True`` if all of the given points are in the domain,
383+
Returns ``True`` if all of the given points are in the domain,
379384
``False`` otherwise.
380385
381386
:param np.ndarray points: An `np.ndarray` of type `self.dtype`.
382387
383388
:rtype: `bool`
384389
"""
385-
are_integer = np.all(np.mod(points,1) == 0)
386-
are_greater = np.all(points >= self._min)
387-
are_smaller = np.all(points <= self._max)
388-
return are_integer and are_greater and are_smaller
389-
390+
if np.all(np.isreal(points)):
391+
try:
392+
are_integer = np.all(np.mod(points, 1) == 0)
393+
except TypeError:
394+
are_integer = False
395+
are_greater = np.all(np.greater_equal(points, self._min))
396+
are_smaller = np.all(np.less_equal(points, self._max))
397+
return are_integer and are_greater and are_smaller
398+
else:
399+
return False
390400

391401
class MultinomialDomain(Domain):
392402
"""
393-
A domain specifying k-tuples of non-negative integers which
403+
A domain specifying k-tuples of non-negative integers which
394404
sum to a specific value.
395405
396406
:param int n_meas: The sum of any tuple in the domain.
397-
:param int n_elements: The number of elements in a tuple.
407+
:param int n_elements: The number of elements in a tuple.
398408
"""
399409

400410
def __init__(self, n_meas, n_elements=2):
@@ -419,7 +429,7 @@ def n_elements(self):
419429
:rtype: `int`
420430
"""
421431
return self._n_elements
422-
432+
423433

424434
@property
425435
def is_continuous(self):
@@ -451,7 +461,7 @@ def dtype(self):
451461
@property
452462
def n_members(self):
453463
"""
454-
Returns the number of members in the domain if it
464+
Returns the number of members in the domain if it
455465
`is_finite`, otherwise, returns `None`.
456466
457467
:type: ``int``
@@ -461,20 +471,20 @@ def n_members(self):
461471
@property
462472
def example_point(self):
463473
"""
464-
Returns any single point guaranteed to be in the domain, but
465-
no other guarantees; useful for testing purposes.
474+
Returns any single point guaranteed to be in the domain, but
475+
no other guarantees; useful for testing purposes.
466476
This is given as a size 1 ``np.array`` of type ``dtype``.
467477
468478
:type: ``np.ndarray``
469479
"""
470-
return np.array([([self.n_meas] + [0] * (self.n_elements-1))], dtype=self.dtype)
480+
return np.array([([self.n_meas] + [0] * (self.n_elements-1),)], dtype=self.dtype)
471481

472482
@property
473483
def values(self):
474484
"""
475-
Returns an `np.array` of type `self.dtype` containing
485+
Returns an `np.array` of type `self.dtype` containing
476486
some values from the domain.
477-
For domains where ``is_finite`` is ``True``, all elements
487+
For domains where ``is_finite`` is ``True``, all elements
478488
of the domain will be yielded exactly once.
479489
480490
:rtype: `np.ndarray`
@@ -483,32 +493,32 @@ def values(self):
483493
# This code comes from Jared Goguen at http://stackoverflow.com/a/37712597/1082565
484494
partition_array = np.empty((self.n_members, self.n_elements), dtype=int)
485495
masks = np.identity(self.n_elements, dtype=int)
486-
for i, c in enumerate(combinations_with_replacement(masks, self.n_meas)):
496+
for i, c in enumerate(combinations_with_replacement(masks, self.n_meas)):
487497
partition_array[i,:] = sum(c)
488498

489499
# Convert to dtype before returning
490500
return self.from_regular_array(partition_array)
491-
501+
492502
## METHODS ##
493503

494504
def to_regular_array(self, A):
495505
"""
496-
Converts from an array of type `self.dtype` to an array
497-
of type `int` with an additional index labeling the
506+
Converts from an array of type `self.dtype` to an array
507+
of type `int` with an additional index labeling the
498508
tuple indeces.
499509
500510
:param np.ndarray A: An `np.array` of type `self.dtype`.
501511
502512
:rtype: `np.ndarray`
503513
"""
504-
# this could be a static method, but we choose to be consistent with
514+
# this could be a static method, but we choose to be consistent with
505515
# from_regular_array
506516
return A.view((int, len(A.dtype.names))).reshape(A.shape + (-1,))
507517

508518
def from_regular_array(self, A):
509519
"""
510-
Converts from an array of type `int` where the last index
511-
is assumed to have length `self.n_elements` to an array
520+
Converts from an array of type `int` where the last index
521+
is assumed to have length `self.n_elements` to an array
512522
of type `self.d_type` with one fewer index.
513523
514524
:param np.ndarray A: An `np.array` of type `int`.
@@ -520,14 +530,14 @@ def from_regular_array(self, A):
520530

521531
def in_domain(self, points):
522532
"""
523-
Returns ``True`` if all of the given points are in the domain,
533+
Returns ``True`` if all of the given points are in the domain,
524534
``False`` otherwise.
525535
526536
:param np.ndarray points: An `np.ndarray` of type `self.dtype`.
527537
528538
:rtype: `bool`
529539
"""
530540
array_view = self.to_regular_array(points)
531-
return np.all(array_view >= 0) and np.all(np.sum(array_view, axis=-1) == self.n_meas)
532-
533-
541+
non_negative = np.all(np.greater_equal(array_view, 0))
542+
correct_sum = np.all(np.sum(array_view, axis=-1) == self.n_meas)
543+
return non_negative and correct_sum

0 commit comments

Comments
 (0)