diff --git a/src/awscli_login/credentials.py b/src/awscli_login/credentials.py index f30dd854..420ce8a9 100644 --- a/src/awscli_login/credentials.py +++ b/src/awscli_login/credentials.py @@ -11,7 +11,7 @@ from botocore.session import Session -from .__main__ import main as login, logout +from .__main__ import main as aws_login, login, logout from ._version import version from .account_names import edit_account_names from .config import Profile, error_handler @@ -63,8 +63,8 @@ def init_parser(): return parser -@error_handler() -def _main(profile: Profile, session: Session, interactive: bool = True): +def get_credentials(profile: Profile, session: Session): + """Get credentials and print them.""" profile.raise_if_logged_out() if profile.are_credentials_expired(): token = login(profile, session, interactive=False) @@ -73,6 +73,11 @@ def _main(profile: Profile, session: Session, interactive: bool = True): print_credentials(token) +@error_handler() +def _main(profile: Profile, session: Session): + get_credentials(profile, session) + + def debug_info(): executable = sys.executable if platform.system() != "Windows" else \ sys.executable.lower() @@ -116,7 +121,7 @@ def main(): if ns.debug_info: debug_info() return - return login(ns, session) + return aws_login(ns, session) elif args.logout: return logout(Namespace(**json.load(args.logout)), session) elif args.alias: diff --git a/src/tests/test_credentials.py b/src/tests/test_credentials.py new file mode 100644 index 00000000..d90ee2c8 --- /dev/null +++ b/src/tests/test_credentials.py @@ -0,0 +1,70 @@ +from unittest.mock import ( + MagicMock, + patch, +) + +from awscli_login.credentials import ( + get_credentials, +) + +from .login import Login + + +class awsLoginTests(Login): + """ Class to test the aws-login script. """ + + # NOTA BENE: This is a regression test for issue #257 + @patch("awscli_login.credentials.print_credentials") + @patch("awscli_login.__main__.authenticate") + @patch("awscli_login.__main__.save_sts_token") + @patch("awscli_login.__main__.get_selection", + return_value=["PrincipalArn2", "RoleArn2"]) + @patch("awscli_login.__main__.refresh", + return_value=("SAML", ["PrincipalArn", "RoleArn"])) + def test_get_credentials_with_refresh( + self, refresh, get_selection, save_sts_token, authenticate, + print_credentials): + """ get_credentials should refresh expired credentials. """ + fake_token = {"TOKEN": "FAKE_DATA"} + self.profile.are_credentials_expired = MagicMock(return_value=True) + save_sts_token.return_value = fake_token + self.profile.load_credentials = MagicMock(return_value=fake_token) + + get_credentials(self.profile, self.session) + + self.session.set_credentials.assert_called_with(None, None) + self.session.create_client.assert_called_with("sts") + self.profile.get_username.assert_not_called() + refresh.assert_called_with( + self.profile.ecp_endpoint_url, + self.profile.cookies, + self.profile.verify_ssl_certificate, + ) + self.profile.get_credentials.assert_not_called() + authenticate.assert_not_called() + get_selection.assert_called_with(["PrincipalArn", "RoleArn"], + self.profile.role_arn, False, {}) + save_sts_token.assert_called_with( + self.profile, + self.client, + "SAML", + ["PrincipalArn2", "RoleArn2"], + self.profile.duration + ) + self.profile.load_credentials.assert_not_called() + print_credentials.assert_called_with(fake_token) + + @patch("awscli_login.credentials.print_credentials") + @patch("awscli_login.credentials.login") + def test_get_credentials_without_refresh( + self, login, print_credentials): + """ get_credentials should just print current credentials. """ + fake_token = {"TOKEN": "FAKE_DATA"} + self.profile.are_credentials_expired = MagicMock(return_value=False) + self.profile.load_credentials = MagicMock(return_value=fake_token) + + get_credentials(self.profile, self.session) + + login.assert_not_called() + self.profile.load_credentials.assert_called() + print_credentials.assert_called_with(fake_token)