@@ -15,10 +15,6 @@ def __getattribute__(self, name):
1515 import numpy as np
1616except ImportError :
1717 np = _DelayedImportError ('numpy' )
18- try :
19- import pandas as pd
20- except ImportError :
21- pd = _DelayedImportError ('pandas' )
2218
2319
2420Empty = Sentinel ('Empty' , 'traittypes' ,
@@ -30,7 +26,7 @@ def __getattribute__(self, name):
3026
3127class SciType (TraitType ):
3228
33- """A base trait type for numpy arrays, pandas dataframes and series."""
29+ """A base trait type for numpy arrays, pandas dataframes, pandas series, xarray datasets and xarray dataarrays ."""
3430
3531 def __init__ (self , ** kwargs ):
3632 super (SciType , self ).__init__ (** kwargs )
@@ -132,9 +128,9 @@ def make_dynamic_default(self):
132128
133129class PandasType (SciType ):
134130
135- """A pandas dataframe trait type."""
131+ """A pandas dataframe or series trait type."""
136132
137- info_text = 'a pandas dataframe'
133+ info_text = 'a pandas dataframe or series '
138134
139135 klass = None
140136
@@ -158,15 +154,14 @@ def set(self, obj, value):
158154 not old_value .equals (new_value )):
159155 obj ._notify_trait (self .name , old_value , new_value )
160156
161- def __init__ (self , default_value = Empty , allow_none = False , dtype = None , klass = None , ** kwargs ):
157+ def __init__ (self , default_value = Empty , allow_none = False , klass = None , ** kwargs ):
162158 if klass is None :
163159 klass = self .klass
164160 if (klass is not None ) and inspect .isclass (klass ):
165161 self .klass = klass
166162 else :
167163 raise TraitError ('The klass attribute must be a class'
168164 ' not: %r' % klass )
169- self .dtype = dtype
170165 if default_value is Empty :
171166 default_value = klass ()
172167 elif default_value is not None and default_value is not Undefined :
@@ -199,10 +194,91 @@ class Series(PandasType):
199194 """A pandas series trait type."""
200195
201196 info_text = 'a pandas series'
197+ dtype = None
202198
203199 def __init__ (self , default_value = Empty , allow_none = False , dtype = None , ** kwargs ):
204200 if 'klass' not in kwargs and self .klass is None :
205201 import pandas as pd
206202 kwargs ['klass' ] = pd .Series
207203 super (Series , self ).__init__ (
208204 default_value = default_value , allow_none = allow_none , dtype = dtype , ** kwargs )
205+ self .dtype = dtype
206+
207+
208+ class XarrayType (SciType ):
209+
210+ """An xarray dataset or dataarray trait type."""
211+
212+ info_text = 'an xarray dataset or dataarray'
213+
214+ klass = None
215+
216+ def validate (self , obj , value ):
217+ if value is None and not self .allow_none :
218+ self .error (obj , value )
219+ if value is None or value is Undefined :
220+ return super (XarrayType , self ).validate (obj , value )
221+ try :
222+ value = self .klass (value )
223+ except (ValueError , TypeError ) as e :
224+ raise TraitError (e )
225+ return super (XarrayType , self ).validate (obj , value )
226+
227+ def set (self , obj , value ):
228+ new_value = self ._validate (obj , value )
229+ old_value = obj ._trait_values .get (self .name , self .default_value )
230+ obj ._trait_values [self .name ] = new_value
231+ if ((old_value is None and new_value is not None ) or
232+ (old_value is Undefined and new_value is not Undefined ) or
233+ not old_value .equals (new_value )):
234+ obj ._notify_trait (self .name , old_value , new_value )
235+
236+ def __init__ (self , default_value = Empty , allow_none = False , klass = None , ** kwargs ):
237+ if klass is None :
238+ klass = self .klass
239+ if (klass is not None ) and inspect .isclass (klass ):
240+ self .klass = klass
241+ else :
242+ raise TraitError ('The klass attribute must be a class'
243+ ' not: %r' % klass )
244+ if default_value is Empty :
245+ default_value = klass ()
246+ elif default_value is not None and default_value is not Undefined :
247+ default_value = klass (default_value )
248+ super (XarrayType , self ).__init__ (default_value = default_value , allow_none = allow_none , ** kwargs )
249+
250+ def make_dynamic_default (self ):
251+ if self .default_value is None or self .default_value is Undefined :
252+ return self .default_value
253+ else :
254+ return self .default_value .copy ()
255+
256+
257+ class Dataset (XarrayType ):
258+
259+ """An xarray dataset trait type."""
260+
261+ info_text = 'an xarray dataset'
262+
263+ def __init__ (self , default_value = Empty , allow_none = False , dtype = None , ** kwargs ):
264+ if 'klass' not in kwargs and self .klass is None :
265+ import xarray as xr
266+ kwargs ['klass' ] = xr .Dataset
267+ super (Dataset , self ).__init__ (
268+ default_value = default_value , allow_none = allow_none , dtype = dtype , ** kwargs )
269+
270+
271+ class DataArray (XarrayType ):
272+
273+ """An xarray dataarray trait type."""
274+
275+ info_text = 'an xarray dataarray'
276+ dtype = None
277+
278+ def __init__ (self , default_value = Empty , allow_none = False , dtype = None , ** kwargs ):
279+ if 'klass' not in kwargs and self .klass is None :
280+ import xarray as xr
281+ kwargs ['klass' ] = xr .DataArray
282+ super (DataArray , self ).__init__ (
283+ default_value = default_value , allow_none = allow_none , dtype = dtype , ** kwargs )
284+ self .dtype = dtype
0 commit comments