33from skglm .solvers .base import BaseSolver
44
55
6+ @njit
7+ def _prox_vec (w , z , penalty , lipschitz ):
8+ # XXX: TO DISCUSS: should add a vectorized prox update
9+ n_features = w .shape [0 ]
10+ for j in range (n_features ):
11+ w [j ] = penalty .prox_1d (z [j ], 1 / lipschitz , j )
12+ return w
13+
14+
615class FISTA (BaseSolver ):
716 r"""ISTA solver with Nesterov acceleration (FISTA)."""
817
918 def __init__ (self , max_iter = 100 , tol = 1e-4 , fit_intercept = False , warm_start = False ,
10- opt_freq = 50 , verbose = 0 ):
19+ opt_freq = 100 , verbose = 0 ):
1120 self .max_iter = max_iter
1221 self .tol = tol
1322 self .fit_intercept = fit_intercept
1423 self .warm_start = warm_start
15- self .opt_freq = opt_freq
24+ self .opt_freq = opt_freq
1625 self .verbose = verbose
1726
1827 def solve (self , X , y , penalty , w_init = None , weights = None ):
@@ -36,10 +45,7 @@ def solve(self, X, y, penalty, w_init=None, weights=None):
3645 w_old = w .copy ()
3746 grad = (G @ z - Xty ) / n_samples
3847 z -= grad / lipschitz
39- # TODO: TO DISCUSS!
40- # XXX: should add a full prox update
41- for j in range (n_features ):
42- w [j ] = penalty .prox_1d (z [j ], 1 / lipschitz , j )
48+ w = _prox_vec (w , z , penalty , lipschitz )
4349 z = w + (t_old - 1. ) / t_new * (w - w_old )
4450
4551 if n_iter % self .opt_freq == 0 :
@@ -48,7 +54,7 @@ def solve(self, X, y, penalty, w_init=None, weights=None):
4854
4955 if self .verbose :
5056 p_obj = (np .sum ((y - X @ w ) ** 2 ) / (2 * n_samples )
51- + penalty .alpha * penalty . value (w ))
57+ + penalty .value (w ))
5258 print (
5359 f"Iteration { n_iter + 1 } : { p_obj :.10f} , "
5460 f"stopping crit: { stop_crit :.2e} "
0 commit comments