Skip to content

Commit c3755cb

Browse files
pavelvelikhovPavel Velikhov
andauthored
[NEW RBO] Added contant folding (#28398)
Co-authored-by: Pavel Velikhov <pavelvelikhov@localhost.localdomain>
1 parent 80918f9 commit c3755cb

14 files changed

+629
-279
lines changed

ydb/core/kqp/host/kqp_runner.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ class TKqpRunner : public IKqpRunner {
373373
//.AddCommonOptimization()
374374

375375
.Add(CreateKqpPgRewriteTransformer(OptimizeCtx, *typesCtx), "RewritePgSelect")
376-
.Add(CreateKqpNewRBOTransformer(OptimizeCtx, *typesCtx, rboKqpTypeAnnTransformer, kqpTypeAnnTransformer, newRBOPhysicalPeepholeTransformer), "NewRBOTransformer")
376+
.Add(CreateKqpNewRBOTransformer(OptimizeCtx, *typesCtx, rboKqpTypeAnnTransformer, kqpTypeAnnTransformer, newRBOPhysicalPeepholeTransformer, funcRegistry), "NewRBOTransformer")
377377
.Add(CreateKqpRBOCleanupTransformer(*typesCtx), "RBOCleanupTransformer")
378378

379379
//.Add(CreatePhysicalDataProposalsInspector(*typesCtx), "ProvidersPhysicalOptimize")
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
#include "kqp_rbo_rules.h"
2+
#include <ydb/core/kqp/common/kqp_yql.h>
3+
#include <yql/essentials/core/yql_expr_optimize.h>
4+
#include <yql/essentials/utils/log/log.h>
5+
#include <yql/essentials/core/services/yql_transform_pipeline.h>
6+
#include <ydb/library/yql/dq/opt/dq_opt_stat.h>
7+
#include <typeinfo>
8+
9+
using namespace NYql;
10+
using namespace NYql::NNodes;
11+
using namespace NYql::NDq;
12+
13+
namespace {
14+
15+
THashSet<TString> notAllowedDataTypeForSafeCast{"JsonDocument", "DyNumber"};
16+
17+
bool IsSuitableToExtractExpr(const TExprNode::TPtr &input) {
18+
if (auto maybeSafeCast = TExprBase(input).Maybe<TCoSafeCast>()) {
19+
auto maybeDataType = maybeSafeCast.Cast().Type().Maybe<TCoDataType>();
20+
if (!maybeDataType) {
21+
if (const auto maybeOptionalType = maybeSafeCast.Cast().Type().Maybe<TCoOptionalType>()) {
22+
maybeDataType = maybeOptionalType.Cast().ItemType().Maybe<TCoDataType>();
23+
}
24+
}
25+
return (maybeDataType && !notAllowedDataTypeForSafeCast.contains(maybeDataType.Cast().Type().Value()));
26+
}
27+
return true;
28+
}
29+
30+
/**
31+
* Traverse a lambda and extract a list of constant expressions
32+
*/
33+
void ExtractConstantExprs(const TExprNode::TPtr& input, TVector<std::pair<TExprNode::TPtr, TExprNode::TPtr>>& exprs, TExprContext& ctx, bool foldUdfs = true) {
34+
if (!IsSuitableToExtractExpr(input)) {
35+
return;
36+
}
37+
38+
if (TCoLambda::Match(input.Get())) {
39+
auto lambda = TExprBase(input).Cast<TCoLambda>();
40+
return ExtractConstantExprs(lambda.Body().Ptr(), exprs, ctx);
41+
}
42+
43+
if (IsDataOrOptionalOfData(input->GetTypeAnn()) && !NeedCalc(TExprBase(input))) {
44+
return;
45+
}
46+
47+
if (IsConstantExpr(input, foldUdfs) && !input->IsCallable("PgConst")) {
48+
TNodeOnNodeOwnedMap deepClones;
49+
auto inputClone = ctx.DeepCopy(*input, ctx, deepClones, false, true, true);
50+
exprs.push_back(std::make_pair(input, inputClone));
51+
return;
52+
}
53+
54+
if (TCoAsStruct::Match(input.Get())) {
55+
for (auto child : TExprBase(input).Cast<TCoAsStruct>()) {
56+
ExtractConstantExprs(child.Item(1).Ptr(), exprs, ctx);
57+
}
58+
return;
59+
}
60+
61+
if (input->IsCallable() && input->Content() != "EvaluateExpr") {
62+
if (input->ChildrenSize() >= 1) {
63+
for (size_t i = 0; i < input->ChildrenSize(); i++) {
64+
ExtractConstantExprs(input->Child(i), exprs, ctx);
65+
}
66+
}
67+
}
68+
69+
return;
70+
}
71+
72+
}
73+
74+
namespace NKikimr {
75+
namespace NKqp {
76+
77+
void TConstantFoldingStage::RunStage(TOpRoot &root, TRBOContext &ctx) {
78+
TVector<TExprNode::TPtr> lambdasWithConstExpr;
79+
bool foldUdfs = ctx.KqpCtx.Config->EnableFoldUdfs;
80+
81+
// Iterate through all operators that contain lambdas with potential constant expression
82+
83+
// Internal map for remap operation
84+
TNodeOnNodeOwnedMap replaces;
85+
86+
// Actual map used in the optimizer
87+
TVector<std::pair<TExprNode::TPtr, TExprNode::TPtr>> globalExtractedExprs;
88+
TVector<std::shared_ptr<IOperator>> affectedOps;
89+
90+
for (auto it : root) {
91+
if (!it.Current->GetLambdas().empty()) {
92+
auto lambdas = it.Current->GetLambdas();
93+
bool affected = false;
94+
for (auto l : lambdas) {
95+
auto lambda = TCoLambda(l);
96+
TVector<std::pair<TExprNode::TPtr, TExprNode::TPtr>> extractedExprs;
97+
ExtractConstantExprs(lambda.Body().Ptr(), extractedExprs, ctx.ExprCtx, foldUdfs);
98+
if (!extractedExprs.empty()) {
99+
affected = true;
100+
globalExtractedExprs.insert(globalExtractedExprs.end(), extractedExprs.begin(), extractedExprs.end());
101+
}
102+
}
103+
104+
if (affected) {
105+
affectedOps.push_back(it.Current);
106+
}
107+
}
108+
}
109+
110+
if (globalExtractedExprs.empty()) {
111+
return;
112+
}
113+
114+
// Build a list of eval expressions
115+
116+
TExprNode::TListType lambdaList;
117+
TExprNode::TListType evalElements;
118+
for (auto & [k, v] : globalExtractedExprs) {
119+
lambdaList.push_back(k);
120+
evalElements.push_back(v);
121+
}
122+
123+
// Evaluate all the constant expressions at once
124+
auto evalList = ctx.ExprCtx.NewList(root.Pos, std::move(evalElements));
125+
evalList = ctx.ExprCtx.NewCallable(root.Pos, "EvaluateExpr", { evalList });
126+
127+
auto evaluator = TTransformationPipeline(&ctx.TypeCtx)
128+
.AddServiceTransformers()
129+
.AddPreTypeAnnotation()
130+
.AddExpressionEvaluation(ctx.FuncRegistry).Build(false);
131+
132+
ctx.ExprCtx.Step.Repeat(TExprStep::ExprEval);
133+
IGraphTransformer::TStatus status(IGraphTransformer::TStatus::Ok);
134+
do {
135+
status = evaluator->Transform(evalList, evalList, ctx.ExprCtx);
136+
} while (status == IGraphTransformer::TStatus::Repeat);
137+
138+
// Iterate over affected operators and modify their expressions with folded expressions
139+
for (size_t i=0; i<lambdaList.size(); i++) {
140+
replaces[lambdaList[i].Get()] = evalList->Child(i);
141+
}
142+
143+
for (auto op : affectedOps) {
144+
op->ApplyReplaceMap(replaces, ctx);
145+
}
146+
}
147+
}
148+
}

ydb/core/kqp/opt/rbo/kqp_operator.cpp

Lines changed: 85 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "kqp_operator.h"
2+
#include <yql/essentials/core/yql_expr_optimize.h>
23

34
namespace {
45
using namespace NKikimr;
@@ -124,6 +125,14 @@ TExprNode::TPtr RenameMembers(TExprNode::TPtr input, const THashMap<TInfoUnit, T
124125
}
125126
}
126127

128+
} // namespace
129+
130+
namespace NKikimr {
131+
namespace NKqp {
132+
133+
using namespace NYql;
134+
using namespace NNodes;
135+
127136
TString PrintRBOExpression(TExprNode::TPtr expr, TExprContext & ctx) {
128137
try {
129138
TConvertToAstSettings settings;
@@ -142,14 +151,6 @@ TString PrintRBOExpression(TExprNode::TPtr expr, TExprContext & ctx) {
142151
}
143152
}
144153

145-
} // namespace
146-
147-
namespace NKikimr {
148-
namespace NKqp {
149-
150-
using namespace NYql;
151-
using namespace NNodes;
152-
153154
/**
154155
* Scan expression and retrieve all members
155156
*/
@@ -388,6 +389,10 @@ std::pair<TExprNode::TPtr, TVector<TExprNode::TPtr>> BuildSortKeySelector(TVecto
388389
}
389390

390391

392+
/**
393+
* Base class Operator methods
394+
*/
395+
391396
void IOperator::RenameIUs(const THashMap<TInfoUnit, TInfoUnit, TInfoUnit::THashFunction> &renameMap, TExprContext &ctx) {
392397
Y_UNUSED(renameMap);
393398
Y_UNUSED(ctx);
@@ -398,11 +403,19 @@ const TTypeAnnotationNode* IOperator::GetIUType(TInfoUnit iu) {
398403
return structType->FindItemType(iu.GetFullName());
399404
}
400405

406+
/**
407+
* EmptySource operator methods
408+
*/
409+
401410
TString TOpEmptySource::ToString(TExprContext& ctx) {
402411
Y_UNUSED(ctx);
403412
return "EmptySource";
404413
}
405414

415+
/**
416+
* OpRead operator methods
417+
*/
418+
406419
TOpRead::TOpRead(TExprNode::TPtr node) : IOperator(EOperator::Source, node->Pos()) {
407420
auto opSource = TKqpOpRead(node);
408421

@@ -440,6 +453,10 @@ TString TOpRead::ToString(TExprContext& ctx) {
440453
return res;
441454
}
442455

456+
/**
457+
* OpMap operator methods
458+
*/
459+
443460
TOpMap::TOpMap(std::shared_ptr<IOperator> input, TPositionHandle pos, TVector<std::pair<TInfoUnit, std::variant<TInfoUnit, TExprNode::TPtr>>> mapElements,
444461
bool project)
445462
: IUnaryOperator(EOperator::Map, pos, input), MapElements(mapElements), Project(project) {}
@@ -456,6 +473,16 @@ TVector<TInfoUnit> TOpMap::GetOutputIUs() {
456473
return res;
457474
}
458475

476+
TVector<TExprNode::TPtr> TOpMap::GetLambdas() {
477+
TVector<TExprNode::TPtr> result;
478+
for (auto &[_,body] : MapElements) {
479+
if (std::holds_alternative<TExprNode::TPtr>(body)) {
480+
result.push_back(std::get<TExprNode::TPtr>(body));
481+
}
482+
}
483+
return result;
484+
}
485+
459486
TVector<TInfoUnit> TOpMap::GetScalarSubplanIUs(TPlanProps& props) {
460487
TVector<TInfoUnit> allVars;
461488
TVector<TInfoUnit> res;
@@ -524,6 +551,19 @@ void TOpMap::RenameIUs(const THashMap<TInfoUnit, TInfoUnit, TInfoUnit::THashFunc
524551
MapElements = newMapElements;
525552
}
526553

554+
void TOpMap::ApplyReplaceMap(TNodeOnNodeOwnedMap map, TRBOContext & ctx) {
555+
TOptimizeExprSettings settings(&ctx.TypeCtx);
556+
for (size_t i=0; i<MapElements.size(); i++) {
557+
auto & body = MapElements[i].second;
558+
if (std::holds_alternative<TExprNode::TPtr>(body)) {
559+
auto bodyLambda = std::get<TExprNode::TPtr>(body);
560+
RemapExpr(bodyLambda, bodyLambda, map, ctx.ExprCtx, settings);
561+
MapElements[i].second = std::variant<TInfoUnit,TExprNode::TPtr>(bodyLambda);
562+
}
563+
}
564+
}
565+
566+
527567
TString TOpMap::ToString(TExprContext& ctx) {
528568
auto res = TStringBuilder();
529569
res << "Map [";
@@ -547,6 +587,10 @@ TString TOpMap::ToString(TExprContext& ctx) {
547587
return res;
548588
}
549589

590+
/**
591+
* OpProject methods
592+
*/
593+
550594
TOpProject::TOpProject(std::shared_ptr<IOperator> input, TPositionHandle pos, TVector<TInfoUnit> projectList)
551595
: IUnaryOperator(EOperator::Project, pos, input), ProjectList(projectList) {}
552596

@@ -584,6 +628,10 @@ TString TOpProject::ToString(TExprContext& ctx) {
584628
return res;
585629
}
586630

631+
/**
632+
* OpFilter operator methods
633+
*/
634+
587635
TOpFilter::TOpFilter(std::shared_ptr<IOperator> input, TPositionHandle pos, TExprNode::TPtr filterLambda)
588636
: IUnaryOperator(EOperator::Filter, pos, input), FilterLambda(filterLambda) {}
589637

@@ -593,6 +641,15 @@ void TOpFilter::RenameIUs(const THashMap<TInfoUnit, TInfoUnit, TInfoUnit::THashF
593641
FilterLambda = RenameMembers(FilterLambda, renameMap, ctx);
594642
}
595643

644+
TVector<TExprNode::TPtr> TOpFilter::GetLambdas() {
645+
return {FilterLambda};
646+
}
647+
648+
void TOpFilter::ApplyReplaceMap(TNodeOnNodeOwnedMap map, TRBOContext & ctx) {
649+
TOptimizeExprSettings settings(&ctx.TypeCtx);
650+
RemapExpr(FilterLambda, FilterLambda, map, ctx.ExprCtx, settings);
651+
}
652+
596653
TVector<TInfoUnit> TOpFilter::GetFilterIUs(TPlanProps& props) const {
597654
TVector<TInfoUnit> res;
598655

@@ -666,6 +723,10 @@ TString TOpFilter::ToString(TExprContext& ctx) {
666723
return TStringBuilder() << "Filter :" << PrintRBOExpression(FilterLambda, ctx);
667724
}
668725

726+
/**
727+
* OpJoin operator methods
728+
*/
729+
669730
TOpJoin::TOpJoin(std::shared_ptr<IOperator> leftInput, std::shared_ptr<IOperator> rightInput, TPositionHandle pos, TString joinKind,
670731
TVector<std::pair<TInfoUnit, TInfoUnit>> joinKeys)
671732
: IBinaryOperator(EOperator::Join, pos, leftInput, rightInput), JoinKind(joinKind), JoinKeys(joinKeys) {}
@@ -706,6 +767,10 @@ TString TOpJoin::ToString(TExprContext& ctx) {
706767
return res;
707768
}
708769

770+
/**
771+
* OpUnionAll operator methods
772+
*/
773+
709774
TOpUnionAll::TOpUnionAll(std::shared_ptr<IOperator> leftInput, std::shared_ptr<IOperator> rightInput, TPositionHandle pos, bool ordered)
710775
: IBinaryOperator(EOperator::UnionAll, pos, leftInput, rightInput), Ordered(ordered) {}
711776

@@ -718,6 +783,10 @@ TString TOpUnionAll::ToString(TExprContext& ctx) {
718783
return "UnionAll";
719784
}
720785

786+
/**
787+
* OpLimit operator methods
788+
*/
789+
721790
TOpLimit::TOpLimit(std::shared_ptr<IOperator> input, TPositionHandle pos, TExprNode::TPtr limitCond)
722791
: IUnaryOperator(EOperator::Limit, pos, input), LimitCond(limitCond) {}
723792

@@ -731,6 +800,10 @@ TString TOpLimit::ToString(TExprContext& ctx) {
731800
return TStringBuilder() << "Limit: " << PrintRBOExpression(LimitCond, ctx);
732801
}
733802

803+
/**
804+
* OpAggregate operator methods
805+
*/
806+
734807
TOpAggregate::TOpAggregate(std::shared_ptr<IOperator> input, TVector<TOpAggregationTraits>& aggTraitsList, TVector<TInfoUnit>& keyColumns,
735808
EAggregationPhase aggPhase, bool distinctAll, TPositionHandle pos)
736809
: IUnaryOperator(EOperator::Aggregate, pos, input), AggregationTraitsList(aggTraitsList), KeyColumns(keyColumns),
@@ -774,6 +847,10 @@ TString TOpAggregate::ToString(TExprContext& ctx) {
774847
return strBuilder;
775848
}
776849

850+
/**
851+
* OpRoot operator methods
852+
*/
853+
777854
TOpRoot::TOpRoot(std::shared_ptr<IOperator> input, TPositionHandle pos, TVector<TString> columnOrder) :
778855
IUnaryOperator(EOperator::Root, pos, input),
779856
ColumnOrder(columnOrder) {}

0 commit comments

Comments
 (0)