Skip to content

Commit c98b9a3

Browse files
committed
feat: isFuncTakesContexts returns context numbers.
1 parent 60d124b commit c98b9a3

File tree

4 files changed

+48
-15
lines changed

4 files changed

+48
-15
lines changed

async.go

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,25 @@ func isContextType(ty reflect.Type) bool {
6565
ty.Implements(contextType) && contextType.Implements(ty)
6666
}
6767

68-
// isFuncTakesContext checks the function takes a Context as the first argument.
69-
func isFuncTakesContext(fn reflect.Type) bool {
68+
// isFuncTakesContexts checks the function takes Contexts as the arguments.
69+
func isFuncTakesContexts(fn reflect.Type) (bool, int) {
7070
if fn.NumIn() <= 0 {
71-
return false
71+
return false, 0
7272
}
7373

74-
in := fn.In(0)
74+
hasContext := false
75+
contextNum := 0
76+
for i := 0; i < fn.NumIn(); i++ {
77+
ok := isContextType(fn.In(i))
78+
if ok {
79+
hasContext = true
80+
contextNum++
81+
} else {
82+
break
83+
}
84+
}
7585

76-
return isContextType(in)
86+
return hasContext, contextNum
7787
}
7888

7989
// isFuncReturnsError checks the last return value of the function is an error if the function
@@ -143,7 +153,7 @@ func invokeAsyncFn(fn AsyncFn, ctx context.Context, params []any) ([]any, error)
143153

144154
// makeFuncIn makes a reflected values list of the parameters to call the function.
145155
func makeFuncIn(ft reflect.Type, ctx context.Context, params []any) []reflect.Value {
146-
isTakeContext := isFuncTakesContext(ft)
156+
isTakeContext, _ := isFuncTakesContexts(ft)
147157
isContextParam := isTakeContext && isFirstParamContext(params, ft.NumIn())
148158

149159
if !ft.IsVariadic() {
@@ -219,7 +229,7 @@ func makeNonVariadicFuncIn(
219229
// isValidNextFunc checks the current function's return values and the next function's parameters,
220230
// and returns a boolean value to indicates whether the functions are match or not
221231
func isValidNextFunc(cur, next reflect.Type) bool {
222-
isTakeContext := isFuncTakesContext(next)
232+
isTakeContext, _ := isFuncTakesContexts(next)
223233
numOut := cur.NumOut()
224234
numIn := next.NumIn()
225235

async_test.go

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,36 @@ func TestValidateAsyncFuncs(t *testing.T) {
3939
}, ErrNotFunction)
4040
}
4141

42-
func TestIsFuncTakesContext(t *testing.T) {
42+
func TestIsFuncTakesContexts(t *testing.T) {
4343
a := assert.New(t)
4444

45-
a.TrueNow(isFuncTakesContext(reflect.TypeOf(func(context.Context) {})))
46-
a.TrueNow(isFuncTakesContext(reflect.TypeOf(func(context.Context, int) {})))
47-
a.NotTrueNow(isFuncTakesContext(reflect.TypeOf(func() {})))
48-
a.NotTrueNow(isFuncTakesContext(reflect.TypeOf(func(int) {})))
49-
a.NotTrueNow(isFuncTakesContext(reflect.TypeOf(func(int, context.Context) {})))
45+
isTakeContext, contextNum := isFuncTakesContexts(reflect.TypeOf(func(context.Context) {}))
46+
a.TrueNow(isTakeContext)
47+
a.EqualNow(contextNum, 1)
48+
49+
isTakeContext, contextNum = isFuncTakesContexts(reflect.TypeOf(func(context.Context, int) {}))
50+
a.TrueNow(isTakeContext)
51+
a.EqualNow(contextNum, 1)
52+
53+
isTakeContext, contextNum = isFuncTakesContexts(reflect.TypeOf(func(context.Context, context.Context, int) {}))
54+
a.TrueNow(isTakeContext)
55+
a.EqualNow(contextNum, 2)
56+
57+
isTakeContext, contextNum = isFuncTakesContexts(reflect.TypeOf(func(context.Context, int, context.Context) {}))
58+
a.TrueNow(isTakeContext)
59+
a.EqualNow(contextNum, 1)
60+
61+
isTakeContext, contextNum = isFuncTakesContexts(reflect.TypeOf(func() {}))
62+
a.NotTrueNow(isTakeContext)
63+
a.EqualNow(contextNum, 0)
64+
65+
isTakeContext, contextNum = isFuncTakesContexts(reflect.TypeOf(func(int) {}))
66+
a.NotTrueNow(isTakeContext)
67+
a.EqualNow(contextNum, 0)
68+
69+
isTakeContext, contextNum = isFuncTakesContexts(reflect.TypeOf(func(int, context.Context) {}))
70+
a.NotTrueNow(isTakeContext)
71+
a.EqualNow(contextNum, 0)
5072
}
5173

5274
func TestIsFuncReturnsError(t *testing.T) {

until.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ func validateUntilFuncs(testFn, fn AsyncFn) (isNoParam bool) {
7171
}
7272

7373
numIn := tft.NumIn()
74-
if numIn == 0 || (numIn == 1 && isFuncTakesContext(tft)) {
74+
_, contextNum := isFuncTakesContexts(tft)
75+
if numIn == contextNum {
7576
return true
7677
}
7778

while.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ func validateWhileFuncs(testFn, fn AsyncFn) {
7070
}
7171

7272
numIn := tft.NumIn()
73-
isTakeContext := isFuncTakesContext(tft)
73+
isTakeContext, _ := isFuncTakesContexts(tft)
7474
if isTakeContext {
7575
numIn--
7676
}

0 commit comments

Comments
 (0)