1+ """
2+ ===================================
3+ Group Logistic regression in python
4+ ===================================
5+ Scikit-learn is missing a Group Logistic regression estimator. We show how to implement
6+ one with ``skglm``.
7+ """
8+
9+ # Author: Mathurin Massias
10+
11+ import numpy as np
12+
13+ from skglm import GeneralizedLinearEstimator
14+ from skglm .datafits import LogisticGroup
15+ from skglm .penalties import WeightedGroupL2
16+ from skglm .solvers import GroupProxNewton
17+ from skglm .utils .data import make_correlated_data , grp_converter
18+
19+ n_features = 30
20+ X , y , _ = make_correlated_data (
21+ n_samples = 10 , n_features = 30 , random_state = 0 )
22+ y = np .sign (y )
23+
24+
25+ # %%
26+ # Classifier creation: combination of penalty, datafit and solver.
27+ #
28+ grp_size = 3 # groups are made of groups of 3 consecutive features
29+ n_groups = n_features // grp_size
30+ grp_indices , grp_ptr = grp_converter (grp_size , n_features = n_features )
31+ alpha = 0.01
32+ weights = np .ones (n_groups )
33+ penalty = WeightedGroupL2 (alpha , weights , grp_ptr , grp_indices )
34+ datafit = LogisticGroup (grp_ptr , grp_indices )
35+ solver = GroupProxNewton (verbose = 2 )
36+
37+ # %%
38+ # Train the model
39+ clf = GeneralizedLinearEstimator (datafit , penalty , solver )
40+ clf .fit (X , y )
41+
42+ # %%
43+ # Fit check that groups are either all 0 or all non zero
44+ print (clf .coef_ .reshape (- 1 , grp_size ))
0 commit comments