Skip to content

Commit 46a9a76

Browse files
committed
better tests
1 parent 4880112 commit 46a9a76

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

skglm/tests/test_fista.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,45 +2,50 @@
22

33
import numpy as np
44
from numpy.linalg import norm
5+
from scipy.sparse import csc_matrix, issparse
56

67
from skglm.datafits import Quadratic, Logistic, QuadraticSVC
7-
from skglm.estimators import Lasso, LinearSVC, SparseLogisticRegression
88
from skglm.penalties import L1, IndicatorBox
9-
from skglm.solvers import FISTA
9+
from skglm.solvers import FISTA, AndersonCD
1010
from skglm.utils import make_correlated_data, compiled_clone
1111

1212

13-
n_samples, n_features = 10, 20
13+
np.random.seed(0)
14+
n_samples, n_features = 50, 60
1415
X, y, _ = make_correlated_data(
1516
n_samples=n_samples, n_features=n_features, random_state=0)
17+
X_sparse = csc_matrix(X * np.random.binomial(1, 0.1, X.shape))
1618
y_classif = np.sign(y)
1719

1820
alpha_max = norm(X.T @ y, ord=np.inf) / len(y)
1921
alpha = alpha_max / 100
2022

21-
tol = 1e-8
23+
tol = 1e-10
2224

23-
# TODO: use GeneralizedLinearEstimator (to test global lipschtiz constants of every datafit)
24-
# TODO: test sparse matrices (global lipschitz constants)
25-
@pytest.mark.parametrize("Datafit, Penalty, Estimator", [
26-
(Quadratic, L1, Lasso),
27-
(Logistic, L1, SparseLogisticRegression),
28-
(QuadraticSVC, IndicatorBox, LinearSVC),
25+
26+
@pytest.mark.parametrize("X", [X, X_sparse])
27+
@pytest.mark.parametrize("Datafit, Penalty", [
28+
(Quadratic, L1),
29+
(Logistic, L1),
30+
(QuadraticSVC, IndicatorBox),
2931
])
30-
def test_fista_solver(Datafit, Penalty, Estimator):
32+
def test_fista_solver(X, Datafit, Penalty):
3133
_y = y if isinstance(Datafit, Quadratic) else y_classif
3234
datafit = compiled_clone(Datafit())
3335
_init = y @ X.T if isinstance(Datafit, QuadraticSVC) else X
34-
datafit.initialize(_init, _y)
36+
if issparse(X):
37+
datafit.initialize_sparse(_init.data, _init.indptr, _init.indices, _y)
38+
else:
39+
datafit.initialize(_init, _y)
3540
penalty = compiled_clone(Penalty(alpha))
3641

37-
solver = FISTA(max_iter=1000, tol=tol)
42+
solver = FISTA(max_iter=1000, tol=tol, opt_freq=1)
3843
w = solver.solve(X, _y, datafit, penalty)
3944

40-
estimator = Estimator(alpha, tol=tol, fit_intercept=False)
41-
estimator.fit(X, _y)
45+
solver_cd = AndersonCD(tol=tol, fit_intercept=False)
46+
w_cd = solver_cd.solve(X, _y, datafit, penalty)[0]
4247

43-
np.testing.assert_allclose(w, estimator.coef_.flatten(), rtol=1e-3)
48+
np.testing.assert_allclose(w, w_cd, rtol=1e-3)
4449

4550

4651
if __name__ == '__main__':

0 commit comments

Comments
 (0)