From 6f165ee08a44239514a9b6fc3fd7f15f50e87aec Mon Sep 17 00:00:00 2001 From: SireInsectus Date: Sun, 2 Oct 2022 21:41:38 -0500 Subject: [PATCH] Update __init__.py --- src/dbacademy_gems/dbgems/__init__.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/dbacademy_gems/dbgems/__init__.py b/src/dbacademy_gems/dbgems/__init__.py index 4e21e4a..038ec6f 100644 --- a/src/dbacademy_gems/dbgems/__init__.py +++ b/src/dbacademy_gems/dbgems/__init__.py @@ -115,18 +115,20 @@ def get_cloud(): raise Exception("Unable to identify the cloud provider.") -def get_tags() -> dict: +def get_tags(): + tags = dbutils.entry_point.getDbutils().notebook().getContext().tags() # noinspection PyProtectedMember,PyUnresolvedReferences - return sc._jvm.scala.collection.JavaConversions.mapAsJavaMap( - dbutils.entry_point.getDbutils().notebook().getContext().tags()) + java_map = sc._jvm.scala.collection.JavaConversions.mapAsJavaMap(tags) + return java_map def get_tag(tag_name: str, default_value: str = None) -> str: - return get_tags().get(tag_name, default_value) + value = get_tags().get(tag_name) + return value or default_value def get_username() -> str: - return get_tags()["user"] + return get_tag("user") def get_browser_host_name(default_value=None): @@ -134,7 +136,7 @@ def get_browser_host_name(default_value=None): def get_job_id(default_value=None): - return get_tag("jobId", default_value=default_value) + return get_tag(tag_name="jobId", default_value=default_value) def is_job(): @@ -177,7 +179,7 @@ def get_current_spark_version(client=None): if includes_dbrest: # noinspection PyUnresolvedReferences from dbacademy import dbrest - cluster_id = get_tags()["clusterId"] + cluster_id = get_tag("clusterId") client = dbrest.DBAcademyRestClient() if client is None else client cluster = client.clusters().get(cluster_id) return cluster.get("spark_version", None) @@ -190,7 +192,7 @@ def get_current_instance_pool_id(client=None): if includes_dbrest: # noinspection PyUnresolvedReferences from dbacademy import dbrest - cluster_id = get_tags()["clusterId"] + cluster_id = get_tag("clusterId") client = dbrest.DBAcademyRestClient() if client is None else client cluster = client.clusters().get(cluster_id) return cluster.get("instance_pool_id", None) @@ -205,7 +207,7 @@ def get_current_node_type_id(client=None): if includes_dbrest: # noinspection PyUnresolvedReferences from dbacademy import dbrest - cluster_id = get_tags()["clusterId"] + cluster_id = get_tag("clusterId") client = dbrest.DBAcademyRestClient() if client is None else client cluster = client.clusters().get(cluster_id) return cluster.get("node_type_id", None)