Skip to content

Commit

Permalink
p
Browse files Browse the repository at this point in the history
Signed-off-by: kevin <[email protected]>
  • Loading branch information
khluu committed Sep 24, 2024
1 parent 9e52954 commit 9d973dc
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
6 changes: 6 additions & 0 deletions scripts/pipeline_generator/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
DOCKER_PLUGIN_NAME = "docker#v5.2.0"
KUBERNETES_PLUGIN_NAME = "kubernetes"


class DockerPluginConfig(BaseModel):
"""
Configuration for Docker plugin running in a Buildkite step.
Expand All @@ -29,6 +30,7 @@ class DockerPluginConfig(BaseModel):
f"{HF_HOME}:{HF_HOME}"
]


class KubernetesPodContainerConfig(BaseModel):
"""
Configuration for a container running in a Kubernetes pod.
Expand Down Expand Up @@ -59,6 +61,7 @@ class KubernetesPodContainerConfig(BaseModel):
],
)


class KubernetesPodSpec(BaseModel):
"""
Configuration for a Kubernetes pod running in a Buildkite step.
Expand All @@ -76,12 +79,14 @@ class KubernetesPodSpec(BaseModel):
]
)


class KubernetesPluginConfig(BaseModel):
"""
Configuration for Kubernetes plugin running in a Buildkite step.
"""
pod_spec: KubernetesPodSpec = Field(alias="podSpec")


def get_kubernetes_plugin_config(container_image: str, test_bash_command: List[str], num_gpus: int) -> Dict:
pod_spec = KubernetesPodSpec(
containers=[
Expand All @@ -94,6 +99,7 @@ def get_kubernetes_plugin_config(container_image: str, test_bash_command: List[s
)
return {KUBERNETES_PLUGIN_NAME: KubernetesPluginConfig(podSpec=pod_spec).dict(by_alias=True)}


def get_docker_plugin_config(docker_image_path: str, test_bash_command: List[str], no_gpu: bool) -> Dict:
docker_plugin_config = DockerPluginConfig(
image=docker_image_path,
Expand Down
4 changes: 2 additions & 2 deletions scripts/tests/pipeline_generator/test_plugin.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import pytest
import sys

from unittest.mock import patch
from scripts.pipeline_generator.plugin import (
get_kubernetes_plugin_config,
get_docker_plugin_config,
DOCKER_PLUGIN_NAME,
KUBERNETES_PLUGIN_NAME,
)


def test_get_kubernetes_plugin_config():
docker_image_path = "test_image:latest"
test_bash_command = ["echo", "Hello, Kubernetes!"]
Expand Down Expand Up @@ -50,7 +50,7 @@ def test_get_kubernetes_plugin_config():
}
}
}

assert get_kubernetes_plugin_config(docker_image_path, test_bash_command, num_gpus) == expected_config


Expand Down

0 comments on commit 9d973dc

Please sign in to comment.