Skip to content

Commit 8584299

Browse files
committed
CLN
1 parent 0868b0f commit 8584299

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

skglm/solvers/fista.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,25 @@
33
from skglm.solvers.base import BaseSolver
44

55

6+
@njit
7+
def _prox_vec(w, z, penalty, lipschitz):
8+
# XXX: TO DISCUSS: should add a vectorized prox update
9+
n_features = w.shape[0]
10+
for j in range(n_features):
11+
w[j] = penalty.prox_1d(z[j], 1 / lipschitz, j)
12+
return w
13+
14+
615
class FISTA(BaseSolver):
716
r"""ISTA solver with Nesterov acceleration (FISTA)."""
817

918
def __init__(self, max_iter=100, tol=1e-4, fit_intercept=False, warm_start=False,
10-
opt_freq=50, verbose=0):
19+
opt_freq=100, verbose=0):
1120
self.max_iter = max_iter
1221
self.tol = tol
1322
self.fit_intercept = fit_intercept
1423
self.warm_start = warm_start
15-
self.opt_freq = opt_freq
24+
self.opt_freq = opt_freq
1625
self.verbose = verbose
1726

1827
def solve(self, X, y, penalty, w_init=None, weights=None):
@@ -36,10 +45,7 @@ def solve(self, X, y, penalty, w_init=None, weights=None):
3645
w_old = w.copy()
3746
grad = (G @ z - Xty) / n_samples
3847
z -= grad / lipschitz
39-
# TODO: TO DISCUSS!
40-
# XXX: should add a full prox update
41-
for j in range(n_features):
42-
w[j] = penalty.prox_1d(z[j], 1 / lipschitz, j)
48+
w = _prox_vec(w, z, penalty, lipschitz)
4349
z = w + (t_old - 1.) / t_new * (w - w_old)
4450

4551
if n_iter % self.opt_freq == 0:
@@ -48,7 +54,7 @@ def solve(self, X, y, penalty, w_init=None, weights=None):
4854

4955
if self.verbose:
5056
p_obj = (np.sum((y - X @ w) ** 2) / (2 * n_samples)
51-
+ penalty.alpha * penalty.value(w))
57+
+ penalty.value(w))
5258
print(
5359
f"Iteration {n_iter+1}: {p_obj:.10f}, "
5460
f"stopping crit: {stop_crit:.2e}"

0 commit comments

Comments
 (0)