|
1 | 1 | from datetime import timedelta, datetime |
2 | 2 |
|
3 | 3 | import pytest |
4 | | -from pyspark.sql.types import StructType, StructField, IntegerType, StringType, FloatType, DateType |
| 4 | +from pyspark.sql.types import ( |
| 5 | + StructType, StructField, IntegerType, StringType, FloatType, DateType, DecimalType, DoubleType, ByteType, |
| 6 | + ShortType, LongType |
| 7 | +) |
| 8 | + |
5 | 9 |
|
6 | 10 | import dbldatagen as dg |
7 | 11 | from dbldatagen import DataGenerator |
@@ -403,6 +407,28 @@ def test_basic_prefix(self): |
403 | 407 | rowCount = formattedDF.count() |
404 | 408 | assert rowCount == 1000 |
405 | 409 |
|
| 410 | + def test_missing_range_values(self): |
| 411 | + column_types = [FloatType(), DoubleType(), ByteType(), ShortType(), IntegerType(), LongType()] |
| 412 | + for column_type in column_types: |
| 413 | + range_no_min = NRange(maxValue=1.0) |
| 414 | + range_no_max = NRange(minValue=0.0) |
| 415 | + range_no_min.adjustForColumnDatatype(column_type) |
| 416 | + assert range_no_min.min == NRange._getNumericDataTypeRange(column_type)[0] |
| 417 | + assert range_no_min.step == 1 |
| 418 | + range_no_max.adjustForColumnDatatype(column_type) |
| 419 | + assert range_no_max.max == NRange._getNumericDataTypeRange(column_type)[1] |
| 420 | + assert range_no_max.step == 1 |
| 421 | + |
| 422 | + def test_range_with_until(self): |
| 423 | + range_until = NRange(step=2, until=100) |
| 424 | + range_until.adjustForColumnDatatype(IntegerType()) |
| 425 | + assert range_until.minValue == 0 |
| 426 | + assert range_until.maxValue == 101 |
| 427 | + |
| 428 | + def test_empty_range(self): |
| 429 | + empty_range = NRange() |
| 430 | + assert empty_range.isEmpty() |
| 431 | + |
406 | 432 | def test_reversed_ranges(self): |
407 | 433 | testDataSpec = (dg.DataGenerator(sparkSession=spark, name="ranged_data", rows=100000, |
408 | 434 | partitions=4) |
@@ -695,6 +721,36 @@ def test_strings_from_numeric_string_field4(self): |
695 | 721 | rowCount = nullRowsDF.count() |
696 | 722 | assert rowCount == 0 |
697 | 723 |
|
| 724 | + @pytest.mark.parametrize("columnSpecOptions", [ |
| 725 | + {"dataType": "byte", "minValue": 1, "maxValue": None}, |
| 726 | + {"dataType": "byte", "minValue": None, "maxValue": 10}, |
| 727 | + {"dataType": "short", "minValue": 1, "maxValue": None}, |
| 728 | + {"dataType": "short", "minValue": None, "maxValue": 100}, |
| 729 | + {"dataType": "integer", "minValue": 1, "maxValue": None}, |
| 730 | + {"dataType": "integer", "minValue": None, "maxValue": 100}, |
| 731 | + {"dataType": "long", "minValue": 1, "maxValue": None}, |
| 732 | + {"dataType": "long", "minValue": None, "maxValue": 100}, |
| 733 | + {"dataType": "float", "minValue": 1.0, "maxValue": None}, |
| 734 | + {"dataType": "float", "minValue": None, "maxValue": 100.0}, |
| 735 | + {"dataType": "double", "minValue": 1, "maxValue": None}, |
| 736 | + {"dataType": "double", "minValue": None, "maxValue": 100.0} |
| 737 | + ]) |
| 738 | + def test_random_generation_without_range_values(self, columnSpecOptions): |
| 739 | + dataType = columnSpecOptions.get("dataType", None) |
| 740 | + minValue = columnSpecOptions.get("minValue", None) |
| 741 | + maxValue = columnSpecOptions.get("maxValue", None) |
| 742 | + testDataSpec = (dg.DataGenerator(sparkSession=spark, name="randomGenerationWithoutRangeValues", rows=100, |
| 743 | + partitions=4) |
| 744 | + .withIdOutput() |
| 745 | + # default column type is string |
| 746 | + .withColumn("randCol", colType=dataType, minValue=minValue, maxValue=maxValue, random=True) |
| 747 | + ) |
| 748 | + |
| 749 | + df = testDataSpec.build(withTempView=True) |
| 750 | + sortedDf = df.orderBy("randCol") |
| 751 | + sortedVals = sortedDf.select("randCol").collect() |
| 752 | + assert sortedVals != df.select("randCol").collect() |
| 753 | + |
698 | 754 | def test_version_info(self): |
699 | 755 | # test access to version info without explicit import |
700 | 756 | print("Data generator version", dg.__version__) |
0 commit comments