Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 13b94bb

Browse files
committed
Fixed PR. All tests passing.
1 parent 963eba8 commit 13b94bb

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

data_diff/databases/databricks.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,8 @@ class Databricks(Database):
2222
"FLOAT": Float,
2323
"DOUBLE": Float,
2424
"DECIMAL": Decimal,
25-
2625
# Timestamps
2726
"TIMESTAMP": Timestamp,
28-
2927
# Text
3028
"STRING": Text,
3129
}
@@ -67,8 +65,8 @@ def to_string(self, s: str) -> str:
6765
return f"cast({s} as string)"
6866

6967
def _convert_db_precision_to_digits(self, p: int) -> int:
70-
# Subtracting 2 due to wierd precision issues in Databricks for the FLOAT type
71-
return super()._convert_db_precision_to_digits(p) - 2
68+
# Subtracting 1 due to wierd precision issues
69+
return max(super()._convert_db_precision_to_digits(p) - 1, 0)
7270

7371
def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str]] = None) -> Dict[str, ColType]:
7472
# Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL.
@@ -88,19 +86,19 @@ def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str
8886

8987
resulted_rows = []
9088
for row in rows:
91-
row_type = 'DECIMAL' if row.DATA_TYPE == 3 else row.TYPE_NAME
89+
row_type = "DECIMAL" if row.DATA_TYPE == 3 else row.TYPE_NAME
9290
type_cls = self.TYPE_CLASSES.get(row_type, UnknownColType)
9391

9492
if issubclass(type_cls, Integer):
9593
row = (row.COLUMN_NAME, row_type, None, None, 0)
9694

9795
elif issubclass(type_cls, Float):
98-
numeric_precision = math.ceil(row.DECIMAL_DIGITS / math.log(2, 10))
96+
numeric_precision = self._convert_db_precision_to_digits(row.DECIMAL_DIGITS)
9997
row = (row.COLUMN_NAME, row_type, None, numeric_precision, None)
10098

10199
elif issubclass(type_cls, Decimal):
102100
# TYPE_NAME has a format DECIMAL(x,y)
103-
items = row.TYPE_NAME[8:].rstrip(')').split(',')
101+
items = row.TYPE_NAME[8:].rstrip(")").split(",")
104102
numeric_precision, numeric_scale = int(items[0]), int(items[1])
105103
row = (row.COLUMN_NAME, row_type, None, numeric_precision, numeric_scale)
106104

@@ -123,7 +121,7 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
123121
timestamp = f"cast(round(unix_micros({value}) / 1000000, {coltype.precision}) * 1000000 as bigint)"
124122
return f"date_format(timestamp_micros({timestamp}), 'yyyy-MM-dd HH:mm:ss.SSSSSS')"
125123
else:
126-
precision_format = 'S' * coltype.precision + '0' * (6 - coltype.precision)
124+
precision_format = "S" * coltype.precision + "0" * (6 - coltype.precision)
127125
return f"date_format({value}, 'yyyy-MM-dd HH:mm:ss.{precision_format}')"
128126

129127
def normalize_number(self, value: str, coltype: NumericType) -> str:

0 commit comments

Comments
 (0)