Skip to content

Commit 4940a0d

Browse files
committed
WIP Lipschitz
1 parent c82e32e commit 4940a0d

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

skglm/solvers/fista.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
from numba import njit
33
from skglm.solvers.base import BaseSolver
4+
from skglm.solvers.common import construct_grad
45

56

67
@njit
@@ -23,37 +24,35 @@ def __init__(self, max_iter=100, tol=1e-4, fit_intercept=False, warm_start=False
2324
self.opt_freq = opt_freq
2425
self.verbose = verbose
2526

26-
def solve(self, X, y, penalty, w_init=None, weights=None):
27-
# needs a quadratic datafit, but works with L1, WeightedL1, SLOPE
27+
def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
2828
n_samples, n_features = X.shape
2929
all_features = np.arange(n_features)
3030
t_new = 1
3131

3232
w = w_init.copy() if w_init is not None else np.zeros(n_features)
3333
z = w_init.copy() if w_init is not None else np.zeros(n_features)
34-
weights = weights if weights is not None else np.ones(n_features)
34+
Xw = Xw_init.copy() if Xw_init is not None else np.zeros(n_samples)
3535

36-
# FISTA with Gram update
37-
G = X.T @ X
38-
Xty = X.T @ y
36+
# line search?
37+
# lipschitz = np.max(datafit.lipschitz)
3938
lipschitz = np.linalg.norm(X, ord=2) ** 2 / n_samples
4039

4140
for n_iter in range(self.max_iter):
4241
t_old = t_new
4342
t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2
4443
w_old = w.copy()
45-
grad = (G @ z - Xty) / n_samples
44+
grad = construct_grad(X, y, z, X @ z, datafit, all_features)
4645
z -= grad / lipschitz
4746
w = _prox_vec(w, z, penalty, lipschitz)
47+
Xw = X @ w
4848
z = w + (t_old - 1.) / t_new * (w - w_old)
4949

5050
if n_iter % self.opt_freq == 0:
5151
opt = penalty.subdiff_distance(w, grad, all_features)
5252
stop_crit = np.max(opt)
5353

5454
if self.verbose:
55-
p_obj = (np.sum((y - X @ w) ** 2) / (2 * n_samples)
56-
+ penalty.value(w))
55+
p_obj = datafit.value(y, w, Xw) + penalty.value(w)
5756
print(
5857
f"Iteration {n_iter+1}: {p_obj:.10f}, "
5958
f"stopping crit: {stop_crit:.2e}"

toy_fista.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
from numpy.linalg import norm
3+
from skglm.datafits.single_task import Quadratic
34
from skglm.solvers import FISTA
45
from skglm.penalties import L1
56
from skglm.estimators import Lasso
@@ -18,8 +19,10 @@
1819
tol = 1e-10
1920

2021
solver = FISTA(max_iter=max_iter, tol=tol, opt_freq=obj_freq, verbose=1)
22+
datafit = compiled_clone(Quadratic())
23+
datafit.initialize(X, y)
2124
penalty = compiled_clone(L1(alpha))
22-
w = solver.solve(X, y, penalty)
25+
w = solver.solve(X, y, datafit, penalty)
2326

2427
clf = Lasso(alpha=alpha, tol=tol, fit_intercept=False)
2528
clf.fit(X, y)

0 commit comments

Comments
 (0)