@@ -5,18 +5,22 @@ import (
55 "errors"
66 "fmt"
77 "time"
8+
9+ "github.com/zgsm-ai/codebase-indexer/internal/svc"
810 "github.com/zgsm-ai/codebase-indexer/internal/types"
911)
1012
1113// TokenLogic token生成逻辑
1214type TokenLogic struct {
1315 ctx context.Context
16+ svcCtx * svc.ServiceContext
1417}
1518
1619// NewTokenLogic 创建TokenLogic实例
17- func NewTokenLogic (ctx context.Context ) * TokenLogic {
20+ func NewTokenLogic (ctx context.Context , svcCtx * svc. ServiceContext ) * TokenLogic {
1821 return & TokenLogic {
1922 ctx : ctx ,
23+ svcCtx : svcCtx ,
2024 }
2125}
2226
@@ -26,6 +30,31 @@ func (l *TokenLogic) GenerateToken(req *types.TokenRequest) (*types.TokenRespons
2630 return nil , fmt .Errorf ("invalid request: %w" , err )
2731 }
2832
33+ // 1. 读取限流配置文件
34+ tokenLimit := l .svcCtx .Config .TokenLimit
35+ if ! tokenLimit .Enabled {
36+ // 限流未启用,直接生成token
37+ return l .generateToken (req )
38+ }
39+
40+ // 2. 查询任务池正运行任务
41+ runningTasks , err := l .getRunningTasksCount ()
42+ if err != nil {
43+ return nil , fmt .Errorf ("查询运行中任务失败: %w" , err )
44+ }
45+
46+ // 3. 判断是否到达限流配置
47+ if runningTasks >= tokenLimit .MaxRunningTasks {
48+ // 4. 生成失败
49+ return nil , types .ErrRateLimitReached
50+ }
51+
52+ // 5. 根据ClientId生成Token
53+ return l .generateToken (req )
54+ }
55+
56+ // generateToken 生成token
57+ func (l * TokenLogic ) generateToken (req * types.TokenRequest ) (* types.TokenResponseData , error ) {
2958 // 使用clientId和codebasePath生成token
3059 // 这里使用简单的哈希组合,实际生产环境应使用更安全的JWT实现
3160 token := fmt .Sprintf ("%s_%s_%s" , req .ClientId , req .CodebasePath , l .generateRandomString (16 ))
@@ -37,6 +66,16 @@ func (l *TokenLogic) GenerateToken(req *types.TokenRequest) (*types.TokenRespons
3766 }, nil
3867}
3968
69+ // getRunningTasksCount 获取运行中任务数量
70+ func (l * TokenLogic ) getRunningTasksCount () (int , error ) {
71+ // 获取任务池中正在运行的任务数量
72+ if l .svcCtx .TaskPool == nil {
73+ return 0 , fmt .Errorf ("任务池未初始化" )
74+ }
75+
76+ return l .svcCtx .TaskPool .Running (), nil
77+ }
78+
4079// validateRequest 验证请求参数
4180func (l * TokenLogic ) validateRequest (req * types.TokenRequest ) error {
4281 if req .ClientId == "" {
@@ -51,27 +90,16 @@ func (l *TokenLogic) validateRequest(req *types.TokenRequest) error {
5190 return nil
5291}
5392
54- // generateJTI 生成唯一的JWT ID
55- func (l * TokenLogic ) generateJTI () string {
56- return fmt .Sprintf ("%d-%s" , time .Now ().UnixNano (), l .generateRandomString (8 ))
57- }
58-
5993// generateRandomString 生成随机字符串
6094func (l * TokenLogic ) generateRandomString (length int ) string {
6195 const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
6296 result := make ([]byte , length )
63-
97+
6498 // 使用更稳定的随机源
6599 seed := time .Now ().UnixNano ()
66100 for i := range result {
67- seed = (seed * 1103515245 + 12345 ) & 0x7fffffff
101+ seed = (seed * 1103515245 + 12345 ) & 0x7fffffff
68102 result [i ] = charset [seed % int64 (len (charset ))]
69103 }
70104 return string (result )
71105}
72-
73- // getSecretKey 获取JWT签名密钥
74- func (l * TokenLogic ) getSecretKey () string {
75- // 默认密钥(仅用于开发环境)
76- return "default-secret-key-change-in-production"
77- }
0 commit comments