|
| 1 | +import numpy as np |
| 2 | +from numba import njit |
| 3 | +from skglm.solvers.base import BaseSolver |
| 4 | + |
| 5 | + |
| 6 | +class FISTA(BaseSolver): |
| 7 | + r"""ISTA solver with Nesterov acceleration (FISTA).""" |
| 8 | + |
| 9 | + def __init__(self, max_iter=100, tol=1e-4, fit_intercept=False, warm_start=False, |
| 10 | + opt_freq=50, verbose=0): |
| 11 | + self.max_iter = max_iter |
| 12 | + self.tol = tol |
| 13 | + self.fit_intercept = fit_intercept |
| 14 | + self.warm_start = warm_start |
| 15 | + self.opt_freq = opt_freq |
| 16 | + self.verbose = verbose |
| 17 | + |
| 18 | + def solve(self, X, y, penalty, w_init=None, weights=None): |
| 19 | + # needs a quadratic datafit, but works with L1, WeightedL1, SLOPE |
| 20 | + n_samples, n_features = X.shape |
| 21 | + all_features = np.arange(n_features) |
| 22 | + t_new = 1 |
| 23 | + |
| 24 | + w = w_init.copy() if w_init is not None else np.zeros(n_features) |
| 25 | + z = w_init.copy() if w_init is not None else np.zeros(n_features) |
| 26 | + weights = weights if weights is not None else np.ones(n_features) |
| 27 | + |
| 28 | + # FISTA with Gram update |
| 29 | + G = X.T @ X |
| 30 | + Xty = X.T @ y |
| 31 | + lipschitz = np.linalg.norm(X, ord=2) ** 2 / n_samples |
| 32 | + |
| 33 | + for n_iter in range(self.max_iter): |
| 34 | + t_old = t_new |
| 35 | + t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 |
| 36 | + w_old = w.copy() |
| 37 | + grad = (G @ z - Xty) / n_samples |
| 38 | + z -= grad / lipschitz |
| 39 | + # TODO: TO DISCUSS! |
| 40 | + # XXX: should add a full prox update |
| 41 | + for j in range(n_features): |
| 42 | + w[j] = penalty.prox_1d(z[j], 1 / lipschitz, j) |
| 43 | + z = w + (t_old - 1.) / t_new * (w - w_old) |
| 44 | + |
| 45 | + if n_iter % self.opt_freq == 0: |
| 46 | + opt = penalty.subdiff_distance(w, grad, all_features) |
| 47 | + stop_crit = np.max(opt) |
| 48 | + |
| 49 | + if self.verbose: |
| 50 | + p_obj = (np.sum((y - X @ w) ** 2) / (2 * n_samples) |
| 51 | + + penalty.alpha * penalty.value(w)) |
| 52 | + print( |
| 53 | + f"Iteration {n_iter+1}: {p_obj:.10f}, " |
| 54 | + f"stopping crit: {stop_crit:.2e}" |
| 55 | + ) |
| 56 | + |
| 57 | + if stop_crit < self.tol: |
| 58 | + if self.verbose: |
| 59 | + print(f"Stopping criterion max violation: {stop_crit:.2e}") |
| 60 | + break |
| 61 | + return w |
0 commit comments