Skip to content

Commit

Permalink
Fix ensure_service_account breakages from the get_or_create_service_a…
Browse files Browse the repository at this point in the history
…ccount refactor (#3912)

* Wait until the last moment to create kubeclients in s_t_n

This should avoid issues with missing spark kubeconfigs on trons that
don't have that kubeconfig puppeted

* Conditionally call ensure_service_account in controller_wrappers

This was reverted as we were hitting the ValueError that protects us
from calling this function without an IAM or k8s role due to the default
get_iam_role() retval being ""
  • Loading branch information
nemacysts authored Jul 5, 2024
1 parent a243dea commit a47482b
Show file tree
Hide file tree
Showing 8 changed files with 399 additions and 194 deletions.
21 changes: 21 additions & 0 deletions paasta_tools/kubernetes/application/controller_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from paasta_tools.kubernetes_tools import create_deployment
from paasta_tools.kubernetes_tools import create_pod_disruption_budget
from paasta_tools.kubernetes_tools import create_stateful_set
from paasta_tools.kubernetes_tools import ensure_service_account
from paasta_tools.kubernetes_tools import KubeClient
from paasta_tools.kubernetes_tools import KubeDeployment
from paasta_tools.kubernetes_tools import KubernetesDeploymentConfig
Expand Down Expand Up @@ -120,6 +121,26 @@ def update_related_api_objects(self, kube_client: KubeClient) -> None:
"""
self.ensure_pod_disruption_budget(kube_client, self.soa_config.get_namespace())

def update_dependency_api_objects(self, kube_client: KubeClient) -> None:
"""
Update related Kubernetes API objects that should be updated before the main object,
such as service accounts.
:param kube_client:
"""
self.ensure_service_account(kube_client)

def ensure_service_account(self, kube_client: KubeClient) -> None:
"""
Ensure that the service account for this application exists
:param kube_client:
"""
if self.soa_config.get_iam_role():
ensure_service_account(
iam_role=self.soa_config.get_iam_role(),
namespace=self.soa_config.get_namespace(),
kube_client=kube_client,
)

def delete_pod_disruption_budget(self, kube_client: KubeClient) -> None:
try:
kube_client.policy.delete_namespaced_pod_disruption_budget(
Expand Down
28 changes: 14 additions & 14 deletions paasta_tools/kubernetes_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2220,9 +2220,9 @@ def get_pod_template_spec(
annotations["iam.amazonaws.com/role"] = ""
iam_role = self.get_iam_role()
if iam_role:
pod_spec_kwargs[
"service_account_name"
] = create_or_find_service_account_name(iam_role, self.get_namespace())
pod_spec_kwargs["service_account_name"] = get_service_account_name(
iam_role
)
if fs_group is None:
# We need some reasoable default for group id of a process
# running inside the container. Seems like most of such
Expand Down Expand Up @@ -4050,12 +4050,9 @@ def get_all_limit_ranges(
_RE_NORMALIZE_IAM_ROLE = re.compile(r"[^0-9a-zA-Z]+")


def create_or_find_service_account_name(
def get_service_account_name(
iam_role: str,
namespace: str,
k8s_role: Optional[str] = None,
kubeconfig_file: Optional[str] = None,
dry_run: bool = False,
) -> str:
# the service account is expected to always be prefixed with paasta- as using the actual namespace
# potentially wastes a lot of characters (e.g., paasta-nrtsearchservices) that could be used for
Expand All @@ -4081,12 +4078,17 @@ def create_or_find_service_account_name(
"Expected at least one of iam_role or k8s_role to be passed in!"
)

# if someone is dry-running paasta_setup_tron_namespace or some other tool that
# calls this function, we probably don't want to mutate k8s state :)
if dry_run:
return sa_name
return sa_name


def ensure_service_account(
iam_role: str,
namespace: str,
kube_client: KubeClient,
k8s_role: Optional[str] = None,
) -> None:
sa_name = get_service_account_name(iam_role, k8s_role)

kube_client = KubeClient(config_file=kubeconfig_file)
if not any(
sa.metadata and sa.metadata.name == sa_name
for sa in get_all_service_accounts(kube_client, namespace)
Expand Down Expand Up @@ -4135,8 +4137,6 @@ def create_or_find_service_account_name(
namespace=namespace, body=role_binding
)

return sa_name


def mode_to_int(mode: Optional[Union[str, int]]) -> Optional[int]:
if mode is not None:
Expand Down
1 change: 1 addition & 0 deletions paasta_tools/setup_kubernetes_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def setup_kube_deployments(
"paasta_namespace": app.kube_deployment.namespace,
}
try:
app.update_dependency_api_objects(kube_client)
if (
app.kube_deployment.service,
app.kube_deployment.instance,
Expand Down
40 changes: 40 additions & 0 deletions paasta_tools/setup_tron_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,13 @@

import ruamel.yaml as yaml

from paasta_tools import spark_tools
from paasta_tools import tron_tools
from paasta_tools.kubernetes_tools import ensure_service_account
from paasta_tools.kubernetes_tools import KubeClient
from paasta_tools.tron_tools import KUBERNETES_NAMESPACE
from paasta_tools.tron_tools import MASTER_NAMESPACE
from paasta_tools.utils import load_system_paasta_config

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -63,6 +68,37 @@ def parse_args():
return args


def ensure_service_accounts(raw_config: str) -> None:
# NOTE: these are lru_cache'd so it should be fine to call these for every service
system_paasta_config = load_system_paasta_config()
kube_client = KubeClient()
# this is kinda silly, but the tron create_config functions return strings
# we should refactor to pass the dicts around until the we're going to send the config to tron
# (where we can finally convert it to a string)
config = yaml.safe_load(raw_config)
for _, job in config.get("jobs", {}).items():
for _, action in job.get("actions", {}).items():
if action.get("service_account_name") is not None:
ensure_service_account(
action["service_account_name"],
namespace=KUBERNETES_NAMESPACE,
kube_client=kube_client,
)
# spark executors are special in that we want the SA to exist in two namespaces:
# the tron namespace - for the spark driver
# and the spark namespace - for the spark executor
if action.get("executor") == "spark":
# this kubeclient creation is lru_cache'd so it should be fine to call this for every spark action
spark_kube_client = KubeClient(
config_file=system_paasta_config.get_spark_kubeconfig()
)
ensure_service_account(
action["service_account_name"],
namespace=spark_tools.SPARK_EXECUTOR_NAMESPACE,
kube_client=spark_kube_client,
)


def main():
args = parse_args()
log_level = logging.DEBUG if args.verbose else logging.INFO
Expand Down Expand Up @@ -133,6 +169,10 @@ def main():
log.info(f"{new_config}")
updated.append(service)
else:
# PaaSTA will not necessarily have created the SAs we want to use
# ...so let's go ahead and create them!
ensure_service_accounts(new_config)

if client.update_namespace(service, new_config):
updated.append(service)
log.debug(f"Updated {service}")
Expand Down
21 changes: 21 additions & 0 deletions paasta_tools/spark_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import cast
from typing import Dict
from typing import List
from typing import Literal
from typing import Mapping
from typing import Set

Expand All @@ -23,6 +24,7 @@
SPARK_JOB_USER = "TRON"
SPARK_PROMETHEUS_SHARD = "ml-compute"
SPARK_DNS_POD_TEMPLATE = "/nail/srv/configs/spark_dns_pod_template.yaml"
MEM_MULTIPLIER = {"k": 1024, "m": 1024**2, "g": 1024**3, "t": 1024**4}

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -247,3 +249,22 @@ def get_spark_driver_monitoring_labels(
"spark.yelp.com/driver_ui_port": ui_port_str,
}
return labels


def get_spark_memory_in_unit(mem: str, unit: Literal["k", "m", "g", "t"]) -> float:
"""
Converts Spark memory to the desired unit.
mem is the same format as JVM memory strings: just number or number followed by 'k', 'm', 'g' or 't'.
unit can be 'k', 'm', 'g' or 't'.
Returns memory as a float converted to the desired unit.
"""
try:
memory_bytes = float(mem)
except ValueError:
try:
memory_bytes = float(mem[:-1]) * MEM_MULTIPLIER[mem[-1]]
except (ValueError, IndexError):
print(f"Unable to parse memory value {mem}. Defaulting to 2 GB.")
memory_bytes = 2147483648 # default to 2 GB
memory_unit = memory_bytes / MEM_MULTIPLIER[unit]
return memory_unit
81 changes: 61 additions & 20 deletions paasta_tools/tron_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@

from paasta_tools.kubernetes_tools import (
allowlist_denylist_to_requirements,
create_or_find_service_account_name,
get_service_account_name,
limit_size_with_hash,
raw_selectors_to_requirements,
to_node_label,
Expand Down Expand Up @@ -280,9 +280,40 @@ def __init__(
soa_dir=soa_dir,
)
self.job, self.action = decompose_instance(instance)

# Indicate whether this config object is created for validation
self.for_validation = for_validation

self.action_spark_config = None
if self.get_executor() == "spark":
# build the complete Spark configuration
# TODO: add conditional check for Spark specific commands spark-submit, pyspark etc ?
self.action_spark_config = self.build_spark_config()

def get_cpus(self) -> float:
# set Spark driver pod CPU if it is specified by Spark arguments
if (
self.action_spark_config
and "spark.driver.cores" in self.action_spark_config
):
return float(self.action_spark_config["spark.driver.cores"])
# we fall back to this default if there's no spark.driver.cores config
return super().get_cpus()

def get_mem(self) -> float:
# set Spark driver pod memory if it is specified by Spark arguments
if (
self.action_spark_config
and "spark.driver.memory" in self.action_spark_config
):
return int(
spark_tools.get_spark_memory_in_unit(
self.action_spark_config["spark.driver.memory"], "m"
)
)
# we fall back to this default if there's no spark.driver.memory config
return super().get_mem()

def build_spark_config(self) -> Dict[str, str]:
system_paasta_config = load_system_paasta_config()
resolved_cluster = system_paasta_config.get_eks_cluster_aliases().get(
Expand Down Expand Up @@ -354,15 +385,11 @@ def build_spark_config(self) -> Dict[str, str]:
"spark.kubernetes.executor.label.yelp.com/owner", self.get_team()
)

# We need to make sure the Service Account used by the executors has been created.
# We are using the Service Account created using the provided or default IAM role.
spark_conf[
"spark.kubernetes.authenticate.executor.serviceAccountName"
] = create_or_find_service_account_name(
] = get_service_account_name(
iam_role=self.get_spark_executor_iam_role(),
namespace=spark_tools.SPARK_EXECUTOR_NAMESPACE,
kubeconfig_file=system_paasta_config.get_spark_kubeconfig(),
dry_run=self.for_validation,
)

return spark_conf
Expand Down Expand Up @@ -440,7 +467,6 @@ def get_env(
system_paasta_config: Optional["SystemPaastaConfig"] = None,
) -> Dict[str, str]:
env = super().get_env(system_paasta_config=system_paasta_config)

if self.get_executor() == "spark":
# Required by some sdks like boto3 client. Throws NoRegionError otherwise.
# AWS_REGION takes precedence if set.
Expand Down Expand Up @@ -605,6 +631,20 @@ def validate(self):
error_msgs.append(
f"{self.get_job_name()}.{self.get_action_name()} must have a deploy_group set"
)
# We are not allowing users to specify `cpus` and `mem` configuration if the action is a Spark job
# with driver running on k8s (executor: spark), because we derive these values from `spark.driver.cores`
# and `spark.driver.memory` in order to avoid confusion.
if self.get_executor() == "spark":
if "cpus" in self.config_dict:
error_msgs.append(
f"{self.get_job_name()}.{self.get_action_name()} is a Spark job. `cpus` config is not allowed. "
f"Please specify the driver cores using `spark.driver.cores`."
)
if "mem" in self.config_dict:
error_msgs.append(
f"{self.get_job_name()}.{self.get_action_name()} is a Spark job. `mem` config is not allowed. "
f"Please specify the driver memory using `spark.driver.memory`."
)
return error_msgs

def get_pool(self) -> str:
Expand Down Expand Up @@ -952,20 +992,14 @@ def format_tron_action_dict(action_config: TronActionConfig):

result["labels"]["yelp.com/owner"] = "compute_infra_platform_experience"

# create_or_find_service_account_name requires k8s credentials, and we don't
# have those available for CI to use (nor do we check these for normal PaaSTA
# services, so we're not doing anything "new" by skipping this)
if (
action_config.get_iam_role_provider() == "aws"
and action_config.get_iam_role()
and not action_config.for_validation
):
# this service account will be used for normal Tron batches as well as for Spark drivers
result["service_account_name"] = create_or_find_service_account_name(
result["service_account_name"] = get_service_account_name(
iam_role=action_config.get_iam_role(),
namespace=EXECUTOR_TYPE_TO_NAMESPACE[executor],
k8s_role=None,
dry_run=action_config.for_validation,
)

# service account token volumes for service authentication
Expand All @@ -975,21 +1009,26 @@ def format_tron_action_dict(action_config: TronActionConfig):
if executor == "spark":
is_mrjob = action_config.config_dict.get("mrjob", False)
system_paasta_config = load_system_paasta_config()
# inject spark configs to the original spark-submit command
spark_config = action_config.build_spark_config()
# inject additional Spark configs in case of Spark commands
result["command"] = spark_tools.build_spark_command(
result["command"],
spark_config,
action_config.action_spark_config,
is_mrjob,
action_config.config_dict.get(
"max_runtime", spark_tools.DEFAULT_SPARK_RUNTIME_TIMEOUT
),
)
# point to the KUBECONFIG needed by Spark driver
result["env"]["KUBECONFIG"] = system_paasta_config.get_spark_kubeconfig()

# spark, unlike normal batches, needs to expose several ports for things like the spark
# ui and for executor->driver communication
result["ports"] = list(
set(spark_tools.get_spark_ports_from_config(spark_config))
set(
spark_tools.get_spark_ports_from_config(
action_config.action_spark_config
)
)
)
# mount KUBECONFIG file for Spark drivers to communicate with EKS cluster
extra_volumes.append(
Expand All @@ -1003,10 +1042,12 @@ def format_tron_action_dict(action_config: TronActionConfig):
)
# Add pod annotations and labels for Spark monitoring metrics
monitoring_annotations = (
spark_tools.get_spark_driver_monitoring_annotations(spark_config)
spark_tools.get_spark_driver_monitoring_annotations(
action_config.action_spark_config
)
)
monitoring_labels = spark_tools.get_spark_driver_monitoring_labels(
spark_config
action_config.action_spark_config
)
result["annotations"].update(monitoring_annotations)
result["labels"].update(monitoring_labels)
Expand Down
Loading

0 comments on commit a47482b

Please sign in to comment.