Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempt #2 MLCOMPUTE-1203 | Configure Spark driver pod memory and cores based on Spark args #3903

Merged
merged 35 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
f727e9e
MLCOMPUTE-1203 | Configure Spark driver pod memory and cores based on…
Jun 10, 2024
488b3b3
MLCOMPUTE-1203 | add validation to prevent users from specifying cpus…
Jun 13, 2024
b7d0833
Merge branch 'master' of https://github.com/Yelp/paasta into MLCOMPUT…
Jun 13, 2024
fac144d
MLCOMPUTE-1203 | overload instance config methods, add tests
Jun 13, 2024
9614c92
MLCOMPUTE-1203 | fix tests
Jun 13, 2024
bbd6c59
MLCOMPUTE-1203 | fix tests
Jun 13, 2024
9782558
MLCOMPUTE-1203 | minor fixes
Jun 13, 2024
1b45874
MLCOMPUTE-1203 | move building spark config to constructor
Jun 14, 2024
659b96f
MLCOMPUTE-1203 | fix tests and formatting
Jun 17, 2024
9d16aa5
Merge branch 'master' of https://github.com/Yelp/paasta into MLCOMPUT…
Jun 17, 2024
ba197eb
Merge branch 'master' of https://github.com/Yelp/paasta into MLCOMPUT…
Jun 27, 2024
d60db18
DAR-2360 | move the creation of spark config away from the TronAction…
Jun 27, 2024
01e43d4
MLCOMPUTE-1203 | add validation signature to methods loading instance…
Jun 27, 2024
3ce3c0b
Merge branch 'master' of https://github.com/Yelp/paasta into MLCOMPUT…
Jun 27, 2024
45f7b4a
MLCOMPUTE-1203 | add validation signature to methods loading instance…
Jun 27, 2024
5d27b1b
formatting
Jun 27, 2024
b94c2ed
MLCOMPUTE-1203 | fix arg type
Jun 27, 2024
6d903be
add missing pos argument
Jun 27, 2024
2e30fc6
Revert "add missing pos argument"
nemacysts Jun 28, 2024
db49af6
Revert "MLCOMPUTE-1203 | fix arg type"
nemacysts Jun 28, 2024
dae03fa
Revert "formatting"
nemacysts Jun 28, 2024
7790790
Revert "MLCOMPUTE-1203 | add validation signature to methods loading …
nemacysts Jun 28, 2024
2217de8
Revert "MLCOMPUTE-1203 | add validation signature to methods loading …
nemacysts Jun 28, 2024
800f012
Start splitting SA creation and validation (for tron)
nemacysts Jun 28, 2024
fbae67a
whoops
nemacysts Jun 28, 2024
d154f24
correctly iterate over dicts
nemacysts Jun 28, 2024
784856f
correctly parse yaml
nemacysts Jun 28, 2024
8f555e3
remove test call
nemacysts Jun 28, 2024
36dc8d3
use spark kubeconfig
nemacysts Jun 28, 2024
f98ae5d
Merge remote-tracking branch 'origin/u/krall/ensure_service_account' …
nemacysts Jun 28, 2024
def9059
Merge branch 'master' of github.com:Yelp/paasta into MLCOMPUTE-1203_f…
nemacysts Jun 28, 2024
6a4f3b8
Use krall's latest changes
nemacysts Jun 28, 2024
b008146
don't make kubeclients outside of dry-run
nemacysts Jun 28, 2024
7de6ee6
build spark config in constructor
nemacysts Jun 28, 2024
8163fda
Merge branch 'master' of https://github.com/Yelp/paasta into MLCOMPUT…
Jul 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions paasta_tools/setup_tron_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,13 @@

import ruamel.yaml as yaml

from paasta_tools import spark_tools
from paasta_tools import tron_tools
from paasta_tools.kubernetes_tools import ensure_service_account
from paasta_tools.kubernetes_tools import KubeClient
from paasta_tools.tron_tools import KUBERNETES_NAMESPACE
from paasta_tools.tron_tools import MASTER_NAMESPACE
from paasta_tools.utils import load_system_paasta_config

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -63,6 +68,32 @@ def parse_args():
return args


def ensure_service_accounts(
raw_config: str, kube_client: KubeClient, spark_kube_client: KubeClient
) -> None:
# this is kinda silly, but the tron create_config functions return strings
# we should refactor to pass the dicts around until the we're going to send the config to tron
# (where we can finally convert it to a string)
config = yaml.safe_load(raw_config)
for _, job in config.get("jobs", {}).items():
for _, action in job.get("actions", {}).items():
if action.get("service_account_name") is not None:
ensure_service_account(
action["service_account_name"],
namespace=KUBERNETES_NAMESPACE,
kube_client=kube_client,
)
# spark executors are special in that we want the SA to exist in two namespaces:
# the tron namespace - for the spark driver
# and the spark namespace - for the spark executor
if action.get("executor") == "spark":
ensure_service_account(
action["service_account_name"],
namespace=spark_tools.SPARK_EXECUTOR_NAMESPACE,
kube_client=spark_kube_client,
)


def main():
args = parse_args()
log_level = logging.DEBUG if args.verbose else logging.INFO
Expand Down Expand Up @@ -133,6 +164,16 @@ def main():
log.info(f"{new_config}")
updated.append(service)
else:
# NOTE: these are all lru_cache'd so it should be fine to call these for every service
system_paasta_config = load_system_paasta_config()
kube_client = KubeClient()
spark_kube_client = KubeClient(
config_file=system_paasta_config.get_spark_kubeconfig()
)
# PaaSTA will not necessarily have created the SAs we want to use
# ...so let's go ahead and create them!
ensure_service_accounts(new_config, kube_client, spark_kube_client)

if client.update_namespace(service, new_config):
updated.append(service)
log.debug(f"Updated {service}")
Expand Down
21 changes: 21 additions & 0 deletions paasta_tools/spark_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import cast
from typing import Dict
from typing import List
from typing import Literal
from typing import Mapping
from typing import Set

Expand All @@ -23,6 +24,7 @@
SPARK_JOB_USER = "TRON"
SPARK_PROMETHEUS_SHARD = "ml-compute"
SPARK_DNS_POD_TEMPLATE = "/nail/srv/configs/spark_dns_pod_template.yaml"
MEM_MULTIPLIER = {"k": 1024, "m": 1024**2, "g": 1024**3, "t": 1024**4}

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -247,3 +249,22 @@ def get_spark_driver_monitoring_labels(
"spark.yelp.com/driver_ui_port": ui_port_str,
}
return labels


def get_spark_memory_in_unit(mem: str, unit: Literal["k", "m", "g", "t"]) -> float:
"""
Converts Spark memory to the desired unit.
mem is the same format as JVM memory strings: just number or number followed by 'k', 'm', 'g' or 't'.
unit can be 'k', 'm', 'g' or 't'.
Returns memory as a float converted to the desired unit.
"""
try:
memory_bytes = float(mem)
except ValueError:
try:
memory_bytes = float(mem[:-1]) * MEM_MULTIPLIER[mem[-1]]
except (ValueError, IndexError):
print(f"Unable to parse memory value {mem}. Defaulting to 2 GB.")
memory_bytes = 2147483648 # default to 2 GB
memory_unit = memory_bytes / MEM_MULTIPLIER[unit]
return memory_unit
65 changes: 58 additions & 7 deletions paasta_tools/tron_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,40 @@ def __init__(
soa_dir=soa_dir,
)
self.job, self.action = decompose_instance(instance)

# Indicate whether this config object is created for validation
self.for_validation = for_validation

self.action_spark_config = None
if self.get_executor() == "spark":
# build the complete Spark configuration
# TODO: add conditional check for Spark specific commands spark-submit, pyspark etc ?
self.action_spark_config = self.build_spark_config()

def get_cpus(self) -> float:
# set Spark driver pod CPU if it is specified by Spark arguments
if (
self.action_spark_config
and "spark.driver.cores" in self.action_spark_config
):
return float(self.action_spark_config["spark.driver.cores"])
# we fall back to this default if there's no spark.driver.cores config
return super().get_cpus()

def get_mem(self) -> float:
# set Spark driver pod memory if it is specified by Spark arguments
if (
self.action_spark_config
and "spark.driver.memory" in self.action_spark_config
):
return int(
spark_tools.get_spark_memory_in_unit(
self.action_spark_config["spark.driver.memory"], "m"
)
)
# we fall back to this default if there's no spark.driver.memory config
return super().get_mem()

def build_spark_config(self) -> Dict[str, str]:
system_paasta_config = load_system_paasta_config()
resolved_cluster = system_paasta_config.get_eks_cluster_aliases().get(
Expand Down Expand Up @@ -436,7 +467,6 @@ def get_env(
system_paasta_config: Optional["SystemPaastaConfig"] = None,
) -> Dict[str, str]:
env = super().get_env(system_paasta_config=system_paasta_config)

if self.get_executor() == "spark":
# Required by some sdks like boto3 client. Throws NoRegionError otherwise.
# AWS_REGION takes precedence if set.
Expand Down Expand Up @@ -601,6 +631,20 @@ def validate(self):
error_msgs.append(
f"{self.get_job_name()}.{self.get_action_name()} must have a deploy_group set"
)
# We are not allowing users to specify `cpus` and `mem` configuration if the action is a Spark job
# with driver running on k8s (executor: spark), because we derive these values from `spark.driver.cores`
# and `spark.driver.memory` in order to avoid confusion.
if self.get_executor() == "spark":
if "cpus" in self.config_dict:
error_msgs.append(
f"{self.get_job_name()}.{self.get_action_name()} is a Spark job. `cpus` config is not allowed. "
f"Please specify the driver cores using `spark.driver.cores`."
)
if "mem" in self.config_dict:
error_msgs.append(
f"{self.get_job_name()}.{self.get_action_name()} is a Spark job. `mem` config is not allowed. "
f"Please specify the driver memory using `spark.driver.memory`."
)
return error_msgs

def get_pool(self) -> str:
Expand Down Expand Up @@ -965,21 +1009,26 @@ def format_tron_action_dict(action_config: TronActionConfig):
if executor == "spark":
is_mrjob = action_config.config_dict.get("mrjob", False)
system_paasta_config = load_system_paasta_config()
# inject spark configs to the original spark-submit command
spark_config = action_config.build_spark_config()
# inject additional Spark configs in case of Spark commands
result["command"] = spark_tools.build_spark_command(
result["command"],
spark_config,
action_config.action_spark_config,
is_mrjob,
action_config.config_dict.get(
"max_runtime", spark_tools.DEFAULT_SPARK_RUNTIME_TIMEOUT
),
)
# point to the KUBECONFIG needed by Spark driver
result["env"]["KUBECONFIG"] = system_paasta_config.get_spark_kubeconfig()

# spark, unlike normal batches, needs to expose several ports for things like the spark
# ui and for executor->driver communication
result["ports"] = list(
set(spark_tools.get_spark_ports_from_config(spark_config))
set(
spark_tools.get_spark_ports_from_config(
action_config.action_spark_config
)
)
)
# mount KUBECONFIG file for Spark drivers to communicate with EKS cluster
extra_volumes.append(
Expand All @@ -993,10 +1042,12 @@ def format_tron_action_dict(action_config: TronActionConfig):
)
# Add pod annotations and labels for Spark monitoring metrics
monitoring_annotations = (
spark_tools.get_spark_driver_monitoring_annotations(spark_config)
spark_tools.get_spark_driver_monitoring_annotations(
action_config.action_spark_config
)
)
monitoring_labels = spark_tools.get_spark_driver_monitoring_labels(
spark_config
action_config.action_spark_config
)
result["annotations"].update(monitoring_annotations)
result["labels"].update(monitoring_labels)
Expand Down
Loading
Loading