Skip to content

Commit a9a7778

Browse files
authored
Fix data ranges for random data generation (#298)
* Set default ranges when values are unspecified * Add tests * Get numeric datatype ranges via a static method * Add tests for better coverage of NRange
1 parent f99882d commit a9a7778

File tree

3 files changed

+97
-8
lines changed

3 files changed

+97
-8
lines changed

dbldatagen/nrange.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import math
1010

1111
from pyspark.sql.types import LongType, FloatType, IntegerType, DoubleType, ShortType, \
12-
ByteType
12+
ByteType, DecimalType
1313

1414
from .datarange import DataRange
1515

@@ -83,13 +83,12 @@ def adjustForColumnDatatype(self, ctype):
8383
:param ctype: Spark SQL type instance to adjust range for
8484
:returns: No return value - executes for effect only
8585
"""
86-
if ctype.typeName() == 'decimal':
86+
numeric_types = [DecimalType, FloatType, DoubleType, ByteType, ShortType, IntegerType, LongType]
87+
if type(ctype) in numeric_types:
8788
if self.minValue is None:
88-
self.minValue = 0.0
89+
self.minValue = NRange._getNumericDataTypeRange(ctype)[0]
8990
if self.maxValue is None:
90-
self.maxValue = math.pow(10, ctype.precision - ctype.scale) - 1.0
91-
if self.step is None:
92-
self.step = 1.0
91+
self.maxValue = NRange._getNumericDataTypeRange(ctype)[1]
9392

9493
if type(ctype) is ShortType and self.maxValue is not None:
9594
assert self.maxValue <= 65536, "`maxValue` must be in range of short"
@@ -145,7 +144,8 @@ def getScale(self):
145144
# return maximum scale of components
146145
return max(smin, smax, sstep)
147146

148-
def _precision_and_scale(self, x):
147+
@staticmethod
148+
def _precision_and_scale(x):
149149
max_digits = 14
150150
int_part = int(abs(x))
151151
magnitude = 1 if int_part == 0 else int(math.log10(int_part)) + 1
@@ -158,3 +158,17 @@ def _precision_and_scale(self, x):
158158
frac_digits /= 10
159159
scale = int(math.log10(frac_digits))
160160
return (magnitude + scale, scale)
161+
162+
@staticmethod
163+
def _getNumericDataTypeRange(ctype):
164+
value_ranges = {
165+
ByteType: (0, (2 ** 4 - 1)),
166+
ShortType: (0, (2 ** 8 - 1)),
167+
IntegerType: (0, (2 ** 16 - 1)),
168+
LongType: (0, (2 ** 32 - 1)),
169+
FloatType: (0.0, 3.402e38),
170+
DoubleType: (0.0, 1.79769e308)
171+
}
172+
if type(ctype) is DecimalType:
173+
return 0.0, math.pow(10, ctype.precision - ctype.scale) - 1.0
174+
return value_ranges.get(type(ctype), None)

tests/test_options.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,25 @@ def test_random2(self):
218218
colSpec3 = ds.getColumnSpec("code3")
219219
assert colSpec3.random is True
220220

221+
def test_random3(self):
222+
# will have implied column `id` for ordinal of row
223+
ds = (
224+
dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=500, partitions=1, random=True)
225+
.withIdOutput()
226+
.withColumn("val1", "decimal(5,2)", maxValue=20.0, step=0.01, random=True)
227+
.withColumn("val2", "float", maxValue=20.0, random=True)
228+
.withColumn("val3", "double", maxValue=20.0, random=True)
229+
.withColumn("val4", "byte", maxValue=15, random=True)
230+
.withColumn("val5", "short", maxValue=31, random=True)
231+
.withColumn("val6", "integer", maxValue=63, random=True)
232+
.withColumn("val7", "long", maxValue=127, random=True)
233+
)
234+
235+
df = ds.build()
236+
cols = ["val1", "val2", "val3", "val4", "val5", "val6", "val7"]
237+
for col in cols:
238+
assert df.collect() != df.orderBy(col).collect(), f"Random values were not generated for {col}"
239+
221240
def test_random_multiple_columns(self):
222241
# will have implied column `id` for ordinal of row
223242
ds = (

tests/test_quick_tests.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
from datetime import timedelta, datetime
22

33
import pytest
4-
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, FloatType, DateType
4+
from pyspark.sql.types import (
5+
StructType, StructField, IntegerType, StringType, FloatType, DateType, DecimalType, DoubleType, ByteType,
6+
ShortType, LongType
7+
)
8+
59

610
import dbldatagen as dg
711
from dbldatagen import DataGenerator
@@ -403,6 +407,28 @@ def test_basic_prefix(self):
403407
rowCount = formattedDF.count()
404408
assert rowCount == 1000
405409

410+
def test_missing_range_values(self):
411+
column_types = [FloatType(), DoubleType(), ByteType(), ShortType(), IntegerType(), LongType()]
412+
for column_type in column_types:
413+
range_no_min = NRange(maxValue=1.0)
414+
range_no_max = NRange(minValue=0.0)
415+
range_no_min.adjustForColumnDatatype(column_type)
416+
assert range_no_min.min == NRange._getNumericDataTypeRange(column_type)[0]
417+
assert range_no_min.step == 1
418+
range_no_max.adjustForColumnDatatype(column_type)
419+
assert range_no_max.max == NRange._getNumericDataTypeRange(column_type)[1]
420+
assert range_no_max.step == 1
421+
422+
def test_range_with_until(self):
423+
range_until = NRange(step=2, until=100)
424+
range_until.adjustForColumnDatatype(IntegerType())
425+
assert range_until.minValue == 0
426+
assert range_until.maxValue == 101
427+
428+
def test_empty_range(self):
429+
empty_range = NRange()
430+
assert empty_range.isEmpty()
431+
406432
def test_reversed_ranges(self):
407433
testDataSpec = (dg.DataGenerator(sparkSession=spark, name="ranged_data", rows=100000,
408434
partitions=4)
@@ -695,6 +721,36 @@ def test_strings_from_numeric_string_field4(self):
695721
rowCount = nullRowsDF.count()
696722
assert rowCount == 0
697723

724+
@pytest.mark.parametrize("columnSpecOptions", [
725+
{"dataType": "byte", "minValue": 1, "maxValue": None},
726+
{"dataType": "byte", "minValue": None, "maxValue": 10},
727+
{"dataType": "short", "minValue": 1, "maxValue": None},
728+
{"dataType": "short", "minValue": None, "maxValue": 100},
729+
{"dataType": "integer", "minValue": 1, "maxValue": None},
730+
{"dataType": "integer", "minValue": None, "maxValue": 100},
731+
{"dataType": "long", "minValue": 1, "maxValue": None},
732+
{"dataType": "long", "minValue": None, "maxValue": 100},
733+
{"dataType": "float", "minValue": 1.0, "maxValue": None},
734+
{"dataType": "float", "minValue": None, "maxValue": 100.0},
735+
{"dataType": "double", "minValue": 1, "maxValue": None},
736+
{"dataType": "double", "minValue": None, "maxValue": 100.0}
737+
])
738+
def test_random_generation_without_range_values(self, columnSpecOptions):
739+
dataType = columnSpecOptions.get("dataType", None)
740+
minValue = columnSpecOptions.get("minValue", None)
741+
maxValue = columnSpecOptions.get("maxValue", None)
742+
testDataSpec = (dg.DataGenerator(sparkSession=spark, name="randomGenerationWithoutRangeValues", rows=100,
743+
partitions=4)
744+
.withIdOutput()
745+
# default column type is string
746+
.withColumn("randCol", colType=dataType, minValue=minValue, maxValue=maxValue, random=True)
747+
)
748+
749+
df = testDataSpec.build(withTempView=True)
750+
sortedDf = df.orderBy("randCol")
751+
sortedVals = sortedDf.select("randCol").collect()
752+
assert sortedVals != df.select("randCol").collect()
753+
698754
def test_version_info(self):
699755
# test access to version info without explicit import
700756
print("Data generator version", dg.__version__)

0 commit comments

Comments
 (0)