diff --git a/dask_kubernetes/operator/controller/controller.py b/dask_kubernetes/operator/controller/controller.py index 00c47aeed..e954f79a6 100644 --- a/dask_kubernetes/operator/controller/controller.py +++ b/dask_kubernetes/operator/controller/controller.py @@ -2,6 +2,7 @@ import time from collections import defaultdict from contextlib import suppress +from copy import deepcopy from datetime import datetime from uuid import uuid4 @@ -64,6 +65,17 @@ def _get_labels(meta): } +def _consolidate_env_vars(existing_env_vars, additional_env_vars): + existing_env_names = {d["name"] for d in existing_env_vars} + additional_env_names = {d["name"] for d in additional_env_vars} + + overlapping_env_names = existing_env_names.intersection(additional_env_names) + + additional_env_vars_to_keep = [d for d in additional_env_vars if d["name"] not in overlapping_env_names] + + return [*existing_env_vars, *additional_env_vars_to_keep] + + def build_scheduler_deployment_spec( cluster_name, namespace, pod_spec, annotations, labels ): @@ -159,9 +171,9 @@ def build_worker_deployment_spec( ] for i in range(len(deployment_spec["spec"]["template"]["spec"]["containers"])): if "env" in deployment_spec["spec"]["template"]["spec"]["containers"][i]: - deployment_spec["spec"]["template"]["spec"]["containers"][i]["env"].extend( - env - ) + existing_env_vars = deployment_spec["spec"]["template"]["spec"]["containers"][i]["env"] + all_env_vars = _consolidate_env_vars(existing_env_vars, env) + deployment_spec["spec"]["template"]["spec"]["containers"][i]["env"] = all_env_vars else: deployment_spec["spec"]["template"]["spec"]["containers"][i]["env"] = env return deployment_spec @@ -197,7 +209,9 @@ def build_job_pod_spec(job_name, cluster_name, namespace, spec, annotations, lab ] for i in range(len(pod_spec["spec"]["containers"])): if "env" in pod_spec["spec"]["containers"][i]: - pod_spec["spec"]["containers"][i]["env"].extend(env) + existing_env_vars = pod_spec["spec"]["template"]["spec"]["containers"][i]["env"] + all_env_vars = _consolidate_env_vars(existing_env_vars, env) + pod_spec["spec"]["containers"][i]["env"] = all_env_vars else: pod_spec["spec"]["containers"][i]["env"] = env return pod_spec @@ -581,7 +595,7 @@ async def daskworkergroup_replica_update( namespace=namespace, cluster_name=cluster_name, uuid=uuid4().hex[:10], - pod_spec=worker_spec["spec"], + pod_spec=deepcopy(worker_spec["spec"]), annotations=annotations, labels=labels, )