Skip to content
This repository was archived by the owner on Oct 21, 2025. It is now read-only.

Commit 357df55

Browse files
refactored utils from de.testing into separate file
1 parent 4c085d3 commit 357df55

File tree

5 files changed

+199
-178
lines changed

5 files changed

+199
-178
lines changed

diffxpy/api/test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
from diffxpy.testing.tests import design_matrix, coef_names, lrt, wald, t_test, rank_test, two_sample, pairwise, \
1+
from diffxpy.testing import lrt, wald, t_test, rank_test, two_sample, pairwise, \
22
versus_rest, partition, continuous_1d
3+
from diffxpy.testing import design_matrix, coef_names

diffxpy/testing/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .tests import lrt, wald, t_test, rank_test, two_sample, pairwise, \
2+
versus_rest, partition, continuous_1d
3+
from .utils import design_matrix, coef_names

diffxpy/testing/det.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import xarray as xr
88
import patsy
99

10-
from diffxpy.testing.tests import _split_X, t_test
10+
from .utils import split_X, dmat_unique
1111

1212
try:
1313
import anndata
@@ -20,16 +20,7 @@
2020
from . import correction
2121
from diffxpy import pkg_constants
2222

23-
logger = logging.getLogger(__name__)
24-
25-
# Use this to suppress matrix subclass PendingDepreceationWarnings from numpy:
26-
np.warnings.filterwarnings("ignore")
27-
28-
def _dmat_unique(dmat, sample_description):
29-
dmat, idx = np.unique(dmat, axis=0, return_index=True)
30-
sample_description = sample_description.iloc[idx].reset_index(drop=True)
31-
32-
return dmat, sample_description
23+
logger = logging.getLogger("diffxpy")
3324

3425

3526
class _Estimation(GeneralizedLinearModel, metaclass=abc.ABCMeta):
@@ -614,7 +605,7 @@ def _log_fold_change(self, factors: Union[Dict, Tuple, Set, List], base=np.e):
614605
dmat = self.full_estim.design_loc
615606

616607
# make rows unique
617-
dmat, sample_description = _dmat_unique(dmat, sample_description)
608+
dmat, sample_description = dmat_unique(dmat, sample_description)
618609

619610
# factors = factors.intersection(di.term_names)
620611

@@ -628,7 +619,7 @@ def _log_fold_change(self, factors: Union[Dict, Tuple, Set, List], base=np.e):
628619
dmat[:, neg_sel] = 0
629620

630621
# make the design matrix + sample description unique again
631-
dmat, sample_description = _dmat_unique(dmat, sample_description)
622+
dmat, sample_description = dmat_unique(dmat, sample_description)
632623

633624
locations = self.full_estim.inverse_link_loc(dmat.dot(self.full_estim.par_link_loc))
634625
locations = np.log(locations) / np.log(base)
@@ -696,7 +687,7 @@ def locations(self):
696687
sample_description = self.sample_description[[f.name() for f in di.factor_infos]]
697688
dmat = self.full_estim.design_loc
698689

699-
dmat, sample_description = _dmat_unique(dmat, sample_description)
690+
dmat, sample_description = dmat_unique(dmat, sample_description)
700691

701692
retval = self.full_estim.inverse_link_loc(dmat.dot(self.full_estim.par_link_loc))
702693
retval = pd.DataFrame(retval, columns=self.full_estim.features)
@@ -718,7 +709,7 @@ def scales(self):
718709
sample_description = self.sample_description[[f.name() for f in di.factor_infos]]
719710
dmat = self.full_estim.design_scale
720711

721-
dmat, sample_description = _dmat_unique(dmat, sample_description)
712+
dmat, sample_description = dmat_unique(dmat, sample_description)
722713

723714
retval = self.full_estim.inverse_link_scale(dmat.doc(self.full_estim.par_link_scale))
724715
retval = pd.DataFrame(retval, columns=self.full_estim.features)
@@ -901,6 +892,7 @@ def summary(self, qval_thres=None, fc_upper_thres=None,
901892
def plot_vs_ttest(self, log10=False):
902893
import matplotlib.pyplot as plt
903894
import seaborn as sns
895+
from .tests import t_test
904896

905897
grouping = np.asarray(self.model_estim.design_loc[:, self.coef_loc_totest])
906898
ttest = t_test(
@@ -935,7 +927,7 @@ def __init__(self, data, grouping, gene_names, is_logged):
935927
self.grouping = grouping
936928
self._gene_names = np.asarray(gene_names)
937929

938-
x0, x1 = _split_X(data, grouping)
930+
x0, x1 = split_X(data, grouping)
939931

940932
# Only compute p-values for genes with non-zero observations and non-zero group-wise variance.
941933
mean_x0 = x0.mean(axis=0).astype(dtype=np.float)
@@ -1040,7 +1032,7 @@ def __init__(self, data, grouping, gene_names, is_logged):
10401032
self.grouping = grouping
10411033
self._gene_names = np.asarray(gene_names)
10421034

1043-
x0, x1 = _split_X(data, grouping)
1035+
x0, x1 = split_X(data, grouping)
10441036

10451037
mean_x0 = x0.mean(axis=0).astype(dtype=np.float)
10461038
mean_x1 = x1.mean(axis=0).astype(dtype=np.float)
@@ -1118,6 +1110,7 @@ def summary(self, qval_thres=None, fc_upper_thres=None,
11181110
def plot_vs_ttest(self, log10=False):
11191111
import matplotlib.pyplot as plt
11201112
import seaborn as sns
1113+
from .tests import t_test
11211114

11221115
grouping = self.grouping
11231116
ttest = t_test(

diffxpy/testing/tests.py

Lines changed: 39 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Union, List, Dict, Callable, Tuple
22

3+
import anndata
4+
import logging
35
import numpy as np
46
import pandas as pd
57
import patsy
@@ -10,126 +12,17 @@
1012
from batchglm.xarray_sparse import SparseXArrayDataSet
1113
from diffxpy import pkg_constants
1214
from diffxpy.models.batch_bfgs.optim import Estim_BFGS
13-
from diffxpy.testing.det import anndata, logger, DifferentialExpressionTestLRT, DifferentialExpressionTestWald, \
15+
from .det import DifferentialExpressionTestLRT, DifferentialExpressionTestWald, \
1416
DifferentialExpressionTestTT, DifferentialExpressionTestRank, _DifferentialExpressionTestSingle, \
1517
DifferentialExpressionTestZTestLazy, DifferentialExpressionTestZTest, DifferentialExpressionTestPairwise, \
1618
DifferentialExpressionTestVsRest, _DifferentialExpressionTestMulti, DifferentialExpressionTestByPartition, \
1719
DifferentialExpressionTestWaldCont, DifferentialExpressionTestLRTCont
20+
from .utils import parse_gene_names, parse_data, parse_sample_description, parse_size_factors, parse_grouping
1821

22+
logger = logging.getLogger("diffxpy")
1923

20-
def _parse_gene_names(data, gene_names):
21-
if gene_names is None:
22-
if anndata is not None and (isinstance(data, anndata.AnnData) or isinstance(data, anndata.base.Raw)):
23-
gene_names = data.var_names
24-
elif isinstance(data, xr.DataArray):
25-
gene_names = data["features"]
26-
elif isinstance(data, xr.Dataset):
27-
gene_names = data["features"]
28-
else:
29-
raise ValueError("Missing gene names")
30-
31-
return np.asarray(gene_names)
32-
33-
34-
def _parse_data(data, gene_names) -> xr.DataArray:
35-
X = data_utils.xarray_from_data(data, dims=("observations", "features"))
36-
if gene_names is not None:
37-
X.coords["features"] = gene_names
38-
39-
return X
40-
41-
42-
def _parse_sample_description(data, sample_description=None) -> pd.DataFrame:
43-
if sample_description is None:
44-
if anndata is not None and isinstance(data, anndata.AnnData):
45-
sample_description = data_utils.sample_description_from_anndata(
46-
dataset=data,
47-
)
48-
elif isinstance(data, xr.Dataset):
49-
sample_description = data_utils.sample_description_from_xarray(
50-
dataset=data,
51-
dim="observations",
52-
)
53-
else:
54-
raise ValueError(
55-
"Please specify `sample_description` or provide `data` as xarray.Dataset or anndata.AnnData " +
56-
"with corresponding sample annotations"
57-
)
58-
59-
if anndata is not None and isinstance(data, anndata.base.Raw):
60-
# anndata.base.Raw does not have attribute shape.
61-
assert data.X.shape[0] == sample_description.shape[0], \
62-
"data matrix and sample description must contain same number of cells"
63-
else:
64-
assert data.shape[0] == sample_description.shape[0], \
65-
"data matrix and sample description must contain same number of cells"
66-
return sample_description
67-
68-
69-
def _parse_size_factors(size_factors, data):
70-
if size_factors is not None:
71-
if isinstance(size_factors, pd.core.series.Series):
72-
size_factors = size_factors.values
73-
assert size_factors.shape[0] == data.shape[0], "data matrix and size factors must contain same number of cells"
74-
return size_factors
75-
76-
77-
def design_matrix(
78-
data=None,
79-
sample_description: pd.DataFrame = None,
80-
formula: str = None,
81-
dmat: pd.DataFrame = None
82-
) -> Union[patsy.design_info.DesignMatrix, xr.Dataset]:
83-
""" Build design matrix for fit of generalized linear model.
84-
85-
This is necessary for wald tests and likelihood ratio tests.
86-
This function only carries through formatting if dmat is directly supplied.
87-
88-
:param data: input data
89-
:param formula: model formula.
90-
:param sample_description: optional pandas.DataFrame containing sample annotations
91-
:param dmat: model design matrix
92-
"""
93-
if data is None and sample_description is None and dmat is None:
94-
raise ValueError("Supply either data or sample_description or dmat.")
95-
if dmat is None and formula is None:
96-
raise ValueError("Supply either dmat or formula.")
97-
98-
if dmat is None:
99-
sample_description = _parse_sample_description(data, sample_description)
100-
dmat = data_utils.design_matrix(sample_description=sample_description, formula=formula)
101-
102-
return dmat
103-
else:
104-
ar = xr.DataArray(dmat, dims=("observations", "design_params"))
105-
ar.coords["design_params"] = dmat.columns
106-
107-
ds = xr.Dataset({
108-
"design": ar,
109-
})
110-
111-
return ds
112-
113-
114-
def coef_names(
115-
data=None,
116-
sample_description: pd.DataFrame = None,
117-
formula: str = None,
118-
dmat: pd.DataFrame = None
119-
) -> list:
120-
""" Output coefficient names of model only.
121-
122-
:param data: input data
123-
:param formula: model formula.
124-
:param sample_description: optional pandas.DataFrame containing sample annotations
125-
:param dmat: model design matrix
126-
"""
127-
return design_matrix(
128-
data=data,
129-
sample_description=sample_description,
130-
formula=formula,
131-
dmat=dmat
132-
).design_info.column_names
24+
# Use this to suppress matrix subclass PendingDepreceationWarnings from numpy:
25+
np.warnings.filterwarnings("ignore")
13326

13427

13528
def _fit(
@@ -408,10 +301,10 @@ def lrt(
408301
if isinstance(as_numeric, str):
409302
as_numeric = [as_numeric]
410303

411-
gene_names = _parse_gene_names(data, gene_names)
412-
X = _parse_data(data, gene_names)
413-
sample_description = _parse_sample_description(data, sample_description)
414-
size_factors = _parse_size_factors(size_factors=size_factors, data=X)
304+
gene_names = parse_gene_names(data, gene_names)
305+
X = parse_data(data, gene_names)
306+
sample_description = parse_sample_description(data, sample_description)
307+
size_factors = parse_size_factors(size_factors=size_factors, data=X)
415308

416309
full_design_loc = data_utils.design_matrix(
417310
sample_description=sample_description,
@@ -623,11 +516,11 @@ def wald(
623516
as_numeric = [as_numeric]
624517

625518
# # Parse input data formats:
626-
gene_names = _parse_gene_names(data, gene_names)
627-
X = _parse_data(data, gene_names)
519+
gene_names = parse_gene_names(data, gene_names)
520+
X = parse_data(data, gene_names)
628521
if dmat_loc is None and dmat_scale is None:
629-
sample_description = _parse_sample_description(data, sample_description)
630-
size_factors = _parse_size_factors(size_factors=size_factors, data=X)
522+
sample_description = parse_sample_description(data, sample_description)
523+
size_factors = parse_size_factors(size_factors=size_factors, data=X)
631524

632525
if dmat_loc is None:
633526
design_loc = data_utils.design_matrix(
@@ -714,20 +607,6 @@ def wald(
714607
return de_test
715608

716609

717-
def _parse_grouping(data, sample_description, grouping):
718-
if isinstance(grouping, str):
719-
sample_description = _parse_sample_description(data, sample_description)
720-
grouping = sample_description[grouping]
721-
return np.squeeze(np.asarray(grouping))
722-
723-
724-
def _split_X(data, grouping):
725-
groups = np.unique(grouping)
726-
x0 = data[np.where(grouping == groups[0])[0]]
727-
x1 = data[np.where(grouping == groups[1])[0]]
728-
return x0, x1
729-
730-
731610
def t_test(
732611
data: Union[anndata.AnnData, anndata.base.Raw, xr.DataArray, xr.Dataset, np.ndarray, scipy.sparse.csr_matrix],
733612
grouping,
@@ -752,11 +631,11 @@ def t_test(
752631
Whether data is already logged. If True, log-fold changes are computed as fold changes on this data.
753632
If False, log-fold changes are computed as log-fold changes on this data.
754633
"""
755-
gene_names = _parse_gene_names(data, gene_names)
756-
X = _parse_data(data, gene_names)
634+
gene_names = parse_gene_names(data, gene_names)
635+
X = parse_data(data, gene_names)
757636
if isinstance(X, SparseXArrayDataSet):
758637
X = X.X
759-
grouping = _parse_grouping(data, sample_description, grouping)
638+
grouping = parse_grouping(data, sample_description, grouping)
760639

761640
de_test = DifferentialExpressionTestTT(
762641
data=X.astype(dtype),
@@ -792,11 +671,11 @@ def rank_test(
792671
Whether data is already logged. If True, log-fold changes are computed as fold changes on this data.
793672
If False, log-fold changes are computed as log-fold changes on this data.
794673
"""
795-
gene_names = _parse_gene_names(data, gene_names)
796-
X = _parse_data(data, gene_names)
674+
gene_names = parse_gene_names(data, gene_names)
675+
X = parse_data(data, gene_names)
797676
if isinstance(X, SparseXArrayDataSet):
798677
X = X.X
799-
grouping = _parse_grouping(data, sample_description, grouping)
678+
grouping = parse_grouping(data, sample_description, grouping)
800679

801680
de_test = DifferentialExpressionTestRank(
802681
data=X.astype(dtype),
@@ -910,9 +789,9 @@ def two_sample(
910789
raise ValueError('base.two_sample(): Do not specify `noise_model` if using test t-test or wilcoxon: ' +
911790
'The t-test is based on a gaussian noise model and wilcoxon is model free.')
912791

913-
gene_names = _parse_gene_names(data, gene_names)
914-
X = _parse_data(data, gene_names)
915-
grouping = _parse_grouping(data, sample_description, grouping)
792+
gene_names = parse_gene_names(data, gene_names)
793+
X = parse_data(data, gene_names)
794+
grouping = parse_grouping(data, sample_description, grouping)
916795
sample_description = pd.DataFrame({"grouping": grouping})
917796

918797
groups = np.unique(grouping)
@@ -1116,10 +995,10 @@ def pairwise(
1116995

1117996
# Do not store all models but only p-value and q-value matrix:
1118997
# genes x groups x groups
1119-
gene_names = _parse_gene_names(data, gene_names)
1120-
X = _parse_data(data, gene_names)
1121-
sample_description = _parse_sample_description(data, sample_description)
1122-
grouping = _parse_grouping(data, sample_description, grouping)
998+
gene_names = parse_gene_names(data, gene_names)
999+
X = parse_data(data, gene_names)
1000+
sample_description = parse_sample_description(data, sample_description)
1001+
grouping = parse_grouping(data, sample_description, grouping)
11231002
sample_description = pd.DataFrame({"grouping": grouping})
11241003

11251004
if test.lower() == 'z-test' or test.lower() == 'z_test' or test.lower() == 'ztest':
@@ -1324,10 +1203,10 @@ def versus_rest(
13241203

13251204
# Do not store all models but only p-value and q-value matrix:
13261205
# genes x groups
1327-
gene_names = _parse_gene_names(data, gene_names)
1328-
X = _parse_data(data, gene_names)
1329-
sample_description = _parse_sample_description(data, sample_description)
1330-
grouping = _parse_grouping(data, sample_description, grouping)
1206+
gene_names = parse_gene_names(data, gene_names)
1207+
X = parse_data(data, gene_names)
1208+
sample_description = parse_sample_description(data, sample_description)
1209+
grouping = parse_grouping(data, sample_description, grouping)
13311210
sample_description = pd.DataFrame({"grouping": grouping})
13321211

13331212
groups = np.unique(grouping)
@@ -1427,10 +1306,10 @@ def __init__(
14271306
:param gene_names: optional list/array of gene names which will be used if `data` does not implicitly store these
14281307
:param sample_description: optional pandas.DataFrame containing sample annotations
14291308
"""
1430-
self.X = _parse_data(data, gene_names)
1431-
self.gene_names = _parse_gene_names(data, gene_names)
1432-
self.sample_description = _parse_sample_description(data, sample_description)
1433-
self.partition = _parse_grouping(data, sample_description, partition)
1309+
self.X = parse_data(data, gene_names)
1310+
self.gene_names = parse_gene_names(data, gene_names)
1311+
self.sample_description = parse_sample_description(data, sample_description)
1312+
self.partition = parse_grouping(data, sample_description, partition)
14341313
self.partitions = np.unique(self.partition)
14351314
self.partition_idx = [np.where(self.partition == x)[0] for x in self.partitions]
14361315

@@ -1874,9 +1753,9 @@ def continuous_1d(
18741753
if isinstance(as_numeric, tuple):
18751754
as_numeric = list(as_numeric)
18761755

1877-
X = _parse_data(data, gene_names)
1878-
gene_names = _parse_gene_names(data, gene_names)
1879-
sample_description = _parse_sample_description(data, sample_description)
1756+
X = parse_data(data, gene_names)
1757+
gene_names = parse_gene_names(data, gene_names)
1758+
sample_description = parse_sample_description(data, sample_description)
18801759

18811760
# Check that continuous factor is contained in sample description
18821761
if continuous not in sample_description.columns:

0 commit comments

Comments
 (0)