Skip to content

Commit 8136ccf

Browse files
Feature distribution changes - migrated tests to Pytest, use of abstract base classes (#277)
* added use of ABC, refactored tests * wip * fixed base method in Distrubution
1 parent 82ce5ce commit 8136ccf

File tree

3 files changed

+95
-123
lines changed

3 files changed

+95
-123
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ All notable changes to the Databricks Labs Data Generator will be documented in
88
### Changed
99
* Modified data generator to allow specification of constraints to the data generation process
1010
* Updated documentation for generating text data.
11+
* Modified data distribiutions to use abstract base classes
12+
* migrated data distribution tests to use `pytest`
1113

1214
### Added
1315
* Added classes for constraints on the data generation via new package `dbldatagen.constraints`

dbldatagen/distributions/data_distribution.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,15 @@
2020
and no further scaling is needed.
2121
"""
2222
import copy
23-
import pyspark.sql.functions as F
23+
from abc import ABC, abstractmethod
24+
2425
import numpy as np
26+
import pyspark.sql.functions as F
2527

2628

27-
class DataDistribution(object):
29+
class DataDistribution(ABC):
2830
""" Base class for all distributions"""
31+
2932
def __init__(self):
3033
self._rounding = False
3134
self._randomSeed = None
@@ -37,8 +40,8 @@ def get_np_random_generator(random_seed):
3740
:param random_seed: Numeric random seed to use. If < 0, then no random
3841
:return:
3942
"""
40-
assert random_seed is None or type(random_seed) in [ np.int32, np.int64, int],\
41-
f"`randomSeed` must be int or int-like not {type(random_seed)}"
43+
assert random_seed is None or type(random_seed) in [np.int32, np.int64, int], \
44+
f"`randomSeed` must be int or int-like not {type(random_seed)}"
4245
from numpy.random import default_rng
4346
if random_seed not in (-1, -1.0):
4447
rng = default_rng(random_seed)
@@ -47,17 +50,17 @@ def get_np_random_generator(random_seed):
4750

4851
return rng
4952

53+
@abstractmethod
5054
def generateNormalizedDistributionSample(self):
5155
""" Generate sample of data for distribution
5256
5357
:return: random samples from distribution scaled to values between 0 and 1
58+
59+
Note implementors should provide implementation for this,
60+
61+
Return value is expected to be a Pyspark SQL column expression such as F.expr("rand()")
5462
"""
55-
if self.randomSeed == -1 or self.randomSeed is None:
56-
newDef = F.expr("rand()")
57-
else:
58-
assert type(self.randomSeed) in [int, float], "random seed should be numeric"
59-
newDef = F.expr(f"rand({self.randomSeed})")
60-
return newDef
63+
pass
6164

6265
def withRounding(self, rounding):
6366
""" Create copy of object and set the rounding attribute

0 commit comments

Comments
 (0)