Skip to content

Commit c3fda75

Browse files
committed
feat: add Equal and NotEqual.
1 parent da67004 commit c3fda75

File tree

4 files changed

+138
-0
lines changed

4 files changed

+138
-0
lines changed

assertion.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package assert
2+
3+
import "testing"
4+
5+
type Assertion struct {
6+
t *testing.T
7+
}
8+
9+
// New returns an assertion instance for verifying invariants.
10+
func New(t *testing.T) *Assertion {
11+
a := new(Assertion)
12+
13+
a.t = t
14+
15+
return a
16+
}
17+
18+
// Equal tests equality between actual and expect parameters.
19+
func (a *Assertion) Equal(actual, expect any, message ...string) error {
20+
return Equal(a.t, actual, expect, message...)
21+
}
22+
23+
// NotEqual tests inequality between actual and expected parameters.
24+
func (a *Assertion) NotEqual(actual, expect any, message ...string) error {
25+
return NotEqual(a.t, actual, expect, message...)
26+
}

equal.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package assert
2+
3+
import (
4+
"reflect"
5+
"testing"
6+
)
7+
8+
// Equal tests equality between actual and expect parameters.
9+
func Equal(t *testing.T, actual, expect any, message ...string) error {
10+
if reflect.DeepEqual(actual, expect) {
11+
return nil
12+
}
13+
14+
err := newAssertionError("==", actual, expect, message...)
15+
16+
t.Error(err)
17+
18+
return err
19+
}
20+
21+
// NotEqual tests inequality between actual and expected parameters.
22+
func NotEqual(t *testing.T, actual, expect any, message ...string) error {
23+
if !reflect.DeepEqual(actual, expect) {
24+
return nil
25+
}
26+
27+
err := newAssertionError("!=", actual, expect, message...)
28+
29+
t.Error(err)
30+
31+
return err
32+
}

equal_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package assert
2+
3+
import "testing"
4+
5+
type testStruct struct {
6+
v int
7+
}
8+
9+
func TestEqualAndNotEqual(t *testing.T) {
10+
mockT := new(testing.T)
11+
assertion := New(mockT)
12+
13+
testEqualAndNotEqual(t, assertion, 1, 1, true)
14+
testEqualAndNotEqual(t, assertion, 1, 2, false)
15+
testEqualAndNotEqual(t, assertion, 1, 1.0, false)
16+
testEqualAndNotEqual(t, assertion, 1, "1", false)
17+
testEqualAndNotEqual(t, assertion, 1, '1', false)
18+
testEqualAndNotEqual(t, assertion, 1, []int{1}, false)
19+
testEqualAndNotEqual(t, assertion, []int{1}, []int{1}, true)
20+
21+
obj1 := testStruct{v: 1}
22+
obj2 := testStruct{v: 1}
23+
24+
testEqualAndNotEqual(t, assertion, obj1, obj2, true)
25+
testEqualAndNotEqual(t, assertion, obj1, &obj2, false)
26+
testEqualAndNotEqual(t, assertion, &obj1, &obj2, true)
27+
28+
obj2.v = 2
29+
testEqualAndNotEqual(t, assertion, obj1, obj2, false)
30+
testEqualAndNotEqual(t, assertion, &obj1, &obj2, false)
31+
}
32+
33+
func testEqualAndNotEqual(t *testing.T, assertion *Assertion, v1, v2 any, isEqual bool) {
34+
err := assertion.Equal(v1, v2)
35+
if isEqual && err != nil {
36+
t.Errorf("Equal(%v, %v) = %v, want = nil", v1, v2, err)
37+
} else if !isEqual && err == nil {
38+
t.Errorf("Equal(%v, %v) = nil, want = error", v1, v2)
39+
}
40+
41+
err = assertion.NotEqual(v1, v2)
42+
if isEqual && err == nil {
43+
t.Errorf("NotEqual(%v, %v) = nil, want = error", v1, v2)
44+
} else if !isEqual && err != nil {
45+
t.Errorf("NotEqual(%v, %v) = %v, want = nil", v1, v2, err)
46+
}
47+
}

error.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package assert
2+
3+
import "fmt"
4+
5+
// AssertionError indicates the failure of an assertion.
6+
type AssertionError struct {
7+
message *string
8+
actual any
9+
expect any
10+
operator string
11+
}
12+
13+
func newAssertionError(operator string, actual, expect any, message ...string) AssertionError {
14+
err := AssertionError{
15+
actual: actual,
16+
expect: expect,
17+
operator: operator,
18+
}
19+
20+
if len(message) > 0 {
21+
err.message = &message[0]
22+
}
23+
24+
return err
25+
}
26+
27+
func (err AssertionError) Error() string {
28+
if err.message != nil {
29+
return *err.message
30+
}
31+
32+
return fmt.Sprintf("%v %s %v", err.actual, err.operator, err.expect)
33+
}

0 commit comments

Comments
 (0)