Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: Samhita Alla <[email protected]>
  • Loading branch information
samhita-alla committed Jun 13, 2024
1 parent 6c88bdc commit 1159209
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .nim.serve import nim
from .sidecar_template import ModelInferenceTemplate
from .sidecar_template import Cloud, ModelInferenceTemplate
45 changes: 32 additions & 13 deletions plugins/flytekit-inference/flytekitplugins/inference/nim/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
V1VolumeMount,
)

from flytekit import Secret
from flytekit.extras.accelerators import GPUAccelerator

from ..sidecar_template import Cloud, ModelInferenceTemplate

Expand All @@ -20,31 +20,52 @@ class nim(ModelInferenceTemplate):
def __init__(
self,
task_function: Optional[Callable] = None,
cloud: Cloud = Cloud.AWS,
cloud: Optional[Cloud] = None,
device: Optional[GPUAccelerator] = None,
image: str = "nvcr.io/nim/meta/llama3-8b-instruct:1.0.0",
port: int = 8000,
cpu: int = 1,
gpu: int = 1,
mem: str = "20Gi",
shm_size: str = "16Gi",
nvcr_image_secret: str = "nvcrio-cred",
ngc_secret: Secret = Secret(group="ngc", key="api_key"),
ngc_image_secret: Optional[str] = None,
ngc_secret_group: Optional[str] = None,
ngc_secret_key: Optional[str] = None,
health_endpoint: str = "v1/health/ready",
**init_kwargs: dict,
):
if ngc_image_secret is None:
raise ValueError("NGC image pull credentials must be provided.")
if ngc_secret_group is None:
raise ValueError("NGC secret group must be provided.")
if ngc_secret_key is None:
raise ValueError("NGC secret key must be provided.")
if not isinstance(cloud, Cloud):
raise ValueError("cloud should derive from Cloud enum. Import Cloud from flytekitplugns.nim")
if not isinstance(device, GPUAccelerator):
raise ValueError("device must be a GPUAccelerator instance.")

self.shm_size = shm_size
self.nvcr_secret = nvcr_image_secret
self.ngc_secret = ngc_secret
self.ngc_image_secret = ngc_image_secret
self.ngc_secret_group = ngc_secret_group
self.ngc_secret_key = ngc_secret_key
self.health_endpoint = health_endpoint

# All kwargs need to be passed up so that the function wrapping works for both `@nim` and `@nim(...)`
super().__init__(
task_function=task_function,
task_function,
cloud=cloud,
device=device,
image=image,
health_endpoint=health_endpoint,
port=port,
cpu=cpu,
gpu=gpu,
mem=mem,
health_endpoint="/v1/health/ready",
shm_size=shm_size,
ngc_image_secret=ngc_image_secret,
ngc_secret_group=ngc_secret_group,
ngc_secret_key=ngc_secret_key,
**init_kwargs,
)

Expand All @@ -59,7 +80,7 @@ def update_pod_template(self):
empty_dir=V1EmptyDirVolumeSource(medium="Memory", size_limit=self.shm_size),
)
]
self.pod_template.pod_spec.image_pull_secrets = [V1LocalObjectReference(name=self.nvcr_secret)]
self.pod_template.pod_spec.image_pull_secrets = [V1LocalObjectReference(name=self.ngc_image_secret)]

# Update the init containers with the additional environment variables
model_server_container = self.pod_template.pod_spec.init_containers[0]
Expand All @@ -68,13 +89,11 @@ def update_pod_template(self):
name="NGC_API_KEY",
value_from=V1EnvVarSource(
secret_key_ref=V1SecretKeySelector(
name=self.ngc_secret.group,
key=self.ngc_secret.key,
name=self.ngc_secret_group,
key=self.ngc_secret_key,
)
),
)
]
model_server_container.volume_mounts = [V1VolumeMount(name="dshm", mount_path="/dev/shm")]
model_server_container.security_context = V1SecurityContext(run_as_user=1000)

self.task_function.secret_requests.append(self.ngc_secret)
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,39 @@ class ModelInferenceTemplate(ClassDecorator):

def __init__(
self,
port: int,
cpu: int,
gpu: int,
mem: str,
task_function: Optional[Callable] = None,
cloud: Optional[Cloud] = None,
device: Optional[GPUAccelerator] = None,
image: Optional[str] = None,
health_endpoint: str = "/",
port: int = 8000,
cpu: int = 1,
gpu: int = 1,
mem: str = "1Gi",
**init_kwargs: dict,
):
self.cloud = cloud
self.device = device
self.image = image
self.health_endpoint = health_endpoint
self.port = port
self.cpu = cpu
self.gpu = gpu
self.mem = mem
self.health_endpoint = health_endpoint
self.pod_template = PodTemplate()
self.device = device

super().__init__(task_function, **init_kwargs)
super().__init__(
task_function,
cloud=cloud,
device=device,
image=image,
health_endpoint=health_endpoint,
port=port,
cpu=cpu,
gpu=gpu,
mem=mem,
**init_kwargs,
)
self.update_pod_template()

def update_pod_template(self):
Expand Down

0 comments on commit 1159209

Please sign in to comment.