Skip to content

Commit 8c55381

Browse files
authored
(dsl): Support Filter aggregation (#349)
1 parent c50fc01 commit 8c55381

File tree

9 files changed

+414
-26
lines changed

9 files changed

+414
-26
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
---
2+
id: elastic_aggregation_filter
3+
title: "Filter Aggregation"
4+
---
5+
6+
The `Filter` aggregation is a single bucket aggregation that narrows down the entire set of documents to a specific set that matches a [query](https://lambdaworks.github.io/zio-elasticsearch/overview/elastic_query).
7+
8+
In order to use the `Filter` aggregation import the following:
9+
```scala
10+
import zio.elasticsearch.aggregation.FilterAggregation
11+
import zio.elasticsearch.ElasticAggregation.filterAggregation
12+
```
13+
14+
You can create a `Filter` aggregation using the `filterAggregation` method in the following manner:
15+
```scala
16+
import zio.elasticsearch.ElasticQuery.term
17+
18+
val aggregation: FilterAggregation = filterAggregation(name = "filterAggregation", query = term(field = Document.stringField, value = "test"))
19+
```
20+
21+
If you want to add aggregation (on the same level), you can use `withAgg` method:
22+
```scala
23+
import zio.elasticsearch.ElasticQuery.term
24+
25+
val multipleAggregations: MultipleAggregations = filterAggregation(name = "filterAggregation", query = term(field = Document.stringField, value = "test")).withAgg(maxAggregation(name = "maxAggregation", field = Document.doubleField))
26+
```
27+
28+
If you want to add another sub-aggregation, you can use `withSubAgg` method:
29+
```scala
30+
import zio.elasticsearch.ElasticQuery.term
31+
import zio.elasticsearch.ElasticAggregation.maxAggregation
32+
33+
val aggregationWithSubAgg: FilterAggregation = filterAggregation(name = "filterAggregation", query = term(field = Document.stringField, value = "test")).withSubAgg(maxAggregation(name = "maxAggregation", field = Document.intField))
34+
```
35+
36+
You can find more information about `Filter` aggregation [here](https://www.elastic.co/guide/en/elasticsearch/reference/current/search-aggregations-bucket-filter-aggregation.html).

modules/integration/src/test/scala/zio/elasticsearch/HttpExecutorSpec.scala

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import zio.elasticsearch.query.sort.SortOrder._
3333
import zio.elasticsearch.query.sort.SourceType.NumberType
3434
import zio.elasticsearch.query.{Distance, FunctionScoreBoostMode, FunctionScoreFunction, InnerHits}
3535
import zio.elasticsearch.request.{CreationOutcome, DeletionOutcome}
36-
import zio.elasticsearch.result.{Item, MaxAggregationResult, UpdateByQueryResult}
36+
import zio.elasticsearch.result.{FilterAggregationResult, Item, MaxAggregationResult, UpdateByQueryResult}
3737
import zio.elasticsearch.script.{Painless, Script}
3838
import zio.json.ast.Json.{Arr, Str}
3939
import zio.schema.codec.JsonCodec
@@ -146,6 +146,65 @@ object HttpExecutorSpec extends IntegrationSpec {
146146
Executor.execute(ElasticRequest.createIndex(firstSearchIndex)),
147147
Executor.execute(ElasticRequest.deleteIndex(firstSearchIndex)).orDie
148148
),
149+
test("aggregate using filter aggregation with max aggregation as a sub aggregation") {
150+
val expectedResult = (
151+
"aggregation",
152+
FilterAggregationResult(
153+
docCount = 2,
154+
subAggregations = Map(
155+
"subAggregation" -> MaxAggregationResult(value = 5.0)
156+
)
157+
)
158+
)
159+
checkOnce(genDocumentId, genTestDocument, genDocumentId, genTestDocument, genDocumentId, genTestDocument) {
160+
(firstDocumentId, firstDocument, secondDocumentId, secondDocument, thirdDocumentId, thirdDocument) =>
161+
for {
162+
_ <- Executor.execute(ElasticRequest.deleteByQuery(firstSearchIndex, matchAll))
163+
firstDocumentUpdated = firstDocument.copy(stringField = "test", intField = 7)
164+
secondDocumentUpdated =
165+
secondDocument.copy(stringField = "filterAggregation", intField = 3)
166+
thirdDocumentUpdated =
167+
thirdDocument.copy(stringField = "filterAggregation", intField = 5)
168+
_ <- Executor.execute(
169+
ElasticRequest.upsert[TestDocument](
170+
firstSearchIndex,
171+
firstDocumentId,
172+
firstDocumentUpdated
173+
)
174+
)
175+
_ <- Executor.execute(
176+
ElasticRequest
177+
.upsert[TestDocument](
178+
firstSearchIndex,
179+
secondDocumentId,
180+
secondDocumentUpdated
181+
)
182+
)
183+
_ <- Executor.execute(
184+
ElasticRequest
185+
.upsert[TestDocument](
186+
firstSearchIndex,
187+
thirdDocumentId,
188+
thirdDocumentUpdated
189+
)
190+
.refreshTrue
191+
)
192+
query = term(field = TestDocument.stringField, value = secondDocumentUpdated.stringField.toLowerCase)
193+
aggregation =
194+
filterAggregation(name = "aggregation", query = query).withSubAgg(
195+
maxAggregation("subAggregation", TestDocument.intField)
196+
)
197+
aggsRes <-
198+
Executor
199+
.execute(ElasticRequest.aggregate(selectors = firstSearchIndex, aggregation = aggregation))
200+
.aggregations
201+
202+
} yield assert(aggsRes.head)(equalTo(expectedResult))
203+
}
204+
} @@ around(
205+
Executor.execute(ElasticRequest.createIndex(firstSearchIndex)),
206+
Executor.execute(ElasticRequest.deleteIndex(firstSearchIndex)).orDie
207+
),
149208
test("aggregate using max aggregation") {
150209
val expectedResponse = ("aggregationInt", MaxAggregationResult(value = 20.0))
151210
checkOnce(genDocumentId, genTestDocument, genDocumentId, genTestDocument) {

modules/library/src/main/scala/zio/elasticsearch/ElasticAggregation.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package zio.elasticsearch
1818

1919
import zio.Chunk
2020
import zio.elasticsearch.aggregation._
21+
import zio.elasticsearch.query.ElasticQuery
2122
import zio.elasticsearch.script.Script
2223

2324
object ElasticAggregation {
@@ -113,6 +114,20 @@ object ElasticAggregation {
113114
final def cardinalityAggregation(name: String, field: String): CardinalityAggregation =
114115
Cardinality(name = name, field = field, missing = None)
115116

117+
/**
118+
* Constructs an instance of [[zio.elasticsearch.aggregation.FilterAggregation]] using the specified parameters.
119+
*
120+
* @param name
121+
* aggregation name
122+
* @param query
123+
* a query which the documents must match
124+
* @return
125+
* an instance of [[zio.elasticsearch.aggregation.FilterAggregation]] that represents filter aggregation to be
126+
* performed.
127+
*/
128+
final def filterAggregation(name: String, query: ElasticQuery[_]): FilterAggregation =
129+
Filter(name = name, query = query, subAggregations = Chunk.empty)
130+
116131
/**
117132
* Constructs a type-safe instance of [[zio.elasticsearch.aggregation.ExtendedStatsAggregation]] using the specified
118133
* parameters.

modules/library/src/main/scala/zio/elasticsearch/aggregation/Aggregations.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import zio.Chunk
2020
import zio.elasticsearch.ElasticAggregation.multipleAggregations
2121
import zio.elasticsearch.ElasticPrimitive.ElasticPrimitiveOps
2222
import zio.elasticsearch.aggregation.options._
23+
import zio.elasticsearch.query.ElasticQuery
2324
import zio.elasticsearch.query.sort.Sort
2425
import zio.elasticsearch.script.Script
2526
import zio.json.ast.Json
@@ -186,6 +187,31 @@ private[elasticsearch] final case class ExtendedStats(
186187
}
187188
}
188189

190+
sealed trait FilterAggregation extends SingleElasticAggregation with WithAgg with WithSubAgg[FilterAggregation]
191+
192+
private[elasticsearch] final case class Filter(
193+
name: String,
194+
query: ElasticQuery[_],
195+
subAggregations: Chunk[SingleElasticAggregation]
196+
) extends FilterAggregation { self =>
197+
198+
def withAgg(agg: SingleElasticAggregation): MultipleAggregations =
199+
multipleAggregations.aggregations(self, agg)
200+
201+
def withSubAgg(aggregation: SingleElasticAggregation): FilterAggregation =
202+
self.copy(subAggregations = aggregation +: subAggregations)
203+
204+
private[elasticsearch] def toJson: Json = {
205+
val subAggsJson: Obj =
206+
if (self.subAggregations.nonEmpty)
207+
Obj("aggs" -> self.subAggregations.map(_.toJson).reduce(_ merge _))
208+
else
209+
Obj()
210+
211+
Obj(name -> (Obj("filter" -> query.toJson(fieldPath = None)) merge subAggsJson))
212+
}
213+
}
214+
189215
sealed trait MaxAggregation extends SingleElasticAggregation with HasMissing[MaxAggregation] with WithAgg
190216

191217
private[elasticsearch] final case class Max(name: String, field: String, missing: Option[Double])

modules/library/src/main/scala/zio/elasticsearch/executor/response/AggregationResponse.scala

Lines changed: 132 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import zio.json.ast.Json
2222
import zio.json.ast.Json.Obj
2323
import zio.json.{DeriveJsonDecoder, JsonDecoder, jsonField}
2424

25+
private[elasticsearch] sealed trait AggregationBucket
26+
2527
sealed trait AggregationResponse
2628

2729
object AggregationResponse {
@@ -68,6 +70,13 @@ object AggregationResponse {
6870
lowerSampling = stdDeviationBoundsResponse.lowerSampling
6971
)
7072
)
73+
case FilterAggregationResponse(docCount, subAggregations) =>
74+
FilterAggregationResult(
75+
docCount = docCount,
76+
subAggregations = subAggregations.fold(Map[String, AggregationResult]())(_.map { case (key, response) =>
77+
(key, toResult(response))
78+
})
79+
)
7180
case MaxAggregationResponse(value) =>
7281
MaxAggregationResult(value)
7382
case MinAggregationResponse(value) =>
@@ -142,6 +151,123 @@ private[elasticsearch] object ExtendedStatsAggregationResponse {
142151
DeriveJsonDecoder.gen[ExtendedStatsAggregationResponse]
143152
}
144153

154+
private[elasticsearch] final case class FilterAggregationResponse(
155+
@jsonField("doc_count")
156+
docCount: Int,
157+
subAggregations: Option[Map[String, AggregationResponse]] = None
158+
) extends AggregationResponse
159+
160+
private[elasticsearch] object FilterAggregationResponse extends JsonDecoderOps {
161+
implicit val decoder: JsonDecoder[FilterAggregationResponse] = Obj.decoder.mapOrFail { case Obj(fields) =>
162+
val allFields = fields.flatMap { case (field, data) =>
163+
field match {
164+
case "doc_count" =>
165+
Some(field -> data.unsafeAs[Int])
166+
case _ =>
167+
val objFields = data.unsafeAs[Obj].fields.toMap
168+
169+
(field: @unchecked) match {
170+
case str if str.contains("weighted_avg#") =>
171+
Some(field -> WeightedAvgAggregationResponse(value = objFields("value").unsafeAs[Double]))
172+
case str if str.contains("avg#") =>
173+
Some(field -> AvgAggregationResponse(value = objFields("value").unsafeAs[Double]))
174+
case str if str.contains("cardinality#") =>
175+
Some(field -> CardinalityAggregationResponse(value = objFields("value").unsafeAs[Int]))
176+
case str if str.contains("extended_stats#") =>
177+
Some(
178+
field -> ExtendedStatsAggregationResponse(
179+
count = objFields("count").unsafeAs[Int],
180+
min = objFields("min").unsafeAs[Double],
181+
max = objFields("max").unsafeAs[Double],
182+
avg = objFields("avg").unsafeAs[Double],
183+
sum = objFields("sum").unsafeAs[Double],
184+
sumOfSquares = objFields("sum_of_squares").unsafeAs[Double],
185+
variance = objFields("variance").unsafeAs[Double],
186+
variancePopulation = objFields("variance_population").unsafeAs[Double],
187+
varianceSampling = objFields("variance_sampling").unsafeAs[Double],
188+
stdDeviation = objFields("std_deviation").unsafeAs[Double],
189+
stdDeviationPopulation = objFields("std_deviation_population").unsafeAs[Double],
190+
stdDeviationSampling = objFields("std_deviation_sampling").unsafeAs[Double],
191+
stdDeviationBoundsResponse = objFields("std_deviation_sampling").unsafeAs[StdDeviationBoundsResponse](
192+
StdDeviationBoundsResponse.decoder
193+
)
194+
)
195+
)
196+
case str if str.contains("filter#") =>
197+
Some(field -> data.unsafeAs[FilterAggregationResponse](FilterAggregationResponse.decoder))
198+
case str if str.contains("max#") =>
199+
Some(field -> MaxAggregationResponse(value = objFields("value").unsafeAs[Double]))
200+
case str if str.contains("min#") =>
201+
Some(field -> MinAggregationResponse(value = objFields("value").unsafeAs[Double]))
202+
case str if str.contains("missing#") =>
203+
Some(field -> MissingAggregationResponse(docCount = objFields("doc_count").unsafeAs[Int]))
204+
case str if str.contains("percentiles#") =>
205+
Some(field -> PercentilesAggregationResponse(values = objFields("values").unsafeAs[Map[String, Double]]))
206+
case str if str.contains("stats#") =>
207+
Some(
208+
field -> StatsAggregationResponse(
209+
count = objFields("count").unsafeAs[Int],
210+
min = objFields("min").unsafeAs[Double],
211+
max = objFields("max").unsafeAs[Double],
212+
avg = objFields("avg").unsafeAs[Double],
213+
sum = objFields("sum").unsafeAs[Double]
214+
)
215+
)
216+
case str if str.contains("sum#") =>
217+
Some(field -> SumAggregationResponse(value = objFields("value").unsafeAs[Double]))
218+
case str if str.contains("terms#") =>
219+
Some(field -> data.unsafeAs[TermsAggregationResponse](TermsAggregationResponse.decoder))
220+
case str if str.contains("value_count#") =>
221+
Some(field -> ValueCountAggregationResponse(value = objFields("value").unsafeAs[Int]))
222+
}
223+
}
224+
}.toMap
225+
226+
val docCount = allFields("doc_count").asInstanceOf[Int]
227+
val subAggs = allFields.collect {
228+
case (field, data) if field != "doc_count" =>
229+
(field: @unchecked) match {
230+
case str if str.contains("weighted_avg#") =>
231+
(field.split("#")(1), data.asInstanceOf[WeightedAvgAggregationResponse])
232+
case str if str.contains("avg#") =>
233+
(field.split("#")(1), data.asInstanceOf[AvgAggregationResponse])
234+
case str if str.contains("cardinality#") =>
235+
(field.split("#")(1), data.asInstanceOf[CardinalityAggregationResponse])
236+
case str if str.contains("extended_stats#") =>
237+
(field.split("#")(1), data.asInstanceOf[ExtendedStatsAggregationResponse])
238+
case str if str.contains("filter#") =>
239+
(field.split("#")(1), data.asInstanceOf[FilterAggregationResponse])
240+
case str if str.contains("max#") =>
241+
(field.split("#")(1), data.asInstanceOf[MaxAggregationResponse])
242+
case str if str.contains("min#") =>
243+
(field.split("#")(1), data.asInstanceOf[MinAggregationResponse])
244+
case str if str.contains("missing#") =>
245+
(field.split("#")(1), data.asInstanceOf[MissingAggregationResponse])
246+
case str if str.contains("percentiles#") =>
247+
(field.split("#")(1), data.asInstanceOf[PercentilesAggregationResponse])
248+
case str if str.contains("stats#") =>
249+
(field.split("#")(1), data.asInstanceOf[StatsAggregationResponse])
250+
case str if str.contains("sum#") =>
251+
(field.split("#")(1), data.asInstanceOf[SumAggregationResponse])
252+
case str if str.contains("terms#") =>
253+
(field.split("#")(1), data.asInstanceOf[TermsAggregationResponse])
254+
case str if str.contains("value_count#") =>
255+
(field.split("#")(1), data.asInstanceOf[ValueCountAggregationResponse])
256+
}
257+
}
258+
Right(FilterAggregationResponse.apply(docCount, Option(subAggs).filter(_.nonEmpty)))
259+
}
260+
}
261+
262+
private[elasticsearch] sealed trait JsonDecoderOps {
263+
implicit class JsonDecoderOps(json: Json) {
264+
def unsafeAs[A](implicit decoder: JsonDecoder[A]): A =
265+
(json.as[A]: @unchecked) match {
266+
case Right(decoded) => decoded
267+
}
268+
}
269+
}
270+
145271
private[elasticsearch] final case class MaxAggregationResponse(value: Double) extends AggregationResponse
146272

147273
private[elasticsearch] object MaxAggregationResponse {
@@ -217,16 +343,14 @@ private[elasticsearch] object TermsAggregationResponse {
217343
implicit val decoder: JsonDecoder[TermsAggregationResponse] = DeriveJsonDecoder.gen[TermsAggregationResponse]
218344
}
219345

220-
private[elasticsearch] sealed trait AggregationBucket
221-
222346
private[elasticsearch] final case class TermsAggregationBucket(
223347
key: String,
224348
@jsonField("doc_count")
225349
docCount: Int,
226350
subAggregations: Option[Map[String, AggregationResponse]] = None
227351
) extends AggregationBucket
228352

229-
private[elasticsearch] object TermsAggregationBucket {
353+
private[elasticsearch] object TermsAggregationBucket extends JsonDecoderOps {
230354
implicit val decoder: JsonDecoder[TermsAggregationBucket] = Obj.decoder.mapOrFail { case Obj(fields) =>
231355
val allFields = fields.flatMap { case (field, data) =>
232356
field match {
@@ -264,6 +388,8 @@ private[elasticsearch] object TermsAggregationBucket {
264388
)
265389
)
266390
)
391+
case str if str.contains("filter#") =>
392+
Some(field -> data.unsafeAs[FilterAggregationResponse](FilterAggregationResponse.decoder))
267393
case str if str.contains("max#") =>
268394
Some(field -> MaxAggregationResponse(value = objFields("value").unsafeAs[Double]))
269395
case str if str.contains("min#") =>
@@ -285,15 +411,7 @@ private[elasticsearch] object TermsAggregationBucket {
285411
case str if str.contains("sum#") =>
286412
Some(field -> SumAggregationResponse(value = objFields("value").unsafeAs[Double]))
287413
case str if str.contains("terms#") =>
288-
Some(
289-
field -> TermsAggregationResponse(
290-
docErrorCount = objFields("doc_count_error_upper_bound").unsafeAs[Int],
291-
sumOtherDocCount = objFields("sum_other_doc_count").unsafeAs[Int],
292-
buckets = objFields("buckets")
293-
.unsafeAs[Chunk[Json]]
294-
.map(_.unsafeAs[TermsAggregationBucket](TermsAggregationBucket.decoder))
295-
)
296-
)
414+
Some(field -> data.unsafeAs[TermsAggregationResponse](TermsAggregationResponse.decoder))
297415
case str if str.contains("value_count#") =>
298416
Some(field -> ValueCountAggregationResponse(value = objFields("value").unsafeAs[Int]))
299417
}
@@ -313,6 +431,8 @@ private[elasticsearch] object TermsAggregationBucket {
313431
(field.split("#")(1), data.asInstanceOf[CardinalityAggregationResponse])
314432
case str if str.contains("extended_stats#") =>
315433
(field.split("#")(1), data.asInstanceOf[ExtendedStatsAggregationResponse])
434+
case str if str.contains("filter#") =>
435+
(field.split("#")(1), data.asInstanceOf[FilterAggregationResponse])
316436
case str if str.contains("max#") =>
317437
(field.split("#")(1), data.asInstanceOf[MaxAggregationResponse])
318438
case str if str.contains("min#") =>
@@ -331,16 +451,8 @@ private[elasticsearch] object TermsAggregationBucket {
331451
(field.split("#")(1), data.asInstanceOf[ValueCountAggregationResponse])
332452
}
333453
}
334-
335454
Right(TermsAggregationBucket.apply(key, docCount, Option(subAggs).filter(_.nonEmpty)))
336455
}
337-
338-
final implicit class JsonDecoderOps(json: Json) {
339-
def unsafeAs[A](implicit decoder: JsonDecoder[A]): A =
340-
(json.as[A]: @unchecked) match {
341-
case Right(decoded) => decoded
342-
}
343-
}
344456
}
345457

346458
private[elasticsearch] final case class ValueCountAggregationResponse(value: Int) extends AggregationResponse

0 commit comments

Comments
 (0)