Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 7a7cd4a

Browse files
authored
Merge pull request #723 from datafold/attrs-instead-of-runtype
Convert to `attrs`, remove `runtype`
2 parents 076fec2 + ee37648 commit 7a7cd4a

39 files changed

+580
-331
lines changed

data_diff/abcs/compiler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
from abc import ABC
22

3+
import attrs
34

5+
6+
@attrs.define(frozen=False)
47
class AbstractCompiler(ABC):
58
pass
69

710

11+
@attrs.define(frozen=False, eq=False)
812
class Compilable(ABC):
913
pass

data_diff/abcs/database_types.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Tuple, Union
44
from datetime import datetime
55

6-
from runtype import dataclass
6+
import attrs
77

88
from data_diff.utils import ArithAlphanumeric, ArithUUID, Unknown
99

@@ -13,55 +13,66 @@
1313
DbTime = datetime
1414

1515

16-
@dataclass
16+
@attrs.define(frozen=True)
1717
class ColType:
18-
supported = True
18+
@property
19+
def supported(self) -> bool:
20+
return True
1921

2022

21-
@dataclass
23+
@attrs.define(frozen=True)
2224
class PrecisionType(ColType):
2325
precision: int
2426
rounds: Union[bool, Unknown] = Unknown
2527

2628

29+
@attrs.define(frozen=True)
2730
class Boolean(ColType):
2831
precision = 0
2932

3033

34+
@attrs.define(frozen=True)
3135
class TemporalType(PrecisionType):
3236
pass
3337

3438

39+
@attrs.define(frozen=True)
3540
class Timestamp(TemporalType):
3641
pass
3742

3843

44+
@attrs.define(frozen=True)
3945
class TimestampTZ(TemporalType):
4046
pass
4147

4248

49+
@attrs.define(frozen=True)
4350
class Datetime(TemporalType):
4451
pass
4552

4653

54+
@attrs.define(frozen=True)
4755
class Date(TemporalType):
4856
pass
4957

5058

51-
@dataclass
59+
@attrs.define(frozen=True)
5260
class NumericType(ColType):
5361
# 'precision' signifies how many fractional digits (after the dot) we want to compare
5462
precision: int
5563

5664

65+
@attrs.define(frozen=True)
5766
class FractionalType(NumericType):
5867
pass
5968

6069

70+
@attrs.define(frozen=True)
6171
class Float(FractionalType):
6272
python_type = float
6373

6474

75+
@attrs.define(frozen=True)
6576
class IKey(ABC):
6677
"Interface for ColType, for using a column as a key in table."
6778

@@ -74,6 +85,7 @@ def make_value(self, value):
7485
return self.python_type(value)
7586

7687

88+
@attrs.define(frozen=True)
7789
class Decimal(FractionalType, IKey): # Snowflake may use Decimal as a key
7890
@property
7991
def python_type(self) -> type:
@@ -82,27 +94,32 @@ def python_type(self) -> type:
8294
return decimal.Decimal
8395

8496

85-
@dataclass
97+
@attrs.define(frozen=True)
8698
class StringType(ColType):
8799
python_type = str
88100

89101

102+
@attrs.define(frozen=True)
90103
class ColType_UUID(ColType, IKey):
91104
python_type = ArithUUID
92105

93106

107+
@attrs.define(frozen=True)
94108
class ColType_Alphanum(ColType, IKey):
95109
python_type = ArithAlphanumeric
96110

97111

112+
@attrs.define(frozen=True)
98113
class Native_UUID(ColType_UUID):
99114
pass
100115

101116

117+
@attrs.define(frozen=True)
102118
class String_UUID(ColType_UUID, StringType):
103119
pass
104120

105121

122+
@attrs.define(frozen=True)
106123
class String_Alphanum(ColType_Alphanum, StringType):
107124
@staticmethod
108125
def test_value(value: str) -> bool:
@@ -116,11 +133,12 @@ def make_value(self, value):
116133
return self.python_type(value)
117134

118135

136+
@attrs.define(frozen=True)
119137
class String_VaryingAlphanum(String_Alphanum):
120138
pass
121139

122140

123-
@dataclass
141+
@attrs.define(frozen=True)
124142
class String_FixedAlphanum(String_Alphanum):
125143
length: int
126144

@@ -130,18 +148,20 @@ def make_value(self, value):
130148
return self.python_type(value, max_len=self.length)
131149

132150

133-
@dataclass
151+
@attrs.define(frozen=True)
134152
class Text(StringType):
135-
supported = False
153+
@property
154+
def supported(self) -> bool:
155+
return False
136156

137157

138158
# In majority of DBMSes, it is called JSON/JSONB. Only in Snowflake, it is OBJECT.
139-
@dataclass
159+
@attrs.define(frozen=True)
140160
class JSON(ColType):
141161
pass
142162

143163

144-
@dataclass
164+
@attrs.define(frozen=True)
145165
class Array(ColType):
146166
item_type: ColType
147167

@@ -151,22 +171,24 @@ class Array(ColType):
151171
# For example, in BigQuery:
152172
# - https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#struct_type
153173
# - https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#struct_literals
154-
@dataclass
174+
@attrs.define(frozen=True)
155175
class Struct(ColType):
156176
pass
157177

158178

159-
@dataclass
179+
@attrs.define(frozen=True)
160180
class Integer(NumericType, IKey):
161181
precision: int = 0
162182
python_type: type = int
163183

164-
def __post_init__(self):
184+
def __attrs_post_init__(self):
165185
assert self.precision == 0
166186

167187

168-
@dataclass
188+
@attrs.define(frozen=True)
169189
class UnknownColType(ColType):
170190
text: str
171191

172-
supported = False
192+
@property
193+
def supported(self) -> bool:
194+
return False

data_diff/abcs/mixins.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
from abc import ABC, abstractmethod
2+
3+
import attrs
4+
25
from data_diff.abcs.database_types import (
36
Array,
47
TemporalType,
@@ -13,10 +16,12 @@
1316
from data_diff.abcs.compiler import Compilable
1417

1518

19+
@attrs.define(frozen=False)
1620
class AbstractMixin(ABC):
1721
"A mixin for a database dialect"
1822

1923

24+
@attrs.define(frozen=False)
2025
class AbstractMixin_NormalizeValue(AbstractMixin):
2126
@abstractmethod
2227
def to_comparable(self, value: str, coltype: ColType) -> str:
@@ -108,6 +113,7 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
108113
return self.to_string(value)
109114

110115

116+
@attrs.define(frozen=False)
111117
class AbstractMixin_MD5(AbstractMixin):
112118
"""Methods for calculating an MD6 hash as an integer."""
113119

@@ -116,6 +122,7 @@ def md5_as_int(self, s: str) -> str:
116122
"Provide SQL for computing md5 and returning an int"
117123

118124

125+
@attrs.define(frozen=False)
119126
class AbstractMixin_Schema(AbstractMixin):
120127
"""Methods for querying the database schema
121128
@@ -134,6 +141,7 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
134141
"""
135142

136143

144+
@attrs.define(frozen=False)
137145
class AbstractMixin_RandomSample(AbstractMixin):
138146
@abstractmethod
139147
def random_sample_n(self, tbl: str, size: int) -> str:
@@ -151,6 +159,7 @@ def random_sample_ratio_approx(self, tbl: str, ratio: float) -> str:
151159
# """
152160

153161

162+
@attrs.define(frozen=False)
154163
class AbstractMixin_TimeTravel(AbstractMixin):
155164
@abstractmethod
156165
def time_travel(
@@ -173,6 +182,7 @@ def time_travel(
173182
"""
174183

175184

185+
@attrs.define(frozen=False)
176186
class AbstractMixin_OptimizerHints(AbstractMixin):
177187
@abstractmethod
178188
def optimizer_hints(self, optimizer_hints: str) -> str:

data_diff/cloud/datafold_api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import base64
2-
import dataclasses
32
import enum
43
import time
54
from typing import Any, Dict, List, Optional, Type, Tuple
65

6+
import attrs
77
import pydantic
88
import requests
99
from typing_extensions import Self
@@ -178,13 +178,13 @@ class TCloudApiDataSourceTestResult(pydantic.BaseModel):
178178
result: Optional[TCloudDataSourceTestResult]
179179

180180

181-
@dataclasses.dataclass
181+
@attrs.define(frozen=True)
182182
class DatafoldAPI:
183183
api_key: str
184184
host: str = "https://app.datafold.com"
185185
timeout: int = 30
186186

187-
def __post_init__(self):
187+
def __attrs_post_init__(self):
188188
self.host = self.host.rstrip("/")
189189
self.headers = {
190190
"Authorization": f"Key {self.api_key}",

data_diff/databases/_connect.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
from itertools import zip_longest
44
from contextlib import suppress
55
import weakref
6+
7+
import attrs
68
import dsnparse
79
import toml
810

9-
from runtype import dataclass
1011
from typing_extensions import Self
1112

1213
from data_diff.databases.base import Database, ThreadedDatabase
@@ -25,7 +26,7 @@
2526
from data_diff.databases.mssql import MsSQL
2627

2728

28-
@dataclass
29+
@attrs.define(frozen=True)
2930
class MatchUriPath:
3031
database_cls: Type[Database]
3132

@@ -92,12 +93,16 @@ def match_path(self, dsn):
9293
}
9394

9495

96+
@attrs.define(frozen=False, init=False)
9597
class Connect:
9698
"""Provides methods for connecting to a supported database using a URL or connection dict."""
9799

100+
database_by_scheme: Dict[str, Database]
101+
match_uri_path: Dict[str, MatchUriPath]
98102
conn_cache: MutableMapping[Hashable, Database]
99103

100104
def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME):
105+
super().__init__()
101106
self.database_by_scheme = database_by_scheme
102107
self.match_uri_path = {name: MatchUriPath(cls) for name, cls in database_by_scheme.items()}
103108
self.conn_cache = weakref.WeakValueDictionary()
@@ -284,6 +289,7 @@ def __make_cache_key(self, db_conf: Union[str, dict]) -> Hashable:
284289
return db_conf
285290

286291

292+
@attrs.define(frozen=False, init=False)
287293
class Connect_SetUTC(Connect):
288294
"""Provides methods for connecting to a supported database using a URL or connection dict.
289295

0 commit comments

Comments
 (0)