@@ -26,7 +26,7 @@ def __getattribute__(self, name):
2626
2727class SciType (TraitType ):
2828
29- """A base trait type for numpy arrays, pandas dataframes, pandas series and xarray datasets ."""
29+ """A base trait type for numpy arrays, pandas dataframes, pandas series, xarray datasets and xarray dataarrays ."""
3030
3131 def __init__ (self , ** kwargs ):
3232 super (SciType , self ).__init__ (** kwargs )
@@ -128,9 +128,9 @@ def make_dynamic_default(self):
128128
129129class PandasType (SciType ):
130130
131- """A pandas dataframe trait type."""
131+ """A pandas dataframe or series trait type."""
132132
133- info_text = 'a pandas dataframe'
133+ info_text = 'a pandas dataframe or series '
134134
135135 klass = None
136136
@@ -154,15 +154,14 @@ def set(self, obj, value):
154154 not old_value .equals (new_value )):
155155 obj ._notify_trait (self .name , old_value , new_value )
156156
157- def __init__ (self , default_value = Empty , allow_none = False , dtype = None , klass = None , ** kwargs ):
157+ def __init__ (self , default_value = Empty , allow_none = False , klass = None , ** kwargs ):
158158 if klass is None :
159159 klass = self .klass
160160 if (klass is not None ) and inspect .isclass (klass ):
161161 self .klass = klass
162162 else :
163163 raise TraitError ('The klass attribute must be a class'
164164 ' not: %r' % klass )
165- self .dtype = dtype
166165 if default_value is Empty :
167166 default_value = klass ()
168167 elif default_value is not None and default_value is not Undefined :
@@ -195,20 +194,22 @@ class Series(PandasType):
195194 """A pandas series trait type."""
196195
197196 info_text = 'a pandas series'
197+ dtype = None
198198
199199 def __init__ (self , default_value = Empty , allow_none = False , dtype = None , ** kwargs ):
200200 if 'klass' not in kwargs and self .klass is None :
201201 import pandas as pd
202202 kwargs ['klass' ] = pd .Series
203203 super (Series , self ).__init__ (
204204 default_value = default_value , allow_none = allow_none , dtype = dtype , ** kwargs )
205+ self .dtype = dtype
205206
206207
207208class XarrayType (SciType ):
208209
209- """An xarray dataset trait type."""
210+ """An xarray dataset or dataarray trait type."""
210211
211- info_text = 'an xarray dataset'
212+ info_text = 'an xarray dataset or dataarray '
212213
213214 klass = None
214215
@@ -232,15 +233,14 @@ def set(self, obj, value):
232233 not old_value .equals (new_value )):
233234 obj ._notify_trait (self .name , old_value , new_value )
234235
235- def __init__ (self , default_value = Empty , allow_none = False , dtype = None , klass = None , ** kwargs ):
236+ def __init__ (self , default_value = Empty , allow_none = False , klass = None , ** kwargs ):
236237 if klass is None :
237238 klass = self .klass
238239 if (klass is not None ) and inspect .isclass (klass ):
239240 self .klass = klass
240241 else :
241242 raise TraitError ('The klass attribute must be a class'
242243 ' not: %r' % klass )
243- self .dtype = dtype
244244 if default_value is Empty :
245245 default_value = klass ()
246246 elif default_value is not None and default_value is not Undefined :
@@ -266,3 +266,19 @@ def __init__(self, default_value=Empty, allow_none=False, dtype=None, **kwargs):
266266 kwargs ['klass' ] = xr .Dataset
267267 super (Dataset , self ).__init__ (
268268 default_value = default_value , allow_none = allow_none , dtype = dtype , ** kwargs )
269+
270+
271+ class DataArray (XarrayType ):
272+
273+ """An xarray dataarray trait type."""
274+
275+ info_text = 'an xarray dataarray'
276+ dtype = None
277+
278+ def __init__ (self , default_value = Empty , allow_none = False , dtype = None , ** kwargs ):
279+ if 'klass' not in kwargs and self .klass is None :
280+ import xarray as xr
281+ kwargs ['klass' ] = xr .DataArray
282+ super (DataArray , self ).__init__ (
283+ default_value = default_value , allow_none = allow_none , dtype = dtype , ** kwargs )
284+ self .dtype = dtype
0 commit comments