Skip to content

Commit 4880112

Browse files
committed
writing tests
1 parent 3635f24 commit 4880112

File tree

3 files changed

+80
-12
lines changed

3 files changed

+80
-12
lines changed

skglm/datafits/single_task.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,13 @@ def initialize_sparse(
6363
for j in range(n_features):
6464
nrm2 = 0.
6565
xty = 0
66-
x2 = 0.
6766
for idx in range(X_indptr[j], X_indptr[j + 1]):
6867
nrm2 += X_data[idx] ** 2
6968
xty += X_data[idx] * y[X_indices[idx]]
70-
x2 += X_data[idx] ** 2 / len(y)
7169

7270
self.lipschitz[j] = nrm2 / len(y)
7371
self.Xty[j] = xty
74-
self.global_lipschitz += x2
72+
self.global_lipschitz += nrm2 / len(y)
7573

7674
def value(self, y, w, Xw):
7775
return np.sum((y - Xw) ** 2) / (2 * len(Xw))
@@ -233,10 +231,9 @@ def params_to_dict(self):
233231
def initialize(self, yXT, y):
234232
n_features = yXT.shape[1]
235233
self.lipschitz = np.zeros(n_features, dtype=yXT.dtype)
236-
self.global_lipschitz = 0.
234+
self.global_lipschitz = norm(yXT, ord=2) ** 2 / len(y)
237235
for j in range(n_features):
238236
self.lipschitz[j] = norm(yXT[:, j]) ** 2
239-
self.global_lipschitz += norm(yXT[:, j]) ** 2
240237

241238
def initialize_sparse(
242239
self, yXT_data, yXT_indptr, yXT_indices, y):

skglm/solvers/fista.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,35 @@ def _prox_vec(w, z, penalty, lipschitz):
1313

1414

1515
class FISTA(BaseSolver):
16-
r"""ISTA solver with Nesterov acceleration (FISTA)."""
16+
r"""ISTA solver with Nesterov acceleration (FISTA).
1717
18-
def __init__(self, max_iter=100, tol=1e-4, fit_intercept=False, warm_start=False,
19-
opt_freq=10, verbose=0):
18+
This solver implements accelerated proximal gradient descent for linear problems.
19+
20+
Attributes
21+
----------
22+
max_iter : int, default 100
23+
Maximum number of iterations.
24+
25+
tol : float, default 1e-4
26+
Tolerance for convergence.
27+
28+
opt_freq : int, default 10
29+
Frequency for optimality condition check.
30+
31+
verbose : bool, default False
32+
Amount of verbosity. 0/False is silent.
33+
34+
References
35+
----------
36+
.. [1] Beck, A. and Teboulle M.
37+
"A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse
38+
problems", 2009, SIAM J. Imaging Sci.
39+
https://epubs.siam.org/doi/10.1137/080716542
40+
"""
41+
42+
def __init__(self, max_iter=100, tol=1e-4, opt_freq=10, verbose=0):
2043
self.max_iter = max_iter
2144
self.tol = tol
22-
self.fit_intercept = fit_intercept
23-
self.warm_start = warm_start
2445
self.opt_freq = opt_freq
2546
self.verbose = verbose
2647

@@ -33,8 +54,11 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
3354
z = w_init.copy() if w_init is not None else np.zeros(n_features)
3455
Xw = Xw_init.copy() if Xw_init is not None else np.zeros(n_samples)
3556

36-
# TODO: OR line search
37-
lipschitz = datafit.global_lipschitz
57+
if hasattr(datafit, "global_lipschitz"):
58+
lipschitz = datafit.global_lipschitz
59+
else:
60+
# TODO: OR line search
61+
raise Exception("Line search is not yet implemented for FISTA solver.")
3862

3963
for n_iter in range(self.max_iter):
4064
t_old = t_new

skglm/tests/test_fista.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import pytest
2+
3+
import numpy as np
4+
from numpy.linalg import norm
5+
6+
from skglm.datafits import Quadratic, Logistic, QuadraticSVC
7+
from skglm.estimators import Lasso, LinearSVC, SparseLogisticRegression
8+
from skglm.penalties import L1, IndicatorBox
9+
from skglm.solvers import FISTA
10+
from skglm.utils import make_correlated_data, compiled_clone
11+
12+
13+
n_samples, n_features = 10, 20
14+
X, y, _ = make_correlated_data(
15+
n_samples=n_samples, n_features=n_features, random_state=0)
16+
y_classif = np.sign(y)
17+
18+
alpha_max = norm(X.T @ y, ord=np.inf) / len(y)
19+
alpha = alpha_max / 100
20+
21+
tol = 1e-8
22+
23+
# TODO: use GeneralizedLinearEstimator (to test global lipschtiz constants of every datafit)
24+
# TODO: test sparse matrices (global lipschitz constants)
25+
@pytest.mark.parametrize("Datafit, Penalty, Estimator", [
26+
(Quadratic, L1, Lasso),
27+
(Logistic, L1, SparseLogisticRegression),
28+
(QuadraticSVC, IndicatorBox, LinearSVC),
29+
])
30+
def test_fista_solver(Datafit, Penalty, Estimator):
31+
_y = y if isinstance(Datafit, Quadratic) else y_classif
32+
datafit = compiled_clone(Datafit())
33+
_init = y @ X.T if isinstance(Datafit, QuadraticSVC) else X
34+
datafit.initialize(_init, _y)
35+
penalty = compiled_clone(Penalty(alpha))
36+
37+
solver = FISTA(max_iter=1000, tol=tol)
38+
w = solver.solve(X, _y, datafit, penalty)
39+
40+
estimator = Estimator(alpha, tol=tol, fit_intercept=False)
41+
estimator.fit(X, _y)
42+
43+
np.testing.assert_allclose(w, estimator.coef_.flatten(), rtol=1e-3)
44+
45+
46+
if __name__ == '__main__':
47+
pass

0 commit comments

Comments
 (0)