From 08e40ea080143d545caf84c443fbdd55ab509cad Mon Sep 17 00:00:00 2001 From: Juho Saarinen Date: Sat, 8 Nov 2025 13:26:16 +0200 Subject: [PATCH 1/3] feat: handle nil values in rules --- ast/Expression.go | 132 ++++++++++++++++++++++++++----------- ast/ExpressionAtom.go | 29 ++++++-- ast/Variable.go | 22 ++++++- engine/GruleEngine.go | 19 +++++- engine/GruleEngine_test.go | 23 ++++++- 5 files changed, 176 insertions(+), 49 deletions(-) diff --git a/ast/Expression.go b/ast/Expression.go index e6fdc8f2..20d25d2c 100755 --- a/ast/Expression.go +++ b/ast/Expression.go @@ -17,10 +17,11 @@ package ast import ( "errors" "fmt" - "github.com/hyperjumptech/grule-rule-engine/ast/unique" "reflect" "strings" + "github.com/hyperjumptech/grule-rule-engine/ast/unique" + "github.com/hyperjumptech/grule-rule-engine/pkg" ) @@ -79,7 +80,8 @@ type Expression struct { Negated bool Value reflect.Value - Evaluated bool + Evaluated bool + CompareNilValues bool } // MakeCatalog will create a catalog entry from Expression node. @@ -279,7 +281,18 @@ func (e *Expression) SetGrlText(grlText string) { // Evaluate will evaluate this AST graph for when scope evaluation func (e *Expression) Evaluate(dataContext IDataContext, memory *WorkingMemory) (reflect.Value, error) { - if e.Evaluated == true { + compareNode := dataContext.Get("COMPARE_NILS") + if compareNode != nil { + if rv := compareNode.Value(); rv.IsValid() && rv.Kind() == reflect.Bool { + e.CompareNilValues = rv.Bool() + } else { + e.CompareNilValues = false + } + } else { + e.CompareNilValues = false + } + + if e.Evaluated { return e.Value, nil } @@ -350,41 +363,84 @@ func (e *Expression) Evaluate(dataContext IDataContext, memory *WorkingMemory) ( return reflect.Value{}, fmt.Errorf("right hand expression error. got %v", rerr) } - switch e.Operator { - case OpMul: - val, opErr = pkg.EvaluateMultiplication(lval, rval) - case OpDiv: - val, opErr = pkg.EvaluateDivision(lval, rval) - case OpMod: - val, opErr = pkg.EvaluateModulo(lval, rval) - case OpAdd: - val, opErr = pkg.EvaluateAddition(lval, rval) - case OpSub: - val, opErr = pkg.EvaluateSubtraction(lval, rval) - case OpBitAnd: - val, opErr = pkg.EvaluateBitAnd(lval, rval) - case OpBitOr: - val, opErr = pkg.EvaluateBitOr(lval, rval) - case OpGT: - val, opErr = pkg.EvaluateGreaterThan(lval, rval) - case OpLT: - val, opErr = pkg.EvaluateLesserThan(lval, rval) - case OpGTE: - val, opErr = pkg.EvaluateGreaterThanEqual(lval, rval) - case OpLTE: - val, opErr = pkg.EvaluateLesserThanEqual(lval, rval) - case OpEq: - val, opErr = pkg.EvaluateEqual(lval, rval) - case OpNEq: - val, opErr = pkg.EvaluateNotEqual(lval, rval) - case OpAnd: - val, opErr = pkg.EvaluateLogicAnd(lval, rval) - case OpOr: - val, opErr = pkg.EvaluateLogicOr(lval, rval) - } - if opErr == nil { - e.Value = val - e.Evaluated = true + if e.CompareNilValues && (!lval.IsValid() || !rval.IsValid()) { + if e.CompareNilValues { + AstLog.Debugf("Values have invalid value (%v and %v) but continuing with null handling", lval, rval) + switch e.Operator { + case OpMul, OpDiv, OpBitAnd, OpBitOr, OpMod: + // Can be left as nil, as these operators with Nil are not defined + e.Evaluated = true + case OpAdd: + if lval.IsValid() { + val = lval + } else if rval.IsValid() { + val = rval + } + e.Evaluated = true + case OpSub: + if lval.IsValid() { + val = lval + } else if rval.IsValid() { + val, _ = pkg.EvaluateSubtraction(reflect.ValueOf(0), rval) + } + e.Evaluated = true + case OpOr: + lvale := pkg.GetValueElem(lval) + rvale := pkg.GetValueElem(rval) + if (lvale.IsValid() && lvale.Kind() == reflect.Bool && lvale.Bool()) || (rvale.IsValid() && rvale.Kind() == reflect.Bool && rvale.Bool()) { + val = reflect.ValueOf(true) + } else { + val = reflect.ValueOf(false) + } + e.Value = val + e.Evaluated = true + case OpEq, OpGTE, OpLTE, OpGT, OpLT, OpAnd: + val = reflect.ValueOf(false) + e.Value = val + e.Evaluated = true + case OpNEq: + val = reflect.ValueOf(true) + e.Value = val + e.Evaluated = true + } + } + } else { + switch e.Operator { + case OpMul: + val, opErr = pkg.EvaluateMultiplication(lval, rval) + case OpDiv: + val, opErr = pkg.EvaluateDivision(lval, rval) + case OpMod: + val, opErr = pkg.EvaluateModulo(lval, rval) + case OpAdd: + val, opErr = pkg.EvaluateAddition(lval, rval) + case OpSub: + val, opErr = pkg.EvaluateSubtraction(lval, rval) + case OpBitAnd: + val, opErr = pkg.EvaluateBitAnd(lval, rval) + case OpBitOr: + val, opErr = pkg.EvaluateBitOr(lval, rval) + case OpGT: + val, opErr = pkg.EvaluateGreaterThan(lval, rval) + case OpLT: + val, opErr = pkg.EvaluateLesserThan(lval, rval) + case OpGTE: + val, opErr = pkg.EvaluateGreaterThanEqual(lval, rval) + case OpLTE: + val, opErr = pkg.EvaluateLesserThanEqual(lval, rval) + case OpEq: + val, opErr = pkg.EvaluateEqual(lval, rval) + case OpNEq: + val, opErr = pkg.EvaluateNotEqual(lval, rval) + case OpAnd: + val, opErr = pkg.EvaluateLogicAnd(lval, rval) + case OpOr: + val, opErr = pkg.EvaluateLogicOr(lval, rval) + } + if opErr == nil { + e.Value = val + e.Evaluated = true + } } return val, opErr diff --git a/ast/ExpressionAtom.go b/ast/ExpressionAtom.go index dae3fdad..c8e8ee73 100755 --- a/ast/ExpressionAtom.go +++ b/ast/ExpressionAtom.go @@ -17,11 +17,12 @@ package ast import ( "errors" "fmt" - "github.com/hyperjumptech/grule-rule-engine/ast/unique" - "github.com/hyperjumptech/grule-rule-engine/model" "reflect" "strings" + "github.com/hyperjumptech/grule-rule-engine/ast/unique" + "github.com/hyperjumptech/grule-rule-engine/model" + "github.com/hyperjumptech/grule-rule-engine/pkg" ) @@ -49,7 +50,8 @@ type ExpressionAtom struct { Value reflect.Value ValueNode model.ValueNode - Evaluated bool + Evaluated bool + CompareNilValues bool } // MakeCatalog will create a catalog entry from ExpressionAtom node. @@ -265,10 +267,22 @@ func (e *ExpressionAtom) SetGrlText(grlText string) { // Evaluate will evaluate this AST graph for when scope evaluation func (e *ExpressionAtom) Evaluate(dataContext IDataContext, memory *WorkingMemory) (val reflect.Value, err error) { - if e.Evaluated == true { + // Extract COMPARE_NILS from dataContext as a bool, defaulting to false when unavailable or not boolean. + compareNode := dataContext.Get("COMPARE_NILS") + if compareNode != nil { + if rv := compareNode.Value(); rv.IsValid() && rv.Kind() == reflect.Bool { + e.CompareNilValues = rv.Bool() + } else { + e.CompareNilValues = false + } + } else { + e.CompareNilValues = false + } + if e.Evaluated { return e.Value, nil } + if e.Constant != nil { val, err := e.Constant.Evaluate(dataContext, memory) if err != nil { @@ -368,6 +382,13 @@ func (e *ExpressionAtom) Evaluate(dataContext IDataContext, memory *WorkingMemor } valueNode, err := e.ExpressionAtom.ValueNode.GetChildNodeByField(e.VariableName) if err != nil { + if e.CompareNilValues { + e.ValueNode = model.NewGoValueNode(reflect.ValueOf(nil), fmt.Sprintf("%s.%s->nil", e.ExpressionAtom.ValueNode.IdentifiedAs(), e.VariableName)) + e.Value = e.ValueNode.Value() + e.Evaluated = true + + return e.Value, nil + } return reflect.Value{}, err } diff --git a/ast/Variable.go b/ast/Variable.go index 65aeef3e..8c888866 100755 --- a/ast/Variable.go +++ b/ast/Variable.go @@ -16,11 +16,12 @@ package ast import ( "fmt" - "github.com/hyperjumptech/grule-rule-engine/ast/unique" - "github.com/hyperjumptech/grule-rule-engine/model" "reflect" "strings" + "github.com/hyperjumptech/grule-rule-engine/ast/unique" + "github.com/hyperjumptech/grule-rule-engine/model" + "github.com/hyperjumptech/grule-rule-engine/pkg" ) @@ -43,6 +44,8 @@ type Variable struct { ValueNode model.ValueNode Value reflect.Value + + CompareNilValues bool } // MakeCatalog create a catalog entry for this AST Node @@ -219,6 +222,17 @@ func (e *Variable) Assign(newVal reflect.Value, dataContext IDataContext, memory // Evaluate will evaluate this AST graph for when scope evaluation func (e *Variable) Evaluate(dataContext IDataContext, memory *WorkingMemory) (reflect.Value, error) { + compareNode := dataContext.Get("COMPARE_NILS") + if compareNode != nil { + if rv := compareNode.Value(); rv.IsValid() && rv.Kind() == reflect.Bool { + e.CompareNilValues = rv.Bool() + } else { + e.CompareNilValues = false + } + } else { + e.CompareNilValues = false + } + if len(e.Name) > 0 && e.Variable == nil { valueNode := dataContext.Get(e.Name) if valueNode == nil { @@ -238,7 +252,9 @@ func (e *Variable) Evaluate(dataContext IDataContext, memory *WorkingMemory) (re } valueNode, err := e.Variable.ValueNode.GetChildNodeByField(e.Name) if err != nil { - + if e.CompareNilValues { + return reflect.ValueOf(nil), nil + } return reflect.Value{}, err } e.ValueNode = valueNode diff --git a/engine/GruleEngine.go b/engine/GruleEngine.go index b65f867d..e8740df2 100755 --- a/engine/GruleEngine.go +++ b/engine/GruleEngine.go @@ -17,11 +17,12 @@ package engine import ( "context" "fmt" + "sort" + "time" + "github.com/rs/zerolog" "github.com/sirupsen/logrus" "go.uber.org/zap" - "sort" - "time" "github.com/hyperjumptech/grule-rule-engine/ast" "github.com/hyperjumptech/grule-rule-engine/logger" @@ -87,6 +88,7 @@ func NewGruleEngine() *GruleEngine { type GruleEngine struct { MaxCycle uint64 ReturnErrOnFailedRuleEvaluation bool + CompareNilValues bool Listeners []GruleEngineListener } @@ -150,6 +152,13 @@ func (g *GruleEngine) ExecuteWithContext(ctx context.Context, dataCtx ast.IDataC return err } + err = dataCtx.Add("COMPARE_NILS", g.CompareNilValues) + if err != nil { + log.Error("COMPARE_NILS add err") + + return err + } + // Working memory need to be resetted. all Expression will be set as not evaluated. log.Debugf("Resetting Working memory") knowledge.WorkingMemory.ResetAll() @@ -279,6 +288,12 @@ func (g *GruleEngine) FetchMatchingRules(dataCtx ast.IDataContext, knowledge *as return nil, err } + err = dataCtx.Add("COMPARE_NILS", g.CompareNilValues) + if err != nil { + log.Error("COMPARE_NILS add err") + + return nil, err + } // Working memory need to be resetted. all Expression will be set as not evaluated. log.Debugf("Resetting Working memory") knowledge.WorkingMemory.ResetAll() diff --git a/engine/GruleEngine_test.go b/engine/GruleEngine_test.go index 38089235..799bd4ca 100755 --- a/engine/GruleEngine_test.go +++ b/engine/GruleEngine_test.go @@ -142,7 +142,7 @@ func getTypeOf(i interface{}) string { const ruleWithAccessErr = `rule AccessErrRule "test access error rule" salience 10 { when - TestStruct.NotExist == 1 + TestStruct.NotExist == 1 || TestStruct.OtherNonExists then Retract("AccessErrRule"); }` @@ -154,7 +154,7 @@ func TestEngine_ExecuteErr(t *testing.T) { lib := ast.NewKnowledgeLibrary() rb := builder.NewRuleBuilder(lib) - err = rb.BuildRuleFromResource("Test", "0.1.1", pkg.NewBytesResource([]byte(rules))) + err = rb.BuildRuleFromResource("Test", "0.1.1", pkg.NewBytesResource([]byte(ruleWithAccessErr))) assert.NoError(t, err) engine := NewGruleEngine() @@ -165,6 +165,25 @@ func TestEngine_ExecuteErr(t *testing.T) { assert.Error(t, err) } +func TestEngine_ExecuteHandleNils(t *testing.T) { + dctx := ast.NewDataContext() + err := dctx.Add("TestStruct", &TestStruct{}) + assert.NoError(t, err) + + lib := ast.NewKnowledgeLibrary() + rb := builder.NewRuleBuilder(lib) + err = rb.BuildRuleFromResource("Test", "0.1.1", pkg.NewBytesResource([]byte(ruleWithAccessErr))) + assert.NoError(t, err) + + engine := NewGruleEngine() + engine.ReturnErrOnFailedRuleEvaluation = true + engine.CompareNilValues = true + kb, err := lib.NewKnowledgeBaseInstance("Test", "0.1.1") + assert.NoError(t, err) + err = engine.Execute(dctx, kb) + assert.NoError(t, err) +} + func TestEmptyValueEquality(t *testing.T) { t1 := time.Time{} tv1 := reflect.ValueOf(t1) From 46d795daf40ad3e48e02027c30c82b5af37f706b Mon Sep 17 00:00:00 2001 From: Juho Saarinen Date: Sat, 8 Nov 2025 14:04:33 +0200 Subject: [PATCH 2/3] feat: handle function calls with nil --- ast/ExpressionAtom.go | 4 ++++ engine/GruleEngine_test.go | 21 ++++++++++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/ast/ExpressionAtom.go b/ast/ExpressionAtom.go index c8e8ee73..f43fdb69 100755 --- a/ast/ExpressionAtom.go +++ b/ast/ExpressionAtom.go @@ -360,6 +360,10 @@ func (e *ExpressionAtom) Evaluate(dataContext IDataContext, memory *WorkingMemor return reflect.ValueOf(nil), err } + if e.ExpressionAtom.ValueNode == nil && e.CompareNilValues { + return reflect.ValueOf(nil), nil + } + retVal, err := e.ExpressionAtom.ValueNode.CallFunction(e.FunctionCall.FunctionName, args...) if err != nil { diff --git a/engine/GruleEngine_test.go b/engine/GruleEngine_test.go index 799bd4ca..5c779b36 100755 --- a/engine/GruleEngine_test.go +++ b/engine/GruleEngine_test.go @@ -142,7 +142,7 @@ func getTypeOf(i interface{}) string { const ruleWithAccessErr = `rule AccessErrRule "test access error rule" salience 10 { when - TestStruct.NotExist == 1 || TestStruct.OtherNonExists + TestStruct.NotExist == 1 || TestStruct.OtherNonExists || TestStruct.ThirdNonExist.StrContains("included value") == true then Retract("AccessErrRule"); }` @@ -165,6 +165,25 @@ func TestEngine_ExecuteErr(t *testing.T) { assert.Error(t, err) } +func TestEngine_ExecuteHandleNilsJSON(t *testing.T) { + dctx := ast.NewDataContext() + err := dctx.AddJSON("TestStruct", []byte("{}")) + assert.NoError(t, err) + + lib := ast.NewKnowledgeLibrary() + rb := builder.NewRuleBuilder(lib) + err = rb.BuildRuleFromResource("Test", "0.1.1", pkg.NewBytesResource([]byte(ruleWithAccessErr))) + assert.NoError(t, err) + + engine := NewGruleEngine() + engine.ReturnErrOnFailedRuleEvaluation = true + engine.CompareNilValues = true + kb, err := lib.NewKnowledgeBaseInstance("Test", "0.1.1") + assert.NoError(t, err) + err = engine.Execute(dctx, kb) + assert.NoError(t, err) +} + func TestEngine_ExecuteHandleNils(t *testing.T) { dctx := ast.NewDataContext() err := dctx.Add("TestStruct", &TestStruct{}) From e771ad4070cc9161728f79cf0d971f3b75c2d3c6 Mon Sep 17 00:00:00 2001 From: Juho Saarinen Date: Sat, 8 Nov 2025 15:11:20 +0200 Subject: [PATCH 3/3] feat: handle function calls with nil arguments --- ast/ExpressionAtom.go | 3 +++ engine/GruleEngine_test.go | 6 ++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/ast/ExpressionAtom.go b/ast/ExpressionAtom.go index f43fdb69..85f8c428 100755 --- a/ast/ExpressionAtom.go +++ b/ast/ExpressionAtom.go @@ -366,6 +366,9 @@ func (e *ExpressionAtom) Evaluate(dataContext IDataContext, memory *WorkingMemor retVal, err := e.ExpressionAtom.ValueNode.CallFunction(e.FunctionCall.FunctionName, args...) if err != nil { + if e.CompareNilValues { + return reflect.ValueOf(nil), nil + } return reflect.ValueOf(nil), err } diff --git a/engine/GruleEngine_test.go b/engine/GruleEngine_test.go index 5c779b36..7eb6139a 100755 --- a/engine/GruleEngine_test.go +++ b/engine/GruleEngine_test.go @@ -140,9 +140,10 @@ func getTypeOf(i interface{}) string { return t.Name() } +// TODO: Add also tests when function argument(s) are nil pointers const ruleWithAccessErr = `rule AccessErrRule "test access error rule" salience 10 { when - TestStruct.NotExist == 1 || TestStruct.OtherNonExists || TestStruct.ThirdNonExist.StrContains("included value") == true + TestStruct.NotExist == 1 || TestStruct.OtherNonExists || TestStruct.ThirdNonExist.Contains("included value") == true || TestStruct.exist.Contains(TestStruct.NonExisting) == true then Retract("AccessErrRule"); }` @@ -167,7 +168,8 @@ func TestEngine_ExecuteErr(t *testing.T) { func TestEngine_ExecuteHandleNilsJSON(t *testing.T) { dctx := ast.NewDataContext() - err := dctx.AddJSON("TestStruct", []byte("{}")) + testJson := `{ "exist": "\"This field exist\"" }` + err := dctx.AddJSON("TestStruct", []byte(testJson)) assert.NoError(t, err) lib := ast.NewKnowledgeLibrary()