1+ import inspect
12import warnings
23
3- from traitlets import TraitType , TraitError , Undefined
4+ from traitlets import TraitType , TraitError , Undefined , Sentinel
45
56class _DelayedImportError (object ):
67 def __init__ (self , package_name ):
@@ -20,6 +21,13 @@ def __getattribute__(self, name):
2021 pd = _DelayedImportError ('pandas' )
2122
2223
24+ Empty = Sentinel ('Empty' , 'traittypes' ,
25+ """
26+ Used in traittypes to specify that the default value should
27+ be an empty dataset
28+ """ )
29+
30+
2331class SciType (TraitType ):
2432
2533 """A base trait type for numpy arrays, pandas dataframes and series."""
@@ -107,96 +115,94 @@ def set(self, obj, value):
107115 if not np .array_equal (old_value , new_value ):
108116 obj ._notify_trait (self .name , old_value , new_value )
109117
110- def __init__ (self , default_value = Undefined , allow_none = False , dtype = None , ** kwargs ):
118+ def __init__ (self , default_value = Empty , allow_none = False , dtype = None , ** kwargs ):
111119 self .dtype = dtype
112- if default_value is Undefined :
120+ if default_value is Empty :
113121 default_value = np .array (0 , dtype = self .dtype )
114- elif default_value is not None :
122+ elif default_value is not None and default_value is not Undefined :
115123 default_value = np .asarray (default_value , dtype = self .dtype )
116124 super (Array , self ).__init__ (default_value = default_value , allow_none = allow_none , ** kwargs )
117125
118126 def make_dynamic_default (self ):
119- if self .default_value is None :
127+ if self .default_value is None or self . default_value is Undefined :
120128 return self .default_value
121129 else :
122130 return np .copy (self .default_value )
123131
124132
125- class DataFrame (SciType ):
133+ class PandasType (SciType ):
126134
127135 """A pandas dataframe trait type."""
128136
129137 info_text = 'a pandas dataframe'
130138
139+ klass = None
140+
131141 def validate (self , obj , value ):
132142 if value is None and not self .allow_none :
133143 self .error (obj , value )
134144 if value is None or value is Undefined :
135- return super (DataFrame , self ).validate (obj , value )
145+ return super (PandasType , self ).validate (obj , value )
136146 try :
137- value = pd . DataFrame (value )
147+ value = self . klass (value )
138148 except (ValueError , TypeError ) as e :
139149 raise TraitError (e )
140- return super (DataFrame , self ).validate (obj , value )
150+ return super (PandasType , self ).validate (obj , value )
141151
142152 def set (self , obj , value ):
143153 new_value = self ._validate (obj , value )
144154 old_value = obj ._trait_values .get (self .name , self .default_value )
145155 obj ._trait_values [self .name ] = new_value
146- if (old_value is None and new_value is not None ) or not old_value .equals (new_value ):
156+ if ((old_value is None and new_value is not None ) or
157+ (old_value is Undefined and new_value is not Undefined ) or
158+ not old_value .equals (new_value )):
147159 obj ._notify_trait (self .name , old_value , new_value )
148160
149- def __init__ (self , default_value = Undefined , allow_none = False , dtype = None , ** kwargs ):
150- import pandas as pd
161+ def __init__ (self , default_value = Empty , allow_none = False , dtype = None , klass = None , ** kwargs ):
162+ if klass is None :
163+ klass = self .klass
164+ if (klass is not None ) and inspect .isclass (klass ):
165+ self .klass = klass
166+ else :
167+ raise TraitError ('The klass attribute must be a class'
168+ ' not: %r' % klass )
151169 self .dtype = dtype
152- if default_value is Undefined :
153- default_value = pd . DataFrame ()
154- elif default_value is not None :
155- default_value = pd . DataFrame (default_value )
156- super (DataFrame , self ).__init__ (default_value = default_value , allow_none = allow_none , ** kwargs )
170+ if default_value is Empty :
171+ default_value = klass ()
172+ elif default_value is not None and default_value is not Undefined :
173+ default_value = klass (default_value )
174+ super (PandasType , self ).__init__ (default_value = default_value , allow_none = allow_none , ** kwargs )
157175
158176 def make_dynamic_default (self ):
159- if self .default_value is None :
177+ if self .default_value is None or self . default_value is Undefined :
160178 return self .default_value
161179 else :
162180 return self .default_value .copy ()
163181
164182
165- class Series ( SciType ):
183+ class DataFrame ( PandasType ):
166184
167- """A pandas series trait type."""
185+ """A pandas dataframe trait type."""
168186
169- info_text = 'a pandas series '
187+ info_text = 'a pandas dataframe '
170188
171- def validate (self , obj , value ):
172- if value is None and not self .allow_none :
173- self .error (obj , value )
174- if value is None or value is Undefined :
175- return super (Series , self ).validate (obj , value )
176- try :
177- value = pd .Series (value )
178- except (ValueError , TypeError ) as e :
179- raise TraitError (e )
180- return super (Series , self ).validate (obj , value )
189+ def __init__ (self , default_value = Empty , allow_none = False , dtype = None , ** kwargs ):
190+ if 'klass' not in kwargs and self .klass is None :
191+ import pandas as pd
192+ kwargs ['klass' ] = pd .DataFrame
193+ super (DataFrame , self ).__init__ (
194+ default_value = default_value , allow_none = allow_none , dtype = dtype , ** kwargs )
181195
182- def set (self , obj , value ):
183- new_value = self ._validate (obj , value )
184- old_value = obj ._trait_values .get (self .name , self .default_value )
185- obj ._trait_values [self .name ] = new_value
186- if (old_value is None and new_value is not None ) or not old_value .equals (new_value ):
187- obj ._notify_trait (self .name , old_value , new_value )
188196
189- def __init__ (self , default_value = Undefined , allow_none = False , dtype = None , ** kwargs ):
190- import pandas as pd
191- self .dtype = dtype
192- if default_value is Undefined :
193- default_value = pd .Series ()
194- elif default_value is not None :
195- default_value = pd .Series (default_value )
196- super (Series , self ).__init__ (default_value = default_value , allow_none = allow_none , ** kwargs )
197+ class Series (PandasType ):
197198
198- def make_dynamic_default (self ):
199- if self .default_value is None :
200- return self .default_value
201- else :
202- return self .default_value .copy ()
199+ """A pandas series trait type."""
200+
201+ info_text = 'a pandas series'
202+
203+ def __init__ (self , default_value = Empty , allow_none = False , dtype = None , ** kwargs ):
204+ if 'klass' not in kwargs and self .klass is None :
205+ import pandas as pd
206+ kwargs ['klass' ] = pd .Series
207+ super (Series , self ).__init__ (
208+ default_value = default_value , allow_none = allow_none , dtype = dtype , ** kwargs )
0 commit comments