Skip to content

Commit 0868b0f

Browse files
committed
POC FISTA
1 parent 8b08c09 commit 0868b0f

File tree

3 files changed

+90
-1
lines changed

3 files changed

+90
-1
lines changed

skglm/solvers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from .anderson_cd import AndersonCD
22
from .base import BaseSolver
3+
from .fista import FISTA
34
from .gram_cd import GramCD
45
from .group_bcd import GroupBCD
56
from .multitask_bcd import MultiTaskBCD
67
from .prox_newton import ProxNewton
78

89

9-
__all__ = [AndersonCD, BaseSolver, GramCD, GroupBCD, MultiTaskBCD, ProxNewton]
10+
__all__ = [AndersonCD, BaseSolver, FISTA, GramCD, GroupBCD, MultiTaskBCD, ProxNewton]

skglm/solvers/fista.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import numpy as np
2+
from numba import njit
3+
from skglm.solvers.base import BaseSolver
4+
5+
6+
class FISTA(BaseSolver):
7+
r"""ISTA solver with Nesterov acceleration (FISTA)."""
8+
9+
def __init__(self, max_iter=100, tol=1e-4, fit_intercept=False, warm_start=False,
10+
opt_freq=50, verbose=0):
11+
self.max_iter = max_iter
12+
self.tol = tol
13+
self.fit_intercept = fit_intercept
14+
self.warm_start = warm_start
15+
self.opt_freq = opt_freq
16+
self.verbose = verbose
17+
18+
def solve(self, X, y, penalty, w_init=None, weights=None):
19+
# needs a quadratic datafit, but works with L1, WeightedL1, SLOPE
20+
n_samples, n_features = X.shape
21+
all_features = np.arange(n_features)
22+
t_new = 1
23+
24+
w = w_init.copy() if w_init is not None else np.zeros(n_features)
25+
z = w_init.copy() if w_init is not None else np.zeros(n_features)
26+
weights = weights if weights is not None else np.ones(n_features)
27+
28+
# FISTA with Gram update
29+
G = X.T @ X
30+
Xty = X.T @ y
31+
lipschitz = np.linalg.norm(X, ord=2) ** 2 / n_samples
32+
33+
for n_iter in range(self.max_iter):
34+
t_old = t_new
35+
t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2
36+
w_old = w.copy()
37+
grad = (G @ z - Xty) / n_samples
38+
z -= grad / lipschitz
39+
# TODO: TO DISCUSS!
40+
# XXX: should add a full prox update
41+
for j in range(n_features):
42+
w[j] = penalty.prox_1d(z[j], 1 / lipschitz, j)
43+
z = w + (t_old - 1.) / t_new * (w - w_old)
44+
45+
if n_iter % self.opt_freq == 0:
46+
opt = penalty.subdiff_distance(w, grad, all_features)
47+
stop_crit = np.max(opt)
48+
49+
if self.verbose:
50+
p_obj = (np.sum((y - X @ w) ** 2) / (2 * n_samples)
51+
+ penalty.alpha * penalty.value(w))
52+
print(
53+
f"Iteration {n_iter+1}: {p_obj:.10f}, "
54+
f"stopping crit: {stop_crit:.2e}"
55+
)
56+
57+
if stop_crit < self.tol:
58+
if self.verbose:
59+
print(f"Stopping criterion max violation: {stop_crit:.2e}")
60+
break
61+
return w

toy_fista.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import numpy as np
2+
from numpy.linalg import norm
3+
from skglm.solvers import FISTA
4+
from skglm.penalties import L1
5+
from skglm.estimators import Lasso
6+
from skglm.utils import make_correlated_data, compiled_clone
7+
8+
9+
X, y, _ = make_correlated_data(n_samples=200, n_features=100, random_state=24)
10+
11+
n_samples, n_features = X.shape
12+
alpha_max = norm(X.T @ y, ord=np.inf) / n_samples
13+
14+
alpha = alpha_max / 10
15+
16+
max_iter = 1000
17+
obj_freq = 100
18+
tol = 1e-10
19+
20+
solver = FISTA(max_iter=max_iter, tol=tol, opt_freq=obj_freq, verbose=1)
21+
penalty = compiled_clone(L1(alpha))
22+
w = solver.solve(X, y, penalty)
23+
24+
clf = Lasso(alpha=alpha, tol=tol, fit_intercept=False)
25+
clf.fit(X, y)
26+
27+
np.testing.assert_allclose(w, clf.coef_, rtol=1e-5)

0 commit comments

Comments
 (0)