88
99from sklearn .utils .validation import check_is_fitted
1010from sklearn .utils import check_array , check_consistent_length
11- from sklearn .linear_model import MultiTaskLasso as MultiTaskLasso_sklearn
1211from sklearn .linear_model ._base import (
1312 LinearModel , RegressorMixin ,
1413 LinearClassifierMixin , SparseCoefMixin , BaseEstimator
@@ -1126,8 +1125,7 @@ def fit(self, X, y):
11261125 # TODO add predict_proba for LinearSVC
11271126
11281127
1129- # TODO we should no longer inherit from sklearn
1130- class MultiTaskLasso (MultiTaskLasso_sklearn ):
1128+ class MultiTaskLasso (LinearModel , RegressorMixin ):
11311129 r"""MultiTaskLasso estimator.
11321130
11331131 The optimization objective for MultiTaskLasso is::
@@ -1139,6 +1137,9 @@ class MultiTaskLasso(MultiTaskLasso_sklearn):
11391137 alpha : float, optional
11401138 Regularization strength (constant that multiplies the L21 penalty).
11411139
1140+ copy_X : bool, optional (default=True)
1141+ If True, X will be copied; else, it may be overwritten.
1142+
11421143 max_iter : int, optional
11431144 The maximum number of iterations (subproblem definitions).
11441145
@@ -1179,12 +1180,14 @@ class MultiTaskLasso(MultiTaskLasso_sklearn):
11791180 Number of subproblems solved by Celer to reach the specified tolerance.
11801181 """
11811182
1182- def __init__ (self , alpha = 1. , max_iter = 50 , max_epochs = 50_000 , p0 = 10 ,
1183+ def __init__ (self , alpha = 1. , copy_X = True , max_iter = 50 , max_epochs = 50_000 , p0 = 10 ,
11831184 verbose = 0 , tol = 1e-4 , fit_intercept = True , warm_start = False ,
11841185 ws_strategy = "subdiff" ):
1185- super ().__init__ (
1186- alpha = alpha , tol = tol ,
1187- fit_intercept = fit_intercept , warm_start = warm_start )
1186+ self .tol = tol
1187+ self .alpha = alpha
1188+ self .copy_X = copy_X
1189+ self .warm_start = warm_start
1190+ self .fit_intercept = fit_intercept
11881191 self .max_iter = max_iter
11891192 self .p0 = p0
11901193 self .ws_strategy = ws_strategy
0 commit comments