Skip to content
This repository was archived by the owner on Oct 21, 2025. It is now read-only.

Commit 9fa1e0b

Browse files
depreceated usage of fomula in wald()
now have to explicitly use formula_loc. i think this is more clear. the default for scale is now ~1.
1 parent 516489f commit 9fa1e0b

File tree

3 files changed

+90
-74
lines changed

3 files changed

+90
-74
lines changed

diffxpy/testing/base.py

Lines changed: 88 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
11
import abc
22
import logging
33
from typing import Union, Dict, Tuple, List, Set, Callable
4-
54
import pandas as pd
5+
import warnings
66

77
import numpy as np
88
import xarray as xr
9-
9+
import patsy
1010
try:
1111
import anndata
1212
except ImportError:
1313
anndata = None
1414

15-
import patsy
1615
import batchglm.data as data_utils
17-
from batchglm.api.models.glm_nb import Model as GeneralizedLinearModel
16+
from batchglm.models.glm_nb import Model as GeneralizedLinearModel
1817

1918
from ..stats import stats
2019
from . import correction
@@ -129,7 +128,7 @@ def __init__(self):
129128
self._pval = None
130129
self._qval = None
131130
self._mean = None
132-
self._log_probs = None
131+
self._log_likelihood = None
133132

134133
@property
135134
@abc.abstractmethod
@@ -179,10 +178,10 @@ def _ave(self):
179178
pass
180179

181180
@property
182-
def log_probs(self):
183-
if self._log_probs is None:
184-
self._log_probs = self._ll().compute()
185-
return self._log_probs
181+
def log_likelihood(self):
182+
if self._log_likelihood is None:
183+
self._log_likelihood = self._ll().compute()
184+
return self._log_likelihood
186185

187186
@property
188187
def mean(self):
@@ -675,7 +674,7 @@ def _ll(self):
675674
676675
:return: xr.DataArray
677676
"""
678-
return np.sum(self.model_estim.log_likelihood, axis=0)
677+
return self.model_estim.log_likelihood
679678

680679
def _ave(self):
681680
"""
@@ -730,8 +729,8 @@ def summary(self, qval_thres=None, fc_upper_thres=None,
730729
if len(self.theta_sd.shape) == 1:
731730
res["coef_sd"] = self.theta_sd
732731
# add in info from bfgs
733-
if self.log_probs is not None:
734-
res["ll"] = self.log_probs
732+
if self.log_likelihood is not None:
733+
res["ll"] = self.log_likelihood
735734
if self._error_codes is not None:
736735
res["err"] = self._error_codes
737736
if self._niter is not None:
@@ -1243,7 +1242,7 @@ def X(self):
12431242
return self.model_estim.X
12441243

12451244
@property
1246-
def log_probs(self):
1245+
def log_likelihood(self):
12471246
return np.sum(self.model_estim.log_probs(), axis=0)
12481247

12491248
@property
@@ -1448,7 +1447,7 @@ def X(self):
14481447
return self.model_estim.X
14491448

14501449
@property
1451-
def log_probs(self):
1450+
def log_likelihood(self):
14521451
return np.sum(self.model_estim.log_probs(), axis=0)
14531452

14541453
@property
@@ -1912,8 +1911,8 @@ def mean(self) -> np.ndarray:
19121911
return self._de_test.mean
19131912

19141913
@property
1915-
def log_probs(self) -> np.ndarray:
1916-
return self._de_test.log_probs
1914+
def log_likelihood(self) -> np.ndarray:
1915+
return self._de_test.log_likelihood
19171916

19181917
def summary(self, nonnumeric=False, qval_thres=None, fc_upper_thres=None,
19191918
fc_lower_thres=None, mean_thres=None) -> pd.DataFrame:
@@ -2096,6 +2095,8 @@ def plot_genes(
20962095
self,
20972096
genes,
20982097
hue=None,
2098+
size=1,
2099+
log=True,
20992100
nonnumeric=False,
21002101
save=None,
21012102
show=True,
@@ -2108,6 +2109,7 @@ def plot_genes(
21082109
21092110
:param genes: Gene IDs to plot.
21102111
:param hue: Confounder to include in plot.
2112+
:param size: Point size.
21112113
:param nonnumeric:
21122114
:param save: Path+file name stem to save plots to.
21132115
File will be save+"_genes.png". Does not save if save is None.
@@ -2151,21 +2153,33 @@ def plot_genes(
21512153
ax = plt.subplot(gs[i])
21522154
axs.append(ax)
21532155

2156+
y = self.X[:, genes[0]]
2157+
yhat = self._continuous_model(idx=g, nonnumeric=nonnumeric)
2158+
if log:
2159+
y = np.log(y + 1)
2160+
yhat = np.log(yhat + 1)
2161+
21542162
sns.scatterplot(
21552163
x=self._continuous_coords,
2156-
y=self.X[:, genes[0]],
2164+
y=y,
21572165
hue=hue,
2158-
ax=ax
2166+
size=size,
2167+
ax=ax,
2168+
legend=False
21592169
)
21602170
sns.lineplot(
21612171
x=self._continuous_coords,
2162-
y=self._continuous_model(idx=g, nonnumeric=nonnumeric),
2172+
y=yhat,
21632173
hue=hue,
21642174
ax=ax
21652175
)
2176+
21662177
ax.set_title(genes[i])
21672178
ax.set_xlabel("continuous")
2168-
ax.set_ylabel("expression")
2179+
if log:
2180+
ax.set_ylabel("log expression")
2181+
else:
2182+
ax.set_ylabel("expression")
21692183

21702184
# Save, show and return figure.
21712185
if save is not None:
@@ -2536,57 +2550,56 @@ def _fit(
25362550
raise ValueError('base.test(): `noise_model="%s"` not recognized.' % noise_model)
25372551
else:
25382552
if noise_model == "nb" or noise_model == "negative_binomial":
2539-
import batchglm.api.models.glm_nb as test_model
2540-
2541-
logger.info("Fitting model...")
2542-
logger.debug(" * Assembling input data...")
2543-
input_data = test_model.InputData.new(
2544-
data=data,
2545-
design_loc=design_loc,
2546-
design_scale=design_scale,
2547-
constraints_loc=constraints_loc,
2548-
constraints_scale=constraints_scale,
2549-
size_factors=size_factors,
2550-
feature_names=gene_names,
2551-
)
2553+
from batchglm.api.models.glm_nb import Estimator, InputData
2554+
else:
2555+
raise ValueError('base.test(): `noise_model="%s"` not recognized.' % noise_model)
25522556

2553-
logger.debug(" * Set up Estimator...")
2554-
constructor_args = {}
2555-
if batch_size is not None:
2556-
constructor_args["batch_size"] = batch_size
2557-
if quick_scale is not None:
2558-
constructor_args["quick_scale"] = quick_scale
2559-
estim = test_model.Estimator(
2560-
input_data=input_data,
2561-
init_model=init_model,
2562-
init_a=init_a,
2563-
init_b=init_b,
2564-
provide_optimizers=provide_optimizers,
2565-
termination_type=termination_type,
2566-
dtype=dtype,
2567-
**constructor_args
2568-
)
2557+
logger.info("Fitting model...")
2558+
logger.debug(" * Assembling input data...")
2559+
input_data = InputData.new(
2560+
data=data,
2561+
design_loc=design_loc,
2562+
design_scale=design_scale,
2563+
constraints_loc=constraints_loc,
2564+
constraints_scale=constraints_scale,
2565+
size_factors=size_factors,
2566+
feature_names=gene_names,
2567+
)
25692568

2570-
logger.debug(" * Initializing Estimator...")
2571-
estim.initialize()
2569+
logger.debug(" * Set up Estimator...")
2570+
constructor_args = {}
2571+
if batch_size is not None:
2572+
constructor_args["batch_size"] = batch_size
2573+
if quick_scale is not None:
2574+
constructor_args["quick_scale"] = quick_scale
2575+
estim = Estimator(
2576+
input_data=input_data,
2577+
init_model=init_model,
2578+
init_a=init_a,
2579+
init_b=init_b,
2580+
provide_optimizers=provide_optimizers,
2581+
termination_type=termination_type,
2582+
dtype=dtype,
2583+
**constructor_args
2584+
)
25722585

2573-
logger.debug(" * Run estimation...")
2574-
# training:
2575-
if callable(training_strategy):
2576-
# call training_strategy if it is a function
2577-
training_strategy(estim)
2578-
else:
2579-
estim.train_sequence(training_strategy)
2586+
logger.debug(" * Initializing Estimator...")
2587+
estim.initialize()
25802588

2581-
if close_session:
2582-
logger.debug(" * Finalize estimation...")
2583-
model = estim.finalize()
2584-
else:
2585-
model = estim
2586-
logger.debug(" * Model fitting done.")
2589+
logger.debug(" * Run estimation...")
2590+
# training:
2591+
if callable(training_strategy):
2592+
# call training_strategy if it is a function
2593+
training_strategy(estim)
2594+
else:
2595+
estim.train_sequence(training_strategy=training_strategy)
25872596

2597+
if close_session:
2598+
logger.debug(" * Finalize estimation...")
2599+
model = estim.finalize()
25882600
else:
2589-
raise ValueError('base.test(): `noise_model="%s"` not recognized.' % noise_model)
2601+
model = estim
2602+
logger.debug(" * Model fitting done.")
25902603

25912604
return model
25922605

@@ -2788,10 +2801,9 @@ def lrt(
27882801
def wald(
27892802
data,
27902803
factor_loc_totest: Union[str, List[str]] = None,
2791-
coef_to_test: Union[str, List[str]] = None, # e.g. coef_to_test="B"
2792-
formula: str = None,
2804+
coef_to_test: Union[str, List[str]] = None,
27932805
formula_loc: str = None,
2794-
formula_scale: str = None,
2806+
formula_scale: str = "~1",
27952807
as_numeric: Union[List[str], Tuple[str], str] = (),
27962808
init_a: Union[np.ndarray, str] = "AUTO",
27972809
init_b: Union[np.ndarray, str] = "AUTO",
@@ -2918,10 +2930,6 @@ def wald(
29182930
if len(kwargs) != 0:
29192931
logger.debug("additional kwargs: %s", str(kwargs))
29202932

2921-
if formula_loc is None:
2922-
formula_loc = formula
2923-
if formula_scale is None:
2924-
formula_scale = formula
29252933
if dmat_loc is None and formula_loc is None:
29262934
raise ValueError("Supply either dmat_loc or formula_loc or formula.")
29272935
if dmat_scale is None and formula_scale is None:
@@ -2959,7 +2967,8 @@ def wald(
29592967
else:
29602968
design_scale = dmat_scale
29612969

2962-
# Coefficients to test:
2970+
# Define indices of coefficients to test:
2971+
contraints_loc_temp = constraints_loc if constraints_loc is not None else np.eye(design_loc.shape[-1])
29632972
if factor_loc_totest is not None:
29642973
# Select coefficients to test via formula model:
29652974
col_indices = np.concatenate([
@@ -2992,6 +3001,12 @@ def wald(
29923001
])
29933002
else:
29943003
raise ValueError("either set factor_loc_totest or coef_to_test")
3004+
# Check that all tested coefficients are independent:
3005+
for x in col_indices:
3006+
if np.sum(contraints_loc_temp[x,:]) != 1:
3007+
raise ValueError("Constraints input is wrong: not all tested coefficients are unconstrained.")
3008+
# Adjust tested coefficients from dependent to independent (fitted) parameters:
3009+
col_indices = np.array([np.where(contraints_loc_temp[x,:] == 1)[0][0] for x in col_indices])
29953010

29963011
## Fit GLM:
29973012
model = _fit(

diffxpy/unit_test/test_continuous.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def test_null_distribution_wald(self, n_cells: int = 2000, n_genes: int = 100):
8181
:param n_genes: Number of genes to simulate (number of tests).
8282
"""
8383
logging.getLogger("tensorflow").setLevel(logging.INFO)
84-
logging.getLogger("batchglm").setLevel(logging.INFO)
84+
logging.getLogger("batchglm").setLevel(logging.WARNING)
8585
logging.getLogger("diffxpy").setLevel(logging.WARNING)
8686

8787
sim = Simulator(num_observations=n_cells, num_features=n_genes)

diffxpy/unit_test/test_single.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def test_null_distribution_wald(self, n_cells: int = 2000, n_genes: int = 100):
3838
factor_loc_totest="condition",
3939
formula="~ 1 + condition + batch",
4040
sample_description=random_sample_description,
41+
batch_size=500,
4142
training_strategy="DEFAULT",
4243
dtype="float64"
4344
)

0 commit comments

Comments
 (0)