From b62ad9b0687bbe04a7ee026001ca373d597b1877 Mon Sep 17 00:00:00 2001 From: deepanker13 Date: Wed, 10 Jan 2024 20:54:49 +0530 Subject: [PATCH] bug fix --- .../kubeflow/training/api/training_client.py | 27 +++++++++++-------- sdk/python/kubeflow/training/utils/utils.py | 2 ++ 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/sdk/python/kubeflow/training/api/training_client.py b/sdk/python/kubeflow/training/api/training_client.py index f0997eddff..249d589d2d 100644 --- a/sdk/python/kubeflow/training/api/training_client.py +++ b/sdk/python/kubeflow/training/api/training_client.py @@ -139,15 +139,19 @@ def train( if isinstance(resources_per_worker, dict): if "gpu" in resources_per_worker: - resources_per_worker["nvidia.com/gpu"] = resources_per_worker.pop("gpu") - - if ( - resources_per_worker["gpu"] is not None - and num_procs_per_worker > resources_per_worker["gpu"] - ) or (resources_per_worker["gpu"] is None and num_procs_per_worker != 0): - raise ValueError( - "Insufficient gpu resources allocated to the container." - ) + if ( + resources_per_worker["gpu"] is not None + and (num_procs_per_worker > resources_per_worker["gpu"]) + ) or ( + resources_per_worker["gpu"] is None and num_procs_per_worker != 0 + ): + raise ValueError( + "Insufficient gpu resources allocated to the container." + ) + if resources_per_worker["gpu"] is not None: + resources_per_worker["nvidia.com/gpu"] = resources_per_worker.pop( + "gpu" + ) if ( "cpu" not in resources_per_worker @@ -171,7 +175,8 @@ def train( ), ) except Exception as e: - raise RuntimeError("failed to create pvc") + pass # local + # raise RuntimeError("failed to create pvc") if isinstance(model_provider_parameters, HuggingFaceModelParams): mp = "hf" @@ -189,7 +194,7 @@ def train( "--model_provider", mp, "--model_provider_parameters", - json.dumps(model_provider_parameters.__dict__), + json.dumps(model_provider_parameters.__dict__, cls=utils.SetEncoder), "--dataset_provider", dp, "--dataset_provider_parameters", diff --git a/sdk/python/kubeflow/training/utils/utils.py b/sdk/python/kubeflow/training/utils/utils.py index 3f0cc6ec56..09130a4de1 100644 --- a/sdk/python/kubeflow/training/utils/utils.py +++ b/sdk/python/kubeflow/training/utils/utils.py @@ -372,4 +372,6 @@ class SetEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, set): return list(obj) + if isinstance(obj, type): + return obj.__name__ return json.JSONEncoder.default(self, obj)