|
1 | 1 | import logging |
2 | | -from typing import Union |
| 2 | +from typing import Tuple, Union |
3 | 3 |
|
| 4 | +import dask |
4 | 5 | import numpy as np |
5 | 6 | import scipy.sparse |
6 | 7 |
|
7 | | -from .external import closedform_glm_mean, closedform_glm_scale |
| 8 | +from .external import closedform_glm_scale |
8 | 9 |
|
9 | 10 | logger = logging.getLogger("batchglm") |
10 | 11 |
|
11 | 12 |
|
12 | | -def closedform_norm_glm_mean( |
13 | | - x: Union[np.ndarray, scipy.sparse.csr_matrix], |
14 | | - design_loc: np.ndarray, |
15 | | - constraints_loc, |
16 | | - size_factors=None, |
17 | | - link_fn=lambda x: x, |
18 | | - inv_link_fn=lambda x: x, |
19 | | -): |
20 | | - r""" |
21 | | - Calculates a closed-form solution for the `mean` parameters of normal GLMs. |
22 | | -
|
23 | | - :param x: The sample data |
24 | | - :param design_loc: design matrix for location |
25 | | - :param constraints_loc: tensor (all parameters x dependent parameters) |
26 | | - Tensor that encodes how complete parameter set which includes dependent |
27 | | - parameters arises from indepedent parameters: all = <constraints, indep>. |
28 | | - This form of constraints is used in vector generalized linear models (VGLMs). |
29 | | - :param size_factors: size factors for X |
30 | | - :return: tuple: (groupwise_means, mean, rmsd) |
31 | | - """ |
32 | | - return closedform_glm_mean( |
33 | | - x=x, |
34 | | - dmat=design_loc, |
35 | | - constraints=constraints_loc, |
36 | | - size_factors=size_factors, |
37 | | - link_fn=link_fn, |
38 | | - inv_link_fn=inv_link_fn, |
39 | | - ) |
40 | | - |
41 | | - |
42 | 13 | def closedform_norm_glm_logsd( |
43 | | - x: Union[np.ndarray, scipy.sparse.csr_matrix], |
44 | | - design_scale: np.ndarray, |
| 14 | + x: Union[np.ndarray, scipy.sparse.csr_matrix, dask.array.core.Array], |
| 15 | + design_scale: Union[np.ndarray, dask.array.core.Array], |
45 | 16 | constraints=None, |
46 | 17 | size_factors=None, |
47 | 18 | groupwise_means=None, |
@@ -71,3 +42,73 @@ def compute_scales_fun(variance, mean): |
71 | 42 | link_fn=link_fn, |
72 | 43 | compute_scales_fun=compute_scales_fun, |
73 | 44 | ) |
| 45 | + |
| 46 | + |
| 47 | +def init_par(model, init_location: str, init_scale: str) -> Tuple[np.ndarray, np.ndarray, bool, bool]: |
| 48 | + r""" |
| 49 | + standard: |
| 50 | + Only initialise intercept and keep other coefficients as zero. |
| 51 | +
|
| 52 | + closed-form: |
| 53 | + Initialize with Maximum Likelihood / Maximum of Momentum estimators |
| 54 | +
|
| 55 | + Idea: |
| 56 | + $$ |
| 57 | + \theta &= f(x) \\ |
| 58 | + \Rightarrow f^{-1}(\theta) &= x \\ |
| 59 | + &= (D \cdot D^{+}) \cdot x \\ |
| 60 | + &= D \cdot (D^{+} \cdot x) \\ |
| 61 | + &= D \cdot x' = f^{-1}(\theta) |
| 62 | + $$ |
| 63 | + """ |
| 64 | + |
| 65 | + groupwise_means = None |
| 66 | + |
| 67 | + init_location_str = init_location.lower() |
| 68 | + # Chose option if auto was chosen |
| 69 | + auto_or_closed_form = init_location_str == "auto" or init_location_str == "closed_form" |
| 70 | + if auto_or_closed_form or init_location_str == "all_zero": |
| 71 | + if auto_or_closed_form: |
| 72 | + logger.warning( |
| 73 | + ( |
| 74 | + "There is no need for closed form location model initialization" |
| 75 | + "because it is already closed form - falling back to zeros" |
| 76 | + ) |
| 77 | + ) |
| 78 | + init_theta_location = np.zeros([model.num_loc_params, model.num_features]) |
| 79 | + elif init_location_str == "standard": |
| 80 | + overall_means = np.mean(model.x, axis=0) # directly calculate the mean |
| 81 | + init_theta_location = np.zeros([model.num_loc_params, model.num_features]) |
| 82 | + init_theta_location[0, :] = np.log(overall_means) |
| 83 | + else: |
| 84 | + raise ValueError("init_location string %s not recognized" % init_location) |
| 85 | + |
| 86 | + init_scale_str = init_scale.lower() |
| 87 | + if init_scale_str == "auto": |
| 88 | + init_scale_str = "standard" |
| 89 | + |
| 90 | + if init_scale_str == "standard": |
| 91 | + groupwise_scales, init_scale_intercept, rmsd_b = closedform_norm_glm_logsd( |
| 92 | + x=model.x, |
| 93 | + design_scale=model.design_scale[:, [0]], |
| 94 | + constraints=model.constraints_scale[[0], :][:, [0]], |
| 95 | + size_factors=model.size_factors, |
| 96 | + groupwise_means=None, |
| 97 | + link_fn=lambda r: np.log(r + np.nextafter(0, 1, dtype=r.dtype)), |
| 98 | + ) |
| 99 | + init_theta_scale = np.zeros([model.num_scale_params, model.num_features]) |
| 100 | + init_theta_scale[0, :] = init_scale_intercept |
| 101 | + elif init_scale_str == "closed_form": |
| 102 | + groupwise_scales, init_theta_scale, rmsd_b = closedform_norm_glm_logsd( |
| 103 | + x=model.x, |
| 104 | + design_scale=model.design_scale, |
| 105 | + constraints=model.constraints_scale, |
| 106 | + size_factors=model.size_factors, |
| 107 | + groupwise_means=groupwise_means, |
| 108 | + ) |
| 109 | + elif init_scale_str == "all_zero": |
| 110 | + init_theta_scale = np.zeros([model.num_scale_params, model.x.shape[1]]) |
| 111 | + else: |
| 112 | + raise ValueError("init_scale string %s not recognized" % init_scale_str) |
| 113 | + |
| 114 | + return init_theta_location, init_theta_scale, True, True |
0 commit comments