Skip to content

Commit

Permalink
KEP-2170: Add validation to Torch numProcPerNode field
Browse files Browse the repository at this point in the history
Signed-off-by: Antonin Stefanutti <[email protected]>
  • Loading branch information
astefanutti committed Feb 6, 2025
1 parent 3060332 commit 4b3c70d
Show file tree
Hide file tree
Showing 14 changed files with 52 additions and 29 deletions.
4 changes: 2 additions & 2 deletions api/openapi-spec/swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@
},
"numProcPerNode": {
"description": "Number of processes per node. This value is inserted into the `--nproc-per-node` argument of the `torchrun` CLI. Supported values: `auto`, `cpu`, `gpu`, or int value. Defaults to `auto`.",
"type": "string"
"$ref": "#/definitions/k8s.io.apimachinery.pkg.util.intstr.IntOrString"
}
}
},
Expand Down Expand Up @@ -716,7 +716,7 @@
},
"numProcPerNode": {
"description": "Number of processes/workers/slots on every training node. For the Torch runtime: `auto`, `cpu`, `gpu`, or int value can be set. For the MPI runtime only int value can be set.",
"type": "string"
"$ref": "#/definitions/k8s.io.apimachinery.pkg.util.intstr.IntOrString"
},
"resourcesPerNode": {
"description": "Compute resources for each training node.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -583,12 +583,17 @@ spec:
type: integer
type: object
numProcPerNode:
anyOf:
- type: integer
- type: string
description: |-
Number of processes per node.
This value is inserted into the `--nproc-per-node` argument of the `torchrun` CLI.
Supported values: `auto`, `cpu`, `gpu`, or int value.
Defaults to `auto`.
type: string
x-kubernetes-int-or-string: true
x-kubernetes-validations:
- rule: self > 0 || self in ['auto', 'cpu', 'gpu']
type: object
type: object
podGroupPolicy:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -583,12 +583,17 @@ spec:
type: integer
type: object
numProcPerNode:
anyOf:
- type: integer
- type: string
description: |-
Number of processes per node.
This value is inserted into the `--nproc-per-node` argument of the `torchrun` CLI.
Supported values: `auto`, `cpu`, `gpu`, or int value.
Defaults to `auto`.
type: string
x-kubernetes-int-or-string: true
x-kubernetes-validations:
- rule: self > 0 || self in ['auto', 'cpu', 'gpu']
type: object
type: object
podGroupPolicy:
Expand Down
5 changes: 4 additions & 1 deletion manifests/base/crds/trainer.kubeflow.org_trainjobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3138,11 +3138,14 @@ spec:
format: int32
type: integer
numProcPerNode:
anyOf:
- type: integer
- type: string
description: |-
Number of processes/workers/slots on every training node.
For the Torch runtime: `auto`, `cpu`, `gpu`, or int value can be set.
For the MPI runtime only int value can be set.
type: string
x-kubernetes-int-or-string: true
resourcesPerNode:
description: Compute resources for each training node.
properties:
Expand Down
5 changes: 3 additions & 2 deletions pkg/apis/trainer/v1alpha1/trainingruntime_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package v1alpha1
import (
autoscalingv2 "k8s.io/api/autoscaling/v2"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/intstr"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"
)

Expand Down Expand Up @@ -171,9 +172,9 @@ type TorchMLPolicySource struct {
// Number of processes per node.
// This value is inserted into the `--nproc-per-node` argument of the `torchrun` CLI.
// Supported values: `auto`, `cpu`, `gpu`, or int value.
// TODO (andreyvelich): Add kubebuilder validation.
// Defaults to `auto`.
NumProcPerNode *string `json:"numProcPerNode,omitempty"`
// +kubebuilder:validation:XValidation:rule="self > 0 || self in ['auto', 'cpu', 'gpu']"
NumProcPerNode *intstr.IntOrString `json:"numProcPerNode,omitempty"`

// Elastic policy for the PyTorch training.
ElasticPolicy *TorchElasticPolicy `json:"elasticPolicy,omitempty"`
Expand Down
3 changes: 2 additions & 1 deletion pkg/apis/trainer/v1alpha1/trainjob_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package v1alpha1
import (
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/intstr"
)

const (
Expand Down Expand Up @@ -194,7 +195,7 @@ type Trainer struct {
// Number of processes/workers/slots on every training node.
// For the Torch runtime: `auto`, `cpu`, `gpu`, or int value can be set.
// For the MPI runtime only int value can be set.
NumProcPerNode *string `json:"numProcPerNode,omitempty"`
NumProcPerNode *intstr.IntOrString `json:"numProcPerNode,omitempty"`
}

// DatasetConfig represents the desired dataset configuration.
Expand Down
5 changes: 3 additions & 2 deletions pkg/apis/trainer/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 4 additions & 6 deletions pkg/apis/trainer/v1alpha1/zz_generated.openapi.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions pkg/client/applyconfiguration/trainer/v1alpha1/trainer.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions pkg/runtime/core/trainingruntime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package core
import (
"context"
"fmt"
"k8s.io/apimachinery/pkg/util/intstr"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -263,7 +264,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
"succeeded to build JobSet with Torch values from the TrainJob": {
trainingRuntime: testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").RuntimeSpec(
testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").Spec).
TorchPolicy(100, "auto").
TorchPolicy(100, intstr.FromString("auto")).
ContainerTrainer("test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
Obj(),
).Obj(),
Expand All @@ -273,7 +274,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
Trainer(
testingutil.MakeTrainJobTrainerWrapper().
NumNodes(30).
NumProcPerNode("3").
NumProcPerNode(intstr.FromInt32(3)).
Obj(),
).
Obj(),
Expand Down Expand Up @@ -317,7 +318,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
"succeeded to build JobSet with Torch values from the Runtime and envs.": {
trainingRuntime: testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").RuntimeSpec(
testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").Spec).
TorchPolicy(100, "auto").
TorchPolicy(100, intstr.FromString("auto")).
ContainerTrainer("test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
ContainerTrainerEnv(
[]corev1.EnvVar{
Expand Down
7 changes: 4 additions & 3 deletions pkg/runtime/framework/plugins/torch/torch.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package torch
import (
"context"
"fmt"
"k8s.io/apimachinery/pkg/util/intstr"

corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/util/sets"
Expand Down Expand Up @@ -61,9 +62,9 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob)
}
info.Trainer.NumNodes = numNodes

numProcPerNode := info.RuntimePolicy.MLPolicy.Torch.NumProcPerNode
numProcPerNode := ptr.Deref(info.RuntimePolicy.MLPolicy.Torch.NumProcPerNode, intstr.FromString("auto"))
if trainJob.Spec.Trainer != nil && trainJob.Spec.Trainer.NumProcPerNode != nil {
numProcPerNode = trainJob.Spec.Trainer.NumProcPerNode
numProcPerNode = ptr.Deref(trainJob.Spec.Trainer.NumProcPerNode, intstr.FromString("auto"))
}

// Update envs for Info object.
Expand All @@ -78,7 +79,7 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob)
},
{
Name: constants.TorchEnvNumProcPerNode,
Value: ptr.Deref(numProcPerNode, "auto"),
Value: numProcPerNode.String(),
},
{
Name: constants.TorchEnvNodeRank,
Expand Down
5 changes: 3 additions & 2 deletions pkg/util/testing/wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/intstr"
"k8s.io/utils/ptr"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"
schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1"
Expand Down Expand Up @@ -392,7 +393,7 @@ func (t *TrainJobTrainerWrapper) NumNodes(numNodes int32) *TrainJobTrainerWrappe
return t
}

func (t *TrainJobTrainerWrapper) NumProcPerNode(numProcPerNode string) *TrainJobTrainerWrapper {
func (t *TrainJobTrainerWrapper) NumProcPerNode(numProcPerNode intstr.IntOrString) *TrainJobTrainerWrapper {
t.Trainer.NumProcPerNode = &numProcPerNode
return t
}
Expand Down Expand Up @@ -689,7 +690,7 @@ func (s *TrainingRuntimeSpecWrapper) NumNodes(numNodes int32) *TrainingRuntimeSp
return s
}

func (s *TrainingRuntimeSpecWrapper) TorchPolicy(numNodes int32, numProcPerNode string) *TrainingRuntimeSpecWrapper {
func (s *TrainingRuntimeSpecWrapper) TorchPolicy(numNodes int32, numProcPerNode intstr.IntOrString) *TrainingRuntimeSpecWrapper {
s.MLPolicy = &trainer.MLPolicy{
NumNodes: &numNodes,
MLPolicySource: trainer.MLPolicySource{
Expand Down
3 changes: 2 additions & 1 deletion test/integration/controller/trainjob_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/intstr"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"
Expand Down Expand Up @@ -278,7 +279,7 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() {
trainingRuntime = testingutil.MakeTrainingRuntimeWrapper(ns.Name, "alpha").
RuntimeSpec(
testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "alpha").Spec).
TorchPolicy(100, "auto").
TorchPolicy(100, intstr.FromString("auto")).
ContainerTrainer("test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
Obj()).
Obj()
Expand Down

0 comments on commit 4b3c70d

Please sign in to comment.