1+ import pytest
12import numpy as np
3+ import pandas as pd
24
3- from skglm .solvers import LBFGS
45from skglm .penalties import L2
5- from skglm .datafits import Logistic
6+ from skglm .solvers import LBFGS
7+ from skglm .datafits import Logistic , Cox
68
79from sklearn .linear_model import LogisticRegression
810
9- from skglm .utils .data import make_correlated_data
1011from skglm .utils .jit_compilation import compiled_clone
12+ from skglm .utils .data import make_correlated_data , make_dummy_survival_data
1113
1214
1315def test_lbfgs_L2_logreg ():
1416 reg = 1.
15- n_samples , n_features = 50 , 10
17+ n_samples , n_features = 100 , 50
1618
1719 X , y , _ = make_correlated_data (
1820 n_samples , n_features , random_state = 0 )
@@ -21,19 +23,59 @@ def test_lbfgs_L2_logreg():
2123 # fit L-BFGS
2224 datafit = compiled_clone (Logistic ())
2325 penalty = compiled_clone (L2 (reg ))
24- w , * _ = LBFGS ().solve (X , y , datafit , penalty )
26+ w , * _ = LBFGS (tol = 1e-12 ).solve (X , y , datafit , penalty )
2527
2628 # fit scikit learn
2729 estimator = LogisticRegression (
2830 penalty = 'l2' ,
2931 C = 1 / (n_samples * reg ),
30- fit_intercept = False
31- )
32- estimator .fit (X , y )
32+ fit_intercept = False ,
33+ tol = 1e-12 ,
34+ ).fit (X , y )
35+
36+ np .testing .assert_allclose (w , estimator .coef_ .flatten ())
37+
38+
39+ @pytest .mark .parametrize ("use_efron" , [True , False ])
40+ def test_L2_Cox (use_efron ):
41+ try :
42+ from lifelines import CoxPHFitter
43+ except ModuleNotFoundError :
44+ pytest .xfail (
45+ "Testing L2 Cox Estimator requires `lifelines` packages\n "
46+ "Run `pip install lifelines`"
47+ )
48+
49+ alpha = 10.
50+ n_samples , n_features = 100 , 50
3351
34- np .testing .assert_allclose (
35- w , estimator .coef_ .flatten (), atol = 1e-4
52+ tm , s , X = make_dummy_survival_data (
53+ n_samples , n_features , normalize = True ,
54+ with_ties = use_efron , random_state = 0 )
55+
56+ datafit = compiled_clone (Cox (use_efron ))
57+ penalty = compiled_clone (L2 (alpha ))
58+
59+ datafit .initialize (X , (tm , s ))
60+ w , * _ = LBFGS ().solve (X , (tm , s ), datafit , penalty )
61+
62+ # fit lifeline estimator
63+ stacked_tm_s_X = np .hstack ((tm [:, None ], s [:, None ], X ))
64+ df = pd .DataFrame (stacked_tm_s_X )
65+
66+ estimator = CoxPHFitter (penalizer = alpha , l1_ratio = 0. ).fit (
67+ df , duration_col = 0 , event_col = 1
3668 )
69+ w_ll = estimator .params_ .values
70+
71+ p_obj_skglm = datafit .value ((tm , s ), w , X @ w ) + penalty .value (w )
72+ p_obj_ll = datafit .value ((tm , s ), w_ll , X @ w_ll ) + penalty .value (w_ll )
73+
74+ # despite increasing tol in lifelines, solutions are quite far apart
75+ # suspecting lifelines https://github.com/CamDavidsonPilon/lifelines/pull/1534
76+ # as our solution gives the lowest objective value
77+ np .testing .assert_allclose (w , w_ll , rtol = 1e-1 )
78+ np .testing .assert_allclose (p_obj_skglm , p_obj_ll , rtol = 1e-6 )
3779
3880
3981if __name__ == "__main__" :
0 commit comments