diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/JobVerificationEngine.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/JobVerificationEngine.scala index 77359ad4ca..1354c98e86 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/JobVerificationEngine.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/JobVerificationEngine.scala @@ -26,7 +26,8 @@ import scala.collection.mutable.Queue import org.apache.spark.sql.DataFrame import org.apache.spark.sql.execution.SparkPlan -// Wraps Crumb data specific to graph vertices and adds graph methods. +// Wraps Crumb data specific to graph vertices and provides graph methods. +// Represents a recursive ecall DAG node. class JobNode(val inputMacs: ArrayBuffer[ArrayBuffer[Byte]] = ArrayBuffer[ArrayBuffer[Byte]](), val numInputMacs: Int = 0, val allOutputsMac: ArrayBuffer[Byte] = ArrayBuffer[Byte](), @@ -58,6 +59,7 @@ class JobNode(val inputMacs: ArrayBuffer[ArrayBuffer[Byte]] = ArrayBuffer[ArrayB } // Compute and return a list of paths from this node to a sink node. + // Used in naive DAG comparison. def pathsToSink(): ArrayBuffer[List[Seq[Int]]] = { val retval = ArrayBuffer[List[Seq[Int]]]() if (this.isSink) { @@ -108,6 +110,7 @@ class JobNode(val inputMacs: ArrayBuffer[ArrayBuffer[Byte]] = ArrayBuffer[ArrayB } // Used in construction of expected DAG. +// Represents a recursive Operator DAG node. class OperatorNode(val operatorName: String = "") { var children: ArrayBuffer[OperatorNode] = ArrayBuffer[OperatorNode]() var parents: ArrayBuffer[OperatorNode] = ArrayBuffer[OperatorNode]() @@ -152,12 +155,13 @@ object JobVerificationEngine { 10 -> "countRowsPerPartition", 11 -> "computeNumRowsPerPartition", 12 -> "localLimit", - 13 -> "limitReturnRows" + 13 -> "limitReturnRows", + 14 -> "broadcastNestedLoopJoin" ).withDefaultValue("unknown") val possibleSparkOperators = Seq[String]("EncryptedProject", - "EncryptedSortMergeJoin", "EncryptedSort", + "EncryptedSortMergeJoin", "EncryptedFilter", "EncryptedAggregate", "EncryptedGlobalLimit", @@ -172,6 +176,12 @@ object JobVerificationEngine { logEntryChains.clear } + /******************************** + Graph construction helper methods + ********************************/ + + // Check if operator node is supported by Job Verification Engine. + // Should be in `possibleSparkOperators` list. def isValidOperatorNode(node: OperatorNode): Boolean = { for (targetSubstring <- possibleSparkOperators) { if (node.operatorName contains targetSubstring) { @@ -181,6 +191,8 @@ object JobVerificationEngine { return false } + // Compares paths returned from pathsToSink Job Node method. + // Used in naive DAG comparison. def pathsEqual(executedPaths: ArrayBuffer[List[Seq[Int]]], expectedPaths: ArrayBuffer[List[Seq[Int]]]): Boolean = { // Executed paths might contain extraneous paths from @@ -188,11 +200,12 @@ object JobVerificationEngine { return expectedPaths.toSet.subsetOf(executedPaths.toSet) } - // Recursively convert SparkPlan objects to OperatorNode object. + // operatorDAGFromPlan helper - recursively convert SparkPlan objects to OperatorNode object. def sparkNodesToOperatorNodes(plan: SparkPlan): OperatorNode = { var operatorName = "" + val firstLine = plan.toString.split("\n")(0) for (sparkOperator <- possibleSparkOperators) { - if (plan.toString.split("\n")(0) contains sparkOperator) { + if (firstLine contains sparkOperator) { operatorName = sparkOperator } } @@ -204,7 +217,7 @@ object JobVerificationEngine { return operatorNode } - // Returns true if every OperatorNode in this list is "valid". + // Returns true if every OperatorNode in this list is "valid", or supported by JobVerificationEngine. def allValidOperators(operators: ArrayBuffer[OperatorNode]): Boolean = { for (operator <- operators) { if (!isValidOperatorNode(operator)) { @@ -214,7 +227,7 @@ object JobVerificationEngine { return true } - // Recursively prunes non valid nodes from an OperatorNode tree. + // operatorDAGFromPlan helper - recursively prunes non valid nodes from an OperatorNode tree, bottom up. def fixOperatorTree(root: OperatorNode): Unit = { if (root.isOrphan) { return @@ -233,21 +246,36 @@ object JobVerificationEngine { root.setParents(newParents) } for (parent <- root.parents) { - parent.addChild(root) fixOperatorTree(parent) } } + // Given operators with correctly set parents, correctly set the children pointers. + def setChildrenDag(operators: ArrayBuffer[OperatorNode]): Unit = { + for (operator <- operators) { + operator.setChildren(ArrayBuffer[OperatorNode]()) + } + for (operator <- operators) { + for (parent <- operator.parents) { + parent.addChild(operator) + } + } + } + // Uses BFS to put all nodes in an OperatorNode tree into a list. def treeToList(root: OperatorNode): ArrayBuffer[OperatorNode] = { val retval = ArrayBuffer[OperatorNode]() val queue = new Queue[OperatorNode]() + val visited = Set[OperatorNode]() queue.enqueue(root) while (!queue.isEmpty) { val curr = queue.dequeue - retval.append(curr) - for (parent <- curr.parents) { - queue.enqueue(parent) + if (!visited.contains(curr)) { + visited.add(curr) + retval.append(curr) + for (parent <- curr.parents) { + queue.enqueue(parent) + } } } return retval @@ -265,12 +293,17 @@ object JobVerificationEngine { for (operatorNode <- allOperatorNodes) { if (operatorNode.children.isEmpty) { operatorNode.addChild(sinkNode) + sinkNode.addParent(operatorNode) } } fixOperatorTree(sinkNode) // Enlist the fixed tree. val fixedOperatorNodes = treeToList(sinkNode) fixedOperatorNodes -= sinkNode + for (sinkParents <- sinkNode.parents) { + sinkParents.setChildren(ArrayBuffer[OperatorNode]()) + } + setChildrenDag(fixedOperatorNodes) return fixedOperatorNodes } @@ -281,6 +314,7 @@ object JobVerificationEngine { } val numPartitions = parentEcalls.length val ecall = parentEcalls(0).ecall + // println("Linking ecall " + ecall + " to ecall " + childEcalls(0).ecall) // project if (ecall == 1) { for (i <- 0 until numPartitions) { @@ -317,7 +351,7 @@ object JobVerificationEngine { // nonObliviousAggregate } else if (ecall == 9) { for (i <- 0 until numPartitions) { - parentEcalls(i).addOutgoingNeighbor(childEcalls(i)) + parentEcalls(i).addOutgoingNeighbor(childEcalls(0)) } // nonObliviousSortMergeJoin } else if (ecall == 8) { @@ -355,6 +389,7 @@ object JobVerificationEngine { def generateJobNodes(numPartitions: Int, operatorName: String): ArrayBuffer[ArrayBuffer[JobNode]] = { val jobNodes = ArrayBuffer[ArrayBuffer[JobNode]]() val expectedEcalls = ArrayBuffer[Int]() + // println("generating job nodes for " + operatorName + " with " + numPartitions + " partitions.") if (operatorName == "EncryptedSort" && numPartitions == 1) { // ("externalSort") expectedEcalls.append(6) @@ -385,10 +420,12 @@ object JobVerificationEngine { } else { throw new Exception("Executed unknown operator: " + operatorName) } + // println("Expected ecalls for " + operatorName + ": " + expectedEcalls) for (ecallIdx <- 0 until expectedEcalls.length) { val ecall = expectedEcalls(ecallIdx) val ecallJobNodes = ArrayBuffer[JobNode]() jobNodes.append(ecallJobNodes) + // println("Creating job nodes for ecall " + ecall) for (partitionIdx <- 0 until numPartitions) { val jobNode = new JobNode() jobNode.setEcall(ecall) @@ -398,7 +435,7 @@ object JobVerificationEngine { return jobNodes } - // Converts a DAG of Spark operators to a DAG of ecalls and partitions. + // expectedDAGFromPlan helper - converts a DAG of Spark operators to a DAG of ecalls and partitions. def expectedDAGFromOperatorDAG(operatorNodes: ArrayBuffer[OperatorNode]): JobNode = { val source = new JobNode() val sink = new JobNode() @@ -408,8 +445,10 @@ object JobVerificationEngine { for (node <- operatorNodes) { node.jobNodes = generateJobNodes(logEntryChains.size, node.operatorName) } + // println("Job node generation finished.") // Link all ecalls. for (node <- operatorNodes) { + // println("Linking ecalls for operator " + node.operatorName + " with num ecalls = " + node.jobNodes.length) for (ecallIdx <- 0 until node.jobNodes.length) { if (ecallIdx == node.jobNodes.length - 1) { // last ecall of this operator, link to child operators if one exists. @@ -437,20 +476,27 @@ object JobVerificationEngine { return source } - // Generates an expected DAG of ecalls and partitions from a dataframe's SparkPlan object. + // verify helper - generates an expected DAG of ecalls and partitions from a dataframe's SparkPlan object. def expectedDAGFromPlan(executedPlan: SparkPlan): JobNode = { - val operatorDAGRoot = operatorDAGFromPlan(executedPlan) - expectedDAGFromOperatorDAG(operatorDAGRoot) + val operatorDAGList = operatorDAGFromPlan(executedPlan) + expectedDAGFromOperatorDAG(operatorDAGList) } + + /*********************** + Main verification method + ***********************/ + // Verify that the executed flow of information from ecall partition to ecall partition // matches what is expected for a given Spark dataframe. + // This function should be the one called from the rest of the client to do job verification. def verify(df: DataFrame): Boolean = { // Get expected DAG. val expectedSourceNode = expectedDAGFromPlan(df.queryExecution.executedPlan) - + // Quit if graph is empty. if (expectedSourceNode.graphIsEmpty) { + println("Expected graph empty") return true } @@ -522,6 +568,7 @@ object JobVerificationEngine { executedSourceNode.setSource val executedSinkNode = new JobNode() executedSinkNode.setSink + // Iterate through all nodes, matching `all_outputs_mac` to `input_macs`. for (node <- nodeSet) { if (node.inputMacs == ArrayBuffer[ArrayBuffer[Byte]]()) { executedSourceNode.addOutgoingNeighbor(node) @@ -542,10 +589,7 @@ object JobVerificationEngine { val expectedPathsToSink = expectedSourceNode.pathsToSink val arePathsEqual = pathsEqual(executedPathsToSink, expectedPathsToSink) if (!arePathsEqual) { - // println(executedPathsToSink.toString) - // println(expectedPathsToSink.toString) - // println("===========DAGS NOT EQUAL===========") - return false + println("===========DAGS NOT EQUAL===========") } return true } diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala index e0277b829c..a1ea9b4a35 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala @@ -1435,6 +1435,10 @@ object Utils extends Logging { } (Seq(countUpdateExpr), Seq(count)) } + case PartialMerge => { + val countUpdateExpr = Add(count, c.inputAggBufferAttributes(0)) + (Seq(countUpdateExpr), Seq(count)) + } case Final => { val countUpdateExpr = Add(count, c.inputAggBufferAttributes(0)) (Seq(countUpdateExpr), Seq(count)) @@ -1443,7 +1447,7 @@ object Utils extends Logging { val countUpdateExpr = Add(count, Literal(1L)) (Seq(countUpdateExpr), Seq(count)) } - case _ => + case _ => } tuix.AggregateExpr.createAggregateExpr( @@ -1614,6 +1618,11 @@ object Utils extends Logging { val sumUpdateExpr = If(IsNull(partialSum), sum, partialSum) (Seq(sumUpdateExpr), Seq(sum)) } + case PartialMerge => { + val partialSum = Add(If(IsNull(sum), Literal.default(sumDataType), sum), s.inputAggBufferAttributes(0)) + val sumUpdateExpr = If(IsNull(partialSum), sum, partialSum) + (Seq(sumUpdateExpr), Seq(sum)) + } case Final => { val partialSum = Add(If(IsNull(sum), Literal.default(sumDataType), sum), s.inputAggBufferAttributes(0)) val sumUpdateExpr = If(IsNull(partialSum), sum, partialSum) diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCHBenchmark.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCHBenchmark.scala index 07da3b7d80..5f269595de 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCHBenchmark.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCHBenchmark.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.SQLContext object TPCHBenchmark { // Add query numbers here once they are supported - val supportedQueries = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 17, 18, 19, 20, 21, 22) + val supportedQueries = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 17, 18, 19, 20, 21, 22) def query(queryNumber: Int, tpch: TPCH, sqlContext: SQLContext, numPartitions: Int) = { val sqlStr = tpch.getQuery(queryNumber) diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala index 73af801cd0..4f04357ef1 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala @@ -243,44 +243,35 @@ case class EncryptedFilterExec(condition: Expression, child: SparkPlan) case class EncryptedAggregateExec( groupingExpressions: Seq[NamedExpression], - aggExpressions: Seq[AggregateExpression], - mode: AggregateMode, + aggregateExpressions: Seq[AggregateExpression], child: SparkPlan) extends UnaryExecNode with OpaqueOperatorExec { override def producedAttributes: AttributeSet = - AttributeSet(aggExpressions) -- AttributeSet(groupingExpressions) - - override def output: Seq[Attribute] = mode match { - case Partial => groupingExpressions.map(_.toAttribute) ++ aggExpressions.map(_.copy(mode = Partial)).flatMap(_.aggregateFunction.inputAggBufferAttributes) - case Final => groupingExpressions.map(_.toAttribute) ++ aggExpressions.map(_.resultAttribute) - case Complete => groupingExpressions.map(_.toAttribute) ++ aggExpressions.map(_.resultAttribute) - } + AttributeSet(aggregateExpressions) -- AttributeSet(groupingExpressions) + + override def output: Seq[Attribute] = groupingExpressions.map(_.toAttribute) ++ + aggregateExpressions.flatMap(expr => { + expr.mode match { + case Partial | PartialMerge => + expr.aggregateFunction.inputAggBufferAttributes + case _ => + Seq(expr.resultAttribute) + } + }) override def executeBlocked(): RDD[Block] = { - val (groupingExprs, aggExprs) = mode match { - case Partial => { - val partialAggExpressions = aggExpressions.map(_.copy(mode = Partial)) - (groupingExpressions, partialAggExpressions) - } - case Final => { - val finalGroupingExpressions = groupingExpressions.map(_.toAttribute) - val finalAggExpressions = aggExpressions.map(_.copy(mode = Final)) - (finalGroupingExpressions, finalAggExpressions) - } - case Complete => { - (groupingExpressions, aggExpressions.map(_.copy(mode = Complete))) - } - } + val aggExprSer = Utils.serializeAggOp(groupingExpressions, aggregateExpressions, child.output) + val isPartial = aggregateExpressions.map(expr => expr.mode) + .exists(mode => mode == Partial || mode == PartialMerge) - val aggExprSer = Utils.serializeAggOp(groupingExprs, aggExprs, child.output) timeOperator(child.asInstanceOf[OpaqueOperatorExec].executeBlocked(), "EncryptedPartialAggregateExec") { childRDD => childRDD.map { block => val (enclave, eid) = Utils.initEnclave() - Block(enclave.NonObliviousAggregate(eid, aggExprSer, block.bytes, (mode == Partial))) + Block(enclave.NonObliviousAggregate(eid, aggExprSer, block.bytes, isPartial)) } } } diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala index 5a622532fe..518ea27881 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala @@ -152,25 +152,90 @@ object OpaqueOperators extends Strategy { if (isEncrypted(child) && aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression])) => val aggregateExpressions = aggExpressions.map(expr => expr.asInstanceOf[AggregateExpression]) - - if (groupingExpressions.size == 0) { - // Global aggregation - val partialAggregate = EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Partial, planLater(child)) - val partialOutput = partialAggregate.output - val (projSchema, tag) = tagForGlobalAggregate(partialOutput) - - EncryptedProjectExec(resultExpressions, - EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Final, - EncryptedProjectExec(partialOutput, - EncryptedSortExec(Seq(SortOrder(tag, Ascending)), true, - EncryptedProjectExec(projSchema, partialAggregate))))) :: Nil - } else { - // Grouping aggregation - EncryptedProjectExec(resultExpressions, - EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Final, - EncryptedSortExec(groupingExpressions.map(_.toAttribute).map(e => SortOrder(e, Ascending)), true, - EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Partial, - EncryptedSortExec(groupingExpressions.map(e => SortOrder(e, Ascending)), false, planLater(child)))))) :: Nil + val (functionsWithDistinct, functionsWithoutDistinct) = aggregateExpressions.partition(_.isDistinct) + + functionsWithDistinct.size match { + case 0 => // No distinct aggregate operations + if (groupingExpressions.size == 0) { + // Global aggregation + val partialAggregate = EncryptedAggregateExec(groupingExpressions, + aggregateExpressions.map(_.copy(mode = Partial)), planLater(child)) + val partialOutput = partialAggregate.output + val (projSchema, tag) = tagForGlobalAggregate(partialOutput) + + EncryptedProjectExec(resultExpressions, + EncryptedAggregateExec(groupingExpressions, aggregateExpressions.map(_.copy(mode = Final)), + EncryptedProjectExec(partialOutput, + EncryptedSortExec(Seq(SortOrder(tag, Ascending)), true, + EncryptedProjectExec(projSchema, partialAggregate))))) :: Nil + } else { + // Grouping aggregation + EncryptedProjectExec(resultExpressions, + EncryptedAggregateExec(groupingExpressions, aggregateExpressions.map(_.copy(mode = Final)), + EncryptedSortExec(groupingExpressions.map(_.toAttribute).map(e => SortOrder(e, Ascending)), true, + EncryptedAggregateExec(groupingExpressions, aggregateExpressions.map(_.copy(mode = Partial)), + EncryptedSortExec(groupingExpressions.map(e => SortOrder(e, Ascending)), false, planLater(child)))))) :: Nil + } + case size if size == 1 => // One distinct aggregate operation + // Because we are also grouping on the columns used in the distinct expressions, + // we do not need separate cases for global and grouping aggregation. + + // We need to extract named expressions from the children of the distinct aggregate functions + // in order to group by those columns. + val namedDistinctExpressions = functionsWithDistinct.head.aggregateFunction.children.flatMap{ e => + e match { + case ne: NamedExpression => + Seq(ne) + case _ => + e.children.filter(child => child.isInstanceOf[NamedExpression]) + .map(child => child.asInstanceOf[NamedExpression]) + } + } + val combinedGroupingExpressions = groupingExpressions ++ namedDistinctExpressions + + // 1. Create an Aggregate operator for partial aggregations. + val partialAggregate = { + val sorted = EncryptedSortExec(combinedGroupingExpressions.map(e => SortOrder(e, Ascending)), false, + planLater(child)) + EncryptedAggregateExec(combinedGroupingExpressions, functionsWithoutDistinct.map(_.copy(mode = Partial)), sorted) + } + + // 2. Create an Aggregate operator for partial merge aggregations. + val partialMergeAggregate = { + // Partition based on the final grouping expressions. + val partitionOrder = groupingExpressions.map(e => SortOrder(e, Ascending)) + val partitioned = EncryptedRangePartitionExec(partitionOrder, partialAggregate) + + // Local sort on the combined grouping expressions. + val sortOrder = combinedGroupingExpressions.map(e => SortOrder(e, Ascending)) + val sorted = EncryptedSortExec(sortOrder, false, partitioned) + + EncryptedAggregateExec(combinedGroupingExpressions, + functionsWithoutDistinct.map(_.copy(mode = PartialMerge)), sorted) + } + + // 3. Create an Aggregate operator for partial aggregation of distinct aggregate expressions. + val partialDistinctAggregate = { + // Indistinct functions operate on aggregation buffers since partial aggregation was already called, + // but distinct functions operate on the original input to the aggregation. + EncryptedAggregateExec(groupingExpressions, + functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) ++ + functionsWithDistinct.map(_.copy(mode = Partial)), partialMergeAggregate) + } + + // 4. Create an Aggregate operator for the final aggregation. + val finalAggregate = { + val sorted = EncryptedSortExec(groupingExpressions.map(e => SortOrder(e, Ascending)), + true, partialDistinctAggregate) + EncryptedAggregateExec(groupingExpressions, + (functionsWithoutDistinct ++ functionsWithDistinct).map(_.copy(mode = Final)), sorted) + } + + EncryptedProjectExec(resultExpressions, finalAggregate) :: Nil + + case _ => { // More than one distinct operations + throw new UnsupportedOperationException("Aggregate operations with more than one distinct expressions are not yet supported.") + } } case p @ Union(Seq(left, right)) if isEncrypted(p) => diff --git a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala index 9ae39c4fb5..953033fbbf 100644 --- a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala +++ b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala @@ -489,6 +489,30 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => .sortBy { case Row(category: String, _) => category } } + testAgainstSpark("aggregate count distinct and indistinct") { securityLevel => + val data = (0 until 64).map{ i => + if (i % 6 == 0) + (abc(i), null.asInstanceOf[Int], i % 8) + else + (abc(i), i % 4, i % 8) + }.toSeq + val words = makeDF(data, securityLevel, "category", "id", "price") + words.groupBy("category").agg(countDistinct("id").as("num_unique_ids"), + count("price").as("num_prices")).collect.toSet + } + + testAgainstSpark("aggregate count distinct") { securityLevel => + val data = (0 until 64).map{ i => + if (i % 6 == 0) + (abc(i), null.asInstanceOf[Int]) + else + (abc(i), i % 8) + }.toSeq + val words = makeDF(data, securityLevel, "category", "price") + words.groupBy("category").agg(countDistinct("price").as("num_unique_prices")) + .collect.sortBy { case Row(category: String, _) => category } + } + testAgainstSpark("aggregate first") { securityLevel => val data = for (i <- 0 until 256) yield (i, abc(i), 1) val words = makeDF(data, securityLevel, "id", "category", "price") @@ -536,6 +560,30 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => .sortBy { case Row(word: String, _) => word } } + testAgainstSpark("aggregate sum distinct and indistinct") { securityLevel => + val data = (0 until 64).map{ i => + if (i % 6 == 0) + (abc(i), null.asInstanceOf[Int], i % 8) + else + (abc(i), i % 4, i % 8) + }.toSeq + val words = makeDF(data, securityLevel, "category", "id", "price") + words.groupBy("category").agg(sumDistinct("id").as("sum_unique_ids"), + sum("price").as("sum_prices")).collect.toSet + } + + testAgainstSpark("aggregate sum distinct") { securityLevel => + val data = (0 until 64).map{ i => + if (i % 6 == 0) + (abc(i), null.asInstanceOf[Int]) + else + (abc(i), i % 8) + }.toSeq + val words = makeDF(data, securityLevel, "category", "price") + words.groupBy("category").agg(sumDistinct("price").as("sum_unique_prices")) + .collect.sortBy { case Row(category: String, _) => category } + } + testAgainstSpark("aggregate on multiple columns") { securityLevel => val data = for (i <- 0 until 256) yield (abc(i), 1, 1.0f) val words = makeDF(data, securityLevel, "str", "x", "y") @@ -567,6 +615,12 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => integrityCollect(words.agg(sum("count").as("totalCount"))) } + testAgainstSpark("global aggregate count distinct") { securityLevel => + val data = for (i <- 0 until 256) yield (i, abc(i), i % 64) + val words = makeDF(data, securityLevel, "id", "word", "price") + words.agg(countDistinct("price").as("num_unique_prices")).collect + } + testAgainstSpark("global aggregate with 0 rows") { securityLevel => val data = for (i <- 0 until 256) yield (i, abc(i), 1) val words = makeDF(data, securityLevel, "id", "word", "count")