Skip to content

Commit 82ce5ce

Browse files
Feature constraints (#257)
* feature for constraints * lint related updates * work in progress * wip * wip * wip * wip * wip * wip * wip * wip * wip * wip * wip * wip * updated build actions * Update CHANGELOG.md * wip * wip * wip * wip
1 parent b28602d commit 82ce5ce

20 files changed

+1125
-8
lines changed

CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,13 @@ All notable changes to the Databricks Labs Data Generator will be documented in
55

66
### Unreleased
77

8-
#### Changed
8+
### Changed
9+
* Modified data generator to allow specification of constraints to the data generation process
910
* Updated documentation for generating text data.
1011

12+
### Added
13+
* Added classes for constraints on the data generation via new package `dbldatagen.constraints`
14+
1115

1216
### Version 0.3.6 Post 1
1317

dbldatagen/constraints/__init__.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# See the License for the specific language governing permissions and
2+
# limitations under the License.
3+
#
4+
5+
"""
6+
This package defines the constraints classes for the `dbldatagen` library.
7+
8+
The constraints classes are used to define predefined constraints that may be used to constrain the generated data.
9+
10+
Constraining the generated data is implemented in several ways:
11+
12+
- Rejection of rows that do not meet the criteria
13+
- Modifying the generated data to meet the constraint (including modifying the data generation parameters)
14+
15+
Some constraints may be implemented using a combination of the above approaches.
16+
17+
For implementations using the rejection approach, the data generation process will possibly generate less than the
18+
requested number of rows.
19+
20+
For the current implementation, most of the constraint strategies will be implemented using rejection based criteria.
21+
"""
22+
23+
from .chained_relation import ChainedRelation
24+
from .constraint import Constraint
25+
from .literal_range_constraint import LiteralRange
26+
from .literal_relation_constraint import LiteralRelation
27+
from .negative_values import NegativeValues
28+
from .positive_values import PositiveValues
29+
from .ranged_values_constraint import RangedValues
30+
from .sql_expr import SqlExpr
31+
from .unique_combinations import UniqueCombinations
32+
33+
__all__ = ["chained_relation",
34+
"constraint",
35+
"negative_values",
36+
"literal_range_constraint",
37+
"literal_relation_constraint",
38+
"positive_values",
39+
"ranged_values_constraint",
40+
"unique_combinations"]
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# See the License for the specific language governing permissions and
2+
# limitations under the License.
3+
#
4+
5+
"""
6+
This module defines the ChainedInequality class
7+
"""
8+
from pyspark.sql import DataFrame
9+
import pyspark.sql.functions as F
10+
from .constraint import Constraint, NoPrepareTransformMixin
11+
12+
13+
class ChainedRelation(NoPrepareTransformMixin, Constraint):
14+
"""ChainedRelation constraint
15+
16+
Constrains one or more columns so that each column has a relationship to the next.
17+
18+
For example if the constraint is defined as `ChainedRelation(['a', 'b','c'], "<")` then only rows that
19+
satisfy the condition `a < b < c` will be included in the output
20+
(where `a`, `b` and `c` represent the data values for the rows).
21+
22+
This can be used to model time related transactions (for example in retail where the purchaseDate, shippingDate
23+
and returnDate all have a specific relationship) etc.
24+
25+
Relations supported include <, <=, >=, >, !=, ==
26+
27+
:param columns: column name or list of column names as string or list of strings
28+
:param relation: operator to check - should be one of <,> , =,>=,<=, ==, !=
29+
"""
30+
def __init__(self, columns, relation):
31+
super().__init__(supportsStreaming=True)
32+
self._relation = relation
33+
self._columns = self._columnsFromListOrString(columns)
34+
35+
if relation not in self.SUPPORTED_OPERATORS:
36+
raise ValueError(f"Parameter `relation` should be one of the operators :{self.SUPPORTED_OPERATORS}")
37+
38+
if not isinstance(self._columns, list) or len(self._columns) <= 1:
39+
raise ValueError("ChainedRelation constraints must be defined across more than one column")
40+
41+
def _generateFilterExpression(self):
42+
""" Generated composite filter expression for chained set of filter expressions
43+
44+
I.e if columns is ['a', 'b', 'c'] and relation is '<'
45+
46+
create set of filters [ col('a') < col('b'), col('b') < col('c')]
47+
and combine them as single expression using logical and operation
48+
49+
:return: filter expression for chained expressions
50+
"""
51+
expressions = [F.col(colname) for colname in self._columns]
52+
53+
filters = []
54+
# build set of filters for chained expressions
55+
for ix in range(1, len(expressions)):
56+
filters.append(self._generate_relation_expression(expressions[ix - 1], self._relation, expressions[ix]))
57+
58+
# ... and combine them using logical `and` operation
59+
return self.mkCombinedConstraintExpression(filters)
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# See the License for the specific language governing permissions and
2+
# limitations under the License.
3+
#
4+
5+
"""
6+
This module defines the Constraint class
7+
"""
8+
import types
9+
from abc import ABC, abstractmethod
10+
from pyspark.sql import Column
11+
12+
13+
class Constraint(ABC):
14+
""" Constraint object - base class for predefined and custom constraints
15+
16+
This class is meant for internal use only.
17+
18+
"""
19+
SUPPORTED_OPERATORS = ["<", ">", ">=", "!=", "==", "=", "<=", "<>"]
20+
21+
def __init__(self, supportsStreaming=False):
22+
"""
23+
Initialize the constraint object
24+
"""
25+
self._filterExpression = None
26+
self._calculatedFilterExpression = False
27+
self._supportsStreaming = supportsStreaming
28+
29+
@staticmethod
30+
def _columnsFromListOrString(columns):
31+
""" Get columns as list of columns from string of list-like
32+
33+
:param columns: string or list of strings representing column names
34+
"""
35+
if isinstance(columns, str):
36+
return [columns]
37+
elif isinstance(columns, (list, set, tuple, types.GeneratorType)):
38+
return list(columns)
39+
else:
40+
raise ValueError("Columns must be a string or list of strings")
41+
42+
@staticmethod
43+
def _generate_relation_expression(column, relation, valueExpression):
44+
""" Generate comparison expression
45+
46+
:param column: Column to generate comparison against
47+
:param relation: relation to implement
48+
:param valueExpression: expression to compare to
49+
:return: relation expression as variation of Pyspark SQL columns
50+
"""
51+
if relation == ">":
52+
return column > valueExpression
53+
elif relation == ">=":
54+
return column >= valueExpression
55+
elif relation == "<":
56+
return column < valueExpression
57+
elif relation == "<=":
58+
return column <= valueExpression
59+
elif relation in ["!=", "<>"]:
60+
return column != valueExpression
61+
elif relation in ["=", "=="]:
62+
return column == valueExpression
63+
else:
64+
raise ValueError(f"Unsupported relation type '{relation}")
65+
66+
@staticmethod
67+
def mkCombinedConstraintExpression(constraintExpressions):
68+
""" Generate a SQL expression that combines multiple constraints using AND
69+
70+
:param constraintExpressions: list of Pyspark SQL Column constraint expression objects
71+
:return: combined constraint expression as Pyspark SQL Column object (or None if no valid expressions)
72+
73+
"""
74+
assert constraintExpressions is not None and isinstance(constraintExpressions, list), \
75+
"Constraints must be a list of Pyspark SQL Column instances"
76+
77+
assert all(expr is None or isinstance(expr, Column) for expr in constraintExpressions), \
78+
"Constraint expressions must be Pyspark SQL columns or None"
79+
80+
valid_constraint_expressions = [expr for expr in constraintExpressions if expr is not None]
81+
82+
if len(valid_constraint_expressions) > 0:
83+
combined_constraint_expression = valid_constraint_expressions[0]
84+
85+
for additional_constraint in valid_constraint_expressions[1:]:
86+
combined_constraint_expression = combined_constraint_expression & additional_constraint
87+
88+
return combined_constraint_expression
89+
else:
90+
return None
91+
92+
@abstractmethod
93+
def prepareDataGenerator(self, dataGenerator):
94+
""" Prepare the data generator to generate data that matches the constraint
95+
96+
This method may modify the data generation rules to meet the constraint
97+
98+
:param dataGenerator: Data generation object that will generate the dataframe
99+
:return: modified or unmodified data generator
100+
"""
101+
raise NotImplementedError("Method prepareDataGenerator must be implemented in derived class")
102+
103+
@abstractmethod
104+
def transformDataframe(self, dataGenerator, dataFrame):
105+
""" Transform the dataframe to make data conform to constraint if possible
106+
107+
This method should not modify the dataGenerator - but may modify the dataframe
108+
109+
:param dataGenerator: Data generation object that generated the dataframe
110+
:param dataFrame: generated dataframe
111+
:return: modified or unmodified Spark dataframe
112+
113+
The default transformation returns the dataframe unmodified
114+
115+
"""
116+
raise NotImplementedError("Method transformDataframe must be implemented in derived class")
117+
118+
@abstractmethod
119+
def _generateFilterExpression(self):
120+
""" Generate a Pyspark SQL expression that may be used for filtering"""
121+
raise NotImplementedError("Method _generateFilterExpression must be implemented in derived class")
122+
123+
@property
124+
def supportsStreaming(self):
125+
""" Return True if the constraint supports streaming dataframes"""
126+
return self._supportsStreaming
127+
128+
@property
129+
def filterExpression(self):
130+
""" Return the filter expression (as instance of type Column that evaluates to True or non-True)"""
131+
if not self._calculatedFilterExpression:
132+
self._filterExpression = self._generateFilterExpression()
133+
self._calculatedFilterExpression = True
134+
return self._filterExpression
135+
136+
137+
class NoFilterMixin:
138+
""" Mixin class to indicate that constraint has no filter expression
139+
140+
Intended to be used in implementation of the concrete constraint classes.
141+
142+
Use of the mixin class is optional but when used with the Constraint class and multiple inheritance,
143+
it will provide a default implementation of the _generateFilterExpression method that satisfies
144+
the abstract method requirement of the Constraint class.
145+
146+
When using mixins, place the mixin class first in the list of base classes.
147+
"""
148+
def _generateFilterExpression(self):
149+
""" Generate a Pyspark SQL expression that may be used for filtering"""
150+
return None
151+
152+
153+
class NoPrepareTransformMixin:
154+
""" Mixin class to indicate that constraint has no filter expression
155+
156+
Intended to be used in implementation of the concrete constraint classes.
157+
158+
Use of the mixin class is optional but when used with the Constraint class and multiple inheritance,
159+
it will provide a default implementation of the `prepareDataGenerator` and `transformeDataFrame` methods
160+
that satisfies the abstract method requirements of the Constraint class.
161+
162+
When using mixins, place the mixin class first in the list of base classes.
163+
"""
164+
def prepareDataGenerator(self, dataGenerator):
165+
""" Prepare the data generator to generate data that matches the constraint
166+
167+
This method may modify the data generation rules to meet the constraint
168+
169+
:param dataGenerator: Data generation object that will generate the dataframe
170+
:return: modified or unmodified data generator
171+
"""
172+
return dataGenerator
173+
174+
def transformDataframe(self, dataGenerator, dataFrame):
175+
""" Transform the dataframe to make data conform to constraint if possible
176+
177+
This method should not modify the dataGenerator - but may modify the dataframe
178+
179+
:param dataGenerator: Data generation object that generated the dataframe
180+
:param dataFrame: generated dataframe
181+
:return: modified or unmodified Spark dataframe
182+
183+
The default transformation returns the dataframe unmodified
184+
185+
"""
186+
return dataFrame
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# See the License for the specific language governing permissions and
2+
# limitations under the License.
3+
#
4+
5+
"""
6+
This module defines the ScalarRange class
7+
"""
8+
import pyspark.sql.functions as F
9+
10+
from .constraint import Constraint, NoPrepareTransformMixin
11+
12+
13+
class LiteralRange(NoPrepareTransformMixin, Constraint):
14+
""" LiteralRange Constraint object - validates that column value(s) are between 2 literal values
15+
16+
:param columns: Name of column or list of column names
17+
:param lowValue: Tests that columns have values greater than low value (greater or equal if `strict` is False)
18+
:param highValue: Tests that columns have values less than high value (less or equal if `strict` is False)
19+
:param strict: If True, excludes low and high values from range. Defaults to False
20+
21+
Note `lowValue` and `highValue` must be values that can be converted to a literal expression using the
22+
`pyspark.sql.functions.lit` function
23+
"""
24+
25+
def __init__(self, columns, lowValue, highValue, strict=False):
26+
super().__init__(supportsStreaming=True)
27+
self._columns = self._columnsFromListOrString(columns)
28+
self._lowValue = lowValue
29+
self._highValue = highValue
30+
self._strict = strict
31+
32+
def _generateFilterExpression(self):
33+
""" Generate a SQL filter expression that may be used for filtering"""
34+
expressions = [F.col(colname) for colname in self._columns]
35+
minValue = F.lit(self._lowValue)
36+
maxValue = F.lit(self._highValue)
37+
38+
# build ranged comparison expressions
39+
if self._strict:
40+
filters = [(column_expr > minValue) & (column_expr < maxValue) for column_expr in expressions]
41+
else:
42+
filters = [column_expr.between(minValue, maxValue) for column_expr in expressions]
43+
44+
# ... and combine them using logical `and` operation
45+
return self.mkCombinedConstraintExpression(filters)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# See the License for the specific language governing permissions and
2+
# limitations under the License.
3+
#
4+
5+
"""
6+
This module defines the ScalarInequality class
7+
"""
8+
import pyspark.sql.functions as F
9+
10+
from .constraint import Constraint, NoPrepareTransformMixin
11+
12+
13+
class LiteralRelation(NoPrepareTransformMixin, Constraint):
14+
"""LiteralRelation constraint
15+
16+
Constrains one or more columns so that the columns have an a relationship to a constant value
17+
18+
:param columns: column name or list of column names
19+
:param relation: operator to check - should be one of <,> , =,>=,<=, ==, !=
20+
:param value: A literal value to to compare against
21+
"""
22+
23+
def __init__(self, columns, relation, value):
24+
super().__init__(supportsStreaming=True)
25+
self._columns = self._columnsFromListOrString(columns)
26+
self._relation = relation
27+
self._value = value
28+
29+
if relation not in self.SUPPORTED_OPERATORS:
30+
raise ValueError(f"Parameter `relation` should be one of the operators :{self.SUPPORTED_OPERATORS}")
31+
32+
def _generateFilterExpression(self):
33+
expressions = [F.col(colname) for colname in self._columns]
34+
literalValue = F.lit(self._value)
35+
filters = [self._generate_relation_expression(col, self._relation, literalValue) for col in expressions]
36+
37+
return self.mkCombinedConstraintExpression(filters)

0 commit comments

Comments
 (0)