11import abc
22import logging
33from typing import Union , Dict , Tuple , List , Set , Callable
4-
54import pandas as pd
5+ import warnings
66
77import numpy as np
88import xarray as xr
9-
9+ import patsy
1010try :
1111 import anndata
1212except ImportError :
1313 anndata = None
1414
15- import patsy
1615import batchglm .data as data_utils
17- from batchglm .api . models .glm_nb import Model as GeneralizedLinearModel
16+ from batchglm .models .glm_nb import Model as GeneralizedLinearModel
1817
1918from ..stats import stats
2019from . import correction
@@ -129,7 +128,7 @@ def __init__(self):
129128 self ._pval = None
130129 self ._qval = None
131130 self ._mean = None
132- self ._log_probs = None
131+ self ._log_likelihood = None
133132
134133 @property
135134 @abc .abstractmethod
@@ -179,10 +178,10 @@ def _ave(self):
179178 pass
180179
181180 @property
182- def log_probs (self ):
183- if self ._log_probs is None :
184- self ._log_probs = self ._ll ().compute ()
185- return self ._log_probs
181+ def log_likelihood (self ):
182+ if self ._log_likelihood is None :
183+ self ._log_likelihood = self ._ll ().compute ()
184+ return self ._log_likelihood
186185
187186 @property
188187 def mean (self ):
@@ -675,7 +674,7 @@ def _ll(self):
675674
676675 :return: xr.DataArray
677676 """
678- return np . sum ( self .model_estim .log_likelihood , axis = 0 )
677+ return self .model_estim .log_likelihood
679678
680679 def _ave (self ):
681680 """
@@ -730,8 +729,8 @@ def summary(self, qval_thres=None, fc_upper_thres=None,
730729 if len (self .theta_sd .shape ) == 1 :
731730 res ["coef_sd" ] = self .theta_sd
732731 # add in info from bfgs
733- if self .log_probs is not None :
734- res ["ll" ] = self .log_probs
732+ if self .log_likelihood is not None :
733+ res ["ll" ] = self .log_likelihood
735734 if self ._error_codes is not None :
736735 res ["err" ] = self ._error_codes
737736 if self ._niter is not None :
@@ -1243,7 +1242,7 @@ def X(self):
12431242 return self .model_estim .X
12441243
12451244 @property
1246- def log_probs (self ):
1245+ def log_likelihood (self ):
12471246 return np .sum (self .model_estim .log_probs (), axis = 0 )
12481247
12491248 @property
@@ -1448,7 +1447,7 @@ def X(self):
14481447 return self .model_estim .X
14491448
14501449 @property
1451- def log_probs (self ):
1450+ def log_likelihood (self ):
14521451 return np .sum (self .model_estim .log_probs (), axis = 0 )
14531452
14541453 @property
@@ -1912,8 +1911,8 @@ def mean(self) -> np.ndarray:
19121911 return self ._de_test .mean
19131912
19141913 @property
1915- def log_probs (self ) -> np .ndarray :
1916- return self ._de_test .log_probs
1914+ def log_likelihood (self ) -> np .ndarray :
1915+ return self ._de_test .log_likelihood
19171916
19181917 def summary (self , nonnumeric = False , qval_thres = None , fc_upper_thres = None ,
19191918 fc_lower_thres = None , mean_thres = None ) -> pd .DataFrame :
@@ -2096,6 +2095,8 @@ def plot_genes(
20962095 self ,
20972096 genes ,
20982097 hue = None ,
2098+ size = 1 ,
2099+ log = True ,
20992100 nonnumeric = False ,
21002101 save = None ,
21012102 show = True ,
@@ -2108,6 +2109,7 @@ def plot_genes(
21082109
21092110 :param genes: Gene IDs to plot.
21102111 :param hue: Confounder to include in plot.
2112+ :param size: Point size.
21112113 :param nonnumeric:
21122114 :param save: Path+file name stem to save plots to.
21132115 File will be save+"_genes.png". Does not save if save is None.
@@ -2151,21 +2153,33 @@ def plot_genes(
21512153 ax = plt .subplot (gs [i ])
21522154 axs .append (ax )
21532155
2156+ y = self .X [:, genes [0 ]]
2157+ yhat = self ._continuous_model (idx = g , nonnumeric = nonnumeric )
2158+ if log :
2159+ y = np .log (y + 1 )
2160+ yhat = np .log (yhat + 1 )
2161+
21542162 sns .scatterplot (
21552163 x = self ._continuous_coords ,
2156- y = self . X [:, genes [ 0 ]] ,
2164+ y = y ,
21572165 hue = hue ,
2158- ax = ax
2166+ size = size ,
2167+ ax = ax ,
2168+ legend = False
21592169 )
21602170 sns .lineplot (
21612171 x = self ._continuous_coords ,
2162- y = self . _continuous_model ( idx = g , nonnumeric = nonnumeric ) ,
2172+ y = yhat ,
21632173 hue = hue ,
21642174 ax = ax
21652175 )
2176+
21662177 ax .set_title (genes [i ])
21672178 ax .set_xlabel ("continuous" )
2168- ax .set_ylabel ("expression" )
2179+ if log :
2180+ ax .set_ylabel ("log expression" )
2181+ else :
2182+ ax .set_ylabel ("expression" )
21692183
21702184 # Save, show and return figure.
21712185 if save is not None :
@@ -2536,57 +2550,56 @@ def _fit(
25362550 raise ValueError ('base.test(): `noise_model="%s"` not recognized.' % noise_model )
25372551 else :
25382552 if noise_model == "nb" or noise_model == "negative_binomial" :
2539- import batchglm .api .models .glm_nb as test_model
2540-
2541- logger .info ("Fitting model..." )
2542- logger .debug (" * Assembling input data..." )
2543- input_data = test_model .InputData .new (
2544- data = data ,
2545- design_loc = design_loc ,
2546- design_scale = design_scale ,
2547- constraints_loc = constraints_loc ,
2548- constraints_scale = constraints_scale ,
2549- size_factors = size_factors ,
2550- feature_names = gene_names ,
2551- )
2553+ from batchglm .api .models .glm_nb import Estimator , InputData
2554+ else :
2555+ raise ValueError ('base.test(): `noise_model="%s"` not recognized.' % noise_model )
25522556
2553- logger .debug (" * Set up Estimator..." )
2554- constructor_args = {}
2555- if batch_size is not None :
2556- constructor_args ["batch_size" ] = batch_size
2557- if quick_scale is not None :
2558- constructor_args ["quick_scale" ] = quick_scale
2559- estim = test_model .Estimator (
2560- input_data = input_data ,
2561- init_model = init_model ,
2562- init_a = init_a ,
2563- init_b = init_b ,
2564- provide_optimizers = provide_optimizers ,
2565- termination_type = termination_type ,
2566- dtype = dtype ,
2567- ** constructor_args
2568- )
2557+ logger .info ("Fitting model..." )
2558+ logger .debug (" * Assembling input data..." )
2559+ input_data = InputData .new (
2560+ data = data ,
2561+ design_loc = design_loc ,
2562+ design_scale = design_scale ,
2563+ constraints_loc = constraints_loc ,
2564+ constraints_scale = constraints_scale ,
2565+ size_factors = size_factors ,
2566+ feature_names = gene_names ,
2567+ )
25692568
2570- logger .debug (" * Initializing Estimator..." )
2571- estim .initialize ()
2569+ logger .debug (" * Set up Estimator..." )
2570+ constructor_args = {}
2571+ if batch_size is not None :
2572+ constructor_args ["batch_size" ] = batch_size
2573+ if quick_scale is not None :
2574+ constructor_args ["quick_scale" ] = quick_scale
2575+ estim = Estimator (
2576+ input_data = input_data ,
2577+ init_model = init_model ,
2578+ init_a = init_a ,
2579+ init_b = init_b ,
2580+ provide_optimizers = provide_optimizers ,
2581+ termination_type = termination_type ,
2582+ dtype = dtype ,
2583+ ** constructor_args
2584+ )
25722585
2573- logger .debug (" * Run estimation..." )
2574- # training:
2575- if callable (training_strategy ):
2576- # call training_strategy if it is a function
2577- training_strategy (estim )
2578- else :
2579- estim .train_sequence (training_strategy )
2586+ logger .debug (" * Initializing Estimator..." )
2587+ estim .initialize ()
25802588
2581- if close_session :
2582- logger .debug (" * Finalize estimation..." )
2583- model = estim .finalize ()
2584- else :
2585- model = estim
2586- logger .debug (" * Model fitting done." )
2589+ logger .debug (" * Run estimation..." )
2590+ # training:
2591+ if callable (training_strategy ):
2592+ # call training_strategy if it is a function
2593+ training_strategy (estim )
2594+ else :
2595+ estim .train_sequence (training_strategy = training_strategy )
25872596
2597+ if close_session :
2598+ logger .debug (" * Finalize estimation..." )
2599+ model = estim .finalize ()
25882600 else :
2589- raise ValueError ('base.test(): `noise_model="%s"` not recognized.' % noise_model )
2601+ model = estim
2602+ logger .debug (" * Model fitting done." )
25902603
25912604 return model
25922605
@@ -2788,10 +2801,9 @@ def lrt(
27882801def wald (
27892802 data ,
27902803 factor_loc_totest : Union [str , List [str ]] = None ,
2791- coef_to_test : Union [str , List [str ]] = None , # e.g. coef_to_test="B"
2792- formula : str = None ,
2804+ coef_to_test : Union [str , List [str ]] = None ,
27932805 formula_loc : str = None ,
2794- formula_scale : str = None ,
2806+ formula_scale : str = "~1" ,
27952807 as_numeric : Union [List [str ], Tuple [str ], str ] = (),
27962808 init_a : Union [np .ndarray , str ] = "AUTO" ,
27972809 init_b : Union [np .ndarray , str ] = "AUTO" ,
@@ -2918,10 +2930,6 @@ def wald(
29182930 if len (kwargs ) != 0 :
29192931 logger .debug ("additional kwargs: %s" , str (kwargs ))
29202932
2921- if formula_loc is None :
2922- formula_loc = formula
2923- if formula_scale is None :
2924- formula_scale = formula
29252933 if dmat_loc is None and formula_loc is None :
29262934 raise ValueError ("Supply either dmat_loc or formula_loc or formula." )
29272935 if dmat_scale is None and formula_scale is None :
@@ -2959,7 +2967,8 @@ def wald(
29592967 else :
29602968 design_scale = dmat_scale
29612969
2962- # Coefficients to test:
2970+ # Define indices of coefficients to test:
2971+ contraints_loc_temp = constraints_loc if constraints_loc is not None else np .eye (design_loc .shape [- 1 ])
29632972 if factor_loc_totest is not None :
29642973 # Select coefficients to test via formula model:
29652974 col_indices = np .concatenate ([
@@ -2992,6 +3001,12 @@ def wald(
29923001 ])
29933002 else :
29943003 raise ValueError ("either set factor_loc_totest or coef_to_test" )
3004+ # Check that all tested coefficients are independent:
3005+ for x in col_indices :
3006+ if np .sum (contraints_loc_temp [x ,:]) != 1 :
3007+ raise ValueError ("Constraints input is wrong: not all tested coefficients are unconstrained." )
3008+ # Adjust tested coefficients from dependent to independent (fitted) parameters:
3009+ col_indices = np .array ([np .where (contraints_loc_temp [x ,:] == 1 )[0 ][0 ] for x in col_indices ])
29953010
29963011 ## Fit GLM:
29973012 model = _fit (
0 commit comments