Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 186 additions & 2 deletions dynamax/hidden_markov_model/models/categorical_hmm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Categorical Hidden Markov Model."""
from typing import NamedTuple, Optional, Tuple, Union
from typing import NamedTuple, Optional, Tuple, Union, List

import jax.numpy as jnp
from jax import lax
import jax.random as jr
import tensorflow_probability.substrates.jax.bijectors as tfb
import tensorflow_probability.substrates.jax.distributions as tfd
Expand All @@ -15,7 +16,8 @@
from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions
from dynamax.parameters import ParameterProperties, ParameterSet, PropertySet
from dynamax.types import IntScalar, Scalar
from dynamax.utils.utils import pytree_sum
from dynamax.types import PRNGKeyT
from dynamax.utils.utils import pytree_sum, ensure_array_has_batch_dim, low_rank_pinv, multilinear_product, cp_decomp


class ParamsCategoricalHMMEmissions(NamedTuple):
Expand Down Expand Up @@ -118,6 +120,43 @@ def m_step(self, params, props, batch_stats, m_step_state):
probs = tfd.Dirichlet(self.emission_prior_concentration + emission_stats['sum_x']).mode()
params = params._replace(probs=probs)
return params, m_step_state

def calc_sample_moment(self,
emissions: Float[Array, "num_batches num_timesteps emission_dim"],
order: Union[int,
List[int]]):
r"""Find the sample cross moments of order $n$. These are averaged over the
full timeseries because the HMM is time homogeneous, so for example the following
are assumed interchangeable:

$$\mathbb{E}[x_1 \otimes x_2 \otimes x_3]$$

$$\mathbb{E}[x_{t+1} \otimes x_{t+2} \dots x_{t+3}]$$
"""
x = one_hot(jnp.squeeze(emissions, -1), num_classes=self.num_classes)
B, T, _ = x.shape
if isinstance(order, int):
order = list(range(order))
order_len = max(order)+1
T_effective = T - order_len + 1
if T_effective <= 0:
raise ValueError

einsum_args = []
output_indices = []
for i, j in enumerate(order):
slice_j = x[:, j:T_effective+j, :]
einsum_args.append(slice_j)
einsum_args.append([0, 1, 2+j])
output_indices.append(2+j)

einsum_args.append(output_indices)
sum_outer_products = jnp.einsum(*einsum_args)

return sum_outer_products / (B * T)

def calc_pos_sample_mean(self, emissions, pos):
return jnp.mean(one_hot(jnp.squeeze(emissions, -1), num_classes=self.num_classes)[:,pos,:], axis=0)


class CategoricalHMM(HMM):
Expand Down Expand Up @@ -186,3 +225,148 @@ def initialize(self,
params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method, transition_matrix=transition_matrix)
params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method, emission_probs=emission_probs)
return ParamsCategoricalHMM(**params), ParamsCategoricalHMM(**props)


def get_view(self,
batch_emissions: Float[Array, "num_batches num_timesteps emission_dim"],
target: int,
num_init: int = 100,
num_iter: int = 1000,
key: Array = jr.PRNGKey(0)
) -> Array:
r"""Return the sample conditional means from the requested view.

Specifically, return the conditional mean of the emissions at time t+target
given the hidden state at time t+1.

$$\mathbb{E}[x_{t+target} \mid y_{t+1}=h]$$

Args:
batch_emissions: the emission data.
target: the requested timestep relative to the hidden state being conditioned on.
num_init: number of random starting points should be used in the robust tensor power method.
num_iter: number of iterations in the robust tensor power method.
key (PRNGKey, optional): random number generator for unspecified parameters. Must not be None if there are any unspecified parameters. Defaults to None.

Returns:
Conditional mean vector $\mu$.
"""
k = self.num_states
sym_M2, sym_M3 = self.moment_view(batch_emissions, target, k)
# find whitening matrix W
eigvals, eigvecs = jnp.linalg.eigh(sym_M2)

idx = jnp.argsort(jnp.abs(eigvals))[-k:]
trunc_eigvals = eigvals[idx]
trunc_eigvecs = eigvecs[:,idx]

U = trunc_eigvecs
D = jnp.diag(1/jnp.sqrt(trunc_eigvals))

W = U @ D
B = jnp.linalg.pinv(W.T)

# tensor decomposition
tilde_sym_M3 = multilinear_product(sym_M3, [W, W, W])
rob_eigvecs, rob_eigvals = cp_decomp(tilde_sym_M3, L=num_init, N=num_iter, k=k, key=key)

return jnp.diag(rob_eigvals) @ rob_eigvecs @ B.T


def fit_moments(
self,
params: ParameterSet,
props: PropertySet,
emissions: Union[Float[Array, "num_timesteps emission_dim"],
Float[Array, "num_batches num_timesteps emission_dim"]],
num_init: int=100,
num_iter: int=1000,
key: Array=jr.PRNGKey(0)
) -> ParameterSet:
r"""Estimate the parameters using method of moments.

Specifically, compute emission distribution and transition matrix from the second
and third moments. Since the model is time homogeneous, you can take it over all
consecutive 2 or 3 timesteps respectively. To recover the initial distribution, take
the mean over the first timestep of each sequence using the known emission
distribution to find the hidden state distribution.

Then

Args:
params: model parameters $\theta$
props: properties specifying which parameters should be learned
emissions: observations from data.
num_init: number of random starting points should be used in the robust tensor power method.
num_iter: number of iterations in the robust tensor power method.
key: sufficient statistics from each sequence

Returns:
new parameters

"""
batch_emissions = ensure_array_has_batch_dim(emissions, self.emission_shape)
key_2, key_3 = jr.split(key, 2)
mu_1 = self.get_view(batch_emissions, 1, 100,1000, key_2)
mu_2 = self.get_view(batch_emissions, 2, 100,1000, key_3)
k = self.num_states

transition_params = mu_2 @ low_rank_pinv(mu_1, k)
emission_params = mu_1

initial_params = low_rank_pinv(emission_params.T, k) @ self.emission_component.calc_pos_sample_mean(batch_emissions, 0)
params = params._replace(initial=initial_params, transitions=transition_params, emissions=emission_params)

return params,


def moment_view(self,
batch_emissions: Float[Array, "num_batches num_timesteps emission_dim"],
target: int,
k: int
) -> Tuple[Array, Array]:
r"""Perform the symmetrizing operation to get a particular view of the
second and third order moments.

Specifically, compute the second and third moments. Since the model is time
homogeneous, you can take it over all consecutive 2 or 3 timesteps respectively.

Then

Args:
batch_emissions: the emission data.
target: the requested view.
k: the number of hidden states.

Returns:
sym_M2: symmetrized second order moment corresponding to view of `target`.
sym_M3: symmetrized third order moment corresponding to view of `target`.
"""

if target == 0:
source_1, source_2 = 1, 2
elif target == 1:
source_1, source_2 = 0, 2
else:
source_1, source_2 = 0, 1

A = self.emission_component.calc_sample_moment(batch_emissions, [target, source_2])
B = self.emission_component.calc_sample_moment(batch_emissions, [source_1, source_2])
C = self.emission_component.calc_sample_moment(batch_emissions, [target, source_1])
D = self.emission_component.calc_sample_moment(batch_emissions, [source_2, source_1])

M2 = self.emission_component.calc_sample_moment(batch_emissions, [source_1, source_2])
M3 = self.emission_component.calc_sample_moment(batch_emissions, 3)

d = self.emission_component.num_classes

sym_pre = jnp.transpose(A @ low_rank_pinv(B, k))
sym_post = jnp.transpose(C @ low_rank_pinv(D, k))

M3_args = [jnp.eye(d)]*3
M3_args[source_1] = sym_pre
M3_args[source_2] = sym_post

sym_M2 = multilinear_product(M2, [sym_pre, sym_post])
sym_M3 = multilinear_product(M3, M3_args)
return sym_M2, sym_M3
29 changes: 29 additions & 0 deletions dynamax/ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,3 +478,32 @@ def _loss_fn(unc_params, minibatch):

params = from_unconstrained(unc_params, props)
return params, losses

def fit_moments(
self,
params: ParameterSet,
props: PropertySet,
emissions: Union[Float[Array, "num_timesteps emission_dim"],
Float[Array, "num_batches num_timesteps emission_dim"]],
key: Array=jr.PRNGKey(0)
) -> ParameterSet:
r"""Estimate the parameters using method of moments.

Specifically, compute the second and third moments. Since the model is time
homogeneous, you can take it over all consecutive 2 or 3 timesteps respectively.

$$M_2 = \mathbb{E}[x_1 \otimes x_2]$$
$$M_3 = \mathbb{E}[x_1 \otimes x_2 \otimes x_3]$$

Then

Args:
params: model parameters $\theta$
props: properties specifying which parameters should be learned
key: sufficient statistics from each sequence

Returns:
new parameters

"""
raise NotImplemented
78 changes: 77 additions & 1 deletion dynamax/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from functools import partial
from jax import jit
from jax import vmap
from jax import vmap, lax
from jax.tree_util import tree_map, tree_leaves, tree_flatten, tree_unflatten
from jaxtyping import Array, Int
from scipy.optimize import linear_sum_assignment
Expand Down Expand Up @@ -220,3 +220,79 @@ def psd_solve(A, b, diagonal_boost=1e-9):
def symmetrize(A):
"""Symmetrize one or more matrices."""
return 0.5 * (A + jnp.swapaxes(A, -1, -2))

def multilinear_product(core, factors):
"""Multilinear map of a core tensor of order p with p matrices.
For an order 3 core tensor of shape (I, J, K), the factor matrices have
shape (P, I), (Q, J) and (R, K) respectively. The output has shape
(P, Q, R).
"""
order = core.ndim
assert order == len(factors)
einsum_args = [core, list(range(order))]
for i, factor in enumerate(factors):
einsum_args.append(factor)
einsum_args.append([i, i+order])
einsum_args.append(list(range(order, 2*order)))

return jnp.einsum(*einsum_args)

def low_rank_pinv(X, k):
"""Find the Moore-Penrose Pseudoinverse of a matrix X with rank k.
This is more robust than jnp.linalg.pinv since the sample cross moments
will likely have higher rank than the population cross moments.

Here, we find the SVD which sorts in descending order of the singular
values and truncate the first k.
"""

u, s, vt = jnp.linalg.svd(X)
u_trunc = u[:,:k]
s_trunc = s[:k]
vt_trunc = vt[:k,:]
return vt_trunc.T @ jnp.diag(1.0/s_trunc) @ u_trunc.T

def rtpm_eigvals(X, y):
"""Find the eigenvalues of X corresponding to the eigenvectors $y$."""
return multilinear_product(X, [y, y, y])

def rtpm(X, key=jr.PRNGKey(0), L=100, N=1000):
"""Applies the robust tensor power method to a tensor X and returns the
deflated tensor, robust eigenvectors and eigenvalues.
"""
assert X.ndim == 3
assert len(set(X.shape)) == 1
keys = jr.split(key, L)
k = X.shape[0]

def power_iter_update(theta, _):
mlm = multilinear_product(X, [jnp.eye(k), theta, theta]).squeeze(-1)
return jnp.divide(mlm, jnp.linalg.norm(mlm)), None

def theta_sample(theta_key):
# random point on the unit sphere in R^k
Z = jr.normal(theta_key, shape=(k,1))
norm_Z = jnp.linalg.norm(Z)
theta_init = jnp.divide(Z, norm_Z)

theta_N, _ = lax.scan(power_iter_update,
theta_init,
length=N)
return theta_N

theta_arr = vmap(theta_sample)(keys)
tau_star = jnp.argmax(vmap(partial(rtpm_eigvals, X))(theta_arr))
theta_hat, _ = lax.scan(power_iter_update,
theta_arr[tau_star],
length=N)
lambda_hat = rtpm_eigvals(X, theta_hat).squeeze()
theta_hat = theta_hat.squeeze()
def_X = X - lambda_hat * jnp.einsum('a,b,c-> a b c', theta_hat, theta_hat, theta_hat)
return def_X, (theta_hat, lambda_hat)

def cp_decomp(X, L, N, k, key):
"""Apply the robust tensor power method iteratively, returning the robust
eigenvectors and eigenvalues.
"""
_, (eigvecs, eigvals) = lax.scan(partial(rtpm, L=L, N=N), X, jr.split(key,k))
return eigvecs, eigvals