Skip to content
This repository has been archived by the owner on Oct 19, 2022. It is now read-only.

Commit

Permalink
Update __init__.py
Browse files Browse the repository at this point in the history
  • Loading branch information
SireInsectus committed Oct 3, 2022
1 parent 1f1afab commit 6f165ee
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions src/dbacademy_gems/dbgems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,26 +115,28 @@ def get_cloud():
raise Exception("Unable to identify the cloud provider.")


def get_tags() -> dict:
def get_tags():
tags = dbutils.entry_point.getDbutils().notebook().getContext().tags()
# noinspection PyProtectedMember,PyUnresolvedReferences
return sc._jvm.scala.collection.JavaConversions.mapAsJavaMap(
dbutils.entry_point.getDbutils().notebook().getContext().tags())
java_map = sc._jvm.scala.collection.JavaConversions.mapAsJavaMap(tags)
return java_map


def get_tag(tag_name: str, default_value: str = None) -> str:
return get_tags().get(tag_name, default_value)
value = get_tags().get(tag_name)
return value or default_value


def get_username() -> str:
return get_tags()["user"]
return get_tag("user")


def get_browser_host_name(default_value=None):
return get_tag(tag_name="browserHostName", default_value=default_value)


def get_job_id(default_value=None):
return get_tag("jobId", default_value=default_value)
return get_tag(tag_name="jobId", default_value=default_value)


def is_job():
Expand Down Expand Up @@ -177,7 +179,7 @@ def get_current_spark_version(client=None):
if includes_dbrest:
# noinspection PyUnresolvedReferences
from dbacademy import dbrest
cluster_id = get_tags()["clusterId"]
cluster_id = get_tag("clusterId")
client = dbrest.DBAcademyRestClient() if client is None else client
cluster = client.clusters().get(cluster_id)
return cluster.get("spark_version", None)
Expand All @@ -190,7 +192,7 @@ def get_current_instance_pool_id(client=None):
if includes_dbrest:
# noinspection PyUnresolvedReferences
from dbacademy import dbrest
cluster_id = get_tags()["clusterId"]
cluster_id = get_tag("clusterId")
client = dbrest.DBAcademyRestClient() if client is None else client
cluster = client.clusters().get(cluster_id)
return cluster.get("instance_pool_id", None)
Expand All @@ -205,7 +207,7 @@ def get_current_node_type_id(client=None):
if includes_dbrest:
# noinspection PyUnresolvedReferences
from dbacademy import dbrest
cluster_id = get_tags()["clusterId"]
cluster_id = get_tag("clusterId")
client = dbrest.DBAcademyRestClient() if client is None else client
cluster = client.clusters().get(cluster_id)
return cluster.get("node_type_id", None)
Expand Down

0 comments on commit 6f165ee

Please sign in to comment.