From 93fe171b8daf7c6b0eaf875a48d1f9454a623620 Mon Sep 17 00:00:00 2001 From: jonded94 Date: Tue, 2 Apr 2024 14:57:59 +0200 Subject: [PATCH] Enable overwrites of default environment variables (#874) * Enable overwrites of default environment variables * Black formatting * Include test for additional worker group; test overriding of environment variables * Black --------- Co-authored-by: Jonas Dedden --- .../operator/controller/controller.py | 32 ++++--- .../tests/resources/simpleworkergroup.yaml | 4 +- .../controller/tests/test_controller.py | 89 +++++++++++++++---- 3 files changed, 93 insertions(+), 32 deletions(-) diff --git a/dask_kubernetes/operator/controller/controller.py b/dask_kubernetes/operator/controller/controller.py index efbd3887b..13702d97e 100644 --- a/dask_kubernetes/operator/controller/controller.py +++ b/dask_kubernetes/operator/controller/controller.py @@ -153,21 +153,25 @@ def build_worker_deployment_spec( "metadata": metadata, "spec": spec, } - env = [ - { - "name": "DASK_WORKER_NAME", - "value": worker_name, - }, - { - "name": "DASK_SCHEDULER_ADDRESS", - "value": f"tcp://{cluster_name}-scheduler.{namespace}.svc.cluster.local:8786", - }, - ] + worker_env = { + "name": "DASK_WORKER_NAME", + "value": worker_name, + } + scheduler_env = { + "name": "DASK_SCHEDULER_ADDRESS", + "value": f"tcp://{cluster_name}-scheduler.{namespace}.svc.cluster.local:8786", + } for container in deployment_spec["spec"]["template"]["spec"]["containers"]: - if "env" in container: - container["env"].extend(env) - else: - container["env"] = env + if "env" not in container: + container["env"] = [worker_env, scheduler_env] + continue + + container_env_names = [env_item["name"] for env_item in container["env"]] + + if "DASK_WORKER_NAME" not in container_env_names: + container["env"].append(worker_env) + if "DASK_SCHEDULER_ADDRESS" not in container_env_names: + container["env"].append(scheduler_env) return deployment_spec diff --git a/dask_kubernetes/operator/controller/tests/resources/simpleworkergroup.yaml b/dask_kubernetes/operator/controller/tests/resources/simpleworkergroup.yaml index cd7da0e92..e99ebe608 100644 --- a/dask_kubernetes/operator/controller/tests/resources/simpleworkergroup.yaml +++ b/dask_kubernetes/operator/controller/tests/resources/simpleworkergroup.yaml @@ -5,7 +5,7 @@ metadata: spec: cluster: simple worker: - replicas: 2 + replicas: 1 spec: containers: - name: worker @@ -23,3 +23,5 @@ spec: env: - name: WORKER_ENV value: hello-world # We dont test the value, just the name + - name: DASK_WORKER_NAME + value: test-worker diff --git a/dask_kubernetes/operator/controller/tests/test_controller.py b/dask_kubernetes/operator/controller/tests/test_controller.py index 144865955..33abf9c4f 100644 --- a/dask_kubernetes/operator/controller/tests/test_controller.py +++ b/dask_kubernetes/operator/controller/tests/test_controller.py @@ -20,7 +20,6 @@ DIR = pathlib.Path(__file__).parent.absolute() - _EXPECTED_ANNOTATIONS = {"test-annotation": "annotation-value"} _EXPECTED_LABELS = {"test-label": "label-value"} DEFAULT_CLUSTER_NAME = "simple" @@ -47,7 +46,6 @@ def gen_cluster(k8s_cluster, ns, gen_cluster_manifest): @asynccontextmanager async def cm(cluster_name=DEFAULT_CLUSTER_NAME): - cluster_path = gen_cluster_manifest(cluster_name) # Create cluster resource k8s_cluster.kubectl("apply", "-n", ns, "-f", cluster_path) @@ -95,6 +93,36 @@ async def cm(job_file): yield cm +@pytest.fixture() +def gen_worker_group(k8s_cluster, ns): + """Yields an instantiated context manager for creating/deleting a worker group.""" + + @asynccontextmanager + async def cm(worker_group_file): + worker_group_path = os.path.join(DIR, "resources", worker_group_file) + with open(worker_group_path) as f: + worker_group_name = yaml.load(f, yaml.Loader)["metadata"]["name"] + + # Create cluster resource + k8s_cluster.kubectl("apply", "-n", ns, "-f", worker_group_path) + while worker_group_name not in k8s_cluster.kubectl( + "get", "daskworkergroups.kubernetes.dask.org", "-n", ns + ): + await asyncio.sleep(0.1) + + try: + yield worker_group_name, ns + finally: + # Test: remove the wait=True, because I think this is blocking the operator + k8s_cluster.kubectl("delete", "-n", ns, "-f", worker_group_path) + while worker_group_name in k8s_cluster.kubectl( + "get", "daskworkergroups.kubernetes.dask.org", "-n", ns + ): + await asyncio.sleep(0.1) + + yield cm + + def test_customresources(k8s_cluster): assert "daskclusters.kubernetes.dask.org" in k8s_cluster.kubectl("get", "crd") assert "daskworkergroups.kubernetes.dask.org" in k8s_cluster.kubectl("get", "crd") @@ -671,32 +699,59 @@ async def test_object_dask_cluster(k8s_cluster, kopf_runner, gen_cluster): @pytest.mark.anyio -async def test_object_dask_worker_group(k8s_cluster, kopf_runner, gen_cluster): +async def test_object_dask_worker_group( + k8s_cluster, kopf_runner, gen_cluster, gen_worker_group +): with kopf_runner: - async with gen_cluster() as (cluster_name, ns): + async with ( + gen_cluster() as (cluster_name, ns), + gen_worker_group("simpleworkergroup.yaml") as ( + additional_workergroup_name, + _, + ), + ): cluster = await DaskCluster.get(cluster_name, namespace=ns) + additional_workergroup = await DaskWorkerGroup.get( + additional_workergroup_name, namespace=ns + ) worker_groups = [] while not worker_groups: worker_groups = await cluster.worker_groups() await asyncio.sleep(0.1) assert len(worker_groups) == 1 # Just the default worker group - wg = worker_groups[0] - assert isinstance(wg, DaskWorkerGroup) + worker_groups = worker_groups + [additional_workergroup] - pods = [] - while not pods: - pods = await wg.pods() - await asyncio.sleep(0.1) - assert all([isinstance(p, Pod) for p in pods]) + for wg in worker_groups: + assert isinstance(wg, DaskWorkerGroup) - deployments = [] - while not deployments: - deployments = await wg.deployments() - await asyncio.sleep(0.1) - assert all([isinstance(d, Deployment) for d in deployments]) + deployments = [] + while not deployments: + deployments = await wg.deployments() + await asyncio.sleep(0.1) + assert all([isinstance(d, Deployment) for d in deployments]) - assert (await wg.cluster()).name == cluster.name + pods = [] + while not pods: + pods = await wg.pods() + await asyncio.sleep(0.1) + assert all([isinstance(p, Pod) for p in pods]) + + assert (await wg.cluster()).name == cluster.name + + for deployment in deployments: + assert deployment.labels["dask.org/cluster-name"] == cluster.name + for env in deployment.spec["template"]["spec"]["containers"][0][ + "env" + ]: + if env["name"] == "DASK_WORKER_NAME": + if wg.name == additional_workergroup_name: + assert env["value"] == "test-worker" + else: + assert env["value"] == deployment.name + if env["name"] == "DASK_SCHEDULER_ADDRESS": + scheduler_service = await cluster.scheduler_service() + assert f"{scheduler_service.name}.{ns}" in env["value"] @pytest.mark.anyio