diff --git a/tests/cli/test_cmds_spark_run.py b/tests/cli/test_cmds_spark_run.py index 42a5445a68..97f81dd432 100644 --- a/tests/cli/test_cmds_spark_run.py +++ b/tests/cli/test_cmds_spark_run.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import os import mock import pytest +from service_configuration_lib.spark_config import AWS_CREDENTIALS_DIR +from service_configuration_lib.spark_config import get_aws_credentials from paasta_tools import spark_tools from paasta_tools import utils @@ -1422,9 +1425,6 @@ def test_build_and_push_docker_image_unexpected_output_format( def test_get_aws_credentials(): - import os - from service_configuration_lib.spark_config import get_aws_credentials - with mock.patch.dict( os.environ, { @@ -1450,3 +1450,64 @@ def test_get_aws_credentials(): RoleSessionName=mock.ANY, WebIdentityToken="token-content", ) + + +@mock.patch("service_configuration_lib.spark_config.use_aws_profile", autospec=False) +@mock.patch("service_configuration_lib.spark_config.Session", autospec=True) +def test_get_aws_credentials_session(mock_boto3_session, mock_use_aws_profile): + # prioritize session over `profile_name` if both are provided + session = mock_boto3_session() + + get_aws_credentials( + service="some-service", + session=session, + profile_name="some-profile", + ) + + mock_use_aws_profile.assert_called_once_with(session=session) + session.assert_not_called() + + +@mock.patch("service_configuration_lib.spark_config.use_aws_profile", autospec=False) +@mock.patch("service_configuration_lib.spark_config.Session", autospec=True) +def test_get_aws_credentials_profile(mock_boto3_session, mock_use_aws_profile): + # prioritize `profile_name` over `service` if both are provided + profile_name = "some-profile" + + get_aws_credentials(service="some-service", profile_name=profile_name) + + mock_use_aws_profile.assert_called_once_with(profile_name=profile_name) + + +@mock.patch("service_configuration_lib.spark_config.use_aws_profile", autospec=False) +@mock.patch("os.path.exists", autospec=False) +@mock.patch( + "service_configuration_lib.spark_config._load_aws_credentials_from_yaml", + autospec=True, +) +def test_get_aws_credentials_boto_cfg( + mock_load_aws_credentials_from_yaml, mock_os_path_exists, mock_use_aws_profile +): + # use `service` if profile_name is not provided + service_name = "some-service" + + get_aws_credentials( + service=service_name, + ) + + credentials_path = f"{AWS_CREDENTIALS_DIR}{service_name}.yaml" + mock_os_path_exists.return_value = True + + mock_load_aws_credentials_from_yaml.assert_called_once_with(credentials_path) + mock_use_aws_profile.assert_not_called() + + +@mock.patch("service_configuration_lib.spark_config.use_aws_profile", autospec=False) +@mock.patch("service_configuration_lib.spark_config.Session", autospec=True) +def test_get_aws_credentials_default_profile(mock_boto3_session, mock_use_aws_profile): + # use `default` profile if no valid options are provided + get_aws_credentials( + service="spark", + ) + + mock_use_aws_profile.assert_called_once_with()