Skip to content

Commit 60d124b

Browse files
committed
feat: add isValidNextFunc test helper function.
1 parent f14fa8e commit 60d124b

File tree

3 files changed

+41
-58
lines changed

3 files changed

+41
-58
lines changed

async.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,3 +215,38 @@ func makeNonVariadicFuncIn(
215215

216216
return in
217217
}
218+
219+
// isValidNextFunc checks the current function's return values and the next function's parameters,
220+
// and returns a boolean value to indicates whether the functions are match or not
221+
func isValidNextFunc(cur, next reflect.Type) bool {
222+
isTakeContext := isFuncTakesContext(next)
223+
numOut := cur.NumOut()
224+
numIn := next.NumIn()
225+
226+
if isTakeContext {
227+
numIn--
228+
}
229+
if numOut < numIn {
230+
return false
231+
}
232+
233+
i := 0
234+
j := 0
235+
236+
if isTakeContext {
237+
if numOut > 0 && isContextType(cur.Out(0)) {
238+
i++
239+
}
240+
numIn++
241+
j++
242+
}
243+
for i < numOut && j < numIn {
244+
if cur.Out(i) != next.In(j) {
245+
return false
246+
}
247+
i++
248+
j++
249+
}
250+
251+
return true
252+
}

seq.go

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -66,45 +66,10 @@ func validateSeqFuncs(funcs ...AsyncFn) error {
6666
}
6767

6868
for i := 1; i < len(types); i++ {
69-
err := validateSeqFuncParams(types[i-1], types[i])
70-
if err != nil {
71-
return err
72-
}
73-
}
74-
75-
return nil
76-
}
77-
78-
// validateSeqFuncParams checks the previous function's return values and the current function's
79-
// parameters, and returns an error if they are not match.
80-
func validateSeqFuncParams(prev, cur reflect.Type) error {
81-
isTakeContext := isFuncTakesContext(cur)
82-
numIn := cur.NumIn()
83-
numOut := prev.NumOut()
84-
85-
if isTakeContext {
86-
numIn--
87-
}
88-
if numOut < numIn {
89-
return ErrInvalidSeqFuncs
90-
}
91-
92-
i := 0
93-
j := 0
94-
95-
if isTakeContext {
96-
if numOut > 0 && isContextType(prev.Out(0)) {
97-
i++
98-
}
99-
numIn++
100-
j++
101-
}
102-
for i < numOut && j < numIn {
103-
if prev.Out(i) != cur.In(j) {
69+
isValid := isValidNextFunc(types[i-1], types[i])
70+
if !isValid {
10471
return ErrInvalidSeqFuncs
10572
}
106-
i++
107-
j++
10873
}
10974

11075
return nil

until.go

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -70,31 +70,14 @@ func validateUntilFuncs(testFn, fn AsyncFn) (isNoParam bool) {
7070
panic(ErrInvalidTestFunc)
7171
}
7272

73-
ii := 0 // index of the test function input parameters list
74-
oi := 0 // index of the function return values list
7573
numIn := tft.NumIn()
76-
isTakeContext := isFuncTakesContext(tft)
77-
if isTakeContext {
78-
numIn--
79-
ii++
80-
}
81-
if numIn != 0 && numIn != ft.NumOut() {
82-
panic(ErrInvalidTestFunc)
83-
}
84-
if numIn == 0 {
74+
if numIn == 0 || (numIn == 1 && isFuncTakesContext(tft)) {
8575
return true
8676
}
8777

88-
for oi < numIn {
89-
it := tft.In(ii) // type of the value in the test function input parameters list
90-
ot := ft.Out(oi) // type of the value in the function return values list
91-
92-
if it != ot && !it.ConvertibleTo(ot) {
93-
panic(ErrInvalidTestFunc)
94-
}
95-
96-
ii++
97-
oi++
78+
isValid := isValidNextFunc(ft, tft)
79+
if !isValid {
80+
panic(ErrInvalidTestFunc)
9881
}
9982

10083
return false

0 commit comments

Comments
 (0)