Skip to content

Commit

Permalink
Allow different quote_via functions (#19)
Browse files Browse the repository at this point in the history
* Allow different `quote_via` functions
* Add unit tests
* Bump version
  • Loading branch information
d3QUone authored Jan 13, 2024
1 parent 09def87 commit 8d38e20
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 16 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ for Python projects, e.g. when you need a high-load queue consumer or high-load

Supports Python versions 3.8, 3.9, 3.10, 3.11, 3.12.

Supported and tested Amazon-like SQS providers: Amazon, VK Cloud.

----

## Why aiosqs?
Expand Down
2 changes: 1 addition & 1 deletion aiosqs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
SendMessageResponse,
)

VERSION = "1.0.4"
VERSION = "1.0.5"
31 changes: 23 additions & 8 deletions aiosqs/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import datetime
import urllib.parse
from logging import getLogger
from typing import Dict, Optional, List, Union
from typing import Dict, Optional, List, Union, Callable, NamedTuple

import aiohttp

Expand All @@ -18,6 +18,11 @@
default_logger = getLogger(__name__)


class SignedRequest(NamedTuple):
headers: Dict
querystring: str


class SQSClient:
algorithm = "AWS4-HMAC-SHA256"
default_timeout_sec = 10
Expand All @@ -31,6 +36,7 @@ def __init__(
timeout_sec: Optional[int] = None,
logger: Optional[LoggerType] = None,
verify_ssl: Optional[bool] = None,
quote_via: Optional[Callable] = None,
):
self.service_name = "sqs"
self.region_name = region_name
Expand All @@ -46,6 +52,11 @@ def __init__(
self.timeout = aiohttp.ClientTimeout(total=timeout_sec or self.default_timeout_sec)
self.session = aiohttp.ClientSession(timeout=self.timeout)

# It's possible to have differen quoting logic for different SQS providers.
# By default Amazon SQS uses `urllib.parse.quote`, so no extra customizations are required.
# Related issue: https://github.com/d3QUone/aiosqs/issues/13
self.quote_via = quote_via or urllib.parse.quote

async def close(self):
await self.session.close()
# https://docs.aiohttp.org/en/stable/client_advanced.html#graceful-shutdown
Expand All @@ -57,7 +68,7 @@ async def __aenter__(self):
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()

def get_headers(self, params: Dict):
def build_signed_request(self, params: Dict) -> SignedRequest:
# Create a date for headers and the credential string
t = datetime.datetime.utcnow()
amz_date = t.strftime("%Y%m%dT%H%M%SZ")
Expand All @@ -69,7 +80,7 @@ def get_headers(self, params: Dict):
# Create the canonical query string. Important notes:
# - Query string values must be URL-encoded (space=%20).
# - The parameters must be sorted by name.
canonical_querystring = urllib.parse.urlencode(list(sorted(params.items())))
canonical_querystring = urllib.parse.urlencode(query=list(sorted(params.items())), quote_via=self.quote_via)

# Create the canonical headers and signed headers.
canonical_headers = f"host:{self.host}" + "\n" + f"x-amz-date:{amz_date}" + "\n"
Expand Down Expand Up @@ -116,21 +127,25 @@ def get_headers(self, params: Dict):
# The request can include any headers, but MUST include "host", "x-amz-date",
# and (for this scenario) "Authorization". "host" and "x-amz-date" must
# be included in the canonical_headers and signed_headers. Order here is not significant.
return {
headers = {
"x-amz-date": amz_date,
"Authorization": authorization_header,
"content-type": "application/x-www-form-urlencoded",
}
return SignedRequest(
headers=headers,
querystring=canonical_querystring,
)

async def request(self, params: Dict) -> Union[Dict, List, None]:
params["Version"] = "2012-11-05"
headers = self.get_headers(params=params)
signed_request = self.build_signed_request(params=params)
url = f"{self.endpoint_url}?{signed_request.querystring}"

try:
response = await self.session.get(
url=self.endpoint_url,
headers=headers,
params=params,
url=url,
headers=signed_request.headers,
verify_ssl=self.verify_ssl,
)
except Exception as e:
Expand Down
77 changes: 71 additions & 6 deletions aiosqs/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import unittest
import re
import logging
import urllib.parse

import ddt
from freezegun import freeze_time
from aioresponses import aioresponses

from aiosqs.exceptions import SQSErrorResponse
Expand All @@ -11,25 +13,47 @@


@ddt.ddt(testNameFormat=ddt.TestNameFormat.INDEX_ONLY)
class ClientTestCase(unittest.IsolatedAsyncioTestCase):
class DefaultClientTestCase(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
await super().asyncSetUp()

logger = logging.getLogger(__name__)
logger.setLevel(logging.CRITICAL)
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.CRITICAL)

self.client = SQSClient(
aws_access_key_id="access_key_id",
aws_secret_access_key="secret_access_key",
region_name="us-west-2",
host="mocked_amazon_host.com",
timeout_sec=0,
logger=logger,
logger=self.logger,
)

async def asyncTearDown(self):
await self.client.close()

async def test_signature_with_quote_via(self):
params = {
"Action": "SendMessage",
"DelaySeconds": 0,
"MessageBody": "a b c d",
"QueueUrl": "http://host.com/internal/tests",
"Version": "2012-11-05",
}
with freeze_time("2022-03-07T11:30:00.0000"):
signed_request = self.client.build_signed_request(params=params)

self.assertEqual(
signed_request.headers,
{
"x-amz-date": "20220307T113000Z",
"Authorization": "AWS4-HMAC-SHA256 Credential=access_key_id/20220307/us-west-2/sqs/aws4_request, SignedHeaders=host;x-amz-date, Signature=7d7ae7f85d3175f61e5256ed560c7b284491f767b9c352d1231f92ec04043d8e",
"content-type": "application/x-www-form-urlencoded",
},
)
self.assertEqual(
signed_request.querystring,
"Action=SendMessage&DelaySeconds=0&MessageBody=a%20%20%20%20%20b%20%20%20%20c%20%20%20%20%20d&QueueUrl=http%3A%2F%2Fhost.com%2Finternal%2Ftests&Version=2012-11-05",
)

@aioresponses()
async def test_is_context_manager(self, mock):
mock.get(
Expand Down Expand Up @@ -67,3 +91,44 @@ async def test_invalid_auth_keys(self, fixture_name: str, error_message: str, mo
self.assertEqual(exception.error.type, "Sender")
self.assertEqual(exception.error.code, "InvalidClientTokenId")
self.assertEqual(exception.error.message, error_message)


class VKClientTestCase(DefaultClientTestCase):
async def asyncSetUp(self):
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.CRITICAL)

self.client = SQSClient(
aws_access_key_id="access_key_id",
aws_secret_access_key="secret_access_key",
region_name="us-west-2",
host="mocked_amazon_host.com",
timeout_sec=0,
logger=self.logger,
quote_via=urllib.parse.quote_plus,
)

async def test_signature_with_quote_via(self):
params = {
"Action": "SendMessage",
"DelaySeconds": 0,
"MessageBody": "a b c d",
"QueueUrl": "http://host.com/internal/tests",
"Version": "2012-11-05",
}

with freeze_time("2022-03-07T11:30:00.0000"):
signed_request = self.client.build_signed_request(params=params)

self.assertEqual(
signed_request.headers,
{
"x-amz-date": "20220307T113000Z",
"Authorization": "AWS4-HMAC-SHA256 Credential=access_key_id/20220307/us-west-2/sqs/aws4_request, SignedHeaders=host;x-amz-date, Signature=0c36e0d3f62bd7ecb7e78ffe09fbd1224b7f850f3b4f13c7fc82e516fc7f2c57",
"content-type": "application/x-www-form-urlencoded",
},
)
self.assertEqual(
signed_request.querystring,
"Action=SendMessage&DelaySeconds=0&MessageBody=a+++++b++++c+++++d&QueueUrl=http%3A%2F%2Fhost.com%2Finternal%2Ftests&Version=2012-11-05",
)
2 changes: 2 additions & 0 deletions e2e/test_e2e.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest
import logging
from urllib.parse import quote_plus

from dotenv import dotenv_values

Expand Down Expand Up @@ -29,6 +30,7 @@ async def asyncSetUp(self):
host=self.host,
verify_ssl=False,
logger=logger,
quote_via=quote_plus,
)

async def asyncTearDown(self):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Source = "https://github.com/d3QUone/aiosqs"

[tool.poetry]
name = "aiosqs"
version = "1.0.4"
version = "1.0.5"
description = "Python asynchronous and lightweight SQS client."
authors = ["Vladimir Kasatkin <[email protected]>"]
license = "MIT"
Expand Down

0 comments on commit 8d38e20

Please sign in to comment.