|
| 1 | +""" |
| 2 | +================================================================= |
| 3 | +Timing comparison between direct prox computation and reweighting |
| 4 | +================================================================= |
| 5 | +Compare time and objective value of L0_5-regularized problem with |
| 6 | +direct proximal computation and iterative reweighting. |
| 7 | +""" |
| 8 | +# Author: Pierre-Antoine Bannier <pierreantoine.bannier@gmail.com> |
| 9 | + |
| 10 | +import time |
| 11 | +import numpy as np |
| 12 | +import pandas as pd |
| 13 | +from numpy.linalg import norm |
| 14 | +import matplotlib.pyplot as plt |
| 15 | + |
| 16 | +from skglm.penalties.separable import L0_5 |
| 17 | +from skglm.utils import make_correlated_data |
| 18 | +from skglm.estimators import GeneralizedLinearEstimator |
| 19 | +from skglm.experimental import IterativeReweightedL1 |
| 20 | +from skglm.solvers import AndersonCD |
| 21 | + |
| 22 | + |
| 23 | +n_samples, n_features = 200, 500 |
| 24 | +X, y, w_true = make_correlated_data( |
| 25 | + n_samples=n_samples, n_features=n_features, random_state=24) |
| 26 | + |
| 27 | +alpha_max = norm(X.T @ y, ord=np.inf) / n_samples |
| 28 | +alphas = [alpha_max / 10, alpha_max / 100, alpha_max / 1000] |
| 29 | +tol = 1e-10 |
| 30 | + |
| 31 | + |
| 32 | +def _obj(w): |
| 33 | + return (np.sum((y - X @ w) ** 2) / (2 * n_samples) |
| 34 | + + alpha * np.sum(np.sqrt(np.abs(w)))) |
| 35 | + |
| 36 | + |
| 37 | +def fit_l05(alpha): |
| 38 | + start = time.time() |
| 39 | + iterative_l05 = IterativeReweightedL1( |
| 40 | + penalty=L0_5(alpha), |
| 41 | + solver=AndersonCD(tol=tol, fit_intercept=False)).fit(X, y) |
| 42 | + iterative_time = time.time() - start |
| 43 | + |
| 44 | + # `subdiff` strategy for WS is uninformative for L0_5 |
| 45 | + start = time.time() |
| 46 | + direct_l05 = GeneralizedLinearEstimator( |
| 47 | + penalty=L0_5(alpha), |
| 48 | + solver=AndersonCD(tol=tol, fit_intercept=False, |
| 49 | + ws_strategy="fixpoint")).fit(X, y) |
| 50 | + direct_time = time.time() - start |
| 51 | + |
| 52 | + results = { |
| 53 | + "iterative": (iterative_l05, iterative_time), |
| 54 | + "direct": (direct_l05, direct_time), |
| 55 | + } |
| 56 | + return results |
| 57 | + |
| 58 | + |
| 59 | +# caching Numba compilation |
| 60 | +fit_l05(alpha_max/10) |
| 61 | + |
| 62 | +time_results = np.zeros((2, len(alphas))) |
| 63 | +obj_results = np.zeros((2, len(alphas))) |
| 64 | + |
| 65 | +# actual run |
| 66 | +for i, alpha in enumerate(alphas): |
| 67 | + results = fit_l05(alpha=alpha) |
| 68 | + iterative_l05, iterative_time = results["iterative"] |
| 69 | + direct_l05, direct_time = results["direct"] |
| 70 | + |
| 71 | + iterative_obj = _obj(iterative_l05.coef_) |
| 72 | + direct_obj = _obj(direct_l05.coef_) |
| 73 | + |
| 74 | + obj_results[:, i] = np.array([iterative_obj, direct_obj]) |
| 75 | + time_results[:, i] = np.array([iterative_time, direct_time]) |
| 76 | + |
| 77 | +time_df = pd.DataFrame(time_results.T, columns=["Iterative", "Direct"]) |
| 78 | +obj_df = pd.DataFrame(obj_results.T, columns=["Iterative", "Direct"]) |
| 79 | + |
| 80 | +time_df.index = [1e-1, 1e-2, 1e-3] |
| 81 | +obj_df.index = [1e-1, 1e-2, 1e-3] |
| 82 | + |
| 83 | +fig, axarr = plt.subplots(1, 2, figsize=(8, 3.5), constrained_layout=True) |
| 84 | +ax = axarr[0] |
| 85 | +time_df.plot.bar(rot=0, ax=ax) |
| 86 | +ax.set_xlabel(r"$\lambda/\lambda_{max}$") |
| 87 | +ax.set_ylabel("time (in s)") |
| 88 | +ax.set_title("Time to fit") |
| 89 | + |
| 90 | +ax = axarr[1] |
| 91 | +obj_df.plot.bar(rot=0, ax=ax) |
| 92 | +ax.set_xlabel(r"$\lambda/\lambda_{max}$") |
| 93 | +ax.set_ylabel("obj. value") |
| 94 | +ax.set_title("Objective at solution") |
| 95 | +plt.show(block=False) |
0 commit comments