@@ -59,10 +59,12 @@ def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4,
5959 self .verbose = verbose
6060
6161 def solve (self , X , y , datafit , penalty , w_init = None , Xw_init = None ):
62+ dtype = X .dtype
6263 n_samples , n_features = X .shape
6364 fit_intercept = self .fit_intercept
64- w = np .zeros (n_features + fit_intercept ) if w_init is None else w_init
65- Xw = np .zeros (n_samples ) if Xw_init is None else Xw_init
65+
66+ w = np .zeros (n_features + fit_intercept , dtype ) if w_init is None else w_init
67+ Xw = np .zeros (n_samples , dtype ) if Xw_init is None else Xw_init
6668 all_features = np .arange (n_features )
6769 stop_crit = 0.
6870 p_objs_out = []
@@ -181,16 +183,17 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
181183 # Minimize quadratic approximation for delta_w = w - w_epoch:
182184 # b.T @ X @ delta_w + \
183185 # 1/2 * delta_w.T @ (X.T @ D @ X) @ delta_w + penalty(w)
186+ dtype = X .dtype
184187 raw_hess = datafit .raw_hessian (y , Xw_epoch )
185188
186- lipschitz = np .zeros (len (ws ))
189+ lipschitz = np .zeros (len (ws ), dtype )
187190 for idx , j in enumerate (ws ):
188191 lipschitz [idx ] = raw_hess @ X [:, j ] ** 2
189192
190193 # for a less costly stopping criterion, we do not compute the exact gradient,
191194 # but store each coordinate-wise gradient every time we update one coordinate
192- past_grads = np .zeros (len (ws ))
193- X_delta_w_ws = np .zeros (X .shape [0 ])
195+ past_grads = np .zeros (len (ws ), dtype )
196+ X_delta_w_ws = np .zeros (X .shape [0 ], dtype )
194197 ws_intercept = np .append (ws , - 1 ) if fit_intercept else ws
195198 w_ws = w_epoch [ws_intercept ]
196199
@@ -243,17 +246,18 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
243246@njit
244247def _descent_direction_s (X_data , X_indptr , X_indices , y , w_epoch ,
245248 Xw_epoch , fit_intercept , grad_ws , datafit , penalty , ws , tol ):
249+ dtype = X_data .dtype
246250 raw_hess = datafit .raw_hessian (y , Xw_epoch )
247251
248- lipschitz = np .zeros (len (ws ))
252+ lipschitz = np .zeros (len (ws ), dtype )
249253 for idx , j in enumerate (ws ):
250254 # equivalent to: lipschitz[idx] += raw_hess * X[:, j] ** 2
251255 lipschitz [idx ] = _sparse_squared_weighted_norm (
252256 X_data , X_indptr , X_indices , j , raw_hess )
253257
254258 # see _descent_direction() comment
255- past_grads = np .zeros (len (ws ))
256- X_delta_w_ws = np .zeros (Xw_epoch .shape [0 ])
259+ past_grads = np .zeros (len (ws ), dtype )
260+ X_delta_w_ws = np .zeros (Xw_epoch .shape [0 ], dtype )
257261 ws_intercept = np .append (ws , - 1 ) if fit_intercept else ws
258262 w_ws = w_epoch [ws_intercept ]
259263
@@ -329,7 +333,11 @@ def _backtrack_line_search(X, y, w, Xw, fit_intercept, datafit, penalty, delta_w
329333 grad_ws = _construct_grad (X , y , w [:n_features ], Xw , datafit , ws )
330334 # TODO: could be improved by passing in w[ws]
331335 stop_crit = penalty .value (w [:n_features ]) - old_penalty_val
332- stop_crit += step * grad_ws @ delta_w_ws [:len (ws )]
336+
337+ # it is mandatory to split the two operations, otherwise numba raises an error
338+ # cf. https://github.com/numba/numba/issues/9025
339+ dot = grad_ws @ delta_w_ws [:len (ws )]
340+ stop_crit += step * dot
333341
334342 if fit_intercept :
335343 stop_crit += step * delta_w_ws [- 1 ] * np .sum (datafit .raw_grad (y , Xw ))
@@ -364,7 +372,11 @@ def _backtrack_line_search_s(X_data, X_indptr, X_indices, y, w, Xw, fit_intercep
364372 y , w [:n_features ], Xw , datafit , ws )
365373 # TODO: could be improved by passing in w[ws]
366374 stop_crit = penalty .value (w [:n_features ]) - old_penalty_val
367- stop_crit += step * grad_ws .T @ delta_w_ws [:len (ws )]
375+
376+ # it is mandatory to split the two operations, otherwise numba raises an error
377+ # cf. https://github.com/numba/numba/issues/9025
378+ dot = grad_ws .T @ delta_w_ws [:len (ws )]
379+ stop_crit += step * dot
368380
369381 if fit_intercept :
370382 stop_crit += step * delta_w_ws [- 1 ] * np .sum (datafit .raw_grad (y , Xw ))
@@ -385,7 +397,7 @@ def _construct_grad(X, y, w, Xw, datafit, ws):
385397 # Compute grad of datafit restricted to ws. This function avoids
386398 # recomputing raw_grad for every j, which is costly for logreg
387399 raw_grad = datafit .raw_grad (y , Xw )
388- grad = np .zeros (len (ws ))
400+ grad = np .zeros (len (ws ), dtype = X . dtype )
389401 for idx , j in enumerate (ws ):
390402 grad [idx ] = X [:, j ] @ raw_grad
391403 return grad
@@ -395,7 +407,7 @@ def _construct_grad(X, y, w, Xw, datafit, ws):
395407def _construct_grad_sparse (X_data , X_indptr , X_indices , y , w , Xw , datafit , ws ):
396408 # Compute grad of datafit restricted to ws in case X sparse
397409 raw_grad = datafit .raw_grad (y , Xw )
398- grad = np .zeros (len (ws ))
410+ grad = np .zeros (len (ws ), dtype = X_data . dtype )
399411 for idx , j in enumerate (ws ):
400412 grad [idx ] = _sparse_xj_dot (X_data , X_indptr , X_indices , j , raw_grad )
401413 return grad
0 commit comments