Skip to content
This repository was archived by the owner on Oct 21, 2025. It is now read-only.

Commit 0249161

Browse files
Merge pull request #107 from theislab/dev
Dev
2 parents 118a607 + 964d93f commit 0249161

File tree

10 files changed

+103
-73
lines changed

10 files changed

+103
-73
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@ resources/*
1414
*/*.ipynb_checkpoints/
1515
**/.DS_Store
1616
docs/_templates/
17-
17+
dist/
1818
!**/.gitignore
1919

diffxpy/fit/fit.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,21 @@
33
from anndata.base import Raw
44
except ImportError:
55
from anndata import Raw
6+
import batchglm.api as glm
67
import logging
78
import numpy as np
89
import pandas as pd
910
import patsy
1011
import scipy.sparse
1112
from typing import Union, List, Dict, Callable, Tuple
1213

13-
from batchglm.models.base import _InputDataBase
1414
from .external import _fit
1515
from .external import parse_gene_names, parse_sample_description, parse_size_factors, parse_grouping, \
1616
constraint_system_from_star
1717

1818

1919
def model(
20-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, _InputDataBase],
20+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
2121
formula_loc: Union[None, str] = None,
2222
formula_scale: Union[None, str] = "~1",
2323
as_numeric: Union[List[str], Tuple[str], str] = (),
@@ -164,6 +164,10 @@ def model(
164164
165165
Should be "float32" for single precision or "float64" for double precision.
166166
:param kwargs: [Debugging] Additional arguments will be passed to the _fit method.
167+
:return:
168+
An estimator instance that contains all estimation relevant attributes and the model in estim.model.
169+
The attributes of the model depend on the noise model and the covariates used.
170+
We provide documentation for the model class in the model section of the documentation.
167171
"""
168172
if len(kwargs) != 0:
169173
logging.getLogger("diffxpy").debug("additional kwargs: %s", str(kwargs))
@@ -222,7 +226,7 @@ def model(
222226

223227

224228
def residuals(
225-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, _InputDataBase],
229+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
226230
formula_loc: Union[None, str] = None,
227231
formula_scale: Union[None, str] = "~1",
228232
as_numeric: Union[List[str], Tuple[str], str] = (),
@@ -396,7 +400,7 @@ def residuals(
396400

397401

398402
def partition(
399-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, _InputDataBase],
403+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
400404
parts: Union[str, np.ndarray, list],
401405
gene_names: Union[np.ndarray, list] = None,
402406
sample_description: pd.DataFrame = None,
@@ -450,7 +454,7 @@ class _Partition:
450454

451455
def __init__(
452456
self,
453-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, _InputDataBase],
457+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
454458
parts: Union[str, np.ndarray, list],
455459
gene_names: Union[np.ndarray, list] = None,
456460
sample_description: pd.DataFrame = None,
@@ -477,7 +481,7 @@ def __init__(
477481
same order as in data or string-type column identifier of size-factor containing
478482
column in sample description.
479483
"""
480-
if isinstance(data, _InputDataBase):
484+
if isinstance(data, glm.typing.InputDataBaseTyping):
481485
self.x = data.x
482486
elif isinstance(data, anndata.AnnData) or isinstance(data, Raw):
483487
self.x = data.X

diffxpy/testing/det.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,18 @@
11
import abc
2+
try:
3+
import anndata
4+
except ImportError:
5+
anndata = None
6+
import batchglm.api as glm
27
import logging
3-
from typing import Union, Dict, Tuple, List, Set
8+
import numpy as np
9+
import patsy
410
import pandas as pd
511
from random import sample
612
import scipy.sparse
7-
8-
import numpy as np
9-
import patsy
13+
from typing import Union, Dict, Tuple, List, Set
1014

1115
from .utils import split_x, dmat_unique
12-
13-
try:
14-
import anndata
15-
except ImportError:
16-
anndata = None
17-
18-
from batchglm.models.base import _EstimatorBase, _InputDataBase
19-
2016
from ..stats import stats
2117
from . import correction
2218
from diffxpy import pkg_constants
@@ -468,17 +464,17 @@ class DifferentialExpressionTestLRT(_DifferentialExpressionTestSingle):
468464

469465
sample_description: pd.DataFrame
470466
full_design_loc_info: patsy.design_info
471-
full_estim: _EstimatorBase
467+
full_estim: glm.typing.EstimatorBaseTyping
472468
reduced_design_loc_info: patsy.design_info
473-
reduced_estim: _EstimatorBase
469+
reduced_estim: glm.typing.EstimatorBaseTyping
474470

475471
def __init__(
476472
self,
477473
sample_description: pd.DataFrame,
478474
full_design_loc_info: patsy.design_info,
479-
full_estim: _EstimatorBase,
475+
full_estim: glm.typing.EstimatorBaseTyping,
480476
reduced_design_loc_info: patsy.design_info,
481-
reduced_estim: _EstimatorBase
477+
reduced_estim: glm.typing.EstimatorBaseTyping
482478
):
483479
super().__init__()
484480
self.sample_description = sample_description
@@ -689,7 +685,7 @@ class DifferentialExpressionTestWald(_DifferentialExpressionTestSingle):
689685
Single wald test per gene.
690686
"""
691687

692-
model_estim: _EstimatorBase
688+
model_estim: glm.typing.EstimatorBaseTyping
693689
sample_description: pd.DataFrame
694690
coef_loc_totest: np.ndarray
695691
theta_mle: np.ndarray
@@ -699,7 +695,7 @@ class DifferentialExpressionTestWald(_DifferentialExpressionTestSingle):
699695

700696
def __init__(
701697
self,
702-
model_estim: _EstimatorBase,
698+
model_estim: glm.typing.EstimatorBaseTyping,
703699
col_indices: np.ndarray,
704700
noise_model: str,
705701
sample_description: pd.DataFrame
@@ -1548,7 +1544,7 @@ def __init__(
15481544
super().__init__()
15491545
if isinstance(data, anndata.AnnData) or isinstance(data, anndata.Raw):
15501546
data = data.X
1551-
elif isinstance(data, _InputDataBase):
1547+
elif isinstance(data, glm.typing.InputDataBaseTyping):
15521548
data = data.x
15531549
self._x = data
15541550
self.sample_description = sample_description
@@ -1673,7 +1669,7 @@ def __init__(
16731669
super().__init__()
16741670
if isinstance(data, anndata.AnnData) or isinstance(data, anndata.Raw):
16751671
data = data.X
1676-
elif isinstance(data, _InputDataBase):
1672+
elif isinstance(data, glm.typing.InputDataBaseTyping):
16771673
data = data.x
16781674
self._x = data
16791675
self.sample_description = sample_description
@@ -2090,13 +2086,13 @@ class DifferentialExpressionTestZTest(_DifferentialExpressionTestMulti):
20902086
Pairwise unit_test between more than 2 groups per gene.
20912087
"""
20922088

2093-
model_estim: _EstimatorBase
2089+
model_estim: glm.typing.EstimatorBaseTyping
20942090
theta_mle: np.ndarray
20952091
theta_sd: np.ndarray
20962092

20972093
def __init__(
20982094
self,
2099-
model_estim: _EstimatorBase,
2095+
model_estim: glm.typing.EstimatorBaseTyping,
21002096
grouping,
21012097
groups,
21022098
correction_type: str
@@ -2293,13 +2289,13 @@ class DifferentialExpressionTestZTestLazy(_DifferentialExpressionTestMulti):
22932289
memory.
22942290
"""
22952291

2296-
model_estim: _EstimatorBase
2292+
model_estim: glm.typing.EstimatorBaseTyping
22972293
_theta_mle: np.ndarray
22982294
_theta_sd: np.ndarray
22992295

23002296
def __init__(
23012297
self,
2302-
model_estim: _EstimatorBase,
2298+
model_estim: glm.typing.EstimatorBaseTyping,
23032299
grouping, groups,
23042300
correction_type="global"
23052301
):
@@ -2856,15 +2852,15 @@ def summary(self, qval_thres=None, fc_upper_thres=None,
28562852

28572853
class _DifferentialExpressionTestCont(_DifferentialExpressionTestSingle):
28582854
_de_test: _DifferentialExpressionTestSingle
2859-
_model_estim: _EstimatorBase
2855+
_model_estim: glm.typing.EstimatorBaseTyping
28602856
_size_factors: np.ndarray
28612857
_continuous_coords: np.ndarray
28622858
_spline_coefs: list
28632859

28642860
def __init__(
28652861
self,
28662862
de_test: _DifferentialExpressionTestSingle,
2867-
model_estim: _EstimatorBase,
2863+
model_estim: glm.typing.EstimatorBaseTyping,
28682864
size_factors: np.ndarray,
28692865
continuous_coords: np.ndarray,
28702866
spline_coefs: list,

diffxpy/testing/tests.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,14 @@
33
from anndata.base import Raw
44
except ImportError:
55
from anndata import Raw
6+
import batchglm.api as glm
67
import logging
78
import numpy as np
89
import pandas as pd
910
import patsy
1011
import scipy.sparse
1112
from typing import Union, List, Dict, Callable, Tuple
1213

13-
from batchglm import data as data_utils
14-
from batchglm.models.base import _EstimatorBase, _InputDataBase
1514
from diffxpy import pkg_constants
1615
from diffxpy.models.batch_bfgs.optim import Estim_BFGS
1716
from .det import DifferentialExpressionTestLRT, DifferentialExpressionTestWald, \
@@ -40,7 +39,7 @@ def _fit(
4039
quick_scale: bool = None,
4140
close_session=True,
4241
dtype="float64"
43-
) -> _EstimatorBase:
42+
) -> glm.typing.InputDataBaseTyping:
4443
"""
4544
:param noise_model: str, noise model to use in model-based unit_test. Possible options:
4645
@@ -187,7 +186,7 @@ def _fit(
187186

188187

189188
def lrt(
190-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, _InputDataBase],
189+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
191190
full_formula_loc: str,
192191
reduced_formula_loc: str,
193192
full_formula_scale: str = "~1",
@@ -298,25 +297,25 @@ def lrt(
298297
sample_description=sample_description
299298
)
300299

301-
full_design_loc = data_utils.design_matrix(
300+
full_design_loc = glm.data.design_matrix(
302301
sample_description=sample_description,
303302
formula=full_formula_loc,
304303
as_categorical=[False if x in as_numeric else True for x in sample_description.columns.values],
305304
return_type="patsy"
306305
)
307-
reduced_design_loc = data_utils.design_matrix(
306+
reduced_design_loc = glm.data.design_matrix(
308307
sample_description=sample_description,
309308
formula=reduced_formula_loc,
310309
as_categorical=[False if x in as_numeric else True for x in sample_description.columns.values],
311310
return_type="patsy"
312311
)
313-
full_design_scale = data_utils.design_matrix(
312+
full_design_scale = glm.data.design_matrix(
314313
sample_description=sample_description,
315314
formula=full_formula_scale,
316315
as_categorical=[False if x in as_numeric else True for x in sample_description.columns.values],
317316
return_type="patsy"
318317
)
319-
reduced_design_scale = data_utils.design_matrix(
318+
reduced_design_scale = glm.data.design_matrix(
320319
sample_description=sample_description,
321320
formula=reduced_formula_scale,
322321
as_categorical=[False if x in as_numeric else True for x in sample_description.columns.values],
@@ -371,7 +370,7 @@ def lrt(
371370

372371

373372
def wald(
374-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, _InputDataBase],
373+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
375374
factor_loc_totest: Union[str, List[str]] = None,
376375
coef_to_test: Union[str, List[str]] = None,
377376
formula_loc: Union[None, str] = None,
@@ -597,7 +596,7 @@ def wald(
597596
elif coef_to_test is not None:
598597
# Directly select coefficients to test from design matrix (xarray):
599598
# Check that coefficients to test are not dependent parameters if constraints are given:
600-
coef_loc_names = data_utils.view_coef_names(design_loc).tolist()
599+
coef_loc_names = glm.data.view_coef_names(design_loc).tolist()
601600
if not np.all([x in coef_loc_names for x in coef_to_test]):
602601
raise ValueError(
603602
"the requested test coefficients %s were found in model coefficients %s" %
@@ -645,7 +644,7 @@ def wald(
645644

646645

647646
def t_test(
648-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, _InputDataBase],
647+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
649648
grouping,
650649
gene_names: Union[np.ndarray, list] = None,
651650
sample_description: pd.DataFrame = None,
@@ -687,7 +686,7 @@ def t_test(
687686

688687

689688
def rank_test(
690-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, _InputDataBase],
689+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
691690
grouping: Union[str, np.ndarray, list],
692691
gene_names: Union[np.ndarray, list] = None,
693692
sample_description: pd.DataFrame = None,
@@ -729,7 +728,7 @@ def rank_test(
729728

730729

731730
def two_sample(
732-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, _InputDataBase],
731+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
733732
grouping: Union[str, np.ndarray, list],
734733
as_numeric: Union[List[str], Tuple[str], str] = (),
735734
test: str = "t-test",
@@ -902,7 +901,7 @@ def two_sample(
902901

903902

904903
def pairwise(
905-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, _InputDataBase],
904+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
906905
grouping: Union[str, np.ndarray, list],
907906
as_numeric: Union[List[str], Tuple[str], str] = (),
908907
test: str = 'z-test',
@@ -1026,7 +1025,7 @@ def pairwise(
10261025

10271026
if test.lower() == 'z-test' or test.lower() == 'z_test' or test.lower() == 'ztest':
10281027
# -1 in formula removes intercept
1029-
dmat = data_utils.design_matrix(
1028+
dmat = glm.data.design_matrix(
10301029
sample_description,
10311030
formula="~ 1 - 1 + grouping"
10321031
)
@@ -1113,7 +1112,7 @@ def pairwise(
11131112

11141113

11151114
def versus_rest(
1116-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, _InputDataBase],
1115+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
11171116
grouping: Union[str, np.ndarray, list],
11181117
as_numeric: Union[List[str], Tuple[str], str] = (),
11191118
test: str = 'wald',
@@ -1275,7 +1274,7 @@ def versus_rest(
12751274

12761275

12771276
def partition(
1278-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, _InputDataBase],
1277+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
12791278
parts: Union[str, np.ndarray, list],
12801279
gene_names: Union[np.ndarray, list] = None,
12811280
sample_description: pd.DataFrame = None
@@ -1318,7 +1317,7 @@ class _Partition:
13181317

13191318
def __init__(
13201319
self,
1321-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, _InputDataBase],
1320+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
13221321
parts: Union[str, np.ndarray, list],
13231322
gene_names: Union[np.ndarray, list] = None,
13241323
sample_description: pd.DataFrame = None
@@ -1333,7 +1332,7 @@ def __init__(
13331332
:param gene_names: optional list/array of gene names which will be used if `data` does not implicitly store these
13341333
:param sample_description: optional pandas.DataFrame containing sample annotations
13351334
"""
1336-
if isinstance(data, _InputDataBase):
1335+
if isinstance(data, glm.typing.InputDataBaseTyping):
13371336
self.x = data.x
13381337
elif isinstance(data, anndata.AnnData) or isinstance(data, Raw):
13391338
self.x = data.X

0 commit comments

Comments
 (0)