11from typing import Union , List , Dict , Callable , Tuple
22
3+ import anndata
4+ import logging
35import numpy as np
46import pandas as pd
57import patsy
1012from batchglm .xarray_sparse import SparseXArrayDataSet
1113from diffxpy import pkg_constants
1214from diffxpy .models .batch_bfgs .optim import Estim_BFGS
13- from diffxpy . testing . det import anndata , logger , DifferentialExpressionTestLRT , DifferentialExpressionTestWald , \
15+ from . det import DifferentialExpressionTestLRT , DifferentialExpressionTestWald , \
1416 DifferentialExpressionTestTT , DifferentialExpressionTestRank , _DifferentialExpressionTestSingle , \
1517 DifferentialExpressionTestZTestLazy , DifferentialExpressionTestZTest , DifferentialExpressionTestPairwise , \
1618 DifferentialExpressionTestVsRest , _DifferentialExpressionTestMulti , DifferentialExpressionTestByPartition , \
1719 DifferentialExpressionTestWaldCont , DifferentialExpressionTestLRTCont
20+ from .utils import parse_gene_names , parse_data , parse_sample_description , parse_size_factors , parse_grouping
1821
22+ logger = logging .getLogger ("diffxpy" )
1923
20- def _parse_gene_names (data , gene_names ):
21- if gene_names is None :
22- if anndata is not None and (isinstance (data , anndata .AnnData ) or isinstance (data , anndata .base .Raw )):
23- gene_names = data .var_names
24- elif isinstance (data , xr .DataArray ):
25- gene_names = data ["features" ]
26- elif isinstance (data , xr .Dataset ):
27- gene_names = data ["features" ]
28- else :
29- raise ValueError ("Missing gene names" )
30-
31- return np .asarray (gene_names )
32-
33-
34- def _parse_data (data , gene_names ) -> xr .DataArray :
35- X = data_utils .xarray_from_data (data , dims = ("observations" , "features" ))
36- if gene_names is not None :
37- X .coords ["features" ] = gene_names
38-
39- return X
40-
41-
42- def _parse_sample_description (data , sample_description = None ) -> pd .DataFrame :
43- if sample_description is None :
44- if anndata is not None and isinstance (data , anndata .AnnData ):
45- sample_description = data_utils .sample_description_from_anndata (
46- dataset = data ,
47- )
48- elif isinstance (data , xr .Dataset ):
49- sample_description = data_utils .sample_description_from_xarray (
50- dataset = data ,
51- dim = "observations" ,
52- )
53- else :
54- raise ValueError (
55- "Please specify `sample_description` or provide `data` as xarray.Dataset or anndata.AnnData " +
56- "with corresponding sample annotations"
57- )
58-
59- if anndata is not None and isinstance (data , anndata .base .Raw ):
60- # anndata.base.Raw does not have attribute shape.
61- assert data .X .shape [0 ] == sample_description .shape [0 ], \
62- "data matrix and sample description must contain same number of cells"
63- else :
64- assert data .shape [0 ] == sample_description .shape [0 ], \
65- "data matrix and sample description must contain same number of cells"
66- return sample_description
67-
68-
69- def _parse_size_factors (size_factors , data ):
70- if size_factors is not None :
71- if isinstance (size_factors , pd .core .series .Series ):
72- size_factors = size_factors .values
73- assert size_factors .shape [0 ] == data .shape [0 ], "data matrix and size factors must contain same number of cells"
74- return size_factors
75-
76-
77- def design_matrix (
78- data = None ,
79- sample_description : pd .DataFrame = None ,
80- formula : str = None ,
81- dmat : pd .DataFrame = None
82- ) -> Union [patsy .design_info .DesignMatrix , xr .Dataset ]:
83- """ Build design matrix for fit of generalized linear model.
84-
85- This is necessary for wald tests and likelihood ratio tests.
86- This function only carries through formatting if dmat is directly supplied.
87-
88- :param data: input data
89- :param formula: model formula.
90- :param sample_description: optional pandas.DataFrame containing sample annotations
91- :param dmat: model design matrix
92- """
93- if data is None and sample_description is None and dmat is None :
94- raise ValueError ("Supply either data or sample_description or dmat." )
95- if dmat is None and formula is None :
96- raise ValueError ("Supply either dmat or formula." )
97-
98- if dmat is None :
99- sample_description = _parse_sample_description (data , sample_description )
100- dmat = data_utils .design_matrix (sample_description = sample_description , formula = formula )
101-
102- return dmat
103- else :
104- ar = xr .DataArray (dmat , dims = ("observations" , "design_params" ))
105- ar .coords ["design_params" ] = dmat .columns
106-
107- ds = xr .Dataset ({
108- "design" : ar ,
109- })
110-
111- return ds
112-
113-
114- def coef_names (
115- data = None ,
116- sample_description : pd .DataFrame = None ,
117- formula : str = None ,
118- dmat : pd .DataFrame = None
119- ) -> list :
120- """ Output coefficient names of model only.
121-
122- :param data: input data
123- :param formula: model formula.
124- :param sample_description: optional pandas.DataFrame containing sample annotations
125- :param dmat: model design matrix
126- """
127- return design_matrix (
128- data = data ,
129- sample_description = sample_description ,
130- formula = formula ,
131- dmat = dmat
132- ).design_info .column_names
24+ # Use this to suppress matrix subclass PendingDepreceationWarnings from numpy:
25+ np .warnings .filterwarnings ("ignore" )
13326
13427
13528def _fit (
@@ -408,10 +301,10 @@ def lrt(
408301 if isinstance (as_numeric , str ):
409302 as_numeric = [as_numeric ]
410303
411- gene_names = _parse_gene_names (data , gene_names )
412- X = _parse_data (data , gene_names )
413- sample_description = _parse_sample_description (data , sample_description )
414- size_factors = _parse_size_factors (size_factors = size_factors , data = X )
304+ gene_names = parse_gene_names (data , gene_names )
305+ X = parse_data (data , gene_names )
306+ sample_description = parse_sample_description (data , sample_description )
307+ size_factors = parse_size_factors (size_factors = size_factors , data = X )
415308
416309 full_design_loc = data_utils .design_matrix (
417310 sample_description = sample_description ,
@@ -623,11 +516,11 @@ def wald(
623516 as_numeric = [as_numeric ]
624517
625518 # # Parse input data formats:
626- gene_names = _parse_gene_names (data , gene_names )
627- X = _parse_data (data , gene_names )
519+ gene_names = parse_gene_names (data , gene_names )
520+ X = parse_data (data , gene_names )
628521 if dmat_loc is None and dmat_scale is None :
629- sample_description = _parse_sample_description (data , sample_description )
630- size_factors = _parse_size_factors (size_factors = size_factors , data = X )
522+ sample_description = parse_sample_description (data , sample_description )
523+ size_factors = parse_size_factors (size_factors = size_factors , data = X )
631524
632525 if dmat_loc is None :
633526 design_loc = data_utils .design_matrix (
@@ -714,20 +607,6 @@ def wald(
714607 return de_test
715608
716609
717- def _parse_grouping (data , sample_description , grouping ):
718- if isinstance (grouping , str ):
719- sample_description = _parse_sample_description (data , sample_description )
720- grouping = sample_description [grouping ]
721- return np .squeeze (np .asarray (grouping ))
722-
723-
724- def _split_X (data , grouping ):
725- groups = np .unique (grouping )
726- x0 = data [np .where (grouping == groups [0 ])[0 ]]
727- x1 = data [np .where (grouping == groups [1 ])[0 ]]
728- return x0 , x1
729-
730-
731610def t_test (
732611 data : Union [anndata .AnnData , anndata .base .Raw , xr .DataArray , xr .Dataset , np .ndarray , scipy .sparse .csr_matrix ],
733612 grouping ,
@@ -752,11 +631,11 @@ def t_test(
752631 Whether data is already logged. If True, log-fold changes are computed as fold changes on this data.
753632 If False, log-fold changes are computed as log-fold changes on this data.
754633 """
755- gene_names = _parse_gene_names (data , gene_names )
756- X = _parse_data (data , gene_names )
634+ gene_names = parse_gene_names (data , gene_names )
635+ X = parse_data (data , gene_names )
757636 if isinstance (X , SparseXArrayDataSet ):
758637 X = X .X
759- grouping = _parse_grouping (data , sample_description , grouping )
638+ grouping = parse_grouping (data , sample_description , grouping )
760639
761640 de_test = DifferentialExpressionTestTT (
762641 data = X .astype (dtype ),
@@ -792,11 +671,11 @@ def rank_test(
792671 Whether data is already logged. If True, log-fold changes are computed as fold changes on this data.
793672 If False, log-fold changes are computed as log-fold changes on this data.
794673 """
795- gene_names = _parse_gene_names (data , gene_names )
796- X = _parse_data (data , gene_names )
674+ gene_names = parse_gene_names (data , gene_names )
675+ X = parse_data (data , gene_names )
797676 if isinstance (X , SparseXArrayDataSet ):
798677 X = X .X
799- grouping = _parse_grouping (data , sample_description , grouping )
678+ grouping = parse_grouping (data , sample_description , grouping )
800679
801680 de_test = DifferentialExpressionTestRank (
802681 data = X .astype (dtype ),
@@ -910,9 +789,9 @@ def two_sample(
910789 raise ValueError ('base.two_sample(): Do not specify `noise_model` if using test t-test or wilcoxon: ' +
911790 'The t-test is based on a gaussian noise model and wilcoxon is model free.' )
912791
913- gene_names = _parse_gene_names (data , gene_names )
914- X = _parse_data (data , gene_names )
915- grouping = _parse_grouping (data , sample_description , grouping )
792+ gene_names = parse_gene_names (data , gene_names )
793+ X = parse_data (data , gene_names )
794+ grouping = parse_grouping (data , sample_description , grouping )
916795 sample_description = pd .DataFrame ({"grouping" : grouping })
917796
918797 groups = np .unique (grouping )
@@ -1116,10 +995,10 @@ def pairwise(
1116995
1117996 # Do not store all models but only p-value and q-value matrix:
1118997 # genes x groups x groups
1119- gene_names = _parse_gene_names (data , gene_names )
1120- X = _parse_data (data , gene_names )
1121- sample_description = _parse_sample_description (data , sample_description )
1122- grouping = _parse_grouping (data , sample_description , grouping )
998+ gene_names = parse_gene_names (data , gene_names )
999+ X = parse_data (data , gene_names )
1000+ sample_description = parse_sample_description (data , sample_description )
1001+ grouping = parse_grouping (data , sample_description , grouping )
11231002 sample_description = pd .DataFrame ({"grouping" : grouping })
11241003
11251004 if test .lower () == 'z-test' or test .lower () == 'z_test' or test .lower () == 'ztest' :
@@ -1324,10 +1203,10 @@ def versus_rest(
13241203
13251204 # Do not store all models but only p-value and q-value matrix:
13261205 # genes x groups
1327- gene_names = _parse_gene_names (data , gene_names )
1328- X = _parse_data (data , gene_names )
1329- sample_description = _parse_sample_description (data , sample_description )
1330- grouping = _parse_grouping (data , sample_description , grouping )
1206+ gene_names = parse_gene_names (data , gene_names )
1207+ X = parse_data (data , gene_names )
1208+ sample_description = parse_sample_description (data , sample_description )
1209+ grouping = parse_grouping (data , sample_description , grouping )
13311210 sample_description = pd .DataFrame ({"grouping" : grouping })
13321211
13331212 groups = np .unique (grouping )
@@ -1427,10 +1306,10 @@ def __init__(
14271306 :param gene_names: optional list/array of gene names which will be used if `data` does not implicitly store these
14281307 :param sample_description: optional pandas.DataFrame containing sample annotations
14291308 """
1430- self .X = _parse_data (data , gene_names )
1431- self .gene_names = _parse_gene_names (data , gene_names )
1432- self .sample_description = _parse_sample_description (data , sample_description )
1433- self .partition = _parse_grouping (data , sample_description , partition )
1309+ self .X = parse_data (data , gene_names )
1310+ self .gene_names = parse_gene_names (data , gene_names )
1311+ self .sample_description = parse_sample_description (data , sample_description )
1312+ self .partition = parse_grouping (data , sample_description , partition )
14341313 self .partitions = np .unique (self .partition )
14351314 self .partition_idx = [np .where (self .partition == x )[0 ] for x in self .partitions ]
14361315
@@ -1874,9 +1753,9 @@ def continuous_1d(
18741753 if isinstance (as_numeric , tuple ):
18751754 as_numeric = list (as_numeric )
18761755
1877- X = _parse_data (data , gene_names )
1878- gene_names = _parse_gene_names (data , gene_names )
1879- sample_description = _parse_sample_description (data , sample_description )
1756+ X = parse_data (data , gene_names )
1757+ gene_names = parse_gene_names (data , gene_names )
1758+ sample_description = parse_sample_description (data , sample_description )
18801759
18811760 # Check that continuous factor is contained in sample description
18821761 if continuous not in sample_description .columns :
0 commit comments