Skip to content

Commit 382f222

Browse files
authored
Normal Model (#139)
* [WIP] The initializer...does everything? * [WIP] Refactor to fit model, accuracy test passes. * [WIP] Rempve unecessary tests. * Seems right? * Fill in functions that might be needed for testing. * Make tests more consistent. * Clean up model a bit. * [WIP] Fix log likelihood? * Remove unfinished methods * [WIP] Add normal scale modeling * [WIP] Fix normal model ll * Small fixes and clean up. * Fix missing export. * Fix init_par function * Fix init_par for scale. * Add docs. * Move to completed. * Fix typing issue. * Fix typing issue. * Small test cleanups. * Move ll function. * Move ll function. * Ok remove! * Add changes from update_api * pre-commit * Missed that variable!
1 parent db83227 commit 382f222

File tree

12 files changed

+343
-107
lines changed

12 files changed

+343
-107
lines changed

batchglm/models/glm_norm/model.py

Lines changed: 2 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -12,95 +12,42 @@ class Model(_ModelGLM, metaclass=abc.ABCMeta):
1212
"""Generalized Linear Model (GLM) with normal noise."""
1313

1414
def link_loc(self, data) -> Union[np.ndarray, dask.array.core.Array]:
15-
"""Short summary.
16-
17-
:param type data: Description of parameter `data`.
18-
:return: Description of returned object.
19-
:rtype: type
20-
21-
"""
2215
return data
2316

2417
def inverse_link_loc(self, data) -> Union[np.ndarray, dask.array.core.Array]:
25-
"""Short summary.
26-
27-
:param type data: Description of parameter `data`.
28-
:return: Description of returned object.
29-
:rtype: type
30-
31-
"""
3218
return data
3319

3420
def link_scale(self, data) -> Union[np.ndarray, dask.array.core.Array]:
35-
"""Short summary.
36-
37-
:param type data: Description of parameter `data`.
38-
:return: Description of returned object.
39-
:rtype: type
40-
41-
"""
4221
return np.log(data)
4322

4423
def inverse_link_scale(self, data) -> Union[np.ndarray, dask.array.core.Array]:
45-
"""Short summary.
46-
47-
:param type data: Description of parameter `data`.
48-
:return: Description of returned object.
49-
:rtype: type
50-
51-
"""
5224
return np.exp(data)
5325

5426
@property
5527
def eta_loc(self) -> Union[np.ndarray, dask.array.core.Array]:
56-
"""Short summary.
57-
58-
:return: Description of returned object.
59-
:rtype: np.ndarray
60-
61-
"""
6228
eta = np.matmul(self.design_loc, self.theta_location_constrained)
6329
if self.size_factors is not None:
64-
eta *= np.expand_dims(self.size_factors, axis=1)
30+
eta *= self.size_factors
6531
return eta
6632

6733
def eta_loc_j(self, j) -> Union[np.ndarray, dask.array.core.Array]:
68-
"""Short summary.
69-
70-
:param type j: Description of parameter `j`.
71-
:return: Description of returned object.
72-
:rtype: np.ndarray
73-
74-
"""
7534
# Make sure that dimensionality of sliced array is kept:
7635
if isinstance(j, int) or isinstance(j, np.int32) or isinstance(j, np.int64):
7736
j = [j]
7837
eta = np.matmul(self.design_loc, self.theta_location_constrained[:, j])
7938
if self.size_factors is not None:
80-
eta *= np.expand_dims(self.size_factors, axis=1)
39+
eta *= self.size_factors
8140
eta = self.np_clip_param(eta, "eta_loc")
8241
return eta
8342

8443
# Re-parameterizations:
8544

8645
@property
8746
def mean(self) -> Union[np.ndarray, dask.array.core.Array]:
88-
"""Short summary.
89-
90-
:return: Description of returned object.
91-
:rtype: np.ndarray
92-
93-
"""
9447
return self.location
9548

9649
@property
9750
def sd(self) -> Union[np.ndarray, dask.array.core.Array]:
98-
"""Short summary.
99-
100-
:return: Description of returned object.
101-
:rtype: np.ndarray
102-
103-
"""
10451
return self.scale
10552

10653
# param constraints:

batchglm/models/glm_norm/utils.py

Lines changed: 75 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,18 @@
11
import logging
2-
from typing import Union
2+
from typing import Tuple, Union
33

4+
import dask
45
import numpy as np
56
import scipy.sparse
67

7-
from .external import closedform_glm_mean, closedform_glm_scale
8+
from .external import closedform_glm_scale
89

910
logger = logging.getLogger("batchglm")
1011

1112

12-
def closedform_norm_glm_mean(
13-
x: Union[np.ndarray, scipy.sparse.csr_matrix],
14-
design_loc: np.ndarray,
15-
constraints_loc,
16-
size_factors=None,
17-
link_fn=lambda x: x,
18-
inv_link_fn=lambda x: x,
19-
):
20-
r"""
21-
Calculates a closed-form solution for the `mean` parameters of normal GLMs.
22-
23-
:param x: The sample data
24-
:param design_loc: design matrix for location
25-
:param constraints_loc: tensor (all parameters x dependent parameters)
26-
Tensor that encodes how complete parameter set which includes dependent
27-
parameters arises from indepedent parameters: all = <constraints, indep>.
28-
This form of constraints is used in vector generalized linear models (VGLMs).
29-
:param size_factors: size factors for X
30-
:return: tuple: (groupwise_means, mean, rmsd)
31-
"""
32-
return closedform_glm_mean(
33-
x=x,
34-
dmat=design_loc,
35-
constraints=constraints_loc,
36-
size_factors=size_factors,
37-
link_fn=link_fn,
38-
inv_link_fn=inv_link_fn,
39-
)
40-
41-
4213
def closedform_norm_glm_logsd(
43-
x: Union[np.ndarray, scipy.sparse.csr_matrix],
44-
design_scale: np.ndarray,
14+
x: Union[np.ndarray, scipy.sparse.csr_matrix, dask.array.core.Array],
15+
design_scale: Union[np.ndarray, dask.array.core.Array],
4516
constraints=None,
4617
size_factors=None,
4718
groupwise_means=None,
@@ -71,3 +42,73 @@ def compute_scales_fun(variance, mean):
7142
link_fn=link_fn,
7243
compute_scales_fun=compute_scales_fun,
7344
)
45+
46+
47+
def init_par(model, init_location: str, init_scale: str) -> Tuple[np.ndarray, np.ndarray, bool, bool]:
48+
r"""
49+
standard:
50+
Only initialise intercept and keep other coefficients as zero.
51+
52+
closed-form:
53+
Initialize with Maximum Likelihood / Maximum of Momentum estimators
54+
55+
Idea:
56+
$$
57+
\theta &= f(x) \\
58+
\Rightarrow f^{-1}(\theta) &= x \\
59+
&= (D \cdot D^{+}) \cdot x \\
60+
&= D \cdot (D^{+} \cdot x) \\
61+
&= D \cdot x' = f^{-1}(\theta)
62+
$$
63+
"""
64+
65+
groupwise_means = None
66+
67+
init_location_str = init_location.lower()
68+
# Chose option if auto was chosen
69+
auto_or_closed_form = init_location_str == "auto" or init_location_str == "closed_form"
70+
if auto_or_closed_form or init_location_str == "all_zero":
71+
if auto_or_closed_form:
72+
logger.warning(
73+
(
74+
"There is no need for closed form location model initialization"
75+
"because it is already closed form - falling back to zeros"
76+
)
77+
)
78+
init_theta_location = np.zeros([model.num_loc_params, model.num_features])
79+
elif init_location_str == "standard":
80+
overall_means = np.mean(model.x, axis=0) # directly calculate the mean
81+
init_theta_location = np.zeros([model.num_loc_params, model.num_features])
82+
init_theta_location[0, :] = np.log(overall_means)
83+
else:
84+
raise ValueError("init_location string %s not recognized" % init_location)
85+
86+
init_scale_str = init_scale.lower()
87+
if init_scale_str == "auto":
88+
init_scale_str = "standard"
89+
90+
if init_scale_str == "standard":
91+
groupwise_scales, init_scale_intercept, rmsd_b = closedform_norm_glm_logsd(
92+
x=model.x,
93+
design_scale=model.design_scale[:, [0]],
94+
constraints=model.constraints_scale[[0], :][:, [0]],
95+
size_factors=model.size_factors,
96+
groupwise_means=None,
97+
link_fn=lambda r: np.log(r + np.nextafter(0, 1, dtype=r.dtype)),
98+
)
99+
init_theta_scale = np.zeros([model.num_scale_params, model.num_features])
100+
init_theta_scale[0, :] = init_scale_intercept
101+
elif init_scale_str == "closed_form":
102+
groupwise_scales, init_theta_scale, rmsd_b = closedform_norm_glm_logsd(
103+
x=model.x,
104+
design_scale=model.design_scale,
105+
constraints=model.constraints_scale,
106+
size_factors=model.size_factors,
107+
groupwise_means=groupwise_means,
108+
)
109+
elif init_scale_str == "all_zero":
110+
init_theta_scale = np.zeros([model.num_scale_params, model.x.shape[1]])
111+
else:
112+
raise ValueError("init_scale string %s not recognized" % init_scale_str)
113+
114+
return init_theta_location, init_theta_scale, True, True

batchglm/train/numpy/base_glm/model_container.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def fim(self) -> Union[np.ndarray, dask.array.core.Array]:
327327
fim_scale_scale = self.fim_scale_scale
328328
fim_location_scale = self.fim_location_scale
329329
fim_ba = np.transpose(fim_location_scale, axes=[0, 2, 1])
330-
return -np.concatenate(
330+
return np.concatenate(
331331
[
332332
np.concatenate([fim_location_location, fim_location_scale], axis=2),
333333
np.concatenate([fim_ba, fim_scale_scale], axis=2),

batchglm/train/numpy/glm_nb/model_container.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def fim_weight_location_location(self) -> Union[np.ndarray, dask.array.core.Arra
1818
Fisher inverse matrix weights
1919
:return: observations x features
2020
"""
21-
return -self.location * self.scale / (self.scale + self.location)
21+
return self.location * self.scale / (self.scale + self.location)
2222

2323
@property
2424
def ybar(self) -> Union[np.ndarray, dask.array.core.Array]:
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .estimator import Estimator
2+
from .model_container import ModelContainer
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import logging
2+
3+
import numpy as np
4+
5+
from .external import EstimatorGlm, Model, init_par
6+
from .model_container import ModelContainer
7+
8+
logger = logging.getLogger("batchglm")
9+
10+
11+
class Estimator(EstimatorGlm):
12+
def __init__(
13+
self,
14+
model: Model,
15+
init_location: str = "AUTO",
16+
init_scale: str = "AUTO",
17+
# batch_size: Optional[Union[Tuple[int, int], int]] = None,
18+
quick_scale: bool = False,
19+
dtype: str = "float64",
20+
):
21+
"""
22+
Performs initialisation and creates a new estimator.
23+
:param model:
24+
The GLM model to be fit
25+
:param init_location: (Optional)
26+
Low-level initial values for a. Can be:
27+
28+
- str:
29+
* "auto": automatically choose best initialization
30+
* "standard": initialize intercept with observed mean
31+
* "closed_form": try to initialize with closed form
32+
- np.ndarray: direct initialization of 'a'
33+
:param init_scale: (Optional)
34+
Low-level initial values for b. Can be:
35+
36+
- str:
37+
* "auto": automatically choose best initialization
38+
* "random": initialize with random values
39+
* "standard": initialize with zeros
40+
* "closed_form": try to initialize with closed form
41+
- np.ndarray: direct initialization of 'b'
42+
:param quick_scale: bool
43+
Whether `scale` will be fitted faster and maybe less accurate.
44+
Useful in scenarios where fitting the exact `scale` is not absolutely necessary.
45+
:param dtype: Numerical precision.
46+
"""
47+
init_theta_location, init_theta_scale, train_loc, train_scale = init_par(
48+
model=model, init_location=init_location, init_scale=init_scale
49+
)
50+
init_theta_location = init_theta_location.astype(dtype)
51+
init_theta_scale = init_theta_scale.astype(dtype)
52+
self._train_scale = train_scale
53+
self._train_loc = train_loc
54+
if quick_scale:
55+
self._train_scale = False
56+
_model_container = ModelContainer(
57+
model=model,
58+
init_theta_location=init_theta_location,
59+
init_theta_scale=init_theta_scale,
60+
chunk_size_genes=model.chunk_size_genes,
61+
dtype=dtype,
62+
)
63+
super(Estimator, self).__init__(model_container=_model_container, dtype=dtype)
64+
65+
def train(
66+
self,
67+
**kwargs,
68+
):
69+
model = self._model_container.model
70+
if self._train_loc:
71+
theta_location, _, _, _ = np.linalg.lstsq(model.design_loc, model.x)
72+
self._model_container.theta_location = theta_location
73+
self._train_loc = False
74+
super().train(**kwargs)
75+
self._train_loc = True
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import batchglm.utils.data as data_utils
2+
from batchglm import pkg_constants
3+
from batchglm.models.base_glm.utils import closedform_glm_scale
4+
from batchglm.models.glm_norm.model import Model
5+
from batchglm.models.glm_norm.utils import closedform_norm_glm_logsd, init_par
6+
7+
# import necessary base_glm layers
8+
from batchglm.train.numpy.base_glm import BaseModelContainer, EstimatorGlm
9+
from batchglm.utils.linalg import groupwise_solve_lm

0 commit comments

Comments
 (0)