@@ -19,12 +19,16 @@ import (
1919 "testing"
2020
2121 "github.com/aws/aws-sdk-go/aws"
22+ "github.com/aws/aws-sdk-go/aws/credentials"
2223 "github.com/aws/aws-sdk-go/aws/request"
2324 "github.com/aws/aws-sdk-go/aws/session"
25+ "github.com/aws/aws-sdk-go/service/kinesis"
26+ "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
2427 "github.com/aws/aws-sdk-go/service/sqs"
2528 "github.com/aws/aws-sdk-go/service/sqs/sqsiface"
2629 "github.com/stretchr/testify/assert"
2730 "github.com/stretchr/testify/require"
31+ "github.com/vmware/vmware-go-kcl/clientlibrary/config"
2832)
2933
3034type mockedSQS struct {
@@ -36,7 +40,14 @@ func (m *mockedSQS) GetQueueUrlWithContext(ctx context.Context, input *sqs.GetQu
3640 return m .GetQueueURLFn (ctx , input )
3741}
3842
43+ type mockedKinesis struct {
44+ kinesisiface.KinesisAPI
45+ DescribeStreamFn func (ctx context.Context , input * kinesis.DescribeStreamInput ) (* kinesis.DescribeStreamOutput , error )
46+ }
3947
48+ func (m * mockedKinesis ) DescribeStreamWithContext (ctx context.Context , input * kinesis.DescribeStreamInput , opts ... request.Option ) (* kinesis.DescribeStreamOutput , error ) {
49+ return m .DescribeStreamFn (ctx , input )
50+ }
4051
4152func TestS3Clients_New (t * testing.T ) {
4253 tests := []struct {
@@ -116,4 +127,139 @@ func TestSqsClients_QueueURL(t *testing.T) {
116127 }
117128}
118129
130+ func TestKinesisClients_Stream (t * testing.T ) {
131+ tests := []struct {
132+ name string
133+ kinesisClient * KinesisClients
134+ streamName string
135+ mockStreamARN * string
136+ mockError error
137+ expectedStream * string
138+ expectedErr error
139+ }{
140+ {
141+ name : "successfully retrieves stream ARN" ,
142+ kinesisClient : & KinesisClients {
143+ Kinesis : & mockedKinesis {DescribeStreamFn : func (ctx context.Context , input * kinesis.DescribeStreamInput ) (* kinesis.DescribeStreamOutput , error ) {
144+ return & kinesis.DescribeStreamOutput {
145+ StreamDescription : & kinesis.StreamDescription {
146+ StreamARN : aws .String ("arn:aws:kinesis:some-region:123456789012:stream/some-stream" ),
147+ },
148+ }, nil
149+ }},
150+ Region : "us-west-1" ,
151+ Credentials : credentials .NewStaticCredentials ("accessKey" , "secretKey" , "" ),
152+ },
153+ streamName : "some-stream" ,
154+ expectedStream : aws .String ("arn:aws:kinesis:some-region:123456789012:stream/some-stream" ),
155+ expectedErr : nil ,
156+ },
157+ {
158+ name : "returns error when stream not found" ,
159+ kinesisClient : & KinesisClients {
160+ Kinesis : & mockedKinesis {DescribeStreamFn : func (ctx context.Context , input * kinesis.DescribeStreamInput ) (* kinesis.DescribeStreamOutput , error ) {
161+ return nil , errors .New ("stream not found" )
162+ }},
163+ Region : "us-west-1" ,
164+ Credentials : credentials .NewStaticCredentials ("accessKey" , "secretKey" , "" ),
165+ },
166+ streamName : "nonexistent-stream" ,
167+ expectedStream : nil ,
168+ expectedErr : errors .New ("unable to get stream arn due to empty client" ),
169+ },
170+ }
171+
172+ for _ , tt := range tests {
173+ t .Run (tt .name , func (t * testing.T ) {
174+ got , err := tt .kinesisClient .Stream (t .Context (), tt .streamName )
175+ if tt .expectedErr != nil {
176+ require .Error (t , err )
177+ assert .Equal (t , tt .expectedErr .Error (), err .Error ())
178+ } else {
179+ require .NoError (t , err )
180+ assert .Equal (t , tt .expectedStream , got )
181+ }
182+ })
183+ }
184+ }
185+
186+ func TestKinesisClients_WorkerCfg (t * testing.T ) {
187+ testCreds := credentials .NewStaticCredentials ("accessKey" , "secretKey" , "" )
188+ tests := []struct {
189+ name string
190+ kinesisClient * KinesisClients
191+ streamName string
192+ consumer string
193+ mode string
194+ expectedConfig * config.KinesisClientLibConfiguration
195+ }{
196+ {
197+ name : "successfully creates shared mode worker config" ,
198+ kinesisClient : & KinesisClients {
199+ Kinesis : & mockedKinesis {
200+ DescribeStreamFn : func (ctx context.Context , input * kinesis.DescribeStreamInput ) (* kinesis.DescribeStreamOutput , error ) {
201+ return & kinesis.DescribeStreamOutput {
202+ StreamDescription : & kinesis.StreamDescription {
203+ StreamARN : aws .String ("arn:aws:kinesis:us-east-1:123456789012:stream/existing-stream" ),
204+ },
205+ }, nil
206+ },
207+ },
208+ Region : "us-west-1" ,
209+ Credentials : testCreds ,
210+ },
211+ streamName : "existing-stream" ,
212+ consumer : "consumer1" ,
213+ mode : "shared" ,
214+ expectedConfig : config .NewKinesisClientLibConfigWithCredential (
215+ "consumer1" , "existing-stream" , "us-west-1" , "consumer1" , testCreds ,
216+ ),
217+ },
218+ {
219+ name : "returns nil when mode is not shared" ,
220+ kinesisClient : & KinesisClients {
221+ Kinesis : & mockedKinesis {
222+ DescribeStreamFn : func (ctx context.Context , input * kinesis.DescribeStreamInput ) (* kinesis.DescribeStreamOutput , error ) {
223+ return & kinesis.DescribeStreamOutput {
224+ StreamDescription : & kinesis.StreamDescription {
225+ StreamARN : aws .String ("arn:aws:kinesis:us-east-1:123456789012:stream/existing-stream" ),
226+ },
227+ }, nil
228+ },
229+ },
230+ Region : "us-west-1" ,
231+ Credentials : testCreds ,
232+ },
233+ streamName : "existing-stream" ,
234+ consumer : "consumer1" ,
235+ mode : "exclusive" ,
236+ expectedConfig : nil ,
237+ },
238+ {
239+ name : "returns nil when client is nil" ,
240+ kinesisClient : & KinesisClients {
241+ Kinesis : nil ,
242+ Region : "us-west-1" ,
243+ Credentials : credentials .NewStaticCredentials ("accessKey" , "secretKey" , "" ),
244+ },
245+ streamName : "existing-stream" ,
246+ consumer : "consumer1" ,
247+ mode : "shared" ,
248+ expectedConfig : nil ,
249+ },
250+ }
119251
252+ for _ , tt := range tests {
253+ t .Run (tt .name , func (t * testing.T ) {
254+ cfg := tt .kinesisClient .WorkerCfg (t .Context (), tt .streamName , tt .consumer , tt .mode )
255+ if tt .expectedConfig == nil {
256+ assert .Equal (t , tt .expectedConfig , cfg )
257+ return
258+ }
259+ assert .Equal (t , tt .expectedConfig .StreamName , cfg .StreamName )
260+ assert .Equal (t , tt .expectedConfig .EnhancedFanOutConsumerName , cfg .EnhancedFanOutConsumerName )
261+ assert .Equal (t , tt .expectedConfig .EnableEnhancedFanOutConsumer , cfg .EnableEnhancedFanOutConsumer )
262+ assert .Equal (t , tt .expectedConfig .RegionName , cfg .RegionName )
263+ })
264+ }
265+ }
0 commit comments