44# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55
66import os
7+ from mock import MagicMock
78import pytest
89from unittest import TestCase , mock
910
2122)
2223from ads .common .auth import (
2324 SecurityToken ,
24- TokenExpiredError ,
25+ SecurityTokenError ,
2526 api_keys ,
2627 resource_principal ,
2728 security_token ,
@@ -539,10 +540,12 @@ class TestSecurityToken(TestCase):
539540
540541 @mock .patch ("oci.auth.signers.SecurityTokenSigner.__init__" )
541542 @mock .patch ("oci.signer.load_private_key_from_file" )
543+ @mock .patch ("ads.common.auth.SecurityToken._read_security_token_file" )
542544 @mock .patch ("ads.common.auth.SecurityToken._validate_and_refresh_token" )
543545 def test_security_token (
544546 self ,
545547 mock_validate_and_refresh_token ,
548+ mock_read_security_token_file ,
546549 mock_load_private_key_from_file ,
547550 mock_security_token_signer
548551 ):
@@ -571,8 +574,9 @@ def test_security_token(
571574 client_kwargs = {"test_client_key" :"test_client_value" }
572575 )
573576
574- mock_validate_and_refresh_token .assert_called_with ("test_security_token" )
575- mock_load_private_key_from_file .assert_called_with ("test_key_file" )
577+ mock_validate_and_refresh_token .assert_called_with (config )
578+ mock_read_security_token_file .assert_called_with ("test_security_token" )
579+ mock_load_private_key_from_file .assert_called_with ("test_key_file" , None )
576580 assert signer ["client_kwargs" ] == {"test_client_key" : "test_client_value" }
577581 assert "additional_user_agent" in signer ["config" ]
578582 assert signer ["config" ]["fingerprint" ] == "test_fingerprint"
@@ -582,7 +586,7 @@ def test_security_token(
582586 assert signer ["config" ]["key_file" ] == "test_key_file"
583587 assert isinstance (signer ["signer" ], SecurityTokenSigner )
584588
585- @mock .patch ("os.system " )
589+ @mock .patch ("ads.common.auth.SecurityToken._refresh_security_token " )
586590 @mock .patch ("oci.auth.security_token_container.SecurityTokenContainer.get_jwt" )
587591 @mock .patch ("time.time" )
588592 @mock .patch ("oci.auth.security_token_container.SecurityTokenContainer.valid" )
@@ -595,7 +599,7 @@ def test_validate_and_refresh_token(
595599 mock_valid ,
596600 mock_time ,
597601 mock_get_jwt ,
598- mock_system
602+ mock_refresh_security_token
599603 ):
600604 security_token = SecurityToken (
601605 args = {
@@ -606,24 +610,94 @@ def test_validate_and_refresh_token(
606610 mock_security_token_container .return_value = None
607611
608612 mock_valid .return_value = False
613+ configuration = {
614+ "fingerprint" : "test_fingerprint" ,
615+ "tenancy" : "test_tenancy" ,
616+ "region" : "us-ashburn-1" ,
617+ "key_file" : "test_key_file" ,
618+ "security_token_file" : "test_security_token" ,
619+ "generic_headers" : [1 ,2 ,3 ],
620+ "body_headers" : [4 ,5 ,6 ]
621+ }
609622 with pytest .raises (
610- TokenExpiredError ,
623+ SecurityTokenError ,
611624 match = "Security token has expired. Call `oci session authenticate` to generate new session."
612625 ):
613- security_token ._validate_and_refresh_token ("test_security_token" )
626+ security_token ._validate_and_refresh_token (configuration )
614627
615628
616629 mock_valid .return_value = True
617630 mock_time .return_value = 1
618631 mock_get_jwt .return_value = {"exp" : 1 }
619632
620- security_token ._validate_and_refresh_token ("test_security_token" )
633+ security_token ._validate_and_refresh_token (configuration )
621634
622635 mock_read_security_token_file .assert_called_with ("test_security_token" )
623636 mock_security_token_container .assert_called ()
624637 mock_time .assert_called ()
625638 mock_get_jwt .assert_called ()
626- mock_system .assert_called_with ("oci session refresh --profile test_profile" )
639+ mock_refresh_security_token .assert_called_with (configuration )
640+
641+ @mock .patch ("oci_cli.cli_util.apply_user_only_access_permissions" )
642+ @mock .patch ("json.loads" )
643+ @mock .patch ("requests.post" )
644+ @mock .patch ("json.dumps" )
645+ @mock .patch ("oci.auth.signers.SecurityTokenSigner.__init__" )
646+ @mock .patch ("oci.signer.load_private_key_from_file" )
647+ @mock .patch ("builtins.open" )
648+ def test_refresh_security_token (
649+ self ,
650+ mock_open ,
651+ mock_load_private_key_from_file ,
652+ mock_security_token_signer ,
653+ mock_dumps ,
654+ mock_post ,
655+ mock_loads ,
656+ mock_apply_user_only_access_permissions
657+ ):
658+ security_token = SecurityToken (args = {})
659+ configuration = {
660+ "fingerprint" : "test_fingerprint" ,
661+ "tenancy" : "test_tenancy" ,
662+ "region" : "us-ashburn-1" ,
663+ "key_file" : "test_key_file" ,
664+ "security_token_file" : "test_security_token" ,
665+ "generic_headers" : [1 ,2 ,3 ],
666+ "body_headers" : [4 ,5 ,6 ]
667+ }
668+ mock_security_token_signer .return_value = None
669+ mock_loads .return_value = {
670+ "token" : "test_token"
671+ }
672+
673+ response = MagicMock ()
674+ response .status_code = 401
675+ mock_post .return_value = response
676+ with pytest .raises (
677+ SecurityTokenError ,
678+ match = "Security token has expired. Call `oci session authenticate` to generate new session."
679+ ):
680+ security_token ._refresh_security_token (configuration )
681+
682+ response .status_code = 500
683+ mock_post .return_value = response
684+ with pytest .raises (
685+ SecurityTokenError ,
686+ ):
687+ security_token ._refresh_security_token (configuration )
688+
689+ response .status_code = 200
690+ response .content = bytes ("test_content" , encoding = 'utf8' )
691+ mock_post .return_value = response
692+ security_token ._refresh_security_token (configuration )
693+
694+ mock_open .assert_called ()
695+ mock_load_private_key_from_file .assert_called_with ("test_key_file" , None )
696+ mock_security_token_signer .assert_called ()
697+ mock_dumps .assert_called ()
698+ mock_post .assert_called ()
699+ mock_loads .assert_called ()
700+ mock_apply_user_only_access_permissions .assert_called ()
627701
628702 @mock .patch ("builtins.open" )
629703 @mock .patch ("os.path.isfile" )
0 commit comments