1+ import inspect
12import warnings
23
34from traitlets import TraitType , TraitError , Undefined
@@ -122,22 +123,24 @@ def make_dynamic_default(self):
122123 return np .copy (self .default_value )
123124
124125
125- class DataFrame (SciType ):
126+ class PandasType (SciType ):
126127
127128 """A pandas dataframe trait type."""
128129
129130 info_text = 'a pandas dataframe'
130131
132+ klass = None
133+
131134 def validate (self , obj , value ):
132135 if value is None and not self .allow_none :
133136 self .error (obj , value )
134137 if value is None or value is Undefined :
135- return super (DataFrame , self ).validate (obj , value )
138+ return super (PandasType , self ).validate (obj , value )
136139 try :
137- value = pd . DataFrame (value )
140+ value = self . klass (value )
138141 except (ValueError , TypeError ) as e :
139142 raise TraitError (e )
140- return super (DataFrame , self ).validate (obj , value )
143+ return super (PandasType , self ).validate (obj , value )
141144
142145 def set (self , obj , value ):
143146 new_value = self ._validate (obj , value )
@@ -146,14 +149,20 @@ def set(self, obj, value):
146149 if (old_value is None and new_value is not None ) or not old_value .equals (new_value ):
147150 obj ._notify_trait (self .name , old_value , new_value )
148151
149- def __init__ (self , default_value = Undefined , allow_none = False , dtype = None , ** kwargs ):
150- import pandas as pd
152+ def __init__ (self , default_value = Undefined , allow_none = False , dtype = None , klass = None , ** kwargs ):
153+ if klass is None :
154+ klass = self .klass
155+ if (klass is not None ) and inspect .isclass (klass ):
156+ self .klass = klass
157+ else :
158+ raise TraitError ('The klass attribute must be a class'
159+ ' not: %r' % klass )
151160 self .dtype = dtype
152161 if default_value is Undefined :
153- default_value = pd . DataFrame ()
162+ default_value = klass ()
154163 elif default_value is not None :
155- default_value = pd . DataFrame (default_value )
156- super (DataFrame , self ).__init__ (default_value = default_value , allow_none = allow_none , ** kwargs )
164+ default_value = klass (default_value )
165+ super (PandasType , self ).__init__ (default_value = default_value , allow_none = allow_none , ** kwargs )
157166
158167 def make_dynamic_default (self ):
159168 if self .default_value is None :
@@ -162,41 +171,29 @@ def make_dynamic_default(self):
162171 return self .default_value .copy ()
163172
164173
165- class Series ( SciType ):
174+ class DataFrame ( PandasType ):
166175
167- """A pandas series trait type."""
176+ """A pandas dataframe trait type."""
168177
169- info_text = 'a pandas series '
178+ info_text = 'a pandas dataframe '
170179
171- def validate (self , obj , value ):
172- if value is None and not self .allow_none :
173- self .error (obj , value )
174- if value is None or value is Undefined :
175- return super (Series , self ).validate (obj , value )
176- try :
177- value = pd .Series (value )
178- except (ValueError , TypeError ) as e :
179- raise TraitError (e )
180- return super (Series , self ).validate (obj , value )
180+ def __init__ (self , default_value = Undefined , allow_none = False , dtype = None , ** kwargs ):
181+ if 'klass' not in kwargs and self .klass is None :
182+ import pandas as pd
183+ kwargs ['klass' ] = pd .DataFrame
184+ super (DataFrame , self ).__init__ (
185+ default_value = default_value , allow_none = allow_none , dtype = dtype , ** kwargs )
181186
182- def set (self , obj , value ):
183- new_value = self ._validate (obj , value )
184- old_value = obj ._trait_values .get (self .name , self .default_value )
185- obj ._trait_values [self .name ] = new_value
186- if (old_value is None and new_value is not None ) or not old_value .equals (new_value ):
187- obj ._notify_trait (self .name , old_value , new_value )
188187
189- def __init__ (self , default_value = Undefined , allow_none = False , dtype = None , ** kwargs ):
190- import pandas as pd
191- self .dtype = dtype
192- if default_value is Undefined :
193- default_value = pd .Series ()
194- elif default_value is not None :
195- default_value = pd .Series (default_value )
196- super (Series , self ).__init__ (default_value = default_value , allow_none = allow_none , ** kwargs )
188+ class Series (PandasType ):
197189
198- def make_dynamic_default (self ):
199- if self .default_value is None :
200- return self .default_value
201- else :
202- return self .default_value .copy ()
190+ """A pandas series trait type."""
191+
192+ info_text = 'a pandas series'
193+
194+ def __init__ (self , default_value = Undefined , allow_none = False , dtype = None , ** kwargs ):
195+ if 'klass' not in kwargs and self .klass is None :
196+ import pandas as pd
197+ kwargs ['klass' ] = pd .Series
198+ super (Series , self ).__init__ (
199+ default_value = default_value , allow_none = allow_none , dtype = dtype , ** kwargs )
0 commit comments