@@ -52,7 +52,7 @@ class BalancedBatchGenerator(ParentClass):
5252 batch_size : int, optional (default=32)
5353 Number of samples per gradient update.
5454
55- sparse : bool, optional (default=False)
55+ keep_sparse : bool, optional (default=False)
5656 Either or not to conserve or not the sparsity of the input (i.e. ``X``,
5757 ``y``, ``sample_weight``). By default, the returned batches will be
5858 dense.
@@ -98,15 +98,15 @@ class BalancedBatchGenerator(ParentClass):
9898
9999 """
100100 def __init__ (self , X , y , sample_weight = None , sampler = None , batch_size = 32 ,
101- sparse = False , random_state = None ):
101+ keep_sparse = False , random_state = None ):
102102 if not HAS_KERAS :
103103 raise ImportError ("'No module named 'keras'" )
104104 self .X = X
105105 self .y = y
106106 self .sample_weight = sample_weight
107107 self .sampler = sampler
108108 self .batch_size = batch_size
109- self .sparse = sparse
109+ self .keep_sparse = keep_sparse
110110 self .random_state = random_state
111111 self ._sample ()
112112
@@ -138,7 +138,7 @@ def __getitem__(self, index):
138138 y_resampled = safe_indexing (
139139 self .y , self .indices_ [index * self .batch_size :
140140 (index + 1 ) * self .batch_size ])
141- if issparse (X_resampled ) and not self .sparse :
141+ if issparse (X_resampled ) and not self .keep_sparse :
142142 X_resampled = X_resampled .toarray ()
143143 if self .sample_weight is not None :
144144 sample_weight_resampled = safe_indexing (
@@ -154,7 +154,8 @@ def __getitem__(self, index):
154154
155155@Substitution (random_state = _random_state_docstring )
156156def balanced_batch_generator (X , y , sample_weight = None , sampler = None ,
157- batch_size = 32 , sparse = False , random_state = None ):
157+ batch_size = 32 , keep_sparse = False ,
158+ random_state = None ):
158159 """Create a balanced batch generator to train keras model.
159160
160161 Returns a generator --- as well as the number of step per epoch --- which
@@ -181,7 +182,7 @@ def balanced_batch_generator(X, y, sample_weight=None, sampler=None,
181182 batch_size : int, optional (default=32)
182183 Number of samples per gradient update.
183184
184- sparse : bool, optional (default=False)
185+ keep_sparse : bool, optional (default=False)
185186 Either or not to conserve or not the sparsity of the input (i.e. ``X``,
186187 ``y``, ``sample_weight``). By default, the returned batches will be
187188 dense.
@@ -226,4 +227,4 @@ def balanced_batch_generator(X, y, sample_weight=None, sampler=None,
226227
227228 return tf_bbg (X = X , y = y , sample_weight = sample_weight ,
228229 sampler = sampler , batch_size = batch_size ,
229- sparse = sparse , random_state = random_state )
230+ keep_sparse = keep_sparse , random_state = random_state )
0 commit comments