@@ -24,6 +24,9 @@ class ProxNewton(BaseSolver):
2424 tol : float, default 1e-4
2525 Tolerance for convergence.
2626
27+ fit_intercept : bool, default True
28+ If ``True``, fits an unpenalized intercept.
29+
2730 verbose : bool, default False
2831 Amount of verbosity. 0/False is silent.
2932
@@ -53,7 +56,8 @@ def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4,
5356
5457 def solve (self , X , y , datafit , penalty , w_init = None , Xw_init = None ):
5558 n_samples , n_features = X .shape
56- w = np .zeros (n_features ) if w_init is None else w_init
59+ fit_intercept = self .fit_intercept
60+ w = np .zeros (n_features + fit_intercept ) if w_init is None else w_init
5761 Xw = np .zeros (n_samples ) if Xw_init is None else Xw_init
5862 all_features = np .arange (n_features )
5963 stop_crit = 0.
@@ -63,20 +67,38 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
6367 if is_sparse :
6468 X_bundles = (X .data , X .indptr , X .indices )
6569
70+ if len (w ) != n_features + self .fit_intercept :
71+ if self .fit_intercept :
72+ val_error_message = (
73+ "w should be of size n_features + 1 when using fit_intercept=True: "
74+ f"expected { n_features + 1 } , got { len (w )} ." )
75+ else :
76+ val_error_message = (
77+ "w should be of size n_features: "
78+ f"expected { n_features } , got { len (w )} ." )
79+ raise ValueError (val_error_message )
80+
6681 for t in range (self .max_iter ):
6782 # compute scores
6883 if is_sparse :
6984 grad = _construct_grad_sparse (
70- * X_bundles , y , w , Xw , datafit , all_features )
85+ * X_bundles , y , w [: n_features ] , Xw , datafit , all_features )
7186 else :
72- grad = _construct_grad (X , y , w , Xw , datafit , all_features )
87+ grad = _construct_grad (X , y , w [: n_features ] , Xw , datafit , all_features )
7388
74- opt = penalty .subdiff_distance (w , grad , all_features )
89+ opt = penalty .subdiff_distance (w [:n_features ], grad , all_features )
90+
91+ # optimality of intercept
92+ if fit_intercept :
93+ # gradient w.r.t. intercept (constant features of ones)
94+ intercept_opt = np .abs (np .sum (datafit .raw_grad (y , Xw )))
95+ else :
96+ intercept_opt = 0.
7597
7698 # check convergences
77- stop_crit = np .max (opt )
99+ stop_crit = max ( np .max (opt ), intercept_opt )
78100 if self .verbose :
79- p_obj = datafit .value (y , w , Xw ) + penalty .value (w )
101+ p_obj = datafit .value (y , w , Xw ) + penalty .value (w [: n_features ] )
80102 print (
81103 "Iteration {}: {:.10f}, " .format (t + 1 , p_obj ) +
82104 "stopping crit: {:.2e}" .format (stop_crit )
@@ -101,20 +123,22 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
101123 # find descent direction
102124 if is_sparse :
103125 delta_w_ws , X_delta_w_ws = _descent_direction_s (
104- * X_bundles , y , w , Xw , grad_ws , datafit ,
126+ * X_bundles , y , w , Xw , fit_intercept , grad_ws , datafit ,
105127 penalty , ws , tol = EPS_TOL * tol_in )
106128 else :
107129 delta_w_ws , X_delta_w_ws = _descent_direction (
108- X , y , w , Xw , grad_ws , datafit , penalty , ws , tol = EPS_TOL * tol_in )
130+ X , y , w , Xw , fit_intercept , grad_ws , datafit ,
131+ penalty , ws , tol = EPS_TOL * tol_in )
109132
110133 # backtracking line search with inplace update of w, Xw
111134 if is_sparse :
112135 grad_ws [:] = _backtrack_line_search_s (
113- * X_bundles , y , w , Xw , datafit , penalty , delta_w_ws ,
114- X_delta_w_ws , ws )
136+ * X_bundles , y , w , Xw , fit_intercept , datafit , penalty ,
137+ delta_w_ws , X_delta_w_ws , ws )
115138 else :
116139 grad_ws [:] = _backtrack_line_search (
117- X , y , w , Xw , datafit , penalty , delta_w_ws , X_delta_w_ws , ws )
140+ X , y , w , Xw , fit_intercept , datafit , penalty ,
141+ delta_w_ws , X_delta_w_ws , ws )
118142
119143 # check convergence
120144 opt_in = penalty .subdiff_distance (w , grad_ws , ws )
@@ -138,7 +162,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
138162
139163
140164@njit
141- def _descent_direction (X , y , w_epoch , Xw_epoch , grad_ws , datafit ,
165+ def _descent_direction (X , y , w_epoch , Xw_epoch , fit_intercept , grad_ws , datafit ,
142166 penalty , ws , tol ):
143167 # Given:
144168 # 1) b = \nabla F(X w_epoch)
@@ -152,11 +176,16 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit,
152176 for idx , j in enumerate (ws ):
153177 lipschitz [idx ] = raw_hess @ X [:, j ] ** 2
154178
155- # for a less costly stopping criterion, we do no compute the exact gradient,
156- # but store each coordinate-wise gradient every time we upate one coordinate:
179+ # for a less costly stopping criterion, we do not compute the exact gradient,
180+ # but store each coordinate-wise gradient every time we update one coordinate
157181 past_grads = np .zeros (len (ws ))
158182 X_delta_w_ws = np .zeros (X .shape [0 ])
159- w_ws = w_epoch [ws ]
183+ ws_intercept = np .append (ws , - 1 ) if fit_intercept else ws
184+ w_ws = w_epoch [ws_intercept ]
185+
186+ if fit_intercept :
187+ lipschitz_intercept = np .sum (raw_hess )
188+ grad_intercept = np .sum (datafit .raw_grad (y , Xw_epoch ))
160189
161190 for cd_iter in range (MAX_CD_ITER ):
162191 for idx , j in enumerate (ws ):
@@ -174,22 +203,35 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit,
174203 if w_ws [idx ] != old_w_idx :
175204 X_delta_w_ws += (w_ws [idx ] - old_w_idx ) * X [:, j ]
176205
206+ if fit_intercept :
207+ past_grads_intercept = grad_intercept + raw_hess @ X_delta_w_ws
208+ old_intercept = w_ws [- 1 ]
209+ w_ws [- 1 ] -= past_grads_intercept / lipschitz_intercept
210+
211+ if w_ws [- 1 ] != old_intercept :
212+ X_delta_w_ws += w_ws [- 1 ] - old_intercept
213+
177214 if cd_iter % 5 == 0 :
178215 # TODO: can be improved by passing in w_ws but breaks for WeightedL1
179216 current_w = w_epoch .copy ()
180- current_w [ws ] = w_ws
217+ current_w [ws_intercept ] = w_ws
181218 opt = penalty .subdiff_distance (current_w , past_grads , ws )
182- if np .max (opt ) <= tol :
219+ stop_crit = np .max (opt )
220+
221+ if fit_intercept :
222+ stop_crit = max (stop_crit , np .abs (past_grads_intercept ))
223+
224+ if stop_crit <= tol :
183225 break
184226
185227 # descent direction
186- return w_ws - w_epoch [ws ], X_delta_w_ws
228+ return w_ws - w_epoch [ws_intercept ], X_delta_w_ws
187229
188230
189- # sparse version of _compute_descent_direction
231+ # sparse version of _descent_direction
190232@njit
191233def _descent_direction_s (X_data , X_indptr , X_indices , y , w_epoch ,
192- Xw_epoch , grad_ws , datafit , penalty , ws , tol ):
234+ Xw_epoch , fit_intercept , grad_ws , datafit , penalty , ws , tol ):
193235 raw_hess = datafit .raw_hessian (y , Xw_epoch )
194236
195237 lipschitz = np .zeros (len (ws ))
@@ -201,7 +243,12 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
201243 # see _descent_direction() comment
202244 past_grads = np .zeros (len (ws ))
203245 X_delta_w_ws = np .zeros (len (y ))
204- w_ws = w_epoch [ws ]
246+ ws_intercept = np .append (ws , - 1 ) if fit_intercept else ws
247+ w_ws = w_epoch [ws_intercept ]
248+
249+ if fit_intercept :
250+ lipschitz_intercept = np .sum (raw_hess )
251+ grad_intercept = np .sum (datafit .raw_grad (y , Xw_epoch ))
205252
206253 for cd_iter in range (MAX_CD_ITER ):
207254 for idx , j in enumerate (ws ):
@@ -224,39 +271,57 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
224271 _update_X_delta_w (X_data , X_indptr , X_indices , X_delta_w_ws ,
225272 w_ws [idx ] - old_w_idx , j )
226273
274+ if fit_intercept :
275+ past_grads_intercept = grad_intercept + raw_hess @ X_delta_w_ws
276+ old_intercept = w_ws [- 1 ]
277+ w_ws [- 1 ] -= past_grads_intercept / lipschitz_intercept
278+
279+ if w_ws [- 1 ] != old_intercept :
280+ X_delta_w_ws += w_ws [- 1 ] - old_intercept
281+
227282 if cd_iter % 5 == 0 :
228283 # TODO: could be improved by passing in w_ws
229284 current_w = w_epoch .copy ()
230- current_w [ws ] = w_ws
285+ current_w [ws_intercept ] = w_ws
231286 opt = penalty .subdiff_distance (current_w , past_grads , ws )
232- if np .max (opt ) <= tol :
287+ stop_crit = np .max (opt )
288+
289+ if fit_intercept :
290+ stop_crit = max (stop_crit , np .abs (past_grads_intercept ))
291+
292+ if stop_crit <= tol :
233293 break
234294
235295 # descent direction
236- return w_ws - w_epoch [ws ], X_delta_w_ws
296+ return w_ws - w_epoch [ws_intercept ], X_delta_w_ws
237297
238298
239299@njit
240- def _backtrack_line_search (X , y , w , Xw , datafit , penalty , delta_w_ws ,
300+ def _backtrack_line_search (X , y , w , Xw , fit_intercept , datafit , penalty , delta_w_ws ,
241301 X_delta_w_ws , ws ):
242302 # 1) find step in [0, 1] such that:
243303 # penalty(w + step * delta_w) - penalty(w) +
244304 # step * \nabla datafit(w + step * delta_w) @ delta_w < 0
245305 # ref: https://www.di.ens.fr/~aspremon/PDF/ENSAE/Newton.pdf
246306 # 2) inplace update of w and Xw and return grad_ws of the last w and Xw
247307 step , prev_step = 1. , 0.
308+ n_features = X .shape [1 ]
309+ ws_intercept = np .append (ws , - 1 ) if fit_intercept else ws
248310 # TODO: could be improved by passing in w[ws]
249- old_penalty_val = penalty .value (w )
311+ old_penalty_val = penalty .value (w [: n_features ] )
250312
251313 # try step = 1, 1/2, 1/4, ...
252314 for _ in range (MAX_BACKTRACK_ITER ):
253- w [ws ] += (step - prev_step ) * delta_w_ws
315+ w [ws_intercept ] += (step - prev_step ) * delta_w_ws
254316 Xw += (step - prev_step ) * X_delta_w_ws
255317
256- grad_ws = _construct_grad (X , y , w , Xw , datafit , ws )
318+ grad_ws = _construct_grad (X , y , w [: n_features ] , Xw , datafit , ws )
257319 # TODO: could be improved by passing in w[ws]
258- stop_crit = penalty .value (w ) - old_penalty_val
259- stop_crit += step * grad_ws @ delta_w_ws
320+ stop_crit = penalty .value (w [:n_features ]) - old_penalty_val
321+ stop_crit += step * grad_ws @ delta_w_ws [:len (ws )]
322+
323+ if fit_intercept :
324+ stop_crit += step * delta_w_ws [- 1 ] * np .sum (datafit .raw_grad (y , Xw ))
260325
261326 if stop_crit < 0 :
262327 break
@@ -272,21 +337,26 @@ def _backtrack_line_search(X, y, w, Xw, datafit, penalty, delta_w_ws,
272337
273338# sparse version of _backtrack_line_search
274339@njit
275- def _backtrack_line_search_s (X_data , X_indptr , X_indices , y , w , Xw , datafit ,
276- penalty , delta_w_ws , X_delta_w_ws , ws ):
340+ def _backtrack_line_search_s (X_data , X_indptr , X_indices , y , w , Xw , fit_intercept ,
341+ datafit , penalty , delta_w_ws , X_delta_w_ws , ws ):
277342 step , prev_step = 1. , 0.
343+ n_features = len (X_indptr ) - 1
344+ ws_intercept = np .append (ws , - 1 ) if fit_intercept else ws
278345 # TODO: could be improved by passing in w[ws]
279- old_penalty_val = penalty .value (w )
346+ old_penalty_val = penalty .value (w [: n_features ] )
280347
281348 for _ in range (MAX_BACKTRACK_ITER ):
282- w [ws ] += (step - prev_step ) * delta_w_ws
349+ w [ws_intercept ] += (step - prev_step ) * delta_w_ws
283350 Xw += (step - prev_step ) * X_delta_w_ws
284351
285352 grad_ws = _construct_grad_sparse (X_data , X_indptr , X_indices ,
286- y , w , Xw , datafit , ws )
353+ y , w [: n_features ] , Xw , datafit , ws )
287354 # TODO: could be improved by passing in w[ws]
288- stop_crit = penalty .value (w ) - old_penalty_val
289- stop_crit += step * grad_ws .T @ delta_w_ws
355+ stop_crit = penalty .value (w [:n_features ]) - old_penalty_val
356+ stop_crit += step * grad_ws .T @ delta_w_ws [:len (ws )]
357+
358+ if fit_intercept :
359+ stop_crit += step * delta_w_ws [- 1 ] * np .sum (datafit .raw_grad (y , Xw ))
290360
291361 if stop_crit < 0 :
292362 break
0 commit comments