@@ -748,20 +748,13 @@ def test_deprecated_mmds_config(uvm_plain):
748748 )
749749
750750
751- @pytest .mark .parametrize ("version" , MMDS_VERSIONS )
752- @pytest .mark .parametrize ("imds_compat" , [None , False , True ])
753- @pytest .mark .parametrize ("sdk" , ["py" , "go" ])
754- def test_aws_credential_provider (uvm_plain , version , imds_compat , sdk ):
755- """
756- Test AWS SDK's credential provider works on MMDS
757- """
758- test_microvm = uvm_plain
759- test_microvm .spawn ()
760- test_microvm .basic_config ()
761- test_microvm .add_net_iface ()
751+ def _configure_with_aws_credentials (microvm , version , imds_compat ):
752+ microvm .spawn ()
753+ microvm .basic_config ()
754+ microvm .add_net_iface ()
762755 # V2 requires session tokens for GET requests
763756 configure_mmds (
764- test_microvm , iface_ids = ["eth0" ], version = version , imds_compat = imds_compat
757+ microvm , iface_ids = ["eth0" ], version = version , imds_compat = imds_compat
765758 )
766759 now = datetime .now (timezone .utc )
767760 credentials = {
@@ -783,13 +776,24 @@ def test_aws_credential_provider(uvm_plain, version, imds_compat, sdk):
783776 }
784777 }
785778 }
786- populate_data_store (test_microvm , data_store )
787- test_microvm .start ()
788-
789- ssh_connection = test_microvm .ssh
779+ populate_data_store (microvm , data_store )
780+ microvm .start ()
790781
782+ ssh_connection = microvm .ssh
791783 run_guest_cmd (ssh_connection , f"ip route add { DEFAULT_IPV4 } dev eth0" , "" )
792784
785+ return ssh_connection
786+
787+
788+ @pytest .mark .parametrize ("version" , MMDS_VERSIONS )
789+ @pytest .mark .parametrize ("imds_compat" , [None , False , True ])
790+ @pytest .mark .parametrize ("sdk" , ["py" , "go" ])
791+ def test_aws_credential_provider (uvm_plain , version , imds_compat , sdk ):
792+ """
793+ Test AWS SDK's credential provider works on MMDS
794+ """
795+ ssh_connection = _configure_with_aws_credentials (uvm_plain , version , imds_compat )
796+
793797 match sdk :
794798 case "py" :
795799 cmd = r"""python3 - <<EOF
@@ -815,3 +819,35 @@ def test_aws_credential_provider(uvm_plain, version, imds_compat, sdk):
815819 cmd = "/usr/local/bin/go_sdk_cred_provider"
816820 _ , stdout , stderr = ssh_connection .check_output (cmd )
817821 assert stdout == "AAA,BBB,CCC\n " , stderr
822+
823+
824+ @pytest .mark .parametrize ("version" , MMDS_VERSIONS )
825+ @pytest .mark .parametrize ("imds_compat" , [None , False , True ])
826+ def test_go_sdk_credential_provider_with_custom_endpoint (
827+ uvm_plain , version , imds_compat
828+ ):
829+ """
830+ Test AWS SDK's credential provider with custom endpoint.
831+
832+ It sets "Accept: application/json" in a request to retrieve AWS credentials.
833+ If imds_compat is True, it should work. If False, it should NOT work,
834+ because MMDS responds a string of a JSON object containing the credentials
835+ (i.e. wrapped with doublequotes) with "Content-Type: application/json" but
836+ AWS SDK for Go expects only the inner JSON object.
837+ """
838+ ssh_connection = _configure_with_aws_credentials (uvm_plain , version , imds_compat )
839+
840+ cmd = "/usr/local/bin/go_sdk_cred_provider_with_custom_endpoint"
841+ ret , stdout , stderr = ssh_connection .run (cmd )
842+ if imds_compat :
843+ assert ret == 0
844+ assert stdout == "AAA,BBB,CCC\n " , stderr
845+ else :
846+ assert ret == 1
847+ assert (
848+ "Unable to retrieve credentials: "
849+ "failed to refresh cached credentials, "
850+ "failed to load credentials, deserialization failed, "
851+ "failed to deserialize json response, "
852+ "json: cannot unmarshal string into Go value of type client.GetCredentialsOutput"
853+ ) in stderr
0 commit comments