Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/awscli_login/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from botocore.session import Session

from .__main__ import main as login, logout
from .__main__ import main as aws_login, login, logout
from ._version import version
from .account_names import edit_account_names
from .config import Profile, error_handler
Expand Down Expand Up @@ -63,8 +63,8 @@ def init_parser():
return parser


@error_handler()
def _main(profile: Profile, session: Session, interactive: bool = True):
def get_credentials(profile: Profile, session: Session):
"""Get credentials and print them."""
profile.raise_if_logged_out()
if profile.are_credentials_expired():
token = login(profile, session, interactive=False)
Expand All @@ -73,6 +73,11 @@ def _main(profile: Profile, session: Session, interactive: bool = True):
print_credentials(token)


@error_handler()
def _main(profile: Profile, session: Session):
get_credentials(profile, session)


def debug_info():
executable = sys.executable if platform.system() != "Windows" else \
sys.executable.lower()
Expand Down Expand Up @@ -116,7 +121,7 @@ def main():
if ns.debug_info:
debug_info()
return
return login(ns, session)
return aws_login(ns, session)
elif args.logout:
return logout(Namespace(**json.load(args.logout)), session)
elif args.alias:
Expand Down
70 changes: 70 additions & 0 deletions src/tests/test_credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from unittest.mock import (
MagicMock,
patch,
)

from awscli_login.credentials import (
get_credentials,
)

from .login import Login


class awsLoginTests(Login):
""" Class to test the aws-login script. """

# NOTA BENE: This is a regression test for issue #257
@patch("awscli_login.credentials.print_credentials")
@patch("awscli_login.__main__.authenticate")
@patch("awscli_login.__main__.save_sts_token")
@patch("awscli_login.__main__.get_selection",
return_value=["PrincipalArn2", "RoleArn2"])
@patch("awscli_login.__main__.refresh",
return_value=("SAML", ["PrincipalArn", "RoleArn"]))
def test_get_credentials_with_refresh(
self, refresh, get_selection, save_sts_token, authenticate,
print_credentials):
""" get_credentials should refresh expired credentials. """
fake_token = {"TOKEN": "FAKE_DATA"}
self.profile.are_credentials_expired = MagicMock(return_value=True)
save_sts_token.return_value = fake_token
self.profile.load_credentials = MagicMock(return_value=fake_token)

get_credentials(self.profile, self.session)

self.session.set_credentials.assert_called_with(None, None)
self.session.create_client.assert_called_with("sts")
self.profile.get_username.assert_not_called()
refresh.assert_called_with(
self.profile.ecp_endpoint_url,
self.profile.cookies,
self.profile.verify_ssl_certificate,
)
self.profile.get_credentials.assert_not_called()
authenticate.assert_not_called()
get_selection.assert_called_with(["PrincipalArn", "RoleArn"],
self.profile.role_arn, False, {})
save_sts_token.assert_called_with(
self.profile,
self.client,
"SAML",
["PrincipalArn2", "RoleArn2"],
self.profile.duration
)
self.profile.load_credentials.assert_not_called()
print_credentials.assert_called_with(fake_token)

@patch("awscli_login.credentials.print_credentials")
@patch("awscli_login.credentials.login")
def test_get_credentials_without_refresh(
self, login, print_credentials):
""" get_credentials should just print current credentials. """
fake_token = {"TOKEN": "FAKE_DATA"}
self.profile.are_credentials_expired = MagicMock(return_value=False)
self.profile.load_credentials = MagicMock(return_value=fake_token)

get_credentials(self.profile, self.session)

login.assert_not_called()
self.profile.load_credentials.assert_called()
print_credentials.assert_called_with(fake_token)