Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.

Commit 0796272

Browse files
wenleixfacebook-github-bot
authored andcommitted
Refactor DType definition into dtypes_core.py (#468)
Summary: Pull Request resolved: #468 dtypes.py contains of two parts: (1) Standard Arrow-compatible DataType definition (Int8/16/32/64, Float32/64, String, List, Map, Struct) (2) Utility functions around DType The DType definition is quite standalone and stable; Refactor DataType definition into seperate file so allows reuse in next gen TorchArrow, such as TorchArrow-UPM or Tensor-based TA In theory we could let dtypes.py only contain DType definition, and move other things into `dtype_util.py`. Starting with this first step. Reviewed By: dracifer Differential Revision: D38358389 fbshipit-source-id: 632037498290222a60794e26d0db0c98c73f1391
1 parent deb3a33 commit 0796272

File tree

2 files changed

+360
-332
lines changed

2 files changed

+360
-332
lines changed

torcharrow/dtypes.py

Lines changed: 18 additions & 332 deletions
Original file line numberDiff line numberDiff line change
@@ -6,90 +6,38 @@
66

77
import dataclasses
88
import inspect
9-
import re
109
import typing as ty
11-
from abc import ABC, abstractmethod
1210
from dataclasses import dataclass, is_dataclass, replace
1311

1412
import numpy as np
1513
import torcharrow._torcharrow
1614
import typing_inspect
1715

16+
from .dtypes_core import (
17+
Boolean,
18+
DType,
19+
Field,
20+
Float32,
21+
Float64,
22+
Int16,
23+
Int32,
24+
Int64,
25+
Int8,
26+
List,
27+
Map,
28+
MetaData,
29+
NL,
30+
String,
31+
Struct,
32+
)
33+
1834
# -----------------------------------------------------------------------------
1935
# Aux
2036

21-
# Pretty printing constants; reused everywhere
22-
OPEN = "{"
23-
CLOSE = "}"
24-
NL = "\n"
25-
2637
# Handy Type abbreviations; reused everywhere
2738
ScalarTypes = ty.Union[int, float, bool, str]
2839

2940

30-
# -----------------------------------------------------------------------------
31-
# Schema and Field
32-
33-
MetaData = ty.Dict[str, str]
34-
35-
36-
@dataclass(frozen=True)
37-
class Field:
38-
name: str
39-
dtype: "DType"
40-
metadata: ty.Optional[MetaData] = None
41-
42-
def __str__(self):
43-
meta = ""
44-
if self.metadata is not None:
45-
meta = (
46-
f"meta = {OPEN}{', '.join(f'{k}: {v}' for k,v in self.metadata)}{CLOSE}"
47-
)
48-
return f"Field('{self.name}', {str(self.dtype)}{meta})"
49-
50-
51-
# -----------------------------------------------------------------------------
52-
# Immutable Types with structural equality...
53-
54-
55-
@dataclass(frozen=True) # type: ignore
56-
class DType(ABC):
57-
typecode: ty.ClassVar[str] = "__TO_BE_DEFINED_IN_SUBCLASS__"
58-
arraycode: ty.ClassVar[str] = "__TO_BE_DEFINED_IN_SUBCLASS__"
59-
60-
@property
61-
@abstractmethod
62-
def nullable(self):
63-
return False
64-
65-
@property
66-
def py_type(self):
67-
return type(self.default_value())
68-
69-
def __str__(self):
70-
if self.nullable:
71-
return f"{self.name.title()}(nullable=True)"
72-
else:
73-
return self.name
74-
75-
@abstractmethod
76-
def constructor(self, nullable):
77-
pass
78-
79-
def with_null(self, nullable=True):
80-
return self.constructor(nullable)
81-
82-
def default_value(self):
83-
# must be overridden by all non primitive types!
84-
return type(self).default
85-
86-
def __qualstr__(self):
87-
return "torcharrow.dtypes"
88-
89-
90-
# for now: no float16, and all date and time stuff, no categorical, (and Null is called Void)
91-
92-
9341
@dataclass(frozen=True)
9442
class Void(DType):
9543
nullable: bool = True
@@ -102,268 +50,6 @@ def constructor(self, nullable):
10250
return Void(nullable)
10351

10452

105-
@dataclass(frozen=True) # type: ignore
106-
class Numeric(DType):
107-
pass
108-
109-
110-
@dataclass(frozen=True)
111-
class Boolean(DType):
112-
nullable: bool = False
113-
typecode: ty.ClassVar[str] = "b"
114-
arraycode: ty.ClassVar[str] = "b"
115-
name: ty.ClassVar[str] = "boolean"
116-
default: ty.ClassVar[bool] = False
117-
118-
def constructor(self, nullable):
119-
return Boolean(nullable)
120-
121-
122-
@dataclass(frozen=True)
123-
class Int8(Numeric):
124-
nullable: bool = False
125-
typecode: ty.ClassVar[str] = "c"
126-
arraycode: ty.ClassVar[str] = "b"
127-
name: ty.ClassVar[str] = "int8"
128-
default: ty.ClassVar[int] = 0
129-
130-
def constructor(self, nullable):
131-
return Int8(nullable)
132-
133-
134-
@dataclass(frozen=True)
135-
class Int16(Numeric):
136-
nullable: bool = False
137-
typecode: ty.ClassVar[str] = "s"
138-
arraycode: ty.ClassVar[str] = "h"
139-
name: ty.ClassVar[str] = "int16"
140-
default: ty.ClassVar[int] = 0
141-
142-
def constructor(self, nullable):
143-
return Int16(nullable)
144-
145-
146-
@dataclass(frozen=True)
147-
class Int32(Numeric):
148-
nullable: bool = False
149-
typecode: ty.ClassVar[str] = "i"
150-
arraycode: ty.ClassVar[str] = "i"
151-
name: ty.ClassVar[str] = "int32"
152-
default: ty.ClassVar[int] = 0
153-
154-
def constructor(self, nullable):
155-
return Int32(nullable)
156-
157-
158-
@dataclass(frozen=True)
159-
class Int64(Numeric):
160-
nullable: bool = False
161-
typecode: ty.ClassVar[str] = "l"
162-
arraycode: ty.ClassVar[str] = "l"
163-
name: ty.ClassVar[str] = "int64"
164-
default: ty.ClassVar[int] = 0
165-
166-
def constructor(self, nullable):
167-
return Int64(nullable)
168-
169-
170-
# Not all Arrow types are supported. We don't have a backend to support unsigned
171-
# integer types right now so they are removed to not confuse users. Feel free to
172-
# add unsigned int types when we have a supporting backend.
173-
174-
175-
@dataclass(frozen=True)
176-
class Float32(Numeric):
177-
nullable: bool = False
178-
typecode: ty.ClassVar[str] = "f"
179-
arraycode: ty.ClassVar[str] = "f"
180-
name: ty.ClassVar[str] = "float32"
181-
default: ty.ClassVar[float] = 0.0
182-
183-
def constructor(self, nullable):
184-
return Float32(nullable)
185-
186-
187-
@dataclass(frozen=True)
188-
class Float64(Numeric):
189-
nullable: bool = False
190-
typecode: ty.ClassVar[str] = "g"
191-
arraycode: ty.ClassVar[str] = "d"
192-
name: ty.ClassVar[str] = "float64"
193-
default: ty.ClassVar[float] = 0.0
194-
195-
def constructor(self, nullable):
196-
return Float64(nullable)
197-
198-
199-
@dataclass(frozen=True)
200-
class String(DType):
201-
nullable: bool = False
202-
typecode: ty.ClassVar[str] = "u" # utf8 string (n byte)
203-
arraycode: ty.ClassVar[str] = "w" # wchar_t (2 byte)
204-
name: ty.ClassVar[str] = "string"
205-
default: ty.ClassVar[str] = ""
206-
207-
def constructor(self, nullable):
208-
return String(nullable)
209-
210-
211-
@dataclass(frozen=True)
212-
class Map(DType):
213-
key_dtype: DType
214-
item_dtype: DType
215-
nullable: bool = False
216-
keys_sorted: bool = False
217-
name: ty.ClassVar[str] = "Map"
218-
typecode: ty.ClassVar[str] = "+m"
219-
arraycode: ty.ClassVar[str] = ""
220-
221-
@property
222-
def py_type(self):
223-
return ty.Dict[self.key_dtype.py_type, self.item_dtype.py_type]
224-
225-
def constructor(self, nullable):
226-
return Map(self.key_dtype, self.item_dtype, nullable)
227-
228-
def __str__(self):
229-
nullable = ", nullable=" + str(self.nullable) if self.nullable else ""
230-
return f"Map({self.key_dtype}, {self.item_dtype}{nullable})"
231-
232-
def default_value(self):
233-
return {}
234-
235-
236-
@dataclass(frozen=True)
237-
class List(DType):
238-
item_dtype: DType
239-
nullable: bool = False
240-
fixed_size: int = -1
241-
name: ty.ClassVar[str] = "List"
242-
typecode: ty.ClassVar[str] = "+l"
243-
arraycode: ty.ClassVar[str] = ""
244-
245-
@property
246-
def py_type(self):
247-
return ty.List[self.item_dtype.py_type]
248-
249-
def constructor(self, nullable, fixed_size=-1):
250-
return List(self.item_dtype, nullable, fixed_size)
251-
252-
def __str__(self):
253-
nullable = ", nullable=" + str(self.nullable) if self.nullable else ""
254-
fixed_size = (
255-
", fixed_size=" + str(self.fixed_size) if self.fixed_size >= 0 else ""
256-
)
257-
return f"List({self.item_dtype}{nullable}{fixed_size})"
258-
259-
def default_value(self):
260-
return []
261-
262-
263-
@dataclass(frozen=True)
264-
class Struct(DType):
265-
fields: ty.List[Field]
266-
nullable: bool = False
267-
is_dataframe: bool = False
268-
metadata: ty.Optional[MetaData] = None
269-
name: ty.ClassVar[str] = "Struct"
270-
typecode: ty.ClassVar[str] = "+s"
271-
arraycode: ty.ClassVar[str] = ""
272-
273-
# For generating NamedTuple class name for cached _py_type (done in __post__init__)
274-
_global_py_type_id: ty.ClassVar[int] = 0
275-
_local_py_type_id: int = dataclasses.field(compare=False, default=-1)
276-
277-
# TODO: perhaps this should be a private method
278-
def get_index(self, name: str) -> int:
279-
for idx, field in enumerate(self.fields):
280-
if field.name == name:
281-
return idx
282-
# pyre-fixme[7]: Expected `int` but got `None`.
283-
return None
284-
285-
def __getstate__(self):
286-
# _py_type is NamedTuple which is not pickle-able, skip it
287-
return (self.fields, self.nullable, self.is_dataframe, self.metadata)
288-
289-
def __setstate__(self, state):
290-
# Restore state, __setattr__ hack is needed due to the frozen dataclass
291-
object.__setattr__(self, "fields", state[0])
292-
object.__setattr__(self, "nullable", state[1])
293-
object.__setattr__(self, "is_dataframe", state[2])
294-
object.__setattr__(self, "metadata", state[3])
295-
296-
# reconstruct _py_type
297-
self.__post_init__()
298-
299-
def __post_init__(self):
300-
if self.nullable:
301-
for f in self.fields:
302-
if not f.dtype.nullable:
303-
raise TypeError(
304-
f"nullable structs require each field (like {f.name}) to be nullable as well."
305-
)
306-
object.__setattr__(self, "_local_py_type_id", type(self)._global_py_type_id)
307-
type(self)._global_py_type_id += 1
308-
309-
def _set_py_type(self):
310-
# cache the type instance, __setattr__ hack is needed due to the frozen dataclass
311-
# the _py_type is not listed above to avoid participation in equality check
312-
313-
def fix_name(name, idx):
314-
# Anonomous Row
315-
if name == "":
316-
return "f_" + str(idx)
317-
318-
# Remove invalid character for NamedTuple
319-
# TODO: this might cause name duplicates, do disambiguation
320-
name = re.sub("[^a-zA-Z0-9_]", "_", name)
321-
if name == "" or name[0].isdigit() or name[0] == "_":
322-
name = "f_" + name
323-
return name
324-
325-
object.__setattr__(
326-
self,
327-
"_py_type",
328-
ty.NamedTuple(
329-
"TorchArrowGeneratedStruct_" + str(self._local_py_type_id),
330-
[
331-
(fix_name(f.name, idx), f.dtype.py_type)
332-
for (idx, f) in enumerate(self.fields)
333-
],
334-
),
335-
)
336-
337-
@property
338-
def py_type(self):
339-
if not hasattr(self, "_py_type"):
340-
# this call is expensive due to the namedtuple creation, so
341-
# do it lazily
342-
self._set_py_type()
343-
return self._py_type
344-
345-
def constructor(self, nullable):
346-
return Struct(self.fields, nullable)
347-
348-
def get(self, name):
349-
for f in self.fields:
350-
if f.name == name:
351-
return f.dtype
352-
raise KeyError(f"{name} not among fields")
353-
354-
def __str__(self):
355-
nullable = ", nullable=" + str(self.nullable) if self.nullable else ""
356-
fields = f"[{', '.join(str(f) for f in self.fields)}]"
357-
meta = ""
358-
if self.metadata is not None:
359-
meta = f", meta = {OPEN}{', '.join(f'{k}: {v}' for k,v in self.metadata)}{CLOSE}"
360-
else:
361-
return f"Struct({fields}{nullable}{meta})"
362-
363-
def default_value(self):
364-
return tuple(f.dtype.default_value() for f in self.fields)
365-
366-
36753
# only used internally for type inference -------------------------------------
36854

36955

0 commit comments

Comments
 (0)