1+ import warnings
2+
13from traitlets import TraitType , TraitError , Undefined
24
35class _DelayedImportError (object ):
@@ -22,6 +24,10 @@ class SciType(TraitType):
2224
2325 """A base trait type for numpy arrays, pandas dataframes and series."""
2426
27+ def __init__ (self , ** kwargs ):
28+ super (SciType , self ).__init__ (** kwargs )
29+ self .validators = []
30+
2531 def valid (self , * validators ):
2632 """
2733 Register new trait validators
@@ -59,6 +65,15 @@ class Foo(HasTraits):
5965 self .validators .extend (validators )
6066 return self
6167
68+ def validate (self , obj , value ):
69+ """Validate the value against registered validators."""
70+ try :
71+ for validator in self .validators :
72+ value = validator (self , value )
73+ return value
74+ except (ValueError , TypeError ) as e :
75+ raise TraitError (e )
76+
6277
6378class Array (SciType ):
6479
@@ -70,13 +85,20 @@ class Array(SciType):
7085 def validate (self , obj , value ):
7186 if value is None and not self .allow_none :
7287 self .error (obj , value )
88+ if value is None or value is Undefined :
89+ return super (Array , self ).validate (obj , value )
7390 try :
74- value = np .asarray (value , dtype = self .dtype )
75- for validator in self .validators :
76- value = validator (self , value )
77- return value
91+ r = np .asarray (value , dtype = self .dtype )
92+ if isinstance (value , np .ndarray ) and r is not value :
93+ warnings .warn (
94+ 'Given trait value dtype "%s" does not match required type "%s". '
95+ 'A coerced copy has been created.' % (
96+ np .dtype (value .dtype ).name ,
97+ np .dtype (self .dtype ).name ))
98+ value = r
7899 except (ValueError , TypeError ) as e :
79100 raise TraitError (e )
101+ return super (Array , self ).validate (obj , value )
80102
81103 def set (self , obj , value ):
82104 new_value = self ._validate (obj , value )
@@ -91,7 +113,6 @@ def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwar
91113 default_value = np .array (0 , dtype = self .dtype )
92114 elif default_value is not None :
93115 default_value = np .asarray (default_value , dtype = self .dtype )
94- self .validators = []
95116 super (Array , self ).__init__ (default_value = default_value , allow_none = allow_none , ** kwargs )
96117
97118 def make_dynamic_default (self ):
@@ -110,13 +131,13 @@ class DataFrame(SciType):
110131 def validate (self , obj , value ):
111132 if value is None and not self .allow_none :
112133 self .error (obj , value )
134+ if value is None or value is Undefined :
135+ return super (DataFrame , self ).validate (obj , value )
113136 try :
114137 value = pd .DataFrame (value )
115- for validator in self .validators :
116- value = validator (self , value )
117- return value
118138 except (ValueError , TypeError ) as e :
119139 raise TraitError (e )
140+ return super (DataFrame , self ).validate (obj , value )
120141
121142 def set (self , obj , value ):
122143 new_value = self ._validate (obj , value )
@@ -132,7 +153,6 @@ def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwar
132153 default_value = pd .DataFrame ()
133154 elif default_value is not None :
134155 default_value = pd .DataFrame (default_value )
135- self .validators = []
136156 super (DataFrame , self ).__init__ (default_value = default_value , allow_none = allow_none , ** kwargs )
137157
138158 def make_dynamic_default (self ):
@@ -151,13 +171,13 @@ class Series(SciType):
151171 def validate (self , obj , value ):
152172 if value is None and not self .allow_none :
153173 self .error (obj , value )
174+ if value is None or value is Undefined :
175+ return super (Series , self ).validate (obj , value )
154176 try :
155177 value = pd .Series (value )
156- for validator in self .validators :
157- value = validator (self , value )
158- return value
159178 except (ValueError , TypeError ) as e :
160179 raise TraitError (e )
180+ return super (Series , self ).validate (obj , value )
161181
162182 def set (self , obj , value ):
163183 new_value = self ._validate (obj , value )
@@ -173,7 +193,6 @@ def __init__(self, default_value=Undefined, allow_none=False, dtype=None, **kwar
173193 default_value = pd .Series ()
174194 elif default_value is not None :
175195 default_value = pd .Series (default_value )
176- self .validators = []
177196 super (Series , self ).__init__ (default_value = default_value , allow_none = allow_none , ** kwargs )
178197
179198 def make_dynamic_default (self ):
0 commit comments