@@ -3,8 +3,8 @@ package epochkghandler
33import (
44 "bytes"
55 "context"
6- "math"
76
7+ lru "github.com/hashicorp/golang-lru/v2"
88 "github.com/jackc/pgx/v4"
99 "github.com/jackc/pgx/v4/pgxpool"
1010 pubsub "github.com/libp2p/go-libp2p-pubsub"
@@ -13,6 +13,7 @@ import (
1313 "github.com/shutter-network/shutter/shlib/shcrypto"
1414
1515 "github.com/shutter-network/rolling-shutter/rolling-shutter/keyper/database"
16+ "github.com/shutter-network/rolling-shutter/rolling-shutter/medley"
1617 "github.com/shutter-network/rolling-shutter/rolling-shutter/p2p"
1718 "github.com/shutter-network/rolling-shutter/rolling-shutter/p2pmsg"
1819 "github.com/shutter-network/rolling-shutter/rolling-shutter/shdb"
@@ -21,12 +22,16 @@ import (
2122const MaxNumKeysPerMessage = 128
2223
2324func NewDecryptionKeyHandler (config Config , dbpool * pgxpool.Pool ) p2p.MessageHandler {
24- return & DecryptionKeyHandler {config : config , dbpool : dbpool }
25+ // Not catching the error as it only can happen if non-positive size was applied
26+ cache , _ := lru.New [shcrypto.EpochSecretKey , []byte ](1024 )
27+ return & DecryptionKeyHandler {config : config , dbpool : dbpool , cache : cache }
2528}
2629
2730type DecryptionKeyHandler struct {
2831 config Config
2932 dbpool * pgxpool.Pool
33+ // keep 1024 verified keys in Cache to skip additional verifications
34+ cache * lru.Cache [shcrypto.EpochSecretKey , []byte ]
3035}
3136
3237func (* DecryptionKeyHandler ) MessagePrototypes () []p2pmsg.Message {
@@ -39,23 +44,23 @@ func (handler *DecryptionKeyHandler) ValidateMessage(ctx context.Context, msg p2
3944 return pubsub .ValidationReject ,
4045 errors .Errorf ("instance ID mismatch (want=%d, have=%d)" , handler .config .GetInstanceID (), key .GetInstanceID ())
4146 }
42- if key .Eon > math .MaxInt64 {
43- return pubsub .ValidationReject , errors .Errorf ("eon %d overflows int64" , key .Eon )
47+ eon , err := medley .Uint64ToInt64Safe (key .Eon )
48+ if err != nil {
49+ return pubsub .ValidationReject , errors .Wrapf (err , "overflow error while converting eon to int64 %d" , eon )
4450 }
45-
46- dkgResultDB , err := database .New (handler .dbpool ).GetDKGResultForKeyperConfigIndex (ctx , int64 (key .Eon ))
47- if err == pgx .ErrNoRows {
48- return pubsub .ValidationReject , errors .Errorf ("no DKG result found for eon %d" , key .Eon )
51+ dkgResultDB , err := database .New (handler .dbpool ).GetDKGResultForKeyperConfigIndex (ctx , eon )
52+ if errors .Is (err , pgx .ErrNoRows ) {
53+ return pubsub .ValidationReject , errors .Errorf ("no DKG result found for eon %d" , eon )
4954 }
5055 if err != nil {
51- return pubsub .ValidationReject , errors .Wrapf (err , "failed to get dkg result for eon %d from db" , key . Eon )
56+ return pubsub .ValidationReject , errors .Wrapf (err , "failed to get dkg result for eon %d from db" , eon )
5257 }
5358 if ! dkgResultDB .Success {
54- return pubsub .ValidationReject , errors .Errorf ("no successful DKG result found for eon %d" , key . Eon )
59+ return pubsub .ValidationReject , errors .Errorf ("no successful DKG result found for eon %d" , eon )
5560 }
5661 pureDKGResult , err := shdb .DecodePureDKGResult (dkgResultDB .PureResult )
5762 if err != nil {
58- return pubsub .ValidationReject , errors .Wrapf (err , "error while decoding pure DKG result for eon %d" , key . Eon )
63+ return pubsub .ValidationReject , errors .Wrapf (err , "error while decoding pure DKG result for eon %d" , eon )
5964 }
6065
6166 if len (key .Keys ) == 0 {
@@ -64,19 +69,26 @@ func (handler *DecryptionKeyHandler) ValidateMessage(ctx context.Context, msg p2
6469 if len (key .Keys ) > MaxNumKeysPerMessage {
6570 return pubsub .ValidationReject , errors .Errorf ("too many keys in message (%d > %d)" , len (key .Keys ), MaxNumKeysPerMessage )
6671 }
72+
6773 for i , k := range key .Keys {
6874 epochSecretKey , err := k .GetEpochSecretKey ()
6975 if err != nil {
7076 return pubsub .ValidationReject , err
7177 }
78+ identity , exists := handler .cache .Get (* epochSecretKey )
79+ if exists {
80+ if bytes .Equal (k .Identity , identity ) {
81+ continue
82+ }
83+ return pubsub .ValidationReject , errors .Errorf ("epoch secret key for identity %x is not valid" , k .Identity )
84+ }
7285 ok , err := shcrypto .VerifyEpochSecretKey (epochSecretKey , pureDKGResult .PublicKey , k .Identity )
7386 if err != nil {
7487 return pubsub .ValidationReject , errors .Wrapf (err , "error while checking epoch secret key for identity %x" , k .Identity )
7588 }
7689 if ! ok {
7790 return pubsub .ValidationReject , errors .Errorf ("epoch secret key for identity %x is not valid" , k .Identity )
7891 }
79-
8092 if i > 0 && bytes .Compare (k .Identity , key .Keys [i - 1 ].Identity ) < 0 {
8193 return pubsub .ValidationReject , errors .Errorf ("keys not ordered" )
8294 }
@@ -87,7 +99,15 @@ func (handler *DecryptionKeyHandler) ValidateMessage(ctx context.Context, msg p2
8799func (handler * DecryptionKeyHandler ) HandleMessage (ctx context.Context , msg p2pmsg.Message ) ([]p2pmsg.Message , error ) {
88100 metricsEpochKGDecryptionKeysReceived .Inc ()
89101 key := msg .(* p2pmsg.DecryptionKeys )
90- // Insert the key into the db. We assume that it's valid as it already passed the libp2p
91- // validator.
102+ // We assume that it's valid as it already passed the libp2p validator.
103+ // Insert the key into the cache.
104+ for _ , k := range key .Keys {
105+ epochSecretKey , err := k .GetEpochSecretKey ()
106+ if err != nil {
107+ return nil , err
108+ }
109+ handler .cache .Add (* epochSecretKey , k .Identity )
110+ }
111+ // Insert the key into the db.
92112 return nil , database .New (handler .dbpool ).InsertDecryptionKeysMsg (ctx , key )
93113}
0 commit comments