44from sklearn .utils import check_array
55from skglm .solvers .common import construct_grad , construct_grad_sparse , dist_fix_point
66
7+ from skglm .utils import AndersonAcceleration
8+
79
810def cd_solver_path (X , y , datafit , penalty , alphas = None ,
911 coef_init = None , max_iter = 20 , max_epochs = 50_000 ,
10- p0 = 10 , tol = 1e-4 , use_acc = True , return_n_iter = False ,
12+ p0 = 10 , tol = 1e-4 , return_n_iter = False ,
1113 ws_strategy = "subdiff" , verbose = 0 ):
1214 r"""Compute optimization path with Anderson accelerated coordinate descent.
1315
@@ -47,9 +49,6 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None,
4749 tol : float, optional
4850 The tolerance for the optimization.
4951
50- use_acc : bool, optional
51- Usage of Anderson acceleration for faster convergence.
52-
5352 return_n_iter : bool, optional
5453 If True, number of iterations along the path are returned.
5554
@@ -148,7 +147,7 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None,
148147 sol = cd_solver (
149148 X , y , datafit , penalty , w , Xw ,
150149 max_iter = max_iter , max_epochs = max_epochs , p0 = p0 , tol = tol ,
151- use_acc = use_acc , verbose = verbose , ws_strategy = ws_strategy )
150+ verbose = verbose , ws_strategy = ws_strategy )
152151
153152 coefs [:, t ] = w
154153 stop_crits [t ] = sol [- 1 ]
@@ -165,7 +164,7 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None,
165164
166165def cd_solver (
167166 X , y , datafit , penalty , w , Xw , max_iter = 50 , max_epochs = 50_000 , p0 = 10 ,
168- tol = 1e-4 , use_acc = True , K = 5 , ws_strategy = "subdiff" , verbose = 0 ):
167+ tol = 1e-4 , ws_strategy = "subdiff" , verbose = 0 ):
169168 r"""Run a coordinate descent solver.
170169
171170 Parameters
@@ -201,12 +200,6 @@ def cd_solver(
201200 tol : float, optional
202201 The tolerance for the optimization.
203202
204- use_acc : bool, optional
205- Usage of Anderson acceleration for faster convergence.
206-
207- K : int, optional
208- The number of past primal iterates used to build an extrapolated point.
209-
210203 ws_strategy : ('subdiff'|'fixpoint'), optional
211204 The score used to build the working set.
212205
@@ -226,13 +219,14 @@ def cd_solver(
226219 """
227220 if ws_strategy not in ("subdiff" , "fixpoint" ):
228221 raise ValueError (f'Unsupported value for ws_strategy: { ws_strategy } ' )
229- n_features = X .shape [ 1 ]
222+ n_samples , n_features = X .shape
230223 pen = penalty .is_penalized (n_features )
231224 unpen = ~ pen
232225 n_unpen = unpen .sum ()
233226 obj_out = []
234227 all_feats = np .arange (n_features )
235228 stop_crit = np .inf # initialize for case n_iter=0
229+ w_acc , Xw_acc = np .zeros (n_features ), np .zeros (n_samples )
236230
237231 is_sparse = sparse .issparse (X )
238232 for t in range (max_iter ):
@@ -259,14 +253,12 @@ def cd_solver(
259253 opt [unpen ] = np .inf # always include unpenalized features
260254 opt [penalty .generalized_support (w )] = np .inf
261255
262- # here use topk instead of sorting the full array
263- # ie the following line
256+ # here use topk instead of np.argsort(opt)[-ws_size:]
264257 ws = np .argpartition (opt , - ws_size )[- ws_size :]
265- # is equivalent to ws = np.argsort(opt)[-ws_size:]
266258
267- if use_acc :
268- last_K_w = np . zeros ([ K + 1 , ws_size ] )
269- U = np . zeros ([ K , ws_size ])
259+ # re init AA at every iter to consider ws
260+ accelerator = AndersonAcceleration ( K = 5 )
261+ w_acc [:] = 0.
270262
271263 if verbose :
272264 print (f'Iteration { t + 1 } , { ws_size } feats in subpb.' )
@@ -283,45 +275,18 @@ def cd_solver(
283275
284276 # 3) do Anderson acceleration on smaller problem
285277 # TODO optimize computation using ws
286- if use_acc :
287- last_K_w [epoch % (K + 1 )] = w [ws ]
288-
289- if epoch % (K + 1 ) == K :
290- for k in range (K ):
291- U [k ] = last_K_w [k + 1 ] - last_K_w [k ]
292- C = np .dot (U , U .T )
293-
294- try :
295- z = np .linalg .solve (C , np .ones (K ))
296- # When C is ill-conditioned, z can take very large finite
297- # positive and negative values (1e35 and -1e35), which leads
298- # to z.sum() being null.
299- if z .sum () == 0 :
300- raise np .linalg .LinAlgError
301- except np .linalg .LinAlgError :
302- if max (verbose - 1 , 0 ):
303- print ("----------Linalg error" )
304- else :
305- c = z / z .sum ()
306- w_acc = np .zeros (n_features )
307- w_acc [ws ] = np .sum (
308- last_K_w [:- 1 ] * c [:, None ], axis = 0 )
309- # TODO create a p_obj function ?
310- # TODO : managed penalty.value(w[ws])
311- p_obj = datafit .value (y , w , Xw ) + penalty .value (w )
312- # p_obj = datafit.value(y, w, Xw) +penalty.value(w[ws])
313- Xw_acc = X [:, ws ] @ w_acc [ws ]
314- # TODO : managed penalty.value(w[ws])
315- p_obj_acc = datafit .value (
316- y , w_acc , Xw_acc ) + penalty .value (w_acc )
317- if p_obj_acc < p_obj :
318- w [:] = w_acc
319- Xw [:] = Xw_acc
278+ w_acc [ws ], Xw_acc [:], is_extrapolated = accelerator .extrapolate (w [ws ], Xw )
320279
321- if epoch % 10 == 0 :
280+ if is_extrapolated : # avoid computing p_obj for un-extrapolated w, Xw
322281 # TODO : manage penalty.value(w, ws) for weighted Lasso
323- p_obj = datafit .value (y , w [ws ], Xw ) + penalty .value (w )
282+ p_obj = datafit .value (y , w , Xw ) + penalty .value (w )
283+ p_obj_acc = datafit .value (y , w_acc , Xw_acc ) + penalty .value (w_acc )
324284
285+ if p_obj_acc < p_obj :
286+ w [:], Xw [:] = w_acc , Xw_acc
287+ p_obj = p_obj_acc
288+
289+ if epoch % 10 == 0 :
325290 if is_sparse :
326291 grad_ws = construct_grad_sparse (
327292 X .data , X .indptr , X .indices , y , w , Xw , datafit , ws )
@@ -334,6 +299,7 @@ def cd_solver(
334299
335300 stop_crit_in = np .max (opt_ws )
336301 if max (verbose - 1 , 0 ):
302+ p_obj = datafit .value (y , w , Xw ) + penalty .value (w )
337303 print (f"Epoch { epoch + 1 } , objective { p_obj :.10f} , "
338304 f"stopping crit { stop_crit_in :.2e} " )
339305 if ws_size == n_features :
@@ -344,6 +310,7 @@ def cd_solver(
344310 if max (verbose - 1 , 0 ):
345311 print ("Early exit" )
346312 break
313+ p_obj = datafit .value (y , w , Xw ) + penalty .value (w )
347314 obj_out .append (p_obj )
348315 return w , np .array (obj_out ), stop_crit
349316
0 commit comments