Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
deepanker13 committed Jan 10, 2024
1 parent 4488d02 commit b62ad9b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 11 deletions.
27 changes: 16 additions & 11 deletions sdk/python/kubeflow/training/api/training_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,19 @@ def train(

if isinstance(resources_per_worker, dict):
if "gpu" in resources_per_worker:
resources_per_worker["nvidia.com/gpu"] = resources_per_worker.pop("gpu")

if (
resources_per_worker["gpu"] is not None
and num_procs_per_worker > resources_per_worker["gpu"]
) or (resources_per_worker["gpu"] is None and num_procs_per_worker != 0):
raise ValueError(
"Insufficient gpu resources allocated to the container."
)
if (
resources_per_worker["gpu"] is not None
and (num_procs_per_worker > resources_per_worker["gpu"])
) or (
resources_per_worker["gpu"] is None and num_procs_per_worker != 0
):
raise ValueError(
"Insufficient gpu resources allocated to the container."
)
if resources_per_worker["gpu"] is not None:
resources_per_worker["nvidia.com/gpu"] = resources_per_worker.pop(
"gpu"
)

if (
"cpu" not in resources_per_worker
Expand All @@ -171,7 +175,8 @@ def train(
),
)
except Exception as e:
raise RuntimeError("failed to create pvc")
pass # local
# raise RuntimeError("failed to create pvc")

if isinstance(model_provider_parameters, HuggingFaceModelParams):
mp = "hf"
Expand All @@ -189,7 +194,7 @@ def train(
"--model_provider",
mp,
"--model_provider_parameters",
json.dumps(model_provider_parameters.__dict__),
json.dumps(model_provider_parameters.__dict__, cls=utils.SetEncoder),
"--dataset_provider",
dp,
"--dataset_provider_parameters",
Expand Down
2 changes: 2 additions & 0 deletions sdk/python/kubeflow/training/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,4 +372,6 @@ class SetEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, set):
return list(obj)
if isinstance(obj, type):
return obj.__name__
return json.JSONEncoder.default(self, obj)

0 comments on commit b62ad9b

Please sign in to comment.