Skip to content

Commit

Permalink
MLCOMPUTE-1203 | Configure Spark driver pod memory and cores based on…
Browse files Browse the repository at this point in the history
… Spark args
  • Loading branch information
Sameer Sharma committed Jun 10, 2024
1 parent 94394c6 commit f727e9e
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
21 changes: 21 additions & 0 deletions paasta_tools/spark_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
SPARK_JOB_USER = "TRON"
SPARK_PROMETHEUS_SHARD = "ml-compute"
SPARK_DNS_POD_TEMPLATE = "/nail/srv/configs/spark_dns_pod_template.yaml"
MEM_MULTIPLIER = {'k': 1024, 'm': 1024**2, 'g': 1024**3, 't': 1024**4}

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -247,3 +248,23 @@ def get_spark_driver_monitoring_labels(
"spark.yelp.com/driver_ui_port": ui_port_str,
}
return labels


def get_spark_memory_in_unit(mem: str, unit: str) -> int:
"""
Converts Spark memory to the desired unit.
mem is the same format as JVM memory strings: just number or number followed by 'k', 'm', 'g' or 't'.
unit can be 'k', 'm', 'g' or 't'.
Returns memory as an integer converted to the desired unit.
"""
memory_bytes = 0
if mem:
if mem[-1] in MEM_MULTIPLIER:
memory_bytes = int(mem[:-1]) * MEM_MULTIPLIER[mem[-1]]
else:
try:
memory_bytes = int(mem)
except ValueError:
memory_bytes = 0
memory_unit = memory_bytes/MEM_MULTIPLIER[unit]
return memory_unit
15 changes: 13 additions & 2 deletions paasta_tools/tron_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,18 @@ def format_tron_action_dict(action_config: TronActionConfig):
"max_runtime", spark_tools.DEFAULT_SPARK_RUNTIME_TIMEOUT
),
)
# set Spark driver pod CPU and memory config if it is specified by Spark arguments
if "spark.driver.cores" in spark_config:
result["cpus"] = spark_config["spark.driver.cores"]
if "spark.driver.memory" in spark_config:
# need to set mem in MB based on tron schema
memory_in_mb = spark_tools.get_spark_memory_in_unit(spark_config["spark.driver.memory"], 'm')
if memory_in_mb:
result["mem"] = str(memory_in_mb)

# point to the KUBECONFIG needed by Spark driver
result["env"]["KUBECONFIG"] = system_paasta_config.get_spark_kubeconfig()

# spark, unlike normal batches, needs to expose several ports for things like the spark
# ui and for executor->driver communication
result["ports"] = list(
Expand Down Expand Up @@ -1014,8 +1025,8 @@ def format_tron_action_dict(action_config: TronActionConfig):
# the following config is only valid for k8s/Mesos since we're not running SSH actions
# in a containerized fashion
if executor in (KUBERNETES_EXECUTOR_NAMES + MESOS_EXECUTOR_NAMES):
result["cpus"] = action_config.get_cpus()
result["mem"] = action_config.get_mem()
result.setdefault("cpus", action_config.get_cpus())
result.setdefault("mem", action_config.get_mem())
result["disk"] = action_config.get_disk()
result["extra_volumes"] = format_volumes(extra_volumes)
result["docker_image"] = action_config.get_docker_url()
Expand Down

0 comments on commit f727e9e

Please sign in to comment.