Skip to content
This repository was archived by the owner on Oct 21, 2025. It is now read-only.

Commit 5019c98

Browse files
changed batchglm input object typing interface
1 parent 964d93f commit 5019c98

File tree

4 files changed

+41
-35
lines changed

4 files changed

+41
-35
lines changed

diffxpy/fit/fit.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
def model(
20-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
20+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBase],
2121
formula_loc: Union[None, str] = None,
2222
formula_scale: Union[None, str] = "~1",
2323
as_numeric: Union[List[str], Tuple[str], str] = (),
@@ -226,7 +226,7 @@ def model(
226226

227227

228228
def residuals(
229-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
229+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBase],
230230
formula_loc: Union[None, str] = None,
231231
formula_scale: Union[None, str] = "~1",
232232
as_numeric: Union[List[str], Tuple[str], str] = (),
@@ -400,7 +400,7 @@ def residuals(
400400

401401

402402
def partition(
403-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
403+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBase],
404404
parts: Union[str, np.ndarray, list],
405405
gene_names: Union[np.ndarray, list] = None,
406406
sample_description: pd.DataFrame = None,
@@ -454,7 +454,7 @@ class _Partition:
454454

455455
def __init__(
456456
self,
457-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
457+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBase],
458458
parts: Union[str, np.ndarray, list],
459459
gene_names: Union[np.ndarray, list] = None,
460460
sample_description: pd.DataFrame = None,
@@ -481,7 +481,7 @@ def __init__(
481481
same order as in data or string-type column identifier of size-factor containing
482482
column in sample description.
483483
"""
484-
if isinstance(data, glm.typing.InputDataBaseTyping):
484+
if isinstance(data, glm.typing.InputDataBase):
485485
self.x = data.x
486486
elif isinstance(data, anndata.AnnData) or isinstance(data, Raw):
487487
self.x = data.X

diffxpy/testing/det.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1544,7 +1544,7 @@ def __init__(
15441544
super().__init__()
15451545
if isinstance(data, anndata.AnnData) or isinstance(data, anndata.Raw):
15461546
data = data.X
1547-
elif isinstance(data, glm.typing.InputDataBaseTyping):
1547+
elif isinstance(data, glm.typing.InputDataBase):
15481548
data = data.x
15491549
self._x = data
15501550
self.sample_description = sample_description
@@ -1669,7 +1669,7 @@ def __init__(
16691669
super().__init__()
16701670
if isinstance(data, anndata.AnnData) or isinstance(data, anndata.Raw):
16711671
data = data.X
1672-
elif isinstance(data, glm.typing.InputDataBaseTyping):
1672+
elif isinstance(data, glm.typing.InputDataBase):
16731673
data = data.x
16741674
self._x = data
16751675
self.sample_description = sample_description
@@ -2103,7 +2103,7 @@ def __init__(
21032103
self.groups = list(np.asarray(groups))
21042104

21052105
# values of parameter estimates: coefficients x genes array with one coefficient per group
2106-
self._theta_mle = model_estim.par_link_loc
2106+
self._theta_mle = model_estim.a_var
21072107
# standard deviation of estimates: coefficients x genes array with one coefficient per group
21082108
# theta_sd = sqrt(diagonal(fisher_inv))
21092109
self._theta_sd = np.sqrt(np.diagonal(model_estim.fisher_inv, axis1=-2, axis2=-1)).T
@@ -2349,7 +2349,6 @@ def _test(self, **kwargs):
23492349

23502350
def _test_pairs(self, groups0, groups1):
23512351
num_features = self.model_estim.x.shape[1]
2352-
23532352
pvals = np.tile(np.NaN, [len(groups0), len(groups1), num_features])
23542353

23552354
for i, g0 in enumerate(groups0):

diffxpy/testing/tests.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _fit(
3939
quick_scale: bool = None,
4040
close_session=True,
4141
dtype="float64"
42-
) -> glm.typing.InputDataBaseTyping:
42+
) -> glm.typing.InputDataBase:
4343
"""
4444
:param noise_model: str, noise model to use in model-based unit_test. Possible options:
4545
@@ -186,7 +186,7 @@ def _fit(
186186

187187

188188
def lrt(
189-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
189+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBase],
190190
full_formula_loc: str,
191191
reduced_formula_loc: str,
192192
full_formula_scale: str = "~1",
@@ -370,7 +370,7 @@ def lrt(
370370

371371

372372
def wald(
373-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
373+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBase],
374374
factor_loc_totest: Union[str, List[str]] = None,
375375
coef_to_test: Union[str, List[str]] = None,
376376
formula_loc: Union[None, str] = None,
@@ -547,7 +547,7 @@ def wald(
547547
if isinstance(as_numeric, str):
548548
as_numeric = [as_numeric]
549549

550-
# # Parse input data formats:
550+
# Parse input data formats:
551551
gene_names = parse_gene_names(data, gene_names)
552552
if dmat_loc is None and dmat_scale is None:
553553
sample_description = parse_sample_description(data, sample_description)
@@ -644,7 +644,7 @@ def wald(
644644

645645

646646
def t_test(
647-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
647+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBase],
648648
grouping,
649649
gene_names: Union[np.ndarray, list] = None,
650650
sample_description: pd.DataFrame = None,
@@ -686,7 +686,7 @@ def t_test(
686686

687687

688688
def rank_test(
689-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
689+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBase],
690690
grouping: Union[str, np.ndarray, list],
691691
gene_names: Union[np.ndarray, list] = None,
692692
sample_description: pd.DataFrame = None,
@@ -728,7 +728,7 @@ def rank_test(
728728

729729

730730
def two_sample(
731-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
731+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBase],
732732
grouping: Union[str, np.ndarray, list],
733733
as_numeric: Union[List[str], Tuple[str], str] = (),
734734
test: str = "t-test",
@@ -883,16 +883,14 @@ def two_sample(
883883
data=data,
884884
gene_names=gene_names,
885885
grouping=grouping,
886-
is_sig_zerovar=is_sig_zerovar,
887-
dtype=dtype
886+
is_sig_zerovar=is_sig_zerovar
888887
)
889888
elif test.lower() == 'rank':
890889
de_test = rank_test(
891890
data=data,
892891
gene_names=gene_names,
893892
grouping=grouping,
894-
is_sig_zerovar=is_sig_zerovar,
895-
dtype=dtype
893+
is_sig_zerovar=is_sig_zerovar
896894
)
897895
else:
898896
raise ValueError('two_sample(): Parameter `test="%s"` not recognized.' % test)
@@ -901,7 +899,7 @@ def two_sample(
901899

902900

903901
def pairwise(
904-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
902+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBase],
905903
grouping: Union[str, np.ndarray, list],
906904
as_numeric: Union[List[str], Tuple[str], str] = (),
907905
test: str = 'z-test',
@@ -1036,6 +1034,8 @@ def pairwise(
10361034
design_scale=dmat,
10371035
gene_names=gene_names,
10381036
size_factors=size_factors,
1037+
init_a="closed_form",
1038+
init_b="closed_form",
10391039
batch_size=batch_size,
10401040
training_strategy=training_strategy,
10411041
quick_scale=quick_scale,
@@ -1058,6 +1058,10 @@ def pairwise(
10581058
correction_type=pval_correction
10591059
)
10601060
else:
1061+
if isinstance(data, anndata.AnnData) or isinstance(data, anndata.Raw):
1062+
data = data.X
1063+
elif isinstance(data, glm.typing.InputDataBase):
1064+
data = data.x
10611065
groups = np.unique(grouping)
10621066
pvals = np.tile(np.NaN, [len(groups), len(groups), data.shape[1]])
10631067
pvals[np.eye(pvals.shape[0]).astype(bool)] = 0
@@ -1073,16 +1077,19 @@ def pairwise(
10731077
for j, g2 in enumerate(groups[(i + 1):]):
10741078
j = j + i + 1
10751079

1076-
sel = (grouping == g1) | (grouping == g2)
1080+
idx = np.where(np.logical_or(
1081+
grouping == g1,
1082+
grouping == g2
1083+
))[0]
10771084
de_test_temp = two_sample(
1078-
data=data[sel],
1079-
grouping=grouping[sel],
1085+
data=data[idx],
1086+
grouping=grouping[idx],
10801087
as_numeric=as_numeric,
10811088
test=test,
10821089
gene_names=gene_names,
1083-
sample_description=sample_description.iloc[sel],
1090+
sample_description=sample_description.iloc[idx, :],
10841091
noise_model=noise_model,
1085-
size_factors=size_factors[sel] if size_factors is not None else None,
1092+
size_factors=size_factors[idx] if size_factors is not None else None,
10861093
batch_size=batch_size,
10871094
training_strategy=training_strategy,
10881095
quick_scale=quick_scale,
@@ -1112,7 +1119,7 @@ def pairwise(
11121119

11131120

11141121
def versus_rest(
1115-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
1122+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBase],
11161123
grouping: Union[str, np.ndarray, list],
11171124
as_numeric: Union[List[str], Tuple[str], str] = (),
11181125
test: str = 'wald',
@@ -1274,7 +1281,7 @@ def versus_rest(
12741281

12751282

12761283
def partition(
1277-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
1284+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBase],
12781285
parts: Union[str, np.ndarray, list],
12791286
gene_names: Union[np.ndarray, list] = None,
12801287
sample_description: pd.DataFrame = None
@@ -1317,7 +1324,7 @@ class _Partition:
13171324

13181325
def __init__(
13191326
self,
1320-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
1327+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBase],
13211328
parts: Union[str, np.ndarray, list],
13221329
gene_names: Union[np.ndarray, list] = None,
13231330
sample_description: pd.DataFrame = None
@@ -1332,7 +1339,7 @@ def __init__(
13321339
:param gene_names: optional list/array of gene names which will be used if `data` does not implicitly store these
13331340
:param sample_description: optional pandas.DataFrame containing sample annotations
13341341
"""
1335-
if isinstance(data, glm.typing.InputDataBaseTyping):
1342+
if isinstance(data, glm.typing.InputDataBase):
13361343
self.x = data.x
13371344
elif isinstance(data, anndata.AnnData) or isinstance(data, Raw):
13381345
self.x = data.X

diffxpy/testing/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@
1818

1919

2020
def parse_gene_names(
21-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
21+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBase],
2222
gene_names: Union[list, np.ndarray, None]
2323
):
2424
if gene_names is None:
2525
if anndata is not None and (isinstance(data, anndata.AnnData) or isinstance(data, Raw)):
2626
gene_names = data.var_names
27-
elif isinstance(data, glm.typing.InputDataBaseTyping):
27+
elif isinstance(data, glm.typing.InputDataBase):
2828
gene_names = data.features
2929
else:
3030
raise ValueError("Missing gene names")
@@ -33,7 +33,7 @@ def parse_gene_names(
3333

3434

3535
def parse_sample_description(
36-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
36+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBase],
3737
sample_description: Union[pd.DataFrame, None]
3838
) -> pd.DataFrame:
3939
"""
@@ -57,7 +57,7 @@ def parse_sample_description(
5757
assert data.X.shape[0] == sample_description.shape[0], \
5858
"data matrix and sample description must contain same number of cells: %i, %i" % \
5959
(data.X.shape[0], sample_description.shape[0])
60-
elif isinstance(data, glm.typing.InputDataBaseTyping):
60+
elif isinstance(data, glm.typing.InputDataBase):
6161
assert data.x.shape[0] == sample_description.shape[0], \
6262
"data matrix and sample description must contain same number of cells: %i, %i" % \
6363
(data.x.shape[0], sample_description.shape[0])
@@ -70,7 +70,7 @@ def parse_sample_description(
7070

7171
def parse_size_factors(
7272
size_factors: Union[np.ndarray, pd.core.series.Series, np.ndarray],
73-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
73+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBase],
7474
sample_description: pd.DataFrame
7575
) -> Union[np.ndarray, None]:
7676
"""

0 commit comments

Comments
 (0)