diff --git a/docs/classes/singer_sdk.connectors.BaseConnector.rst b/docs/classes/singer_sdk.connectors.BaseConnector.rst new file mode 100644 index 000000000..3ba703887 --- /dev/null +++ b/docs/classes/singer_sdk.connectors.BaseConnector.rst @@ -0,0 +1,8 @@ +singer_sdk.connectors.BaseConnector +=================================== + +.. currentmodule:: singer_sdk.connectors + +.. autoclass:: BaseConnector + :members: + :special-members: __init__, __call__ \ No newline at end of file diff --git a/docs/guides/custom-connector.md b/docs/guides/custom-connector.md new file mode 100644 index 000000000..cf0ec7e20 --- /dev/null +++ b/docs/guides/custom-connector.md @@ -0,0 +1,32 @@ +# Using a custom connector class + +The Singer SDK has a few built-in connector classes that are designed to work with a variety of sources: + +* [`SQLConnector`](../../classes/singer_sdk.SQLConnector) for SQL databases + +If you need to connect to a source that is not supported by one of these built-in connectors, you can create your own connector class. This guide will walk you through the process of creating a custom connector class. + +## Subclass `BaseConnector` + +The first step is to create a subclass of [`BaseConnector`](../../classes/singer_sdk.connectors.BaseConnector). This class is responsible for creating streams and handling the connection to the source. + +```python +from singer_sdk.connectors import BaseConnector + + +class MyConnector(BaseConnector): + pass +``` + +## Implement `get_connection` + +The [`get_connection`](http://127.0.0.1:5500/build/classes/singer_sdk.connectors.BaseConnector.html#singer_sdk.connectors.BaseConnector.get_connection) method is responsible for creating a connection to the source. It should return an object that implements the [context manager protocol](https://docs.python.org/3/reference/datamodel.html#with-statement-context-managers), e.g. it has `__enter__` and `__exit__` methods. + +```python +from singer_sdk.connectors import BaseConnector + + +class MyConnector(BaseConnector): + def get_connection(self): + return MyConnection() +``` diff --git a/docs/guides/index.md b/docs/guides/index.md index 268f27957..37c310cd2 100644 --- a/docs/guides/index.md +++ b/docs/guides/index.md @@ -7,6 +7,7 @@ The following pages contain useful information for developers building on top of porting pagination-classes +custom-connector custom-clis config-schema ``` diff --git a/docs/reference.rst b/docs/reference.rst index b59bd6651..1ef38dca4 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -143,3 +143,12 @@ Batch batch.BaseBatcher batch.JSONLinesBatcher + +Abstract Connector Classes +-------------------------- + +.. autosummary:: + :toctree: classes + :template: class.rst + + connectors.BaseConnector diff --git a/poetry.lock b/poetry.lock index a238790e1..c6f060445 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1006,7 +1006,7 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] name = "markupsafe" version = "2.1.5" description = "Safely add untrusted strings to HTML/XML markup." -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc"}, @@ -1492,6 +1492,20 @@ compat = ["pytest-benchmark (>=4.0.0,<4.1.0)", "pytest-xdist (>=2.0.0,<2.1.0)"] lint = ["mypy (>=1.3.0,<1.4.0)", "ruff (>=0.3.3,<0.4.0)"] test = ["pytest (>=7.0,<8.0)", "pytest-cov (>=4.0.0,<4.1.0)"] +[[package]] +name = "pytest-httpserver" +version = "1.0.12" +description = "pytest-httpserver is a httpserver for pytest" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest_httpserver-1.0.12-py3-none-any.whl", hash = "sha256:dae1c79ec7aeda83bfaaf4d0a400867a4b1bc6bf668244daaf13aa814e3022da"}, + {file = "pytest_httpserver-1.0.12.tar.gz", hash = "sha256:c14600b8efb9ea8d7e63251a242ab987f13028b36d3d397ffaca3c929f67eb16"}, +] + +[package.dependencies] +Werkzeug = ">=2.0.0" + [[package]] name = "pytest-snapshot" version = "0.9.0" @@ -2461,6 +2475,23 @@ brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotl secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] +[[package]] +name = "werkzeug" +version = "3.0.3" +description = "The comprehensive WSGI web application library." +optional = false +python-versions = ">=3.8" +files = [ + {file = "werkzeug-3.0.3-py3-none-any.whl", hash = "sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8"}, + {file = "werkzeug-3.0.3.tar.gz", hash = "sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18"}, +] + +[package.dependencies] +MarkupSafe = ">=2.1.1" + +[package.extras] +watchdog = ["watchdog (>=2.3)"] + [[package]] name = "xdoctest" version = "1.1.6" @@ -2510,4 +2541,4 @@ testing = ["pytest"] [metadata] lock-version = "2.0" python-versions = ">=3.8" -content-hash = "8d0665d7e5397609e616976d470dca9863a925a12f513c0b78439f055aeb664a" +content-hash = "6469b9ee5115b54975583b6b110a2c60e6d3bffd60529caef99c3d54585f16b3" diff --git a/pyproject.toml b/pyproject.toml index 68faa36ec..97095b2a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,6 +122,7 @@ duckdb-engine = { version = ">=0.9.4", python = "<4" } fastjsonschema = ">=2.19.1" pytest-benchmark = ">=4.0.0" +pytest-httpserver = { version = ">=1.0.6", python = "<4" } pytest-snapshot = ">=0.9.0" pytz = ">=2022.2.1" requests-mock = ">=1.10.0" diff --git a/singer_sdk/authenticators.py b/singer_sdk/authenticators.py index c6478cb92..9fce27854 100644 --- a/singer_sdk/authenticators.py +++ b/singer_sdk/authenticators.py @@ -11,6 +11,7 @@ from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit import requests +from requests.auth import AuthBase from singer_sdk.helpers._util import utc_now @@ -590,3 +591,52 @@ def oauth_request_payload(self) -> dict: "RS256", ), } + + +class NoopAuth(AuthBase): + """No-op authenticator.""" + + def __call__(self, r: requests.PreparedRequest) -> requests.PreparedRequest: + """Do nothing. + + Args: + r: The prepared request. + + Returns: + The unmodified prepared request. + """ + return r + + +class HeaderAuth(AuthBase): + """Header-based authenticator.""" + + def __init__( + self, + keyword: str, + value: str, + header: str = "Authorization", + ) -> None: + """Initialize the authenticator. + + Args: + keyword: The keyword to use in the header, e.g. "Bearer". + value: The value to use in the header, e.g. "my-token". + header: The header to add the keyword and value to, defaults to + ``"Authorization"``. + """ + self.keyword = keyword + self.value = value + self.header = header + + def __call__(self, r: requests.PreparedRequest) -> requests.PreparedRequest: + """Add the header to the request. + + Args: + r: The prepared request. + + Returns: + The prepared request with the header added. + """ + r.headers[self.header] = f"{self.keyword} {self.value}" + return r diff --git a/singer_sdk/connectors/__init__.py b/singer_sdk/connectors/__init__.py index 32799417a..1c3916672 100644 --- a/singer_sdk/connectors/__init__.py +++ b/singer_sdk/connectors/__init__.py @@ -2,6 +2,8 @@ from __future__ import annotations +from ._http import HTTPConnector +from .base import BaseConnector from .sql import SQLConnector -__all__ = ["SQLConnector"] +__all__ = ["BaseConnector", "HTTPConnector", "SQLConnector"] diff --git a/singer_sdk/connectors/_http.py b/singer_sdk/connectors/_http.py new file mode 100644 index 000000000..8301233bb --- /dev/null +++ b/singer_sdk/connectors/_http.py @@ -0,0 +1,139 @@ +"""HTTP-based tap class for Singer SDK.""" + +from __future__ import annotations + +import typing as t + +import requests + +from singer_sdk.authenticators import NoopAuth +from singer_sdk.connectors.base import BaseConnector + +if t.TYPE_CHECKING: + import sys + + from requests.adapters import BaseAdapter + + if sys.version_info >= (3, 10): + from typing import TypeAlias # noqa: ICN003 + else: + from typing_extensions import TypeAlias + +_Auth: TypeAlias = t.Callable[[requests.PreparedRequest], requests.PreparedRequest] + + +class HTTPConnector(BaseConnector[requests.Session]): + """Base class for all HTTP-based connectors.""" + + def __init__(self, config: t.Mapping[str, t.Any] | None) -> None: + """Initialize the HTTP connector. + + Args: + config: Connector configuration parameters. + """ + super().__init__(config) + self.__session = self.get_session() + self.refresh_auth() + + def get_connection(self, *, authenticate: bool = True) -> requests.Session: + """Return a new HTTP session object. + + Adds adapters and optionally authenticates the session. + + Args: + authenticate: Whether to authenticate the request. + + Returns: + A new HTTP session object. + """ + for prefix, adapter in self.adapters.items(): + self.__session.mount(prefix, adapter) + + self.__session.auth = self.auth if authenticate else None + + return self.__session + + def get_session(self) -> requests.Session: # noqa: PLR6301 + """Return a new HTTP session object. + + Returns: + A new HTTP session object. + """ + return requests.Session() + + def get_authenticator(self) -> _Auth: # noqa: PLR6301 + """Authenticate the HTTP session. + + Returns: + An auth callable. + """ + return NoopAuth() + + def refresh_auth(self) -> None: + """Refresh the HTTP session authentication.""" + self.auth = self.get_authenticator() + + @property + def auth(self) -> _Auth: + """Return the HTTP session authenticator. + + Returns: + An auth callable. + """ + return self.__auth + + @auth.setter + def auth(self, auth: _Auth) -> None: + """Set the HTTP session authenticator. + + Args: + auth: An auth callable. + """ + self.__auth = auth + + @property + def session(self) -> requests.Session: + """Return the HTTP session object. + + Returns: + The HTTP session object. + """ + return self.__session + + @property + def adapters(self) -> dict[str, BaseAdapter]: + """Return a mapping of URL prefixes to adapter objects. + + Returns: + A mapping of URL prefixes to adapter objects. + """ + return {} + + @property + def default_request_kwargs(self) -> dict[str, t.Any]: + """Return default kwargs for HTTP requests. + + Returns: + A mapping of default kwargs for HTTP requests. + """ + return {} + + def request( + self, + *args: t.Any, + authenticate: bool = True, + **kwargs: t.Any, + ) -> requests.Response: + """Make an HTTP request. + + Args: + *args: Positional arguments to pass to the request method. + authenticate: Whether to authenticate the request. + **kwargs: Keyword arguments to pass to the request method. + + Returns: + The HTTP response object. + """ + with self.connect(authenticate=authenticate) as session: + kwargs = {**self.default_request_kwargs, **kwargs} + return session.request(*args, **kwargs) diff --git a/singer_sdk/connectors/base.py b/singer_sdk/connectors/base.py new file mode 100644 index 000000000..8e199b943 --- /dev/null +++ b/singer_sdk/connectors/base.py @@ -0,0 +1,54 @@ +"""Base class for all connectors.""" + +from __future__ import annotations + +import abc +import typing as t +from contextlib import contextmanager + +_T = t.TypeVar("_T") + + +# class BaseConnector(abc.ABC, t.Generic[_T_co]): +class BaseConnector(abc.ABC, t.Generic[_T]): + """Base class for all connectors.""" + + def __init__(self, config: t.Mapping[str, t.Any] | None) -> None: + """Initialize the connector. + + Args: + config: Plugin configuration parameters. + """ + self._config = config or {} + + @property + def config(self) -> t.Mapping: + """Return the connector configuration. + + Returns: + A mapping of configuration parameters. + """ + return self._config + + @contextmanager + def connect(self, *args: t.Any, **kwargs: t.Any) -> t.Generator[_T, None, None]: + """Connect to the destination. + + Args: + args: Positional arguments to pass to the connection method. + kwargs: Keyword arguments to pass to the connection method. + + Yields: + A connection object. + """ + yield self.get_connection(*args, **kwargs) + + @abc.abstractmethod + def get_connection(self, *args: t.Any, **kwargs: t.Any) -> _T: + """Connect to the destination. + + Args: + args: Positional arguments to pass to the connection method. + kwargs: Keyword arguments to pass to the connection method. + """ + ... diff --git a/singer_sdk/connectors/sql.py b/singer_sdk/connectors/sql.py index f48222640..a9745f25c 100644 --- a/singer_sdk/connectors/sql.py +++ b/singer_sdk/connectors/sql.py @@ -13,6 +13,7 @@ from singer_sdk import typing as th from singer_sdk._singerlib import CatalogEntry, MetadataMapping, Schema +from singer_sdk.connectors.base import BaseConnector from singer_sdk.exceptions import ConfigValidationError from singer_sdk.helpers._util import dump_json, load_json from singer_sdk.helpers.capabilities import TargetLoadMethods @@ -22,7 +23,7 @@ from sqlalchemy.engine.reflection import Inspector -class SQLConnector: # noqa: PLR0904 +class SQLConnector(BaseConnector[sa.engine.Connection]): # noqa: PLR0904 """Base class for SQLAlchemy-based connectors. The connector class serves as a wrapper around the SQL connection. @@ -45,7 +46,7 @@ class SQLConnector: # noqa: PLR0904 def __init__( self, - config: dict | None = None, + config: t.Mapping[str, t.Any] | None = None, sqlalchemy_url: str | None = None, ) -> None: """Initialize the SQL connector. @@ -54,18 +55,9 @@ def __init__( config: The parent tap or target object's config. sqlalchemy_url: Optional URL for the connection. """ - self._config: dict[str, t.Any] = config or {} + super().__init__(config=config) self._sqlalchemy_url: str | None = sqlalchemy_url or None - @property - def config(self) -> dict: - """If set, provides access to the tap or target config. - - Returns: - The settings as a dict. - """ - return self._config - @property def logger(self) -> logging.Logger: """Get logger. @@ -76,9 +68,35 @@ def logger(self) -> logging.Logger: return logging.getLogger("sqlconnector") @contextmanager - def _connect(self) -> t.Iterator[sa.engine.Connection]: - with self._engine.connect().execution_options(stream_results=True) as conn: - yield conn + def _connect(self): # noqa: ANN202 + """Connect to the source. + + Yields: + A connection object. + """ + warnings.warn( + "`SQLConnector._connect` is deprecated. " + "Use `SQLConnector.connect` instead.", + DeprecationWarning, + stacklevel=2, + ) + with self.connect() as connection: + yield connection + + def get_connection( + self, + *, + stream_results: bool = True, + ) -> sa.engine.Connection: + """Return a new SQLAlchemy connection using the provided config. + + Args: + stream_results: Whether to stream results from the database. + + Returns: + A newly created SQLAlchemy connection object. + """ + return self._engine.connect().execution_options(stream_results=stream_results) def create_sqlalchemy_connection(self) -> sa.engine.Connection: """(DEPRECATED) Return a new SQLAlchemy connection using the provided config. @@ -158,7 +176,7 @@ def sqlalchemy_url(self) -> str: return self._sqlalchemy_url - def get_sqlalchemy_url(self, config: dict[str, t.Any]) -> str: # noqa: PLR6301 + def get_sqlalchemy_url(self, config: t.Mapping[str, t.Any]) -> str: # noqa: PLR6301 """Return the SQLAlchemy URL string. Developers can generally override just one of the following: @@ -661,7 +679,7 @@ def create_schema(self, schema_name: str) -> None: Args: schema_name: The target schema to create. """ - with self._connect() as conn, conn.begin(): + with self.connect() as conn, conn.begin(): conn.execute(sa.schema.CreateSchema(schema_name)) def create_empty_table( @@ -738,7 +756,7 @@ def _create_empty_column( column_name=column_name, column_type=sql_type, ) - with self._connect() as conn, conn.begin(): + with self.connect() as conn, conn.begin(): conn.execute(column_add_ddl) def prepare_schema(self, schema_name: str) -> None: @@ -842,7 +860,7 @@ def rename_column(self, full_table_name: str, old_name: str, new_name: str) -> N column_name=old_name, new_column_name=new_name, ) - with self._connect() as conn, conn.begin(): + with self.connect() as conn, conn.begin(): conn.execute(column_rename_ddl) def merge_sql_types( @@ -1149,7 +1167,7 @@ def _adapt_column_type( column_name=column_name, column_type=compatible_sql_type, ) - with self._connect() as conn, conn.begin(): + with self.connect() as conn, conn.begin(): conn.execute(alter_column_ddl) def serialize_json(self, obj: object) -> str: # noqa: PLR6301 @@ -1201,7 +1219,7 @@ def delete_old_versions( version_column_name: The name of the version column. current_version: The current ACTIVATE version of the table. """ - with self._connect() as conn, conn.begin(): + with self.connect() as conn, conn.begin(): conn.execute( sa.text( f"DELETE FROM {full_table_name} " # noqa: S608 diff --git a/singer_sdk/sinks/sql.py b/singer_sdk/sinks/sql.py index 33a741614..25b88be73 100644 --- a/singer_sdk/sinks/sql.py +++ b/singer_sdk/sinks/sql.py @@ -333,8 +333,7 @@ def bulk_insert_records( ] self.logger.info("Inserting with SQL: %s", insert_sql) - - with self.connector._connect() as conn, conn.begin(): # noqa: SLF001 + with self.connector.connect() as conn, conn.begin(): result = conn.execute(insert_sql, new_records) return result.rowcount @@ -413,7 +412,7 @@ def activate_version(self, new_version: int) -> None: bindparam("deletedate", value=deleted_at, type_=sa.types.DateTime), bindparam("version", value=new_version, type_=sa.types.Integer), ) - with self.connector._connect() as conn, conn.begin(): # noqa: SLF001 + with self.connector.connect() as conn, conn.begin(): conn.execute(query) diff --git a/singer_sdk/streams/rest.py b/singer_sdk/streams/rest.py index 559389e8c..44afd23ba 100644 --- a/singer_sdk/streams/rest.py +++ b/singer_sdk/streams/rest.py @@ -15,6 +15,7 @@ from singer_sdk import metrics from singer_sdk.authenticators import SimpleAuthenticator +from singer_sdk.connectors import HTTPConnector from singer_sdk.exceptions import FatalAPIError, RetriableAPIError from singer_sdk.helpers.jsonpath import extract_jsonpath from singer_sdk.pagination import ( @@ -44,7 +45,7 @@ class RESTStream(Stream, t.Generic[_TToken], metaclass=abc.ABCMeta): # noqa: PL """Abstract base class for REST API streams.""" _page_size: int = DEFAULT_PAGE_SIZE - _requests_session: requests.Session | None + _requests_session: requests.Session #: HTTP method to use for requests. Defaults to "GET". rest_method = "GET" @@ -101,7 +102,13 @@ def __init__( if path: self.path = path self._http_headers: dict = {} - self._requests_session = requests.Session() + + self.connector = HTTPConnector(self.config) + + # Override the connector's auth with the stream's auth + self.connector.auth = self.authenticator + + self._requests_session = self.connector.session self._compiled_jsonpath = None self._next_page_token_compiled_jsonpath = None @@ -146,8 +153,12 @@ def requests_session(self) -> requests.Session: Returns: The :class:`requests.Session` object for HTTP requests. """ - if not self._requests_session: - self._requests_session = requests.Session() + warn( + "The `requests_session` property is deprecated and will be removed in a " + "future release. Use the `connector` property instead.", + DeprecationWarning, + stacklevel=2, + ) return self._requests_session def validate_response(self, response: requests.Response) -> None: @@ -263,11 +274,13 @@ def _request( Returns: TODO """ - response = self.requests_session.send( - prepared_request, - timeout=self.timeout, - allow_redirects=self.allow_redirects, - ) + with self.connector.connect() as session: + response = session.send( + prepared_request, + timeout=self.timeout, + allow_redirects=self.allow_redirects, + ) + self._write_request_duration_log( endpoint=self.path, response=response, @@ -331,8 +344,8 @@ def build_prepared_request( A :class:`requests.PreparedRequest` object. """ request = requests.Request(*args, **kwargs) - self.requests_session.auth = self.authenticator - return self.requests_session.prepare_request(request) + with self.connector.connect(authenticate=True) as session: + return session.prepare_request(request) def prepare_request( self, diff --git a/singer_sdk/streams/sql.py b/singer_sdk/streams/sql.py index 954159885..689bc1434 100644 --- a/singer_sdk/streams/sql.py +++ b/singer_sdk/streams/sql.py @@ -207,7 +207,7 @@ def get_records(self, context: Context | None) -> t.Iterable[dict[str, t.Any]]: # processed. query = query.limit(self.ABORT_AT_RECORD_COUNT + 1) - with self.connector._connect() as conn: # noqa: SLF001 + with self.connector.connect() as conn: for record in conn.execute(query).mappings(): # TODO: Standardize record mapping type # https://github.com/meltano/sdk/issues/2096 diff --git a/tests/core/connectors/__init__.py b/tests/core/connectors/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/core/connectors/test_http_connector.py b/tests/core/connectors/test_http_connector.py new file mode 100644 index 000000000..2c62ec94a --- /dev/null +++ b/tests/core/connectors/test_http_connector.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import json +import typing as t + +import requests +import werkzeug +from requests.adapters import BaseAdapter + +from singer_sdk.authenticators import HeaderAuth +from singer_sdk.connectors import HTTPConnector + +if t.TYPE_CHECKING: + from pytest_httpserver import HTTPServer + + +class MockAdapter(BaseAdapter): + def send( + self, + request: requests.PreparedRequest, + stream: bool = False, # noqa: FBT002 + timeout: float | tuple[float, float] | tuple[float, None] | None = None, + verify: bool | str = True, # noqa: FBT002 + cert: bytes | str | tuple[bytes | str, bytes | str] | None = None, + proxies: t.Mapping[str, str] | None = None, + ) -> requests.Response: + """Send a request.""" + response = requests.Response() + data = { + "url": request.url, + "headers": dict(request.headers), + "method": request.method, + "body": request.body, + "stream": stream, + "timeout": timeout, + "verify": verify, + "cert": cert, + "proxies": proxies, + } + response.status_code = 200 + response._content = json.dumps(data).encode("utf-8") + return response + + def close(self) -> None: + pass + + +class HeaderAuthConnector(HTTPConnector): + def get_authenticator(self) -> HeaderAuth: + return HeaderAuth("Bearer", self.config["token"]) + + +def test_base_connector(httpserver: HTTPServer): + connector = HTTPConnector({}) + + httpserver.expect_request("").respond_with_json({"foo": "bar"}) + url = httpserver.url_for("/") + + response = connector.request("GET", url) + data = response.json() + assert data["foo"] == "bar" + + +def test_auth(httpserver: HTTPServer): + connector = HeaderAuthConnector({"token": "s3cr3t"}) + + def _handler(request: werkzeug.Request) -> werkzeug.Response: + return werkzeug.Response( + json.dumps( + { + "headers": dict(request.headers), + "url": request.url, + }, + ), + status=200, + mimetype="application/json", + ) + + httpserver.expect_request("").respond_with_handler(_handler) + url = httpserver.url_for("/") + + response = connector.request("GET", url) + data = response.json() + assert data["headers"]["Authorization"] == "Bearer s3cr3t" + + response = connector.request("GET", url, authenticate=False) + data = response.json() + assert "Authorization" not in data["headers"] + + +def test_custom_adapters(): + class MyConnector(HTTPConnector): + @property + def adapters(self) -> dict[str, BaseAdapter]: + return { + "https://test": MockAdapter(), + } + + connector = MyConnector({}) + response = connector.request("GET", "https://test") + data = response.json() + + assert data["url"] == "https://test/" + assert data["headers"] + assert data["method"] == "GET" diff --git a/tests/core/test_connector_sql.py b/tests/core/connectors/test_sql_connector.py similarity index 95% rename from tests/core/test_connector_sql.py rename to tests/core/connectors/test_sql_connector.py index 10ee0c0f4..afe5650e1 100644 --- a/tests/core/test_connector_sql.py +++ b/tests/core/connectors/test_sql_connector.py @@ -154,43 +154,45 @@ def test_engine_creates_and_returns_cached_engine(self, connector): engine2 = connector._cached_engine assert engine1 is engine2 - def test_deprecated_functions_warn(self, connector): + def test_deprecated_functions_warn(self, connector: SQLConnector): with pytest.deprecated_call(): connector.create_sqlalchemy_engine() with pytest.deprecated_call(): connector.create_sqlalchemy_connection() with pytest.deprecated_call(): _ = connector.connection + with pytest.deprecated_call(), connector._connect() as _: + pass - def test_connect_calls_engine(self, connector): + def test_connect_calls_engine(self, connector: SQLConnector): with mock.patch.object( SQLConnector, "_engine", - ) as mock_engine, connector._connect() as _: + ) as mock_engine, connector.connect() as _: mock_engine.connect.assert_called_once() - def test_connect_calls_connect(self, connector): + def test_connect_calls_connect(self, connector: SQLConnector): attached_engine = connector._engine with mock.patch.object( attached_engine, "connect", - ) as mock_conn, connector._connect() as _: + ) as mock_conn, connector.connect() as _: mock_conn.assert_called_once() - def test_connect_raises_on_operational_failure(self, connector): + def test_connect_raises_on_operational_failure(self, connector: SQLConnector): with pytest.raises( sa.exc.OperationalError, - ) as _, connector._connect() as conn: + ) as _, connector.connect() as conn: conn.execute(sa.text("SELECT * FROM fake_table")) - def test_rename_column_uses_connect_correctly(self, connector): + def test_rename_column_uses_connect_correctly(self, connector: SQLConnector): attached_engine = connector._engine # Ends up using the attached engine with mock.patch.object(attached_engine, "connect") as mock_conn: connector.rename_column("fake_table", "old_name", "new_name") mock_conn.assert_called_once() # Uses the _connect method - with mock.patch.object(connector, "_connect") as mock_connect_method: + with mock.patch.object(connector, "connect") as mock_connect_method: connector.rename_column("fake_table", "old_name", "new_name") mock_connect_method.assert_called_once() diff --git a/tests/core/test_streams.py b/tests/core/test_streams.py index 592f921e6..0add6934e 100644 --- a/tests/core/test_streams.py +++ b/tests/core/test_streams.py @@ -549,3 +549,9 @@ def discover_streams(self): assert all( tap.streams[stream].selected is selection[stream] for stream in selection ) + + +def test_deprecations(tap: SimpleTestTap): + stream = RestTestStream(tap=tap) + with pytest.deprecated_call(): + _ = stream.requests_session diff --git a/tests/samples/conftest.py b/tests/samples/conftest.py index 90cb80dec..c534a7d86 100644 --- a/tests/samples/conftest.py +++ b/tests/samples/conftest.py @@ -19,9 +19,9 @@ def csv_config(outdir: str) -> dict: @pytest.fixture -def _sqlite_sample_db(sqlite_connector): +def _sqlite_sample_db(sqlite_connector: SQLiteConnector): """Return a path to a newly constructed sample DB.""" - with sqlite_connector._connect() as conn, conn.begin(): + with sqlite_connector.connect() as conn, conn.begin(): for t in range(3): conn.execute(sa.text(f"DROP TABLE IF EXISTS t{t}")) conn.execute(