11#!/usr/bin/env python
22# -*- coding: utf-8; -*-
33
4- # Copyright (c) 2020, 2022 Oracle and/or its affiliates.
4+ # Copyright (c) 2020, 2023 Oracle and/or its affiliates.
55# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66
77import bisect
8- from collections import defaultdict
8+ import numpy as np
99
10+ from collections import defaultdict
1011from sklearn .base import TransformerMixin
1112from sklearn .preprocessing import LabelEncoder
1213
1314
1415class DataFrameLabelEncoder (TransformerMixin ):
1516 """
16- Label encoder for pandas.dataframe. dask.dataframe.core.DataFrame
17+ Label encoder for `pandas.DataFrame` and `dask.dataframe.core.DataFrame`.
18+
19+ Attributes
20+ ----------
21+ label_encoders : defaultdict
22+ Holds the label encoder for each column.
23+
24+ Examples
25+ --------
26+ >>> import pandas as pd
27+ >>> from ads.dataset.label_encoder import DataFrameLabelEncoder
28+
29+ >>> df = pd.DataFrame(data={'col1': [1, 2], 'col2': [3, 4]})
30+ >>> le = DataFrameLabelEncoder()
31+ >>> le.fit_transform(X=df)
32+
1733 """
1834
1935 def __init__ (self ):
36+ """Initialize an instance of DataFrameLabelEncoder."""
2037 self .label_encoders = defaultdict (LabelEncoder )
2138
22- def fit (self , X ):
39+ def fit (self , X : "pandas.DataFrame" ):
2340 """
24- Fits a DataFrameLAbelEncoder.
41+ Fits a DataFrameLabelEncoder.
42+
43+ Parameters
44+ ----------
45+ X : pandas.DataFrame
46+ Target values.
47+
48+ Returns
49+ -------
50+ self : returns an instance of self.
51+ Fitted label encoder.
52+
2553 """
2654 for column in X .columns :
2755 if X [column ].dtype .name in ["object" , "category" ]:
@@ -33,12 +61,24 @@ def fit(self, X):
3361 for class_ in self .label_encoders [column ].classes_ .tolist ()
3462 ]
3563 bisect .insort_left (label_encoder_classes_ , "unknown" )
64+ label_encoder_classes_ = np .asarray (label_encoder_classes_ )
3665 self .label_encoders [column ].classes_ = label_encoder_classes_
3766 return self
3867
39- def transform (self , X ):
68+ def transform (self , X : "pandas.DataFrame" ):
4069 """
41- Transforms a dataset using the DataFrameLAbelEncoder.
70+ Transforms a dataset using the DataFrameLabelEncoder.
71+
72+ Parameters
73+ ----------
74+ X : pandas.DataFrame
75+ Target values.
76+
77+ Returns
78+ -------
79+ pandas.DataFrame
80+ Labels as normalized encodings.
81+
4282 """
4383 categorical_columns = list (self .label_encoders .keys ())
4484 if len (categorical_columns ) == 0 :
0 commit comments