Skip to content
This repository was archived by the owner on Oct 21, 2025. It is now read-only.

Commit 08482df

Browse files
committed
Fix plotting.
1 parent 2861451 commit 08482df

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

diffxpy/testing/det.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
import scipy.sparse
1414
import sparse
1515
from typing import Union, Dict, Tuple, List, Set
16+
from batchglm.models.glm_norm import Model
17+
from batchglm.utils.input import InputDataGLM
18+
from batchglm.train.numpy.glm_norm import Estimator
1619

1720
from .utils import split_x, dmat_unique
1821
from ..stats import stats
@@ -968,7 +971,7 @@ def plot_comparison_ols_coef(
968971
import matplotlib.pyplot as plt
969972
from matplotlib import gridspec
970973
from matplotlib import rcParams
971-
from batchglm.api.models.tf1.glm_norm import Estimator, InputDataGLM
974+
972975

973976
plt.ioff()
974977

@@ -983,12 +986,12 @@ def plot_comparison_ols_coef(
983986
size_factors=self.model_estim.model_container.size_factors,
984987
feature_names=self.model_estim.model_container.features,
985988
)
989+
model = Model(input_data=input_data_ols)
986990
estim_ols = Estimator(
987-
input_data=input_data_ols,
991+
model=model,
988992
init_model=None,
989993
init_a="standard",
990994
init_b="standard",
991-
dtype=self.model_estim.model_container.theta_location.dtype
992995
)
993996
estim_ols.initialize()
994997
store_ols = estim_ols.finalize()
@@ -999,7 +1002,7 @@ def plot_comparison_ols_coef(
9991002
# Prepare parameter summary of both model fits.
10001003
par_loc = self.model_estim.model_container.data.coords["design_loc_params"].values
10011004

1002-
theta_location_ols = store_ols.theta_location
1005+
theta_location_ols = store_ols.model_container.theta_location
10031006
theta_location_ols[1:, :] = (theta_location_ols[1:, :] + theta_location_ols[[0], :]) / theta_location_ols[[0], :]
10041007

10051008
theta_location_user = self.model_estim.model_container.theta_location
@@ -1107,7 +1110,6 @@ def plot_comparison_ols_pred(
11071110
import matplotlib.pyplot as plt
11081111
from matplotlib import gridspec
11091112
from matplotlib import rcParams
1110-
from batchglm.api.models.tf1.glm_norm import Estimator, InputDataGLM
11111113

11121114
plt.ioff()
11131115

@@ -1122,12 +1124,12 @@ def plot_comparison_ols_pred(
11221124
size_factors=self.model_estim.model_container.size_factors,
11231125
feature_names=self.model_estim.model_container.features,
11241126
)
1127+
model = Model(input_data=input_data_ols)
11251128
estim_ols = Estimator(
1126-
input_data=input_data_ols,
1129+
model=model,
11271130
init_model=None,
11281131
init_a="standard",
11291132
init_b="standard",
1130-
dtype=self.model_estim.model_container.theta_location.dtype
11311133
)
11321134
estim_ols.initialize()
11331135
store_ols = estim_ols.finalize()
@@ -1164,8 +1166,8 @@ def plot_comparison_ols_pred(
11641166
y_user = self.model_estim.model_container.inverse_link_loc(
11651167
np.matmul(self.model_estim.model_container.design_loc[pred_n_cells, :], self.model_estim.model_container.theta_location).flatten()
11661168
)
1167-
y_ols = store_ols.inverse_link_loc(
1168-
np.matmul(store_ols.design_loc[pred_n_cells, :], store_ols.theta_location).flatten()
1169+
y_ols = store_ols.model_container.inverse_link_loc(
1170+
np.matmul(store_ols.model_container.design_loc[pred_n_cells, :], store_ols.model_container.theta_location).flatten()
11691171
)
11701172
if log1p_transform:
11711173
x = np.log(x+1)

0 commit comments

Comments
 (0)