|
14 | 14 | from sklearn.linear_model import ElasticNet as ElasticNet_sklearn |
15 | 15 | from sklearn.linear_model import LogisticRegression as LogReg_sklearn |
16 | 16 | from sklearn.linear_model import MultiTaskLasso as MultiTaskLasso_sklearn |
| 17 | +from sklearn.linear_model import PoissonRegressor, GammaRegressor |
17 | 18 | from sklearn.model_selection import GridSearchCV |
18 | 19 | from sklearn.svm import LinearSVC as LinearSVC_sklearn |
19 | 20 | from sklearn.utils.estimator_checks import check_estimator |
|
23 | 24 | from skglm.estimators import ( |
24 | 25 | GeneralizedLinearEstimator, Lasso, MultiTaskLasso, WeightedLasso, ElasticNet, |
25 | 26 | MCPRegression, SparseLogisticRegression, LinearSVC, GroupLasso, CoxEstimator) |
26 | | -from skglm.datafits import Logistic, Quadratic, QuadraticSVC, QuadraticMultiTask, Cox |
27 | | -from skglm.penalties import L1, IndicatorBox, L1_plus_L2, MCPenalty, WeightedL1, SLOPE |
| 27 | +from skglm.datafits import ( |
| 28 | + Logistic, Quadratic, QuadraticSVC, QuadraticMultiTask, Cox, Poisson, Gamma |
| 29 | +) |
| 30 | +from skglm.penalties import ( |
| 31 | + L1, IndicatorBox, L1_plus_L2, MCPenalty, WeightedL1, SLOPE |
| 32 | +) |
28 | 33 | from skglm.solvers import AndersonCD, FISTA, ProxNewton |
29 | 34 |
|
30 | 35 | n_samples = 50 |
@@ -629,5 +634,23 @@ def test_SLOPE_printing(): |
629 | 634 | assert isinstance(res, str) |
630 | 635 |
|
631 | 636 |
|
| 637 | +@pytest.mark.parametrize( |
| 638 | + "sklearn_reg, skglm_datafit", |
| 639 | + [(PoissonRegressor, Poisson), (GammaRegressor, Gamma)] |
| 640 | +) |
| 641 | +def test_inverse_link_prediction(sklearn_reg, skglm_datafit): |
| 642 | + np.random.seed(42) |
| 643 | + X = np.random.randn(20, 5) |
| 644 | + y = np.random.randint(1, 6, size=20) # Use 1-6 for both (Gamma needs y>0) |
| 645 | + sklearn_pred = sklearn_reg(alpha=0.0, max_iter=10_000, |
| 646 | + tol=1e-8).fit(X, y).predict(X) |
| 647 | + skglm_pred = GeneralizedLinearEstimator( |
| 648 | + datafit=skglm_datafit(), |
| 649 | + penalty=L1_plus_L2(0.0, l1_ratio=0.0), |
| 650 | + solver=ProxNewton(fit_intercept=True, max_iter=10_000, tol=1e-8) |
| 651 | + ).fit(X, y).predict(X) |
| 652 | + np.testing.assert_allclose(sklearn_pred, skglm_pred, rtol=1e-6, atol=1e-8) |
| 653 | + |
| 654 | + |
632 | 655 | if __name__ == "__main__": |
633 | 656 | pass |
0 commit comments