diff --git a/common/resource.go b/common/resource.go index 4e357305d..fb8a09e5b 100644 --- a/common/resource.go +++ b/common/resource.go @@ -16,17 +16,18 @@ import ( // Resource aims to simplify things like error & deleted entities handling type Resource struct { - Create func(ctx context.Context, d *schema.ResourceData, c *DatabricksClient) error - Read func(ctx context.Context, d *schema.ResourceData, c *DatabricksClient) error - Update func(ctx context.Context, d *schema.ResourceData, c *DatabricksClient) error - Delete func(ctx context.Context, d *schema.ResourceData, c *DatabricksClient) error - CustomizeDiff func(ctx context.Context, d *schema.ResourceDiff) error - StateUpgraders []schema.StateUpgrader - Schema map[string]*schema.Schema - SchemaVersion int - Timeouts *schema.ResourceTimeout - DeprecationMessage string - Importer *schema.ResourceImporter + Create func(ctx context.Context, d *schema.ResourceData, c *DatabricksClient) error + Read func(ctx context.Context, d *schema.ResourceData, c *DatabricksClient) error + Update func(ctx context.Context, d *schema.ResourceData, c *DatabricksClient) error + Delete func(ctx context.Context, d *schema.ResourceData, c *DatabricksClient) error + CustomizeDiff func(ctx context.Context, d *schema.ResourceDiff) error + StateUpgraders []schema.StateUpgrader + Schema map[string]*schema.Schema + SchemaVersion int + Timeouts *schema.ResourceTimeout + DeprecationMessage string + Importer *schema.ResourceImporter + CanSkipReadAfterCreateAndUpdate func(d *schema.ResourceData) bool } func nicerError(ctx context.Context, err error, action string) error { @@ -94,6 +95,9 @@ func (r Resource) ToResource() *schema.Resource { err = nicerError(ctx, err, "update") return diag.FromErr(err) } + if r.CanSkipReadAfterCreateAndUpdate != nil && r.CanSkipReadAfterCreateAndUpdate(d) { + return nil + } if err := recoverable(r.Read)(ctx, d, c); err != nil { err = nicerError(ctx, err, "read") return diag.FromErr(err) @@ -162,6 +166,9 @@ func (r Resource) ToResource() *schema.Resource { err = nicerError(ctx, err, "create") return diag.FromErr(err) } + if r.CanSkipReadAfterCreateAndUpdate != nil && r.CanSkipReadAfterCreateAndUpdate(d) { + return nil + } if err = recoverable(r.Read)(ctx, d, c); err != nil { err = nicerError(ctx, err, "read") return diag.FromErr(err) diff --git a/common/resource_test.go b/common/resource_test.go index f01f373ff..2ece50d28 100644 --- a/common/resource_test.go +++ b/common/resource_test.go @@ -3,6 +3,7 @@ package common import ( "context" "fmt" + "log" "testing" "github.com/databricks/databricks-sdk-go/apierr" @@ -38,6 +39,94 @@ func TestImportingCallsRead(t *testing.T) { assert.Equal(t, 1, d.Get("foo")) } +func createTestResourceForSkipRead(skipRead bool) Resource { + res := Resource{ + Create: func(ctx context.Context, + d *schema.ResourceData, + c *DatabricksClient) error { + log.Println("[DEBUG] Create called") + return d.Set("foo", 1) + }, + Read: func(ctx context.Context, + d *schema.ResourceData, + c *DatabricksClient) error { + log.Println("[DEBUG] Read called") + d.Set("foo", 2) + return nil + }, + Update: func(ctx context.Context, + d *schema.ResourceData, + c *DatabricksClient) error { + log.Println("[DEBUG] Update called") + return d.Set("foo", 3) + }, + Schema: map[string]*schema.Schema{ + "foo": { + Type: schema.TypeInt, + Required: true, + }, + }, + } + if skipRead { + res.CanSkipReadAfterCreateAndUpdate = func(d *schema.ResourceData) bool { + return true + } + } + return res +} + +func TestCreateSkipRead(t *testing.T) { + client := &DatabricksClient{} + ctx := context.Background() + r := createTestResourceForSkipRead(true).ToResource() + d := r.TestResourceData() + diags := r.CreateContext(ctx, d, client) + assert.False(t, diags.HasError()) + assert.Equal(t, 1, d.Get("foo")) +} + +func TestCreateDontSkipRead(t *testing.T) { + client := &DatabricksClient{} + ctx := context.Background() + r := createTestResourceForSkipRead(false).ToResource() + d := r.TestResourceData() + diags := r.CreateContext(ctx, d, client) + assert.False(t, diags.HasError()) + assert.Equal(t, 2, d.Get("foo")) +} + +func TestUpdateSkipRead(t *testing.T) { + client := &DatabricksClient{} + ctx := context.Background() + r := createTestResourceForSkipRead(true).ToResource() + d := r.TestResourceData() + datas, err := r.Importer.StateContext(ctx, d, client) + require.NoError(t, err) + assert.Len(t, datas, 1) + assert.False(t, r.Schema["foo"].ForceNew) + assert.Equal(t, "", d.Id()) + + diags := r.UpdateContext(ctx, d, client) + assert.False(t, diags.HasError()) + assert.Equal(t, 3, d.Get("foo")) +} + +func TestUpdateDontSkipRead(t *testing.T) { + client := &DatabricksClient{} + ctx := context.Background() + r := createTestResourceForSkipRead(false).ToResource() + d := r.TestResourceData() + datas, err := r.Importer.StateContext(ctx, d, client) + require.NoError(t, err) + assert.Len(t, datas, 1) + assert.False(t, r.Schema["foo"].ForceNew) + assert.Equal(t, "", d.Id()) + + diags := r.UpdateContext(ctx, d, client) + assert.False(t, diags.HasError()) + assert.Equal(t, 2, d.Get("foo")) +} + func TestHTTP404TriggersResourceRemovalForReadAndDelete(t *testing.T) { nope := func(ctx context.Context, d *schema.ResourceData,