@@ -4,15 +4,18 @@ import (
44 "bytes"
55 "context"
66 "encoding/binary"
7+ "errors"
78 "fmt"
89 "io"
910 "net"
1011 "time"
1112
13+ "github.com/aws/aws-sdk-go-v2/aws"
1214 "github.com/aws/aws-sdk-go-v2/config"
15+ "github.com/aws/aws-sdk-go-v2/credentials/stscreds"
16+ "github.com/aws/aws-sdk-go-v2/service/sts"
1317 proxyconfig "github.com/grepplabs/kafka-proxy/config"
1418 "github.com/grepplabs/kafka-proxy/proxy/protocol"
15- "github.com/pkg/errors"
1619 "github.com/sirupsen/logrus"
1720)
1821
@@ -44,6 +47,21 @@ func NewAwsMSKIamAuth(
4447 if err != nil {
4548 return nil , fmt .Errorf ("loading aws config: %v" , err )
4649 }
50+ if awsConfig .RoleArn != "" {
51+ stsClient := sts .NewFromConfig (cfg )
52+ assumeRoleProvider := stscreds .NewAssumeRoleProvider (stsClient , awsConfig .RoleArn )
53+ cfg .Credentials = aws .NewCredentialsCache (assumeRoleProvider )
54+ }
55+ if awsConfig .IdentityLookup {
56+ ctx := context .Background ()
57+ ctx , cancel := context .WithTimeout (ctx , 15 * time .Second )
58+ defer cancel ()
59+ output , err := sts .NewFromConfig (cfg ).GetCallerIdentity (ctx , & sts.GetCallerIdentityInput {})
60+ if err != nil {
61+ return nil , fmt .Errorf ("failed to get caller identity: %v" , err )
62+ }
63+ logrus .Infof ("AWS_MSK_IAM caller identity %s" , aws .ToString (output .Arn ))
64+ }
4765 return & AwsMSKIamAuth {
4866 clientID : clientId ,
4967 signer : newMechanism (cfg ),
@@ -56,11 +74,11 @@ func NewAwsMSKIamAuth(
5674// sendAndReceiveSASLAuth handles the entire SASL authentication process
5775func (a * AwsMSKIamAuth ) sendAndReceiveSASLAuth (conn DeadlineReaderWriter , brokerString string ) error {
5876 if err := a .saslHandshake (conn ); err != nil {
59- return errors . Wrap ( err , "handshake failed" )
77+ return fmt . Errorf ( "handshake failed: %w" , err )
6078 }
6179
6280 if err := a .saslAuthenticate (conn , brokerString ); err != nil {
63- return errors . Wrap ( err , "authenticate failed" )
81+ return fmt . Errorf ( "authenticate failed: %w" , err )
6482 }
6583
6684 return nil
@@ -76,21 +94,21 @@ func (a *AwsMSKIamAuth) saslHandshake(conn DeadlineReaderWriter) error {
7694 Body : rb ,
7795 }
7896 if err := a .write (conn , req ); err != nil {
79- return errors . Wrap ( err , "writing SASL handshake" )
97+ return fmt . Errorf ( "writing SASL handshake: %w" , err )
8098 }
8199
82100 payload , err := a .read (conn )
83101 if err != nil {
84- return errors . Wrap ( err , "reading SASL handshake" )
102+ return fmt . Errorf ( "reading SASL handshake: %w" , err )
85103 }
86104
87105 res := & protocol.SaslHandshakeResponseV0orV1 {}
88106 if err := protocol .Decode (payload , res ); err != nil {
89- return errors . Wrap ( err , "parsing SASL handshake response" )
107+ return fmt . Errorf ( "parsing SASL handshake response: %w" , err )
90108 }
91109
92- if res .Err != protocol .ErrNoError {
93- return errors . Wrap ( res . Err , "sasl handshake protocol error" )
110+ if ! errors . Is ( res .Err , protocol .ErrNoError ) {
111+ return fmt . Errorf ( "sasl handshake protocol error: %w" , res . Err )
94112 }
95113 logrus .Debugf ("Successful IAM SASL handshake. Available mechanisms: %v" , res .EnabledMechanisms )
96114 return nil
@@ -114,59 +132,59 @@ func (a *AwsMSKIamAuth) saslAuthenticate(conn DeadlineReaderWriter, brokerString
114132 Body : saslAuthReqV0 ,
115133 }
116134 if err := a .write (conn , req ); err != nil {
117- return errors . Wrap ( err , "writing SASL authentication request" )
135+ return fmt . Errorf ( "writing SASL authentication request: %w" , err )
118136 }
119137
120138 payload , err := a .read (conn )
121139 if err != nil {
122- return errors . Wrap ( err , "reading SASL authentication response" )
140+ return fmt . Errorf ( "reading SASL authentication response: %w" , err )
123141 }
124142
125143 res := & protocol.SaslAuthenticateResponseV0 {}
126144 err = protocol .Decode (payload , res )
127145 if err != nil {
128- return errors . Wrap ( err , "parsing SASL authentication response" )
146+ return fmt . Errorf ( "parsing SASL authentication response: %w" , err )
129147 }
130- if res .Err != protocol .ErrNoError {
131- return errors . Wrap ( res . Err , "sasl authentication protocol error" )
148+ if ! errors . Is ( res .Err , protocol .ErrNoError ) {
149+ return fmt . Errorf ( "sasl authentication protocol error: %w" , res . Err )
132150 }
133151 return nil
134152}
135153
136154func (a * AwsMSKIamAuth ) write (conn DeadlineReaderWriter , req * protocol.Request ) error {
137155 reqBuf , err := protocol .Encode (req )
138156 if err != nil {
139- return errors . Wrap ( err , "serializing request" )
157+ return fmt . Errorf ( "serializing request: %w" , err )
140158 }
141159
142160 sizeBuf := make ([]byte , 4 )
143161 binary .BigEndian .PutUint32 (sizeBuf , uint32 (len (reqBuf )))
144162
145163 if err := conn .SetWriteDeadline (time .Now ().Add (a .writeTimeout )); err != nil {
146- return errors . Wrap ( err , "setting write deadline" )
164+ return fmt . Errorf ( "setting write deadline: %w" , err )
147165 }
148166
149167 if _ , err := conn .Write (bytes .Join ([][]byte {sizeBuf , reqBuf }, nil )); err != nil {
150- return errors . Wrap ( err , "writing bytes" )
168+ return fmt . Errorf ( "writing bytes: %w" , err )
151169 }
152170 return nil
153171}
154172
155173func (a * AwsMSKIamAuth ) read (conn DeadlineReaderWriter ) ([]byte , error ) {
156174 if err := conn .SetReadDeadline (time .Now ().Add (a .readTimeout )); err != nil {
157- return nil , errors . Wrap ( err , "setting read deadline" )
175+ return nil , fmt . Errorf ( "setting read deadline: %w" , err )
158176 }
159177
160178 //wait for the handshake response
161179 header := make ([]byte , 8 ) // response header
162180 if _ , err := io .ReadFull (conn , header ); err != nil {
163- return nil , errors . Wrap ( err , "reading header" )
181+ return nil , fmt . Errorf ( "reading header: %w" , err )
164182 }
165183
166184 length := binary .BigEndian .Uint32 (header [:4 ])
167185 payload := make ([]byte , length - 4 )
168186 if _ , err := io .ReadFull (conn , payload ); err != nil {
169- return nil , errors . Wrap ( err , "reading payload" )
187+ return nil , fmt . Errorf ( "reading payload: %w" , err )
170188 }
171189
172190 return payload , nil
0 commit comments