diff --git a/graphistry/PlotterBase.py b/graphistry/PlotterBase.py
index d56d07f6b..577fde40b 100644
--- a/graphistry/PlotterBase.py
+++ b/graphistry/PlotterBase.py
@@ -1403,7 +1403,20 @@ def plot(
info = PyGraphistry._etl1(dataset)
elif api_version == 3:
logger.debug("3. @PloatterBase plot: PyGraphistry.org_name(): {}".format(PyGraphistry.org_name()))
- PyGraphistry.refresh()
+
+ if not PyGraphistry.api_token() and PyGraphistry.sso_state(): # if it is sso login
+ if in_ipython() or in_databricks() or PyGraphistry._config["sso_opt_into_type"] == 'display':
+ PyGraphistry.sso_wait_for_token_text_display()
+ if not PyGraphistry.sso_verify_token_display():
+ from IPython.core.display import HTML
+ msg_html = "Invalid token due to login timeout"
+ return HTML(msg_html)
+ else:
+ PyGraphistry.sso_repeat_get_token()
+
+
+ else: # if not sso mode, just refresh to make sure token is valid
+ PyGraphistry.refresh()
logger.debug("4. @PloatterBase plot: PyGraphistry.org_name(): {}".format(PyGraphistry.org_name()))
dataset = self._plot_dispatch(g, n, name, description, 'arrow', self._style, memoize)
diff --git a/graphistry/__init__.py b/graphistry/__init__.py
index befef2c1a..f5e98be77 100644
--- a/graphistry/__init__.py
+++ b/graphistry/__init__.py
@@ -46,7 +46,9 @@
ArrowFileUploader,
PyGraphistry,
from_igraph,
- from_cugraph
+ from_cugraph,
+ sso_wait_for_token_display,
+ register_databricks_sso,
)
from graphistry.compute import (
diff --git a/graphistry/pygraphistry.py b/graphistry/pygraphistry.py
index 0925a6a74..7c41ee3d7 100644
--- a/graphistry/pygraphistry.py
+++ b/graphistry/pygraphistry.py
@@ -15,7 +15,7 @@
from . import util
from . import bolt_util
from .plotter import Plotter
-from .util import in_databricks, setup_logger, in_ipython, make_iframe
+from .util import in_databricks, setup_logger, in_ipython, make_iframe, display_message_html
from .exceptions import SsoRetrieveTokenTimeoutException, TokenExpireException
from .messages import (
@@ -135,6 +135,12 @@ def __reset_token_creds_in_memory():
PyGraphistry._config["api_key"] = None
PyGraphistry._is_authenticated = False
+ @staticmethod
+ def __reset_sso_variables_in_memory():
+ """Reset the sso related variable in memory, used when switching hosts, switching register method"""
+
+ PyGraphistry._config["sso_state"] = None
+ PyGraphistry._config["sso_opt_into_type"] = None
@staticmethod
@@ -271,7 +277,7 @@ def _handle_auth_url(auth_url, sso_timeout, sso_opt_into_type):
if in_ipython() or in_databricks() or sso_opt_into_type == 'display': # If run in notebook, just display the HTML
# from IPython.core.display import HTML
from IPython.display import display, HTML
- display(HTML(f'Login SSO'))
+ display(HTML(f'Login SSO
Please click the above URL to open browser to login'))
print("Please click the above URL to open browser to login")
print(f"If you cannot see the URL, please open browser, browse to this URL: {auth_url}")
print("Please close browser tab after SSO login to back to notebook")
@@ -406,6 +412,8 @@ def refresh(token=None, fail_silent=False):
logger.debug("2. @PyGraphistry refresh :relogin")
if isinstance(e, TokenExpireException):
print("Token is expired, you need to relogin")
+ PyGraphistry._config['api_token'] = None
+ PyGraphistry._is_authenticated = False
return PyGraphistry.relogin()
if not fail_silent:
@@ -414,7 +422,7 @@ def refresh(token=None, fail_silent=False):
@staticmethod
def verify_token(token=None, fail_silent=False) -> bool:
- """Return True iff current or provided token is still valid"""
+ """Return True if current or provided token is still valid"""
using_self_token = token is None
try:
logger.debug("JWT refresh")
@@ -570,6 +578,11 @@ def certificate_validation(value=None):
def set_bolt_driver(driver=None):
PyGraphistry._config["bolt_driver"] = bolt_util.to_bolt_driver(driver)
+ @staticmethod
+ def set_sso_opt_into_type(value: Optional[str]):
+ PyGraphistry._config["sso_opt_into_type"] = value
+
+
@staticmethod
def register(
key: Optional[str] = None,
@@ -707,6 +720,8 @@ def register(
PyGraphistry.set_bolt_driver(bolt)
# Reset token creds
PyGraphistry.__reset_token_creds_in_memory()
+ # Reset sso related variables in memory
+ PyGraphistry.__reset_sso_variables_in_memory()
if not (username is None) and not (password is None):
PyGraphistry.login(username, password, org_name)
@@ -2321,7 +2336,22 @@ def layout_settings(
@staticmethod
def org_name(value=None):
- """Set or get the org_name when register/login.
+ """Set or get the organization name during registration or login.
+
+ :param value: The organization name to set. If None, the current organization name is returned.
+ :type value: Optional[str]
+ :return: The current organization name if value is None, otherwise None.
+ :rtype: Optional[str]
+
+ **Example: Setting the organization name**
+ ::
+ import graphistry
+ graphistry.org_name("my_org_name")
+
+ **Example: Getting the organization name**
+ ::
+ import graphistry
+ org_name = graphistry.org_name()
"""
if value is None:
@@ -2339,7 +2369,22 @@ def org_name(value=None):
@staticmethod
def idp_name(value=None):
- """Set or get the idp_name when register/login.
+ """Set or get the IDP (Identity Provider) name during registration or login.
+
+ :param value: The IDP name to set. If None, the current IDP name is returned.
+ :type value: Optional[str]
+ :return: The current IDP name if value is None, otherwise None.
+ :rtype: Optional[str]
+
+ **Example: Setting the IDP name**
+ ::
+ import graphistry
+ graphistry.idp_name("my_idp_name")
+
+ **Example: Getting the IDP name**
+ ::
+ import graphistry
+ idp_name = graphistry.idp_name()
"""
if value is None:
@@ -2354,7 +2399,22 @@ def idp_name(value=None):
@staticmethod
def sso_state(value=None):
- """Set or get the sso_state when register/sso login.
+ """Set or get the SSO state during registration or SSO login.
+
+ :param value: The SSO state to set. If None, the current SSO state is returned.
+ :type value: Optional[str]
+ :return: The current SSO state if value is None, otherwise None.
+ :rtype: Optional[str]
+
+ **Example: Setting the SSO state**
+ ::
+ import graphistry
+ graphistry.sso_state("my_sso_state")
+
+ **Example: Getting the SSO state**
+ ::
+ import graphistry
+ sso_state = graphistry.sso_state()
"""
if value is None:
@@ -2390,7 +2450,22 @@ def scene_settings(
@staticmethod
def personal_key_id(value: Optional[str] = None):
- """Set or get the personal_key_id when register.
+ """Set or get the personal_key_id during registration.
+
+ :param value: The personal key ID to set. If None, the current personal key ID is returned.
+ :type value: Optional[str]
+ :return: The current personal key ID if value is None, otherwise None.
+ :rtype: Optional[str]
+
+ **Example: Setting the personal key ID**
+ ::
+ import graphistry
+ graphistry.personal_key_id("my_personal_key_id")
+
+ **Example: Getting the personal key ID**
+ ::
+ import graphistry
+ key_id = graphistry.personal_key_id()
"""
if value is None:
@@ -2404,7 +2479,22 @@ def personal_key_id(value: Optional[str] = None):
@staticmethod
def personal_key_secret(value: Optional[str] = None):
- """Set or get the personal_key_secret when register.
+ """Set or get the personal_key_secret during registration.
+
+ :param value: The personal key secret to set. If None, the current personal key secret is returned.
+ :type value: Optional[str]
+ :return: The current personal key secret if value is None, otherwise None.
+ :rtype: Optional[str]
+
+ **Example: Setting the personal key secret**
+ ::
+ import graphistry
+ graphistry.personal_key_secret("my_personal_key_secret")
+
+ **Example: Getting the personal key secret**
+ ::
+ import graphistry
+ secret = graphistry.personal_key_secret()
"""
if value is None:
@@ -2446,7 +2536,184 @@ def _handle_api_response(response):
logger.error('Error: %s', response, exc_info=True)
raise Exception("Unknown Error")
+ @staticmethod
+ def sso_repeat_get_token(repeat: int = 20, wait: int = 5):
+ """Repeatedly call to obtain the JWT token after SSO login.
+
+ :param repeat: Number of times to attempt obtaining the token, defaults to 20
+ :type repeat: int, optional
+ :param wait: Number of seconds to wait between attempts, defaults to 5
+ :type wait: int, optional
+ :return: The obtained JWT token or None if unsuccessful
+ :rtype: Optional[str]
+
+ **Example:**
+
+ ::
+
+ token = PyGraphistry.sso_repeat_get_token(repeat=10, wait=2)
+ if token:
+ print("Token obtained:", token)
+ else:
+ print("Failed to obtain token")
+ """
+
+ for _ in range(repeat):
+ token = PyGraphistry.sso_get_token()
+ if token:
+ return token
+ time.sleep(wait)
+
+ return
+
+ @staticmethod
+ def sso_wait_for_token_display(repeat: int = 20, wait: int = 5, fail_silent: bool = False, display_mode: str = 'text'):
+ if display_mode == 'html':
+ PyGraphistry.sso_wait_for_token_html_display(repeat, wait, fail_silent)
+ else:
+ PyGraphistry.sso_wait_for_token_text_display(repeat, wait, fail_silent)
+
+ @staticmethod
+ def sso_wait_for_token_text_display(repeat: int = 20, wait: int = 5, fail_silent: bool = False):
+ """Get the JWT token for SSO login and display the corresponding message in text.
+
+ This method attempts to obtain the JWT token for SSO login and displays the result as a text message.
+
+ :param repeat: Number of times to attempt obtaining the token, defaults to 20
+ :type repeat: int, optional
+ :param wait: Number of seconds to wait between attempts, defaults to 5
+ :type wait: int, optional
+ :param fail_silent: Whether to suppress exceptions on failure, defaults to False
+ :type fail_silent: bool, optional
+
+ **Example:**
+
+ ::
+
+ PyGraphistry.sso_wait_for_token_text_display(repeat=10, wait=2, fail_silent=True)
+ """
+ if not PyGraphistry.api_token():
+ msg_text = '....'
+ if not PyGraphistry.sso_repeat_get_token(repeat, wait):
+ msg_text = f'{msg_text}\nFailed to get token after {repeat * wait} seconds ....'
+ if not fail_silent:
+ msg = f"Failed to get token after {repeat * wait} seconds. Please re-run the login process"
+ if in_ipython() or in_databricks() or PyGraphistry.set_sso_opt_into_type == "display":
+ display_message_html(f"{msg}")
+ raise Exception(msg)
+ else:
+ msg_text = f'{msg_text}\nGot token'
+ print(msg_text)
+ return
+
+ msg_text = f'{msg_text}\nGot token'
+ print(msg_text)
+ else:
+ print('Token is valid, no waiting required.')
+
+
+ @staticmethod
+ def sso_wait_for_token_html_display(repeat: int = 20, wait: int = 5, fail_silent: bool = False):
+ """Get the JWT token for SSO login and display the corresponding message in HTML.
+
+ This method attempts to obtain the JWT token for SSO login and displays the result as an HTML message.
+
+ :param repeat: Number of times to attempt obtaining the token, defaults to 20
+ :type repeat: int, optional
+ :param wait: Number of seconds to wait between attempts, defaults to 5
+ :type wait: int, optional
+ :param fail_silent: Whether to suppress exceptions on failure, defaults to False
+ :type fail_silent: bool, optional
+
+ **Example:**
+
+ ::
+
+ PyGraphistry.sso_wait_for_token_html_display(repeat=10, wait=2, fail_silent=True)
+ """
+ from IPython.display import display, HTML
+ if not PyGraphistry.api_token():
+ msg_html = '
.... '
+ if not PyGraphistry.sso_repeat_get_token(repeat, wait):
+ msg_html = f'{msg_html}
Failed to get token after {repeat * wait} seconds .... '
+ if not fail_silent:
+ raise Exception(f"Failed to get token after {repeat * wait} seconds. Please re-run the login process")
+ else:
+ msg_html = f'{msg_html}
Got token'
+ display(HTML(msg_html))
+ return
+
+ msg_html = f'{msg_html}
Got token'
+ display(HTML(msg_html))
+ else:
+ display(HTML('
Token is valid, no waiting required.'))
+
+
+ @staticmethod
+ def sso_verify_token_display(
+ repeat: int = 20,
+ wait: int = 5,
+ display_mode: str = 'text'
+ ) -> bool:
+ if display_mode == 'html':
+ from IPython.display import display, HTML, clear_output
+ clear_output()
+
+ required_login = False
+ token = PyGraphistry.api_token()
+ if token:
+ is_valid = PyGraphistry.verify_token()
+ print(f"is_valid : {is_valid}")
+ if not is_valid:
+ print("***********token not valid, refresh token*****************")
+ if display_mode == 'html':
+ display(HTML('
Refresh token ....'))
+ try:
+ PyGraphistry.refresh()
+ except Exception:
+ required_login = True
+
+ else:
+ print("Token is still valid")
+ if display_mode == 'html':
+ display(HTML('
Token is still valid ....'))
+
+ else:
+ required_login = True
+
+ if required_login:
+ print("***********Prepare to sign in*****************")
+ msg_html = f'
Prepare to sign in ....
Please Login with the link appear later. Waiting for success login for {repeat * wait} seconds, please login within {wait} seconds....
Please close the browser tab and come back to dashboard....'
+ display(HTML(msg_html))
+
+
+ return not required_login
+
+
+ # Databricks Dashboard SSO helper functions
+ class DatabricksHelper():
+ """Helper class for databricks.
+
+ **Helper class to improve the sso login flow**
+
+ """
+ @staticmethod
+ def register_databricks_sso(
+ server: Optional[str] = None,
+ org_name: Optional[str] = None,
+ idp_name: Optional[str] = None,
+ **kwargs
+ ):
+ if not PyGraphistry.api_token():
+ PyGraphistry.register(api=3, protocol="https", server=server, is_sso_login=True, org_name=org_name, idp_name=idp_name, sso_timeout=None, sso_opt_into_type="display")
+
+ # @staticmethod
+ # def databricks_sso_login(server="hub.graphistry.com", org_name=None, idp_name=None, retry=5, wait=20):
+ # from IPython.display import clear_output
+ # clear_output()
+ # if not PyGraphistry.api_token():
+ # PyGraphistry.register(api=3, protocol="https", server=server, is_sso_login=True, org_name=org_name, idp_name=idp_name, sso_timeout=None, sso_opt_into_type="display")
client_protocol_hostname = PyGraphistry.client_protocol_hostname
@@ -2501,6 +2768,11 @@ def _handle_api_response(response):
personal_key_secret = PyGraphistry.personal_key_secret
switch_org = PyGraphistry.switch_org
+# databricks dashboard helper functions
+sso_wait_for_token_display = PyGraphistry.sso_wait_for_token_display
+sso_verify_token_display = PyGraphistry.sso_verify_token_display
+sso_repeat_get_token = PyGraphistry.sso_repeat_get_token
+register_databricks_sso = PyGraphistry.DatabricksHelper.register_databricks_sso
class NumpyJSONEncoder(json.JSONEncoder):
diff --git a/graphistry/util.py b/graphistry/util.py
index c2c47996f..b4dee6acb 100644
--- a/graphistry/util.py
+++ b/graphistry/util.py
@@ -10,7 +10,7 @@
import uuid
import warnings
from functools import lru_cache
-from typing import Any
+from typing import Any, Optional
from collections import UserDict
from .constants import VERBOSE, CACHE_COERCION_SIZE, TRACE
@@ -172,8 +172,16 @@ def check_set_memoize(
weakref[hashed] = w
return False
+def display_message_html(message: str, cleared: Optional[bool] = False):
+ from IPython.display import display, HTML, clear_output
-def make_iframe(url, height, extra_html="", override_html_style=None):
+ if cleared:
+ clear_output()
+
+ display(HTML(message))
+
+
+def make_iframe(url, height, extra_html="", override_html_style=None, srcdoc: Optional[str] = None):
id = uuid.uuid4()
height_str = (
@@ -203,7 +211,7 @@ def make_iframe(url, height, extra_html="", override_html_style=None):
)
iframe = """
-