@@ -305,6 +305,7 @@ def plot_volcano(
305305 plt .show ()
306306
307307 plt .close (fig )
308+ plt .ion ()
308309
309310 if return_axs :
310311 return ax
@@ -853,12 +854,19 @@ def summary(
853854 def plot_vs_ttest (
854855 self ,
855856 log10 = False ,
857+ show : bool = True ,
858+ save : Union [str , None ] = None ,
859+ suffix : str = "_plot_vs_ttest.png" ,
856860 return_axs : bool = False
857861 ):
858862 """
859863 Normalizes data by size factors if any were used in model.
860864
861865 :param log10:
866+ :param show: Whether (if save is not None) and where (save indicates dir and file stem) to display plot.
867+ :param save: Path+file name stem to save plots to.
868+ File will be save+suffix. Does not save if save is None.
869+ :param suffix: Suffix for file name to save plot to. Also use this to set the file type.
862870 :param return_axs: Whether to return axis objects.
863871
864872 :return:
@@ -867,12 +875,17 @@ def plot_vs_ttest(
867875 import seaborn as sns
868876 from .tests import t_test
869877
878+ plt .ioff ()
879+
870880 grouping = np .asarray (self .model_estim .input_data .design_loc [:, self .coef_loc_totest ])
871881 # Normalize by size factors that were used in regression.
872- sf = np .broadcast_to (np .expand_dims (self .model_estim .input_data .size_factors , axis = 1 ),
873- shape = self .model_estim .x .shape )
882+ if self .model_estim .input_data .size_factors is not None :
883+ sf = np .broadcast_to (np .expand_dims (self .model_estim .input_data .size_factors , axis = 1 ),
884+ shape = self .model_estim .x .shape )
885+ else :
886+ sf = np .ones (shape = (self .model_estim .x .shape [0 ], 1 ))
874887 ttest = t_test (
875- data = self .model_estim .X . multiply ( 1 / sf , copy = True ) ,
888+ data = self .model_estim .x / sf ,
876889 grouping = grouping ,
877890 gene_names = self .gene_ids ,
878891 )
@@ -889,6 +902,16 @@ def plot_vs_ttest(
889902
890903 ax .set (xlabel = "t-test" , ylabel = 'wald test' )
891904
905+ # Save, show and return figure.
906+ if save is not None :
907+ plt .savefig (save + suffix )
908+
909+ if show :
910+ plt .show ()
911+
912+ plt .close (fig )
913+ plt .ion ()
914+
892915 if return_axs :
893916 return ax
894917 else :
@@ -930,6 +953,8 @@ def plot_comparison_ols_coef(
930953 from matplotlib import rcParams
931954 from batchglm .api .models .glm_norm import Estimator , InputDataGLM
932955
956+ plt .ioff ()
957+
933958 # Run OLS model fit to have comparison coefficients.
934959 if self ._store_ols is None :
935960 input_data_ols = InputDataGLM (
@@ -1067,6 +1092,8 @@ def plot_comparison_ols_pred(
10671092 from matplotlib import rcParams
10681093 from batchglm .api .models .glm_norm import Estimator , InputDataGLM
10691094
1095+ plt .ioff ()
1096+
10701097 # Run OLS model fit to have comparison coefficients.
10711098 if self ._store_ols is None :
10721099 input_data_ols = InputDataGLM (
0 commit comments