Skip to content

Commit e47c68a

Browse files
committed
ADD global lipschitz constants
1 parent 4940a0d commit e47c68a

File tree

1 file changed

+43
-2
lines changed

1 file changed

+43
-2
lines changed

skglm/datafits/single_task.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ class Quadratic(BaseDatafit):
2222
The coordinatewise gradient Lipschitz constants. Equal to
2323
norm(X, axis=0) ** 2 / n_samples.
2424
25+
global_lipschitz : float
26+
Global Lipschitz constant. Equal to
27+
norm(X, ord=2) ** 2 / n_samples.
28+
2529
Note
2630
----
2731
The class is jit compiled at fit time using Numba compiler.
@@ -35,6 +39,7 @@ def get_spec(self):
3539
spec = (
3640
('Xty', float64[:]),
3741
('lipschitz', float64[:]),
42+
('global_lipschitz', float64),
3843
)
3944
return spec
4045

@@ -44,6 +49,7 @@ def params_to_dict(self):
4449
def initialize(self, X, y):
4550
self.Xty = X.T @ y
4651
n_features = X.shape[1]
52+
self.global_lipschitz = norm(X, ord=2) ** 2 / len(y)
4753
self.lipschitz = np.zeros(n_features, dtype=X.dtype)
4854
for j in range(n_features):
4955
self.lipschitz[j] = (X[:, j] ** 2).sum() / len(y)
@@ -53,15 +59,19 @@ def initialize_sparse(
5359
n_features = len(X_indptr) - 1
5460
self.Xty = np.zeros(n_features, dtype=X_data.dtype)
5561
self.lipschitz = np.zeros(n_features, dtype=X_data.dtype)
62+
self.global_lipschitz = 0.
5663
for j in range(n_features):
5764
nrm2 = 0.
5865
xty = 0
66+
x2 = 0.
5967
for idx in range(X_indptr[j], X_indptr[j + 1]):
6068
nrm2 += X_data[idx] ** 2
6169
xty += X_data[idx] * y[X_indices[idx]]
70+
x2 += X_data[idx] ** 2 / len(y)
6271

6372
self.lipschitz[j] = nrm2 / len(y)
6473
self.Xty[j] = xty
74+
self.global_lipschitz += x2
6575

6676
def value(self, y, w, Xw):
6777
return np.sum((y - Xw) ** 2) / (2 * len(Xw))
@@ -111,6 +121,10 @@ class Logistic(BaseDatafit):
111121
The coordinatewise gradient Lipschitz constants. Equal to
112122
norm(X, axis=0) ** 2 / (4 * n_samples).
113123
124+
global_lipschitz : float
125+
Global Lipschitz constant. Equal to
126+
norm(X, ord=2) ** 2 / (4 * n_samples).
127+
114128
Note
115129
----
116130
The class is jit compiled at fit time using Numba compiler.
@@ -123,6 +137,7 @@ def __init__(self):
123137
def get_spec(self):
124138
spec = (
125139
('lipschitz', float64[:]),
140+
('global_lipschitz', float64),
126141
)
127142
return spec
128143

@@ -140,13 +155,16 @@ def raw_hessian(self, y, Xw):
140155

141156
def initialize(self, X, y):
142157
self.lipschitz = (X ** 2).sum(axis=0) / (len(y) * 4)
158+
self.global_lipschitz = norm(X, ord=2) ** 2 / (len(y) * 4)
143159

144160
def initialize_sparse(self, X_data, X_indptr, X_indices, y):
145161
n_features = len(X_indptr) - 1
146162
self.lipschitz = np.zeros(n_features, dtype=X_data.dtype)
163+
self.global_lipschitz = 0.
147164
for j in range(n_features):
148165
Xj = X_data[X_indptr[j]:X_indptr[j+1]]
149166
self.lipschitz[j] = (Xj ** 2).sum() / (len(y) * 4)
167+
self.global_lipschitz += (Xj ** 2).sum() / (len(y) * 4)
150168

151169
def value(self, y, w, Xw):
152170
return np.log(1. + np.exp(- y * Xw)).sum() / len(y)
@@ -187,6 +205,11 @@ class QuadraticSVC(BaseDatafit):
187205
----------
188206
lipschitz : array, shape (n_features,)
189207
The coordinatewise gradient Lipschitz constants.
208+
Equal to norm(yXT, axis=0) ** 2.
209+
210+
global_lipschitz : float
211+
Global Lipschitz constant. Equal to
212+
norm(yXT, ord=2) ** 2.
190213
191214
Note
192215
----
@@ -200,6 +223,7 @@ def __init__(self):
200223
def get_spec(self):
201224
spec = (
202225
('lipschitz', float64[:]),
226+
('global_lipschitz', float64),
203227
)
204228
return spec
205229

@@ -209,18 +233,22 @@ def params_to_dict(self):
209233
def initialize(self, yXT, y):
210234
n_features = yXT.shape[1]
211235
self.lipschitz = np.zeros(n_features, dtype=yXT.dtype)
236+
self.global_lipschitz = 0.
212237
for j in range(n_features):
213238
self.lipschitz[j] = norm(yXT[:, j]) ** 2
239+
self.global_lipschitz += norm(yXT[:, j]) ** 2
214240

215241
def initialize_sparse(
216242
self, yXT_data, yXT_indptr, yXT_indices, y):
217243
n_features = len(yXT_indptr) - 1
218244
self.lipschitz = np.zeros(n_features, dtype=yXT_data.dtype)
245+
self.global_lipschitz = 0.
219246
for j in range(n_features):
220247
nrm2 = 0.
221248
for idx in range(yXT_indptr[j], yXT_indptr[j + 1]):
222249
nrm2 += yXT_data[idx] ** 2
223250
self.lipschitz[j] = nrm2
251+
self.global_lipschitz += nrm2
224252

225253
def value(self, y, w, yXTw):
226254
return (yXTw ** 2).sum() / 2 - np.sum(w)
@@ -264,8 +292,16 @@ class Huber(BaseDatafit):
264292
265293
Attributes
266294
----------
295+
delta : float
296+
Shape hyperparameter.
297+
267298
lipschitz : array, shape (n_features,)
268-
The coordinatewise gradient Lipschitz constants.
299+
The coordinatewise gradient Lipschitz constants. Equal to
300+
norm(X, axis=0) ** 2 / n_samples.
301+
302+
global_lipschitz : float
303+
Global Lipschitz constant. Equal to
304+
norm(X, ord=2) ** 2 / n_samples.
269305
270306
Note
271307
----
@@ -279,7 +315,8 @@ def __init__(self, delta):
279315
def get_spec(self):
280316
spec = (
281317
('delta', float64),
282-
('lipschitz', float64[:])
318+
('lipschitz', float64[:]),
319+
('global_lipschitz', float64),
283320
)
284321
return spec
285322

@@ -289,18 +326,22 @@ def params_to_dict(self):
289326
def initialize(self, X, y):
290327
n_features = X.shape[1]
291328
self.lipschitz = np.zeros(n_features, dtype=X.dtype)
329+
self.global_lipschitz = 0.
292330
for j in range(n_features):
293331
self.lipschitz[j] = (X[:, j] ** 2).sum() / len(y)
332+
self.global_lipschitz += (X[:, j] ** 2).sum() / len(y)
294333

295334
def initialize_sparse(
296335
self, X_data, X_indptr, X_indices, y):
297336
n_features = len(X_indptr) - 1
298337
self.lipschitz = np.zeros(n_features, dtype=X_data.dtype)
338+
self.global_lipschitz = 0.
299339
for j in range(n_features):
300340
nrm2 = 0.
301341
for idx in range(X_indptr[j], X_indptr[j + 1]):
302342
nrm2 += X_data[idx] ** 2
303343
self.lipschitz[j] = nrm2 / len(y)
344+
self.global_lipschitz += nrm2 / len(y)
304345

305346
def value(self, y, w, Xw):
306347
n_samples = len(y)

0 commit comments

Comments
 (0)