Skip to content

Commit 31ee1e0

Browse files
Merge pull request #617 from sqlkata/iss-507_support_filter_clause
Iss 507 support filter clause
2 parents 6337a57 + 4867541 commit 31ee1e0

File tree

8 files changed

+291
-7
lines changed

8 files changed

+291
-7
lines changed

QueryBuilder.Tests/MySqlExecutionTest.cs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,55 @@ public void ExistsShouldReturnTrueForNonEmptyTable()
200200
db.Drop("Transaction");
201201
}
202202

203+
[Fact]
204+
public void BasicSelectFilter()
205+
{
206+
var db = DB().Create("Transaction", new[] {
207+
"Id INT PRIMARY KEY AUTO_INCREMENT",
208+
"Date DATE NOT NULL",
209+
"Amount int NOT NULL",
210+
});
211+
212+
var data = new Dictionary<string, int> {
213+
// 2020
214+
{"2020-01-01", 10},
215+
{"2020-05-01", 20},
216+
217+
// 2021
218+
{"2021-01-01", 40},
219+
{"2021-02-01", 10},
220+
{"2021-04-01", -10},
221+
222+
// 2022
223+
{"2022-01-01", 80},
224+
{"2022-02-01", -30},
225+
{"2022-05-01", 50},
226+
};
227+
228+
foreach (var row in data)
229+
{
230+
db.Query("Transaction").Insert(new
231+
{
232+
Date = row.Key,
233+
Amount = row.Value
234+
});
235+
}
236+
237+
var query = db.Query("Transaction")
238+
.SelectSum("Amount as Total_2020", q => q.WhereDatePart("year", "date", 2020))
239+
.SelectSum("Amount as Total_2021", q => q.WhereDatePart("year", "date", 2021))
240+
.SelectSum("Amount as Total_2022", q => q.WhereDatePart("year", "date", 2022))
241+
;
242+
243+
var results = query.Get().ToList();
244+
Assert.Single(results);
245+
Assert.Equal(30, results[0].Total_2020);
246+
Assert.Equal(40, results[0].Total_2021);
247+
Assert.Equal(100, results[0].Total_2022);
248+
249+
db.Drop("Transaction");
250+
}
251+
203252
QueryFactory DB()
204253
{
205254
var host = System.Environment.GetEnvironmentVariable("SQLKATA_MYSQL_HOST");

QueryBuilder.Tests/SQLiteExecutionTest.cs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,57 @@ public void InlineTable()
207207
db.Drop("Transaction");
208208
}
209209

210+
211+
212+
[Fact]
213+
public void BasicSelectFilter()
214+
{
215+
var db = DB().Create("Transaction", new[] {
216+
"Id INTEGER PRIMARY KEY AUTOINCREMENT",
217+
"Date DATE NOT NULL",
218+
"Amount int NOT NULL",
219+
});
220+
221+
var data = new Dictionary<string, int> {
222+
// 2020
223+
{"2020-01-01", 10},
224+
{"2020-05-01", 20},
225+
226+
// 2021
227+
{"2021-01-01", 40},
228+
{"2021-02-01", 10},
229+
{"2021-04-01", -10},
230+
231+
// 2022
232+
{"2022-01-01", 80},
233+
{"2022-02-01", -30},
234+
{"2022-05-01", 50},
235+
};
236+
237+
foreach (var row in data)
238+
{
239+
db.Query("Transaction").Insert(new
240+
{
241+
Date = row.Key,
242+
Amount = row.Value
243+
});
244+
}
245+
246+
var query = db.Query("Transaction")
247+
.SelectSum("Amount as Total_2020", q => q.WhereDatePart("year", "date", 2020))
248+
.SelectSum("Amount as Total_2021", q => q.WhereDatePart("year", "date", 2021))
249+
.SelectSum("Amount as Total_2022", q => q.WhereDatePart("year", "date", 2022))
250+
;
251+
252+
var results = query.Get().ToList();
253+
Assert.Single(results);
254+
Assert.Equal(30, results[0].Total_2020);
255+
Assert.Equal(40, results[0].Total_2021);
256+
Assert.Equal(100, results[0].Total_2022);
257+
258+
db.Drop("Transaction");
259+
}
260+
210261
QueryFactory DB()
211262
{
212263
var cs = $"Data Source=file::memory:;Cache=Shared";

QueryBuilder.Tests/SelectTests.cs

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,7 @@ public void EscapeClauseThrowsForMultipleCharacters()
819819
[Fact]
820820
public void BasicSelectRaw_WithNoTable()
821821
{
822-
var q = new Query().SelectRaw("somefunction() as c1");
822+
var q = new Query().SelectRaw("somefunction() as c1");
823823

824824
var c = Compilers.CompileFor(EngineCodes.SqlServer, q);
825825
Assert.Equal("SELECT somefunction() as c1", c.ToString());
@@ -848,6 +848,60 @@ public void BasicSelect_WithNoTableWhereRawClause()
848848
var c = Compilers.CompileFor(EngineCodes.SqlServer, q);
849849
Assert.Equal("SELECT [c1] WHERE 1 = 1", c.ToString());
850850
}
851-
851+
852+
[Fact]
853+
public void BasicSelectAggregate()
854+
{
855+
var q = new Query("Posts").Select("Title")
856+
.SelectAggregate("sum", "ViewCount");
857+
858+
var sqlServer = Compilers.CompileFor(EngineCodes.SqlServer, q);
859+
Assert.Equal("SELECT [Title], SUM([ViewCount]) FROM [Posts]", sqlServer.ToString());
860+
}
861+
862+
[Fact]
863+
public void SelectAggregateShouldIgnoreEmptyFilter()
864+
{
865+
var q = new Query("Posts").Select("Title")
866+
.SelectAggregate("sum", "ViewCount", q => q);
867+
868+
var sqlServer = Compilers.CompileFor(EngineCodes.SqlServer, q);
869+
Assert.Equal("SELECT [Title], SUM([ViewCount]) FROM [Posts]", sqlServer.ToString());
870+
}
871+
872+
[Fact]
873+
public void SelectAggregateShouldIgnoreEmptyQueryFilter()
874+
{
875+
var q = new Query("Posts").Select("Title")
876+
.SelectAggregate("sum", "ViewCount", new Query());
877+
878+
var sqlServer = Compilers.CompileFor(EngineCodes.SqlServer, q);
879+
Assert.Equal("SELECT [Title], SUM([ViewCount]) FROM [Posts]", sqlServer.ToString());
880+
}
881+
882+
[Fact]
883+
public void BasicSelectAggregateWithAlias()
884+
{
885+
var q = new Query("Posts").Select("Title")
886+
.SelectAggregate("sum", "ViewCount as TotalViews");
887+
888+
var sqlServer = Compilers.CompileFor(EngineCodes.SqlServer, q);
889+
Assert.Equal("SELECT [Title], SUM([ViewCount]) AS [TotalViews] FROM [Posts]", sqlServer.ToString());
890+
}
891+
892+
[Fact]
893+
public void SelectWithFilter()
894+
{
895+
var q = new Query("Posts").Select("Title")
896+
.SelectAggregate("sum", "ViewCount as Published_Jan", q => q.Where("Published_Month", "Jan"))
897+
.SelectAggregate("sum", "ViewCount as Published_Feb", q => q.Where("Published_Month", "Feb"));
898+
899+
var pgSql = Compilers.CompileFor(EngineCodes.PostgreSql, q);
900+
Assert.Equal("SELECT \"Title\", SUM(\"ViewCount\") FILTER (WHERE \"Published_Month\" = 'Jan') AS \"Published_Jan\", SUM(\"ViewCount\") FILTER (WHERE \"Published_Month\" = 'Feb') AS \"Published_Feb\" FROM \"Posts\"", pgSql.ToString());
901+
902+
var sqlServer = Compilers.CompileFor(EngineCodes.SqlServer, q);
903+
Assert.Equal("SELECT [Title], SUM(CASE WHEN [Published_Month] = 'Jan' THEN [ViewCount] END) AS [Published_Jan], SUM(CASE WHEN [Published_Month] = 'Feb' THEN [ViewCount] END) AS [Published_Feb] FROM [Posts]", sqlServer.ToString());
904+
}
905+
852906
}
853907
}

QueryBuilder/Clauses/ColumnClause.cs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,33 @@ public override AbstractClause Clone()
7777
};
7878
}
7979
}
80+
81+
/// <summary>
82+
/// Represents an aggregated column clause with an optional filter
83+
/// </summary>
84+
/// <seealso cref="AbstractColumn" />
85+
public class AggregatedColumn : AbstractColumn
86+
{
87+
/// <summary>
88+
/// Gets or sets the a query that used to filter the data,
89+
/// the compiler will consider only the `Where` clause.
90+
/// </summary>
91+
/// <value>
92+
/// The filter query.
93+
/// </value>
94+
public Query Filter { get; set; } = null;
95+
public string Aggregate { get; set; }
96+
public AbstractColumn Column { get; set; }
97+
public override AbstractClause Clone()
98+
{
99+
return new AggregatedColumn
100+
{
101+
Engine = Engine,
102+
Filter = Filter?.Clone(),
103+
Column = Column.Clone() as AbstractColumn,
104+
Aggregate = Aggregate,
105+
Component = Component,
106+
};
107+
}
108+
}
80109
}

QueryBuilder/Compilers/Compiler.cs

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ protected Compiler()
2424

2525
public virtual string EngineCode { get; }
2626

27+
/// <summary>
28+
/// Whether the compiler supports the `SELECT ... FILTER` syntax
29+
/// </summary>
30+
/// <value></value>
31+
public virtual bool SupportsFilterClause { get; set; } = false;
32+
2733
protected virtual string SingleRowDummyTableName { get => null; }
2834

2935
/// <summary>
@@ -512,10 +518,44 @@ public virtual string CompileColumn(SqlResult ctx, AbstractColumn column)
512518
return "(" + subCtx.RawSql + $"){alias}";
513519
}
514520

521+
if (column is AggregatedColumn aggregatedColumn)
522+
{
523+
string agg = aggregatedColumn.Aggregate.ToUpperInvariant();
524+
525+
var (col, alias) = SplitAlias(CompileColumn(ctx, aggregatedColumn.Column));
526+
527+
alias = string.IsNullOrEmpty(alias) ? "" : $" {ColumnAsKeyword}{alias}";
528+
529+
string filterCondition = CompileFilterConditions(ctx, aggregatedColumn);
530+
531+
if (string.IsNullOrEmpty(filterCondition))
532+
{
533+
return $"{agg}({col}){alias}";
534+
}
535+
536+
if (SupportsFilterClause)
537+
{
538+
return $"{agg}({col}) FILTER (WHERE {filterCondition}){alias}";
539+
}
540+
541+
return $"{agg}(CASE WHEN {filterCondition} THEN {col} END){alias}";
542+
}
543+
515544
return Wrap((column as Column).Name);
516545

517546
}
518547

548+
protected virtual string CompileFilterConditions(SqlResult ctx, AggregatedColumn aggregatedColumn)
549+
{
550+
if (aggregatedColumn.Filter == null)
551+
{
552+
return null;
553+
}
554+
555+
var wheres = aggregatedColumn.Filter.GetComponents<AbstractCondition>("where");
556+
557+
return CompileConditions(ctx, wheres);
558+
}
519559

520560
public virtual SqlResult CompileCte(AbstractFrom cte)
521561
{
@@ -872,9 +912,7 @@ public virtual string Wrap(string value)
872912

873913
if (value.ToLowerInvariant().Contains(" as "))
874914
{
875-
var index = value.ToLowerInvariant().IndexOf(" as ");
876-
var before = value.Substring(0, index);
877-
var after = value.Substring(index + 4);
915+
var (before, after) = SplitAlias(value);
878916

879917
return Wrap(before) + $" {ColumnAsKeyword}" + WrapValue(after);
880918
}
@@ -892,6 +930,20 @@ public virtual string Wrap(string value)
892930
return WrapValue(value);
893931
}
894932

933+
public virtual (string, string) SplitAlias(string value)
934+
{
935+
var index = value.LastIndexOf(" as ", StringComparison.OrdinalIgnoreCase);
936+
937+
if (index > 0)
938+
{
939+
var before = value.Substring(0, index);
940+
var after = value.Substring(index + 4);
941+
return (before, after);
942+
}
943+
944+
return (value, null);
945+
}
946+
895947
/// <summary>
896948
/// Wrap a single string in keyword identifiers.
897949
/// </summary>

QueryBuilder/Compilers/PostgresCompiler.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ public PostgresCompiler()
1111
}
1212

1313
public override string EngineCode { get; } = EngineCodes.PostgreSql;
14+
public override bool SupportsFilterClause { get; set; } = true;
1415

1516

1617
protected override string CompileBasicStringCondition(SqlResult ctx, BasicStringCondition x)

QueryBuilder/Compilers/SqliteCompiler.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
using System.Collections.Generic;
2-
using SqlKata;
3-
using SqlKata.Compilers;
42

53
namespace SqlKata.Compilers
64
{
@@ -10,6 +8,7 @@ public class SqliteCompiler : Compiler
108
protected override string OpeningIdentifier { get; set; } = "\"";
119
protected override string ClosingIdentifier { get; set; } = "\"";
1210
protected override string LastId { get; set; } = "select last_insert_rowid() as id";
11+
public override bool SupportsFilterClause { get; set; } = true;
1312

1413
public override string CompileTrue()
1514
{

QueryBuilder/Query.Select.cs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,5 +68,54 @@ public Query Select(Func<Query, Query> callback, string alias)
6868
{
6969
return Select(callback.Invoke(NewChild()), alias);
7070
}
71+
72+
public Query SelectAggregate(string aggregate, string column, Query filter = null)
73+
{
74+
Method = "select";
75+
76+
AddComponent("select", new AggregatedColumn
77+
{
78+
Column = new Column { Name = column },
79+
Aggregate = aggregate,
80+
Filter = filter,
81+
});
82+
83+
return this;
84+
}
85+
86+
public Query SelectAggregate(string aggregate, string column, Func<Query, Query> filter)
87+
{
88+
if (filter == null)
89+
{
90+
return SelectAggregate(aggregate, column);
91+
}
92+
93+
return SelectAggregate(aggregate, column, filter.Invoke(NewChild()));
94+
}
95+
96+
public Query SelectSum(string column, Func<Query, Query> filter = null)
97+
{
98+
return SelectAggregate("sum", column, filter);
99+
}
100+
101+
public Query SelectCount(string column, Func<Query, Query> filter = null)
102+
{
103+
return SelectAggregate("count", column, filter);
104+
}
105+
106+
public Query SelectAvg(string column, Func<Query, Query> filter = null)
107+
{
108+
return SelectAggregate("avg", column, filter);
109+
}
110+
111+
public Query SelectMin(string column, Func<Query, Query> filter = null)
112+
{
113+
return SelectAggregate("min", column, filter);
114+
}
115+
116+
public Query SelectMax(string column, Func<Query, Query> filter = null)
117+
{
118+
return SelectAggregate("max", column, filter);
119+
}
71120
}
72121
}

0 commit comments

Comments
 (0)