11import numpy as np
22from numba import njit
33from skglm .solvers .base import BaseSolver
4+ from skglm .solvers .common import construct_grad
45
56
67@njit
@@ -23,37 +24,35 @@ def __init__(self, max_iter=100, tol=1e-4, fit_intercept=False, warm_start=False
2324 self .opt_freq = opt_freq
2425 self .verbose = verbose
2526
26- def solve (self , X , y , penalty , w_init = None , weights = None ):
27- # needs a quadratic datafit, but works with L1, WeightedL1, SLOPE
27+ def solve (self , X , y , datafit , penalty , w_init = None , Xw_init = None ):
2828 n_samples , n_features = X .shape
2929 all_features = np .arange (n_features )
3030 t_new = 1
3131
3232 w = w_init .copy () if w_init is not None else np .zeros (n_features )
3333 z = w_init .copy () if w_init is not None else np .zeros (n_features )
34- weights = weights if weights is not None else np .ones ( n_features )
34+ Xw = Xw_init . copy () if Xw_init is not None else np .zeros ( n_samples )
3535
36- # FISTA with Gram update
37- G = X .T @ X
38- Xty = X .T @ y
36+ # line search?
37+ # lipschitz = np.max(datafit.lipschitz)
3938 lipschitz = np .linalg .norm (X , ord = 2 ) ** 2 / n_samples
4039
4140 for n_iter in range (self .max_iter ):
4241 t_old = t_new
4342 t_new = (1 + np .sqrt (1 + 4 * t_old ** 2 )) / 2
4443 w_old = w .copy ()
45- grad = ( G @ z - Xty ) / n_samples
44+ grad = construct_grad ( X , y , z , X @ z , datafit , all_features )
4645 z -= grad / lipschitz
4746 w = _prox_vec (w , z , penalty , lipschitz )
47+ Xw = X @ w
4848 z = w + (t_old - 1. ) / t_new * (w - w_old )
4949
5050 if n_iter % self .opt_freq == 0 :
5151 opt = penalty .subdiff_distance (w , grad , all_features )
5252 stop_crit = np .max (opt )
5353
5454 if self .verbose :
55- p_obj = (np .sum ((y - X @ w ) ** 2 ) / (2 * n_samples )
56- + penalty .value (w ))
55+ p_obj = datafit .value (y , w , Xw ) + penalty .value (w )
5756 print (
5857 f"Iteration { n_iter + 1 } : { p_obj :.10f} , "
5958 f"stopping crit: { stop_crit :.2e} "
0 commit comments