From 6f077265493310943c09759b49d21993045fe1e7 Mon Sep 17 00:00:00 2001 From: Alexander Mohr Date: Tue, 31 Mar 2020 09:57:12 -0700 Subject: [PATCH] async signing support (#659) --- CHANGES.rst | 9 + aiobotocore/__init__.py | 2 +- aiobotocore/args.py | 4 +- aiobotocore/client.py | 82 +- aiobotocore/credentials.py | 797 +++++++++++++++++++ aiobotocore/endpoint.py | 32 +- aiobotocore/hooks.py | 41 + aiobotocore/response.py | 3 + aiobotocore/session.py | 102 ++- aiobotocore/signers.py | 190 +++++ aiobotocore/utils.py | 198 +++++ aiobotocore/waiter.py | 2 +- setup.py | 11 +- tests/botocore/test_credentials.py | 1145 ++++++++++++++++++++++++++++ tests/botocore/test_signers.py | 151 ++++ tests/botocore/test_utils.py | 289 +++++++ tests/conftest.py | 14 +- tests/test_basic_s3.py | 19 +- tests/test_config.py | 10 + tests/test_eventstreams.py | 2 +- tests/test_patches.py | 154 +++- 21 files changed, 3198 insertions(+), 59 deletions(-) create mode 100644 aiobotocore/credentials.py create mode 100644 aiobotocore/hooks.py create mode 100644 aiobotocore/signers.py create mode 100644 aiobotocore/utils.py create mode 100644 tests/botocore/test_credentials.py create mode 100644 tests/botocore/test_signers.py create mode 100644 tests/botocore/test_utils.py diff --git a/CHANGES.rst b/CHANGES.rst index 149943b0..d39fc9ad 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,5 +1,14 @@ Changes ------- +1.0.0 (2020-03-31) +^^^^^^^^^^^^^^^^^^ +* API breaking: The result of create_client is now a required async context class +* Credential refresh should not work +* generate_presigned_url is now an async call along with other credential methods +* Credentials.[access_key/secret_key/token] now raise NotImplementedError because + they won't call refresh like botocore. Instead should use get_frozen_credentials + async method +* Bump botocore and extras 0.12.0 (2020-02-23) ^^^^^^^^^^^^^^^^^^^ diff --git a/aiobotocore/__init__.py b/aiobotocore/__init__.py index 15047d2f..997292ed 100644 --- a/aiobotocore/__init__.py +++ b/aiobotocore/__init__.py @@ -1,4 +1,4 @@ from .session import get_session, AioSession __all__ = ['get_session', 'AioSession'] -__version__ = '0.12.0' +__version__ = '1.0.0' diff --git a/aiobotocore/args.py b/aiobotocore/args.py index 5aa11aa5..3f04f8f7 100644 --- a/aiobotocore/args.py +++ b/aiobotocore/args.py @@ -3,10 +3,10 @@ from botocore.args import ClientArgsCreator import botocore.serialize import botocore.parsers -from botocore.signers import RequestSigner from .config import AioConfig from .endpoint import AioEndpointCreator +from .signers import AioRequestSigner class AioClientArgsCreator(ClientArgsCreator): @@ -32,7 +32,7 @@ def get_client_args(self, service_model, region_name, is_secure, endpoint_region_name = endpoint_config['region_name'] event_emitter = copy.copy(self._event_emitter) - signer = RequestSigner( + signer = AioRequestSigner( service_model.service_id, signing_region, endpoint_config['signing_name'], endpoint_config['signature_version'], diff --git a/aiobotocore/client.py b/aiobotocore/client.py index 07fdd1a7..989e3a35 100644 --- a/aiobotocore/client.py +++ b/aiobotocore/client.py @@ -1,25 +1,53 @@ -from botocore.client import logger, PaginatorDocstring, ClientCreator, BaseClient +from botocore.awsrequest import prepare_request_dict +from botocore.client import logger, PaginatorDocstring, ClientCreator, \ + BaseClient, ClientEndpointBridge from botocore.exceptions import OperationNotPageableError from botocore.history import get_global_history_recorder from botocore.utils import get_service_module_name from botocore.waiter import xform_name +from botocore.hooks import first_non_none_response from .paginate import AioPaginator from .args import AioClientArgsCreator from . import waiter - history_recorder = get_global_history_recorder() class AioClientCreator(ClientCreator): - def _create_client_class(self, service_name, service_model): + async def create_client(self, service_name, region_name, is_secure=True, + endpoint_url=None, verify=None, + credentials=None, scoped_config=None, + api_version=None, + client_config=None): + responses = await self._event_emitter.emit( + 'choose-service-name', service_name=service_name) + service_name = first_non_none_response(responses, default=service_name) + service_model = self._load_service_model(service_name, api_version) + cls = await self._create_client_class(service_name, service_model) + endpoint_bridge = ClientEndpointBridge( + self._endpoint_resolver, scoped_config, client_config, + service_signing_name=service_model.metadata.get('signingName')) + client_args = self._get_client_args( + service_model, region_name, is_secure, endpoint_url, + verify, credentials, scoped_config, client_config, endpoint_bridge) + service_client = cls(**client_args) + self._register_retries(service_client) + self._register_s3_events( + service_client, endpoint_bridge, endpoint_url, client_config, + scoped_config) + self._register_endpoint_discovery( + service_client, endpoint_url, client_config + ) + return service_client + + async def _create_client_class(self, service_name, service_model): class_attributes = self._create_methods(service_model) py_name_to_operation_name = self._create_name_mapping(service_model) class_attributes['_PY_TO_OP_NAME'] = py_name_to_operation_name bases = [AioBaseClient] service_id = service_model.service_id.hyphenize() - self._event_emitter.emit( + await self._event_emitter.emit( 'creating-client-class.%s' % service_id, class_attributes=class_attributes, base_classes=bases) @@ -59,11 +87,11 @@ async def _make_api_call(self, operation_name, api_params): 'has_streaming_input': operation_model.has_streaming_input, 'auth_type': operation_model.auth_type, } - request_dict = self._convert_to_request_dict( + request_dict = await self._convert_to_request_dict( api_params, operation_model, context=request_context) service_id = self._service_model.service_id.hyphenize() - handler, event_response = self.meta.events.emit_until_response( + handler, event_response = await self.meta.events.emit_until_response( 'before-call.{service_id}.{operation_name}'.format( service_id=service_id, operation_name=operation_name), @@ -76,7 +104,7 @@ async def _make_api_call(self, operation_name, api_params): http, parsed_response = await self._make_request( operation_model, request_dict, request_context) - self.meta.events.emit( + await self.meta.events.emit( 'after-call.{service_id}.{operation_name}'.format( service_id=service_id, operation_name=operation_name), @@ -95,7 +123,7 @@ async def _make_request(self, operation_model, request_dict, request_context): try: return await self._endpoint.make_request(operation_model, request_dict) except Exception as e: - self.meta.events.emit( + await self.meta.events.emit( 'after-call-error.{service_id}.{operation_name}'.format( service_id=self._service_model.service_id.hyphenize(), operation_name=operation_model.name), @@ -103,6 +131,44 @@ async def _make_request(self, operation_model, request_dict, request_context): ) raise + async def _convert_to_request_dict(self, api_params, operation_model, + context=None): + api_params = await self._emit_api_params( + api_params, operation_model, context) + request_dict = self._serializer.serialize_to_request( + api_params, operation_model) + if not self._client_config.inject_host_prefix: + request_dict.pop('host_prefix', None) + prepare_request_dict(request_dict, endpoint_url=self._endpoint.host, + user_agent=self._client_config.user_agent, + context=context) + return request_dict + + async def _emit_api_params(self, api_params, operation_model, context): + # Given the API params provided by the user and the operation_model + # we can serialize the request to a request_dict. + operation_name = operation_model.name + + # Emit an event that allows users to modify the parameters at the + # beginning of the method. It allows handlers to modify existing + # parameters or return a new set of parameters to use. + service_id = self._service_model.service_id.hyphenize() + responses = await self.meta.events.emit( + 'provide-client-params.{service_id}.{operation_name}'.format( + service_id=service_id, + operation_name=operation_name), + params=api_params, model=operation_model, context=context) + api_params = first_non_none_response(responses, default=api_params) + + event_name = ( + 'before-parameter-build.{service_id}.{operation_name}') + await self.meta.events.emit( + event_name.format( + service_id=service_id, + operation_name=operation_name), + params=api_params, model=operation_model, context=context) + return api_params + def get_paginator(self, operation_name): """Create a paginator for an operation. diff --git a/aiobotocore/credentials.py b/aiobotocore/credentials.py new file mode 100644 index 00000000..a5b15898 --- /dev/null +++ b/aiobotocore/credentials.py @@ -0,0 +1,797 @@ +import asyncio +import logging +import subprocess +from copy import deepcopy +from typing import Optional + +from botocore import UNSIGNED +import botocore.compat +from botocore.credentials import EnvProvider, Credentials, RefreshableCredentials, \ + ReadOnlyCredentials, ContainerProvider, ContainerMetadataFetcher, \ + _parse_if_needed, InstanceMetadataProvider, _get_client_creator, \ + ProfileProviderBuilder, ConfigProvider, SharedCredentialProvider, \ + ProcessProvider, AssumeRoleWithWebIdentityProvider, _local_now, \ + CachedCredentialFetcher, _serialize_if_needed, BaseAssumeRoleCredentialFetcher, \ + AssumeRoleProvider, AssumeRoleCredentialFetcher, CredentialResolver, \ + CanonicalNameCredentialSourcer, BotoProvider, OriginalEC2Provider +from botocore.exceptions import MetadataRetrievalError, CredentialRetrievalError, \ + InvalidConfigError, PartialCredentialsError, RefreshWithMFAUnsupportedError, \ + UnknownCredentialError +from botocore.compat import compat_shell_split + +from aiobotocore.utils import AioContainerMetadataFetcher, AioInstanceMetadataFetcher +from aiobotocore.config import AioConfig + +logger = logging.getLogger(__name__) + + +def create_credential_resolver(session, cache=None, region_name=None): + """Create a default credential resolver. + This creates a pre-configured credential resolver + that includes the default lookup chain for + credentials. + """ + profile_name = session.get_config_variable('profile') or 'default' + metadata_timeout = session.get_config_variable('metadata_service_timeout') + num_attempts = session.get_config_variable('metadata_service_num_attempts') + disable_env_vars = session.instance_variables().get('profile') is not None + + if cache is None: + cache = {} + + env_provider = AioEnvProvider() + container_provider = AioContainerProvider() + instance_metadata_provider = AioInstanceMetadataProvider( + iam_role_fetcher=AioInstanceMetadataFetcher( + timeout=metadata_timeout, + num_attempts=num_attempts, + user_agent=session.user_agent()) + ) + + profile_provider_builder = AioProfileProviderBuilder( + session, cache=cache, region_name=region_name) + assume_role_provider = AioAssumeRoleProvider( + load_config=lambda: session.full_config, + client_creator=_get_client_creator(session, region_name), + cache=cache, + profile_name=profile_name, + credential_sourcer=AioCanonicalNameCredentialSourcer([ + env_provider, container_provider, instance_metadata_provider + ]), + profile_provider_builder=profile_provider_builder, + ) + + pre_profile = [ + env_provider, + assume_role_provider, + ] + profile_providers = profile_provider_builder.providers( + profile_name=profile_name, + disable_env_vars=disable_env_vars, + ) + post_profile = [ + AioOriginalEC2Provider(), + AioBotoProvider(), + container_provider, + instance_metadata_provider, + ] + providers = pre_profile + profile_providers + post_profile + + if disable_env_vars: + # An explicitly provided profile will negate an EnvProvider. + # We will defer to providers that understand the "profile" + # concept to retrieve credentials. + # The one edge case if is all three values are provided via + # env vars: + # export AWS_ACCESS_KEY_ID=foo + # export AWS_SECRET_ACCESS_KEY=bar + # export AWS_PROFILE=baz + # Then, just like our client() calls, the explicit credentials + # will take precedence. + # + # This precedence is enforced by leaving the EnvProvider in the chain. + # This means that the only way a "profile" would win is if the + # EnvProvider does not return credentials, which is what we want + # in this scenario. + providers.remove(env_provider) + logger.debug('Skipping environment variable credential check' + ' because profile name was explicitly set.') + + resolver = AioCredentialResolver(providers=providers) + return resolver + + +class AioProfileProviderBuilder(ProfileProviderBuilder): + def _create_process_provider(self, profile_name): + return AioProcessProvider( + profile_name=profile_name, + load_config=lambda: self._session.full_config, + ) + + def _create_shared_credential_provider(self, profile_name): + credential_file = self._session.get_config_variable('credentials_file') + return AioSharedCredentialProvider( + profile_name=profile_name, + creds_filename=credential_file, + ) + + def _create_config_provider(self, profile_name): + config_file = self._session.get_config_variable('config_file') + return AioConfigProvider( + profile_name=profile_name, + config_filename=config_file, + ) + + def _create_web_identity_provider(self, profile_name, disable_env_vars): + return AioAssumeRoleWithWebIdentityProvider( + load_config=lambda: self._session.full_config, + client_creator=_get_client_creator( + self._session, self._region_name), + cache=self._cache, + profile_name=profile_name, + disable_env_vars=disable_env_vars, + ) + + +async def get_credentials(session): + resolver = create_credential_resolver(session) + return await resolver.load_credentials() + + +def create_assume_role_refresher(client, params): + async def refresh(): + async with client as sts: + response = await sts.assume_role(**params) + credentials = response['Credentials'] + # We need to normalize the credential names to + # the values expected by the refresh creds. + return { + 'access_key': credentials['AccessKeyId'], + 'secret_key': credentials['SecretAccessKey'], + 'token': credentials['SessionToken'], + 'expiry_time': _serialize_if_needed(credentials['Expiration']), + } + return refresh + + +def create_aio_mfa_serial_refresher(actual_refresh): + class _Refresher(object): + def __init__(self, refresh): + self._refresh = refresh + self._has_been_called = False + + async def call(self): + if self._has_been_called: + # We can explore an option in the future to support + # reprompting for MFA, but for now we just error out + # when the temp creds expire. + raise RefreshWithMFAUnsupportedError() + self._has_been_called = True + return await self._refresh() + + return _Refresher(actual_refresh).call + + +class AioCredentials(Credentials): + async def get_frozen_credentials(self): + return ReadOnlyCredentials(self.access_key, + self.secret_key, + self.token) + + @classmethod + def from_credentials(cls, obj: Optional[Credentials]): + if obj is None: + return None + return cls( + obj.access_key, obj.secret_key, + obj.token, obj.method) + + +class AioRefreshableCredentials(RefreshableCredentials): + def __init__(self, *args, **kwargs): + super(AioRefreshableCredentials, self).__init__(*args, **kwargs) + self._refresh_lock = asyncio.Lock() + + @classmethod + def from_refreshable_credentials(cls, obj: Optional[RefreshableCredentials]): + if obj is None: + return None + return cls( # Using interval values here to skip property calling .refresh() + obj._access_key, obj._secret_key, + obj._token, obj._expiry_time, + obj._refresh_using, obj.method, + obj._time_fetcher + ) + + # Redeclaring the properties so it doesnt call refresh + # Have to redeclare setter as we're overriding the getter + @property + def access_key(self): + # TODO: this needs to be resolved + raise NotImplementedError("missing call to self._refresh. " + "Use get_frozen_credentials instead") + return self._access_key + + @access_key.setter + def access_key(self, value): + self._access_key = value + + @property + def secret_key(self): + # TODO: this needs to be resolved + raise NotImplementedError("missing call to self._refresh. " + "Use get_frozen_credentials instead") + return self._secret_key + + @secret_key.setter + def secret_key(self, value): + self._secret_key = value + + @property + def token(self): + # TODO: this needs to be resolved + raise NotImplementedError("missing call to self._refresh. " + "Use get_frozen_credentials instead") + return self._token + + @token.setter + def token(self, value): + self._token = value + + async def _refresh(self): + if not self.refresh_needed(self._advisory_refresh_timeout): + return + + # By this point we need a refresh but its not critical + if not self._refresh_lock.locked(): + async with self._refresh_lock: + if not self.refresh_needed(self._advisory_refresh_timeout): + return + is_mandatory_refresh = self.refresh_needed( + self._mandatory_refresh_timeout) + await self._protected_refresh(is_mandatory=is_mandatory_refresh) + return + elif self.refresh_needed(self._mandatory_refresh_timeout): + # If we're here, we absolutely need a refresh and the + # lock is held so wait for it + async with self._refresh_lock: + # Might have refreshed by now + if not self.refresh_needed(self._mandatory_refresh_timeout): + return + await self._protected_refresh(is_mandatory=True) + + async def _protected_refresh(self, is_mandatory): + try: + metadata = await self._refresh_using() + except Exception: + period_name = 'mandatory' if is_mandatory else 'advisory' + logger.warning("Refreshing temporary credentials failed " + "during %s refresh period.", + period_name, exc_info=True) + if is_mandatory: + # If this is a mandatory refresh, then + # all errors that occur when we attempt to refresh + # credentials are propagated back to the user. + raise + # Otherwise we'll just return. + # The end result will be that we'll use the current + # set of temporary credentials we have. + return + self._set_from_data(metadata) + self._frozen_credentials = ReadOnlyCredentials( + self._access_key, self._secret_key, self._token) + if self._is_expired(): + msg = ("Credentials were refreshed, but the " + "refreshed credentials are still expired.") + logger.warning(msg) + raise RuntimeError(msg) + + async def get_frozen_credentials(self): + await self._refresh() + return self._frozen_credentials + + +class AioDeferredRefreshableCredentials(AioRefreshableCredentials): + def __init__(self, refresh_using, method, time_fetcher=_local_now): + self._refresh_using = refresh_using + self._access_key = None + self._secret_key = None + self._token = None + self._expiry_time = None + self._time_fetcher = time_fetcher + self._refresh_lock = asyncio.Lock() + self.method = method + self._frozen_credentials = None + + def refresh_needed(self, refresh_in=None): + if self._frozen_credentials is None: + return True + return super(AioDeferredRefreshableCredentials, self).refresh_needed( + refresh_in + ) + + +class AioCachedCredentialFetcher(CachedCredentialFetcher): + async def _get_credentials(self): + raise NotImplementedError('_get_credentials()') + + async def fetch_credentials(self): + return await self._get_cached_credentials() + + async def _get_cached_credentials(self): + """Get up-to-date credentials. + + This will check the cache for up-to-date credentials, calling assume + role if none are available. + """ + response = self._load_from_cache() + if response is None: + response = await self._get_credentials() + self._write_to_cache(response) + else: + logger.debug("Credentials for role retrieved from cache.") + + creds = response['Credentials'] + expiration = _serialize_if_needed(creds['Expiration'], iso=True) + return { + 'access_key': creds['AccessKeyId'], + 'secret_key': creds['SecretAccessKey'], + 'token': creds['SessionToken'], + 'expiry_time': expiration, + } + + +class AioBaseAssumeRoleCredentialFetcher(BaseAssumeRoleCredentialFetcher, + AioCachedCredentialFetcher): + pass + + +class AioAssumeRoleCredentialFetcher(AssumeRoleCredentialFetcher, + AioBaseAssumeRoleCredentialFetcher): + async def _get_credentials(self): + """Get credentials by calling assume role.""" + kwargs = self._assume_role_kwargs() + client = await self._create_client() + async with client as sts: + return await sts.assume_role(**kwargs) + + async def _create_client(self): + """Create an STS client using the source credentials.""" + frozen_credentials = await self._source_credentials.get_frozen_credentials() + return self._client_creator( + 'sts', + aws_access_key_id=frozen_credentials.access_key, + aws_secret_access_key=frozen_credentials.secret_key, + aws_session_token=frozen_credentials.token, + ) + + +class AioAssumeRoleWithWebIdentityCredentialFetcher( + AioBaseAssumeRoleCredentialFetcher +): + def __init__(self, client_creator, web_identity_token_loader, role_arn, + extra_args=None, cache=None, expiry_window_seconds=None): + + self._web_identity_token_loader = web_identity_token_loader + + super(AioAssumeRoleWithWebIdentityCredentialFetcher, self).__init__( + client_creator, role_arn, extra_args=extra_args, + cache=cache, expiry_window_seconds=expiry_window_seconds + ) + + async def _get_credentials(self): + """Get credentials by calling assume role.""" + kwargs = self._assume_role_kwargs() + # Assume role with web identity does not require credentials other than + # the token, explicitly configure the client to not sign requests. + config = AioConfig(signature_version=UNSIGNED) + async with self._client_creator('sts', config=config) as client: + return await client.assume_role_with_web_identity(**kwargs) + + def _assume_role_kwargs(self): + """Get the arguments for assume role based on current configuration.""" + assume_role_kwargs = deepcopy(self._assume_kwargs) + identity_token = self._web_identity_token_loader() + assume_role_kwargs['WebIdentityToken'] = identity_token + + return assume_role_kwargs + + +class AioProcessProvider(ProcessProvider): + def __init__(self, *args, popen=asyncio.create_subprocess_exec, **kwargs): + super(AioProcessProvider, self).__init__(*args, **kwargs, popen=popen) + + async def load(self): + credential_process = self._credential_process + if credential_process is None: + return + + creds_dict = await self._retrieve_credentials_using(credential_process) + if creds_dict.get('expiry_time') is not None: + return AioRefreshableCredentials.create_from_metadata( + creds_dict, + lambda: self._retrieve_credentials_using(credential_process), + self.METHOD + ) + + return AioCredentials( + access_key=creds_dict['access_key'], + secret_key=creds_dict['secret_key'], + token=creds_dict.get('token'), + method=self.METHOD + ) + + async def _retrieve_credentials_using(self, credential_process): + # We're not using shell=True, so we need to pass the + # command and all arguments as a list. + process_list = compat_shell_split(credential_process) + p = await self._popen(process_list, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + stdout, stderr = await p.communicate() + if p.returncode != 0: + raise CredentialRetrievalError( + provider=self.METHOD, error_msg=stderr.decode('utf-8')) + parsed = botocore.compat.json.loads(stdout.decode('utf-8')) + version = parsed.get('Version', '') + if version != 1: + raise CredentialRetrievalError( + provider=self.METHOD, + error_msg=("Unsupported version '%s' for credential process " + "provider, supported versions: 1" % version)) + try: + return { + 'access_key': parsed['AccessKeyId'], + 'secret_key': parsed['SecretAccessKey'], + 'token': parsed.get('SessionToken'), + 'expiry_time': parsed.get('Expiration'), + } + except KeyError as e: + raise CredentialRetrievalError( + provider=self.METHOD, + error_msg="Missing required key in response: %s" % e + ) + + +class AioInstanceMetadataProvider(InstanceMetadataProvider): + async def load(self): + fetcher = self._role_fetcher + metadata = await fetcher.retrieve_iam_role_credentials() + if not metadata: + return None + logger.debug('Found credentials from IAM Role: %s', + metadata['role_name']) + + creds = AioRefreshableCredentials.create_from_metadata( + metadata, + method=self.METHOD, + refresh_using=fetcher.retrieve_iam_role_credentials, + ) + return creds + + +class AioEnvProvider(EnvProvider): + async def load(self): + # It gets credentials from an env var, + # so just convert the response to Aio variants + result = super().load() + if isinstance(result, RefreshableCredentials): + return AioRefreshableCredentials.\ + from_refreshable_credentials(result) + elif isinstance(result, Credentials): + return AioCredentials.from_credentials(result) + + return None + + +class AioOriginalEC2Provider(OriginalEC2Provider): + async def load(self): + result = super(AioOriginalEC2Provider, self).load() + if isinstance(result, Credentials): + result = AioCredentials.from_credentials(result) + return result + + +class AioSharedCredentialProvider(SharedCredentialProvider): + async def load(self): + result = super(AioSharedCredentialProvider, self).load() + if isinstance(result, Credentials): + result = AioCredentials.from_credentials(result) + return result + + +class AioConfigProvider(ConfigProvider): + async def load(self): + result = super(AioConfigProvider, self).load() + if isinstance(result, Credentials): + result = AioCredentials.from_credentials(result) + return result + + +class AioBotoProvider(BotoProvider): + async def load(self): + result = super(AioBotoProvider, self).load() + if isinstance(result, Credentials): + result = AioCredentials.from_credentials(result) + return result + + +class AioAssumeRoleProvider(AssumeRoleProvider): + async def load(self): + self._loaded_config = self._load_config() + profiles = self._loaded_config.get('profiles', {}) + profile = profiles.get(self._profile_name, {}) + if self._has_assume_role_config_vars(profile): + return await self._load_creds_via_assume_role(self._profile_name) + + async def _load_creds_via_assume_role(self, profile_name): + role_config = self._get_role_config(profile_name) + source_credentials = await self._resolve_source_credentials( + role_config, profile_name + ) + + extra_args = {} + role_session_name = role_config.get('role_session_name') + if role_session_name is not None: + extra_args['RoleSessionName'] = role_session_name + + external_id = role_config.get('external_id') + if external_id is not None: + extra_args['ExternalId'] = external_id + + mfa_serial = role_config.get('mfa_serial') + if mfa_serial is not None: + extra_args['SerialNumber'] = mfa_serial + + duration_seconds = role_config.get('duration_seconds') + if duration_seconds is not None: + extra_args['DurationSeconds'] = duration_seconds + + fetcher = AioAssumeRoleCredentialFetcher( + client_creator=self._client_creator, + source_credentials=source_credentials, + role_arn=role_config['role_arn'], + extra_args=extra_args, + mfa_prompter=self._prompter, + cache=self.cache, + ) + refresher = fetcher.fetch_credentials + if mfa_serial is not None: + refresher = create_aio_mfa_serial_refresher(refresher) + + # The initial credentials are empty and the expiration time is set + # to now so that we can delay the call to assume role until it is + # strictly needed. + return AioDeferredRefreshableCredentials( + method=self.METHOD, + refresh_using=refresher, + time_fetcher=_local_now + ) + + async def _resolve_source_credentials(self, role_config, profile_name): + credential_source = role_config.get('credential_source') + if credential_source is not None: + return await self._resolve_credentials_from_source( + credential_source, profile_name + ) + + source_profile = role_config['source_profile'] + self._visited_profiles.append(source_profile) + return await self._resolve_credentials_from_profile(source_profile) + + async def _resolve_credentials_from_profile(self, profile_name): + profiles = self._loaded_config.get('profiles', {}) + profile = profiles[profile_name] + + if self._has_static_credentials(profile) and \ + not self._profile_provider_builder: + return self._resolve_static_credentials_from_profile(profile) + elif self._has_static_credentials(profile) or \ + not self._has_assume_role_config_vars(profile): + profile_providers = self._profile_provider_builder.providers( + profile_name=profile_name, + disable_env_vars=True, + ) + profile_chain = AioCredentialResolver(profile_providers) + credentials = await profile_chain.load_credentials() + if credentials is None: + error_message = ( + 'The source profile "%s" must have credentials.' + ) + raise InvalidConfigError( + error_msg=error_message % profile_name, + ) + return credentials + + return self._load_creds_via_assume_role(profile_name) + + def _resolve_static_credentials_from_profile(self, profile): + try: + return AioCredentials( + access_key=profile['aws_access_key_id'], + secret_key=profile['aws_secret_access_key'], + token=profile.get('aws_session_token') + ) + except KeyError as e: + raise PartialCredentialsError( + provider=self.METHOD, cred_var=str(e)) + + async def _resolve_credentials_from_source(self, credential_source, + profile_name): + credentials = await self._credential_sourcer.source_credentials( + credential_source) + if credentials is None: + raise CredentialRetrievalError( + provider=credential_source, + error_msg=( + 'No credentials found in credential_source referenced ' + 'in profile %s' % profile_name + ) + ) + return credentials + + +class AioAssumeRoleWithWebIdentityProvider(AssumeRoleWithWebIdentityProvider): + async def load(self): + return await self._assume_role_with_web_identity() + + async def _assume_role_with_web_identity(self): + token_path = self._get_config('web_identity_token_file') + if not token_path: + return None + token_loader = self._token_loader_cls(token_path) + + role_arn = self._get_config('role_arn') + if not role_arn: + error_msg = ( + 'The provided profile or the current environment is ' + 'configured to assume role with web identity but has no ' + 'role ARN configured. Ensure that the profile has the role_arn' + 'configuration set or the AWS_ROLE_ARN env var is set.' + ) + raise InvalidConfigError(error_msg=error_msg) + + extra_args = {} + role_session_name = self._get_config('role_session_name') + if role_session_name is not None: + extra_args['RoleSessionName'] = role_session_name + + fetcher = AioAssumeRoleWithWebIdentityCredentialFetcher( + client_creator=self._client_creator, + web_identity_token_loader=token_loader, + role_arn=role_arn, + extra_args=extra_args, + cache=self.cache, + ) + # The initial credentials are empty and the expiration time is set + # to now so that we can delay the call to assume role until it is + # strictly needed. + return AioDeferredRefreshableCredentials( + method=self.METHOD, + refresh_using=fetcher.fetch_credentials, + ) + + +class AioCanonicalNameCredentialSourcer(CanonicalNameCredentialSourcer): + async def source_credentials(self, source_name): + """Loads source credentials based on the provided configuration. + + :type source_name: str + :param source_name: The value of credential_source in the config + file. This is the canonical name of the credential provider. + + :rtype: Credentials + """ + source = self._get_provider(source_name) + if isinstance(source, AioCredentialResolver): + return await source.load_credentials() + return await source.load() + + def _get_provider(self, canonical_name): + """Return a credential provider by its canonical name. + + :type canonical_name: str + :param canonical_name: The canonical name of the provider. + + :raises UnknownCredentialError: Raised if no + credential provider by the provided name + is found. + """ + provider = self._get_provider_by_canonical_name(canonical_name) + + # The AssumeRole provider should really be part of the SharedConfig + # provider rather than being its own thing, but it is not. It is + # effectively part of both the SharedConfig provider and the + # SharedCredentials provider now due to the way it behaves. + # Therefore if we want either of those providers we should return + # the AssumeRole provider with it. + if canonical_name.lower() in ['sharedconfig', 'sharedcredentials']: + assume_role_provider = self._get_provider_by_method('assume-role') + if assume_role_provider is not None: + # The SharedConfig or SharedCredentials provider may not be + # present if it was removed for some reason, but the + # AssumeRole provider could still be present. In that case, + # return the assume role provider by itself. + if provider is None: + return assume_role_provider + + # If both are present, return them both as a + # CredentialResolver so that calling code can treat them as + # a single entity. + return AioCredentialResolver([assume_role_provider, provider]) + + if provider is None: + raise UnknownCredentialError(name=canonical_name) + + return provider + + +class AioContainerProvider(ContainerProvider): + def __init__(self, *args, **kwargs): + super(AioContainerProvider, self).__init__(*args, **kwargs) + + # This will always run if no fetcher arg is provided + if isinstance(self._fetcher, ContainerMetadataFetcher): + self._fetcher = AioContainerMetadataFetcher() + + async def load(self): + if self.ENV_VAR in self._environ or self.ENV_VAR_FULL in self._environ: + return await self._retrieve_or_fail() + + async def _retrieve_or_fail(self): + if self._provided_relative_uri(): + full_uri = self._fetcher.full_url(self._environ[self.ENV_VAR]) + else: + full_uri = self._environ[self.ENV_VAR_FULL] + headers = self._build_headers() + fetcher = self._create_fetcher(full_uri, headers) + creds = await fetcher() + return AioRefreshableCredentials( + access_key=creds['access_key'], + secret_key=creds['secret_key'], + token=creds['token'], + method=self.METHOD, + expiry_time=_parse_if_needed(creds['expiry_time']), + refresh_using=fetcher, + ) + + def _create_fetcher(self, full_uri, headers): + async def fetch_creds(): + try: + response = await self._fetcher.retrieve_full_uri( + full_uri, headers=headers) + except MetadataRetrievalError as e: + logger.debug("Error retrieving container metadata: %s", e, + exc_info=True) + raise CredentialRetrievalError(provider=self.METHOD, + error_msg=str(e)) + return { + 'access_key': response['AccessKeyId'], + 'secret_key': response['SecretAccessKey'], + 'token': response['Token'], + 'expiry_time': response['Expiration'], + } + + return fetch_creds + + +class AioCredentialResolver(CredentialResolver): + async def load_credentials(self): + """ + Goes through the credentials chain, returning the first ``Credentials`` + that could be loaded. + """ + # First provider to return a non-None response wins. + for provider in self.providers: + logger.debug("Looking for credentials via: %s", provider.METHOD) + creds = await provider.load() + if creds is not None: + return creds + + # If we got here, no credentials could be found. + # This feels like it should be an exception, but historically, ``None`` + # is returned. + # + # +1 + # -js + return None diff --git a/aiobotocore/endpoint.py b/aiobotocore/endpoint.py index 4219f02c..bfa4ae49 100644 --- a/aiobotocore/endpoint.py +++ b/aiobotocore/endpoint.py @@ -5,7 +5,7 @@ import aiohttp.http_exceptions from aiohttp.client import URL from botocore.endpoint import EndpointCreator, Endpoint, DEFAULT_TIMEOUT, \ - MAX_POOL_CONNECTIONS, logger, history_recorder + MAX_POOL_CONNECTIONS, logger, history_recorder, create_request_object from botocore.exceptions import ConnectionClosedError from botocore.hooks import first_non_none_response from botocore.utils import is_valid_endpoint_url @@ -63,9 +63,25 @@ def __init__(self, *args, proxies=None, **kwargs): super().__init__(*args, **kwargs) self.proxies = proxies or {} + async def create_request(self, params, operation_model=None): + request = create_request_object(params) + if operation_model: + request.stream_output = any([ + operation_model.has_streaming_output, + operation_model.has_event_stream_output + ]) + service_id = operation_model.service_model.service_id.hyphenize() + event_name = 'request-created.{service_id}.{op_name}'.format( + service_id=service_id, + op_name=operation_model.name) + await self._event_emitter.emit(event_name, request=request, + operation_name=operation_model.name) + prepared_request = self.prepare_request(request) + return prepared_request + async def _send_request(self, request_dict, operation_model): attempts = 1 - request = self.create_request(request_dict, operation_model) + request = await self.create_request(request_dict, operation_model) context = request_dict['context'] success_response, exception = await self._get_response( request, operation_model, context) @@ -79,7 +95,7 @@ async def _send_request(self, request_dict, operation_model): # body. request.reset_stream() # Create a new request when retried (including a new signature). - request = self.create_request( + request = await self.create_request( request_dict, operation_model) success_response, exception = await self._get_response( request, operation_model, context) @@ -114,7 +130,7 @@ async def _get_response(self, request, operation_model, context): kwargs_to_emit['response_dict'] = await convert_to_response_dict( http_response, operation_model) service_id = operation_model.service_model.service_id.hyphenize() - self._event_emitter.emit( + await self._event_emitter.emit( 'response-received.%s.%s' % ( service_id, operation_model.name), **kwargs_to_emit) return success_response, exception @@ -130,8 +146,10 @@ async def _do_get_response(self, request, operation_model): 'body': request.body }) service_id = operation_model.service_model.service_id.hyphenize() - event_name = 'before-send.%s.%s' % (service_id, operation_model.name) - responses = self._event_emitter.emit(event_name, request=request) + event_name = 'before-send.%s.%s' % ( + service_id, operation_model.name) + responses = await self._event_emitter.emit(event_name, + request=request) http_response = first_non_none_response(responses) if http_response is None: http_response = await self._send(request) @@ -170,7 +188,7 @@ async def _needs_retry(self, attempts, operation_model, request_dict, event_name = 'needs-retry.%s.%s' % ( service_id, operation_model.name) - responses = self._event_emitter.emit( + responses = await self._event_emitter.emit( event_name, response=response, endpoint=self, operation=operation_model, attempts=attempts, caught_exception=caught_exception, request_dict=request_dict) diff --git a/aiobotocore/hooks.py b/aiobotocore/hooks.py new file mode 100644 index 00000000..496d5f74 --- /dev/null +++ b/aiobotocore/hooks.py @@ -0,0 +1,41 @@ +import asyncio + +from botocore.hooks import HierarchicalEmitter, logger + + +class AioHierarchicalEmitter(HierarchicalEmitter): + async def _emit(self, event_name, kwargs, stop_on_response=False): + responses = [] + # Invoke the event handlers from most specific + # to least specific, each time stripping off a dot. + handlers_to_call = self._lookup_cache.get(event_name) + if handlers_to_call is None: + handlers_to_call = self._handlers.prefix_search(event_name) + self._lookup_cache[event_name] = handlers_to_call + elif not handlers_to_call: + # Short circuit and return an empty response is we have + # no handlers to call. This is the common case where + # for the majority of signals, nothing is listening. + return [] + kwargs['event_name'] = event_name + responses = [] + for handler in handlers_to_call: + logger.debug('Event %s: calling handler %s', event_name, handler) + + # Await the handler if its a coroutine. + if asyncio.iscoroutinefunction(handler): + response = await handler(**kwargs) + else: + response = handler(**kwargs) + + responses.append((handler, response)) + if stop_on_response and response is not None: + return responses + return responses + + async def emit_until_response(self, event_name, **kwargs): + responses = await self._emit(event_name, kwargs, stop_on_response=True) + if responses: + return responses[-1] + else: + return None, None diff --git a/aiobotocore/response.py b/aiobotocore/response.py index 7993087a..83fbb28e 100644 --- a/aiobotocore/response.py +++ b/aiobotocore/response.py @@ -37,6 +37,9 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): # NOTE: set_socket_timeout was only for when requests didn't support # read timeouts, so not needed + def tell(self): + return self._self_amount_read + async def read(self, amt=None): """Read at most amt bytes from the stream. diff --git a/aiobotocore/session.py b/aiobotocore/session.py index 6b0d4c3b..273ebb0d 100644 --- a/aiobotocore/session.py +++ b/aiobotocore/session.py @@ -1,25 +1,53 @@ -import botocore.credentials -import botocore.session +from botocore.session import Session, EVENT_ALIASES, ServiceModel, UnknownServiceError from botocore import UNSIGNED from botocore import retryhandler, translate from botocore.exceptions import PartialCredentialsError -from .client import AioClientCreator +from .client import AioClientCreator, AioBaseClient +from .hooks import AioHierarchicalEmitter from .parsers import AioResponseParserFactory +from .signers import add_generate_presigned_url +from .credentials import create_credential_resolver, AioCredentials -class AioSession(botocore.session.Session): +class ClientCreatorContext: + def __init__(self, coro): + self._coro = coro + self._client = None - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + async def __aenter__(self) -> AioBaseClient: + self._client = await self._coro + return await self._client.__aenter__() - # Register the AioResponseParserFactory so event streams will be async'd - self.register_component('response_parser_factory', AioResponseParserFactory()) + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self._client.__aexit__(exc_type, exc_val, exc_tb) - def create_client(self, service_name, region_name=None, api_version=None, - use_ssl=True, verify=None, endpoint_url=None, - aws_access_key_id=None, aws_secret_access_key=None, - aws_session_token=None, config=None): + +class AioSession(Session): + + # noinspection PyMissingConstructor + def __init__(self, session_vars=None, event_hooks=None, + include_builtin_handlers=True, profile=None): + if event_hooks is None: + event_hooks = AioHierarchicalEmitter() + + super().__init__(session_vars, event_hooks, include_builtin_handlers, profile) + + # Register our own handlers + self.register('creating-client-class', add_generate_presigned_url) + + def _register_response_parser_factory(self): + self._components.register_component('response_parser_factory', + AioResponseParserFactory()) + + def create_client(self, *args, **kwargs): + return ClientCreatorContext(self._create_client(*args, **kwargs)) + + async def _create_client(self, service_name, region_name=None, + api_version=None, + use_ssl=True, verify=None, endpoint_url=None, + aws_access_key_id=None, aws_secret_access_key=None, + aws_session_token=None, config=None): default_client_config = self.get_default_client_config() # If a config is provided and a default config is set, then @@ -50,7 +78,7 @@ def create_client(self, service_name, region_name=None, api_version=None, credentials = None elif aws_access_key_id is not None and \ aws_secret_access_key is not None: - credentials = botocore.credentials.Credentials( + credentials = AioCredentials( access_key=aws_access_key_id, secret_key=aws_secret_access_key, token=aws_session_token) @@ -61,7 +89,7 @@ def create_client(self, service_name, region_name=None, api_version=None, cred_var=self._missing_cred_vars(aws_access_key_id, aws_secret_access_key)) else: - credentials = self.get_credentials() + credentials = await self.get_credentials() endpoint_resolver = self._get_internal_component('endpoint_resolver') exceptions_factory = self._get_internal_component('exceptions_factory') config_store = self.get_component('config_store') @@ -69,7 +97,7 @@ def create_client(self, service_name, region_name=None, api_version=None, loader, endpoint_resolver, self.user_agent(), event_emitter, retryhandler, translate, response_parser_factory, exceptions_factory, config_store) - client = client_creator.create_client( + client = await client_creator.create_client( service_name=service_name, region_name=region_name, is_secure=use_ssl, endpoint_url=endpoint_url, verify=verify, credentials=credentials, scoped_config=self.get_scoped_config(), @@ -79,6 +107,50 @@ def create_client(self, service_name, region_name=None, api_version=None, monitor.register(client.meta.events) return client + def _create_credential_resolver(self): + return create_credential_resolver( + self, region_name=self._last_client_region_used) + + async def get_credentials(self): + if self._credentials is None: + self._credentials = await (self._components.get_component( + 'credential_provider').load_credentials()) + return self._credentials + + async def get_service_model(self, service_name, api_version=None): + service_description = await self.get_service_data(service_name, api_version) + return ServiceModel(service_description, service_name=service_name) + + async def get_service_data(self, service_name, api_version=None): + """ + Retrieve the fully merged data associated with a service. + """ + data_path = service_name + service_data = self.get_component('data_loader').load_service_model( + data_path, + type_name='service-2', + api_version=api_version + ) + service_id = EVENT_ALIASES.get(service_name, service_name) + self._events.emit('service-data-loaded.%s' % service_id, + service_data=service_data, + service_name=service_name, session=self) + return service_data + + async def get_available_regions(self, service_name, partition_name='aws', + allow_non_regional=False): + resolver = self._get_internal_component('endpoint_resolver') + results = [] + try: + service_data = await self.get_service_data(service_name) + endpoint_prefix = service_data['metadata'].get( + 'endpointPrefix', service_name) + results = resolver.get_available_endpoints( + endpoint_prefix, partition_name, allow_non_regional) + except UnknownServiceError: + pass + return results + def get_session(env_vars=None): """ diff --git a/aiobotocore/signers.py b/aiobotocore/signers.py new file mode 100644 index 00000000..8fd83b67 --- /dev/null +++ b/aiobotocore/signers.py @@ -0,0 +1,190 @@ +import botocore +import botocore.auth +from botocore.signers import RequestSigner, UnknownSignatureVersionError, \ + UnsupportedSignatureVersionError, create_request_object, prepare_request_dict, \ + _should_use_global_endpoint +from botocore.exceptions import UnknownClientMethodError + + +class AioRequestSigner(RequestSigner): + async def handler(self, operation_name=None, request=None, **kwargs): + # This is typically hooked up to the "request-created" event + # from a client's event emitter. When a new request is created + # this method is invoked to sign the request. + # Don't call this method directly. + return await self.sign(operation_name, request) + + async def sign(self, operation_name, request, region_name=None, + signing_type='standard', expires_in=None, + signing_name=None): + explicit_region_name = region_name + if region_name is None: + region_name = self._region_name + + if signing_name is None: + signing_name = self._signing_name + + signature_version = await self._choose_signer( + operation_name, signing_type, request.context) + + # Allow mutating request before signing + await self._event_emitter.emit( + 'before-sign.{0}.{1}'.format( + self._service_id.hyphenize(), operation_name), + request=request, signing_name=signing_name, + region_name=self._region_name, + signature_version=signature_version, request_signer=self, + operation_name=operation_name + ) + + if signature_version != botocore.UNSIGNED: + kwargs = { + 'signing_name': signing_name, + 'region_name': region_name, + 'signature_version': signature_version + } + if expires_in is not None: + kwargs['expires'] = expires_in + if not explicit_region_name and request.context.get( + 'signing', {}).get('region'): + kwargs['region_name'] = request.context['signing']['region'] + try: + auth = await self.get_auth_instance(**kwargs) + except UnknownSignatureVersionError as e: + if signing_type != 'standard': + raise UnsupportedSignatureVersionError( + signature_version=signature_version) + else: + raise e + + auth.add_auth(request) + + async def get_auth_instance(self, signing_name, region_name, + signature_version=None, **kwargs): + if signature_version is None: + signature_version = self._signature_version + + cls = botocore.auth.AUTH_TYPE_MAPS.get(signature_version) + if cls is None: + raise UnknownSignatureVersionError( + signature_version=signature_version) + + frozen_credentials = None + if self._credentials is not None: + frozen_credentials = await self._credentials.get_frozen_credentials() + kwargs['credentials'] = frozen_credentials + if cls.REQUIRES_REGION: + if self._region_name is None: + raise botocore.exceptions.NoRegionError() + kwargs['region_name'] = region_name + kwargs['service_name'] = signing_name + auth = cls(**kwargs) + return auth + + # Alias get_auth for backwards compatibility. + get_auth = get_auth_instance + + async def _choose_signer(self, operation_name, signing_type, context): + signing_type_suffix_map = { + 'presign-post': '-presign-post', + 'presign-url': '-query' + } + suffix = signing_type_suffix_map.get(signing_type, '') + + signature_version = self._signature_version + if signature_version is not botocore.UNSIGNED and not \ + signature_version.endswith(suffix): + signature_version += suffix + + handler, response = await self._event_emitter.emit_until_response( + 'choose-signer.{0}.{1}'.format( + self._service_id.hyphenize(), operation_name), + signing_name=self._signing_name, region_name=self._region_name, + signature_version=signature_version, context=context) + + if response is not None: + signature_version = response + # The suffix needs to be checked again in case we get an improper + # signature version from choose-signer. + if signature_version is not botocore.UNSIGNED and not \ + signature_version.endswith(suffix): + signature_version += suffix + + return signature_version + + async def generate_presigned_url(self, request_dict, operation_name, + expires_in=3600, region_name=None, + signing_name=None): + request = create_request_object(request_dict) + await self.sign(operation_name, request, region_name, + 'presign-url', expires_in, signing_name) + + request.prepare() + return request.url + + +def add_generate_presigned_url(class_attributes, **kwargs): + class_attributes['generate_presigned_url'] = generate_presigned_url + + +async def generate_presigned_url(self, ClientMethod, Params=None, ExpiresIn=3600, + HttpMethod=None): + """Generate a presigned url given a client, its method, and arguments + + :type ClientMethod: string + :param ClientMethod: The client method to presign for + + :type Params: dict + :param Params: The parameters normally passed to + ``ClientMethod``. + + :type ExpiresIn: int + :param ExpiresIn: The number of seconds the presigned url is valid + for. By default it expires in an hour (3600 seconds) + + :type HttpMethod: string + :param HttpMethod: The http method to use on the generated url. By + default, the http method is whatever is used in the method's model. + + :returns: The presigned url + """ + client_method = ClientMethod + params = Params + if params is None: + params = {} + expires_in = ExpiresIn + http_method = HttpMethod + context = { + 'is_presign_request': True, + 'use_global_endpoint': _should_use_global_endpoint(self), + } + + request_signer = self._request_signer + serializer = self._serializer + + try: + operation_name = self._PY_TO_OP_NAME[client_method] + except KeyError: + raise UnknownClientMethodError(method_name=client_method) + + operation_model = self.meta.service_model.operation_model( + operation_name) + + params = await self._emit_api_params(params, operation_model, context) + + # Create a request dict based on the params to serialize. + request_dict = serializer.serialize_to_request( + params, operation_model) + + # Switch out the http method if user specified it. + if http_method is not None: + request_dict['method'] = http_method + + # Prepare the request dict by including the client's endpoint url. + prepare_request_dict( + request_dict, endpoint_url=self.meta.endpoint_url, context=context) + + # Generate the presigned url. + return await request_signer.generate_presigned_url( + request_dict=request_dict, expires_in=expires_in, + operation_name=operation_name) diff --git a/aiobotocore/utils.py b/aiobotocore/utils.py new file mode 100644 index 00000000..b8f3b521 --- /dev/null +++ b/aiobotocore/utils.py @@ -0,0 +1,198 @@ +import asyncio +import logging +import json + +import aiohttp +import aiohttp.client_exceptions +from botocore.utils import ContainerMetadataFetcher, InstanceMetadataFetcher, \ + IMDSFetcher, get_environ_proxies, BadIMDSRequestError +from botocore.exceptions import MetadataRetrievalError +import botocore.awsrequest + + +logger = logging.getLogger(__name__) +RETRYABLE_HTTP_ERRORS = (aiohttp.client_exceptions.ClientError, asyncio.TimeoutError) + + +class AioIMDSFetcher(IMDSFetcher): + class Response(object): + def __init__(self, status_code, text, url): + self.status_code = status_code + self.url = url + self.text = text + self.content = text + + def __init__(self, *args, session=None, **kwargs): + super(AioIMDSFetcher, self).__init__(*args, **kwargs) + self._trust_env = bool(get_environ_proxies(self._base_url)) + self._session = session or aiohttp.ClientSession + + async def _fetch_metadata_token(self): + self._assert_enabled() + url = self._base_url + self._TOKEN_PATH + headers = { + 'x-aws-ec2-metadata-token-ttl-seconds': self._TOKEN_TTL, + } + self._add_user_agent(headers) + + request = botocore.awsrequest.AWSRequest( + method='PUT', url=url, headers=headers) + + timeout = aiohttp.ClientTimeout(total=self._timeout) + async with self._session(timeout=timeout, + trust_env=self._trust_env) as session: + for i in range(self._num_attempts): + try: + async with session.put(url, headers=headers) as resp: + text = await resp.text() + if resp.status == 200: + return text + elif resp.status in (404, 403, 405): + return None + elif resp.status in (400,): + raise BadIMDSRequestError(request) + except asyncio.TimeoutError: + return None + except RETRYABLE_HTTP_ERRORS as e: + logger.debug( + "Caught retryable HTTP exception while making metadata " + "service request to %s: %s", url, e, exc_info=True) + + return None + + async def _get_request(self, url_path, retry_func, token=None): + self._assert_enabled() + if retry_func is None: + retry_func = self._default_retry + url = self._base_url + url_path + headers = {} + if token is not None: + headers['x-aws-ec2-metadata-token'] = token + self._add_user_agent(headers) + + timeout = aiohttp.ClientTimeout(total=self._timeout) + async with self._session(timeout=timeout, + trust_env=self._trust_env) as session: + for i in range(self._num_attempts): + try: + async with session.get(url, headers=headers) as resp: + text = await resp.text() + response = self.Response(resp.status, text, resp.url) + + if not retry_func(response): + return response + except RETRYABLE_HTTP_ERRORS as e: + logger.debug( + "Caught retryable HTTP exception while making metadata " + "service request to %s: %s", url, e, exc_info=True) + raise self._RETRIES_EXCEEDED_ERROR_CLS() + + +class AioInstanceMetadataFetcher(AioIMDSFetcher, InstanceMetadataFetcher): + async def retrieve_iam_role_credentials(self): + try: + token = await self._fetch_metadata_token() + role_name = await self._get_iam_role(token) + credentials = await self._get_credentials(role_name, token) + if self._contains_all_credential_fields(credentials): + return { + 'role_name': role_name, + 'access_key': credentials['AccessKeyId'], + 'secret_key': credentials['SecretAccessKey'], + 'token': credentials['Token'], + 'expiry_time': credentials['Expiration'], + } + else: + if 'Code' in credentials and 'Message' in credentials: + logger.debug('Error response received when retrieving' + 'credentials: %s.', credentials) + return {} + except self._RETRIES_EXCEEDED_ERROR_CLS: + logger.debug("Max number of attempts exceeded (%s) when " + "attempting to retrieve data from metadata service.", + self._num_attempts) + except BadIMDSRequestError as e: + logger.debug("Bad IMDS request: %s", e.request) + return {} + + async def _get_iam_role(self, token=None): + r = await self._get_request( + url_path=self._URL_PATH, + retry_func=self._needs_retry_for_role_name, + token=token + ) + return r.text + + async def _get_credentials(self, role_name, token=None): + r = await self._get_request( + url_path=self._URL_PATH + role_name, + retry_func=self._needs_retry_for_credentials, + token=token + ) + return json.loads(r.text) + + +class AioContainerMetadataFetcher(ContainerMetadataFetcher): + def __init__(self, session=None, sleep=asyncio.sleep): + if session is None: + session = aiohttp.ClientSession + super(AioContainerMetadataFetcher, self).__init__(session, sleep) + + async def retrieve_full_uri(self, full_url, headers=None): + self._validate_allowed_url(full_url) + return await self._retrieve_credentials(full_url, headers) + + async def retrieve_uri(self, relative_uri): + """Retrieve JSON metadata from ECS metadata. + + :type relative_uri: str + :param relative_uri: A relative URI, e.g "/foo/bar?id=123" + + :return: The parsed JSON response. + + """ + full_url = self.full_url(relative_uri) + return await self._retrieve_credentials(full_url) + + async def _retrieve_credentials(self, full_url, extra_headers=None): + headers = {'Accept': 'application/json'} + if extra_headers is not None: + headers.update(extra_headers) + attempts = 0 + while True: + try: + return await self._get_response( + full_url, headers, self.TIMEOUT_SECONDS) + except MetadataRetrievalError as e: + logger.debug("Received error when attempting to retrieve " + "container metadata: %s", e, exc_info=True) + await self._sleep(self.SLEEP_TIME) + attempts += 1 + if attempts >= self.RETRY_ATTEMPTS: + raise + + async def _get_response(self, full_url, headers, timeout): + try: + timeout = aiohttp.ClientTimeout(total=self.TIMEOUT_SECONDS) + async with self._session(timeout=timeout) as session: + async with session.get(full_url, headers=headers) as resp: + if resp.status != 200: + text = await resp.text() + raise MetadataRetrievalError( + error_msg=( + "Received non 200 response (%d) " + "from ECS metadata: %s" + ) % (resp.status, text)) + try: + return await resp.json() + except ValueError: + text = await resp.text() + error_msg = ( + "Unable to parse JSON returned from ECS metadata services" + ) + logger.debug('%s:%s', error_msg, text) + raise MetadataRetrievalError(error_msg=error_msg) + except RETRYABLE_HTTP_ERRORS as e: + error_msg = ("Received error when attempting to retrieve " + "ECS metadata: %s" % e) + raise MetadataRetrievalError(error_msg=error_msg) diff --git a/aiobotocore/waiter.py b/aiobotocore/waiter.py index 0a71390e..9f2b55a5 100644 --- a/aiobotocore/waiter.py +++ b/aiobotocore/waiter.py @@ -2,7 +2,7 @@ # WaiterModel is required for client.py import from botocore.exceptions import ClientError -from botocore.waiter import WaiterModel # noqa: F401 +from botocore.waiter import WaiterModel # noqa: F401, lgtm[py/unused-import] from botocore.waiter import Waiter, xform_name, logger, WaiterError, \ NormalizedOperationMethod as _NormalizedOperationMethod from botocore.docs.docstring import WaiterDocstring diff --git a/setup.py b/setup.py index 5eb9cf49..b0c5b5f5 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ # NOTE: When updating botocore make sure to update awscli/boto3 versions below install_requires = [ # pegged to also match items in `extras_require` - 'botocore>=1.15.3,<1.15.16', + 'botocore>=1.15.32,<1.15.33', 'aiohttp>=3.3.1', 'wrapt>=1.10.10', 'aioitertools>=0.5.1', @@ -19,8 +19,8 @@ def read(f): extras_require = { - 'awscli': ['awscli==1.18.3'], - 'boto3': ['boto3==1.12.3'], + 'awscli': ['awscli==1.18.32'], + 'boto3': ['boto3==1.12.32'], } @@ -33,9 +33,8 @@ def read_version(): match = regexp.match(line) if match is not None: return match.group(1) - else: - raise RuntimeError('Cannot find version in ' - 'aiobotocore/__init__.py') + raise RuntimeError('Cannot find version in ' + 'aiobotocore/__init__.py') classifiers = [ diff --git a/tests/botocore/test_credentials.py b/tests/botocore/test_credentials.py new file mode 100644 index 00000000..ecbe17e8 --- /dev/null +++ b/tests/botocore/test_credentials.py @@ -0,0 +1,1145 @@ +""" +These tests have been taken from +https://github.com/boto/botocore/blob/develop/tests/unit/test_credentials.py +and adapted to work with asyncio and pytest +""" +import asyncio +import datetime +import json +import subprocess + +import mock +from typing import Optional + +import pytest +import botocore.exceptions +from dateutil.tz import tzlocal + +from aiobotocore.session import AioSession +from aiobotocore import credentials +from botocore.configprovider import ConfigValueStore +from botocore.utils import FileWebIdentityTokenLoader + + +# From class TestCredentials(BaseEnvVar): +@pytest.mark.moto +@pytest.mark.parametrize("access,secret", [ + ('foo\xe2\x80\x99', 'bar\xe2\x80\x99'), (u'foo', u'bar')]) +def test_credentials_normalization(access, secret): + c = credentials.AioCredentials(access, secret) + assert isinstance(c.access_key, type(u'u')) + assert isinstance(c.secret_key, type(u'u')) + + +# From class TestRefreshableCredentials(TestCredentials): +@pytest.fixture +def refreshable_creds(): + def _f(mock_time_return_value=None, refresher_return_value='METADATA'): + refresher = mock.AsyncMock() + future_time = datetime.datetime.now(tzlocal()) + datetime.timedelta(hours=24) + expiry_time = datetime.datetime.now(tzlocal()) - datetime.timedelta(minutes=30) + metadata = { + 'access_key': 'NEW-ACCESS', + 'secret_key': 'NEW-SECRET', + 'token': 'NEW-TOKEN', + 'expiry_time': future_time.isoformat(), + 'role_name': 'rolename', + } + refresher.return_value = metadata if refresher_return_value == 'METADATA' \ + else refresher_return_value + mock_time = mock.Mock() + mock_time.return_value = mock_time_return_value + creds = credentials.AioRefreshableCredentials( + 'ORIGINAL-ACCESS', 'ORIGINAL-SECRET', 'ORIGINAL-TOKEN', + expiry_time, refresher, 'iam-role', time_fetcher=mock_time + ) + return creds + return _f + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_refreshablecredentials_get_credentials_set(refreshable_creds): + creds = refreshable_creds( + mock_time_return_value=(datetime.datetime.now(tzlocal()) - + datetime.timedelta(minutes=60)) + ) + + assert not creds.refresh_needed() + + credentials_set = await creds.get_frozen_credentials() + assert isinstance(credentials_set, credentials.ReadOnlyCredentials) + assert credentials_set.access_key == 'ORIGINAL-ACCESS' + assert credentials_set.secret_key == 'ORIGINAL-SECRET' + assert credentials_set.token == 'ORIGINAL-TOKEN' + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_refreshablecredentials_refresh_returns_empty_dict(refreshable_creds): + creds = refreshable_creds( + mock_time_return_value=datetime.datetime.now(tzlocal()), + refresher_return_value={} + ) + + assert creds.refresh_needed() + + with pytest.raises(botocore.exceptions.CredentialRetrievalError): + await creds.get_frozen_credentials() + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_refreshablecredentials_refresh_returns_none(refreshable_creds): + creds = refreshable_creds( + mock_time_return_value=datetime.datetime.now(tzlocal()), + refresher_return_value=None + ) + + assert creds.refresh_needed() + + with pytest.raises(botocore.exceptions.CredentialRetrievalError): + await creds.get_frozen_credentials() + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_refreshablecredentials_refresh_returns_partial(refreshable_creds): + creds = refreshable_creds( + mock_time_return_value=datetime.datetime.now(tzlocal()), + refresher_return_value={'access_key': 'akid'} + ) + + assert creds.refresh_needed() + + with pytest.raises(botocore.exceptions.CredentialRetrievalError): + await creds.get_frozen_credentials() + + +# From class TestDeferredRefreshableCredentials(unittest.TestCase): +@pytest.fixture +def deferrable_creds(): + def _f(mock_time_return_value=None, refresher_return_value='METADATA'): + refresher = mock.AsyncMock() + future_time = datetime.datetime.now(tzlocal()) + datetime.timedelta(hours=24) + metadata = { + 'access_key': 'NEW-ACCESS', + 'secret_key': 'NEW-SECRET', + 'token': 'NEW-TOKEN', + 'expiry_time': future_time.isoformat(), + 'role_name': 'rolename', + } + refresher.return_value = metadata if refresher_return_value == 'METADATA' \ + else refresher_return_value + mock_time = mock.Mock() + mock_time.return_value = (mock_time_return_value or + datetime.datetime.now(tzlocal())) + creds = credentials.AioDeferredRefreshableCredentials( + refresher, 'iam-role', time_fetcher=mock_time + ) + return creds + return _f + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_deferrablecredentials_get_credentials_set(deferrable_creds): + creds = deferrable_creds() + + creds._refresh_using.assert_not_called() + + await creds.get_frozen_credentials() + assert creds._refresh_using.call_count == 1 + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_deferrablecredentials_refresh_only_called_once(deferrable_creds): + creds = deferrable_creds() + + creds._refresh_using.assert_not_called() + + for _ in range(5): + await creds.get_frozen_credentials() + + assert creds._refresh_using.call_count == 1 + + +# From class TestAssumeRoleCredentialFetcher(BaseEnvVar): +def assume_role_client_creator(with_response): + class _Client(object): + def __init__(self, resp): + self._resp = resp + + self._called = [] + self._call_count = 0 + + async def assume_role(self, *args, **kwargs): + self._call_count += 1 + self._called.append((args, kwargs)) + + if isinstance(self._resp, list): + return self._resp.pop(0) + return self._resp + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + return mock.Mock(return_value=_Client(with_response)) + + +def some_future_time(): + timeobj = datetime.datetime.now(tzlocal()) + return timeobj + datetime.timedelta(hours=24) + + +def get_expected_creds_from_response(response): + expiration = response['Credentials']['Expiration'] + if isinstance(expiration, datetime.datetime): + expiration = expiration.isoformat() + return { + 'access_key': response['Credentials']['AccessKeyId'], + 'secret_key': response['Credentials']['SecretAccessKey'], + 'token': response['Credentials']['SessionToken'], + 'expiry_time': expiration + } + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_assumerolefetcher_no_cache(): + response = { + 'Credentials': { + 'AccessKeyId': 'foo', + 'SecretAccessKey': 'bar', + 'SessionToken': 'baz', + 'Expiration': some_future_time().isoformat() + }, + } + refresher = credentials.AioAssumeRoleCredentialFetcher( + assume_role_client_creator(response), + credentials.AioCredentials('a', 'b', 'c'), + 'myrole' + ) + + expected_response = get_expected_creds_from_response(response) + response = await refresher.fetch_credentials() + + assert response == expected_response + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_assumerolefetcher_cache_key_with_role_session_name(): + response = { + 'Credentials': { + 'AccessKeyId': 'foo', + 'SecretAccessKey': 'bar', + 'SessionToken': 'baz', + 'Expiration': some_future_time().isoformat() + }, + } + cache = {} + client_creator = assume_role_client_creator(response) + role_session_name = 'my_session_name' + + refresher = credentials.AioAssumeRoleCredentialFetcher( + client_creator, + credentials.AioCredentials('a', 'b', 'c'), + 'myrole', + cache=cache, + extra_args={'RoleSessionName': role_session_name} + ) + await refresher.fetch_credentials() + + # This is the sha256 hex digest of the expected assume role args. + cache_key = ( + '2964201f5648c8be5b9460a9cf842d73a266daf2' + ) + assert cache_key in cache + assert cache[cache_key] == response + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_assumerolefetcher_cache_in_cache_but_expired(): + response = { + 'Credentials': { + 'AccessKeyId': 'foo', + 'SecretAccessKey': 'bar', + 'SessionToken': 'baz', + 'Expiration': some_future_time().isoformat(), + }, + } + client_creator = assume_role_client_creator(response) + cache = { + 'development--myrole': { + 'Credentials': { + 'AccessKeyId': 'foo-cached', + 'SecretAccessKey': 'bar-cached', + 'SessionToken': 'baz-cached', + 'Expiration': datetime.datetime.now(tzlocal()), + } + } + } + + refresher = credentials.AioAssumeRoleCredentialFetcher( + client_creator, + credentials.AioCredentials('a', 'b', 'c'), + 'myrole', + cache=cache + ) + expected = get_expected_creds_from_response(response) + response = await refresher.fetch_credentials() + + assert response == expected + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_assumerolefetcher_mfa(): + response = { + 'Credentials': { + 'AccessKeyId': 'foo', + 'SecretAccessKey': 'bar', + 'SessionToken': 'baz', + 'Expiration': some_future_time().isoformat(), + }, + } + client_creator = assume_role_client_creator(response) + prompter = mock.Mock(return_value='token-code') + mfa_serial = 'mfa' + + refresher = credentials.AioAssumeRoleCredentialFetcher( + client_creator, + credentials.AioCredentials('a', 'b', 'c'), + 'myrole', + extra_args={'SerialNumber': mfa_serial}, mfa_prompter=prompter + ) + await refresher.fetch_credentials() + + # Slighly different to the botocore mock + client = client_creator.return_value + assert client._call_count == 1 + call_kwargs = client._called[0][1] + assert call_kwargs['SerialNumber'] == 'mfa' + assert call_kwargs['RoleArn'] == 'myrole' + assert call_kwargs['TokenCode'] == 'token-code' + + +# From class TestAssumeRoleWithWebIdentityCredentialFetcher(BaseEnvVar): +def assume_role_web_identity_client_creator(with_response): + class _Client(object): + def __init__(self, resp): + self._resp = resp + + self._called = [] + self._call_count = 0 + + async def assume_role_with_web_identity(self, *args, **kwargs): + self._call_count += 1 + self._called.append((args, kwargs)) + + if isinstance(self._resp, list): + return self._resp.pop(0) + return self._resp + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + return mock.Mock(return_value=_Client(with_response)) + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_webidentfetcher_no_cache(): + response = { + 'Credentials': { + 'AccessKeyId': 'foo', + 'SecretAccessKey': 'bar', + 'SessionToken': 'baz', + 'Expiration': some_future_time().isoformat() + }, + } + refresher = credentials.AioAssumeRoleWithWebIdentityCredentialFetcher( + assume_role_web_identity_client_creator(response), + lambda: 'totally.a.token', + 'myrole' + ) + + expected_response = get_expected_creds_from_response(response) + response = await refresher.fetch_credentials() + + assert response == expected_response + + +# From class TestInstanceMetadataProvider(BaseEnvVar): +@pytest.mark.moto +@pytest.mark.asyncio +async def test_instancemetadata_load(): + timeobj = datetime.datetime.now(tzlocal()) + timestamp = (timeobj + datetime.timedelta(hours=24)).isoformat() + + fetcher = mock.AsyncMock() + fetcher.retrieve_iam_role_credentials.return_value = { + 'access_key': 'a', + 'secret_key': 'b', + 'token': 'c', + 'expiry_time': timestamp, + 'role_name': 'myrole', + } + + provider = credentials.AioInstanceMetadataProvider( + iam_role_fetcher=fetcher + ) + creds = await provider.load() + assert creds is not None + assert creds.method == 'iam-role' + + creds = await creds.get_frozen_credentials() + assert creds.access_key == 'a' + assert creds.secret_key == 'b' + assert creds.token == 'c' + + +# From class CredentialResolverTest(BaseEnvVar): +@pytest.fixture +def credential_provider(): + def _f(method, canonical_name, creds='None'): + # 'None' so that we can differentiate from None + provider = mock.AsyncMock() + provider.METHOD = method + provider.CANONICAL_NAME = canonical_name + if creds != 'None': + provider.load.return_value = creds + return provider + return _f + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_credresolver_load_credentials_single_provider(credential_provider): + provider1 = credential_provider('provider1', 'CustomProvider1', + credentials.AioCredentials('a', 'b', 'c')) + resolver = credentials.AioCredentialResolver(providers=[provider1]) + + creds = await resolver.load_credentials() + assert creds.access_key == 'a' + assert creds.secret_key == 'b' + assert creds.token == 'c' + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_credresolver_no_providers(credential_provider): + provider1 = credential_provider('provider1', 'CustomProvider1', + None) + resolver = credentials.AioCredentialResolver(providers=[provider1]) + + creds = await resolver.load_credentials() + assert creds is None + + +# From class TestCanonicalNameSourceProvider(BaseEnvVar): +@pytest.mark.moto +@pytest.mark.asyncio +async def test_canonicalsourceprovider_source_creds(credential_provider): + creds = credentials.AioCredentials('a', 'b', 'c') + provider1 = credential_provider('provider1', 'CustomProvider1', creds) + provider2 = credential_provider('provider2', 'CustomProvider2') + provider = credentials.AioCanonicalNameCredentialSourcer( + providers=[provider1, provider2]) + + result = await provider.source_credentials('CustomProvider1') + assert result is creds + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_canonicalsourceprovider_source_creds_case_insensitive( + credential_provider): + creds = credentials.AioCredentials('a', 'b', 'c') + provider1 = credential_provider('provider1', 'CustomProvider1', creds) + provider2 = credential_provider('provider2', 'CustomProvider2') + provider = credentials.AioCanonicalNameCredentialSourcer( + providers=[provider1, provider2]) + + result = await provider.source_credentials('cUsToMpRoViDeR1') + assert result is creds + + +# From class TestAssumeRoleCredentialProvider(unittest.TestCase): +@pytest.fixture +def assumerolecredprovider_config_loader(): + fake_config = { + 'profiles': { + 'development': { + 'role_arn': 'myrole', + 'source_profile': 'longterm', + }, + 'longterm': { + 'aws_access_key_id': 'akid', + 'aws_secret_access_key': 'skid', + }, + 'non-static': { + 'role_arn': 'myrole', + 'credential_source': 'Environment' + }, + 'chained': { + 'role_arn': 'chained-role', + 'source_profile': 'development' + } + } + } + + def _f(config=None): + return lambda: (config or fake_config) + + return _f + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_assumerolecredprovider_assume_role_no_cache( + credential_provider, + assumerolecredprovider_config_loader): + creds = credentials.AioCredentials('a', 'b', 'c') + provider1 = credential_provider('provider1', 'CustomProvider1', creds) + provider2 = credential_provider('provider2', 'CustomProvider2') + provider = credentials.AioCanonicalNameCredentialSourcer( + providers=[provider1, provider2]) + + result = await provider.source_credentials('cUsToMpRoViDeR1') + assert result is creds + + response = { + 'Credentials': { + 'AccessKeyId': 'foo', + 'SecretAccessKey': 'bar', + 'SessionToken': 'baz', + 'Expiration': some_future_time().isoformat() + }, + } + client_creator = assume_role_client_creator(response) + provider = credentials.AioAssumeRoleProvider( + assumerolecredprovider_config_loader(), + client_creator, cache={}, profile_name='development') + + creds = await provider.load() + + # So calling .access_key would cause deferred credentials to be loaded, + # according to the source, you're supposed to call get_frozen_credentials + # so will do that. + creds = await creds.get_frozen_credentials() + assert creds.access_key == 'foo' + assert creds.secret_key == 'bar' + assert creds.token == 'baz' + + +# MFA +@pytest.mark.moto +@pytest.mark.asyncio +async def test_assumerolecredprovider_mfa( + credential_provider, + assumerolecredprovider_config_loader): + + fake_config = { + 'profiles': { + 'development': { + 'role_arn': 'myrole', + 'source_profile': 'longterm', + 'mfa_serial': 'mfa' + }, + 'longterm': { + 'aws_access_key_id': 'akid', + 'aws_secret_access_key': 'skid', + }, + 'non-static': { + 'role_arn': 'myrole', + 'credential_source': 'Environment' + }, + 'chained': { + 'role_arn': 'chained-role', + 'source_profile': 'development' + } + } + } + + response = { + 'Credentials': { + 'AccessKeyId': 'foo', + 'SecretAccessKey': 'bar', + 'SessionToken': 'baz', + 'Expiration': some_future_time().isoformat() + }, + } + client_creator = assume_role_client_creator(response) + prompter = mock.Mock(return_value='token-code') + provider = credentials.AioAssumeRoleProvider( + assumerolecredprovider_config_loader(fake_config), + client_creator, cache={}, profile_name='development', prompter=prompter) + + creds = await provider.load() + # So calling .access_key would cause deferred credentials to be loaded, + # according to the source, you're supposed to call get_frozen_credentials + # so will do that. + await creds.get_frozen_credentials() + + client = client_creator.return_value + assert client._call_count == 1 + call_kwargs = client._called[0][1] + assert call_kwargs['SerialNumber'] == 'mfa' + assert call_kwargs['RoleArn'] == 'myrole' + assert call_kwargs['TokenCode'] == 'token-code' + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_assumerolecredprovider_mfa_cannot_refresh_credentials( + credential_provider, + assumerolecredprovider_config_loader): + + fake_config = { + 'profiles': { + 'development': { + 'role_arn': 'myrole', + 'source_profile': 'longterm', + 'mfa_serial': 'mfa' + }, + 'longterm': { + 'aws_access_key_id': 'akid', + 'aws_secret_access_key': 'skid', + }, + 'non-static': { + 'role_arn': 'myrole', + 'credential_source': 'Environment' + }, + 'chained': { + 'role_arn': 'chained-role', + 'source_profile': 'development' + } + } + } + + expiration_time = some_future_time() + response = { + 'Credentials': { + 'AccessKeyId': 'foo', + 'SecretAccessKey': 'bar', + 'SessionToken': 'baz', + 'Expiration': expiration_time.isoformat() + }, + } + client_creator = assume_role_client_creator(response) + prompter = mock.Mock(return_value='token-code') + provider = credentials.AioAssumeRoleProvider( + assumerolecredprovider_config_loader(fake_config), + client_creator, cache={}, profile_name='development', prompter=prompter) + + local_now = mock.Mock(return_value=datetime.datetime.now(tzlocal())) + with mock.patch('aiobotocore.credentials._local_now', local_now): + creds = await provider.load() + await creds.get_frozen_credentials() + + local_now.return_value = expiration_time + with pytest.raises(credentials.RefreshWithMFAUnsupportedError): + await creds.get_frozen_credentials() + + +# From class TestAssumeRoleWithWebIdentityCredentialProvider +@pytest.mark.moto +@pytest.mark.asyncio +async def test_assumerolewebidentprovider_no_cache(): + future = datetime.datetime.now(tzlocal()) + datetime.timedelta(hours=24) + + response = { + 'Credentials': { + 'AccessKeyId': 'foo', + 'SecretAccessKey': 'bar', + 'SessionToken': 'baz', + 'Expiration': future.isoformat() + }, + } + + # client + client_creator = assume_role_web_identity_client_creator(response) + + mock_loader = mock.Mock(spec=FileWebIdentityTokenLoader) + mock_loader.return_value = 'totally.a.token' + mock_loader_cls = mock.Mock(return_value=mock_loader) + + config = { + 'profiles': { + 'some-profile': { + 'role_arn': 'arn:aws:iam::123:role/role-name', + 'web_identity_token_file': '/some/path/token.jwt' + } + } + } + + provider = credentials.AioAssumeRoleWithWebIdentityProvider( + load_config=lambda: config, + client_creator=client_creator, + cache={}, + profile_name='some-profile', + token_loader_cls=mock_loader_cls + ) + + creds = await provider.load() + creds = await creds.get_frozen_credentials() + assert creds.access_key == 'foo' + assert creds.secret_key == 'bar' + assert creds.token == 'baz' + + mock_loader_cls.assert_called_with('/some/path/token.jwt') + + +# From class TestContainerProvider(BaseEnvVar): +def full_url(url): + return 'http://%s%s' % (credentials.AioContainerMetadataFetcher.IP_ADDRESS, url) + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_containerprovider_assume_role_no_cache(): + environ = { + 'AWS_CONTAINER_CREDENTIALS_RELATIVE_URI': '/latest/credentials?id=foo' + } + fetcher = mock.AsyncMock() + fetcher.full_url = full_url + + timeobj = datetime.datetime.now(tzlocal()) + timestamp = (timeobj + datetime.timedelta(hours=24)).isoformat() + fetcher.retrieve_full_uri.return_value = { + "AccessKeyId": "access_key", + "SecretAccessKey": "secret_key", + "Token": "token", + "Expiration": timestamp, + } + provider = credentials.AioContainerProvider(environ, fetcher) + # Will return refreshable credentials + creds = await provider.load() + + url = full_url('/latest/credentials?id=foo') + fetcher.retrieve_full_uri.assert_called_with(url, headers=None) + + assert creds.method == 'container-role' + + creds = await creds.get_frozen_credentials() + assert creds.access_key == 'access_key' + assert creds.secret_key == 'secret_key' + assert creds.token == 'token' + + +# From class TestEnvVar(BaseEnvVar): +@pytest.mark.moto +@pytest.mark.asyncio +async def test_envvarprovider_env_var_present(): + environ = { + 'AWS_ACCESS_KEY_ID': 'foo', + 'AWS_SECRET_ACCESS_KEY': 'bar', + } + provider = credentials.AioEnvProvider(environ) + creds = await provider.load() + assert isinstance(creds, credentials.AioCredentials) + + assert creds.access_key == 'foo' + assert creds.secret_key == 'bar' + assert creds.method == 'env' + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_envvarprovider_env_var_absent(): + environ = {} + provider = credentials.AioEnvProvider(environ) + creds = await provider.load() + assert creds is None + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_envvarprovider_env_var_expiry(): + expiry_time = datetime.datetime.now(tzlocal()) - datetime.timedelta(hours=1) + environ = { + 'AWS_ACCESS_KEY_ID': 'foo', + 'AWS_SECRET_ACCESS_KEY': 'bar', + 'AWS_CREDENTIAL_EXPIRATION': expiry_time.isoformat() + } + provider = credentials.AioEnvProvider(environ) + creds = await provider.load() + assert isinstance(creds, credentials.AioRefreshableCredentials) + + del environ['AWS_CREDENTIAL_EXPIRATION'] + + with pytest.raises(botocore.exceptions.PartialCredentialsError): + await creds.get_frozen_credentials() + + +# From class TestConfigFileProvider(BaseEnvVar): +@pytest.fixture +def profile_config(): + parser = mock.Mock() + profile_config = { + 'aws_access_key_id': 'a', + 'aws_secret_access_key': 'b', + 'aws_session_token': 'c', + # Non creds related configs can be in a session's # config. + 'region': 'us-west-2', + 'output': 'json', + } + parsed = {'profiles': {'default': profile_config}} + parser.return_value = parsed + return parser + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_configprovider_file_exists(profile_config): + provider = credentials.AioConfigProvider('cli.cfg', 'default', profile_config) + creds = await provider.load() + assert isinstance(creds, credentials.AioCredentials) + + assert creds.access_key == 'a' + assert creds.secret_key == 'b' + assert creds.method == 'config-file' + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_configprovider_file_missing_profile(profile_config): + provider = credentials.AioConfigProvider('cli.cfg', 'NOT-default', profile_config) + creds = await provider.load() + assert creds is None + + +# From class TestSharedCredentialsProvider(BaseEnvVar): +@pytest.mark.moto +@pytest.mark.asyncio +async def test_sharedcredentials_file_exists(): + parser = mock.Mock() + parser.return_value = { + 'default': { + 'aws_access_key_id': 'foo', + 'aws_secret_access_key': 'bar', + } + } + + provider = credentials.AioSharedCredentialProvider( + creds_filename='~/.aws/creds', profile_name='default', + ini_parser=parser) + creds = await provider.load() + assert isinstance(creds, credentials.AioCredentials) + + assert creds.access_key == 'foo' + assert creds.secret_key == 'bar' + assert creds.method == 'shared-credentials-file' + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_sharedcredentials_file_missing(): + parser = mock.Mock() + parser.side_effect = botocore.exceptions.ConfigNotFound(path='foo') + + provider = credentials.AioSharedCredentialProvider( + creds_filename='~/.aws/creds', profile_name='dev', + ini_parser=parser) + creds = await provider.load() + assert creds is None + + +# From class TestBotoProvider(BaseEnvVar): +@pytest.mark.moto +@pytest.mark.asyncio +async def test_botoprovider_file_exists(): + parser = mock.Mock() + parser.return_value = { + 'Credentials': { + 'aws_access_key_id': 'a', + 'aws_secret_access_key': 'b', + } + } + + provider = credentials.AioBotoProvider(environ={}, ini_parser=parser) + creds = await provider.load() + assert isinstance(creds, credentials.AioCredentials) + + assert creds.access_key == 'a' + assert creds.secret_key == 'b' + assert creds.method == 'boto-config' + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_botoprovider_file_missing(): + parser = mock.Mock() + parser.side_effect = botocore.exceptions.ConfigNotFound(path='foo') + + provider = credentials.AioBotoProvider(environ={}, ini_parser=parser) + creds = await provider.load() + assert creds is None + + +# From class TestOriginalEC2Provider(BaseEnvVar): +@pytest.mark.moto +@pytest.mark.asyncio +async def test_originalec2provider_file_exists(): + envrion = {'AWS_CREDENTIAL_FILE': 'foo.cfg'} + parser = mock.Mock() + parser.return_value = { + 'AWSAccessKeyId': 'a', + 'AWSSecretKey': 'b', + } + + provider = credentials.AioOriginalEC2Provider(environ=envrion, parser=parser) + creds = await provider.load() + assert isinstance(creds, credentials.AioCredentials) + + assert creds.access_key == 'a' + assert creds.secret_key == 'b' + assert creds.method == 'ec2-credentials-file' + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_originalec2provider_file_missing(): + provider = credentials.AioOriginalEC2Provider(environ={}) + creds = await provider.load() + assert creds is None + + +# From class TestProcessProvider +@pytest.fixture() +def process_provider(): + def _f(profile_name='default', loaded_config=None, invoked_process=None): + load_config = mock.Mock(return_value=loaded_config) + popen_mock = mock.Mock(return_value=invoked_process or mock.Mock(), + spec=asyncio.create_subprocess_exec) + return popen_mock, credentials.AioProcessProvider(profile_name, + load_config, + popen=popen_mock) + return _f + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_processprovider_retrieve_refereshable_creds(process_provider): + config = {'profiles': {'default': {'credential_process': 'my-process'}}} + invoked_process = mock.AsyncMock() + stdout = json.dumps({ + 'Version': 1, + 'AccessKeyId': 'foo', + 'SecretAccessKey': 'bar', + 'SessionToken': 'baz', + 'Expiration': '2999-01-01T00:00:00Z', + }) + invoked_process.communicate.return_value = \ + (stdout.encode('utf-8'), ''.encode('utf-8')) + invoked_process.returncode = 0 + + popen_mock, provider = process_provider( + loaded_config=config, invoked_process=invoked_process) + creds = await provider.load() + assert isinstance(creds, credentials.AioRefreshableCredentials) + assert creds is not None + assert creds.method == 'custom-process' + + creds = await creds.get_frozen_credentials() + assert creds.access_key == 'foo' + assert creds.secret_key == 'bar' + assert creds.token == 'baz' + popen_mock.assert_called_with(['my-process'], + stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_processprovider_retrieve_creds(process_provider): + config = {'profiles': {'default': {'credential_process': 'my-process'}}} + invoked_process = mock.AsyncMock() + stdout = json.dumps({ + 'Version': 1, + 'AccessKeyId': 'foo', + 'SecretAccessKey': 'bar', + 'SessionToken': 'baz' + }) + invoked_process.communicate.return_value = \ + (stdout.encode('utf-8'), ''.encode('utf-8')) + invoked_process.returncode = 0 + + popen_mock, provider = process_provider( + loaded_config=config, invoked_process=invoked_process) + creds = await provider.load() + assert isinstance(creds, credentials.AioCredentials) + assert creds is not None + assert creds.access_key == 'foo' + assert creds.secret_key == 'bar' + assert creds.token == 'baz' + assert creds.method == 'custom-process' + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_processprovider_bad_version(process_provider): + config = {'profiles': {'default': {'credential_process': 'my-process'}}} + invoked_process = mock.AsyncMock() + stdout = json.dumps({ + 'Version': 2, + 'AccessKeyId': 'foo', + 'SecretAccessKey': 'bar', + 'SessionToken': 'baz', + 'Expiration': '2999-01-01T00:00:00Z', + }) + invoked_process.communicate.return_value = \ + (stdout.encode('utf-8'), ''.encode('utf-8')) + invoked_process.returncode = 0 + + popen_mock, provider = process_provider( + loaded_config=config, invoked_process=invoked_process) + with pytest.raises(botocore.exceptions.CredentialRetrievalError): + await provider.load() + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_processprovider_missing_field(process_provider): + config = {'profiles': {'default': {'credential_process': 'my-process'}}} + invoked_process = mock.AsyncMock() + stdout = json.dumps({ + 'Version': 1, + 'SecretAccessKey': 'bar', + 'SessionToken': 'baz', + 'Expiration': '2999-01-01T00:00:00Z', + }) + invoked_process.communicate.return_value = \ + (stdout.encode('utf-8'), ''.encode('utf-8')) + invoked_process.returncode = 0 + + popen_mock, provider = process_provider( + loaded_config=config, invoked_process=invoked_process) + with pytest.raises(botocore.exceptions.CredentialRetrievalError): + await provider.load() + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_processprovider_bad_exitcode(process_provider): + config = {'profiles': {'default': {'credential_process': 'my-process'}}} + invoked_process = mock.AsyncMock() + stdout = 'lah' + invoked_process.communicate.return_value = \ + (stdout.encode('utf-8'), ''.encode('utf-8')) + invoked_process.returncode = 1 + + popen_mock, provider = process_provider( + loaded_config=config, invoked_process=invoked_process) + with pytest.raises(botocore.exceptions.CredentialRetrievalError): + await provider.load() + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_processprovider_bad_config(process_provider): + config = {'profiles': {'default': {'credential_process': None}}} + invoked_process = mock.AsyncMock() + stdout = json.dumps({ + 'Version': 2, + 'AccessKeyId': 'foo', + 'SecretAccessKey': 'bar', + 'SessionToken': 'baz', + 'Expiration': '2999-01-01T00:00:00Z', + }) + invoked_process.communicate.return_value = \ + (stdout.encode('utf-8'), ''.encode('utf-8')) + invoked_process.returncode = 0 + + popen_mock, provider = process_provider( + loaded_config=config, invoked_process=invoked_process) + creds = await provider.load() + assert creds is None + + +# From class TestCreateCredentialResolver +@pytest.fixture +def mock_session(): + def _f(config_loader: Optional[ConfigValueStore] = None) -> AioSession: + if not config_loader: + config_loader = ConfigValueStore() + + fake_instance_variables = { + 'credentials_file': 'a', + 'legacy_config_file': 'b', + 'config_file': 'c', + 'metadata_service_timeout': 1, + 'metadata_service_num_attempts': 1, + } + + def fake_get_component(self, key): + if key == 'config_provider': + return config_loader + return None + + def fake_set_config_variable(self, logical_name, value): + fake_instance_variables[logical_name] = value + + session = mock.Mock(spec=AioSession) + session.get_component = fake_get_component + session.full_config = {} + + for name, value in fake_instance_variables.items(): + config_loader.set_config_variable(name, value) + + session.get_config_variable = config_loader.get_config_variable + session.set_config_variable = fake_set_config_variable + + return session + return _f + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_createcredentialresolver(mock_session): + session = mock_session() + + resolver = credentials.create_credential_resolver(session) + assert isinstance(resolver, credentials.AioCredentialResolver) + + +# Disabled on travis as we cant easily disable the tests properly and +# travis has an IAM role which can't be applied to the mock session +# @pytest.mark.moto +@pytest.mark.asyncio +async def test_get_credentials(mock_session): + session = mock_session() + + creds = await credentials.get_credentials(session) + + assert creds is None + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_from_aiocredentials_is_none(): + creds = credentials.AioCredentials.from_credentials(None) + assert creds is None + creds = credentials.AioRefreshableCredentials.from_refreshable_credentials(None) + assert creds is None + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_session_credentials(): + with mock.patch('aiobotocore.credentials.AioCredential' + 'Resolver.load_credentials') as mock_obj: + mock_obj.return_value = 'somecreds' + + session = AioSession() + creds = await session.get_credentials() + assert creds == 'somecreds' diff --git a/tests/botocore/test_signers.py b/tests/botocore/test_signers.py new file mode 100644 index 00000000..0e4035e0 --- /dev/null +++ b/tests/botocore/test_signers.py @@ -0,0 +1,151 @@ +import pytest +import mock + +import aiobotocore +import aiobotocore.credentials +import aiobotocore.signers +import botocore.auth +from botocore.model import ServiceId +from botocore.awsrequest import AWSRequest +from botocore.exceptions import UnknownClientMethodError, NoRegionError, \ + UnknownSignatureVersionError + + +# From class TestSigner +@pytest.fixture +async def base_signer_setup() -> dict: + emitter = mock.AsyncMock() + emitter.emit_until_response.return_value = (None, None) + credentials = aiobotocore.credentials.AioCredentials('key', 'secret') + + signer = aiobotocore.signers.AioRequestSigner(ServiceId('service_name'), + 'region_name', 'signing_name', + 'v4', credentials, emitter) + return { + 'credentials': credentials, + 'emitter': emitter, + 'signer': signer, + 'fixed_credentials': await credentials.get_frozen_credentials(), + 'request': AWSRequest() + } + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_testsigner_get_auth(base_signer_setup: dict): + auth_cls = mock.Mock() + with mock.patch.dict(botocore.auth.AUTH_TYPE_MAPS, {'v4': auth_cls}): + signer = base_signer_setup['signer'] + auth = await signer.get_auth('service_name', 'region_name') + + assert auth_cls.return_value is auth + auth_cls.assert_called_with( + credentials=base_signer_setup['fixed_credentials'], + service_name='service_name', + region_name='region_name' + ) + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_testsigner_region_required_for_sig4(base_signer_setup: dict): + signer = aiobotocore.signers.AioRequestSigner( + ServiceId('service_name'), None, 'signing_name', + 'v4', base_signer_setup['credentials'], base_signer_setup['emitter']) + + with pytest.raises(NoRegionError): + await signer.sign('operation_name', base_signer_setup['request']) + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_testsigner_custom_sign_version(base_signer_setup: dict): + signer = base_signer_setup['signer'] + with pytest.raises(UnknownSignatureVersionError): + await signer.get_auth('service_name', 'region_name', + signature_version='bad') + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_testsigner_choose_signer_override(base_signer_setup: dict): + auth_cls = mock.Mock() + auth_cls.REQUIRES_REGION = False + base_signer_setup['emitter'].emit_until_response.return_value = (None, 'custom') + + with mock.patch.dict(botocore.auth.AUTH_TYPE_MAPS, {'custom': auth_cls}): + signer = base_signer_setup['signer'] + request = base_signer_setup['request'] + await signer.sign('operation_name', request) + + fixed_credentials = base_signer_setup['fixed_credentials'] + auth_cls.assert_called_with(credentials=fixed_credentials) + auth_cls.return_value.add_auth.assert_called_with(request) + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_testsigner_generate_presigned_url(base_signer_setup: dict): + auth_cls = mock.Mock() + auth_cls.REQUIRES_REGION = True + + request_dict = { + 'headers': {}, + 'url': 'https://foo.com', + 'body': b'', + 'url_path': '/', + 'method': 'GET', + 'context': {} + } + + with mock.patch.dict(botocore.auth.AUTH_TYPE_MAPS, {'v4-query': auth_cls}): + signer = base_signer_setup['signer'] + presigned_url = await signer.generate_presigned_url( + request_dict, operation_name='operation_name' + ) + + auth_cls.assert_called_with( + credentials=base_signer_setup['fixed_credentials'], + region_name='region_name', service_name='signing_name', + expires=3600 + ) + assert presigned_url == 'https://foo.com' + + +# From class TestGenerateUrl +@pytest.mark.moto +@pytest.mark.asyncio +async def test_signers_generate_presigned_urls(): + with mock.patch('aiobotocore.signers.AioRequestSigner.generate_presigned_url') \ + as cls_gen_presigned_url_mock: + session = aiobotocore.session.get_session() + async with session.create_client('s3', region_name='us-east-1', + aws_access_key_id='lalala', + aws_secret_access_key='lalala', + aws_session_token='lalala') as client: + + # Uses HEAD as it covers more lines :) + await client.generate_presigned_url('get_object', + Params={'Bucket': 'mybucket', + 'Key': 'mykey'}, + HttpMethod='HEAD') + + ref_request_dict = { + 'body': b'', + 'url': 'https://s3.amazonaws.com/mybucket/mykey', + 'headers': {}, + 'query_string': {}, + 'url_path': '/mybucket/mykey', + 'method': 'HEAD', + 'context': mock.ANY + } + + cls_gen_presigned_url_mock.assert_called_with( + request_dict=ref_request_dict, + expires_in=3600, + operation_name='GetObject') + + cls_gen_presigned_url_mock.reset_mock() + + with pytest.raises(UnknownClientMethodError): + await client.generate_presigned_url('lalala') diff --git a/tests/botocore/test_utils.py b/tests/botocore/test_utils.py new file mode 100644 index 00000000..3d17b87d --- /dev/null +++ b/tests/botocore/test_utils.py @@ -0,0 +1,289 @@ +import asyncio +import pytest +import json +import mock +import itertools +from typing import Union, List, Tuple + +from aiobotocore import utils +from botocore.utils import MetadataRetrievalError, BadIMDSRequestError + + +# From class TestContainerMetadataFetcher +def fake_aiohttp_session(responses: Union[List[Tuple[Union[str, object], int]], + Tuple[Union[str, object], int]]): + """ + Dodgy shim class + """ + if isinstance(responses, Tuple): + data = itertools.cycle([responses]) + else: + data = iter(responses) + + class FakeAioHttpSession(object): + class FakeResponse(object): + def __init__(self, url, *args, **kwargs): + self.url = url + self._body, self.status = next(data) + if not isinstance(self._body, str): + raise self._body + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def text(self): + return self._body + + async def json(self): + return json.loads(self._body) + + def __init__(self, *args, **kwargs): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + def get(self, url, *args, **kwargs): + return self.FakeResponse(url) + + def put(self, url, *args, **kwargs): + return self.FakeResponse(url) + + return FakeAioHttpSession + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_containermetadatafetcher_retrieve_url(): + json_body = json.dumps({ + "AccessKeyId": "a", + "SecretAccessKey": "b", + "Token": "c", + "Expiration": "d" + }) + + sleep = mock.AsyncMock() + http = fake_aiohttp_session((json_body, 200)) + + fetcher = utils.AioContainerMetadataFetcher(http, sleep) + resp = await fetcher.retrieve_uri('/foo?id=1') + assert resp['AccessKeyId'] == 'a' + assert resp['SecretAccessKey'] == 'b' + assert resp['Token'] == 'c' + assert resp['Expiration'] == 'd' + + resp = await fetcher.retrieve_full_uri('http://localhost/foo?id=1', + {'extra': 'header'}) + assert resp['AccessKeyId'] == 'a' + assert resp['SecretAccessKey'] == 'b' + assert resp['Token'] == 'c' + assert resp['Expiration'] == 'd' + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_containermetadatafetcher_retrieve_url_bad_status(): + json_body = "not json" + + sleep = mock.AsyncMock() + http = fake_aiohttp_session((json_body, 500)) + + fetcher = utils.AioContainerMetadataFetcher(http, sleep) + with pytest.raises(MetadataRetrievalError): + await fetcher.retrieve_uri('/foo?id=1') + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_containermetadatafetcher_retrieve_url_not_json(): + json_body = "not json" + + sleep = mock.AsyncMock() + http = fake_aiohttp_session((json_body, 200)) + + fetcher = utils.AioContainerMetadataFetcher(http, sleep) + with pytest.raises(MetadataRetrievalError): + await fetcher.retrieve_uri('/foo?id=1') + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_instancemetadatafetcher_retrieve_creds(): + with mock.patch('aiobotocore.utils.AioInstance' + 'MetadataFetcher._get_request') as mock_obj: + mock_obj.side_effect = [ + utils.AioIMDSFetcher.Response(200, 'some-role', + 'someurl'), + utils.AioIMDSFetcher.Response(200, '{"AccessKeyId": "foo", ' + '"SecretAccessKey": "bar", ' + '"Token": "baz", ' + '"Expiration": "bah"}', + 'someurl'), + ] + + fetcher = utils.AioInstanceMetadataFetcher() + + creds = await fetcher.retrieve_iam_role_credentials() + assert creds['role_name'] == 'some-role' + assert creds['access_key'] == 'foo' + assert creds['secret_key'] == 'bar' + assert creds['token'] == 'baz' + assert creds['expiry_time'] == 'bah' + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_instancemetadatafetcher_partial_response(): + with mock.patch('aiobotocore.utils.AioInstance' + 'MetadataFetcher._get_request') as mock_obj: + mock_obj.side_effect = [ + utils.AioIMDSFetcher.Response(200, 'some-role', + 'someurl'), + utils.AioIMDSFetcher.Response(200, '{"Code": "foo", "Message": "test"}', + 'someurl'), + ] + + fetcher = utils.AioInstanceMetadataFetcher() + + creds = await fetcher.retrieve_iam_role_credentials() + assert creds == {} + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_instancemetadatafetcher_bad_token(): + with mock.patch('aiobotocore.utils.AioInstance' + 'MetadataFetcher._fetch_metadata_token') as mock_obj: + mock_obj.side_effect = BadIMDSRequestError('somereq') + + fetcher = utils.AioInstanceMetadataFetcher() + + creds = await fetcher.retrieve_iam_role_credentials() + assert creds == {} + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_idmsfetcher_disabled(): + env = {'AWS_EC2_METADATA_DISABLED': 'true'} + fetcher = utils.AioIMDSFetcher(env=env) + + with pytest.raises(fetcher._RETRIES_EXCEEDED_ERROR_CLS): + await fetcher._get_request('path', None) + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_idmsfetcher_get_token_success(): + session = fake_aiohttp_session([ + ('blah', 200), + ]) + + fetcher = utils.AioIMDSFetcher(num_attempts=2, + session=session, + user_agent='test') + response = await fetcher._fetch_metadata_token() + assert response == 'blah' + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_idmsfetcher_get_token_not_found(): + session = fake_aiohttp_session([ + ('blah', 404), + ]) + + fetcher = utils.AioIMDSFetcher(num_attempts=2, + session=session, + user_agent='test') + response = await fetcher._fetch_metadata_token() + assert response is None + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_idmsfetcher_get_token_bad_request(): + session = fake_aiohttp_session([ + ('blah', 400), + ]) + + fetcher = utils.AioIMDSFetcher(num_attempts=2, + session=session, + user_agent='test') + with pytest.raises(BadIMDSRequestError): + await fetcher._fetch_metadata_token() + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_idmsfetcher_get_token_timeout(): + session = fake_aiohttp_session([ + (asyncio.TimeoutError(), 500), + ]) + + fetcher = utils.AioIMDSFetcher(num_attempts=2, + session=session) + + response = await fetcher._fetch_metadata_token() + assert response is None + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_idmsfetcher_get_token_retry(): + session = fake_aiohttp_session([ + ('blah', 500), + ('blah', 500), + ('token', 200), + ]) + + fetcher = utils.AioIMDSFetcher(num_attempts=3, + session=session) + + response = await fetcher._fetch_metadata_token() + assert response == 'token' + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_idmsfetcher_retry(): + session = fake_aiohttp_session([ + ('blah', 500), + ('data', 200), + ]) + + fetcher = utils.AioIMDSFetcher(num_attempts=2, + session=session, + user_agent='test') + response = await fetcher._get_request('path', None, 'some_token') + + assert response.text == 'data' + + session = fake_aiohttp_session([ + ('blah', 500), + ('data', 200), + ]) + + fetcher = utils.AioIMDSFetcher(num_attempts=1, session=session) + with pytest.raises(fetcher._RETRIES_EXCEEDED_ERROR_CLS): + await fetcher._get_request('path', None) + + +@pytest.mark.moto +@pytest.mark.asyncio +async def test_idmsfetcher_timeout(): + session = fake_aiohttp_session([ + (asyncio.TimeoutError(), 500), + ]) + + fetcher = utils.AioIMDSFetcher(num_attempts=1, + session=session) + + with pytest.raises(fetcher._RETRIES_EXCEEDED_ERROR_CLS): + await fetcher._get_request('path', None) diff --git a/tests/conftest.py b/tests/conftest.py index f78794a0..bc8e1671 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -201,16 +201,16 @@ def sqs_client(request, session, region, config, sqs_server, return client -def create_client(client_type, request, event_loop, session, region, - config, **kw): - async def f(): - return session.create_client(client_type, region_name=region, - config=config, **kw) - client = event_loop.run_until_complete(f()) +def create_client(client_type, request, event_loop: asyncio.AbstractEventLoop, + session, region, config, **kw): + client = session.create_client(client_type, region_name=region, + config=config, **kw) def fin(): - event_loop.run_until_complete(client.close()) + event_loop.run_until_complete(client.__aexit__(None, None, None)) request.addfinalizer(fin) + + client = event_loop.run_until_complete(client.__aenter__()) return client diff --git a/tests/test_basic_s3.py b/tests/test_basic_s3.py index 75104127..2cc5f283 100644 --- a/tests/test_basic_s3.py +++ b/tests/test_basic_s3.py @@ -415,15 +415,14 @@ async def test_presign_with_existing_query_string_values( params = {'Bucket': bucket_name, 'Key': key_name, 'ResponseContentDisposition': content_disposition} - presigned_url = s3_client.generate_presigned_url( + presigned_url = await s3_client.generate_presigned_url( 'get_object', Params=params) # Try to retrieve the object using the presigned url. - resp = await aio_session.get(presigned_url) - data = await resp.read() - await resp.close() - assert resp.headers['Content-Disposition'] == content_disposition - assert data == b'foo' + async with aio_session.get(presigned_url) as resp: + data = await resp.read() + assert resp.headers['Content-Disposition'] == content_disposition + assert data == b'foo' @pytest.mark.parametrize('region', ['us-east-1']) @@ -435,7 +434,7 @@ async def test_presign_sigv4(s3_client, bucket_name, aio_session, create_object): key = 'myobject' await create_object(key_name=key) - presigned_url = s3_client.generate_presigned_url( + presigned_url = await s3_client.generate_presigned_url( 'get_object', Params={'Bucket': bucket_name, 'Key': key}) msg = "Host was suppose to be the us-east-1 endpoint, " \ "instead got: %s" % presigned_url @@ -443,9 +442,9 @@ async def test_presign_sigv4(s3_client, bucket_name, aio_session, % (bucket_name, key)), msg # Try to retrieve the object using the presigned url. - resp = await aio_session.get(presigned_url) - data = await resp.read() - assert data == b'foo' + async with aio_session.get(presigned_url) as resp: + data = await resp.read() + assert data == b'foo' @pytest.mark.parametrize('signature_version', ['s3v4']) diff --git a/tests/test_config.py b/tests/test_config.py index 39f7d060..e43bd691 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -96,3 +96,13 @@ async def test_connector_timeout2(): async def test_get_session(): session = get_session() assert isinstance(session, AioSession) + + +@pytest.mark.moto +def test_merge(): + config = AioConfig() + other_config = AioConfig() + new_config = config.merge(other_config) + assert isinstance(new_config, AioConfig) + assert new_config is not config + assert new_config is not other_config diff --git a/tests/test_eventstreams.py b/tests/test_eventstreams.py index 9cd91750..d61bc213 100644 --- a/tests/test_eventstreams.py +++ b/tests/test_eventstreams.py @@ -96,4 +96,4 @@ async def test_eventstream_no_iter(s3_client): with pytest.raises(NotImplementedError): for _ in event_stream: - print('fail') + pass diff --git a/tests/test_patches.py b/tests/test_patches.py index 09f4322b..0fa77506 100644 --- a/tests/test_patches.py +++ b/tests/test_patches.py @@ -23,7 +23,19 @@ from botocore.parsers import ResponseParserFactory, PROTOCOL_PARSERS, \ RestXMLParser, EC2QueryParser, QueryParser, JSONParser, RestJSONParser from botocore.response import StreamingBody - +from botocore.signers import RequestSigner, add_generate_presigned_url, \ + generate_presigned_url +from botocore.hooks import EventAliaser, HierarchicalEmitter +from botocore.utils import ContainerMetadataFetcher, IMDSFetcher, \ + InstanceMetadataFetcher +from botocore.credentials import Credentials, RefreshableCredentials, \ + CachedCredentialFetcher, AssumeRoleCredentialFetcher, EnvProvider, \ + ContainerProvider, InstanceMetadataProvider, ProfileProviderBuilder, \ + ConfigProvider, SharedCredentialProvider, ProcessProvider, CredentialResolver, \ + AssumeRoleWithWebIdentityProvider, AssumeRoleProvider, \ + CanonicalNameCredentialSourcer, BotoProvider, OriginalEC2Provider, \ + create_credential_resolver, get_credentials, create_mfa_serial_refresher, \ + AssumeRoleWithWebIdentityCredentialFetcher # This file ensures that our private patches will work going forward. If a # method gets updated this will assert and someone will need to validate: @@ -63,11 +75,14 @@ ClientArgsCreator.get_client_args: {'e3a44e6f50159e8e31c3d76f5e8a1110dda495fa'}, # client.py + ClientCreator.create_client: {'ee63a3d60b5917879cb644c1b0aa3fe34538b915'}, ClientCreator._create_client_class: {'5e493d069eedbf314e40e12a7886bbdbcf194335'}, ClientCreator._get_client_args: {'555e1e41f93df7558c8305a60466681e3a267ef3'}, BaseClient._make_api_call: {'0c59329d4c8a55b88250b512b5e69239c42246fb'}, BaseClient._make_request: {'033a386f7d1025522bea7f2bbca85edc5c8aafd2'}, + BaseClient._convert_to_request_dict: {'0071c2a37c3c696d9b0fba5f54b2985489c76b78'}, + BaseClient._emit_api_params: {'2bfadaaa70671b63c50b1beed6d6c66e85813e9b'}, BaseClient.get_paginator: {'c69885f5f73fae048c0b93b43bbfcd1f9c6168b8'}, BaseClient.get_waiter: {'23d57598555bfbc4c6e7ec93406d05771f108d9e'}, @@ -75,9 +90,100 @@ Config.merge: {'c3dd8c3ffe0da86953ceba4a35267dfb79c6a2c8'}, Config: {'2dcc44190a3dc2a4b26ab0ed9410daefcd7c93c1'}, + # credentials.py + create_mfa_serial_refresher: {'180b81fc40c91d1cf40de1a28e32ae7d601e1d50'}, + Credentials.get_frozen_credentials: {'08af57df08ee9953e440aa7aca58137ed936cdb6'}, + RefreshableCredentials.__init__: {'c685fd2c62eb60096fdf8bb885fb642df1819f7f'}, + # We've overridden some properties + RefreshableCredentials.__dict__['access_key'].fset: + {'edc4a25baef877a9662f68cd9ccefcd33a81bab7'}, + RefreshableCredentials.__dict__['access_key'].fget: + {'f6c823210099db99dd343d9e1fae6d4eb5aa5fce'}, + RefreshableCredentials.__dict__['secret_key'].fset: + {'b19fe41d66822c72bd6ae2e60de5c5d27367868a'}, + RefreshableCredentials.__dict__['secret_key'].fget: + {'3e27331a037549104b8669e225bbbb2c465a16d4'}, + RefreshableCredentials.__dict__['token'].fset: + {'1f8a308d4bf21e666f8054a0546e91541661da7b'}, + RefreshableCredentials.__dict__['token'].fget: + {'005c1b44b616f37739ce9276352e4e83644d8220'}, + RefreshableCredentials._refresh: {'f4759b7ef0d1f0d8af07855dcd9ca49ef12c2e7b'}, + RefreshableCredentials._protected_refresh: + {'432409f81601dbeea9ec187d433d190ab7c5ab2f'}, + RefreshableCredentials.get_frozen_credentials: + {'f661c84a8b759786e011f0b1e8a468a0c6294e36'}, + + CachedCredentialFetcher._get_credentials: + {'02a7d13599d972e3f258d2b53f87eeda4cc3e3a4'}, + CachedCredentialFetcher.fetch_credentials: + {'0dd2986a4cbb38764ec747075306a33117e86c3d'}, + CachedCredentialFetcher._get_cached_credentials: + {'a9f8c348d226e62122972da9ccc025365b6803d6'}, + AssumeRoleCredentialFetcher._get_credentials: + {'5c575634bc0a713c10e5668f28fbfa8779d5a1da'}, + AssumeRoleCredentialFetcher._create_client: + {'27c76f07bd43e665899ca8d21b6ba2038b276fbb'}, + # Referenced by AioAssumeRoleWithWebIdentityCredentialFetcher + AssumeRoleWithWebIdentityCredentialFetcher.__init__: + {'85c022a7237a3500ca973b2f7f91bffe894e4577'}, + AssumeRoleWithWebIdentityCredentialFetcher._get_credentials: + {'02eba9d4e846474910cb076710070348e395a819'}, + AssumeRoleWithWebIdentityCredentialFetcher._assume_role_kwargs: + {'8fb4fefe8664b7d82a67e0fd6d6812c1c8d92285'}, + # Ensure that the load method doesn't do anything we should asyncify + EnvProvider.load: {'07cff5032b39b568505779774a1ca66efc513abb'}, + + ContainerProvider.__init__: {'ea6aafb2e12730066af930fb5a27f7659c1736a1'}, + ContainerProvider.load: {'57c35569050b45c1e9e33fcdb3b49da9e342fdcf'}, + ContainerProvider._retrieve_or_fail: + {'7c14f1cdee07217f847a71068866bdd10c3fa0fa'}, + ContainerProvider._create_fetcher: + {'09a3ffded0fc20a574f3b34fa432a1569d5e729f'}, + InstanceMetadataProvider.load: {'4a27eb94fe220fba2b46c97bdd9e16de199ce004'}, + ProfileProviderBuilder._create_process_provider: + {'c5eea47bcfc449a6d73a9892bd0e1897f6be0c20'}, + ProfileProviderBuilder._create_shared_credential_provider: + {'33f99c6a0ef71a92b0c52ccc59c8ca7e33fa0890'}, + ProfileProviderBuilder._create_config_provider: + {'f9a40d4211f6e663ba2ae9682fba5306152178c5'}, + ProfileProviderBuilder._create_web_identity_provider: + {'0907c1ad5573bc5c0fc87efb601a6c4c3fcf34ae'}, + ConfigProvider.load: {'8fb32140086dce65fa28be8edd3ac0d22698c3ae'}, + SharedCredentialProvider.load: {'c0be1fe376d25952461ca18d9bef4b4340203441'}, + ProcessProvider.__init__: {'2e870ec0c6b0bc8483fa9b1159ef68bbd7a12c56'}, + ProcessProvider.load: {'aac90e2c8823939f09936b9c883e67503128e438'}, + ProcessProvider._retrieve_credentials_using: + {'ffc27c7cba0e37cf6db3a3eacfd54be8bd99d3a9'}, + CredentialResolver.load_credentials: + {'ef31ba8817f84c1f61f36259da1cc6e597b8625a'}, + AssumeRoleWithWebIdentityProvider.load: + {'8f48f6cadf08a09cf5a22b1cc668e60bc4ea389d'}, + AssumeRoleWithWebIdentityProvider._assume_role_with_web_identity: + {'32c9d720ab5f12054583758b5cd5d287f652ccd3'}, + AssumeRoleProvider.load: {'ee9ddb43e25eb1105185253c0963a2f5add49a95'}, + AssumeRoleProvider._load_creds_via_assume_role: + {'9fdba45a8dd16b885dea7c1fafc7d02609870fa7'}, + AssumeRoleProvider._resolve_source_credentials: + {'105c0c011e23d76a3b8bd3d9b91b6d945c8307a1'}, + AssumeRoleProvider._resolve_credentials_from_profile: + {'402a1a6b3e0a29c234b7883e5b855110eb655830'}, + AssumeRoleProvider._resolve_static_credentials_from_profile: + {'58f04986bb1027d548212b7769034e5dae5cc30f'}, + AssumeRoleProvider._resolve_credentials_from_source: + {'6f76ae62f477279a2297565f80a5cfbe5ea30eaf'}, + CanonicalNameCredentialSourcer.source_credentials: + {'602930a78e0e64e3b313a046aab5edc3bcf5c2d9'}, + CanonicalNameCredentialSourcer._get_provider: + {'c028b9776383cc566be10999745b6082f458d902'}, + BotoProvider.load: {'9351b8565c2c969937963fc1d3fbc8b3b6d8ccc1'}, + OriginalEC2Provider.load: {'bde9af019f01acf3848a6eda125338b2c588c1ab'}, + create_credential_resolver: {'5ff7fe49d7636b795a50202ff5c089611f4e27c1'}, + get_credentials: {'ff0c735a388ac8dd7fe300a32c1e36cdf33c0f56'}, + # endpoint.py convert_to_response_dict: {'2c73c059fa63552115314b079ae8cbf5c4e78da0'}, + Endpoint.create_request: {'4ccc14de2fd52f5c60017e55ff8e5b78bbaabcec'}, Endpoint._send_request: {'50ab33d6f16e75594d01ab1c2ec6b7c7903798db'}, Endpoint._get_response: {'46c3a8cb4ff7672b75193ce5571dbea48aa9da75'}, Endpoint._do_get_response: {'df29f099d26dc057834c7b25d3b5217f1f7acbe4'}, @@ -91,6 +197,12 @@ 'cc101f3ca2bca4f14ccd6b385af900a15f96967b'}, EventStream.__iter__: {'8a9b454943f8ef6e81f5794d641adddd1fdd5248'}, + # hooks.py + HierarchicalEmitter._emit: {'5d9a6b1aea1323667a9310e707a9f0a006f8f6e8'}, + HierarchicalEmitter.emit_until_response: + {'23670e04e0b09a9575c7533442bca1b2972ade82'}, + EventAliaser.emit_until_response: {'0d635bf7ae5022b1fdde891cd9a91cd4c449fd49'}, + # paginate.py PageIterator.__iter__: {'56b3a1e30f488e2f1f5d5309db42fd5ad8a3895d'}, PageIterator.result_key_iters: {'04d3c647bd98caba3687df80e650fea517a0068e'}, @@ -110,8 +222,47 @@ # session.py Session.__init__: {'ccf156a76beda3425fb54363f3b2718dc0445f6d'}, + Session._register_response_parser_factory: + {'d6cd5a8b1b473b0ec3b71db5f621acfb12cc412c'}, Session.create_client: {'36f4e718fc4bada66808c2f98fa71835c09076f7'}, + Session._create_credential_resolver: {'87e98d201c72d06f7fbdb4ebee2dce1c09de0fb2'}, + Session.get_credentials: {'c0de970743b6b9dd91b5a71031db8a495fde53e4'}, get_session: {'c47d588f5da9b8bde81ccc26eaef3aee19ddd901'}, + Session.get_service_data: {'e28f2de9ebaf13214f1606d33349dfa8e2555923'}, + Session.get_service_model: {'1c8f93e6fb9913e859e43aea9bc2546edbea8365'}, + Session.get_available_regions: {'bc455d24d98fbc112ff22325ebfd12a6773cb7d4'}, + + # signers.py + RequestSigner.handler: {'371909df136a0964ef7469a63d25149176c2b442'}, + RequestSigner.sign: {'7df841d3df3f4015763523c1932652aef754287a'}, + RequestSigner.get_auth: {'4f8099bef30f9a72fa3bcaa1bd3d22c4fbd224a8'}, + RequestSigner.get_auth_instance: {'4f8099bef30f9a72fa3bcaa1bd3d22c4fbd224a8'}, + RequestSigner._choose_signer: {'d1e0e3196ada449d3ae0ec09b8ae9b5868c50d4e'}, + RequestSigner.generate_presigned_url: {'2acffdfd926b7b6f6cc4b70b90c0587e7f424888'}, + add_generate_presigned_url: {'5820f74ac46b004eb79e00eea1adc467bcf4defe'}, + generate_presigned_url: {'9c471f957210c0a71a11f5c73be9fed844ecb5bb'}, + + # utils.py + ContainerMetadataFetcher.__init__: + {'46d90a7249ba8389feb487779b0a02e6faa98e57'}, + ContainerMetadataFetcher.retrieve_full_uri: + {'2c7080f7d6ee5a3dacc1b690945c045dba1b1d21'}, + ContainerMetadataFetcher.retrieve_uri: + {'4ee8aa704cf0a378d68ef9a7b375a1aa8840b000'}, + ContainerMetadataFetcher._retrieve_credentials: + {'f5294f9f811cb3cc370e4824ca106269ea1f44f9'}, + ContainerMetadataFetcher._get_response: + {'7e5acdd2cf0167a047e3d5ee1439565a2f79f6a6'}, + # Overrided session and dealing with proxy support + IMDSFetcher.__init__: {'690e37140ccdcd67c7a85ce5d36331491a79954e'}, + IMDSFetcher._get_request: {'96a0e580cab5a21deb4d2cd7e904aa17d5e1e504'}, + IMDSFetcher._fetch_metadata_token: {'4fdad673b4997b1268c6d9dff09a4b99c1cb5e0d'}, + + InstanceMetadataFetcher.retrieve_iam_role_credentials: + {'76737f6add82a1b9a0dc590cf10bfac0c7026a2e'}, + InstanceMetadataFetcher._get_iam_role: {'80073d7adc9fb604bc6235af87241f5efc296ad7'}, + InstanceMetadataFetcher._get_credentials: + {'1a64f59a3ca70b83700bd14deeac25af14100d58'}, # waiter.py NormalizedOperationMethod.__call__: {'79723632d023739aa19c8a899bc2b814b8ab12ff'}, @@ -119,6 +270,7 @@ create_waiter_with_client: {'c3d12c9a4293105cc8c2ecfc7e69a2152ad564de'}, } + _PROTOCOL_PARSER_CONTENT = {'ec2', 'query', 'json', 'rest-json', 'rest-xml'}