diff --git a/scripts/pipeline_generator/utils.py b/scripts/pipeline_generator/utils.py index 52118a6..0fe7059 100644 --- a/scripts/pipeline_generator/utils.py +++ b/scripts/pipeline_generator/utils.py @@ -32,7 +32,7 @@ class AgentQueue(str, enum.Enum): def get_agent_queue(no_gpu: Optional[bool], gpu_type: Optional[str], num_gpus: Optional[int]) -> AgentQueue: if no_gpu: return AgentQueue.AWS_SMALL_CPU - if gpu_type == A100_GPU: + if gpu_type == GPUType.A100.value: return AgentQueue.A100 return AgentQueue.AWS_1xL4 if num_gpus == 1 else AgentQueue.AWS_4xL4