@@ -65,6 +65,9 @@ def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4,
6565 self .verbose = verbose
6666
6767 def solve (self , X , y , datafit , penalty , w_init = None , Xw_init = None ):
68+ if self .ws_strategy not in ("subdiff" , "fixpoint" ):
69+ raise ValueError ("ws_strategy must be `subdiff` or `fixpoint`, "
70+ f"got { self .ws_strategy } ." )
6871 dtype = X .dtype
6972 n_samples , n_features = X .shape
7073 fit_intercept = self .fit_intercept
@@ -206,9 +209,9 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
206209 dtype = X .dtype
207210 raw_hess = datafit .raw_hessian (y , Xw_epoch )
208211
209- lipschitz = np .zeros (len (ws ), dtype )
212+ lipschitz_ws = np .zeros (len (ws ), dtype )
210213 for idx , j in enumerate (ws ):
211- lipschitz [idx ] = raw_hess @ X [:, j ] ** 2
214+ lipschitz_ws [idx ] = raw_hess @ X [:, j ] ** 2
212215
213216 # for a less costly stopping criterion, we do not compute the exact gradient,
214217 # but store each coordinate-wise gradient every time we update one coordinate
@@ -224,12 +227,12 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
224227 for cd_iter in range (MAX_CD_ITER ):
225228 for idx , j in enumerate (ws ):
226229 # skip when X[:, j] == 0
227- if lipschitz [idx ] == 0 :
230+ if lipschitz_ws [idx ] == 0 :
228231 continue
229232
230233 past_grads [idx ] = grad_ws [idx ] + X [:, j ] @ (raw_hess * X_delta_w_ws )
231234 old_w_idx = w_ws [idx ]
232- stepsize = 1 / lipschitz [idx ]
235+ stepsize = 1 / lipschitz_ws [idx ]
233236
234237 w_ws [idx ] = penalty .prox_1d (
235238 old_w_idx - stepsize * past_grads [idx ], stepsize , j )
@@ -253,7 +256,7 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
253256 opt = penalty .subdiff_distance (current_w , past_grads , ws )
254257 elif ws_strategy == "fixpoint" :
255258 opt = dist_fix_point_cd (
256- current_w , past_grads , lipschitz , datafit , penalty , ws
259+ current_w , past_grads , lipschitz_ws , datafit , penalty , ws
257260 )
258261 stop_crit = np .max (opt )
259262
@@ -264,7 +267,7 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
264267 break
265268
266269 # descent direction
267- return w_ws - w_epoch [ws_intercept ], X_delta_w_ws , lipschitz
270+ return w_ws - w_epoch [ws_intercept ], X_delta_w_ws , lipschitz_ws
268271
269272
270273# sparse version of _descent_direction
@@ -275,10 +278,10 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
275278 dtype = X_data .dtype
276279 raw_hess = datafit .raw_hessian (y , Xw_epoch )
277280
278- lipschitz = np .zeros (len (ws ), dtype )
281+ lipschitz_ws = np .zeros (len (ws ), dtype )
279282 for idx , j in enumerate (ws ):
280- # equivalent to: lipschitz [idx] += raw_hess * X[:, j] ** 2
281- lipschitz [idx ] = _sparse_squared_weighted_norm (
283+ # equivalent to: lipschitz_ws [idx] += raw_hess * X[:, j] ** 2
284+ lipschitz_ws [idx ] = _sparse_squared_weighted_norm (
282285 X_data , X_indptr , X_indices , j , raw_hess )
283286
284287 # see _descent_direction() comment
@@ -294,7 +297,7 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
294297 for cd_iter in range (MAX_CD_ITER ):
295298 for idx , j in enumerate (ws ):
296299 # skip when X[:, j] == 0
297- if lipschitz [idx ] == 0 :
300+ if lipschitz_ws [idx ] == 0 :
298301 continue
299302
300303 past_grads [idx ] = grad_ws [idx ]
@@ -303,7 +306,7 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
303306 X_data , X_indptr , X_indices , j , X_delta_w_ws , raw_hess )
304307
305308 old_w_idx = w_ws [idx ]
306- stepsize = 1 / lipschitz [idx ]
309+ stepsize = 1 / lipschitz_ws [idx ]
307310
308311 w_ws [idx ] = penalty .prox_1d (
309312 old_w_idx - stepsize * past_grads [idx ], stepsize , j )
@@ -328,7 +331,7 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
328331 opt = penalty .subdiff_distance (current_w , past_grads , ws )
329332 elif ws_strategy == "fixpoint" :
330333 opt = dist_fix_point_cd (
331- current_w , past_grads , lipschitz , datafit , penalty , ws
334+ current_w , past_grads , lipschitz_ws , datafit , penalty , ws
332335 )
333336 stop_crit = np .max (opt )
334337
@@ -339,7 +342,7 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
339342 break
340343
341344 # descent direction
342- return w_ws - w_epoch [ws_intercept ], X_delta_w_ws , lipschitz
345+ return w_ws - w_epoch [ws_intercept ], X_delta_w_ws , lipschitz_ws
343346
344347
345348@njit
0 commit comments