Skip to content

Commit 2ba9a7a

Browse files
authored
Merge pull request #138 from ihincks/feature-dirichlet-distribution
DirichletDistribution
2 parents 3c9cc7e + f4b18f6 commit 2ba9a7a

File tree

4 files changed

+71
-3
lines changed

4 files changed

+71
-3
lines changed

doc/source/apiref/distributions.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ Specific Distributions
5656

5757
.. autoclass:: BetaBinomialDistribution
5858
:members:
59+
60+
.. autoclass:: DirichletDistribution
61+
:members:
5962

6063
.. autoclass:: GammaDistribution
6164
:members:

src/qinfer/distributions.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
'SlantedNormalDistribution',
7777
'LogNormalDistribution',
7878
'BetaDistribution',
79+
'DirichletDistribution',
7980
'BetaBinomialDistribution',
8081
'GammaDistribution',
8182
'GinibreUniform',
@@ -1006,6 +1007,31 @@ def n_rvs(self):
10061007

10071008
def sample(self, n=1):
10081009
return self.dist.rvs(size=n)[:, np.newaxis]
1010+
1011+
class DirichletDistribution(Distribution):
1012+
r"""
1013+
The dirichlet distribution, whose pdf at :math:`x` is proportional to
1014+
:math:`\prod_i x_i^{\alpha_i-1}`.
1015+
1016+
:param alpha: The list of concentration parameters.
1017+
"""
1018+
def __init__(self, alpha):
1019+
self._alpha = np.array(alpha)
1020+
if self.alpha.ndim != 1:
1021+
raise ValueError('The input alpha must be a 1D list of concentration parameters.')
1022+
1023+
self._dist = st.dirichlet(alpha=self.alpha)
1024+
1025+
@property
1026+
def alpha(self):
1027+
return self._alpha
1028+
1029+
@property
1030+
def n_rvs(self):
1031+
return self._alpha.size
1032+
1033+
def sample(self, n=1):
1034+
return self._dist.rvs(size=n)
10091035

10101036
class BetaBinomialDistribution(Distribution):
10111037
r"""

src/qinfer/tests/test_concrete_models.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
NormalDistribution,
6565
BetaDistribution, UniformDistribution,
6666
PostselectedDistribution,
67-
ConstrainedSumDistribution,
67+
ConstrainedSumDistribution, DirichletDistribution,
6868
DirectViewParallelizedModel,
6969
GaussianHyperparameterizedModel
7070
)
@@ -291,8 +291,7 @@ class TestMultinomialModel(ConcreteModelTest, DerandomizedTestCase):
291291
def instantiate_model(self):
292292
return MultinomialModel(NDieModel(n=6))
293293
def instantiate_prior(self):
294-
unif = UniformDistribution(np.array([[0,1],[0,1],[0,1],[0,1],[0,1],[0,1]]))
295-
return ConstrainedSumDistribution(unif, desired_total=1)
294+
return DirichletDistribution([1,2,3,10,1,3])
296295
def instantiate_expparams(self):
297296
return np.arange(10).astype(self.model.expparams_dtype)
298297

src/qinfer/tests/test_distributions.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,46 @@ def test_betabinomial_n_rvs(self):
347347
"""
348348
dist = BetaBinomialDistribution(10, alpha=10,beta=42)
349349
assert(dist.n_rvs == 1)
350+
351+
class TestDirichletDistribution(DerandomizedTestCase):
352+
"""
353+
Tests ``DirichletDistribution``
354+
"""
355+
356+
## TEST METHODS ##
357+
358+
def test_dirichlet_moments(self):
359+
"""
360+
Distributions: Checks that the dirichlet distribution has the right
361+
moments, with either of the two input formats
362+
"""
363+
alpha = [1,2,3,4]
364+
alpha_np = np.array(alpha)
365+
alpha_0 = alpha_np.sum()
366+
mean = alpha_np / alpha_0
367+
var = alpha_np * (alpha_0 - alpha_np) / (alpha_0 **2 * (alpha_0+1))
368+
369+
dist = DirichletDistribution(alpha)
370+
samples = dist.sample(100000)
371+
372+
assert samples.shape == (100000, alpha_np.size)
373+
assert_almost_equal(samples.mean(axis=0), mean, 2)
374+
assert_almost_equal(samples.var(axis=0), var, 2)
375+
376+
alpha = np.array([8,7,5,2,2])
377+
alpha_np = np.array(alpha)
378+
alpha_0 = alpha_np.sum()
379+
mean = alpha_np / alpha_0
380+
var = alpha_np * (alpha_0 - alpha_np) / (alpha_0 **2 * (alpha_0+1))
381+
382+
dist = DirichletDistribution(alpha)
383+
samples = dist.sample(100000)
384+
385+
assert samples.shape == (100000, alpha_np.size)
386+
assert_almost_equal(samples.mean(axis=0), mean, 2)
387+
assert_almost_equal(samples.var(axis=0), var, 2)
388+
389+
350390

351391
class TestGammaDistribution(DerandomizedTestCase):
352392
"""

0 commit comments

Comments
 (0)