Skip to content

Commit

Permalink
train api updated
Browse files Browse the repository at this point in the history
  • Loading branch information
deepanker13 committed Dec 13, 2023
1 parent 80dbf96 commit 8cbe61e
Showing 1 changed file with 56 additions and 18 deletions.
74 changes: 56 additions & 18 deletions sdk/python/kubeflow/training/api/training_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
import time
import json
from typing import Optional, Callable, List, Dict, Any, Set
from typing import Optional, Callable, List, Dict, Any, Set, Literal
import queue
from kubernetes import client, config, watch

Expand All @@ -25,6 +25,13 @@
from kubeflow.training.constants import constants
from kubeflow.training.utils import utils

from typing import Union
from kubeflow.storage_init_container.hugging_face import (
HuggingFaceModelParams,
HuggingFaceTrainParams,
)
from kubeflow.storage_init_container.s3 import S3DatasetParams

logger = logging.getLogger(__name__)

status_logger = utils.StatusLogger(
Expand Down Expand Up @@ -84,19 +91,38 @@ def __init__(

def train(
self,
name=None,
namespace=None,
workers=1,
model_args=None,
dataset_args=None,
parameters=None,
resources_per_worker={"gpu": 0, "cpu": 0, "memory": "10Gi"},
name: str = None,
namespace: str = None,
num_workers: int = 1,
num_procs_per_worker: int = 1,
pvc: Dict[Literal["name", "claimName"], str] = None,
model_params: HuggingFaceModelParams = None,
dataset_params: S3DatasetParams = None,
parameters: HuggingFaceTrainParams = None,
resources_per_worker: Dict[Literal["gpu", "cpu", "memory"], any] = None,
):
"""
Higher level train api
"""
if not name or not namespace:
raise ValueError("job name or namespace cannot be null")
if (
not name
or not namespace
or not pvc
or not model_params
or not dataset_params
or not parameters
or not resources_per_worker
):
raise ValueError("One of the required parameters is None")

if num_procs_per_worker > resources_per_worker["gpu"]:
raise ValueError("Insufficient gpu resources allocated to the container.")

if isinstance(model_params, HuggingFaceModelParams):
mp = "hf"

if isinstance(dataset_params, S3DatasetParams):
dp = "s3"

# create init container spec
init_container_spec = utils.get_container_spec(
Expand All @@ -108,13 +134,15 @@ def train(
"--model_provider",
mp,
"--model_provider_args",
json.dumps(model_args.__dict__),
json.dumps(model_params.__dict__),
"--dataset_provider",
dp,
"--dataset_provider_args",
json.dumps(dataset_args.__dict__),
json.dumps(dataset_params.__dict__),
],
volume_mounts=models.V1VolumeMount(),
volume_mounts=models.V1VolumeMount(
name="model_dataset_store", mount_path="/workspace"
),
)

# create app container spec
Expand All @@ -124,31 +152,41 @@ def train(
"train_container_image"
],
args=["--parameters", json.dumps(parameters.__dict__)],
volume_mounts=models.V1VolumeMount(),
volume_mounts=models.V1VolumeMount(
name=pvc["name"], mount_path="/workspace"
),
resources=resources_per_worker,
)

# create worker pod spec
worker_pod_template_spec = utils.get_pod_template_spec(
job_kind=constants.PYTORCHJOB_KIND,
containers_spec=[container_spec],
volumes_spec=[models.V1Volume()],
volumes_spec=[
models.V1Volume(
name=pvc["name"], persistent_volume_claim=pvc["claimName"]
)
],
)

# create master pod spec
master_pod_template_spec = utils.get_pod_template_spec(
job_kind=constants.PYTORCHJOB_KIND,
containers_spec=[init_container_spec, container_spec],
volumes_spec=[models.V1Volume()],
volumes_spec=[
models.V1Volume(
name=pvc["name"], persistent_volume_claim=pvc["claimName"]
)
],
)

job = utils.get_pytorchjob_template(
name=name,
namespace=namespace,
master_pod_template_spec=master_pod_template_spec,
worker_pod_template_spec=worker_pod_template_spec,
num_worker_replicas=workers,
num_procs_per_worker=resources_per_worker["gpu"],
num_worker_replicas=num_workers,
num_procs_per_worker=num_procs_per_worker,
elastic_policy=models.KubeflowOrgV1ElasticPolicy(rdzv_backend="c10d"),
)

Expand Down

0 comments on commit 8cbe61e

Please sign in to comment.