44
55from skglm .penalties .base import BasePenalty
66from skglm .utils .prox_funcs import (
7- ST , box_proj , prox_05 , prox_2_3 , prox_SCAD , value_SCAD , prox_MCP , value_MCP )
7+ ST , box_proj , prox_05 , prox_2_3 , prox_SCAD , value_SCAD , prox_MCP ,
8+ value_MCP , value_weighted_MCP )
89
910
1011class L1 (BasePenalty ):
@@ -216,48 +217,57 @@ class MCPenalty(BasePenalty):
216217 With :math:`x >= 0`:
217218
218219 .. math::
219- "pen"(x) = {(alpha x - x^2 / (2 gamma), if x =< alpha gamma),
220+ "pen"(x) = {(alpha x - x^2 / (2 gamma), if x <= alpha gamma),
220221 (gamma alpha^2 / 2 , if x > alpha gamma):}
221222 .. math::
222223 "value" = sum_(j=1)^(n_"features") "pen"(abs(w_j))
223224 """
224225
225- def __init__ (self , alpha , gamma ):
226+ def __init__ (self , alpha , gamma , positive = False ):
226227 self .alpha = alpha
227228 self .gamma = gamma
229+ self .positive = positive
228230
229231 def get_spec (self ):
230232 spec = (
231233 ('alpha' , float64 ),
232234 ('gamma' , float64 ),
235+ ('positive' , bool_ )
233236 )
234237 return spec
235238
236239 def params_to_dict (self ):
237240 return dict (alpha = self .alpha ,
238- gamma = self .gamma )
241+ gamma = self .gamma ,
242+ positive = self .positive )
239243
240244 def value (self , w ):
241245 return value_MCP (w , self .alpha , self .gamma )
242246
243247 def prox_1d (self , value , stepsize , j ):
244248 """Compute the proximal operator of MCP."""
245- return prox_MCP (value , stepsize , self .alpha , self .gamma )
249+ return prox_MCP (value , stepsize , self .alpha , self .gamma , self . positive )
246250
247251 def subdiff_distance (self , w , grad , ws ):
248252 """Compute distance of negative gradient to the subdifferential at w."""
249253 subdiff_dist = np .zeros_like (grad )
250254 for idx , j in enumerate (ws ):
251- if w [j ] == 0 :
252- # distance of -grad to alpha * [-1, 1]
253- subdiff_dist [idx ] = max (0 , np .abs (grad [idx ]) - self .alpha )
254- elif np .abs (w [j ]) < self .alpha * self .gamma :
255- # distance of -grad_j to (alpha * sign(w[j]) - w[j] / gamma)
256- subdiff_dist [idx ] = np .abs (
257- grad [idx ] + self .alpha * np .sign (w [j ]) - w [j ] / self .gamma )
255+ if self .positive and w [j ] < 0 :
256+ subdiff_dist [idx ] = np .inf
257+ elif self .positive and w [j ] == 0 :
258+ # distance of -grad to (-infty, alpha]
259+ subdiff_dist [idx ] = max (0 , - grad [idx ] - self .alpha )
258260 else :
259- # distance of grad to 0
260- subdiff_dist [idx ] = np .abs (grad [idx ])
261+ if w [j ] == 0 :
262+ # distance of -grad to [-alpha, alpha]
263+ subdiff_dist [idx ] = max (0 , np .abs (grad [idx ]) - self .alpha )
264+ elif np .abs (w [j ]) < self .alpha * self .gamma :
265+ # distance of -grad to {alpha * sign(w[j]) - w[j] / gamma}
266+ subdiff_dist [idx ] = np .abs (
267+ grad [idx ] + self .alpha * np .sign (w [j ]) - w [j ] / self .gamma )
268+ else :
269+ # distance of grad to 0
270+ subdiff_dist [idx ] = np .abs (grad [idx ])
261271 return subdiff_dist
262272
263273 def is_penalized (self , n_features ):
@@ -273,6 +283,89 @@ def alpha_max(self, gradient0):
273283 return np .max (np .abs (gradient0 ))
274284
275285
286+ class WeightedMCPenalty (BasePenalty ):
287+ """Weighted Minimax Concave Penalty (MCP), a non-convex sparse penalty.
288+
289+ Notes
290+ -----
291+ With :math:`x >= 0`:
292+
293+ .. math::
294+ "pen"(x) = {(alpha x - x^2 / (2 gamma), if x <= alpha gamma),
295+ (gamma alpha^2 / 2 , if x > alpha gamma):}
296+ .. math::
297+ "value" = sum_(j=1)^(n_"features") "weights"_j xx "pen"(abs(w_j))
298+ """
299+
300+ def __init__ (self , alpha , gamma , weights , positive = False ):
301+ self .alpha = alpha
302+ self .gamma = gamma
303+ self .weights = weights .astype (np .float64 )
304+ self .positive = positive
305+
306+ def get_spec (self ):
307+ spec = (
308+ ('alpha' , float64 ),
309+ ('gamma' , float64 ),
310+ ('weights' , float64 [:]),
311+ ('positive' , bool_ )
312+ )
313+ return spec
314+
315+ def params_to_dict (self ):
316+ return dict (alpha = self .alpha ,
317+ gamma = self .gamma ,
318+ weights = self .weights ,
319+ positive = self .positive )
320+
321+ def value (self , w ):
322+ return value_weighted_MCP (w , self .alpha , self .gamma , self .weights )
323+
324+ def prox_1d (self , value , stepsize , j ):
325+ """Compute the proximal operator of the weighted MCP."""
326+ return prox_MCP (
327+ value , stepsize , self .alpha , self .gamma , self .positive , self .weights [j ])
328+
329+ def subdiff_distance (self , w , grad , ws ):
330+ """Compute distance of negative gradient to the subdifferential at w."""
331+ subdiff_dist = np .zeros_like (grad )
332+ for idx , j in enumerate (ws ):
333+ if self .positive and w [j ] < 0 :
334+ subdiff_dist [idx ] = np .inf
335+ elif self .positive and w [j ] == 0 :
336+ # distance of -grad to (-infty, alpha * weights[j]]
337+ subdiff_dist [idx ] = max (
338+ 0 , - grad [idx ] - self .alpha * self .weights [j ])
339+ else :
340+ if w [j ] == 0 :
341+ # distance of -grad to weights[j] * [-alpha, alpha]
342+ subdiff_dist [idx ] = max (
343+ 0 , np .abs (grad [idx ]) - self .alpha * self .weights [j ])
344+ elif np .abs (w [j ]) < self .alpha * self .gamma :
345+ # distance of -grad to
346+ # {weights[j] * alpha * sign(w[j]) - w[j] / gamma}
347+ subdiff_dist [idx ] = np .abs (
348+ grad [idx ] + self .alpha * self .weights [j ] * np .sign (w [j ])
349+ - self .weights [j ] * w [j ] / self .gamma )
350+ else :
351+ # distance of grad to 0
352+ subdiff_dist [idx ] = np .abs (grad [idx ])
353+ return subdiff_dist
354+
355+ def is_penalized (self , n_features ):
356+ """Return a binary mask with the penalized features."""
357+ return np .ones (n_features , bool_ )
358+
359+ def generalized_support (self , w ):
360+ """Return a mask with non-zero coefficients."""
361+ return w != 0
362+
363+ def alpha_max (self , gradient0 ):
364+ """Return penalization value for which 0 is solution."""
365+ nnz_weights = self .weights != 0
366+ return np .max (np .abs (gradient0 [nnz_weights ] / self .weights [nnz_weights ]))
367+
368+
276369class SCAD (BasePenalty ):
277370 r"""Smoothly Clipped Absolute Deviation.
278371
0 commit comments