Skip to content

Commit 425c9e6

Browse files
committed
add rate limit
1 parent f7966f0 commit 425c9e6

File tree

10 files changed

+94
-19
lines changed

10 files changed

+94
-19
lines changed

.gitignore

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
.idea
2525
main
2626
main.exe
27-
.codebase_index/
2827
*.iml
2928
*.scip
3029
bin/*
3130
.roo
3231
docs/task.md
33-
testfiles
32+
testfiles
33+
cmd.sh

deploy/deployment.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ data:
8787
log_level: "info"
8888
max_concurrency: 5
8989
skip_patterns: []
90+
TokenLimit:
91+
max_running_tasks: 1
92+
enabled: true
9093
---
9194
apiVersion: apps/v1
9295
kind: Deployment

docs/charts/gettoken.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
2+
```mermaid
3+
graph TD
4+
A[生成Token接口] --> B[读取限流配置文件]
5+
B --> C[查询任务池正运行任务]
6+
C --> D{到达限流配置?}
7+
D -->|是| E[生成失败]
8+
D -->|否| F[根据ClientId生成Token]
9+
```

etc/conf.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,7 @@ Validation:
8989
log_level: "info"
9090
max_concurrency: 5
9191
skip_patterns: []
92+
93+
TokenLimit:
94+
max_running_tasks: 10
95+
enabled: true

internal/config/config.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ type Config struct {
1717
VectorStore VectorStoreConf
1818
Cleaner CleanerConf
1919
Validation ValidationConfig
20+
TokenLimit TokenLimitConf
21+
}
22+
23+
// TokenLimitConf token限流配置
24+
type TokenLimitConf struct {
25+
MaxRunningTasks int `json:"max_running_tasks" yaml:"max_running_tasks"`
26+
Enabled bool `json:"enabled" yaml:"enabled"`
2027
}
2128

2229
// Validate 实现 Validator 接口

internal/handler/token.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package handler
22

33
import (
4+
"errors"
45
"net/http"
56

67
"github.com/zgsm-ai/codebase-indexer/internal/logic"
@@ -33,11 +34,16 @@ func (h *tokenHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
3334
}
3435

3536
// 创建token逻辑
36-
tokenLogic := logic.NewTokenLogic(r.Context())
37+
tokenLogic := logic.NewTokenLogic(r.Context(), h.svcCtx)
3738
tokenResp, err := tokenLogic.GenerateToken(&req)
3839
if err != nil {
40+
// 检查是否为限流错误
41+
if errors.Is(err, types.ErrRateLimitReached) {
42+
response.RateLimit(w, err)
43+
return
44+
}
3945
response.Error(w, err)
4046
return
4147
}
4248
response.Json(w, tokenResp)
43-
}
49+
}

internal/logic/token.go

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,22 @@ import (
55
"errors"
66
"fmt"
77
"time"
8+
9+
"github.com/zgsm-ai/codebase-indexer/internal/svc"
810
"github.com/zgsm-ai/codebase-indexer/internal/types"
911
)
1012

1113
// TokenLogic token生成逻辑
1214
type TokenLogic struct {
1315
ctx context.Context
16+
svcCtx *svc.ServiceContext
1417
}
1518

1619
// NewTokenLogic 创建TokenLogic实例
17-
func NewTokenLogic(ctx context.Context) *TokenLogic {
20+
func NewTokenLogic(ctx context.Context, svcCtx *svc.ServiceContext) *TokenLogic {
1821
return &TokenLogic{
1922
ctx: ctx,
23+
svcCtx: svcCtx,
2024
}
2125
}
2226

@@ -26,6 +30,31 @@ func (l *TokenLogic) GenerateToken(req *types.TokenRequest) (*types.TokenRespons
2630
return nil, fmt.Errorf("invalid request: %w", err)
2731
}
2832

33+
// 1. 读取限流配置文件
34+
tokenLimit := l.svcCtx.Config.TokenLimit
35+
if !tokenLimit.Enabled {
36+
// 限流未启用,直接生成token
37+
return l.generateToken(req)
38+
}
39+
40+
// 2. 查询任务池正运行任务
41+
runningTasks, err := l.getRunningTasksCount()
42+
if err != nil {
43+
return nil, fmt.Errorf("查询运行中任务失败: %w", err)
44+
}
45+
46+
// 3. 判断是否到达限流配置
47+
if runningTasks >= tokenLimit.MaxRunningTasks {
48+
// 4. 生成失败
49+
return nil, types.ErrRateLimitReached
50+
}
51+
52+
// 5. 根据ClientId生成Token
53+
return l.generateToken(req)
54+
}
55+
56+
// generateToken 生成token
57+
func (l *TokenLogic) generateToken(req *types.TokenRequest) (*types.TokenResponseData, error) {
2958
// 使用clientId和codebasePath生成token
3059
// 这里使用简单的哈希组合,实际生产环境应使用更安全的JWT实现
3160
token := fmt.Sprintf("%s_%s_%s", req.ClientId, req.CodebasePath, l.generateRandomString(16))
@@ -37,6 +66,16 @@ func (l *TokenLogic) GenerateToken(req *types.TokenRequest) (*types.TokenRespons
3766
}, nil
3867
}
3968

69+
// getRunningTasksCount 获取运行中任务数量
70+
func (l *TokenLogic) getRunningTasksCount() (int, error) {
71+
// 获取任务池中正在运行的任务数量
72+
if l.svcCtx.TaskPool == nil {
73+
return 0, fmt.Errorf("任务池未初始化")
74+
}
75+
76+
return l.svcCtx.TaskPool.Running(), nil
77+
}
78+
4079
// validateRequest 验证请求参数
4180
func (l *TokenLogic) validateRequest(req *types.TokenRequest) error {
4281
if req.ClientId == "" {
@@ -51,27 +90,16 @@ func (l *TokenLogic) validateRequest(req *types.TokenRequest) error {
5190
return nil
5291
}
5392

54-
// generateJTI 生成唯一的JWT ID
55-
func (l *TokenLogic) generateJTI() string {
56-
return fmt.Sprintf("%d-%s", time.Now().UnixNano(), l.generateRandomString(8))
57-
}
58-
5993
// generateRandomString 生成随机字符串
6094
func (l *TokenLogic) generateRandomString(length int) string {
6195
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
6296
result := make([]byte, length)
63-
97+
6498
// 使用更稳定的随机源
6599
seed := time.Now().UnixNano()
66100
for i := range result {
67-
seed = (seed * 1103515245 + 12345) & 0x7fffffff
101+
seed = (seed*1103515245 + 12345) & 0x7fffffff
68102
result[i] = charset[seed%int64(len(charset))]
69103
}
70104
return string(result)
71105
}
72-
73-
// getSecretKey 获取JWT签名密钥
74-
func (l *TokenLogic) getSecretKey() string {
75-
// 默认密钥(仅用于开发环境)
76-
return "default-secret-key-change-in-production"
77-
}

internal/response/code_msg.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,8 @@ func NewAuthError(msg string) error {
3030
func NewPermissionError(msg string) error {
3131
return &codeMsg{Code: 403, Message: msg}
3232
}
33+
34+
// NewRateLimitError creates a new rate limit error.
35+
func NewRateLimitError(msg string) error {
36+
return &codeMsg{Code: 429, Message: msg}
37+
}

internal/response/resp.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ package response
22

33
import (
44
"context"
5+
"net/http"
6+
57
"github.com/zeromicro/go-zero/core/logx"
68
"github.com/zeromicro/go-zero/rest/httpx"
7-
"net/http"
89
)
910

1011
const (
@@ -31,6 +32,11 @@ func Error(w http.ResponseWriter, e error) {
3132
httpx.WriteJson(w, http.StatusBadRequest, wrapResponse(e)) // TODO 500会触发go-zero breaker
3233
}
3334

35+
func RateLimit(w http.ResponseWriter, e error) {
36+
logx.WithCallerSkip(2).Errorf("rate limit error: %v", e)
37+
httpx.WriteJson(w, http.StatusTooManyRequests, wrapResponse(e))
38+
}
39+
3440
func Bytes(w http.ResponseWriter, v []byte) {
3541
w.WriteHeader(http.StatusOK)
3642
_, _ = w.Write(v)

internal/types/errors.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,8 @@
11
package types
2+
3+
import "errors"
4+
5+
var (
6+
// ErrRateLimitReached 限流达到上限错误
7+
ErrRateLimitReached = errors.New("The system is busy. Please try again later (maximum number of concurrent tasks reached).")
8+
)

0 commit comments

Comments
 (0)