Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
Signed-off-by: Samhita Alla <[email protected]>
  • Loading branch information
samhita-alla committed Jun 14, 2024
1 parent 1159209 commit c5155e7
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 66 deletions.
34 changes: 13 additions & 21 deletions plugins/flytekit-inference/flytekitplugins/inference/nim/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,16 @@
V1VolumeMount,
)

from flytekit.extras.accelerators import GPUAccelerator

from ..sidecar_template import Cloud, ModelInferenceTemplate
from ..sidecar_template import ModelInferenceTemplate


class nim(ModelInferenceTemplate):
def __init__(
self,
task_function: Optional[Callable] = None,
cloud: Optional[Cloud] = None,
device: Optional[GPUAccelerator] = None,
node_selector: Optional[dict] = None,
image: str = "nvcr.io/nim/meta/llama3-8b-instruct:1.0.0",
health_endpoint: str = "v1/health/ready",
port: int = 8000,
cpu: int = 1,
gpu: int = 1,
Expand All @@ -31,7 +29,6 @@ def __init__(
ngc_image_secret: Optional[str] = None,
ngc_secret_group: Optional[str] = None,
ngc_secret_key: Optional[str] = None,
health_endpoint: str = "v1/health/ready",
**init_kwargs: dict,
):
if ngc_image_secret is None:
Expand All @@ -40,22 +37,17 @@ def __init__(
raise ValueError("NGC secret group must be provided.")
if ngc_secret_key is None:
raise ValueError("NGC secret key must be provided.")
if not isinstance(cloud, Cloud):
raise ValueError("cloud should derive from Cloud enum. Import Cloud from flytekitplugns.nim")
if not isinstance(device, GPUAccelerator):
raise ValueError("device must be a GPUAccelerator instance.")

self.shm_size = shm_size
self.ngc_image_secret = ngc_image_secret
self.ngc_secret_group = ngc_secret_group
self.ngc_secret_key = ngc_secret_key
self.health_endpoint = health_endpoint
self._shm_size = shm_size
self._ngc_image_secret = ngc_image_secret
self._ngc_secret_group = ngc_secret_group
self._ngc_secret_key = ngc_secret_key
self._health_endpoint = health_endpoint

# All kwargs need to be passed up so that the function wrapping works for both `@nim` and `@nim(...)`
super().__init__(
task_function,
cloud=cloud,
device=device,
node_selector=node_selector,
image=image,
health_endpoint=health_endpoint,
port=port,
Expand All @@ -77,10 +69,10 @@ def update_pod_template(self):
self.pod_template.pod_spec.volumes = [
V1Volume(
name="dshm",
empty_dir=V1EmptyDirVolumeSource(medium="Memory", size_limit=self.shm_size),
empty_dir=V1EmptyDirVolumeSource(medium="Memory", size_limit=self._shm_size),
)
]
self.pod_template.pod_spec.image_pull_secrets = [V1LocalObjectReference(name=self.ngc_image_secret)]
self.pod_template.pod_spec.image_pull_secrets = [V1LocalObjectReference(name=self._ngc_image_secret)]

# Update the init containers with the additional environment variables
model_server_container = self.pod_template.pod_spec.init_containers[0]
Expand All @@ -89,8 +81,8 @@ def update_pod_template(self):
name="NGC_API_KEY",
value_from=V1EnvVarSource(
secret_key_ref=V1SecretKeySelector(
name=self.ngc_secret_group,
key=self.ngc_secret_key,
name=self._ngc_secret_group,
key=self._ngc_secret_key,
)
),
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from enum import Enum
from typing import Callable, Optional

from kubernetes.client.models import (
Expand All @@ -10,25 +9,17 @@

from flytekit import FlyteContextManager, PodTemplate
from flytekit.core.utils import ClassDecorator
from flytekit.extras.accelerators import GPUAccelerator


class Cloud(Enum):
AWS = "aws"
GCP = "gcp"


class ModelInferenceTemplate(ClassDecorator):
CLOUD = "cloud"
INSTANCE = "instance"
NODE_SELECTOR = "node_selector"
IMAGE = "image"
PORT = "port"

def __init__(
self,
task_function: Optional[Callable] = None,
cloud: Optional[Cloud] = None,
device: Optional[GPUAccelerator] = None,
node_selector: Optional[dict] = None,
image: Optional[str] = None,
health_endpoint: str = "/",
port: int = 8000,
Expand All @@ -37,20 +28,19 @@ def __init__(
mem: str = "1Gi",
**init_kwargs: dict,
):
self.cloud = cloud
self.device = device
self.image = image
self.health_endpoint = health_endpoint
self.port = port
self.cpu = cpu
self.gpu = gpu
self.mem = mem
self.pod_template = PodTemplate()
self._node_selector = node_selector
self._image = image
self._health_endpoint = health_endpoint
self._port = port
self._cpu = cpu
self._gpu = gpu
self._mem = mem

self._pod_template = PodTemplate()

super().__init__(
task_function,
cloud=cloud,
device=device,
node_selector=node_selector,
image=image,
health_endpoint=health_endpoint,
port=port,
Expand All @@ -61,24 +51,29 @@ def __init__(
)
self.update_pod_template()

@property
def pod_template(self):
return self._pod_template

def update_pod_template(self):
self.pod_template.pod_spec = V1PodSpec(
self._pod_template.pod_spec = V1PodSpec(
node_selector=self._node_selector,
containers=[],
init_containers=[
V1Container(
name="model-server",
image=self.image,
ports=[V1ContainerPort(container_port=self.port)],
image=self._image,
ports=[V1ContainerPort(container_port=self._port)],
resources=V1ResourceRequirements(
requests={
"cpu": self.cpu,
"nvidia.com/gpu": self.gpu,
"memory": self.mem,
"cpu": self._cpu,
"nvidia.com/gpu": self._gpu,
"memory": self._mem,
},
limits={
"cpu": self.cpu,
"nvidia.com/gpu": self.gpu,
"memory": self.mem,
"cpu": self._cpu,
"nvidia.com/gpu": self._gpu,
"memory": self._mem,
},
),
restart_policy="Always", # treat this container as a sidecar
Expand All @@ -89,7 +84,7 @@ def update_pod_template(self):
command=[
"sh",
"-c",
f"until wget -qO- http://localhost:{self.port}/{self.health_endpoint}; do sleep 1; done;",
f"until wget -qO- http://localhost:{self._port}/{self._health_endpoint}; do sleep 1; done;",
],
resources=V1ResourceRequirements(
requests={"cpu": 1, "memory": "100Mi"},
Expand All @@ -99,11 +94,6 @@ def update_pod_template(self):
],
)

if self.cloud == Cloud.AWS and self.device:
self.pod_template.pod_spec.node_selector = {"k8s.amazonaws.com/accelerator": self.device._device}
elif self.cloud == Cloud.GCP and self.device:
self.pod_template.pod_spec.node_selector = {"cloud.google.com/gke-accelerator": self.device._device}

def execute(self, *args, **kwargs):
ctx = FlyteContextManager.current_context()
is_local_execution = ctx.execution_state.is_local_execution()
Expand All @@ -116,11 +106,7 @@ def execute(self, *args, **kwargs):

def get_extra_config(self):
return {
self.CLOUD: self.cloud.value if self.cloud else None,
self.INSTANCE: self.device._device if self.device else None,
self.IMAGE: self.image,
self.PORT: str(self.port),
self.NODE_SELECTOR: self._node_selector,
self.IMAGE: self._image,
self.PORT: self._port,
}

def pod_template(self):
return self.pod_template
2 changes: 1 addition & 1 deletion plugins/flytekit-inference/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flytekit>=1.12.2,<2.0.0", "kubernetes"]
plugin_requires = ["flytekit>=1.12.2,<2.0.0", "kubernetes", "openai"]

__version__ = "0.0.0+develop"

Expand Down

0 comments on commit c5155e7

Please sign in to comment.