Skip to content

Commit

Permalink
feat: support vllm in controller
Browse files Browse the repository at this point in the history
- set vllm as the default runtime

Signed-off-by: jerryzhuang <[email protected]>
  • Loading branch information
zhuangqh committed Nov 6, 2024
1 parent ad0dde9 commit 3b52cc4
Show file tree
Hide file tree
Showing 22 changed files with 621 additions and 239 deletions.
30 changes: 30 additions & 0 deletions api/v1alpha1/labels.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@

package v1alpha1

import (
"github.com/kaito-project/kaito/pkg/featuregates"
"github.com/kaito-project/kaito/pkg/model"
"github.com/kaito-project/kaito/pkg/utils/consts"
)

const (

// Non-prefixed labels/annotations are reserved for end-use.
Expand All @@ -27,4 +33,28 @@ const (

// WorkspaceRevisionAnnotation is the Annotations for revision number
WorkspaceRevisionAnnotation = "workspace.kaito.io/revision"

// AnnotationWorkspaceRuntime is the annotation for runtime selection.
AnnotationWorkspaceRuntime = KAITOPrefix + "runtime"
)

// GetWorkspaceRuntimeName returns the runtime name of the workspace.
func GetWorkspaceRuntimeName(ws *Workspace) model.RuntimeName {
if ws == nil {
panic("workspace is nil")
}
runtime := model.RuntimeNameHuggingfaceTransformers
if featuregates.FeatureGates[consts.FeatureFlagVLLM] {
runtime = model.RuntimeNameVLLM
}

name := ws.Annotations[AnnotationWorkspaceRuntime]
switch name {
case string(model.RuntimeNameHuggingfaceTransformers):
runtime = model.RuntimeNameHuggingfaceTransformers
case string(model.RuntimeNameVLLM):
runtime = model.RuntimeNameVLLM
}

return runtime
}
4 changes: 2 additions & 2 deletions api/v1alpha1/workspace_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func (r *TuningSpec) validateCreate(ctx context.Context, workspaceNamespace stri
// Currently require a preset to specified, in future we can consider defining a template
if r.Preset == nil {
errs = errs.Also(apis.ErrMissingField("Preset"))
} else if presetName := string(r.Preset.Name); !utils.IsValidPreset(presetName) {
} else if presetName := string(r.Preset.Name); !plugin.IsValidPreset(presetName) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported tuning preset name %s", presetName), "presetName"))
}
return errs
Expand Down Expand Up @@ -407,7 +407,7 @@ func (i *InferenceSpec) validateCreate() (errs *apis.FieldError) {
if i.Preset != nil {
presetName := string(i.Preset.Name)
// Validate preset name
if !utils.IsValidPreset(presetName) {
if !plugin.IsValidPreset(presetName) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported inference preset name %s", presetName), "presetName"))
}
// Validate private preset has private image specified
Expand Down
2 changes: 1 addition & 1 deletion pkg/controllers/workspace_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1a
} else if apierrors.IsNotFound(err) {
var workloadObj client.Object
// Need to create a new workload
workloadObj, err = inference.CreatePresetInference(ctx, wObj, revisionStr, inferenceParam, model.SupportDistributedInference(), c.Client)
workloadObj, err = inference.CreatePresetInference(ctx, wObj, revisionStr, model, c.Client)
if err != nil {
return
}
Expand Down
1 change: 1 addition & 0 deletions pkg/featuregates/featuregates.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ var (
// FeatureGates is a map that holds the feature gates and their default values for Kaito.
FeatureGates = map[string]bool{
consts.FeatureFlagKarpenter: false,
consts.FeatureFlagVLLM: false,
// Add more feature gates here
}
)
Expand Down
87 changes: 42 additions & 45 deletions pkg/inference/preset-inferences.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/kaito-project/kaito/pkg/utils"
"github.com/kaito-project/kaito/pkg/utils/consts"

"github.com/kaito-project/kaito/api/v1alpha1"
kaitov1alpha1 "github.com/kaito-project/kaito/api/v1alpha1"
"github.com/kaito-project/kaito/pkg/model"
"github.com/kaito-project/kaito/pkg/resources"
Expand All @@ -22,9 +23,8 @@ import (
)

const (
ProbePath = "/healthz"
Port5000 = int32(5000)
InferenceFile = "inference_api.py"
ProbePath = "/health"
Port5000 = int32(5000)
)

var (
Expand Down Expand Up @@ -70,26 +70,31 @@ var (
}
)

func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient client.Client, wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetParam) error {
func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient client.Client, wObj *kaitov1alpha1.Workspace, inferenceParam *model.PresetParam) error {
runtimeName := v1alpha1.GetWorkspaceRuntimeName(wObj)
if runtimeName != model.RuntimeNameHuggingfaceTransformers {
return fmt.Errorf("distributed inference is not supported for runtime %s", runtimeName)
}

existingService := &corev1.Service{}
err := resources.GetResource(ctx, wObj.Name, wObj.Namespace, kubeClient, existingService)
if err != nil {
return err
}

nodes := *wObj.Resource.Count
inferenceObj.TorchRunParams["nnodes"] = strconv.Itoa(nodes)
inferenceObj.TorchRunParams["nproc_per_node"] = strconv.Itoa(inferenceObj.WorldSize / nodes)
inferenceParam.Transformers.TorchRunParams["nnodes"] = strconv.Itoa(nodes)
inferenceParam.Transformers.TorchRunParams["nproc_per_node"] = strconv.Itoa(inferenceParam.WorldSize / nodes)
if nodes > 1 {
inferenceObj.TorchRunParams["node_rank"] = "$(echo $HOSTNAME | grep -o '[^-]*$')"
inferenceObj.TorchRunParams["master_addr"] = existingService.Spec.ClusterIP
inferenceObj.TorchRunParams["master_port"] = "29500"
}
if inferenceObj.TorchRunRdzvParams != nil {
inferenceObj.TorchRunRdzvParams["max_restarts"] = "3"
inferenceObj.TorchRunRdzvParams["rdzv_id"] = "job"
inferenceObj.TorchRunRdzvParams["rdzv_backend"] = "c10d"
inferenceObj.TorchRunRdzvParams["rdzv_endpoint"] =
inferenceParam.Transformers.TorchRunParams["node_rank"] = "$(echo $HOSTNAME | grep -o '[^-]*$')"
inferenceParam.Transformers.TorchRunParams["master_addr"] = existingService.Spec.ClusterIP
inferenceParam.Transformers.TorchRunParams["master_port"] = "29500"
}
if inferenceParam.Transformers.TorchRunRdzvParams != nil {
inferenceParam.Transformers.TorchRunRdzvParams["max_restarts"] = "3"
inferenceParam.Transformers.TorchRunRdzvParams["rdzv_id"] = "job"
inferenceParam.Transformers.TorchRunRdzvParams["rdzv_backend"] = "c10d"
inferenceParam.Transformers.TorchRunRdzvParams["rdzv_endpoint"] =
fmt.Sprintf("%s-0.%s-headless.%s.svc.cluster.local:29500", wObj.Name, wObj.Name, wObj.Namespace)
}
return nil
Expand Down Expand Up @@ -121,14 +126,17 @@ func GetInferenceImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Work
}

func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, revisionNum string,
inferenceObj *model.PresetParam, supportDistributedInference bool, kubeClient client.Client) (client.Object, error) {
if inferenceObj.TorchRunParams != nil && supportDistributedInference {
if err := updateTorchParamsForDistributedInference(ctx, kubeClient, workspaceObj, inferenceObj); err != nil {
model model.Model, kubeClient client.Client) (client.Object, error) {
inferenceParam := model.GetInferenceParameters().DeepCopy()

if model.SupportDistributedInference() {
if err := updateTorchParamsForDistributedInference(ctx, kubeClient, workspaceObj, inferenceParam); err != nil { //
klog.ErrorS(err, "failed to update torch params", "workspace", workspaceObj)
return nil, err
}
}

// additional volume
var volumes []corev1.Volume
var volumeMounts []corev1.VolumeMount
shmVolume, shmVolumeMount := utils.ConfigSHMVolume(*workspaceObj.Resource.Count)
Expand All @@ -138,24 +146,35 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work
if shmVolumeMount.Name != "" {
volumeMounts = append(volumeMounts, shmVolumeMount)
}

if len(workspaceObj.Inference.Adapters) > 0 {
adapterVolume, adapterVolumeMount := utils.ConfigAdapterVolume()
volumes = append(volumes, adapterVolume)
volumeMounts = append(volumeMounts, adapterVolumeMount)
}

// resource requirements
skuNumGPUs, err := utils.GetSKUNumGPUs(ctx, kubeClient, workspaceObj.Status.WorkerNodes,
workspaceObj.Resource.InstanceType, inferenceObj.GPUCountRequirement)
workspaceObj.Resource.InstanceType, inferenceParam.GPUCountRequirement)
if err != nil {
return nil, fmt.Errorf("failed to get SKU num GPUs: %v", err)
}
resourceReq := corev1.ResourceRequirements{
Requests: corev1.ResourceList{
corev1.ResourceName(resources.CapacityNvidiaGPU): resource.MustParse(skuNumGPUs),
},
Limits: corev1.ResourceList{
corev1.ResourceName(resources.CapacityNvidiaGPU): resource.MustParse(skuNumGPUs),
},
}

commands, resourceReq := prepareInferenceParameters(ctx, inferenceObj, skuNumGPUs)
image, imagePullSecrets := GetInferenceImageInfo(ctx, workspaceObj, inferenceObj)
// inference command
runtimeName := kaitov1alpha1.GetWorkspaceRuntimeName(workspaceObj)
commands := inferenceParam.GetInferenceCommand(runtimeName)

image, imagePullSecrets := GetInferenceImageInfo(ctx, workspaceObj, inferenceParam)

var depObj client.Object
if supportDistributedInference {
if model.SupportDistributedInference() {
depObj = resources.GenerateStatefulSetManifest(ctx, workspaceObj, image, imagePullSecrets, *workspaceObj.Resource.Count, commands,
containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volumes, volumeMounts)
} else {
Expand All @@ -168,25 +187,3 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work
}
return depObj, nil
}

// prepareInferenceParameters builds a PyTorch command:
// torchrun <TORCH_PARAMS> <OPTIONAL_RDZV_PARAMS> baseCommand <MODEL_PARAMS>
// and sets the GPU resources required for inference.
// Returns the command and resource configuration.
func prepareInferenceParameters(ctx context.Context, inferenceObj *model.PresetParam, skuNumGPUs string) ([]string, corev1.ResourceRequirements) {
torchCommand := utils.BuildCmdStr(inferenceObj.BaseCommand, inferenceObj.TorchRunParams)
torchCommand = utils.BuildCmdStr(torchCommand, inferenceObj.TorchRunRdzvParams)
modelCommand := utils.BuildCmdStr(InferenceFile, inferenceObj.ModelRunParams)
commands := utils.ShellCmd(torchCommand + " " + modelCommand)

resourceRequirements := corev1.ResourceRequirements{
Requests: corev1.ResourceList{
corev1.ResourceName(resources.CapacityNvidiaGPU): resource.MustParse(skuNumGPUs),
},
Limits: corev1.ResourceList{
corev1.ResourceName(resources.CapacityNvidiaGPU): resource.MustParse(skuNumGPUs),
},
}

return commands, resourceRequirements
}
61 changes: 35 additions & 26 deletions pkg/inference/preset-inferences_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@ import (
"strings"
"testing"

"github.com/kaito-project/kaito/pkg/utils/consts"

"github.com/kaito-project/kaito/api/v1alpha1"
"github.com/kaito-project/kaito/pkg/utils/consts"
"github.com/kaito-project/kaito/pkg/utils/test"

"github.com/kaito-project/kaito/pkg/model"
"github.com/kaito-project/kaito/pkg/utils/plugin"
"github.com/stretchr/testify/mock"
appsv1 "k8s.io/api/apps/v1"
Expand All @@ -28,6 +26,7 @@ var ValidStrength string = "0.5"
func TestCreatePresetInference(t *testing.T) {
test.RegisterTestModel()
testcases := map[string]struct {
workspace *v1alpha1.Workspace
nodeCount int
modelName string
callMocks func(c *test.MockClient)
Expand All @@ -37,7 +36,8 @@ func TestCreatePresetInference(t *testing.T) {
expectedVolume string
}{

"test-model": {
"test-model/vllm": {
workspace: test.MockWorkspaceWithPresetVLLM,
nodeCount: 1,
modelName: "test-model",
callMocks: func(c *test.MockClient) {
Expand All @@ -46,32 +46,48 @@ func TestCreatePresetInference(t *testing.T) {
workload: "Deployment",
// No BaseCommand, TorchRunParams, TorchRunRdzvParams, or ModelRunParams
// So expected cmd consists of shell command and inference file
expectedCmd: "/bin/sh -c inference_api.py",
expectedCmd: "/bin/sh -c python3 /workspace/vllm/inference_api.py",
hasAdapters: false,
},

"test-distributed-model": {
"test-model-with-adapters/vllm": {
workspace: test.MockWorkspaceWithPresetVLLM,
nodeCount: 1,
modelName: "test-distributed-model",
modelName: "test-model",
callMocks: func(c *test.MockClient) {
c.On("Get", mock.IsType(context.TODO()), mock.Anything, mock.IsType(&corev1.Service{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.TODO()), mock.IsType(&appsv1.StatefulSet{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.TODO()), mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil)
},
workload: "StatefulSet",
expectedCmd: "/bin/sh -c inference_api.py",
hasAdapters: false,
workload: "Deployment",
expectedCmd: "/bin/sh -c python3 /workspace/vllm/inference_api.py",
hasAdapters: true,
expectedVolume: "adapter-volume",
},

"test-model-with-adapters": {
"test-model/transformers": {
workspace: test.MockWorkspaceWithPreset,
nodeCount: 1,
modelName: "test-model",
callMocks: func(c *test.MockClient) {
c.On("Create", mock.IsType(context.TODO()), mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil)
},
workload: "Deployment",
expectedCmd: "/bin/sh -c inference_api.py",
hasAdapters: true,
expectedVolume: "adapter-volume",
workload: "Deployment",
// No BaseCommand, TorchRunParams, TorchRunRdzvParams, or ModelRunParams
// So expected cmd consists of shell command and inference file
expectedCmd: "/bin/sh -c accelerate launch /workspace/tfs/inference_api.py",
hasAdapters: false,
},

"test-distributed-model/transformers": {
workspace: test.MockWorkspaceDistributedModel,
nodeCount: 1,
modelName: "test-distributed-model",
callMocks: func(c *test.MockClient) {
c.On("Get", mock.IsType(context.TODO()), mock.Anything, mock.IsType(&corev1.Service{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.TODO()), mock.IsType(&appsv1.StatefulSet{}), mock.Anything).Return(nil)
},
workload: "StatefulSet",
expectedCmd: "/bin/sh -c accelerate launch --nnodes=1 --nproc_per_node=0 --max_restarts=3 --rdzv_id=job --rdzv_backend=c10d --rdzv_endpoint=testWorkspace-0.testWorkspace-headless.kaito.svc.cluster.local:29500 /workspace/tfs/inference_api.p",
hasAdapters: false,
},
}

Expand All @@ -81,7 +97,7 @@ func TestCreatePresetInference(t *testing.T) {
mockClient := test.NewClient()
tc.callMocks(mockClient)

workspace := test.MockWorkspaceWithPreset
workspace := tc.workspace
workspace.Resource.Count = &tc.nodeCount
expectedSecrets := []string{"fake-secret"}
if tc.hasAdapters {
Expand All @@ -97,15 +113,8 @@ func TestCreatePresetInference(t *testing.T) {
}
}

useHeadlessSvc := false

var inferenceObj *model.PresetParam
model := plugin.KaitoModelRegister.MustGet(tc.modelName)
inferenceObj = model.GetInferenceParameters()

if strings.Contains(tc.modelName, "distributed") {
useHeadlessSvc = true
}
svc := &corev1.Service{
ObjectMeta: v1.ObjectMeta{
Name: workspace.Name,
Expand All @@ -117,7 +126,7 @@ func TestCreatePresetInference(t *testing.T) {
}
mockClient.CreateOrUpdateObjectInMap(svc)

createdObject, _ := CreatePresetInference(context.TODO(), workspace, test.MockWorkspaceWithPresetHash, inferenceObj, useHeadlessSvc, mockClient)
createdObject, _ := CreatePresetInference(context.TODO(), workspace, test.MockWorkspaceWithPresetHash, model, mockClient)
createdWorkload := ""
switch createdObject.(type) {
case *appsv1.Deployment:
Expand Down
Loading

0 comments on commit 3b52cc4

Please sign in to comment.