diff --git a/CHANGELOG.md b/CHANGELOG.md index 040fb8ab..d8b31ea9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ### Deprecated ### Removed ### Fixed +- Fixed allow AWSV4SignerAuth to work with a tunnel ([[#184](https://github.com/opensearch-project/opensearch-py/issues/184) ### Security ### Dependencies - Bumps `sphinx` from <7.1 to <7.3 diff --git a/opensearchpy/helpers/signer.py b/opensearchpy/helpers/signer.py index 176f6ac9..81acc27f 100644 --- a/opensearchpy/helpers/signer.py +++ b/opensearchpy/helpers/signer.py @@ -6,7 +6,6 @@ # # Modifications Copyright OpenSearch Contributors. See # GitHub history for details. - import sys import requests @@ -43,12 +42,22 @@ def fetch_url(prepared_request): # type: ignore return url.scheme + "://" + location + path + querystring +def derive_signature_url(original_url: str, singing_port: int) -> str: + url = urlparse(original_url) + if url.hostname is None: + raise RuntimeError("Cannot use derive_signature_url on urls without hostname.") + else: + return url._replace(netloc=url.hostname + ":" + str(singing_port)).geturl() + + class AWSV4SignerAuth(requests.auth.AuthBase): """ AWS V4 Request Signer for Requests. """ - def __init__(self, credentials, region, service="es"): # type: ignore + def __init__(self, credentials, region, service="es", signature_port=None): # type: ignore + # can be used to sign the request for a different port than the request, e.g. due to a tunnel being used + self.signature_port = signature_port if not credentials: raise ValueError("Credentials cannot be empty") self.credentials = credentials @@ -79,7 +88,9 @@ def _sign_request(self, prepared_request): # type: ignore # create an AWS request object and sign it using SigV4Auth aws_request = AWSRequest( method=prepared_request.method.upper(), - url=url, + url=derive_signature_url(url, self.signature_port) + if self.signature_port is not None + else url, data=prepared_request.body, ) diff --git a/test_opensearchpy/test_connection.py b/test_opensearchpy/test_connection.py index 5ec6e09d..e28545ad 100644 --- a/test_opensearchpy/test_connection.py +++ b/test_opensearchpy/test_connection.py @@ -336,12 +336,65 @@ def test_aws_signer_as_http_auth(self): con = RequestsHttpConnection(http_auth=auth) prepared_request = requests.Request("GET", "http://localhost").prepare() auth(prepared_request) + self._assert_auth_and_signature_headers(auth, con, prepared_request) + + @pytest.mark.skipif( + sys.version_info < (3, 6), reason="AWSV4SignerAuth requires python3.6+" + ) + def test_aws_signer_as_http_auth_with_query_path(self): + region = "us-west-2" + + import requests + + from opensearchpy.helpers.signer import AWSV4SignerAuth + + auth = AWSV4SignerAuth(self.mock_session(), region) + con = RequestsHttpConnection(http_auth=auth) + prepared_request = requests.Request( + "GET", "http://localhost?hello=world" + ).prepare() + auth(prepared_request) + self._assert_auth_and_signature_headers(auth, con, prepared_request) + + def _assert_auth_and_signature_headers(self, auth, con, prepared_request): self.assertEqual(auth, con.session.auth) self.assertIn("Authorization", prepared_request.headers) self.assertIn("X-Amz-Date", prepared_request.headers) self.assertIn("X-Amz-Security-Token", prepared_request.headers) self.assertIn("X-Amz-Content-SHA256", prepared_request.headers) + @pytest.mark.skipif( + sys.version_info < (3, 6), reason="AWSV4SignerAuth requires python3.6+" + ) + def test_aws_signer_as_http_auth_with_sign_port_with_port_on_base_url(self): + region = "us-west-2" + + import requests + + from opensearchpy.helpers.signer import AWSV4SignerAuth + + auth = AWSV4SignerAuth(self.mock_session(), region, signature_port=443) + con = RequestsHttpConnection(http_auth=auth) + prepared_request = requests.Request("GET", "http://localhost:1045").prepare() + auth(prepared_request) + self._assert_auth_and_signature_headers(auth, con, prepared_request) + + @pytest.mark.skipif( + sys.version_info < (3, 6), reason="AWSV4SignerAuth requires python3.6+" + ) + def test_aws_signer_as_http_auth_with_sign_port_but_without_port_on_base_url(self): + region = "us-west-2" + + import requests + + from opensearchpy.helpers.signer import AWSV4SignerAuth + + auth = AWSV4SignerAuth(self.mock_session(), region, signature_port=443) + con = RequestsHttpConnection(http_auth=auth) + prepared_request = requests.Request("GET", "http://localhost").prepare() + auth(prepared_request) + self._assert_auth_and_signature_headers(auth, con, prepared_request) + @pytest.mark.skipif( sys.version_info < (3, 6), reason="AWSV4SignerAuth requires python3.6+" ) diff --git a/test_opensearchpy/test_helpers/test_signer.py b/test_opensearchpy/test_helpers/test_signer.py new file mode 100644 index 00000000..597a9fc5 --- /dev/null +++ b/test_opensearchpy/test_helpers/test_signer.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# The OpenSearch Contributors require contributions made to +# this file be licensed under the Apache-2.0 license or a +# compatible open source license. +# +# Modifications Copyright OpenSearch Contributors. See +# GitHub history for details. +from unittest import TestCase + +from opensearchpy.helpers.signer import derive_signature_url + + +class TestUrllib3Connection(TestCase): + def test_derive_signature_url(self): + assert ( + derive_signature_url("http://localhost:10552/", singing_port=443) + == "http://localhost:443/" + ) + assert ( + derive_signature_url("http://localhost:10552/foo/bar", singing_port=443) + == "http://localhost:443/foo/bar" + ) + assert ( + derive_signature_url("http://localhost/", singing_port=443) + == "http://localhost:443/" + ) + assert ( + derive_signature_url("http://localhost/foo/bar", singing_port=443) + == "http://localhost:443/foo/bar" + ) + + def test_derive_signature_url_no_hostname(self): + self.assertRaises(RuntimeError, derive_signature_url, "http://", 23)