66
77import dataclasses
88import inspect
9- import re
109import typing as ty
11- from abc import ABC , abstractmethod
1210from dataclasses import dataclass , is_dataclass , replace
1311
1412import numpy as np
1513import torcharrow ._torcharrow
1614import typing_inspect
1715
16+ from .dtypes_core import (
17+ Boolean ,
18+ DType ,
19+ Field ,
20+ Float32 ,
21+ Float64 ,
22+ Int16 ,
23+ Int32 ,
24+ Int64 ,
25+ Int8 ,
26+ List ,
27+ Map ,
28+ MetaData ,
29+ NL ,
30+ String ,
31+ Struct ,
32+ )
33+
1834# -----------------------------------------------------------------------------
1935# Aux
2036
21- # Pretty printing constants; reused everywhere
22- OPEN = "{"
23- CLOSE = "}"
24- NL = "\n "
25-
2637# Handy Type abbreviations; reused everywhere
2738ScalarTypes = ty .Union [int , float , bool , str ]
2839
2940
30- # -----------------------------------------------------------------------------
31- # Schema and Field
32-
33- MetaData = ty .Dict [str , str ]
34-
35-
36- @dataclass (frozen = True )
37- class Field :
38- name : str
39- dtype : "DType"
40- metadata : ty .Optional [MetaData ] = None
41-
42- def __str__ (self ):
43- meta = ""
44- if self .metadata is not None :
45- meta = (
46- f"meta = { OPEN } { ', ' .join (f'{ k } : { v } ' for k ,v in self .metadata )} { CLOSE } "
47- )
48- return f"Field('{ self .name } ', { str (self .dtype )} { meta } )"
49-
50-
51- # -----------------------------------------------------------------------------
52- # Immutable Types with structural equality...
53-
54-
55- @dataclass (frozen = True ) # type: ignore
56- class DType (ABC ):
57- typecode : ty .ClassVar [str ] = "__TO_BE_DEFINED_IN_SUBCLASS__"
58- arraycode : ty .ClassVar [str ] = "__TO_BE_DEFINED_IN_SUBCLASS__"
59-
60- @property
61- @abstractmethod
62- def nullable (self ):
63- return False
64-
65- @property
66- def py_type (self ):
67- return type (self .default_value ())
68-
69- def __str__ (self ):
70- if self .nullable :
71- return f"{ self .name .title ()} (nullable=True)"
72- else :
73- return self .name
74-
75- @abstractmethod
76- def constructor (self , nullable ):
77- pass
78-
79- def with_null (self , nullable = True ):
80- return self .constructor (nullable )
81-
82- def default_value (self ):
83- # must be overridden by all non primitive types!
84- return type (self ).default
85-
86- def __qualstr__ (self ):
87- return "torcharrow.dtypes"
88-
89-
90- # for now: no float16, and all date and time stuff, no categorical, (and Null is called Void)
91-
92-
9341@dataclass (frozen = True )
9442class Void (DType ):
9543 nullable : bool = True
@@ -102,268 +50,6 @@ def constructor(self, nullable):
10250 return Void (nullable )
10351
10452
105- @dataclass (frozen = True ) # type: ignore
106- class Numeric (DType ):
107- pass
108-
109-
110- @dataclass (frozen = True )
111- class Boolean (DType ):
112- nullable : bool = False
113- typecode : ty .ClassVar [str ] = "b"
114- arraycode : ty .ClassVar [str ] = "b"
115- name : ty .ClassVar [str ] = "boolean"
116- default : ty .ClassVar [bool ] = False
117-
118- def constructor (self , nullable ):
119- return Boolean (nullable )
120-
121-
122- @dataclass (frozen = True )
123- class Int8 (Numeric ):
124- nullable : bool = False
125- typecode : ty .ClassVar [str ] = "c"
126- arraycode : ty .ClassVar [str ] = "b"
127- name : ty .ClassVar [str ] = "int8"
128- default : ty .ClassVar [int ] = 0
129-
130- def constructor (self , nullable ):
131- return Int8 (nullable )
132-
133-
134- @dataclass (frozen = True )
135- class Int16 (Numeric ):
136- nullable : bool = False
137- typecode : ty .ClassVar [str ] = "s"
138- arraycode : ty .ClassVar [str ] = "h"
139- name : ty .ClassVar [str ] = "int16"
140- default : ty .ClassVar [int ] = 0
141-
142- def constructor (self , nullable ):
143- return Int16 (nullable )
144-
145-
146- @dataclass (frozen = True )
147- class Int32 (Numeric ):
148- nullable : bool = False
149- typecode : ty .ClassVar [str ] = "i"
150- arraycode : ty .ClassVar [str ] = "i"
151- name : ty .ClassVar [str ] = "int32"
152- default : ty .ClassVar [int ] = 0
153-
154- def constructor (self , nullable ):
155- return Int32 (nullable )
156-
157-
158- @dataclass (frozen = True )
159- class Int64 (Numeric ):
160- nullable : bool = False
161- typecode : ty .ClassVar [str ] = "l"
162- arraycode : ty .ClassVar [str ] = "l"
163- name : ty .ClassVar [str ] = "int64"
164- default : ty .ClassVar [int ] = 0
165-
166- def constructor (self , nullable ):
167- return Int64 (nullable )
168-
169-
170- # Not all Arrow types are supported. We don't have a backend to support unsigned
171- # integer types right now so they are removed to not confuse users. Feel free to
172- # add unsigned int types when we have a supporting backend.
173-
174-
175- @dataclass (frozen = True )
176- class Float32 (Numeric ):
177- nullable : bool = False
178- typecode : ty .ClassVar [str ] = "f"
179- arraycode : ty .ClassVar [str ] = "f"
180- name : ty .ClassVar [str ] = "float32"
181- default : ty .ClassVar [float ] = 0.0
182-
183- def constructor (self , nullable ):
184- return Float32 (nullable )
185-
186-
187- @dataclass (frozen = True )
188- class Float64 (Numeric ):
189- nullable : bool = False
190- typecode : ty .ClassVar [str ] = "g"
191- arraycode : ty .ClassVar [str ] = "d"
192- name : ty .ClassVar [str ] = "float64"
193- default : ty .ClassVar [float ] = 0.0
194-
195- def constructor (self , nullable ):
196- return Float64 (nullable )
197-
198-
199- @dataclass (frozen = True )
200- class String (DType ):
201- nullable : bool = False
202- typecode : ty .ClassVar [str ] = "u" # utf8 string (n byte)
203- arraycode : ty .ClassVar [str ] = "w" # wchar_t (2 byte)
204- name : ty .ClassVar [str ] = "string"
205- default : ty .ClassVar [str ] = ""
206-
207- def constructor (self , nullable ):
208- return String (nullable )
209-
210-
211- @dataclass (frozen = True )
212- class Map (DType ):
213- key_dtype : DType
214- item_dtype : DType
215- nullable : bool = False
216- keys_sorted : bool = False
217- name : ty .ClassVar [str ] = "Map"
218- typecode : ty .ClassVar [str ] = "+m"
219- arraycode : ty .ClassVar [str ] = ""
220-
221- @property
222- def py_type (self ):
223- return ty .Dict [self .key_dtype .py_type , self .item_dtype .py_type ]
224-
225- def constructor (self , nullable ):
226- return Map (self .key_dtype , self .item_dtype , nullable )
227-
228- def __str__ (self ):
229- nullable = ", nullable=" + str (self .nullable ) if self .nullable else ""
230- return f"Map({ self .key_dtype } , { self .item_dtype } { nullable } )"
231-
232- def default_value (self ):
233- return {}
234-
235-
236- @dataclass (frozen = True )
237- class List (DType ):
238- item_dtype : DType
239- nullable : bool = False
240- fixed_size : int = - 1
241- name : ty .ClassVar [str ] = "List"
242- typecode : ty .ClassVar [str ] = "+l"
243- arraycode : ty .ClassVar [str ] = ""
244-
245- @property
246- def py_type (self ):
247- return ty .List [self .item_dtype .py_type ]
248-
249- def constructor (self , nullable , fixed_size = - 1 ):
250- return List (self .item_dtype , nullable , fixed_size )
251-
252- def __str__ (self ):
253- nullable = ", nullable=" + str (self .nullable ) if self .nullable else ""
254- fixed_size = (
255- ", fixed_size=" + str (self .fixed_size ) if self .fixed_size >= 0 else ""
256- )
257- return f"List({ self .item_dtype } { nullable } { fixed_size } )"
258-
259- def default_value (self ):
260- return []
261-
262-
263- @dataclass (frozen = True )
264- class Struct (DType ):
265- fields : ty .List [Field ]
266- nullable : bool = False
267- is_dataframe : bool = False
268- metadata : ty .Optional [MetaData ] = None
269- name : ty .ClassVar [str ] = "Struct"
270- typecode : ty .ClassVar [str ] = "+s"
271- arraycode : ty .ClassVar [str ] = ""
272-
273- # For generating NamedTuple class name for cached _py_type (done in __post__init__)
274- _global_py_type_id : ty .ClassVar [int ] = 0
275- _local_py_type_id : int = dataclasses .field (compare = False , default = - 1 )
276-
277- # TODO: perhaps this should be a private method
278- def get_index (self , name : str ) -> int :
279- for idx , field in enumerate (self .fields ):
280- if field .name == name :
281- return idx
282- # pyre-fixme[7]: Expected `int` but got `None`.
283- return None
284-
285- def __getstate__ (self ):
286- # _py_type is NamedTuple which is not pickle-able, skip it
287- return (self .fields , self .nullable , self .is_dataframe , self .metadata )
288-
289- def __setstate__ (self , state ):
290- # Restore state, __setattr__ hack is needed due to the frozen dataclass
291- object .__setattr__ (self , "fields" , state [0 ])
292- object .__setattr__ (self , "nullable" , state [1 ])
293- object .__setattr__ (self , "is_dataframe" , state [2 ])
294- object .__setattr__ (self , "metadata" , state [3 ])
295-
296- # reconstruct _py_type
297- self .__post_init__ ()
298-
299- def __post_init__ (self ):
300- if self .nullable :
301- for f in self .fields :
302- if not f .dtype .nullable :
303- raise TypeError (
304- f"nullable structs require each field (like { f .name } ) to be nullable as well."
305- )
306- object .__setattr__ (self , "_local_py_type_id" , type (self )._global_py_type_id )
307- type (self )._global_py_type_id += 1
308-
309- def _set_py_type (self ):
310- # cache the type instance, __setattr__ hack is needed due to the frozen dataclass
311- # the _py_type is not listed above to avoid participation in equality check
312-
313- def fix_name (name , idx ):
314- # Anonomous Row
315- if name == "" :
316- return "f_" + str (idx )
317-
318- # Remove invalid character for NamedTuple
319- # TODO: this might cause name duplicates, do disambiguation
320- name = re .sub ("[^a-zA-Z0-9_]" , "_" , name )
321- if name == "" or name [0 ].isdigit () or name [0 ] == "_" :
322- name = "f_" + name
323- return name
324-
325- object .__setattr__ (
326- self ,
327- "_py_type" ,
328- ty .NamedTuple (
329- "TorchArrowGeneratedStruct_" + str (self ._local_py_type_id ),
330- [
331- (fix_name (f .name , idx ), f .dtype .py_type )
332- for (idx , f ) in enumerate (self .fields )
333- ],
334- ),
335- )
336-
337- @property
338- def py_type (self ):
339- if not hasattr (self , "_py_type" ):
340- # this call is expensive due to the namedtuple creation, so
341- # do it lazily
342- self ._set_py_type ()
343- return self ._py_type
344-
345- def constructor (self , nullable ):
346- return Struct (self .fields , nullable )
347-
348- def get (self , name ):
349- for f in self .fields :
350- if f .name == name :
351- return f .dtype
352- raise KeyError (f"{ name } not among fields" )
353-
354- def __str__ (self ):
355- nullable = ", nullable=" + str (self .nullable ) if self .nullable else ""
356- fields = f"[{ ', ' .join (str (f ) for f in self .fields )} ]"
357- meta = ""
358- if self .metadata is not None :
359- meta = f", meta = { OPEN } { ', ' .join (f'{ k } : { v } ' for k ,v in self .metadata )} { CLOSE } "
360- else :
361- return f"Struct({ fields } { nullable } { meta } )"
362-
363- def default_value (self ):
364- return tuple (f .dtype .default_value () for f in self .fields )
365-
366-
36753# only used internally for type inference -------------------------------------
36854
36955
0 commit comments