diff --git a/CHANGELOG.md b/CHANGELOG.md index c8ad520..2435368 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ ## 0.2.24dev +* [Feature] Allows disabling the message that promotes Ploomber Cloud when initializing `Telemetry` + ## 0.2.23 (2024-01-23) * [Fix] Fully removes dependency on `click` diff --git a/src/ploomber_core/telemetry/telemetry.py b/src/ploomber_core/telemetry/telemetry.py index b4a63a2..b06751a 100644 --- a/src/ploomber_core/telemetry/telemetry.py +++ b/src/ploomber_core/telemetry/telemetry.py @@ -28,6 +28,7 @@ telemetry_version - Telemetry version """ + from copy import copy from inspect import signature, _empty import logging @@ -326,7 +327,7 @@ def check_cloud(): internal.last_cloud_check = now -def _get_telemetry_info(package_name, version): +def _get_telemetry_info(): """ The function checks for the local config and uid files, returns the right values according to the config file (True/False). In addition it checks @@ -335,10 +336,6 @@ def _get_telemetry_info(package_name, version): # Check if telemetry is enabled, if not skip, else check for uid telemetry_enabled = check_telemetry_enabled() - # Check latest version - check_version(package_name, version) - check_cloud() - if telemetry_enabled: # Check first time install is_install = check_first_time_usage() @@ -384,7 +381,7 @@ def log_call(self, action=None, payload=False, log_args=False, ignore_args=None) class Telemetry: - def __init__(self, api_key, package_name, version): + def __init__(self, api_key, package_name, version, *, print_cloud_message=True): """ Parameters @@ -398,6 +395,9 @@ def __init__(self, api_key, package_name, version): version : str Version of the package calling the function + print_cloud_message : bool, default=True + If True, it'll print a message to ask the user to sign up for + Ploomber Cloud """ if "_PLOOMBER_TELEMETRY_DEBUG" in os.environ: warnings.warn( @@ -409,16 +409,22 @@ def __init__(self, api_key, package_name, version): self.api_key = api_key self.package_name = package_name self.version = version + self.print_cloud_message = print_cloud_message @classmethod - def from_package(cls, package_name): + def from_package(cls, package_name, *, print_cloud_message=True): """ Initialize a Telemetry client with the default configuration for a package with the given name """ default_api_key = "phc_P9SpSeypyPwxrMdFn2edOOEooQioF2axppyEeDwtMSP" version = get_package_version(package_name) - return cls(api_key=default_api_key, package_name=package_name, version=version) + return cls( + api_key=default_api_key, + package_name=package_name, + version=version, + print_cloud_message=print_cloud_message, + ) def log_api(self, action, client_time=None, total_runtime=None, metadata=None): """ @@ -434,9 +440,13 @@ def log_api(self, action, client_time=None, total_runtime=None, metadata=None): if client_time is None: client_time = datetime.datetime.now() - (telemetry_enabled, uid, is_install) = _get_telemetry_info( - self.package_name, self.version - ) + (telemetry_enabled, uid, is_install) = _get_telemetry_info() + + # Check latest version + check_version(self.package_name, self.version) + + if self.print_cloud_message: + check_cloud() # NOTE: this should not happen anymore if "NO_UID" in uid: diff --git a/tests/telemetry/test_telemetry.py b/tests/telemetry/test_telemetry.py index 4eaebcf..60bfa1c 100644 --- a/tests/telemetry/test_telemetry.py +++ b/tests/telemetry/test_telemetry.py @@ -328,9 +328,7 @@ def test_full_telemetry_info(monkeypatch, ignore_env_var_and_set_tmp_default_hom monkeypatch.setattr(telemetry, "DEFAULT_HOME_DIR", str(Path().absolute())) monkeypatch.setattr(telemetry, "internal", telemetry.Internal()) - (stat_enabled, uid, is_install) = telemetry._get_telemetry_info( - "ploomber", "0.14.0" - ) + (stat_enabled, uid, is_install) = telemetry._get_telemetry_info() assert stat_enabled is True assert isinstance(uid, str) assert is_install is True @@ -1083,3 +1081,35 @@ def test_from_package(): assert _telemetry.api_key assert _telemetry.version assert _telemetry.package_name == "ploomber-core" + + +def test_runs_check_cloud(monkeypatch): + mock = Mock() + monkeypatch.setattr(telemetry, "check_cloud", mock) + + my_telemetry = telemetry.Telemetry( + MOCK_API_KEY, + "some-package", + "1.2.2", + print_cloud_message=True, + ) + + my_telemetry.log_api("some-action") + + mock.assert_called_once() + + +def test_disable_check_cloud(monkeypatch): + mock = Mock() + monkeypatch.setattr(telemetry, "check_cloud", mock) + + my_telemetry = telemetry.Telemetry( + MOCK_API_KEY, + "some-package", + "1.2.2", + print_cloud_message=False, + ) + + my_telemetry.log_api("some-action") + + mock.assert_not_called()