Skip to content

Commit 3635f24

Browse files
committed
FISTA with global lipschitz
1 parent e47c68a commit 3635f24

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

skglm/solvers/fista.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,8 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
3333
z = w_init.copy() if w_init is not None else np.zeros(n_features)
3434
Xw = Xw_init.copy() if Xw_init is not None else np.zeros(n_samples)
3535

36-
# line search?
37-
# lipschitz = np.max(datafit.lipschitz)
38-
lipschitz = np.linalg.norm(X, ord=2) ** 2 / n_samples
36+
# TODO: OR line search
37+
lipschitz = datafit.global_lipschitz
3938

4039
for n_iter in range(self.max_iter):
4140
t_old = t_new

toy_fista.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import numpy as np
22
from numpy.linalg import norm
3-
from skglm.datafits.single_task import Quadratic
3+
from skglm.datafits.single_task import Quadratic, Logistic
44
from skglm.solvers import FISTA
55
from skglm.penalties import L1
6-
from skglm.estimators import Lasso
6+
from skglm.estimators import SparseLogisticRegression, Lasso
77
from skglm.utils import make_correlated_data, compiled_clone
88

99

@@ -19,6 +19,10 @@
1919
tol = 1e-10
2020

2121
solver = FISTA(max_iter=max_iter, tol=tol, opt_freq=obj_freq, verbose=1)
22+
23+
##############
24+
# Quadratic
25+
##############
2226
datafit = compiled_clone(Quadratic())
2327
datafit.initialize(X, y)
2428
penalty = compiled_clone(L1(alpha))
@@ -28,3 +32,20 @@
2832
clf.fit(X, y)
2933

3034
np.testing.assert_allclose(w, clf.coef_, rtol=1e-5)
35+
36+
##############
37+
# Logistic
38+
##############
39+
y = np.sign(y)
40+
alpha_max = norm(X.T @ y, ord=np.inf) / (4 * n_samples)
41+
alpha = alpha_max / 10
42+
43+
datafit = compiled_clone(Logistic())
44+
datafit.initialize(X, y)
45+
penalty = compiled_clone(L1(alpha))
46+
w = solver.solve(X, y, datafit, penalty)
47+
48+
clf = SparseLogisticRegression(alpha=alpha, tol=tol, fit_intercept=False)
49+
clf.fit(X, y)
50+
51+
np.testing.assert_allclose(w, np.squeeze(clf.coef_), rtol=1e-3)

0 commit comments

Comments
 (0)