From 5513b9d3fd161a896904afc9b4b3721696f95a9c Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Mon, 19 Aug 2019 15:03:24 -0400 Subject: [PATCH] Unquote connection string components properly When a connection string component contains characters that have a special meaning in the URI (e.g. '@' or '='), percent-encoding must be used. asyncpg must take care to unquote the parsed components correctly, and it doesn't currently. Additionally, this makes asyncpg follow the libpq's behavior of parsing the authentication part of netloc, i.e. split on the first '@' and not the last. Fixes: #418 Fixes: #471 --- asyncpg/connect_utils.py | 40 +++++++++++++++++++++++++++------------- tests/test_connect.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 13 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index c3a9670c..26fdec59 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -153,7 +153,7 @@ def _validate_port_spec(hosts, port): return port -def _parse_hostlist(hostlist, port): +def _parse_hostlist(hostlist, port, *, unquote=False): if ',' in hostlist: # A comma-separated list of host addresses. hostspecs = hostlist.split(',') @@ -185,9 +185,14 @@ def _parse_hostlist(hostlist, port): addr = hostspec hostspec_port = '' + if unquote: + addr = urllib.parse.unquote(addr) + hosts.append(addr) if not port: if hostspec_port: + if unquote: + hostspec_port = urllib.parse.unquote(hostspec_port) hostlist_ports.append(int(hostspec_port)) else: hostlist_ports.append(default_port[i]) @@ -213,25 +218,34 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, 'invalid DSN: scheme is expected to be either ' '"postgresql" or "postgres", got {!r}'.format(parsed.scheme)) - if not host and parsed.netloc: + if parsed.netloc: if '@' in parsed.netloc: - auth, _, hostspec = parsed.netloc.partition('@') + dsn_auth, _, dsn_hostspec = parsed.netloc.partition('@') else: - hostspec = parsed.netloc + dsn_hostspec = parsed.netloc + dsn_auth = '' + else: + dsn_auth = dsn_hostspec = '' + + if dsn_auth: + dsn_user, _, dsn_password = dsn_auth.partition(':') + else: + dsn_user = dsn_password = '' - if hostspec: - host, port = _parse_hostlist(hostspec, port) + if not host and dsn_hostspec: + host, port = _parse_hostlist(dsn_hostspec, port, unquote=True) if parsed.path and database is None: - database = parsed.path - if database.startswith('/'): - database = database[1:] + dsn_database = parsed.path + if dsn_database.startswith('/'): + dsn_database = dsn_database[1:] + database = urllib.parse.unquote(dsn_database) - if parsed.username and user is None: - user = parsed.username + if user is None and dsn_user: + user = urllib.parse.unquote(dsn_user) - if parsed.password and password is None: - password = parsed.password + if password is None and dsn_password: + password = urllib.parse.unquote(dsn_password) if parsed.query: query = urllib.parse.parse_qs(parsed.query, strict_parsing=True) diff --git a/tests/test_connect.py b/tests/test_connect.py index abb30647..5830b38e 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -453,6 +453,41 @@ class TestConnectParams(tb.TestCase): 'database': 'dbname'}) }, + { + 'dsn': 'postgresql://us%40r:p%40ss@h%40st1,h%40st2:543%33/d%62', + 'result': ( + [('h@st1', 5432), ('h@st2', 5433)], + { + 'user': 'us@r', + 'password': 'p@ss', + 'database': 'db', + } + ) + }, + + { + 'dsn': 'postgresql://user:p@ss@host/db', + 'result': ( + [('ss@host', 5432)], + { + 'user': 'user', + 'password': 'p', + 'database': 'db', + } + ) + }, + + { + 'dsn': 'postgresql:///d%62?user=us%40r&host=h%40st&port=543%33', + 'result': ( + [('h@st', 5433)], + { + 'user': 'us@r', + 'database': 'db', + } + ) + }, + { 'dsn': 'pq:///dbname?host=/unix_sock/test&user=spam', 'error': (ValueError, 'invalid DSN')