@@ -39,7 +39,7 @@ def _fit(
3939 quick_scale : bool = None ,
4040 close_session = True ,
4141 dtype = "float64"
42- ) -> glm .typing .InputDataBaseTyping :
42+ ) -> glm .typing .InputDataBase :
4343 """
4444 :param noise_model: str, noise model to use in model-based unit_test. Possible options:
4545
@@ -186,7 +186,7 @@ def _fit(
186186
187187
188188def lrt (
189- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBaseTyping ],
189+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBase ],
190190 full_formula_loc : str ,
191191 reduced_formula_loc : str ,
192192 full_formula_scale : str = "~1" ,
@@ -370,7 +370,7 @@ def lrt(
370370
371371
372372def wald (
373- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBaseTyping ],
373+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBase ],
374374 factor_loc_totest : Union [str , List [str ]] = None ,
375375 coef_to_test : Union [str , List [str ]] = None ,
376376 formula_loc : Union [None , str ] = None ,
@@ -547,7 +547,7 @@ def wald(
547547 if isinstance (as_numeric , str ):
548548 as_numeric = [as_numeric ]
549549
550- # # Parse input data formats:
550+ # Parse input data formats:
551551 gene_names = parse_gene_names (data , gene_names )
552552 if dmat_loc is None and dmat_scale is None :
553553 sample_description = parse_sample_description (data , sample_description )
@@ -644,7 +644,7 @@ def wald(
644644
645645
646646def t_test (
647- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBaseTyping ],
647+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBase ],
648648 grouping ,
649649 gene_names : Union [np .ndarray , list ] = None ,
650650 sample_description : pd .DataFrame = None ,
@@ -686,7 +686,7 @@ def t_test(
686686
687687
688688def rank_test (
689- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBaseTyping ],
689+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBase ],
690690 grouping : Union [str , np .ndarray , list ],
691691 gene_names : Union [np .ndarray , list ] = None ,
692692 sample_description : pd .DataFrame = None ,
@@ -728,7 +728,7 @@ def rank_test(
728728
729729
730730def two_sample (
731- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBaseTyping ],
731+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBase ],
732732 grouping : Union [str , np .ndarray , list ],
733733 as_numeric : Union [List [str ], Tuple [str ], str ] = (),
734734 test : str = "t-test" ,
@@ -883,16 +883,14 @@ def two_sample(
883883 data = data ,
884884 gene_names = gene_names ,
885885 grouping = grouping ,
886- is_sig_zerovar = is_sig_zerovar ,
887- dtype = dtype
886+ is_sig_zerovar = is_sig_zerovar
888887 )
889888 elif test .lower () == 'rank' :
890889 de_test = rank_test (
891890 data = data ,
892891 gene_names = gene_names ,
893892 grouping = grouping ,
894- is_sig_zerovar = is_sig_zerovar ,
895- dtype = dtype
893+ is_sig_zerovar = is_sig_zerovar
896894 )
897895 else :
898896 raise ValueError ('two_sample(): Parameter `test="%s"` not recognized.' % test )
@@ -901,7 +899,7 @@ def two_sample(
901899
902900
903901def pairwise (
904- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBaseTyping ],
902+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBase ],
905903 grouping : Union [str , np .ndarray , list ],
906904 as_numeric : Union [List [str ], Tuple [str ], str ] = (),
907905 test : str = 'z-test' ,
@@ -1036,6 +1034,8 @@ def pairwise(
10361034 design_scale = dmat ,
10371035 gene_names = gene_names ,
10381036 size_factors = size_factors ,
1037+ init_a = "closed_form" ,
1038+ init_b = "closed_form" ,
10391039 batch_size = batch_size ,
10401040 training_strategy = training_strategy ,
10411041 quick_scale = quick_scale ,
@@ -1058,6 +1058,10 @@ def pairwise(
10581058 correction_type = pval_correction
10591059 )
10601060 else :
1061+ if isinstance (data , anndata .AnnData ) or isinstance (data , anndata .Raw ):
1062+ data = data .X
1063+ elif isinstance (data , glm .typing .InputDataBase ):
1064+ data = data .x
10611065 groups = np .unique (grouping )
10621066 pvals = np .tile (np .NaN , [len (groups ), len (groups ), data .shape [1 ]])
10631067 pvals [np .eye (pvals .shape [0 ]).astype (bool )] = 0
@@ -1073,16 +1077,19 @@ def pairwise(
10731077 for j , g2 in enumerate (groups [(i + 1 ):]):
10741078 j = j + i + 1
10751079
1076- sel = (grouping == g1 ) | (grouping == g2 )
1080+ idx = np .where (np .logical_or (
1081+ grouping == g1 ,
1082+ grouping == g2
1083+ ))[0 ]
10771084 de_test_temp = two_sample (
1078- data = data [sel ],
1079- grouping = grouping [sel ],
1085+ data = data [idx ],
1086+ grouping = grouping [idx ],
10801087 as_numeric = as_numeric ,
10811088 test = test ,
10821089 gene_names = gene_names ,
1083- sample_description = sample_description .iloc [sel ],
1090+ sample_description = sample_description .iloc [idx , : ],
10841091 noise_model = noise_model ,
1085- size_factors = size_factors [sel ] if size_factors is not None else None ,
1092+ size_factors = size_factors [idx ] if size_factors is not None else None ,
10861093 batch_size = batch_size ,
10871094 training_strategy = training_strategy ,
10881095 quick_scale = quick_scale ,
@@ -1112,7 +1119,7 @@ def pairwise(
11121119
11131120
11141121def versus_rest (
1115- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBaseTyping ],
1122+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBase ],
11161123 grouping : Union [str , np .ndarray , list ],
11171124 as_numeric : Union [List [str ], Tuple [str ], str ] = (),
11181125 test : str = 'wald' ,
@@ -1274,7 +1281,7 @@ def versus_rest(
12741281
12751282
12761283def partition (
1277- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBaseTyping ],
1284+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBase ],
12781285 parts : Union [str , np .ndarray , list ],
12791286 gene_names : Union [np .ndarray , list ] = None ,
12801287 sample_description : pd .DataFrame = None
@@ -1317,7 +1324,7 @@ class _Partition:
13171324
13181325 def __init__ (
13191326 self ,
1320- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBaseTyping ],
1327+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBase ],
13211328 parts : Union [str , np .ndarray , list ],
13221329 gene_names : Union [np .ndarray , list ] = None ,
13231330 sample_description : pd .DataFrame = None
@@ -1332,7 +1339,7 @@ def __init__(
13321339 :param gene_names: optional list/array of gene names which will be used if `data` does not implicitly store these
13331340 :param sample_description: optional pandas.DataFrame containing sample annotations
13341341 """
1335- if isinstance (data , glm .typing .InputDataBaseTyping ):
1342+ if isinstance (data , glm .typing .InputDataBase ):
13361343 self .x = data .x
13371344 elif isinstance (data , anndata .AnnData ) or isinstance (data , Raw ):
13381345 self .x = data .X
0 commit comments