1- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
1+ from typing import Any , Callable , Dict , List , Mapping , Optional , Tuple , Union
22
33import numpy as np
44
1111 TASK_TYPES_TO_STRING ,
1212)
1313from autoPyTorch .data .tabular_validator import TabularInputValidator
14+ from autoPyTorch .data .utils import (
15+ get_dataset_compression_mapping
16+ )
1417from autoPyTorch .datasets .base_dataset import BaseDatasetPropertiesType
1518from autoPyTorch .datasets .resampling_strategy import (
1619 HoldoutValTypes ,
@@ -163,6 +166,7 @@ def _get_dataset_input_validator(
163166 resampling_strategy : Optional [ResamplingStrategies ] = None ,
164167 resampling_strategy_args : Optional [Dict [str , Any ]] = None ,
165168 dataset_name : Optional [str ] = None ,
169+ dataset_compression : Optional [Mapping [str , Any ]] = None ,
166170 ) -> Tuple [TabularDataset , TabularInputValidator ]:
167171 """
168172 Returns an object of `TabularDataset` and an object of
@@ -199,26 +203,27 @@ def _get_dataset_input_validator(
199203
200204 # Create a validator object to make sure that the data provided by
201205 # the user matches the autopytorch requirements
202- InputValidator = TabularInputValidator (
206+ input_validator = TabularInputValidator (
203207 is_classification = True ,
204208 logger_port = self ._logger_port ,
209+ dataset_compression = dataset_compression
205210 )
206211
207212 # Fit a input validator to check the provided data
208213 # Also, an encoder is fit to both train and test data,
209214 # to prevent unseen categories during inference
210- InputValidator .fit (X_train = X_train , y_train = y_train , X_test = X_test , y_test = y_test )
215+ input_validator .fit (X_train = X_train , y_train = y_train , X_test = X_test , y_test = y_test )
211216
212217 dataset = TabularDataset (
213218 X = X_train , Y = y_train ,
214219 X_test = X_test , Y_test = y_test ,
215- validator = InputValidator ,
220+ validator = input_validator ,
216221 resampling_strategy = resampling_strategy ,
217222 resampling_strategy_args = resampling_strategy_args ,
218223 dataset_name = dataset_name
219224 )
220225
221- return dataset , InputValidator
226+ return dataset , input_validator
222227
223228 def search (
224229 self ,
@@ -234,14 +239,15 @@ def search(
234239 total_walltime_limit : int = 100 ,
235240 func_eval_time_limit_secs : Optional [int ] = None ,
236241 enable_traditional_pipeline : bool = True ,
237- memory_limit : Optional [ int ] = 4096 ,
242+ memory_limit : int = 4096 ,
238243 smac_scenario_args : Optional [Dict [str , Any ]] = None ,
239244 get_smac_object_callback : Optional [Callable ] = None ,
240245 all_supported_metrics : bool = True ,
241246 precision : int = 32 ,
242247 disable_file_output : Optional [List [Union [str , DisableFileOutputParameters ]]] = None ,
243248 load_models : bool = True ,
244249 portfolio_selection : Optional [str ] = None ,
250+ dataset_compression : Union [Mapping [str , Any ], bool ] = False ,
245251 ) -> 'BaseTask' :
246252 """
247253 Search for the best pipeline configuration for the given dataset.
@@ -310,7 +316,7 @@ def search(
310316 feature by turning this flag to False. All machine learning
311317 algorithms that are fitted during search() are considered for
312318 ensemble building.
313- memory_limit (Optional[ int] : default=4096):
319+ memory_limit (int: default=4096):
314320 Memory limit in MB for the machine learning algorithm.
315321 Autopytorch will stop fitting the machine learning algorithm
316322 if it tries to allocate more than memory_limit MB. If None
@@ -368,20 +374,52 @@ def search(
368374 Additionally, the keyword 'greedy' is supported,
369375 which would use the default portfolio from
370376 `AutoPyTorch Tabular <https://arxiv.org/abs/2006.13799>`_.
377+ dataset_compression: Union[bool, Mapping[str, Any]] = True
378+ We compress datasets so that they fit into some predefined amount of memory.
379+ **NOTE**
380+
381+ Default configuration when left as ``True``:
382+ .. code-block:: python
383+ {
384+ "memory_allocation": 0.1,
385+ "methods": ["precision"]
386+ }
387+ You can also pass your own configuration with the same keys and choosing
388+ from the available ``"methods"``.
389+ The available options are described here:
390+ **memory_allocation**
391+ By default, we attempt to fit the dataset into ``0.1 * memory_limit``. This
392+ float value can be set with ``"memory_allocation": 0.1``. We also allow for
393+ specifying absolute memory in MB, e.g. 10MB is ``"memory_allocation": 10``.
394+ The memory used by the dataset is checked after each reduction method is
395+ performed. If the dataset fits into the allocated memory, any further methods
396+ listed in ``"methods"`` will not be performed.
397+
398+ **methods**
399+ We currently provide the following methods for reducing the dataset size.
400+ These can be provided in a list and are performed in the order as given.
401+ * ``"precision"`` - We reduce floating point precision as follows:
402+ * ``np.float128 -> np.float64``
403+ * ``np.float96 -> np.float64``
404+ * ``np.float64 -> np.float32``
405+ * pandas dataframes are reduced using the downcast option of `pd.to_numeric`
406+ to the lowest possible precision.
371407
372408 Returns:
373409 self
374410
375411 """
412+ self ._dataset_compression = get_dataset_compression_mapping (memory_limit , dataset_compression )
376413
377- self .dataset , self .InputValidator = self ._get_dataset_input_validator (
414+ self .dataset , self .input_validator = self ._get_dataset_input_validator (
378415 X_train = X_train ,
379416 y_train = y_train ,
380417 X_test = X_test ,
381418 y_test = y_test ,
382419 resampling_strategy = self .resampling_strategy ,
383420 resampling_strategy_args = self .resampling_strategy_args ,
384- dataset_name = dataset_name )
421+ dataset_name = dataset_name ,
422+ dataset_compression = self ._dataset_compression )
385423
386424 return self ._search (
387425 dataset = self .dataset ,
@@ -418,28 +456,28 @@ def predict(
418456 Returns:
419457 Array with estimator predictions.
420458 """
421- if self .InputValidator is None or not self .InputValidator ._is_fitted :
459+ if self .input_validator is None or not self .input_validator ._is_fitted :
422460 raise ValueError ("predict() is only supported after calling search. Kindly call first "
423461 "the estimator search() method." )
424462
425- X_test = self .InputValidator .feature_validator .transform (X_test )
463+ X_test = self .input_validator .feature_validator .transform (X_test )
426464 predicted_probabilities = super ().predict (X_test , batch_size = batch_size ,
427465 n_jobs = n_jobs )
428466
429- if self .InputValidator .target_validator .is_single_column_target ():
467+ if self .input_validator .target_validator .is_single_column_target ():
430468 predicted_indexes = np .argmax (predicted_probabilities , axis = 1 )
431469 else :
432470 predicted_indexes = (predicted_probabilities > 0.5 ).astype (int )
433471
434472 # Allow to predict in the original domain -- that is, the user is not interested
435473 # in our encoded values
436- return self .InputValidator .target_validator .inverse_transform (predicted_indexes )
474+ return self .input_validator .target_validator .inverse_transform (predicted_indexes )
437475
438476 def predict_proba (self ,
439477 X_test : Union [np .ndarray , pd .DataFrame , List ],
440478 batch_size : Optional [int ] = None , n_jobs : int = 1 ) -> np .ndarray :
441- if self .InputValidator is None or not self .InputValidator ._is_fitted :
479+ if self .input_validator is None or not self .input_validator ._is_fitted :
442480 raise ValueError ("predict() is only supported after calling search. Kindly call first "
443481 "the estimator search() method." )
444- X_test = self .InputValidator .feature_validator .transform (X_test )
482+ X_test = self .input_validator .feature_validator .transform (X_test )
445483 return super ().predict (X_test , batch_size = batch_size , n_jobs = n_jobs )
0 commit comments