33 from anndata .base import Raw
44except ImportError :
55 from anndata import Raw
6+ import batchglm .api as glm
67import logging
78import numpy as np
89import pandas as pd
910import patsy
1011import scipy .sparse
1112from typing import Union , List , Dict , Callable , Tuple
1213
13- from batchglm import data as data_utils
14- from batchglm .models .base import _EstimatorBase , _InputDataBase
1514from diffxpy import pkg_constants
1615from diffxpy .models .batch_bfgs .optim import Estim_BFGS
1716from .det import DifferentialExpressionTestLRT , DifferentialExpressionTestWald , \
@@ -40,7 +39,7 @@ def _fit(
4039 quick_scale : bool = None ,
4140 close_session = True ,
4241 dtype = "float64"
43- ) -> _EstimatorBase :
42+ ) -> glm . typing . InputDataBaseTyping :
4443 """
4544 :param noise_model: str, noise model to use in model-based unit_test. Possible options:
4645
@@ -187,7 +186,7 @@ def _fit(
187186
188187
189188def lrt (
190- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , _InputDataBase ],
189+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm . typing . InputDataBaseTyping ],
191190 full_formula_loc : str ,
192191 reduced_formula_loc : str ,
193192 full_formula_scale : str = "~1" ,
@@ -298,25 +297,25 @@ def lrt(
298297 sample_description = sample_description
299298 )
300299
301- full_design_loc = data_utils .design_matrix (
300+ full_design_loc = glm . data .design_matrix (
302301 sample_description = sample_description ,
303302 formula = full_formula_loc ,
304303 as_categorical = [False if x in as_numeric else True for x in sample_description .columns .values ],
305304 return_type = "patsy"
306305 )
307- reduced_design_loc = data_utils .design_matrix (
306+ reduced_design_loc = glm . data .design_matrix (
308307 sample_description = sample_description ,
309308 formula = reduced_formula_loc ,
310309 as_categorical = [False if x in as_numeric else True for x in sample_description .columns .values ],
311310 return_type = "patsy"
312311 )
313- full_design_scale = data_utils .design_matrix (
312+ full_design_scale = glm . data .design_matrix (
314313 sample_description = sample_description ,
315314 formula = full_formula_scale ,
316315 as_categorical = [False if x in as_numeric else True for x in sample_description .columns .values ],
317316 return_type = "patsy"
318317 )
319- reduced_design_scale = data_utils .design_matrix (
318+ reduced_design_scale = glm . data .design_matrix (
320319 sample_description = sample_description ,
321320 formula = reduced_formula_scale ,
322321 as_categorical = [False if x in as_numeric else True for x in sample_description .columns .values ],
@@ -371,7 +370,7 @@ def lrt(
371370
372371
373372def wald (
374- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , _InputDataBase ],
373+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm . typing . InputDataBaseTyping ],
375374 factor_loc_totest : Union [str , List [str ]] = None ,
376375 coef_to_test : Union [str , List [str ]] = None ,
377376 formula_loc : Union [None , str ] = None ,
@@ -597,7 +596,7 @@ def wald(
597596 elif coef_to_test is not None :
598597 # Directly select coefficients to test from design matrix (xarray):
599598 # Check that coefficients to test are not dependent parameters if constraints are given:
600- coef_loc_names = data_utils .view_coef_names (design_loc ).tolist ()
599+ coef_loc_names = glm . data .view_coef_names (design_loc ).tolist ()
601600 if not np .all ([x in coef_loc_names for x in coef_to_test ]):
602601 raise ValueError (
603602 "the requested test coefficients %s were found in model coefficients %s" %
@@ -645,7 +644,7 @@ def wald(
645644
646645
647646def t_test (
648- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , _InputDataBase ],
647+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm . typing . InputDataBaseTyping ],
649648 grouping ,
650649 gene_names : Union [np .ndarray , list ] = None ,
651650 sample_description : pd .DataFrame = None ,
@@ -687,7 +686,7 @@ def t_test(
687686
688687
689688def rank_test (
690- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , _InputDataBase ],
689+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm . typing . InputDataBaseTyping ],
691690 grouping : Union [str , np .ndarray , list ],
692691 gene_names : Union [np .ndarray , list ] = None ,
693692 sample_description : pd .DataFrame = None ,
@@ -729,7 +728,7 @@ def rank_test(
729728
730729
731730def two_sample (
732- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , _InputDataBase ],
731+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm . typing . InputDataBaseTyping ],
733732 grouping : Union [str , np .ndarray , list ],
734733 as_numeric : Union [List [str ], Tuple [str ], str ] = (),
735734 test : str = "t-test" ,
@@ -902,7 +901,7 @@ def two_sample(
902901
903902
904903def pairwise (
905- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , _InputDataBase ],
904+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm . typing . InputDataBaseTyping ],
906905 grouping : Union [str , np .ndarray , list ],
907906 as_numeric : Union [List [str ], Tuple [str ], str ] = (),
908907 test : str = 'z-test' ,
@@ -1026,7 +1025,7 @@ def pairwise(
10261025
10271026 if test .lower () == 'z-test' or test .lower () == 'z_test' or test .lower () == 'ztest' :
10281027 # -1 in formula removes intercept
1029- dmat = data_utils .design_matrix (
1028+ dmat = glm . data .design_matrix (
10301029 sample_description ,
10311030 formula = "~ 1 - 1 + grouping"
10321031 )
@@ -1113,7 +1112,7 @@ def pairwise(
11131112
11141113
11151114def versus_rest (
1116- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , _InputDataBase ],
1115+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm . typing . InputDataBaseTyping ],
11171116 grouping : Union [str , np .ndarray , list ],
11181117 as_numeric : Union [List [str ], Tuple [str ], str ] = (),
11191118 test : str = 'wald' ,
@@ -1275,7 +1274,7 @@ def versus_rest(
12751274
12761275
12771276def partition (
1278- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , _InputDataBase ],
1277+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm . typing . InputDataBaseTyping ],
12791278 parts : Union [str , np .ndarray , list ],
12801279 gene_names : Union [np .ndarray , list ] = None ,
12811280 sample_description : pd .DataFrame = None
@@ -1318,7 +1317,7 @@ class _Partition:
13181317
13191318 def __init__ (
13201319 self ,
1321- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , _InputDataBase ],
1320+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm . typing . InputDataBaseTyping ],
13221321 parts : Union [str , np .ndarray , list ],
13231322 gene_names : Union [np .ndarray , list ] = None ,
13241323 sample_description : pd .DataFrame = None
@@ -1333,7 +1332,7 @@ def __init__(
13331332 :param gene_names: optional list/array of gene names which will be used if `data` does not implicitly store these
13341333 :param sample_description: optional pandas.DataFrame containing sample annotations
13351334 """
1336- if isinstance (data , _InputDataBase ):
1335+ if isinstance (data , glm . typing . InputDataBaseTyping ):
13371336 self .x = data .x
13381337 elif isinstance (data , anndata .AnnData ) or isinstance (data , Raw ):
13391338 self .x = data .X
0 commit comments