Skip to content

Commit 9f0653a

Browse files
committed
support sparse matrices
1 parent 46a9a76 commit 9f0653a

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

skglm/solvers/fista.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import numpy as np
2+
from scipy.sparse import issparse
23
from numba import njit
34
from skglm.solvers.base import BaseSolver
4-
from skglm.solvers.common import construct_grad
5+
from skglm.solvers.common import construct_grad, construct_grad_sparse
56

67

78
@njit
@@ -64,7 +65,11 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
6465
t_old = t_new
6566
t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2
6667
w_old = w.copy()
67-
grad = construct_grad(X, y, z, X @ z, datafit, all_features)
68+
if issparse(X):
69+
grad = construct_grad_sparse(
70+
X.data, X.indptr, X.indices, y, z, X @ z, datafit, all_features)
71+
else:
72+
grad = construct_grad(X, y, z, X @ z, datafit, all_features)
6873
z -= grad / lipschitz
6974
w = _prox_vec(w, z, penalty, lipschitz)
7075
Xw = X @ w

0 commit comments

Comments
 (0)