Skip to content

Commit e88af69

Browse files
committed
Add RetryAfter and ResetAfter to AllowAtMostN
1 parent 1e4c1e2 commit e88af69

File tree

3 files changed

+41
-13
lines changed

3 files changed

+41
-13
lines changed

lua.go

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,21 @@ local allow_at = new_tat - burst_offset
4545
local diff = now - allow_at
4646
local remaining = math.floor(diff / emission_interval + 0.5)
4747
48-
if remaining >= 0 then
49-
local reset_after = new_tat - now
50-
redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
51-
local retry_after = -1
52-
return {cost, remaining, tostring(retry_after), tostring(reset_after)}
48+
if remaining < 0 then
49+
local reset_after = tat - now
50+
local retry_after = diff * -1
51+
return {
52+
0, -- allowed
53+
0, -- remaining
54+
tostring(retry_after),
55+
tostring(reset_after),
56+
}
5357
end
5458
55-
remaining = 0
56-
local reset_after = tat - now
57-
local retry_after = diff * -1
58-
return {0, remaining, tostring(retry_after), tostring(reset_after)}
59+
local reset_after = new_tat - now
60+
redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
61+
local retry_after = -1
62+
return {cost, remaining, tostring(retry_after), tostring(reset_after)}
5963
`)
6064

6165
var allowAtMost = redis.NewScript(`
@@ -96,9 +100,13 @@ local diff = now - (tat - burst_offset)
96100
local remaining = math.floor(diff / emission_interval + 0.5)
97101
98102
if remaining == 0 then
103+
local reset_after = tat - now
104+
local retry_after = emission_interval - diff
99105
return {
100106
0, -- allowed
101107
0, -- remaining
108+
tostring(retry_after),
109+
tostring(reset_after),
102110
}
103111
end
104112
@@ -118,5 +126,7 @@ redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
118126
return {
119127
cost,
120128
remaining,
129+
tostring(-1),
130+
tostring(reset_after),
121131
}
122132
`)

rate.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ func (l *Limiter) AllowN(
102102
}
103103

104104
// AllowAtMostN reports whether at most n events may happen at time now.
105-
// It returns number of allowed events. RetryAfter and ResetAfter are not set.
105+
// It returns number of allowed events that is less than or equal to n.
106106
func (l *Limiter) AllowAtMostN(
107107
ctx context.Context,
108108
key string,
@@ -117,10 +117,22 @@ func (l *Limiter) AllowAtMostN(
117117

118118
values = v.([]interface{})
119119

120+
retryAfter, err := strconv.ParseFloat(values[2].(string), 64)
121+
if err != nil {
122+
return nil, err
123+
}
124+
125+
resetAfter, err := strconv.ParseFloat(values[3].(string), 64)
126+
if err != nil {
127+
return nil, err
128+
}
129+
120130
res := &Result{
121-
Limit: limit,
122-
Allowed: int(values[0].(int64)),
123-
Remaining: int(values[1].(int64)),
131+
Limit: limit,
132+
Allowed: int(values[0].(int64)),
133+
Remaining: int(values[1].(int64)),
134+
RetryAfter: dur(retryAfter),
135+
ResetAfter: dur(resetAfter),
124136
}
125137
return res, nil
126138
}

rate_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ func TestAllowAtMostN(t *testing.T) {
7373
assert.Nil(t, err)
7474
assert.Equal(t, res.Allowed, 2)
7575
assert.Equal(t, res.Remaining, 7)
76+
assert.Equal(t, res.RetryAfter, time.Duration(-1))
77+
assert.InDelta(t, res.ResetAfter, 300*time.Millisecond, float64(10*time.Millisecond))
7678

7779
res, err = l.AllowN(ctx, "test_id", limit, 0)
7880
assert.Nil(t, err)
@@ -85,6 +87,8 @@ func TestAllowAtMostN(t *testing.T) {
8587
assert.Nil(t, err)
8688
assert.Equal(t, res.Allowed, 7)
8789
assert.Equal(t, res.Remaining, 0)
90+
assert.Equal(t, res.RetryAfter, time.Duration(-1))
91+
assert.InDelta(t, res.ResetAfter, 999*time.Millisecond, float64(10*time.Millisecond))
8892

8993
res, err = l.AllowN(ctx, "test_id", limit, 0)
9094
assert.Nil(t, err)
@@ -97,6 +101,8 @@ func TestAllowAtMostN(t *testing.T) {
97101
assert.Nil(t, err)
98102
assert.Equal(t, res.Allowed, 0)
99103
assert.Equal(t, res.Remaining, 0)
104+
assert.InDelta(t, res.RetryAfter, 99*time.Millisecond, float64(10*time.Millisecond))
105+
assert.InDelta(t, res.ResetAfter, 999*time.Millisecond, float64(10*time.Millisecond))
100106

101107
res, err = l.AllowN(ctx, "test_id", limit, 1000)
102108
assert.Nil(t, err)

0 commit comments

Comments
 (0)