1313import scipy .sparse
1414import sparse
1515from typing import Union , Dict , Tuple , List , Set
16+ from batchglm .models .glm_norm import Model
17+ from batchglm .utils .input import InputDataGLM
18+ from batchglm .train .numpy .glm_norm import Estimator
1619
1720from .utils import split_x , dmat_unique
1821from ..stats import stats
@@ -968,7 +971,7 @@ def plot_comparison_ols_coef(
968971 import matplotlib .pyplot as plt
969972 from matplotlib import gridspec
970973 from matplotlib import rcParams
971- from batchglm . api . models . tf1 . glm_norm import Estimator , InputDataGLM
974+
972975
973976 plt .ioff ()
974977
@@ -983,12 +986,12 @@ def plot_comparison_ols_coef(
983986 size_factors = self .model_estim .model_container .size_factors ,
984987 feature_names = self .model_estim .model_container .features ,
985988 )
989+ model = Model (input_data = input_data_ols )
986990 estim_ols = Estimator (
987- input_data = input_data_ols ,
991+ model = model ,
988992 init_model = None ,
989993 init_a = "standard" ,
990994 init_b = "standard" ,
991- dtype = self .model_estim .model_container .theta_location .dtype
992995 )
993996 estim_ols .initialize ()
994997 store_ols = estim_ols .finalize ()
@@ -999,7 +1002,7 @@ def plot_comparison_ols_coef(
9991002 # Prepare parameter summary of both model fits.
10001003 par_loc = self .model_estim .model_container .data .coords ["design_loc_params" ].values
10011004
1002- theta_location_ols = store_ols .theta_location
1005+ theta_location_ols = store_ols .model_container . theta_location
10031006 theta_location_ols [1 :, :] = (theta_location_ols [1 :, :] + theta_location_ols [[0 ], :]) / theta_location_ols [[0 ], :]
10041007
10051008 theta_location_user = self .model_estim .model_container .theta_location
@@ -1107,7 +1110,6 @@ def plot_comparison_ols_pred(
11071110 import matplotlib .pyplot as plt
11081111 from matplotlib import gridspec
11091112 from matplotlib import rcParams
1110- from batchglm .api .models .tf1 .glm_norm import Estimator , InputDataGLM
11111113
11121114 plt .ioff ()
11131115
@@ -1122,12 +1124,12 @@ def plot_comparison_ols_pred(
11221124 size_factors = self .model_estim .model_container .size_factors ,
11231125 feature_names = self .model_estim .model_container .features ,
11241126 )
1127+ model = Model (input_data = input_data_ols )
11251128 estim_ols = Estimator (
1126- input_data = input_data_ols ,
1129+ model = model ,
11271130 init_model = None ,
11281131 init_a = "standard" ,
11291132 init_b = "standard" ,
1130- dtype = self .model_estim .model_container .theta_location .dtype
11311133 )
11321134 estim_ols .initialize ()
11331135 store_ols = estim_ols .finalize ()
@@ -1164,8 +1166,8 @@ def plot_comparison_ols_pred(
11641166 y_user = self .model_estim .model_container .inverse_link_loc (
11651167 np .matmul (self .model_estim .model_container .design_loc [pred_n_cells , :], self .model_estim .model_container .theta_location ).flatten ()
11661168 )
1167- y_ols = store_ols .inverse_link_loc (
1168- np .matmul (store_ols .design_loc [pred_n_cells , :], store_ols .theta_location ).flatten ()
1169+ y_ols = store_ols .model_container . inverse_link_loc (
1170+ np .matmul (store_ols .model_container . design_loc [pred_n_cells , :], store_ols . model_container .theta_location ).flatten ()
11691171 )
11701172 if log1p_transform :
11711173 x = np .log (x + 1 )
0 commit comments