|
2 | 2 |
|
3 | 3 | import numpy as np |
4 | 4 | from numpy.linalg import norm |
| 5 | +from scipy.sparse import csc_matrix, issparse |
5 | 6 |
|
6 | 7 | from skglm.datafits import Quadratic, Logistic, QuadraticSVC |
7 | | -from skglm.estimators import Lasso, LinearSVC, SparseLogisticRegression |
8 | 8 | from skglm.penalties import L1, IndicatorBox |
9 | | -from skglm.solvers import FISTA |
| 9 | +from skglm.solvers import FISTA, AndersonCD |
10 | 10 | from skglm.utils import make_correlated_data, compiled_clone |
11 | 11 |
|
12 | 12 |
|
13 | | -n_samples, n_features = 10, 20 |
| 13 | +np.random.seed(0) |
| 14 | +n_samples, n_features = 50, 60 |
14 | 15 | X, y, _ = make_correlated_data( |
15 | 16 | n_samples=n_samples, n_features=n_features, random_state=0) |
| 17 | +X_sparse = csc_matrix(X * np.random.binomial(1, 0.1, X.shape)) |
16 | 18 | y_classif = np.sign(y) |
17 | 19 |
|
18 | 20 | alpha_max = norm(X.T @ y, ord=np.inf) / len(y) |
19 | 21 | alpha = alpha_max / 100 |
20 | 22 |
|
21 | | -tol = 1e-8 |
| 23 | +tol = 1e-10 |
22 | 24 |
|
23 | | -# TODO: use GeneralizedLinearEstimator (to test global lipschtiz constants of every datafit) |
24 | | -# TODO: test sparse matrices (global lipschitz constants) |
25 | | -@pytest.mark.parametrize("Datafit, Penalty, Estimator", [ |
26 | | - (Quadratic, L1, Lasso), |
27 | | - (Logistic, L1, SparseLogisticRegression), |
28 | | - (QuadraticSVC, IndicatorBox, LinearSVC), |
| 25 | + |
| 26 | +@pytest.mark.parametrize("X", [X, X_sparse]) |
| 27 | +@pytest.mark.parametrize("Datafit, Penalty", [ |
| 28 | + (Quadratic, L1), |
| 29 | + (Logistic, L1), |
| 30 | + (QuadraticSVC, IndicatorBox), |
29 | 31 | ]) |
30 | | -def test_fista_solver(Datafit, Penalty, Estimator): |
| 32 | +def test_fista_solver(X, Datafit, Penalty): |
31 | 33 | _y = y if isinstance(Datafit, Quadratic) else y_classif |
32 | 34 | datafit = compiled_clone(Datafit()) |
33 | 35 | _init = y @ X.T if isinstance(Datafit, QuadraticSVC) else X |
34 | | - datafit.initialize(_init, _y) |
| 36 | + if issparse(X): |
| 37 | + datafit.initialize_sparse(_init.data, _init.indptr, _init.indices, _y) |
| 38 | + else: |
| 39 | + datafit.initialize(_init, _y) |
35 | 40 | penalty = compiled_clone(Penalty(alpha)) |
36 | 41 |
|
37 | | - solver = FISTA(max_iter=1000, tol=tol) |
| 42 | + solver = FISTA(max_iter=1000, tol=tol, opt_freq=1) |
38 | 43 | w = solver.solve(X, _y, datafit, penalty) |
39 | 44 |
|
40 | | - estimator = Estimator(alpha, tol=tol, fit_intercept=False) |
41 | | - estimator.fit(X, _y) |
| 45 | + solver_cd = AndersonCD(tol=tol, fit_intercept=False) |
| 46 | + w_cd = solver_cd.solve(X, _y, datafit, penalty)[0] |
42 | 47 |
|
43 | | - np.testing.assert_allclose(w, estimator.coef_.flatten(), rtol=1e-3) |
| 48 | + np.testing.assert_allclose(w, w_cd, rtol=1e-3) |
44 | 49 |
|
45 | 50 |
|
46 | 51 | if __name__ == '__main__': |
|
0 commit comments