diff --git a/go/tasks/pluginmachinery/core/phase.go b/go/tasks/pluginmachinery/core/phase.go index 93fd3067d..4a2964c32 100644 --- a/go/tasks/pluginmachinery/core/phase.go +++ b/go/tasks/pluginmachinery/core/phase.go @@ -218,7 +218,6 @@ func PhaseInfoQueuedWithTaskInfo(version uint32, reason string, info *TaskInfo) } func PhaseInfoInitializing(t time.Time, version uint32, reason string, info *TaskInfo) PhaseInfo { - pi := phaseInfo(PhaseInitializing, version, nil, info, false) pi.reason = reason return pi diff --git a/go/tasks/plugins/webapi/agent/plugin.go b/go/tasks/plugins/webapi/agent/plugin.go index 9e663319e..f59e94d6f 100644 --- a/go/tasks/plugins/webapi/agent/plugin.go +++ b/go/tasks/plugins/webapi/agent/plugin.go @@ -5,6 +5,7 @@ import ( "crypto/x509" "encoding/gob" "fmt" + "time" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "google.golang.org/grpc/credentials" @@ -12,7 +13,7 @@ import ( "google.golang.org/grpc/grpclog" - flyteIdl "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" pluginErrors "github.com/flyteorg/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" @@ -40,7 +41,8 @@ type Plugin struct { type ResourceWrapper struct { State admin.State - Outputs *flyteIdl.LiteralMap + Outputs *flyteIdlCore.LiteralMap + Message string } type ResourceMetaWrapper struct { @@ -143,6 +145,7 @@ func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest weba return &ResourceWrapper{ State: res.Resource.State, Outputs: res.Resource.Outputs, + Message: res.Resource.Message, }, nil } @@ -173,6 +176,8 @@ func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase taskInfo := &core.TaskInfo{} switch resource.State { + case admin.State_PENDING: + return core.PhaseInfoInitializing(time.Now(), core.DefaultPhaseVersion, resource.Message, taskInfo), nil case admin.State_RUNNING: return core.PhaseInfoRunning(core.DefaultPhaseVersion, taskInfo), nil case admin.State_PERMANENT_FAILURE: diff --git a/go/tasks/plugins/webapi/agent/plugin_test.go b/go/tasks/plugins/webapi/agent/plugin_test.go index 3c413bf37..2915f3893 100644 --- a/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/go/tasks/plugins/webapi/agent/plugin_test.go @@ -5,12 +5,13 @@ import ( "testing" "time" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flytestdlib/config" - "google.golang.org/grpc" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" pluginCoreMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" + webapiPlugin "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi/mocks" "github.com/flyteorg/flytestdlib/promutils" "github.com/stretchr/testify/assert" ) @@ -30,6 +31,7 @@ func TestPlugin(t *testing.T) { metricScope: fakeSetupContext.MetricsScope(), cfg: GetConfig(), } + t.Run("get config", func(t *testing.T) { err := SetConfig(&cfg) assert.NoError(t, err) @@ -99,4 +101,70 @@ func TestPlugin(t *testing.T) { ctx, _ = getFinalContext(context.TODO(), "CreateTask", &Agent{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"CreateTask": {Duration: 1 * time.Millisecond}}}) assert.NotEqual(t, context.TODO(), ctx) }) + + t.Run("test PENDING Status", func(t *testing.T) { + taskContext := new(webapiPlugin.StatusContext) + taskContext.On("Resource").Return(&ResourceWrapper{ + State: admin.State_PENDING, + Outputs: nil, + Message: "Waiting for cluster", + }) + + phase, err := plugin.Status(context.Background(), taskContext) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseInitializing, phase.Phase()) + assert.Equal(t, "Waiting for cluster", phase.Reason()) + }) + + t.Run("test RUNNING Status", func(t *testing.T) { + taskContext := new(webapiPlugin.StatusContext) + taskContext.On("Resource").Return(&ResourceWrapper{ + State: admin.State_RUNNING, + Outputs: nil, + Message: "Job is running", + }) + + phase, err := plugin.Status(context.Background(), taskContext) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseRunning, phase.Phase()) + }) + + t.Run("test PERMANENT_FAILURE Status", func(t *testing.T) { + taskContext := new(webapiPlugin.StatusContext) + taskContext.On("Resource").Return(&ResourceWrapper{ + State: admin.State_PERMANENT_FAILURE, + Outputs: nil, + Message: "", + }) + + phase, err := plugin.Status(context.Background(), taskContext) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhasePermanentFailure, phase.Phase()) + }) + + t.Run("test RETRYABLE_FAILURE Status", func(t *testing.T) { + taskContext := new(webapiPlugin.StatusContext) + taskContext.On("Resource").Return(&ResourceWrapper{ + State: admin.State_RETRYABLE_FAILURE, + Outputs: nil, + Message: "", + }) + + phase, err := plugin.Status(context.Background(), taskContext) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseRetryableFailure, phase.Phase()) + }) + + t.Run("test UNDEFINED Status", func(t *testing.T) { + taskContext := new(webapiPlugin.StatusContext) + taskContext.On("Resource").Return(&ResourceWrapper{ + State: 5, + Outputs: nil, + Message: "", + }) + + phase, err := plugin.Status(context.Background(), taskContext) + assert.Error(t, err) + assert.Equal(t, pluginsCore.PhaseUndefined, phase.Phase()) + }) }