@@ -14,6 +14,7 @@ import (
1414 "context"
1515 "crypto/tls"
1616 "encoding/base64"
17+ "encoding/json"
1718 "fmt"
1819 "io/ioutil"
1920 "net"
@@ -30,6 +31,7 @@ import (
3031 "go.mongodb.org/mongo-driver/v2/internal/handshake"
3132 "go.mongodb.org/mongo-driver/v2/internal/integration/mtest"
3233 "go.mongodb.org/mongo-driver/v2/internal/integtest"
34+ "go.mongodb.org/mongo-driver/v2/internal/require"
3335 "go.mongodb.org/mongo-driver/v2/mongo"
3436 "go.mongodb.org/mongo-driver/v2/mongo/options"
3537 "go.mongodb.org/mongo-driver/v2/mongo/writeconcern"
@@ -2925,7 +2927,7 @@ func TestClientSideEncryptionProse(t *testing.T) {
29252927 }
29262928 })
29272929
2928- mt .RunOpts ("22 . range explicit encryption applies defaults" , qeRunOpts22 , func (mt * mtest.T ) {
2930+ mt .RunOpts ("23 . range explicit encryption applies defaults" , qeRunOpts22 , func (mt * mtest.T ) {
29292931 err := mt .Client .Database ("keyvault" ).Collection ("datakeys" ).Drop (context .Background ())
29302932 assert .Nil (mt , err , "error on Drop: %v" , err )
29312933
@@ -2986,6 +2988,147 @@ func TestClientSideEncryptionProse(t *testing.T) {
29862988 assert .Greater (t , len (payload .Data ), len (payloadDefaults .Data ), "the returned payload size is expected to be greater than %d" , len (payloadDefaults .Data ))
29872989 })
29882990 })
2991+
2992+ mt .RunOpts ("24. kms retry tests" , noClientOpts , func (mt * mtest.T ) {
2993+ kmsTlsTestcase := os .Getenv ("KMS_FAILPOINT_SERVER_RUNNING" )
2994+ if kmsTlsTestcase == "" {
2995+ mt .Skipf ("Skipping test as KMS_FAILPOINT_SERVER_RUNNING is not set" )
2996+ }
2997+
2998+ mt .Parallel ()
2999+
3000+ tlsCAFile := os .Getenv ("KMS_FAILPOINT_CA_FILE" )
3001+ require .NotEqual (mt , tlsCAFile , "" , "failed to load CA file" )
3002+
3003+ clientAndCATlsMap := map [string ]interface {}{
3004+ "tlsCAFile" : tlsCAFile ,
3005+ }
3006+ tlsCfg , err := options .BuildTLSConfig (clientAndCATlsMap )
3007+ require .NoError (mt , err , "BuildTLSConfig error: %v" , err )
3008+
3009+ setFailPoint := func (failure string , count int ) error {
3010+ url := fmt .Sprintf ("https://localhost:9003/set_failpoint/%s" , failure )
3011+ var payloadBuf bytes.Buffer
3012+ body := map [string ]int {"count" : count }
3013+ json .NewEncoder (& payloadBuf ).Encode (body )
3014+ req , err := http .NewRequest (http .MethodPost , url , & payloadBuf )
3015+ if err != nil {
3016+ return err
3017+ }
3018+
3019+ client := & http.Client {
3020+ Transport : & http.Transport {TLSClientConfig : tlsCfg },
3021+ }
3022+ res , err := client .Do (req )
3023+ if err != nil {
3024+ return err
3025+ }
3026+ return res .Body .Close ()
3027+ }
3028+
3029+ kmsProviders := map [string ]map [string ]interface {}{
3030+ "aws" : {
3031+ "accessKeyId" : awsAccessKeyID ,
3032+ "secretAccessKey" : awsSecretAccessKey ,
3033+ },
3034+ "azure" : {
3035+ "tenantId" : azureTenantID ,
3036+ "clientId" : azureClientID ,
3037+ "clientSecret" : azureClientSecret ,
3038+ "identityPlatformEndpoint" : "127.0.0.1:9003" ,
3039+ },
3040+ "gcp" : {
3041+ "email" : gcpEmail ,
3042+ "privateKey" : gcpPrivateKey ,
3043+ "endpoint" : "127.0.0.1:9003" ,
3044+ },
3045+ }
3046+
3047+ dataKeys := []struct {
3048+ provider string
3049+ masterKey interface {}
3050+ }{
3051+ {"aws" , bson.D {
3052+ {"region" , "foo" },
3053+ {"key" , "bar" },
3054+ {"endpoint" , "127.0.0.1:9003" },
3055+ }},
3056+ {"azure" , bson.D {
3057+ {"keyVaultEndpoint" , "127.0.0.1:9003" },
3058+ {"keyName" , "foo" },
3059+ }},
3060+ {"gcp" , bson.D {
3061+ {"projectId" , "foo" },
3062+ {"location" , "bar" },
3063+ {"keyRing" , "baz" },
3064+ {"keyName" , "qux" },
3065+ {"endpoint" , "127.0.0.1:9003" },
3066+ }},
3067+ }
3068+
3069+ testCases := []struct {
3070+ name string
3071+ failure string
3072+ }{
3073+ {"Case 1: createDataKey and encrypt with TCP retry" , "network" },
3074+ {"Case 2: createDataKey and encrypt with HTTP retry" , "http" },
3075+ }
3076+
3077+ for _ , tc := range testCases {
3078+ for _ , dataKey := range dataKeys {
3079+ mt .Run (fmt .Sprintf ("%s_%s" , tc .name , dataKey .provider ), func (mt * mtest.T ) {
3080+ keyVaultClient , err := mongo .Connect (options .Client ().ApplyURI (mtest .ClusterURI ()))
3081+ require .NoError (mt , err , "error on Connect: %v" , err )
3082+
3083+ ceo := options .ClientEncryption ().
3084+ SetKeyVaultNamespace (kvNamespace ).
3085+ SetKmsProviders (kmsProviders ).
3086+ SetTLSConfig (map [string ]* tls.Config {dataKey .provider : tlsCfg })
3087+ clientEncryption , err := mongo .NewClientEncryption (keyVaultClient , ceo )
3088+ require .NoError (mt , err , "error on NewClientEncryption: %v" , err )
3089+
3090+ err = setFailPoint (tc .failure , 1 )
3091+ require .NoError (mt , err , "mock server error: %v" , err )
3092+
3093+ dkOpts := options .DataKey ().SetMasterKey (dataKey .masterKey )
3094+ var keyID bson.Binary
3095+ keyID , err = clientEncryption .CreateDataKey (context .Background (), dataKey .provider , dkOpts )
3096+ require .NoError (mt , err , "error in CreateDataKey: %v" , err )
3097+
3098+ err = setFailPoint (tc .failure , 1 )
3099+ require .NoError (mt , err , "mock server error: %v" , err )
3100+
3101+ testVal := bson.RawValue {Type : bson .TypeInt32 , Value : bsoncore .AppendInt32 (nil , 123 )}
3102+ eo := options .Encrypt ().
3103+ SetKeyID (keyID ).
3104+ SetAlgorithm ("AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic" )
3105+ _ , err = clientEncryption .Encrypt (context .Background (), testVal , eo )
3106+ require .NoError (mt , err , "error in Encrypt: %v" , err )
3107+ })
3108+ }
3109+ }
3110+
3111+ for _ , dataKey := range dataKeys {
3112+ mt .Run (fmt .Sprintf ("Case 3: createDataKey fails after too many retries_%s" , dataKey .provider ), func (mt * mtest.T ) {
3113+ keyVaultClient , err := mongo .Connect (options .Client ().ApplyURI (mtest .ClusterURI ()))
3114+ require .NoError (mt , err , "error on Connect: %v" , err )
3115+
3116+ ceo := options .ClientEncryption ().
3117+ SetKeyVaultNamespace (kvNamespace ).
3118+ SetKmsProviders (kmsProviders ).
3119+ SetTLSConfig (map [string ]* tls.Config {dataKey .provider : tlsCfg })
3120+ clientEncryption , err := mongo .NewClientEncryption (keyVaultClient , ceo )
3121+ require .NoError (mt , err , "error on NewClientEncryption: %v" , err )
3122+
3123+ err = setFailPoint ("network" , 4 )
3124+ require .NoError (mt , err , "mock server error: %v" , err )
3125+
3126+ dkOpts := options .DataKey ().SetMasterKey (dataKey .masterKey )
3127+ _ , err = clientEncryption .CreateDataKey (context .Background (), dataKey .provider , dkOpts )
3128+ require .ErrorContains (mt , err , "KMS request failed after 3 retries due to a network error" )
3129+ })
3130+ }
3131+ })
29893132}
29903133
29913134func getWatcher (mt * mtest.T , streamType mongo.StreamType , cpt * cseProseTest ) watcher {
0 commit comments