Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
deepanker13 committed Dec 15, 2023
1 parent f9cae4d commit c59b0b0
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 18 deletions.
13 changes: 10 additions & 3 deletions sdk/python/kubeflow/storage_init_container/hugging_face.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from abstract_model_provider import modelProvider
from .abstract_model_provider import modelProvider
from dataclasses import dataclass, field
from typing import Literal
import transformers
Expand All @@ -18,9 +18,16 @@

@dataclass
class HuggingFaceModelParams:
access_token: str
model_uri: str
transformer_type: Literal[*TRANSFORMER_TYPES]
transformer_type: Literal[
"AutoModelForSequenceClassification",
"AutoModelForTokenClassification",
"AutoModelForQuestionAnswering",
"AutoModelForCausalLM",
"AutoModelForMaskedLM",
"AutoModelForImageClassification",
]
access_token: str = None
download_dir: str = field(default="/workspace/models")

def __post_init__(self):
Expand Down
8 changes: 4 additions & 4 deletions sdk/python/kubeflow/storage_init_container/s3.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from abstract_dataset_provider import datasetProvider
from .abstract_dataset_provider import datasetProvider
from dataclasses import dataclass, field
import json
import boto3
Expand All @@ -7,12 +7,12 @@

@dataclass
class S3DatasetParams:
access_key: str
secret_key: str
endpoint_url: str
bucket_name: str
file_key: str
region_name: str
region_name: str = None
access_key: str = None
secret_key: str = None
download_dir: str = field(default="/workspace/datasets")

def is_valid_url(self, url):
Expand Down
32 changes: 23 additions & 9 deletions sdk/python/kubeflow/training/api/training_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,11 @@ def train(
"--dataset_provider_args",
json.dumps(dataset_params.__dict__),
],
volume_mounts=models.V1VolumeMount(
name="model_dataset_store", mount_path="/workspace"
),
volume_mounts=[
models.V1VolumeMount(
name="model_dataset_store", mount_path="/workspace"
)
],
)

# create app container spec
Expand All @@ -152,10 +154,16 @@ def train(
"train_container_image"
],
args=["--parameters", json.dumps(parameters.__dict__)],
volume_mounts=models.V1VolumeMount(
name=pvc["name"], mount_path="/workspace"
volume_mounts=[
models.V1VolumeMount(name=pvc["name"], mount_path="/workspace")
],
resources=models.V1ResourceRequirements(
limits={
"nvidia.com/gpu": resources_per_worker["gpu"],
"cpu": resources_per_worker["cpu"],
"memory": resources_per_worker["memory"],
}
),
resources=resources_per_worker,
)

# create worker pod spec
Expand All @@ -164,7 +172,10 @@ def train(
containers_spec=[container_spec],
volumes_spec=[
models.V1Volume(
name=pvc["name"], persistent_volume_claim=pvc["claimName"]
name=pvc["name"],
persistent_volume_claim=models.V1PersistentVolumeClaimVolumeSource(
claim_name=pvc["claimName"]
),
)
],
)
Expand All @@ -175,7 +186,10 @@ def train(
containers_spec=[init_container_spec, container_spec],
volumes_spec=[
models.V1Volume(
name=pvc["name"], persistent_volume_claim=pvc["claimName"]
name=pvc["name"],
persistent_volume_claim=models.V1PersistentVolumeClaimVolumeSource(
claim_name=pvc["claimName"]
),
)
],
)
Expand All @@ -190,7 +204,7 @@ def train(
elastic_policy=models.KubeflowOrgV1ElasticPolicy(rdzv_backend="c10d"),
)

self.create_job(job)
self.create_job(job, namespace=namespace)

def create_job(
self,
Expand Down
4 changes: 2 additions & 2 deletions sdk/python/kubeflow/training/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def get_container_spec(
image: str,
args: Optional[List[str]] = None,
resources: Optional[models.V1ResourceRequirements] = None,
volume_mounts: Optional[models.V1VolumeMount] = None,
volume_mounts: Optional[List[models.V1VolumeMount]] = None,
) -> models.V1Container:
"""
get container spec for given name and image.
Expand Down Expand Up @@ -322,7 +322,7 @@ def get_pytorchjob_template(
)

if num_procs_per_worker > 0:
pytorchjob.spec.nproc_per_node = num_procs_per_worker
pytorchjob.spec.nproc_per_node = str(num_procs_per_worker)
if elastic_policy:
pytorchjob.spec.elastic_policy = elastic_policy

Expand Down

0 comments on commit c59b0b0

Please sign in to comment.