diff --git a/lectures/multi_hyper.md b/lectures/multi_hyper.md index c6919a189..21847b61a 100644 --- a/lectures/multi_hyper.md +++ b/lectures/multi_hyper.md @@ -3,10 +3,12 @@ jupytext: text_representation: extension: .md format_name: myst + format_version: 0.13 + jupytext_version: 1.17.1 kernelspec: - display_name: Python 3 - language: python name: python3 + display_name: Python 3 (ipykernel) + language: python --- (multi_hyper_v7)= @@ -24,9 +26,17 @@ kernelspec: :depth: 2 ``` +In addition to what's included in base Anaconda, we need to install JAX + +```{code-cell} ipython3 +:tags: [hide-output] + +!pip install jax +``` + ## Overview -This lecture describes how an administrator deployed a **multivariate hypergeometric distribution** in order to access the fairness of a procedure for awarding research grants. +This lecture describes how an administrator deployed a *multivariate hypergeometric distribution* in order to access the fairness of a procedure for awarding research grants. In the lecture we'll learn about @@ -35,12 +45,12 @@ In the lecture we'll learn about * using a Monte Carlo simulation of a multivariate normal distribution to evaluate the quality of a normal approximation * the administrator's problem and why the multivariate hypergeometric distribution is the right tool -## The Administrator's Problem +## The administrator's problem An administrator in charge of allocating research grants is in the following situation. To help us forget details that are none of our business here and to protect the anonymity of the administrator and the subjects, we call -research proposals **balls** and continents of residence of authors of a proposal a **color**. +research proposals *balls* and continents of residence of authors of a proposal a *color*. There are $K_i$ balls (proposals) of color $i$. @@ -54,7 +64,7 @@ All $N$ of these balls are placed in an urn. Then $n$ balls are drawn randomly. -The selection procedure is supposed to be **color blind** meaning that **ball quality**, a random variable that is supposed to be independent of **ball color**, governs whether a ball is drawn. +The selection procedure is supposed to be *color blind* meaning that *ball quality*, a random variable that is supposed to be independent of *ball color*, governs whether a ball is drawn. Thus, the selection procedure is supposed randomly to draw $n$ balls from the urn. @@ -62,7 +72,7 @@ The $n$ balls drawn represent successful proposals and are awarded research fu The remaining $N-n$ balls receive no research funds. -### Details of the Awards Procedure Under Study +### Details of the awards procedure under study Let $k_i$ be the number of balls of color $i$ that are drawn. @@ -74,8 +84,8 @@ $$ X = \begin{bmatrix} k_1 \cr k_2 \cr \vdots \cr k_c \end{bmatrix}. $$ -To evaluate whether the selection procedure is **color blind** the administrator wants to study whether the particular realization of $X$ drawn can plausibly -be said to be a random draw from the probability distribution that is implied by the **color blind** hypothesis. +To evaluate whether the selection procedure is *color blind* the administrator wants to study whether the particular realization of $X$ drawn can plausibly +be said to be a random draw from the probability distribution that is implied by the *color blind* hypothesis. The appropriate probability distribution is the one described [here](https://en.wikipedia.org/wiki/Hypergeometric_distribution). @@ -104,18 +114,24 @@ evidence against the hypothesis that the selection process is *fair*, which here means *color blind* and truly are random draws without replacement from the population of $N$ balls. -The right tool for the administrator's job is the **multivariate hypergeometric distribution**. +The right tool for the administrator's job is the *multivariate hypergeometric distribution*. -### Multivariate Hypergeometric Distribution +### Multivariate hypergeometric distribution Let's start with some imports. -```{code-cell} ipython +```{code-cell} ipython3 import matplotlib.pyplot as plt import numpy as np -from scipy.special import comb +import jax +import jax.numpy as jnp +import jax.random as jr +from functools import partial +from jax.scipy.special import gammaln from scipy.stats import normaltest -from numba import jit, prange +from typing import NamedTuple + +jax.config.update("jax_platform_name", "cpu") ``` To recapitulate, we assume there are in total $c$ types of objects in an urn. @@ -128,26 +144,24 @@ has the multivariate hypergeometric distribution. Note again that $N=\sum_{i=1}^{c} K_{i}$ is the total number of objects in the urn and $n=\sum_{i=1}^{c}k_{i}$. -**Notation** - -We use the following notation for **binomial coefficients**: ${m \choose q} = \frac{m!}{(m-q)!}$. +We use the following notation for *binomial coefficients*: ${m \choose q} = \frac{m!}{(m-q)!}$. The multivariate hypergeometric distribution has the following properties: -**Probability mass function**: +*Probability mass function*: $$ \Pr \{X_{i}=k_{i} \ \forall i\} = \frac {\prod _{i=1}^{c}{\binom {K_{i}}{k_{i}}}}{\binom {N}{n}} $$ -**Mean**: +*Mean*: $$ {\displaystyle \operatorname {E} (X_{i})=n{\frac {K_{i}}{N}}} $$ -**Variances and covariances**: +*Variances and covariances*: $$ {\displaystyle \operatorname {Var} (X_{i})=n{\frac {N-n}{N-1}}\;{\frac {K_{i}}{N}}\left(1-{\frac {K_{i}}{N}}\right)} @@ -157,97 +171,82 @@ $$ {\displaystyle \operatorname {Cov} (X_{i},X_{j})=-n{\frac {N-n}{N-1}}\;{\frac {K_{i}}{N}}{\frac {K_{j}}{N}}} $$ -To do our work for us, we'll write an `Urn` class. - -```{code-cell} python3 -class Urn: - - def __init__(self, K_arr): - """ - Initialization given the number of each type i object in the urn. - - Parameters - ---------- - K_arr: ndarray(int) - number of each type i object. - """ - - self.K_arr = np.array(K_arr) - self.N = np.sum(K_arr) - self.c = len(K_arr) - - def pmf(self, k_arr): - """ - Probability mass function. - - Parameters - ---------- - k_arr: ndarray(int) - number of observed successes of each object. - """ - - K_arr, N = self.K_arr, self.N - - k_arr = np.atleast_2d(k_arr) - n = np.sum(k_arr, 1) - - num = np.prod(comb(K_arr, k_arr), 1) - denom = comb(N, n) - - pr = num / denom - - return pr - - def moments(self, n): - """ - Compute the mean and variance-covariance matrix for - multivariate hypergeometric distribution. - - Parameters - ---------- - n: int - number of draws. - """ - - K_arr, N, c = self.K_arr, self.N, self.c - - # mean - μ = n * K_arr / N - - # variance-covariance matrix - Σ = np.full((c, c), n * (N - n) / (N - 1) / N ** 2) - for i in range(c-1): - Σ[i, i] *= K_arr[i] * (N - K_arr[i]) - for j in range(i+1, c): - Σ[i, j] *= - K_arr[i] * K_arr[j] - Σ[j, i] = Σ[i, j] - - Σ[-1, -1] *= K_arr[-1] * (N - K_arr[-1]) - - return μ, Σ - - def simulate(self, n, size=1, seed=None): - """ - Simulate a sample from multivariate hypergeometric - distribution where at each draw we take n objects - from the urn without replacement. - - Parameters - ---------- - n: int - number of objects for each draw. - size: int(optional) - sample size. - seed: int(optional) - random seed. - """ - - K_arr = self.K_arr - - gen = np.random.Generator(np.random.PCG64(seed)) - sample = gen.multivariate_hypergeometric(K_arr, n, size=size) - - return sample +We follow the same template used in other JAX lectures by building a small helper structure and standalone functions. + +```{code-cell} ipython3 +class UrnModel(NamedTuple): + K: jnp.ndarray + N: int + c: int + + +def create_urn(K_arr): + """Return an UrnModel containing totals for each color.""" + K_arr = jnp.asarray(K_arr, dtype=jnp.int32) + N = int(jnp.sum(K_arr)) + c = int(K_arr.size) + return UrnModel(K=K_arr, N=N, c=c) + + +def log_comb(n, k): + """Compute log binomial coefficients using gammaln.""" + n = jnp.asarray(n) + k = jnp.asarray(k) + return gammaln(n + 1) - gammaln(k + 1) - gammaln(n - k + 1) + + +def pmf(urn, k_arr): + """Evaluate the multivariate hypergeometric PMF.""" + K_arr, N = urn.K, urn.N + k_arr = jnp.atleast_2d(jnp.asarray(k_arr, dtype=jnp.int32)) + n = jnp.sum(k_arr, axis=1) + num_log = jnp.sum(log_comb(K_arr[None, :], k_arr), axis=1) + denom_log = log_comb(N, n) + return jnp.exp(num_log - denom_log) + + +def moments(urn, n): + """Return the mean vector and covariance matrix.""" + K_arr, N = urn.K, urn.N + p = K_arr / N + μ = n * p + factor = n * (N - n) / (N - 1) + Σ = factor * (jnp.diag(p) - jnp.outer(p, p)) + return μ, Σ + + +def simulate(urn, n, size=1, seed=0): + """Simulate multivariate hypergeometric draws.""" + K_arr, c = urn.K, urn.c + n = int(n) + size = int(size) + key = jr.PRNGKey(seed) + + if size == 0: + return jnp.zeros((0, c), dtype=jnp.int32) + + def draw_once(key): + """Generate one draw via a lax.scan loop.""" + counts0 = jnp.zeros(c, dtype=jnp.int32) + + def body_fn(carry, _): + counts, remaining, key = carry + key, subkey = jr.split(key) + probs = remaining / jnp.sum(remaining) + u = jr.uniform(subkey) + cut = jnp.cumsum(probs) + idx = jnp.searchsorted(cut, u, side='right') + counts = counts.at[idx].add(1) + remaining = remaining.at[idx].add(-1) + return (counts, remaining, key), None + + (counts, _, _), _ = jax.lax.scan( + body_fn, (counts0, K_arr, key), None, length=n) + return counts + + draw_once = jax.jit(draw_once) + keys = jr.split(key, size) + return jax.vmap(draw_once)(keys) ``` ## Usage @@ -265,17 +264,16 @@ $$ P(2{\text{ black}},2{\text{ white}},2{\text{ red}})={{{5 \choose 2}{10 \choose 2}{15 \choose 2}} \over {30 \choose 6}}=0.079575596816976 $$ -```{code-cell} python3 -# construct the urn +```{code-cell} ipython3 K_arr = [5, 10, 15] -urn = Urn(K_arr) +urn = create_urn(K_arr) ``` -Now use the Urn Class method `pmf` to compute the probability of the outcome $X = \begin{bmatrix} 2 & 2 & 2 \end{bmatrix}$ +Now use `pmf` to compute the probability of the outcome $X = \begin{bmatrix} 2 & 2 & 2 \end{bmatrix}$. -```{code-cell} python3 -k_arr = [2, 2, 2] # array of number of observed successes -urn.pmf(k_arr) +```{code-cell} ipython3 +k_arr = [2, 2, 2] +pmf(urn, k_arr) ``` We can use the code to compute probabilities of a list of possible outcomes by @@ -283,27 +281,27 @@ constructing a 2-dimensional array `k_arr` and `pmf` will return an array of probabilities for observing each case. -```{code-cell} python3 +```{code-cell} ipython3 k_arr = [[2, 2, 2], [1, 3, 2]] -urn.pmf(k_arr) +pmf(urn, k_arr) ``` Now let's compute the mean vector and variance-covariance matrix. -```{code-cell} python3 +```{code-cell} ipython3 n = 6 -μ, Σ = urn.moments(n) +μ, Σ = moments(urn, n) ``` -```{code-cell} python3 +```{code-cell} ipython3 μ ``` -```{code-cell} python3 +```{code-cell} ipython3 Σ ``` -### Back to The Administrator's Problem +### Back to the administrator's problem Now let's turn to the grant administrator's problem. @@ -311,75 +309,72 @@ Here the array of numbers of $i$ objects in the urn is $\left(157, 11, 46, 24\right)$. -```{code-cell} python3 +```{code-cell} ipython3 K_arr = [157, 11, 46, 24] -urn = Urn(K_arr) +urn = create_urn(K_arr) ``` Let's compute the probability of the outcome $\left(10, 1, 4, 0 \right)$. -```{code-cell} python3 +```{code-cell} ipython3 k_arr = [10, 1, 4, 0] -urn.pmf(k_arr) +pmf(urn, k_arr) ``` -We can compute probabilities of three possible outcomes by constructing a 3-dimensional -arrays `k_arr` and utilizing the method `pmf` of the `Urn` class. +We can compute probabilities of three possible outcomes by constructing a 3-dimensional array `k_arr` and applying `pmf`. -```{code-cell} python3 +```{code-cell} ipython3 k_arr = [[5, 5, 4 ,1], [10, 1, 2, 2], [13, 0, 2, 0]] -urn.pmf(k_arr) +pmf(urn, k_arr) ``` Now let's compute the mean and variance-covariance matrix of $X$ when $n=6$. -```{code-cell} python3 -n = 6 # number of draws -μ, Σ = urn.moments(n) +```{code-cell} ipython3 +n = 6 +μ, Σ = moments(urn, n) ``` -```{code-cell} python3 -# mean +```{code-cell} ipython3 μ ``` -```{code-cell} python3 -# variance-covariance matrix +```{code-cell} ipython3 Σ ``` We can simulate a large sample and verify that sample means and covariances closely approximate the population means and covariances. -```{code-cell} python3 -size = 10_000_000 -sample = urn.simulate(n, size=size) +```{code-cell} ipython3 +size = 200_000 +sample = simulate(urn, n, size=size, seed=123) ``` -```{code-cell} python3 -# mean -np.mean(sample, 0) +```{code-cell} ipython3 +jnp.mean(sample, axis=0) ``` -```{code-cell} python3 -# variance covariance matrix -np.cov(sample.T) +```{code-cell} ipython3 +jnp.cov(sample.T) ``` Evidently, the sample means and covariances approximate their population counterparts well. -### Quality of Normal Approximation +### Quality of normal approximation To judge the quality of a multivariate normal approximation to the multivariate hypergeometric distribution, we draw a large sample from a multivariate normal distribution with the mean vector and covariance matrix for the corresponding multivariate hypergeometric distribution and compare the simulated distribution with the population multivariate hypergeometric distribution. -```{code-cell} python3 -sample_normal = np.random.multivariate_normal(μ, Σ, size=size) +```{code-cell} ipython3 +key_normal = jr.PRNGKey(0) +sample_normal = jr.multivariate_normal( + key_normal, μ, Σ, shape=(size,)) ``` -```{code-cell} python3 +```{code-cell} ipython3 def bivariate_normal(x, y, μ, Σ, i, j): μ_x, μ_y = μ[i], μ[j] - σ_x, σ_y = np.sqrt(Σ[i, i]), np.sqrt(Σ[j, j]) + σ_x, σ_y = jnp.sqrt(Σ[i, i]), jnp.sqrt(Σ[j, j]) σ_xy = Σ[i, j] x_μ = x - μ_x @@ -387,35 +382,40 @@ def bivariate_normal(x, y, μ, Σ, i, j): ρ = σ_xy / (σ_x * σ_y) z = x_μ**2 / σ_x**2 + y_μ**2 / σ_y**2 - 2 * ρ * x_μ * y_μ / (σ_x * σ_y) - denom = 2 * np.pi * σ_x * σ_y * np.sqrt(1 - ρ**2) + denom = 2 * jnp.pi * σ_x * σ_y * jnp.sqrt(1 - ρ**2) - return np.exp(-z / (2 * (1 - ρ**2))) / denom + return jnp.exp(-z / (2 * (1 - ρ**2))) / denom ``` -```{code-cell} python3 -@jit +```{code-cell} ipython3 +@partial(jax.jit, static_argnums=2) def count(vec1, vec2, n): - size = sample.shape[0] - - count_mat = np.zeros((n+1, n+1)) - for i in prange(size): - count_mat[vec1[i], vec2[i]] += 1 - - return count_mat + """Count joint frequencies of integer pairs using JAX bincount.""" + vec1 = vec1.astype(jnp.int32) + vec2 = vec2.astype(jnp.int32) + base = n + 1 + idx = vec1 * base + vec2 + counts = jnp.bincount(idx, length=base * base) + return counts.reshape((base, base)) ``` -```{code-cell} python3 +```{code-cell} ipython3 c = urn.c fig, axs = plt.subplots(c, c, figsize=(14, 14)) # grids for ploting the bivariate Gaussian -x_grid = np.linspace(-2, n+1, 100) -y_grid = np.linspace(-2, n+1, 100) -X, Y = np.meshgrid(x_grid, y_grid) +x_grid = jnp.linspace(-2, n+1, 100) +y_grid = jnp.linspace(-2, n+1, 100) +X, Y = jnp.meshgrid(x_grid, y_grid) +bin_edges = list(range(0, n + 1)) for i in range(c): - axs[i, i].hist(sample[:, i], bins=np.arange(0, n, 1), alpha=0.5, density=True, label='hypergeom') - axs[i, i].hist(sample_normal[:, i], bins=np.arange(0, n, 1), alpha=0.5, density=True, label='normal') + axs[i, i].hist(sample[:, i], + bins=bin_edges, alpha=0.5, + density=True, label='hypergeom') + axs[i, i].hist(sample_normal[:, i], + bins=bin_edges, alpha=0.5, + density=True, label='normal') axs[i, i].legend() axs[i, i].set_title('$k_{' +str(i+1) +'}$') for j in range(c): @@ -423,14 +423,18 @@ for i in range(c): continue # bivariate Gaussian density function - Z = bivariate_normal(X, Y, μ, Σ, i, j) - cs = axs[i, j].contour(X, Y, Z, 4, colors="black", alpha=0.6) + Z = np.asarray(bivariate_normal(X, Y, μ, Σ, i, j)) + cs = axs[i, j].contour( + X, Y, Z, 4, colors="black", alpha=0.6) axs[i, j].clabel(cs, inline=1, fontsize=10) # empirical multivariate hypergeometric distrbution - count_mat = count(sample[:, i], sample[:, j], n) - axs[i, j].pcolor(count_mat.T/size, cmap='Blues') - axs[i, j].set_title('$(k_{' +str(i+1) +'}, k_{' + str(j+1) + '})$') + count_mat = count( + sample[:, i], sample[:, j], n) + axs[i, j].pcolor( + count_mat.T/size, cmap='Blues') + axs[i, j].set_title( + '$(k_{' +str(i+1) +'}, k_{' + str(j+1) + '})$') plt.show() ``` @@ -454,7 +458,7 @@ The null hypothesis is that the sample follows normal distribution. > `normaltest` returns an array of p-values associated with tests for each $k_i$ sample. -```{code-cell} python3 +```{code-cell} ipython3 test_multihyper = normaltest(sample) test_multihyper.pvalue ``` @@ -463,7 +467,7 @@ As we can see, all the p-values are almost $0$ and the null hypothesis is soundl By contrast, the sample from normal distribution does not reject the null hypothesis. -```{code-cell} python3 +```{code-cell} ipython3 test_normal = normaltest(sample_normal) test_normal.pvalue ```