From f569ead1a19c956044a142d6fe75ae0612cac322 Mon Sep 17 00:00:00 2001 From: achettyiitr Date: Wed, 18 Sep 2024 12:27:23 +0530 Subject: [PATCH] feat: snowpipe streaming --- .github/workflows/tests.yaml | 9 +- .../snowpipestreaming_test.go | 735 ++++++++++++++++++ ...docker-compose.rudder-snowpipe-clients.yml | 11 + .../docker-compose.rudder-transformer.yml | 11 + ...{integration_test.go => warehouse_test.go} | 1 + .../asyncdestinationmanager/common/utils.go | 2 +- .../asyncdestinationmanager/manager.go | 3 + .../snowpipestreaming/createchannel.go | 72 ++ .../snowpipestreaming/createchannel_test.go | 136 ++++ .../snowpipestreaming/deletechannel.go | 38 + .../snowpipestreaming/deletechannel_test.go | 108 +++ .../snowpipestreaming/insert.go | 77 ++ .../snowpipestreaming/insert_test.go | 125 +++ .../snowpipestreaming/options.go | 9 + .../snowpipestreaming/snowpipestreaming.go | 488 ++++++++++++ .../snowpipestreaming_test.go | 311 ++++++++ .../snowpipestreaming/status.go | 42 + .../snowpipestreaming/status_test.go | 122 +++ ...docker-compose.rudder-snowpipe-clients.yml | 11 + router/batchrouter/handle.go | 2 + router/batchrouter/handle_async.go | 3 +- router/batchrouter/handle_lifecycle.go | 1 + testhelper/warehouse/records.go | 58 ++ utils/misc/misc.go | 2 +- .../integrations/snowflake/datatype_mapper.go | 2 +- .../snowflake/datatype_mapper_test.go | 2 +- warehouse/integrations/snowflake/snowflake.go | 2 +- warehouse/utils/uploader.go | 58 ++ warehouse/utils/utils.go | 19 - 29 files changed, 2434 insertions(+), 26 deletions(-) create mode 100644 integration_test/snowpipestreaming/snowpipestreaming_test.go create mode 100644 integration_test/snowpipestreaming/testdata/docker-compose.rudder-snowpipe-clients.yml create mode 100644 integration_test/snowpipestreaming/testdata/docker-compose.rudder-transformer.yml rename integration_test/warehouse/{integration_test.go => warehouse_test.go} (99%) create mode 100644 router/batchrouter/asyncdestinationmanager/snowpipestreaming/createchannel.go create mode 100644 router/batchrouter/asyncdestinationmanager/snowpipestreaming/createchannel_test.go create mode 100644 router/batchrouter/asyncdestinationmanager/snowpipestreaming/deletechannel.go create mode 100644 router/batchrouter/asyncdestinationmanager/snowpipestreaming/deletechannel_test.go create mode 100644 router/batchrouter/asyncdestinationmanager/snowpipestreaming/insert.go create mode 100644 router/batchrouter/asyncdestinationmanager/snowpipestreaming/insert_test.go create mode 100644 router/batchrouter/asyncdestinationmanager/snowpipestreaming/options.go create mode 100644 router/batchrouter/asyncdestinationmanager/snowpipestreaming/snowpipestreaming.go create mode 100644 router/batchrouter/asyncdestinationmanager/snowpipestreaming/snowpipestreaming_test.go create mode 100644 router/batchrouter/asyncdestinationmanager/snowpipestreaming/status.go create mode 100644 router/batchrouter/asyncdestinationmanager/snowpipestreaming/status_test.go create mode 100644 router/batchrouter/asyncdestinationmanager/snowpipestreaming/testdata/docker-compose.rudder-snowpipe-clients.yml create mode 100644 testhelper/warehouse/records.go create mode 100644 warehouse/utils/uploader.go diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 8f872a9b77..5b1819fedb 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -142,6 +142,7 @@ jobs: - integration_test/tracing - integration_test/backendconfigunavailability - integration_test/trackedusersreporting + - integration_test/snowpipestreaming - processor - regulation-worker - router @@ -168,6 +169,11 @@ jobs: go-version-file: 'go.mod' - run: go version - run: go mod download + - name: Login to DockerHub + uses: docker/login-action@v3.3.0 + with: + username: rudderlabs + password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Package Unit [ ${{ matrix.package }} ] env: TEST_KAFKA_CONFLUENT_CLOUD_HOST: ${{ secrets.TEST_KAFKA_CONFLUENT_CLOUD_HOST }} @@ -178,7 +184,8 @@ jobs: TEST_KAFKA_AZURE_EVENT_HUBS_CLOUD_CONNECTION_STRING: ${{ secrets.TEST_KAFKA_AZURE_EVENT_HUBS_CLOUD_CONNECTION_STRING }} TEST_S3_DATALAKE_CREDENTIALS: ${{ secrets.TEST_S3_DATALAKE_CREDENTIALS }} BIGQUERY_INTEGRATION_TEST_CREDENTIALS: ${{ secrets.BIGQUERY_INTEGRATION_TEST_CREDENTIALS }} - run: make test exclude="${{ matrix.exclude }}" package=${{ matrix.package }} + SNOWPIPE_STREAMING_KEYPAIR_UNENCRYPTED_INTEGRATION_TEST_CREDENTIALS: ${{ secrets.SNOWPIPE_STREAMING_KEYPAIR_UNENCRYPTED_INTEGRATION_TEST_CREDENTIALS }} + run: FORCE_RUN_INTEGRATION_TESTS=true make test exclude="${{ matrix.exclude }}" package=${{ matrix.package }} - name: Sanitize name for Artifact run: | name=$(echo -n "${{ matrix.package }}" | sed -e 's/[ \t:\/\\"<>|*?]/-/g' -e 's/--*/-/g') diff --git a/integration_test/snowpipestreaming/snowpipestreaming_test.go b/integration_test/snowpipestreaming/snowpipestreaming_test.go new file mode 100644 index 0000000000..68444a6c7f --- /dev/null +++ b/integration_test/snowpipestreaming/snowpipestreaming_test.go @@ -0,0 +1,735 @@ +package snowpipestreaming + +import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "path" + "strconv" + "strings" + "testing" + "time" + + "github.com/ory/dockertest/v3" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + + "github.com/rudderlabs/compose-test/compose" + "github.com/rudderlabs/compose-test/testcompose" + "github.com/rudderlabs/rudder-go-kit/config" + kithttputil "github.com/rudderlabs/rudder-go-kit/httputil" + "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-go-kit/stats" + kithelper "github.com/rudderlabs/rudder-go-kit/testhelper" + "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource/postgres" + "github.com/rudderlabs/rudder-go-kit/testhelper/rand" + + "github.com/rudderlabs/rudder-server/runner" + "github.com/rudderlabs/rudder-server/testhelper/backendconfigtest" + "github.com/rudderlabs/rudder-server/testhelper/health" + "github.com/rudderlabs/rudder-server/warehouse/integrations/snowflake" + whutils "github.com/rudderlabs/rudder-server/warehouse/utils" +) + +type testCredentials struct { + Account string `json:"account"` + User string `json:"user"` + Role string `json:"role"` + Database string `json:"database"` + Warehouse string `json:"warehouse"` + PrivateKey string `json:"privateKey"` + PrivateKeyPassphrase string `json:"privateKeyPassphrase"` +} + +const ( + testKeyPairUnencrypted = "SNOWPIPE_STREAMING_KEYPAIR_UNENCRYPTED_INTEGRATION_TEST_CREDENTIALS" +) + +func getSnowpipeTestCredentials(key string) (*testCredentials, error) { + cred, exists := os.LookupEnv(key) + if !exists { + return nil, errors.New("snowpipe test credentials not found") + } + + var credentials testCredentials + err := json.Unmarshal([]byte(cred), &credentials) + if err != nil { + return nil, fmt.Errorf("unable to marshall %s to snowpipe test credentials: %v", key, err) + } + return &credentials, nil +} + +func randSchema(provider string) string { // nolint:unparam + hex := strings.ToLower(rand.String(12)) + namespace := fmt.Sprintf("test_%s_%d", hex, time.Now().Unix()) + return whutils.ToProviderCase(provider, whutils.ToSafeNamespace(provider, + namespace, + )) +} + +func TestSnowpipeStreaming(t *testing.T) { + for _, key := range []string{ + testKeyPairUnencrypted, + } { + if _, exists := os.LookupEnv(key); !exists { + if os.Getenv("FORCE_RUN_INTEGRATION_TESTS") == "true" { + t.Fatalf("%s environment variable not set", key) + } + t.Skipf("Skipping %s as %s is not set", t.Name(), key) + } + } + + c := testcompose.New(t, compose.FilePaths([]string{"testdata/docker-compose.rudder-snowpipe-clients.yml", "testdata/docker-compose.rudder-transformer.yml"})) + c.Start(context.Background()) + + transformerURL := fmt.Sprintf("http://localhost:%d", c.Port("transformer", 9090)) + snowpipeClientsURL := fmt.Sprintf("http://localhost:%d", c.Port("rudder-snowpipe-clients", 9078)) + + keyPairUnEncryptedCredentials, err := getSnowpipeTestCredentials(testKeyPairUnencrypted) + require.NoError(t, err) + + t.Run("namespace and table already exists", func(t *testing.T) { + config.Reset() + defer config.Reset() + + pool, err := dockertest.NewPool("") + require.NoError(t, err) + postgresContainer, err := postgres.Setup(pool, t) + require.NoError(t, err) + + gwPort, err := kithelper.GetFreePort() + require.NoError(t, err) + + namespace := randSchema(whutils.SNOWFLAKE) + + destination := backendconfigtest. + NewDestinationBuilder("SNOWPIPE_STREAMING"). + WithID("destination1"). + WithConfigOption("account", keyPairUnEncryptedCredentials.Account). + WithConfigOption("warehouse", keyPairUnEncryptedCredentials.Warehouse). + WithConfigOption("database", keyPairUnEncryptedCredentials.Database). + WithConfigOption("role", keyPairUnEncryptedCredentials.Role). + WithConfigOption("user", keyPairUnEncryptedCredentials.User). + WithConfigOption("useKeyPairAuth", true). + WithConfigOption("privateKey", keyPairUnEncryptedCredentials.PrivateKey). + WithConfigOption("privateKeyPassphrase", keyPairUnEncryptedCredentials.PrivateKeyPassphrase). + WithConfigOption("namespace", namespace). + WithRevisionID("destination1"). + Build() + warehouse := whutils.ModelWarehouse{ + Namespace: namespace, + Destination: destination, + } + + bcServer := backendconfigtest.NewBuilder(). + WithWorkspaceConfig( + backendconfigtest.NewConfigBuilder(). + WithSource( + backendconfigtest.NewSourceBuilder(). + WithID("source1"). + WithWriteKey("writekey1"). + WithConnection(destination). + Build()). + Build()). + Build() + defer bcServer.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + wg, ctx := errgroup.WithContext(ctx) + wg.Go(func() error { + err := runRudderServer(ctx, gwPort, postgresContainer, bcServer.URL, transformerURL, snowpipeClientsURL, t.TempDir()) + if err != nil { + t.Logf("rudder-server exited with error: %v", err) + } + return err + }) + url := fmt.Sprintf("http://localhost:%d", gwPort) + health.WaitUntilReady(ctx, t, url+"/health", 60*time.Second, 10*time.Millisecond, t.Name()) + + sm := snowflake.New(config.New(), logger.NOP, stats.NOP) + require.NoError(t, err) + require.NoError(t, sm.Setup(ctx, warehouse, &whutils.NopUploader{})) + t.Cleanup(func() { sm.Cleanup(ctx) }) + require.NoError(t, sm.CreateSchema(ctx)) + t.Cleanup(func() { dropSchema(t, sm.DB.DB, namespace) }) + require.NoError(t, sm.CreateTable(ctx, "IDENTIFIES", whutils.ModelTableSchema{ + "CONTEXT_IP": "string", "CONTEXT_LIBRARY_NAME": "string", "CONTEXT_SOURCE_TYPE": "string", "ORIGINAL_TIMESTAMP": "datetime", "UUID_TS": "datetime", "CONTEXT_DESTINATION_ID": "string", "CONTEXT_DESTINATION_TYPE": "string", "CONTEXT_PASSED_IP": "string", "SENT_AT": "datetime", "TIMESTAMP": "datetime", "CONTEXT_SOURCE_ID": "string", "CONTEXT_TRAITS_TRAIT_1": "string", "CONTEXT_REQUEST_IP": "string", "ID": "string", "RECEIVED_AT": "datetime", "TRAIT_1": "string", "USER_ID": "string", + })) + require.NoError(t, sm.CreateTable(ctx, "USERS", whutils.ModelTableSchema{ + "CONTEXT_DESTINATION_ID": "string", "CONTEXT_IP": "string", "CONTEXT_LIBRARY_NAME": "string", "CONTEXT_PASSED_IP": "string", "CONTEXT_REQUEST_IP": "string", "CONTEXT_TRAITS_TRAIT_1": "string", "RECEIVED_AT": "datetime", "CONTEXT_DESTINATION_TYPE": "string", "CONTEXT_SOURCE_ID": "string", "CONTEXT_SOURCE_TYPE": "string", "ID": "string", "TRAIT_1": "string", "UUID_TS": "datetime", + })) + + eventFormat := func(index int) string { + return fmt.Sprintf(`{"batch":[{"userId":"%[1]s","type":"%[2]s","context":{"traits":{"trait1":"new-val"},"ip":"14.5.67.21","library":{"name":"http"}},"timestamp":"2020-02-02T00:23:09.544Z"}]}`, + rand.String(10), + "identify", + ) + } + + err = sendEvents(5, eventFormat, "writekey1", url) + require.NoError(t, err) + + require.Eventually(t, func() bool { + var jobsCount int + require.NoError(t, postgresContainer.DB.QueryRow("SELECT count(*) FROM unionjobsdbmetadata('gw',1) WHERE job_state = 'succeeded'").Scan(&jobsCount)) + t.Logf("gw processedJobCount: %d", jobsCount) + return jobsCount == 5 + }, 20*time.Second, 1*time.Second, "all gw events should be successfully processed") + + require.Eventually(t, func() bool { + var jobsCount int + require.NoError(t, postgresContainer.DB.QueryRow("SELECT count(*) FROM unionjobsdbmetadata('batch_rt',1) WHERE job_state = 'succeeded'").Scan(&jobsCount)) + t.Logf("batch_rt succeeded: %d", jobsCount) + return jobsCount == 10 + }, 200*time.Second, 1*time.Second, "all events should be aborted in batch router") + + var ( + identifiesCount int + usersCount int + ) + + err = sm.DB.DB.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM %q.%q;`, namespace, "IDENTIFIES")).Scan(&identifiesCount) + require.NoError(t, err) + require.Equal(t, 5, identifiesCount) + + err = sm.DB.DB.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM %q.%q;`, namespace, "USERS")).Scan(&usersCount) + require.NoError(t, err) + require.Equal(t, 5, usersCount) + + cancel() + _ = wg.Wait() + }) + + t.Run("namespace does not exists", func(t *testing.T) { + config.Reset() + defer config.Reset() + + pool, err := dockertest.NewPool("") + require.NoError(t, err) + postgresContainer, err := postgres.Setup(pool, t) + require.NoError(t, err) + + gwPort, err := kithelper.GetFreePort() + require.NoError(t, err) + + namespace := randSchema(whutils.SNOWFLAKE) + + destination := backendconfigtest. + NewDestinationBuilder("SNOWPIPE_STREAMING"). + WithID("destination1"). + WithConfigOption("account", keyPairUnEncryptedCredentials.Account). + WithConfigOption("warehouse", keyPairUnEncryptedCredentials.Warehouse). + WithConfigOption("database", keyPairUnEncryptedCredentials.Database). + WithConfigOption("role", keyPairUnEncryptedCredentials.Role). + WithConfigOption("user", keyPairUnEncryptedCredentials.User). + WithConfigOption("useKeyPairAuth", true). + WithConfigOption("privateKey", keyPairUnEncryptedCredentials.PrivateKey). + WithConfigOption("privateKeyPassphrase", keyPairUnEncryptedCredentials.PrivateKeyPassphrase). + WithConfigOption("namespace", namespace). + WithRevisionID("destination1"). + Build() + warehouse := whutils.ModelWarehouse{ + Namespace: namespace, + Destination: destination, + } + + bcServer := backendconfigtest.NewBuilder(). + WithWorkspaceConfig( + backendconfigtest.NewConfigBuilder(). + WithSource( + backendconfigtest.NewSourceBuilder(). + WithID("source1"). + WithWriteKey("writekey1"). + WithConnection(destination). + Build()). + Build()). + Build() + defer bcServer.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + wg, ctx := errgroup.WithContext(ctx) + wg.Go(func() error { + err := runRudderServer(ctx, gwPort, postgresContainer, bcServer.URL, transformerURL, snowpipeClientsURL, t.TempDir()) + if err != nil { + t.Logf("rudder-server exited with error: %v", err) + } + return err + }) + url := fmt.Sprintf("http://localhost:%d", gwPort) + health.WaitUntilReady(ctx, t, url+"/health", 60*time.Second, 10*time.Millisecond, t.Name()) + + sm := snowflake.New(config.New(), logger.NOP, stats.NOP) + require.NoError(t, err) + require.NoError(t, sm.Setup(ctx, warehouse, &whutils.NopUploader{})) + t.Cleanup(func() { sm.Cleanup(ctx) }) + t.Cleanup(func() { dropSchema(t, sm.DB.DB, namespace) }) + + eventFormat := func(index int) string { + return fmt.Sprintf(`{"batch":[{"userId":"%[1]s","type":"%[2]s","context":{"traits":{"trait1":"new-val"},"ip":"14.5.67.21","library":{"name":"http"}},"timestamp":"2020-02-02T00:23:09.544Z"}]}`, + rand.String(10), + "identify", + ) + } + + err = sendEvents(5, eventFormat, "writekey1", url) + require.NoError(t, err) + + require.Eventually(t, func() bool { + var jobsCount int + require.NoError(t, postgresContainer.DB.QueryRow("SELECT count(*) FROM unionjobsdbmetadata('gw',1) WHERE job_state = 'succeeded'").Scan(&jobsCount)) + t.Logf("gw processedJobCount: %d", jobsCount) + return jobsCount == 5 + }, 20*time.Second, 1*time.Second, "all gw events should be successfully processed") + + require.Eventually(t, func() bool { + var jobsCount int + require.NoError(t, postgresContainer.DB.QueryRow("SELECT count(*) FROM unionjobsdbmetadata('batch_rt',1) WHERE job_state = 'succeeded'").Scan(&jobsCount)) + t.Logf("batch_rt succeeded: %d", jobsCount) + return jobsCount == 10 + }, 200*time.Second, 1*time.Second, "all events should be aborted in batch router") + + var ( + identifiesCount int + usersCount int + ) + + err = sm.DB.DB.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM %q.%q;`, namespace, "IDENTIFIES")).Scan(&identifiesCount) + require.NoError(t, err) + require.Equal(t, 5, identifiesCount) + + err = sm.DB.DB.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM %q.%q;`, namespace, "USERS")).Scan(&usersCount) + require.NoError(t, err) + require.Equal(t, 5, usersCount) + + cancel() + _ = wg.Wait() + }) + + t.Run("table does not exists", func(t *testing.T) { + config.Reset() + defer config.Reset() + + pool, err := dockertest.NewPool("") + require.NoError(t, err) + postgresContainer, err := postgres.Setup(pool, t) + require.NoError(t, err) + + gwPort, err := kithelper.GetFreePort() + require.NoError(t, err) + + namespace := randSchema(whutils.SNOWFLAKE) + + destination := backendconfigtest. + NewDestinationBuilder("SNOWPIPE_STREAMING"). + WithID("destination1"). + WithConfigOption("account", keyPairUnEncryptedCredentials.Account). + WithConfigOption("warehouse", keyPairUnEncryptedCredentials.Warehouse). + WithConfigOption("database", keyPairUnEncryptedCredentials.Database). + WithConfigOption("role", keyPairUnEncryptedCredentials.Role). + WithConfigOption("user", keyPairUnEncryptedCredentials.User). + WithConfigOption("useKeyPairAuth", true). + WithConfigOption("privateKey", keyPairUnEncryptedCredentials.PrivateKey). + WithConfigOption("privateKeyPassphrase", keyPairUnEncryptedCredentials.PrivateKeyPassphrase). + WithConfigOption("namespace", namespace). + WithRevisionID("destination1"). + Build() + warehouse := whutils.ModelWarehouse{ + Namespace: namespace, + Destination: destination, + } + + bcServer := backendconfigtest.NewBuilder(). + WithWorkspaceConfig( + backendconfigtest.NewConfigBuilder(). + WithSource( + backendconfigtest.NewSourceBuilder(). + WithID("source1"). + WithWriteKey("writekey1"). + WithConnection(destination). + Build()). + Build()). + Build() + defer bcServer.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + wg, ctx := errgroup.WithContext(ctx) + wg.Go(func() error { + err := runRudderServer(ctx, gwPort, postgresContainer, bcServer.URL, transformerURL, snowpipeClientsURL, t.TempDir()) + if err != nil { + t.Logf("rudder-server exited with error: %v", err) + } + return err + }) + url := fmt.Sprintf("http://localhost:%d", gwPort) + health.WaitUntilReady(ctx, t, url+"/health", 60*time.Second, 10*time.Millisecond, t.Name()) + + sm := snowflake.New(config.New(), logger.NOP, stats.NOP) + require.NoError(t, err) + require.NoError(t, sm.Setup(ctx, warehouse, &whutils.NopUploader{})) + t.Cleanup(func() { sm.Cleanup(ctx) }) + require.NoError(t, sm.CreateSchema(ctx)) + t.Cleanup(func() { dropSchema(t, sm.DB.DB, namespace) }) + + eventFormat := func(index int) string { + return fmt.Sprintf(`{"batch":[{"userId":"%[1]s","type":"%[2]s","context":{"traits":{"trait1":"new-val"},"ip":"14.5.67.21","library":{"name":"http"}},"timestamp":"2020-02-02T00:23:09.544Z"}]}`, + rand.String(10), + "identify", + ) + } + + err = sendEvents(5, eventFormat, "writekey1", url) + require.NoError(t, err) + + require.Eventually(t, func() bool { + var jobsCount int + require.NoError(t, postgresContainer.DB.QueryRow("SELECT count(*) FROM unionjobsdbmetadata('gw',1) WHERE job_state = 'succeeded'").Scan(&jobsCount)) + t.Logf("gw processedJobCount: %d", jobsCount) + return jobsCount == 5 + }, 20*time.Second, 1*time.Second, "all gw events should be successfully processed") + + require.Eventually(t, func() bool { + var jobsCount int + require.NoError(t, postgresContainer.DB.QueryRow("SELECT count(*) FROM unionjobsdbmetadata('batch_rt',1) WHERE job_state = 'succeeded'").Scan(&jobsCount)) + t.Logf("batch_rt succeeded: %d", jobsCount) + return jobsCount == 10 + }, 200*time.Second, 1*time.Second, "all events should be aborted in batch router") + + var ( + identifiesCount int + usersCount int + ) + + err = sm.DB.DB.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM %q.%q;`, namespace, "IDENTIFIES")).Scan(&identifiesCount) + require.NoError(t, err) + require.Equal(t, 5, identifiesCount) + + err = sm.DB.DB.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM %q.%q;`, namespace, "USERS")).Scan(&usersCount) + require.NoError(t, err) + require.Equal(t, 5, usersCount) + + cancel() + _ = wg.Wait() + }) + + t.Run("events with different schema", func(t *testing.T) { + config.Reset() + defer config.Reset() + + pool, err := dockertest.NewPool("") + require.NoError(t, err) + postgresContainer, err := postgres.Setup(pool, t) + require.NoError(t, err) + + gwPort, err := kithelper.GetFreePort() + require.NoError(t, err) + + namespace := randSchema(whutils.SNOWFLAKE) + + destination := backendconfigtest. + NewDestinationBuilder("SNOWPIPE_STREAMING"). + WithID("destination1"). + WithConfigOption("account", keyPairUnEncryptedCredentials.Account). + WithConfigOption("warehouse", keyPairUnEncryptedCredentials.Warehouse). + WithConfigOption("database", keyPairUnEncryptedCredentials.Database). + WithConfigOption("role", keyPairUnEncryptedCredentials.Role). + WithConfigOption("user", keyPairUnEncryptedCredentials.User). + WithConfigOption("useKeyPairAuth", true). + WithConfigOption("privateKey", keyPairUnEncryptedCredentials.PrivateKey). + WithConfigOption("privateKeyPassphrase", keyPairUnEncryptedCredentials.PrivateKeyPassphrase). + WithConfigOption("namespace", namespace). + WithRevisionID("destination1"). + Build() + warehouse := whutils.ModelWarehouse{ + Namespace: namespace, + Destination: destination, + } + + bcServer := backendconfigtest.NewBuilder(). + WithWorkspaceConfig( + backendconfigtest.NewConfigBuilder(). + WithSource( + backendconfigtest.NewSourceBuilder(). + WithID("source1"). + WithWriteKey("writekey1"). + WithConnection(destination). + Build()). + Build()). + Build() + defer bcServer.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + wg, ctx := errgroup.WithContext(ctx) + wg.Go(func() error { + err := runRudderServer(ctx, gwPort, postgresContainer, bcServer.URL, transformerURL, snowpipeClientsURL, t.TempDir()) + if err != nil { + t.Logf("rudder-server exited with error: %v", err) + } + return err + }) + url := fmt.Sprintf("http://localhost:%d", gwPort) + health.WaitUntilReady(ctx, t, url+"/health", 60*time.Second, 10*time.Millisecond, t.Name()) + + sm := snowflake.New(config.New(), logger.NOP, stats.NOP) + require.NoError(t, err) + require.NoError(t, sm.Setup(ctx, warehouse, &whutils.NopUploader{})) + t.Cleanup(func() { sm.Cleanup(ctx) }) + require.NoError(t, sm.CreateSchema(ctx)) + t.Cleanup(func() { dropSchema(t, sm.DB.DB, namespace) }) + + eventFormat := func(index int) string { + return fmt.Sprintf(`{"batch":[{"userId":"%[1]s","type":"%[2]s","context":{"traits":{"trait%[3]d":"new-val"},"ip":"14.5.67.21","library":{"name":"http"}},"timestamp":"2020-02-02T00:23:09.544Z"}]}`, + rand.String(10), + "identify", + index, + ) + } + + err = sendEvents(5, eventFormat, "writekey1", url) + require.NoError(t, err) + + require.Eventually(t, func() bool { + var jobsCount int + require.NoError(t, postgresContainer.DB.QueryRow("SELECT count(*) FROM unionjobsdbmetadata('gw',1) WHERE job_state = 'succeeded'").Scan(&jobsCount)) + t.Logf("gw processedJobCount: %d", jobsCount) + return jobsCount == 5 + }, 20*time.Second, 1*time.Second, "all gw events should be successfully processed") + + require.Eventually(t, func() bool { + var jobsCount int + require.NoError(t, postgresContainer.DB.QueryRow("SELECT count(*) FROM unionjobsdbmetadata('batch_rt',1) WHERE job_state = 'succeeded'").Scan(&jobsCount)) + t.Logf("batch_rt succeeded: %d", jobsCount) + return jobsCount == 10 + }, 200*time.Second, 1*time.Second, "all events should be aborted in batch router") + + var ( + identifiesCount int + usersCount int + ) + + err = sm.DB.DB.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM %q.%q;`, namespace, "IDENTIFIES")).Scan(&identifiesCount) + require.NoError(t, err) + require.Equal(t, 5, identifiesCount) + + err = sm.DB.DB.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM %q.%q;`, namespace, "USERS")).Scan(&usersCount) + require.NoError(t, err) + require.Equal(t, 5, usersCount) + + cancel() + _ = wg.Wait() + }) + + t.Run("addition of new properties", func(t *testing.T) { + config.Reset() + defer config.Reset() + + pool, err := dockertest.NewPool("") + require.NoError(t, err) + postgresContainer, err := postgres.Setup(pool, t) + require.NoError(t, err) + + gwPort, err := kithelper.GetFreePort() + require.NoError(t, err) + + namespace := randSchema(whutils.SNOWFLAKE) + + destination := backendconfigtest. + NewDestinationBuilder("SNOWPIPE_STREAMING"). + WithID("destination1"). + WithConfigOption("account", keyPairUnEncryptedCredentials.Account). + WithConfigOption("warehouse", keyPairUnEncryptedCredentials.Warehouse). + WithConfigOption("database", keyPairUnEncryptedCredentials.Database). + WithConfigOption("role", keyPairUnEncryptedCredentials.Role). + WithConfigOption("user", keyPairUnEncryptedCredentials.User). + WithConfigOption("useKeyPairAuth", true). + WithConfigOption("privateKey", keyPairUnEncryptedCredentials.PrivateKey). + WithConfigOption("privateKeyPassphrase", keyPairUnEncryptedCredentials.PrivateKeyPassphrase). + WithConfigOption("namespace", namespace). + WithRevisionID("destination1"). + Build() + warehouse := whutils.ModelWarehouse{ + Namespace: namespace, + Destination: destination, + } + + bcServer := backendconfigtest.NewBuilder(). + WithWorkspaceConfig( + backendconfigtest.NewConfigBuilder(). + WithSource( + backendconfigtest.NewSourceBuilder(). + WithID("source1"). + WithWriteKey("writekey1"). + WithConnection(destination). + Build()). + Build()). + Build() + defer bcServer.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + wg, ctx := errgroup.WithContext(ctx) + wg.Go(func() error { + err := runRudderServer(ctx, gwPort, postgresContainer, bcServer.URL, transformerURL, snowpipeClientsURL, t.TempDir()) + if err != nil { + t.Logf("rudder-server exited with error: %v", err) + } + return err + }) + url := fmt.Sprintf("http://localhost:%d", gwPort) + health.WaitUntilReady(ctx, t, url+"/health", 60*time.Second, 10*time.Millisecond, t.Name()) + + sm := snowflake.New(config.New(), logger.NOP, stats.NOP) + require.NoError(t, err) + require.NoError(t, sm.Setup(ctx, warehouse, &whutils.NopUploader{})) + t.Cleanup(func() { sm.Cleanup(ctx) }) + require.NoError(t, sm.CreateSchema(ctx)) + t.Cleanup(func() { dropSchema(t, sm.DB.DB, namespace) }) + require.NoError(t, sm.CreateTable(ctx, "IDENTIFIES", whutils.ModelTableSchema{ + "CONTEXT_IP": "string", "CONTEXT_LIBRARY_NAME": "string", "CONTEXT_SOURCE_TYPE": "string", "ORIGINAL_TIMESTAMP": "datetime", "UUID_TS": "datetime", "CONTEXT_DESTINATION_ID": "string", "CONTEXT_DESTINATION_TYPE": "string", "CONTEXT_PASSED_IP": "string", "SENT_AT": "datetime", "TIMESTAMP": "datetime", "CONTEXT_SOURCE_ID": "string", "CONTEXT_TRAITS_TRAIT_1": "string", "CONTEXT_REQUEST_IP": "string", "ID": "string", "RECEIVED_AT": "datetime", "TRAIT_1": "string", "USER_ID": "string", + })) + require.NoError(t, sm.CreateTable(ctx, "USERS", whutils.ModelTableSchema{ + "CONTEXT_DESTINATION_ID": "string", "CONTEXT_IP": "string", "CONTEXT_LIBRARY_NAME": "string", "CONTEXT_PASSED_IP": "string", "CONTEXT_REQUEST_IP": "string", "CONTEXT_TRAITS_TRAIT_1": "string", "RECEIVED_AT": "datetime", "CONTEXT_DESTINATION_TYPE": "string", "CONTEXT_SOURCE_ID": "string", "CONTEXT_SOURCE_TYPE": "string", "ID": "string", "TRAIT_1": "string", "UUID_TS": "datetime", + })) + + eventFormat := func(index int) string { + return fmt.Sprintf(`{"batch":[{"userId":"%[1]s","type":"%[2]s","context":{"traits":{"trait%[3]d":"new-val"},"ip":"14.5.67.21","library":{"name":"http"}},"timestamp":"2020-02-02T00:23:09.544Z"}]}`, + rand.String(10), + "identify", + index, + ) + } + + err = sendEvents(5, eventFormat, "writekey1", url) + require.NoError(t, err) + + require.Eventually(t, func() bool { + var jobsCount int + require.NoError(t, postgresContainer.DB.QueryRow("SELECT count(*) FROM unionjobsdbmetadata('gw',1) WHERE job_state = 'succeeded'").Scan(&jobsCount)) + t.Logf("gw processedJobCount: %d", jobsCount) + return jobsCount == 5 + }, 20*time.Second, 1*time.Second, "all gw events should be successfully processed") + + require.Eventually(t, func() bool { + var jobsCount int + require.NoError(t, postgresContainer.DB.QueryRow("SELECT count(*) FROM unionjobsdbmetadata('batch_rt',1) WHERE job_state = 'succeeded'").Scan(&jobsCount)) + t.Logf("batch_rt succeeded: %d", jobsCount) + return jobsCount == 10 + }, 200*time.Second, 1*time.Second, "all events should be aborted in batch router") + + var ( + identifiesCount int + usersCount int + ) + + err = sm.DB.DB.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM %q.%q;`, namespace, "IDENTIFIES")).Scan(&identifiesCount) + require.NoError(t, err) + require.Equal(t, 5, identifiesCount) + + err = sm.DB.DB.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM %q.%q;`, namespace, "USERS")).Scan(&usersCount) + require.NoError(t, err) + require.Equal(t, 5, usersCount) + + cancel() + _ = wg.Wait() + }) +} + +func runRudderServer(ctx context.Context, port int, postgresContainer *postgres.Resource, cbURL, transformerURL, snowpipeClientsURL, tmpDir string) (err error) { + config.Set("CONFIG_BACKEND_URL", cbURL) + config.Set("WORKSPACE_TOKEN", "token") + config.Set("DB.host", postgresContainer.Host) + config.Set("DB.port", postgresContainer.Port) + config.Set("DB.user", postgresContainer.User) + config.Set("DB.name", postgresContainer.Database) + config.Set("DB.password", postgresContainer.Password) + config.Set("DEST_TRANSFORM_URL", transformerURL) + config.Set("Snowpipe.Client.URL", snowpipeClientsURL) + config.Set("BatchRouter.pollStatusLoopSleep", "1s") + config.Set("BatchRouter.asyncUploadTimeout", "1s") + config.Set("BatchRouter.asyncUploadWorkerTimeout", "1s") + config.Set("BatchRouter.mainLoopFreq", "1s") + config.Set("BatchRouter.uploadFreq", "1s") + config.Set("BatchRouter.isolationMode", "none") + + config.Set("Warehouse.mode", "off") + config.Set("DestinationDebugger.disableEventDeliveryStatusUploads", true) + config.Set("SourceDebugger.disableEventUploads", true) + config.Set("TransformationDebugger.disableTransformationStatusUploads", true) + config.Set("JobsDB.backup.enabled", false) + config.Set("JobsDB.migrateDSLoopSleepDuration", "60m") + config.Set("archival.Enabled", false) + config.Set("Reporting.syncer.enabled", false) + config.Set("BatchRouter.mainLoopFreq", "1s") + config.Set("BatchRouter.uploadFreq", "1s") + config.Set("Gateway.webPort", strconv.Itoa(port)) + config.Set("RUDDER_TMPDIR", os.TempDir()) + config.Set("recovery.storagePath", path.Join(tmpDir, "/recovery_data.json")) + config.Set("recovery.enabled", false) + config.Set("Profiler.Enabled", false) + config.Set("Gateway.enableSuppressUserFeature", false) + + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panicked: %v", r) + } + }() + r := runner.New(runner.ReleaseInfo{EnterpriseToken: "TOKEN"}) + c := r.Run(ctx, []string{"proc-isolation-test-rudder-server"}) + if c != 0 { + err = fmt.Errorf("rudder-server exited with a non-0 exit code: %d", c) + } + return +} + +func sendEvents(num int, eventFormat func(index int) string, writeKey, url string) error { // nolint:unparam + for i := 0; i < num; i++ { + payload := []byte(eventFormat(i)) + req, err := http.NewRequest(http.MethodPost, url+"/v1/batch", bytes.NewReader(payload)) + if err != nil { + return err + } + req.SetBasicAuth(writeKey, "password") + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + return fmt.Errorf("failed to send event to rudder server, status code: %d: %s", resp.StatusCode, string(b)) + } + func() { kithttputil.CloseResponse(resp) }() + } + return nil +} + +func dropSchema(t *testing.T, db *sql.DB, namespace string) { + t.Helper() + t.Log("dropping schema", namespace) + + require.Eventually(t, + func() bool { + _, err := db.ExecContext(context.Background(), fmt.Sprintf(`DROP SCHEMA %q CASCADE;`, namespace)) + if err != nil { + t.Logf("error deleting schema %q: %v", namespace, err) + return false + } + return true + }, + time.Minute, + time.Second, + ) +} diff --git a/integration_test/snowpipestreaming/testdata/docker-compose.rudder-snowpipe-clients.yml b/integration_test/snowpipestreaming/testdata/docker-compose.rudder-snowpipe-clients.yml new file mode 100644 index 0000000000..e6690dce64 --- /dev/null +++ b/integration_test/snowpipestreaming/testdata/docker-compose.rudder-snowpipe-clients.yml @@ -0,0 +1,11 @@ +version: "3.9" + +services: + rudder-snowpipe-clients: + image: "rudderstack/rudder-snowpipe-clients:chore.snowpipe-poc" + ports: + - "9078" + healthcheck: + test: wget --no-verbose --tries=1 --spider http://localhost:9078/q/health || exit 1 + interval: 1s + retries: 25 diff --git a/integration_test/snowpipestreaming/testdata/docker-compose.rudder-transformer.yml b/integration_test/snowpipestreaming/testdata/docker-compose.rudder-transformer.yml new file mode 100644 index 0000000000..61bfb370ed --- /dev/null +++ b/integration_test/snowpipestreaming/testdata/docker-compose.rudder-transformer.yml @@ -0,0 +1,11 @@ +version: "3.9" + +services: + transformer: + image: "rudderstack/develop-rudder-transformer:feat.snowpipe-streaming" + ports: + - "9090:9090" + healthcheck: + test: wget --no-verbose --tries=1 --spider http://0.0.0.0:9090/health || exit 1 + interval: 1s + retries: 25 diff --git a/integration_test/warehouse/integration_test.go b/integration_test/warehouse/warehouse_test.go similarity index 99% rename from integration_test/warehouse/integration_test.go rename to integration_test/warehouse/warehouse_test.go index 2b64911686..2720b1c9b3 100644 --- a/integration_test/warehouse/integration_test.go +++ b/integration_test/warehouse/warehouse_test.go @@ -33,6 +33,7 @@ import ( kithelper "github.com/rudderlabs/rudder-go-kit/testhelper" "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource/minio" "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource/postgres" + "github.com/rudderlabs/rudder-server/admin" "github.com/rudderlabs/rudder-server/app" backendconfig "github.com/rudderlabs/rudder-server/backend-config" diff --git a/router/batchrouter/asyncdestinationmanager/common/utils.go b/router/batchrouter/asyncdestinationmanager/common/utils.go index 8b541ea016..df711f816e 100644 --- a/router/batchrouter/asyncdestinationmanager/common/utils.go +++ b/router/batchrouter/asyncdestinationmanager/common/utils.go @@ -3,7 +3,7 @@ package common import "slices" var ( - asyncDestinations = []string{"MARKETO_BULK_UPLOAD", "BINGADS_AUDIENCE", "ELOQUA", "YANDEX_METRICA_OFFLINE_EVENTS", "BINGADS_OFFLINE_CONVERSIONS", "KLAVIYO_BULK_UPLOAD", "LYTICS_BULK_UPLOAD"} + asyncDestinations = []string{"MARKETO_BULK_UPLOAD", "BINGADS_AUDIENCE", "ELOQUA", "YANDEX_METRICA_OFFLINE_EVENTS", "BINGADS_OFFLINE_CONVERSIONS", "KLAVIYO_BULK_UPLOAD", "LYTICS_BULK_UPLOAD", "SNOWPIPE_STREAMING"} sftpDestinations = []string{"SFTP"} ) diff --git a/router/batchrouter/asyncdestinationmanager/manager.go b/router/batchrouter/asyncdestinationmanager/manager.go index 59dd7fe748..527d957754 100644 --- a/router/batchrouter/asyncdestinationmanager/manager.go +++ b/router/batchrouter/asyncdestinationmanager/manager.go @@ -16,6 +16,7 @@ import ( lyticsBulkUpload "github.com/rudderlabs/rudder-server/router/batchrouter/asyncdestinationmanager/lytics_bulk_upload" marketobulkupload "github.com/rudderlabs/rudder-server/router/batchrouter/asyncdestinationmanager/marketo-bulk-upload" "github.com/rudderlabs/rudder-server/router/batchrouter/asyncdestinationmanager/sftp" + "github.com/rudderlabs/rudder-server/router/batchrouter/asyncdestinationmanager/snowpipestreaming" "github.com/rudderlabs/rudder-server/router/batchrouter/asyncdestinationmanager/yandexmetrica" ) @@ -41,6 +42,8 @@ func newRegularManager( return klaviyobulkupload.NewManager(logger, statsFactory, destination) case "LYTICS_BULK_UPLOAD": return lyticsBulkUpload.NewManager(logger, statsFactory, destination) + case "SNOWPIPE_STREAMING": + return snowpipestreaming.New(conf, logger, statsFactory, destination), nil } return nil, errors.New("invalid destination type") } diff --git a/router/batchrouter/asyncdestinationmanager/snowpipestreaming/createchannel.go b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/createchannel.go new file mode 100644 index 0000000000..a4f67d035b --- /dev/null +++ b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/createchannel.go @@ -0,0 +1,72 @@ +package snowpipestreaming + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + + "github.com/rudderlabs/rudder-go-kit/httputil" +) + +type accountConfig struct { + Account string `json:"account"` + User string `json:"user"` + Role string `json:"role"` + PrivateKey string `json:"privateKey"` + PrivateKeyPassphrase string `json:"privateKeyPassphrase"` +} + +type tableConfig struct { + Database string `json:"database"` + Schema string `json:"schema"` + Table string `json:"table"` +} + +type createChannelRequest struct { + RudderIdentifier string `json:"rudderIdentifier"` + Partition string `json:"partition"` + AccountConfig accountConfig `json:"account"` + TableConfig tableConfig `json:"table"` +} + +type createChannelResponse struct { + ChannelID string `json:"channelId"` + ChannelName string `json:"channelName"` + ClientName string `json:"clientName"` + Valid bool `json:"valid"` + TableSchema map[string]map[string]any `json:"tableSchema"` +} + +func (m *Manager) createChannel(ctx context.Context, channelReq *createChannelRequest) (*createChannelResponse, error) { + reqJSON, err := json.Marshal(channelReq) + if err != nil { + return nil, fmt.Errorf("marshalling create channel request: %w", err) + } + + channelReqURL := m.config.clientURL + "/channels" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, channelReqURL, bytes.NewBuffer(reqJSON)) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, reqErr := m.requestDoer.Do(req) + if reqErr != nil { + return nil, fmt.Errorf("sending request: %w", reqErr) + } + defer func() { httputil.CloseResponse(resp) }() + + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("invalid status code: %d, body: %s", resp.StatusCode, string(b)) + } + + var res createChannelResponse + if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + return nil, fmt.Errorf("decoding response: %w", err) + } + + return &res, nil +} diff --git a/router/batchrouter/asyncdestinationmanager/snowpipestreaming/createchannel_test.go b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/createchannel_test.go new file mode 100644 index 0000000000..955928240c --- /dev/null +++ b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/createchannel_test.go @@ -0,0 +1,136 @@ +package snowpipestreaming + +import ( + "bytes" + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-go-kit/stats" + + "github.com/rudderlabs/rudder-server/testhelper/backendconfigtest" +) + +func TestCreateChannel(t *testing.T) { + ccr := &createChannelRequest{ + RudderIdentifier: "rudderIdentifier", + Partition: "partition", + AccountConfig: accountConfig{ + Account: "account", + User: "user", + Role: "role", + PrivateKey: "privateKey", + PrivateKeyPassphrase: "privateKeyPassphrase", + }, + TableConfig: tableConfig{ + Database: "database", + Schema: "schema", + Table: "table", + }, + } + + snowpipeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + require.Equal(t, "application/json", r.Header.Get("Content-Type")) + + body, err := io.ReadAll(r.Body) + require.NoError(t, err) + require.NoError(t, r.Body.Close()) + require.JSONEq(t, `{"rudderIdentifier":"rudderIdentifier","partition":"partition","account":{"account":"account","user":"user","role":"role","privateKey":"privateKey","privateKeyPassphrase":"privateKeyPassphrase"},"table":{"database":"database","schema":"schema","table":"table"}}`, string(body)) + + switch r.URL.String() { + case "/channels": + _, err := w.Write([]byte(`{"channelId":"channelId","channelName":"channelName","clientName":"clientName","valid":true}`)) + require.NoError(t, err) + default: + require.FailNowf(t, "SnowpipeClients", "Unexpected %s to SnowpipeClients, not found: %+v", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + } + })) + defer snowpipeServer.Close() + + t.Run("Success", func(t *testing.T) { + ctx := context.Background() + destination := backendconfigtest. + NewDestinationBuilder("SNOWPIPE_STREAMING"). + Build() + + c := config.New() + c.Set("Snowpipe.Client.URL", snowpipeServer.URL) + + manager := New(c, logger.NOP, stats.NOP, &destination, WithRequestDoer(snowpipeServer.Client())) + res, err := manager.createChannel(ctx, ccr) + require.NoError(t, err) + require.Equal(t, "channelId", res.ChannelID) + require.Equal(t, "channelName", res.ChannelName) + require.Equal(t, "clientName", res.ClientName) + require.True(t, res.Valid) + }) + t.Run("Request failure", func(t *testing.T) { + ctx := context.Background() + destination := backendconfigtest. + NewDestinationBuilder("SNOWPIPE_STREAMING"). + Build() + + c := config.New() + c.Set("Snowpipe.Client.URL", snowpipeServer.URL) + + reqDoer := &mockRequestDoer{ + err: errors.New("bad client"), + } + + manager := New(c, logger.NOP, stats.NOP, &destination, WithRequestDoer(reqDoer)) + res, err := manager.createChannel(ctx, ccr) + require.Error(t, err) + require.Nil(t, res) + }) + t.Run("Request failure (non 200's status code)", func(t *testing.T) { + ctx := context.Background() + destination := backendconfigtest. + NewDestinationBuilder("SNOWPIPE_STREAMING"). + Build() + + c := config.New() + c.Set("Snowpipe.Client.URL", snowpipeServer.URL) + + reqDoer := &mockRequestDoer{ + response: &http.Response{ + StatusCode: http.StatusBadRequest, + Body: nopReadCloser{Reader: bytes.NewReader([]byte(`{}`))}, + }, + } + + manager := New(c, logger.NOP, stats.NOP, &destination, WithRequestDoer(reqDoer)) + res, err := manager.createChannel(ctx, ccr) + require.Error(t, err) + require.Nil(t, res) + }) + t.Run("Request failure (invalid response)", func(t *testing.T) { + ctx := context.Background() + destination := backendconfigtest. + NewDestinationBuilder("SNOWPIPE_STREAMING"). + Build() + + c := config.New() + c.Set("Snowpipe.Client.URL", snowpipeServer.URL) + + reqDoer := &mockRequestDoer{ + response: &http.Response{ + StatusCode: http.StatusOK, + Body: nopReadCloser{Reader: bytes.NewReader([]byte(`{abd}`))}, + }, + } + + manager := New(c, logger.NOP, stats.NOP, &destination, WithRequestDoer(reqDoer)) + res, err := manager.createChannel(ctx, ccr) + require.Error(t, err) + require.Nil(t, res) + }) +} diff --git a/router/batchrouter/asyncdestinationmanager/snowpipestreaming/deletechannel.go b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/deletechannel.go new file mode 100644 index 0000000000..28d6307b7f --- /dev/null +++ b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/deletechannel.go @@ -0,0 +1,38 @@ +package snowpipestreaming + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + + "github.com/rudderlabs/rudder-go-kit/httputil" +) + +func (m *Manager) deleteChannel(ctx context.Context, channelReq *createChannelRequest) error { + reqJSON, err := json.Marshal(channelReq) + if err != nil { + return fmt.Errorf("marshalling create channel request: %w", err) + } + + channelReqURL := m.config.clientURL + "/channels" + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, channelReqURL, bytes.NewBuffer(reqJSON)) + if err != nil { + return fmt.Errorf("creating request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, reqErr := m.requestDoer.Do(req) + if reqErr != nil { + return fmt.Errorf("sending request: %w", reqErr) + } + defer func() { httputil.CloseResponse(resp) }() + + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + return fmt.Errorf("invalid status code: %d, body: %s", resp.StatusCode, string(b)) + } + + return nil +} diff --git a/router/batchrouter/asyncdestinationmanager/snowpipestreaming/deletechannel_test.go b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/deletechannel_test.go new file mode 100644 index 0000000000..96887178e5 --- /dev/null +++ b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/deletechannel_test.go @@ -0,0 +1,108 @@ +package snowpipestreaming + +import ( + "bytes" + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-go-kit/stats" + + "github.com/rudderlabs/rudder-server/testhelper/backendconfigtest" +) + +func TestDeleteChannel(t *testing.T) { + ccr := &createChannelRequest{ + RudderIdentifier: "rudderIdentifier", + Partition: "partition", + AccountConfig: accountConfig{ + Account: "account", + User: "user", + Role: "role", + PrivateKey: "privateKey", + PrivateKeyPassphrase: "privateKeyPassphrase", + }, + TableConfig: tableConfig{ + Database: "database", + Schema: "schema", + Table: "table", + }, + } + + snowpipeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodDelete, r.Method) + require.Equal(t, "application/json", r.Header.Get("Content-Type")) + + body, err := io.ReadAll(r.Body) + require.NoError(t, err) + require.NoError(t, r.Body.Close()) + require.JSONEq(t, `{"rudderIdentifier":"rudderIdentifier","partition":"partition","account":{"account":"account","user":"user","role":"role","privateKey":"privateKey","privateKeyPassphrase":"privateKeyPassphrase"},"table":{"database":"database","schema":"schema","table":"table"}}`, string(body)) + + switch r.URL.String() { + case "/channels": + w.WriteHeader(http.StatusOK) + default: + require.FailNowf(t, "SnowpipeClients", "Unexpected %s to SnowpipeClients, not found: %+v", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + } + })) + defer snowpipeServer.Close() + + t.Run("Success", func(t *testing.T) { + ctx := context.Background() + destination := backendconfigtest. + NewDestinationBuilder("SNOWPIPE_STREAMING"). + Build() + + c := config.New() + c.Set("Snowpipe.Client.URL", snowpipeServer.URL) + + manager := New(c, logger.NOP, stats.NOP, &destination, WithRequestDoer(snowpipeServer.Client())) + err := manager.deleteChannel(ctx, ccr) + require.NoError(t, err) + }) + t.Run("Request failure", func(t *testing.T) { + ctx := context.Background() + destination := backendconfigtest. + NewDestinationBuilder("SNOWPIPE_STREAMING"). + Build() + + c := config.New() + c.Set("Snowpipe.Client.URL", snowpipeServer.URL) + + reqDoer := &mockRequestDoer{ + err: errors.New("bad client"), + } + + manager := New(c, logger.NOP, stats.NOP, &destination, WithRequestDoer(reqDoer)) + err := manager.deleteChannel(ctx, ccr) + require.Error(t, err) + }) + t.Run("Request failure (non 200's status code)", func(t *testing.T) { + ctx := context.Background() + destination := backendconfigtest. + NewDestinationBuilder("SNOWPIPE_STREAMING"). + Build() + + c := config.New() + c.Set("Snowpipe.Client.URL", snowpipeServer.URL) + + reqDoer := &mockRequestDoer{ + response: &http.Response{ + StatusCode: http.StatusBadRequest, + Body: nopReadCloser{Reader: bytes.NewReader([]byte(`{}`))}, + }, + } + + manager := New(c, logger.NOP, stats.NOP, &destination, WithRequestDoer(reqDoer)) + err := manager.deleteChannel(ctx, ccr) + require.Error(t, err) + }) +} diff --git a/router/batchrouter/asyncdestinationmanager/snowpipestreaming/insert.go b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/insert.go new file mode 100644 index 0000000000..a0eb667b6f --- /dev/null +++ b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/insert.go @@ -0,0 +1,77 @@ +package snowpipestreaming + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + + "github.com/samber/lo" + + "github.com/rudderlabs/rudder-server/utils/httputil" +) + +type Row map[string]any + +type insertRequest struct { + Rows []Row `json:"rows"` + Offset string `json:"offset"` +} + +type insertError struct { + RowIndex int64 `json:"rowIndex"` + ExtraColNames []string `json:"extraColNames"` + MissingNotNullColNames []string `json:"missingNotNullColNames"` + NullValueForNotNullColNames []string `json:"nullvalueForNotNullColNames"` +} + +type insertResponse struct { + Success bool `json:"success"` + Errors []insertError `json:"errors"` +} + +// extraColumns returns the extra columns present in the insert errors. +func (i *insertResponse) extraColumns() []string { + extraColNamesSet := make(map[string]struct{}) + for _, err := range i.Errors { + for _, colName := range err.ExtraColNames { + if _, exists := extraColNamesSet[colName]; !exists { + extraColNamesSet[colName] = struct{}{} + } + } + } + return lo.Keys(extraColNamesSet) +} + +func (m *Manager) insert(ctx context.Context, channelId string, insertRequest *insertRequest) (*insertResponse, error) { + reqJSON, err := json.Marshal(insertRequest) + if err != nil { + return nil, fmt.Errorf("marshalling insert request: %w", err) + } + + insertReqURL := m.config.clientURL + "/channels/" + channelId + "/insert" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, insertReqURL, bytes.NewBuffer(reqJSON)) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, reqErr := m.requestDoer.Do(req) + if reqErr != nil { + return nil, fmt.Errorf("sending request: %w", reqErr) + } + defer func() { httputil.CloseResponse(resp) }() + + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("invalid status code: %d, body: %s", resp.StatusCode, string(b)) + } + + var res insertResponse + if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + return nil, fmt.Errorf("decoding response: %w", err) + } + + return &res, nil +} diff --git a/router/batchrouter/asyncdestinationmanager/snowpipestreaming/insert_test.go b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/insert_test.go new file mode 100644 index 0000000000..b5fc958489 --- /dev/null +++ b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/insert_test.go @@ -0,0 +1,125 @@ +package snowpipestreaming + +import ( + "bytes" + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-go-kit/stats" + + "github.com/rudderlabs/rudder-server/testhelper/backendconfigtest" +) + +func TestInsert(t *testing.T) { + var ( + channelID = "channelID" + ir = &insertRequest{Rows: []Row{{"key1": "value1"}, {"key2": "value2"}}, Offset: "5"} + ) + + snowpipeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + require.Equal(t, "application/json", r.Header.Get("Content-Type")) + + body, err := io.ReadAll(r.Body) + require.NoError(t, err) + require.NoError(t, r.Body.Close()) + require.JSONEq(t, `{"rows":[{"key1":"value1"},{"key2":"value2"}],"offset":"5"}`, string(body)) + + switch r.URL.String() { + case "/channels/" + channelID + "/insert": + _, err := w.Write([]byte(`{"success":true,"errors":[]}`)) + require.NoError(t, err) + default: + require.FailNowf(t, "SnowpipeClients", "Unexpected %s to SnowpipeClients, not found: %+v", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + } + })) + defer snowpipeServer.Close() + + t.Run("Success", func(t *testing.T) { + ctx := context.Background() + destination := backendconfigtest. + NewDestinationBuilder("SNOWPIPE_STREAMING"). + Build() + + c := config.New() + c.Set("Snowpipe.Client.URL", snowpipeServer.URL) + + manager := New(c, logger.NOP, stats.NOP, &destination, WithRequestDoer(snowpipeServer.Client())) + res, err := manager.insert(ctx, channelID, ir) + require.NoError(t, err) + require.True(t, res.Success) + require.Empty(t, res.Errors) + }) + t.Run("Request failure", func(t *testing.T) { + ctx := context.Background() + destination := backendconfigtest. + NewDestinationBuilder("SNOWPIPE_STREAMING"). + Build() + + c := config.New() + c.Set("Snowpipe.Client.URL", snowpipeServer.URL) + + reqDoer := &mockRequestDoer{ + err: errors.New("bad client"), + response: &http.Response{ + StatusCode: http.StatusOK, + }, + } + + manager := New(c, logger.NOP, stats.NOP, &destination, WithRequestDoer(reqDoer)) + res, err := manager.insert(ctx, channelID, ir) + require.Error(t, err) + require.Nil(t, res) + }) + t.Run("Request failure (non 200's status code)", func(t *testing.T) { + ctx := context.Background() + destination := backendconfigtest. + NewDestinationBuilder("SNOWPIPE_STREAMING"). + Build() + + c := config.New() + c.Set("Snowpipe.Client.URL", snowpipeServer.URL) + + reqDoer := &mockRequestDoer{ + response: &http.Response{ + StatusCode: http.StatusBadRequest, + Body: nopReadCloser{Reader: bytes.NewReader([]byte(`{}`))}, + }, + } + + manager := New(c, logger.NOP, stats.NOP, &destination, WithRequestDoer(reqDoer)) + res, err := manager.insert(ctx, channelID, ir) + require.Error(t, err) + require.Nil(t, res) + }) + t.Run("Request failure (invalid response)", func(t *testing.T) { + ctx := context.Background() + destination := backendconfigtest. + NewDestinationBuilder("SNOWPIPE_STREAMING"). + Build() + + c := config.New() + c.Set("Snowpipe.Client.URL", snowpipeServer.URL) + + reqDoer := &mockRequestDoer{ + response: &http.Response{ + StatusCode: http.StatusOK, + Body: nopReadCloser{Reader: bytes.NewReader([]byte(`{abd}`))}, + }, + } + + manager := New(c, logger.NOP, stats.NOP, &destination, WithRequestDoer(reqDoer)) + res, err := manager.insert(ctx, channelID, ir) + require.Error(t, err) + require.Nil(t, res) + }) +} diff --git a/router/batchrouter/asyncdestinationmanager/snowpipestreaming/options.go b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/options.go new file mode 100644 index 0000000000..d7a1bedd15 --- /dev/null +++ b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/options.go @@ -0,0 +1,9 @@ +package snowpipestreaming + +type Opt func(*Manager) + +func WithRequestDoer(requestDoer requestDoer) Opt { + return func(s *Manager) { + s.requestDoer = requestDoer + } +} diff --git a/router/batchrouter/asyncdestinationmanager/snowpipestreaming/snowpipestreaming.go b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/snowpipestreaming.go new file mode 100644 index 0000000000..64605ddb24 --- /dev/null +++ b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/snowpipestreaming.go @@ -0,0 +1,488 @@ +package snowpipestreaming + +import ( + "bufio" + "context" + "errors" + "fmt" + "net/http" + "os" + "strconv" + "strings" + "time" + + jsoniter "github.com/json-iterator/go" + "github.com/mitchellh/mapstructure" + "github.com/samber/lo" + "golang.org/x/sync/errgroup" + + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-go-kit/stats" + obskit "github.com/rudderlabs/rudder-observability-kit/go/labels" + + backendconfig "github.com/rudderlabs/rudder-server/backend-config" + "github.com/rudderlabs/rudder-server/jobsdb" + "github.com/rudderlabs/rudder-server/router/batchrouter/asyncdestinationmanager/common" + "github.com/rudderlabs/rudder-server/warehouse/integrations/snowflake" + whutils "github.com/rudderlabs/rudder-server/warehouse/utils" +) + +var json = jsoniter.ConfigCompatibleWithStandardLibrary + +type requestDoer interface { + Do(*http.Request) (*http.Response, error) +} + +type event struct { + Message struct { + Metadata struct { + Table string `json:"table"` + Columns map[string]string `json:"columns"` + } `json:"metadata"` + Data map[string]any `json:"data"` + } `json:"message"` + Metadata struct { + JobID int64 `json:"job_id"` + } +} + +type destConfig struct { + Account string `mapstructure:"account"` + Warehouse string `mapstructure:"warehouse"` + Database string `mapstructure:"database"` + User string `mapstructure:"user"` + Role string `mapstructure:"role"` + PrivateKey string `mapstructure:"privateKey"` + PrivateKeyPassphrase string `mapstructure:"privateKeyPassphrase"` + Namespace string `mapstructure:"namespace"` +} + +type channelIDOffset struct { + ChannelID string `json:"channelId"` + Offset string `json:"offset"` +} + +type Manager struct { + conf *config.Config + logger logger.Logger + statsFactory stats.Stats + destination *backendconfig.DestinationT + requestDoer requestDoer + + config struct { + client struct { + maxHTTPConnections int + maxHTTPIdleConnections int + maxIdleConnDuration time.Duration + disableKeepAlives bool + timeoutDuration time.Duration + } + + clientURL string + instanceID string + pollFrequency time.Duration + } + + stats struct { + successJobCount stats.Counter + failedJobCount stats.Counter + } +} + +func New( + conf *config.Config, + logger logger.Logger, + statsFactory stats.Stats, + destination *backendconfig.DestinationT, + opts ...Opt, +) *Manager { + m := &Manager{ + conf: conf, + logger: logger.Child("snowpipestreaming").Withn(obskit.WorkspaceID(destination.WorkspaceID), obskit.DestinationID(destination.ID)), + statsFactory: statsFactory, + destination: destination, + } + m.config.client.maxHTTPConnections = conf.GetInt("Snowpipe.Client.maxHTTPConnections", 20) + m.config.client.maxHTTPIdleConnections = conf.GetInt("Snowpipe.Client.maxHTTPIdleConnections", 10) + m.config.client.maxIdleConnDuration = conf.GetDuration("Snowpipe.Client.maxIdleConnDuration", 30, time.Second) + m.config.client.disableKeepAlives = conf.GetBool("Snowpipe.Client.disableKeepAlives", true) + m.config.client.timeoutDuration = conf.GetDuration("Snowpipe.Client.timeout", 600, time.Second) + m.config.clientURL = conf.GetString("Snowpipe.Client.URL", "http://localhost:9078") + m.config.instanceID = conf.GetString("INSTANCE_ID", "1") + m.config.pollFrequency = conf.GetDuration("Snowpipe.Client.pollFrequency", 1, time.Second) + + m.stats.failedJobCount = statsFactory.NewTaggedStat("snowpipestreaming_failed_jobs_count", stats.CountType, stats.Tags{ + "module": "batch_router", + "destType": destination.DestinationDefinition.Name, + }) + m.stats.successJobCount = statsFactory.NewTaggedStat("snowpipestreaming_success_job_count", stats.CountType, stats.Tags{ + "module": "batch_router", + "destType": destination.DestinationDefinition.Name, + }) + + for _, opt := range opts { + opt(m) + } + if m.requestDoer == nil { + m.requestDoer = &http.Client{ + Transport: &http.Transport{ + DisableKeepAlives: m.config.client.disableKeepAlives, + MaxConnsPerHost: m.config.client.maxHTTPConnections, + MaxIdleConnsPerHost: m.config.client.maxHTTPIdleConnections, + IdleConnTimeout: m.config.client.maxIdleConnDuration, + }, + Timeout: m.config.client.timeoutDuration, + } + } + return m +} + +func (m *Manager) Transform(job *jobsdb.JobT) (string, error) { + return common.GetMarshalledData(string(job.EventPayload), job.JobID) +} + +func (m *Manager) Upload(asyncDestStruct *common.AsyncDestinationStruct) common.AsyncUploadOutput { + m.logger.Infon("Uploading data to snowpipe streaming destination") + + var destConfig destConfig + err := mapstructure.Decode(asyncDestStruct.Destination.Config, &destConfig) + if err != nil { + return m.abortJobs(asyncDestStruct, fmt.Errorf("failed to decode destination config: %v", err).Error()) + } + + file, err := os.Open(asyncDestStruct.FileName) + if err != nil { + return m.abortJobs(asyncDestStruct, fmt.Errorf("failed to open file: %v", err).Error()) + } + defer func() { + _ = file.Close() + }() + + var ( + events []event + failedJobIDs, successJobIDs []int64 + channelIFOffsets []channelIDOffset + ) + + ctx := context.Background() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + var e event + err = json.Unmarshal(scanner.Bytes(), &e) + if err != nil { + return m.abortJobs(asyncDestStruct, fmt.Errorf("failed to unmarshal event: %v", err).Error()) + } + + events = append(events, e) + } + m.logger.Infon("Read events from file", logger.NewIntField("events", int64(len(events)))) + + eventsByTable := lo.GroupBy(events, func(event event) string { + return event.Message.Metadata.Table + }) + + for table, tableEvents := range eventsByTable { + m.logger.Infon("Uploading data to table", logger.NewStringField("table", table), logger.NewIntField("events", int64(len(tableEvents)))) + + jobIDs := lo.Map(tableEvents, func(event event, _ int) int64 { + return event.Metadata.JobID + }) + + channelID, offset, err := m.sendEventsToSnowpipe(ctx, asyncDestStruct, destConfig, table, tableEvents) + if err != nil { + m.logger.Warnn("Failed to send events to snowpipe", obskit.Error(err), logger.NewStringField("table", table)) + failedJobIDs = append(failedJobIDs, jobIDs...) + continue + } + + successJobIDs = append(successJobIDs, jobIDs...) + channelIFOffsets = append(channelIFOffsets, channelIDOffset{ + ChannelID: channelID, + Offset: offset, + }) + + m.logger.Infon("Successfully uploaded data to table", logger.NewStringField("table", table), logger.NewIntField("events", int64(len(tableEvents)))) + } + importID, err := json.Marshal(channelIFOffsets) + if err != nil { + return m.abortJobs(asyncDestStruct, fmt.Errorf("failed to marshal import id: %v", err).Error()) + } + + importParameters, err := json.Marshal(common.ImportParameters{ + ImportId: string(importID), + }) + if err != nil { + return m.abortJobs(asyncDestStruct, fmt.Errorf("failed to marshal import parameters: %v", err).Error()) + } + + m.logger.Infon("Successfully uploaded data to snowpipe streaming destination") + + m.stats.failedJobCount.Count(len(failedJobIDs)) + m.stats.successJobCount.Count(len(successJobIDs)) + + return common.AsyncUploadOutput{ + ImportingJobIDs: successJobIDs, + ImportingParameters: importParameters, + FailedJobIDs: failedJobIDs, + FailedCount: len(failedJobIDs), + DestinationID: asyncDestStruct.Destination.ID, + } +} + +func (m *Manager) sendEventsToSnowpipe( + ctx context.Context, + asyncDestStruct *common.AsyncDestinationStruct, + destConf destConfig, + table string, + events []event, +) (string, string, error) { + tableSchema := tableSchemaFromEvents(events) + + channelReq := &createChannelRequest{ + RudderIdentifier: asyncDestStruct.Destination.ID, + Partition: m.config.instanceID, + AccountConfig: accountConfig{ + Account: destConf.Account, + User: destConf.User, + Role: destConf.Role, + PrivateKey: whutils.FormatPemContent(destConf.PrivateKey), + PrivateKeyPassphrase: destConf.PrivateKeyPassphrase, + }, + TableConfig: tableConfig{ + Database: destConf.Database, + Schema: destConf.Namespace, + Table: table, + }, + } + channelResponse, err := m.createChannelWithRetries(ctx, channelReq, tableSchema) + if err != nil { + return "", "", fmt.Errorf("failed to create channel: %v", err) + } + + oldestEvent := lo.MaxBy(events, func(a, b event) bool { + return a.Metadata.JobID > b.Metadata.JobID + }) + offset := strconv.FormatInt(oldestEvent.Metadata.JobID, 10) + + insertReq := &insertRequest{ + Rows: lo.Map(events, func(event event, _ int) Row { + return event.Message.Data + }), + Offset: offset, + } + channelID, offset, err := m.insertWithRetries(ctx, channelReq, channelResponse, insertReq, tableSchema) + if err != nil { + return "", "", fmt.Errorf("failed to insert data: %v", err) + } + return channelID, offset, nil +} + +// Iterate over events and merge their columns into the final map +// Keeping the first type first serve basis +func tableSchemaFromEvents(events []event) whutils.ModelTableSchema { + columnsMap := make(whutils.ModelTableSchema) + for _, e := range events { + for col, typ := range e.Message.Metadata.Columns { + if _, ok := columnsMap[col]; !ok { + columnsMap[col] = typ + } + } + } + return columnsMap +} + +func (m *Manager) createChannelWithRetries( + ctx context.Context, + channelReq *createChannelRequest, + tableSchema map[string]string, +) (*createChannelResponse, error) { + res, err := m.createChannel(ctx, channelReq) + if err == nil { + return res, nil + } + + // checking if the errors is around The supplied schema does not exist or is not authorized. + if strings.Contains(err.Error(), "The supplied schema does not exist or is not authorized") { + sm, err := m.createSnowflakeManager(ctx, channelReq.TableConfig.Schema) + if err != nil { + return nil, fmt.Errorf("creating snowflake manager: %v", err) + } + + err = sm.CreateSchema(ctx) + if err != nil { + return nil, fmt.Errorf("creating schema: %v", err) + } + + err = sm.CreateTable(ctx, channelReq.TableConfig.Table, tableSchema) + if err != nil { + return nil, fmt.Errorf("creating table: %v", err) + } + + return m.createChannel(ctx, channelReq) + } + + // The supplied table does not exist or is not authorized. + if strings.Contains(err.Error(), "The supplied table does not exist or is not authorized") { + sm, err := m.createSnowflakeManager(ctx, channelReq.TableConfig.Schema) + if err != nil { + return nil, fmt.Errorf("creating snowflake manager: %v", err) + } + + err = sm.CreateTable(ctx, channelReq.TableConfig.Table, tableSchema) + if err != nil { + return nil, fmt.Errorf("creating table: %v", err) + } + + return m.createChannel(ctx, channelReq) + } + + return nil, fmt.Errorf("creating channel: %v", err) +} + +func (m *Manager) createSnowflakeManager(ctx context.Context, namespace string) (*snowflake.Snowflake, error) { + warehouse := whutils.ModelWarehouse{ + Namespace: namespace, + Destination: *m.destination, + } + + // Since currently we are using key pair auth for snowflake, we need to set this flag to true + warehouse.Destination.Config["useKeyPairAuth"] = true + + sf := snowflake.New(m.conf, m.logger, m.statsFactory) + err := sf.Setup(ctx, warehouse, &whutils.NopUploader{}) + if err != nil { + return nil, fmt.Errorf("failed to setup snowflake manager: %v", err) + } + return sf, nil +} + +func (m *Manager) insertWithRetries( + ctx context.Context, + channelReq *createChannelRequest, + channelRes *createChannelResponse, + insertReq *insertRequest, + schemaInEvents whutils.ModelTableSchema, +) (string, string, error) { + res, err := m.insert(ctx, channelRes.ChannelID, insertReq) + if err == nil && res.Success { + return channelRes.ChannelID, insertReq.Offset, nil + } + if err != nil { + return "", "", fmt.Errorf("failed to insert data: %v", err) + } + extraColumns := res.extraColumns() + + sm, err := m.createSnowflakeManager(ctx, channelReq.TableConfig.Schema) + if err != nil { + return "", "", fmt.Errorf("creating snowflake manager: %v", err) + } + columnsInfo := lo.Map(extraColumns, func(col string, _ int) whutils.ColumnInfo { + return whutils.ColumnInfo{ + Name: col, + Type: schemaInEvents[col], + } + }) + + err = sm.AddColumns(ctx, channelReq.TableConfig.Table, columnsInfo) + if err != nil { + return "", "", fmt.Errorf("adding columns: %v", err) + } + + err = m.deleteChannel(ctx, channelReq) + if err != nil { + return "", "", fmt.Errorf("deleting channel: %v", err) + } + + channelRes, err = m.createChannelWithRetries(ctx, channelReq, schemaInEvents) + if err != nil { + return "", "", fmt.Errorf("creating channel: %v", err) + } + + res, err = m.insert(ctx, channelRes.ChannelID, insertReq) + if err == nil && res.Success { + return channelRes.ChannelID, insertReq.Offset, nil + } + if err != nil { + return "", "", fmt.Errorf("failed to insert data: %v", err) + } + return "", "", fmt.Errorf("failed to insert data, success: %s", strconv.FormatBool(res.Success)) +} + +func (m *Manager) abortJobs(asyncDestStruct *common.AsyncDestinationStruct, abortReason string) common.AsyncUploadOutput { + m.stats.failedJobCount.Count(len(asyncDestStruct.ImportingJobIDs)) + return common.AsyncUploadOutput{ + AbortJobIDs: asyncDestStruct.ImportingJobIDs, + AbortCount: len(asyncDestStruct.ImportingJobIDs), + AbortReason: abortReason, + DestinationID: asyncDestStruct.Destination.ID, + } +} + +func (m *Manager) Poll(pollInput common.AsyncPoll) common.PollStatusResponse { + log := m.logger.Withn(logger.NewStringField("importId", pollInput.ImportId)) + log.Infon("Polling started") + + var channelIDOffsets []channelIDOffset + err := json.Unmarshal([]byte(pollInput.ImportId), &channelIDOffsets) + if err != nil { + return common.PollStatusResponse{ + InProgress: false, + StatusCode: http.StatusBadRequest, + Complete: true, + HasFailed: true, + Error: fmt.Sprintf("failed to unmarshal import id: %v", err), + } + } + + ctx := context.Background() + g, ctx := errgroup.WithContext(ctx) + + for _, idOffset := range channelIDOffsets { + g.Go(func() error { + for { + log.Infon("Polling for channel", logger.NewStringField("channelId", idOffset.ChannelID)) + + statusRes, err := m.status(ctx, idOffset.ChannelID) + if err != nil { + return fmt.Errorf("failed to get status: %v", err) + } + if !statusRes.Valid { + log.Warnn("Invalid status response", logger.NewStringField("channelId", idOffset.ChannelID)) + return errors.New("invalid status response") + } + if statusRes.Offset != idOffset.Offset { + log.Infon("Polling in progress", logger.NewStringField("channelId", idOffset.ChannelID)) + time.Sleep(m.config.pollFrequency) + continue + } + + log.Infon("Polling completed", logger.NewStringField("channelId", idOffset.ChannelID)) + break + } + return nil + }) + } + if err = g.Wait(); err != nil { + return common.PollStatusResponse{ + InProgress: false, + StatusCode: http.StatusBadRequest, + Complete: true, + HasFailed: true, + Error: fmt.Errorf("failed to get status: %v", err).Error(), + } + } + return common.PollStatusResponse{ + InProgress: false, + StatusCode: http.StatusOK, + Complete: true, + HasFailed: false, + HasWarning: false, + } +} + +func (m *Manager) GetUploadStats(UploadStatsInput common.GetUploadStatsInput) common.GetUploadStatsResponse { + m.logger.Infon("Getting upload stats for snowpipe streaming destination") + return common.GetUploadStatsResponse{} +} diff --git a/router/batchrouter/asyncdestinationmanager/snowpipestreaming/snowpipestreaming_test.go b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/snowpipestreaming_test.go new file mode 100644 index 0000000000..a09784922e --- /dev/null +++ b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/snowpipestreaming_test.go @@ -0,0 +1,311 @@ +package snowpipestreaming + +import ( + "context" + "database/sql" + "errors" + "fmt" + "io" + "net/http" + "os" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/compose-test/compose" + "github.com/rudderlabs/compose-test/testcompose" + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-go-kit/stats" + "github.com/rudderlabs/rudder-go-kit/testhelper/rand" + + "github.com/rudderlabs/rudder-server/testhelper/backendconfigtest" + thwh "github.com/rudderlabs/rudder-server/testhelper/warehouse" + "github.com/rudderlabs/rudder-server/warehouse/integrations/snowflake" + whutils "github.com/rudderlabs/rudder-server/warehouse/utils" +) + +type mockRequestDoer struct { + err error + response *http.Response +} + +func (c *mockRequestDoer) Do(*http.Request) (*http.Response, error) { + return c.response, c.err +} + +func (nopReadCloser) Close() error { + return nil +} + +type nopReadCloser struct { + io.Reader +} + +type testCredentials struct { + Account string `json:"account"` + User string `json:"user"` + Role string `json:"role"` + Database string `json:"database"` + Warehouse string `json:"warehouse"` + PrivateKey string `json:"privateKey"` + PrivateKeyPassphrase string `json:"privateKeyPassphrase"` +} + +const ( + testKeyPairUnencrypted = "SNOWPIPE_STREAMING_KEYPAIR_UNENCRYPTED_INTEGRATION_TEST_CREDENTIALS" +) + +func getSnowpipeTestCredentials(key string) (*testCredentials, error) { + cred, exists := os.LookupEnv(key) + if !exists { + return nil, errors.New("snowpipe test credentials not found") + } + + var credentials testCredentials + err := json.Unmarshal([]byte(cred), &credentials) + if err != nil { + return nil, fmt.Errorf("unable to marshall %s to snowpipe test credentials: %v", key, err) + } + return &credentials, nil +} + +func randSchema(provider string) string { + hex := strings.ToLower(rand.String(12)) + namespace := fmt.Sprintf("test_%s_%d", hex, time.Now().Unix()) + return whutils.ToProviderCase(provider, whutils.ToSafeNamespace(provider, + namespace, + )) +} + +func TestSnowpipeStreaming(t *testing.T) { + t.Run("Integration", func(t *testing.T) { + for _, key := range []string{ + testKeyPairUnencrypted, + } { + if _, exists := os.LookupEnv(key); !exists { + if os.Getenv("FORCE_RUN_INTEGRATION_TESTS") == "true" { + t.Fatalf("%s environment variable not set", key) + } + t.Skipf("Skipping %s as %s is not set", t.Name(), key) + } + } + + t.Run("Create channel + Insert + Status", func(t *testing.T) { + c := testcompose.New(t, compose.FilePaths([]string{"testdata/docker-compose.rudder-snowpipe-clients.yml"})) + c.Start(context.Background()) + + credentials, err := getSnowpipeTestCredentials(testKeyPairUnencrypted) + require.NoError(t, err) + + ctx := context.Background() + + namespace := randSchema(whutils.SNOWFLAKE) + table := "TEST_TABLE" + tableSchema := whutils.ModelTableSchema{ + "ID": "string", "NAME": "string", "EMAIL": "string", "AGE": "int", "ACTIVE": "boolean", "DOB": "datetime", + } + + destination := backendconfigtest. + NewDestinationBuilder("SNOWPIPE_STREAMING"). + WithConfigOption("account", credentials.Account). + WithConfigOption("warehouse", credentials.Warehouse). + WithConfigOption("database", credentials.Database). + WithConfigOption("role", credentials.Role). + WithConfigOption("user", credentials.User). + WithConfigOption("useKeyPairAuth", true). + WithConfigOption("privateKey", credentials.PrivateKey). + WithConfigOption("privateKeyPassphrase", credentials.PrivateKeyPassphrase). + Build() + warehouse := whutils.ModelWarehouse{ + Namespace: namespace, + Destination: destination, + } + + // Creating namespace and table + sm := snowflake.New(config.New(), logger.NewLogger().Child("test"), stats.NOP) + require.NoError(t, err) + require.NoError(t, sm.Setup(ctx, warehouse, &whutils.NopUploader{})) + t.Cleanup(func() { sm.Cleanup(ctx) }) + require.NoError(t, sm.CreateSchema(ctx)) + t.Cleanup(func() { dropSchema(t, sm.DB.DB, namespace) }) + require.NoError(t, sm.CreateTable(ctx, table, tableSchema)) + + snowpipeClientsURL := fmt.Sprintf("http://localhost:%d", c.Port("rudder-snowpipe-clients", 9078)) + + conf := config.New() + conf.Set("Snowpipe.Client.URL", snowpipeClientsURL) + + snowpipeStreamManager := New(conf, logger.NewLogger().Child("test"), stats.NOP, &destination) + + // creating channel + createChannelRes, err := snowpipeStreamManager.createChannel(ctx, &createChannelRequest{ + RudderIdentifier: "1", + Partition: "1", + AccountConfig: accountConfig{ + Account: credentials.Account, + User: credentials.User, + Role: credentials.Role, + PrivateKey: strings.ReplaceAll(credentials.PrivateKey, "\n", "\\\\\n"), + PrivateKeyPassphrase: credentials.PrivateKeyPassphrase, + }, + TableConfig: tableConfig{ + Database: credentials.Database, + Schema: namespace, + Table: table, + }, + }) + require.NoError(t, err) + require.True(t, createChannelRes.Valid) + + // inserting rows + insertRes, err := snowpipeStreamManager.insert(ctx, createChannelRes.ChannelID, &insertRequest{ + Rows: []Row{ + {"ID": "ID1", "NAME": "Alice Johnson", "EMAIL": "alice.johnson@example.com", "AGE": 28, "ACTIVE": true, "DOB": "1995-06-15T12:30:00Z"}, + {"ID": "ID2", "NAME": "Bob Smith", "EMAIL": "bob.smith@example.com", "AGE": 35, "ACTIVE": true, "DOB": "1988-01-20T09:30:00Z"}, + {"ID": "ID3", "NAME": "Charlie Brown", "EMAIL": "charlie.brown@example.com", "AGE": 22, "ACTIVE": false, "DOB": "2001-11-05T14:45:00Z"}, + {"ID": "ID4", "NAME": "Diana Prince", "EMAIL": "diana.prince@example.com", "AGE": 30, "ACTIVE": true, "DOB": "1993-08-18T08:15:00Z"}, + {"ID": "ID5", "NAME": "Eve Adams", "AGE": 45, "ACTIVE": true, "DOB": "1978-03-22T16:50:00Z"}, // -- No email + {"ID": "ID6", "NAME": "Frank Castle", "EMAIL": "frank.castle@example.com", "AGE": 38, "ACTIVE": false, "DOB": "1985-09-14T10:10:00Z"}, + {"ID": "ID7", "NAME": "Grace Hopper", "EMAIL": "grace.hopper@example.com", "AGE": 85, "ACTIVE": true, "DOB": "1936-12-09T11:30:00Z"}, + }, + Offset: "100", + }) + require.NoError(t, err) + require.True(t, insertRes.Success) + require.Empty(t, insertRes.Errors) + + // getting status + require.Eventually(t, func() bool { + statusRes, err := snowpipeStreamManager.status(ctx, createChannelRes.ChannelID) + if err != nil { + t.Log("Error getting status:", err) + return false + } + return statusRes.Offset == "100" + }, + 30*time.Second, + 300*time.Millisecond, + ) + + // checking records in warehouse + records := thwh.RetrieveRecordsFromWarehouse(t, sm.DB.DB, fmt.Sprintf(`SELECT ID, NAME, EMAIL, AGE, ACTIVE, DOB FROM %q.%q ORDER BY ID;`, namespace, table)) + require.ElementsMatch(t, [][]string{ + {"ID1", "Alice Johnson", "alice.johnson@example.com", "28", "true", "1995-06-15T12:30:00Z"}, + {"ID2", "Bob Smith", "bob.smith@example.com", "35", "true", "1988-01-20T09:30:00Z"}, + {"ID3", "Charlie Brown", "charlie.brown@example.com", "22", "false", "2001-11-05T14:45:00Z"}, + {"ID4", "Diana Prince", "diana.prince@example.com", "30", "true", "1993-08-18T08:15:00Z"}, + {"ID5", "Eve Adams", "", "45", "true", "1978-03-22T16:50:00Z"}, + {"ID6", "Frank Castle", "frank.castle@example.com", "38", "false", "1985-09-14T10:10:00Z"}, + {"ID7", "Grace Hopper", "grace.hopper@example.com", "85", "true", "1936-12-09T11:30:00Z"}, + }, records) + }) + + t.Run("Create + Delete channel", func(t *testing.T) { + c := testcompose.New(t, compose.FilePaths([]string{"testdata/docker-compose.rudder-snowpipe-clients.yml"})) + c.Start(context.Background()) + + credentials, err := getSnowpipeTestCredentials(testKeyPairUnencrypted) + require.NoError(t, err) + + ctx := context.Background() + + namespace := randSchema(whutils.SNOWFLAKE) + table := "TEST_TABLE" + tableSchema := whutils.ModelTableSchema{ + "ID": "string", "NAME": "string", "EMAIL": "string", "AGE": "int", "ACTIVE": "boolean", "DOB": "datetime", + } + + destination := backendconfigtest. + NewDestinationBuilder("SNOWPIPE_STREAMING"). + WithConfigOption("account", credentials.Account). + WithConfigOption("warehouse", credentials.Warehouse). + WithConfigOption("database", credentials.Database). + WithConfigOption("role", credentials.Role). + WithConfigOption("user", credentials.User). + WithConfigOption("useKeyPairAuth", true). + WithConfigOption("privateKey", credentials.PrivateKey). + WithConfigOption("privateKeyPassphrase", credentials.PrivateKeyPassphrase). + Build() + warehouse := whutils.ModelWarehouse{ + Namespace: namespace, + Destination: destination, + } + + // Creating namespace and table + sm := snowflake.New(config.New(), logger.NOP, stats.NOP) + require.NoError(t, err) + require.NoError(t, sm.Setup(ctx, warehouse, &whutils.NopUploader{})) + t.Cleanup(func() { sm.Cleanup(ctx) }) + require.NoError(t, sm.CreateSchema(ctx)) + t.Cleanup(func() { dropSchema(t, sm.DB.DB, namespace) }) + require.NoError(t, sm.CreateTable(ctx, table, tableSchema)) + + snowpipeClientsURL := fmt.Sprintf("http://localhost:%d", c.Port("rudder-snowpipe-clients", 9078)) + + conf := config.New() + conf.Set("Snowpipe.Client.URL", snowpipeClientsURL) + + snowpipeStreamManager := New(conf, logger.NOP, stats.NOP, &destination) + + // creating channel + createChannelReq := &createChannelRequest{ + RudderIdentifier: "1", + Partition: "1", + AccountConfig: accountConfig{ + Account: credentials.Account, + User: credentials.User, + Role: credentials.Role, + PrivateKey: strings.ReplaceAll(credentials.PrivateKey, "\n", "\\\\\n"), + PrivateKeyPassphrase: credentials.PrivateKeyPassphrase, + }, + TableConfig: tableConfig{ + Database: credentials.Database, + Schema: namespace, + Table: table, + }, + } + + // creating channel + createChannelRes1, err := snowpipeStreamManager.createChannel(ctx, createChannelReq) + require.NoError(t, err) + require.True(t, createChannelRes1.Valid) + + // creating channel again with same request should return same channel id + createChannelRes2, err := snowpipeStreamManager.createChannel(ctx, createChannelReq) + require.NoError(t, err) + require.True(t, createChannelRes2.Valid) + require.Equal(t, createChannelRes1.ChannelID, createChannelRes2.ChannelID) + + // deleting channel + err = snowpipeStreamManager.deleteChannel(ctx, createChannelReq) + require.NoError(t, err) + + // creating channel again, since the previous channel is deleted, it should return a new channel id + createChannelRes3, err := snowpipeStreamManager.createChannel(ctx, createChannelReq) + require.NoError(t, err) + require.True(t, createChannelRes3.Valid) + require.NotEqual(t, createChannelRes1.ChannelID, createChannelRes3.ChannelID) + }) + }) +} + +func dropSchema(t *testing.T, db *sql.DB, namespace string) { + t.Helper() + t.Log("dropping schema", namespace) + + require.Eventually(t, + func() bool { + _, err := db.ExecContext(context.Background(), fmt.Sprintf(`DROP SCHEMA %q CASCADE;`, namespace)) + if err != nil { + t.Logf("error deleting schema %q: %v", namespace, err) + return false + } + return true + }, + time.Minute, + time.Second, + ) +} diff --git a/router/batchrouter/asyncdestinationmanager/snowpipestreaming/status.go b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/status.go new file mode 100644 index 0000000000..4dac573694 --- /dev/null +++ b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/status.go @@ -0,0 +1,42 @@ +package snowpipestreaming + +import ( + "context" + "fmt" + "io" + "net/http" + + "github.com/rudderlabs/rudder-server/utils/httputil" +) + +type statusResponse struct { + Offset string `json:"offset"` + Valid bool `json:"valid"` +} + +func (m *Manager) status(ctx context.Context, channelId string) (*statusResponse, error) { + statusReqURL := m.config.clientURL + "/channels/" + channelId + "/status" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, statusReqURL, nil) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, reqErr := m.requestDoer.Do(req) + if reqErr != nil { + return nil, fmt.Errorf("sending request: %w", reqErr) + } + defer func() { httputil.CloseResponse(resp) }() + + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("invalid status code: %d, body: %s", resp.StatusCode, string(b)) + } + + var res statusResponse + if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + return nil, fmt.Errorf("decoding response: %w", err) + } + + return &res, nil +} diff --git a/router/batchrouter/asyncdestinationmanager/snowpipestreaming/status_test.go b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/status_test.go new file mode 100644 index 0000000000..338bf1123f --- /dev/null +++ b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/status_test.go @@ -0,0 +1,122 @@ +package snowpipestreaming + +import ( + "bytes" + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-go-kit/stats" + + "github.com/rudderlabs/rudder-server/testhelper/backendconfigtest" +) + +func TestStatus(t *testing.T) { + channelID := "channelID" + + snowpipeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Equal(t, "application/json", r.Header.Get("Content-Type")) + + body, err := io.ReadAll(r.Body) + require.NoError(t, err) + require.NoError(t, r.Body.Close()) + require.Empty(t, body) + + switch r.URL.String() { + case "/channels/" + channelID + "/status": + _, err := w.Write([]byte(`{"offset":"5","valid":true}`)) + require.NoError(t, err) + default: + require.FailNowf(t, "SnowpipeClients", "Unexpected %s to SnowpipeClients, not found: %+v", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + } + })) + defer snowpipeServer.Close() + + t.Run("Success", func(t *testing.T) { + ctx := context.Background() + destination := backendconfigtest. + NewDestinationBuilder("SNOWPIPE_STREAMING"). + Build() + + c := config.New() + c.Set("Snowpipe.Client.URL", snowpipeServer.URL) + + manager := New(c, logger.NOP, stats.NOP, &destination, WithRequestDoer(snowpipeServer.Client())) + res, err := manager.status(ctx, channelID) + require.NoError(t, err) + require.Equal(t, "5", res.Offset) + require.True(t, res.Valid) + }) + t.Run("Request failure", func(t *testing.T) { + ctx := context.Background() + destination := backendconfigtest. + NewDestinationBuilder("SNOWPIPE_STREAMING"). + Build() + + c := config.New() + c.Set("Snowpipe.Client.URL", snowpipeServer.URL) + + reqDoer := &mockRequestDoer{ + err: errors.New("bad client"), + response: &http.Response{ + StatusCode: http.StatusOK, + }, + } + + manager := New(c, logger.NOP, stats.NOP, &destination, WithRequestDoer(reqDoer)) + res, err := manager.status(ctx, channelID) + require.Error(t, err) + require.Nil(t, res) + }) + t.Run("Request failure (non 200's status code)", func(t *testing.T) { + ctx := context.Background() + destination := backendconfigtest. + NewDestinationBuilder("SNOWPIPE_STREAMING"). + Build() + + c := config.New() + c.Set("Snowpipe.Client.URL", snowpipeServer.URL) + + reqDoer := &mockRequestDoer{ + response: &http.Response{ + StatusCode: http.StatusBadRequest, + Body: nopReadCloser{Reader: bytes.NewReader([]byte(`{}`))}, + }, + } + + manager := New(c, logger.NOP, stats.NOP, &destination, WithRequestDoer(reqDoer)) + res, err := manager.status(ctx, channelID) + require.Error(t, err) + require.Nil(t, res) + }) + t.Run("Request failure (invalid response)", func(t *testing.T) { + ctx := context.Background() + destination := backendconfigtest. + NewDestinationBuilder("SNOWPIPE_STREAMING"). + Build() + + c := config.New() + c.Set("Snowpipe.Client.URL", snowpipeServer.URL) + + reqDoer := &mockRequestDoer{ + response: &http.Response{ + StatusCode: http.StatusOK, + Body: nopReadCloser{Reader: bytes.NewReader([]byte(`{abd}`))}, + }, + } + + manager := New(c, logger.NOP, stats.NOP, &destination, WithRequestDoer(reqDoer)) + res, err := manager.status(ctx, channelID) + require.Error(t, err) + require.Nil(t, res) + }) +} diff --git a/router/batchrouter/asyncdestinationmanager/snowpipestreaming/testdata/docker-compose.rudder-snowpipe-clients.yml b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/testdata/docker-compose.rudder-snowpipe-clients.yml new file mode 100644 index 0000000000..e6690dce64 --- /dev/null +++ b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/testdata/docker-compose.rudder-snowpipe-clients.yml @@ -0,0 +1,11 @@ +version: "3.9" + +services: + rudder-snowpipe-clients: + image: "rudderstack/rudder-snowpipe-clients:chore.snowpipe-poc" + ports: + - "9078" + healthcheck: + test: wget --no-verbose --tries=1 --spider http://localhost:9078/q/health || exit 1 + interval: 1s + retries: 25 diff --git a/router/batchrouter/handle.go b/router/batchrouter/handle.go index 3b55638667..cb279604ba 100644 --- a/router/batchrouter/handle.go +++ b/router/batchrouter/handle.go @@ -28,6 +28,7 @@ import ( "github.com/rudderlabs/rudder-go-kit/stats" kitsync "github.com/rudderlabs/rudder-go-kit/sync" obskit "github.com/rudderlabs/rudder-observability-kit/go/labels" + backendconfig "github.com/rudderlabs/rudder-server/backend-config" "github.com/rudderlabs/rudder-server/jobsdb" asynccommon "github.com/rudderlabs/rudder-server/router/batchrouter/asyncdestinationmanager/common" @@ -76,6 +77,7 @@ type Handle struct { maxFailedCountForJob config.ValueLoader[int] maxFailedCountForSourcesJob config.ValueLoader[int] asyncUploadTimeout config.ValueLoader[time.Duration] + asyncUploadWorkerTimeout config.ValueLoader[time.Duration] retryTimeWindow config.ValueLoader[time.Duration] sourcesRetryTimeWindow config.ValueLoader[time.Duration] reportingEnabled bool diff --git a/router/batchrouter/handle_async.go b/router/batchrouter/handle_async.go index 3d370a141f..3571503a03 100644 --- a/router/batchrouter/handle_async.go +++ b/router/batchrouter/handle_async.go @@ -16,6 +16,7 @@ import ( "github.com/tidwall/gjson" "github.com/rudderlabs/rudder-go-kit/stats" + "github.com/rudderlabs/rudder-server/jobsdb" "github.com/rudderlabs/rudder-server/router/batchrouter/asyncdestinationmanager/common" asynccommon "github.com/rudderlabs/rudder-server/router/batchrouter/asyncdestinationmanager/common" @@ -320,7 +321,7 @@ func (brt *Handle) asyncUploadWorker(ctx context.Context) { select { case <-ctx.Done(): return - case <-time.After(10 * time.Second): + case <-time.After(brt.asyncUploadWorkerTimeout.Load()): brt.configSubscriberMu.RLock() destinationsMap := brt.destinationsMap uploadIntervalMap := brt.uploadIntervalMap diff --git a/router/batchrouter/handle_lifecycle.go b/router/batchrouter/handle_lifecycle.go index 06851978d7..676e5c4145 100644 --- a/router/batchrouter/handle_lifecycle.go +++ b/router/batchrouter/handle_lifecycle.go @@ -202,6 +202,7 @@ func (brt *Handle) setupReloadableVars() { brt.maxFailedCountForJob = config.GetReloadableIntVar(128, 1, "BatchRouter."+brt.destType+".maxFailedCountForJob", "BatchRouter.maxFailedCountForJob") brt.maxFailedCountForSourcesJob = config.GetReloadableIntVar(3, 1, "BatchRouter.RSources."+brt.destType+".maxFailedCountForJob", "BatchRouter.RSources.maxFailedCountForJob") brt.asyncUploadTimeout = config.GetReloadableDurationVar(30, time.Minute, "BatchRouter."+brt.destType+".asyncUploadTimeout", "BatchRouter.asyncUploadTimeout") + brt.asyncUploadWorkerTimeout = config.GetReloadableDurationVar(10, time.Second, "BatchRouter."+brt.destType+".asyncUploadWorkerTimeout", "BatchRouter.asyncUploadWorkerTimeout") brt.retryTimeWindow = config.GetReloadableDurationVar(180, time.Minute, "BatchRouter."+brt.destType+".retryTimeWindow", "BatchRouter."+brt.destType+".retryTimeWindowInMins", "BatchRouter.retryTimeWindow", "BatchRouter.retryTimeWindowInMins") brt.sourcesRetryTimeWindow = config.GetReloadableDurationVar(1, time.Minute, "BatchRouter.RSources."+brt.destType+".retryTimeWindow", "BatchRouter.RSources."+brt.destType+".retryTimeWindowInMins", "BatchRouter.RSources.retryTimeWindow", "BatchRouter.RSources.retryTimeWindowInMins") brt.jobQueryBatchSize = config.GetReloadableIntVar(100000, 1, "BatchRouter."+brt.destType+".jobQueryBatchSize", "BatchRouter.jobQueryBatchSize") diff --git a/testhelper/warehouse/records.go b/testhelper/warehouse/records.go new file mode 100644 index 0000000000..c1aedc836d --- /dev/null +++ b/testhelper/warehouse/records.go @@ -0,0 +1,58 @@ +package warehouse + +import ( + "context" + "database/sql" + "testing" + "time" + + "github.com/samber/lo" + "github.com/spf13/cast" + "github.com/stretchr/testify/require" +) + +// RetrieveRecordsFromWarehouse retrieves records from the warehouse based on the given query. +// It returns a slice of slices, where each inner slice represents a record's values. +func RetrieveRecordsFromWarehouse( + t testing.TB, + db *sql.DB, + query string, +) [][]string { + t.Helper() + + rows, err := db.QueryContext(context.Background(), query) + require.NoError(t, err) + defer func() { _ = rows.Close() }() + + _ = rows.Err() + + columns, err := rows.Columns() + require.NoError(t, err) + + var records [][]string + for rows.Next() { + resultSet := make([]any, len(columns)) + resultSetPtrs := make([]any, len(columns)) + for i := 0; i < len(columns); i++ { + resultSetPtrs[i] = &resultSet[i] + } + + err = rows.Scan(resultSetPtrs...) + require.NoError(t, err) + + records = append(records, lo.Map(resultSet, func(item any, index int) string { + switch item := item.(type) { + case time.Time: + return item.Format(time.RFC3339) + case string: + if t, err := time.Parse(time.RFC3339Nano, item); err == nil { + return t.Format(time.RFC3339) + } + return item + default: + return cast.ToString(item) + } + })) + } + return records +} diff --git a/utils/misc/misc.go b/utils/misc/misc.go index 64e07af891..0c0677e114 100644 --- a/utils/misc/misc.go +++ b/utils/misc/misc.go @@ -96,7 +96,7 @@ func Init() { } func BatchDestinations() []string { - batchDestinations := []string{"S3", "GCS", "MINIO", "RS", "BQ", "AZURE_BLOB", "SNOWFLAKE", "POSTGRES", "CLICKHOUSE", "DIGITAL_OCEAN_SPACES", "MSSQL", "AZURE_SYNAPSE", "S3_DATALAKE", "MARKETO_BULK_UPLOAD", "GCS_DATALAKE", "AZURE_DATALAKE", "DELTALAKE", "BINGADS_AUDIENCE", "ELOQUA", "YANDEX_METRICA_OFFLINE_EVENTS", "SFTP", "BINGADS_OFFLINE_CONVERSIONS", "KLAVIYO_BULK_UPLOAD", "LYTICS_BULK_UPLOAD"} + batchDestinations := []string{"S3", "GCS", "MINIO", "RS", "BQ", "AZURE_BLOB", "SNOWFLAKE", "POSTGRES", "CLICKHOUSE", "DIGITAL_OCEAN_SPACES", "MSSQL", "AZURE_SYNAPSE", "S3_DATALAKE", "MARKETO_BULK_UPLOAD", "GCS_DATALAKE", "AZURE_DATALAKE", "DELTALAKE", "BINGADS_AUDIENCE", "ELOQUA", "YANDEX_METRICA_OFFLINE_EVENTS", "SFTP", "BINGADS_OFFLINE_CONVERSIONS", "KLAVIYO_BULK_UPLOAD", "LYTICS_BULK_UPLOAD", "SNOWPIPE_STREAMING"} return batchDestinations } diff --git a/warehouse/integrations/snowflake/datatype_mapper.go b/warehouse/integrations/snowflake/datatype_mapper.go index be077a6f74..0a06bd03f2 100644 --- a/warehouse/integrations/snowflake/datatype_mapper.go +++ b/warehouse/integrations/snowflake/datatype_mapper.go @@ -44,7 +44,7 @@ var dataTypesMapToRudder = map[string]string{ "VARIANT": "json", } -func calculateDataType(columnType string, numericScale sql.NullInt64) (string, bool) { +func CalculateDataType(columnType string, numericScale sql.NullInt64) (string, bool) { if datatype, ok := dataTypesMapToRudder[columnType]; ok { if datatype == "int" && numericScale.Valid && numericScale.Int64 > 0 { datatype = "float" diff --git a/warehouse/integrations/snowflake/datatype_mapper_test.go b/warehouse/integrations/snowflake/datatype_mapper_test.go index 2dc8e720fe..6f5ed380a0 100644 --- a/warehouse/integrations/snowflake/datatype_mapper_test.go +++ b/warehouse/integrations/snowflake/datatype_mapper_test.go @@ -21,7 +21,7 @@ func TestCalculateDataType(t *testing.T) { } for _, tc := range testCases { - dataType, exists := calculateDataType(tc.columnType, tc.numericScale) + dataType, exists := CalculateDataType(tc.columnType, tc.numericScale) require.Equal(t, tc.expected, dataType) require.Equal(t, tc.exists, exists) } diff --git a/warehouse/integrations/snowflake/snowflake.go b/warehouse/integrations/snowflake/snowflake.go index 98b9a573bc..f476e5e6ea 100644 --- a/warehouse/integrations/snowflake/snowflake.go +++ b/warehouse/integrations/snowflake/snowflake.go @@ -1392,7 +1392,7 @@ func (sf *Snowflake) FetchSchema(ctx context.Context) (model.Schema, model.Schem schema[tableName] = make(map[string]string) } - if datatype, ok := calculateDataType(columnType, numericScale); ok { + if datatype, ok := CalculateDataType(columnType, numericScale); ok { schema[tableName][columnName] = datatype } else { if _, ok := unrecognizedSchema[tableName]; !ok { diff --git a/warehouse/utils/uploader.go b/warehouse/utils/uploader.go new file mode 100644 index 0000000000..0538607d16 --- /dev/null +++ b/warehouse/utils/uploader.go @@ -0,0 +1,58 @@ +package warehouseutils + +import ( + "context" + "time" + + "github.com/rudderlabs/rudder-server/warehouse/internal/model" +) + +type ( + ModelWarehouse = model.Warehouse + ModelTableSchema = model.TableSchema +) + +//go:generate mockgen -destination=../internal/mocks/utils/mock_uploader.go -package mock_uploader github.com/rudderlabs/rudder-server/warehouse/utils Uploader +type Uploader interface { + IsWarehouseSchemaEmpty() bool + GetLocalSchema(ctx context.Context) (model.Schema, error) + UpdateLocalSchema(ctx context.Context, schema model.Schema) error + GetTableSchemaInWarehouse(tableName string) model.TableSchema + GetTableSchemaInUpload(tableName string) model.TableSchema + GetLoadFilesMetadata(ctx context.Context, options GetLoadFilesOptions) ([]LoadFile, error) + GetSampleLoadFileLocation(ctx context.Context, tableName string) (string, error) + GetSingleLoadFile(ctx context.Context, tableName string) (LoadFile, error) + ShouldOnDedupUseNewRecord() bool + UseRudderStorage() bool + GetLoadFileGenStartTIme() time.Time + GetLoadFileType() string + GetFirstLastEvent() (time.Time, time.Time) + CanAppend() bool +} + +type NopUploader struct{} + +func (n *NopUploader) IsWarehouseSchemaEmpty() bool { + return false +} +func (n *NopUploader) GetLocalSchema(ctx context.Context) (model.Schema, error) { return nil, nil } // nolint:nilnil +func (n *NopUploader) UpdateLocalSchema(ctx context.Context, schema model.Schema) error { return nil } +func (n *NopUploader) GetTableSchemaInWarehouse(tableName string) model.TableSchema { return nil } +func (n *NopUploader) GetTableSchemaInUpload(tableName string) model.TableSchema { return nil } +func (n *NopUploader) ShouldOnDedupUseNewRecord() bool { return false } +func (n *NopUploader) UseRudderStorage() bool { return false } +func (n *NopUploader) GetLoadFileGenStartTIme() time.Time { return time.Time{} } +func (n *NopUploader) GetLoadFileType() string { return "" } +func (n *NopUploader) CanAppend() bool { return false } +func (n *NopUploader) GetLoadFilesMetadata(ctx context.Context, options GetLoadFilesOptions) ([]LoadFile, error) { + return nil, nil +} + +func (n *NopUploader) GetSampleLoadFileLocation(ctx context.Context, tableName string) (string, error) { + return "", nil +} + +func (n *NopUploader) GetSingleLoadFile(ctx context.Context, tableName string) (LoadFile, error) { + return LoadFile{}, nil +} +func (n *NopUploader) GetFirstLastEvent() (time.Time, time.Time) { return time.Time{}, time.Time{} } diff --git a/warehouse/utils/utils.go b/warehouse/utils/utils.go index e2a7a4001e..976b355870 100644 --- a/warehouse/utils/utils.go +++ b/warehouse/utils/utils.go @@ -2,7 +2,6 @@ package warehouseutils import ( "bytes" - "context" "crypto/sha512" "database/sql" "encoding/hex" @@ -218,24 +217,6 @@ type KeyValue struct { Value interface{} } -//go:generate mockgen -destination=../internal/mocks/utils/mock_uploader.go -package mock_uploader github.com/rudderlabs/rudder-server/warehouse/utils Uploader -type Uploader interface { - IsWarehouseSchemaEmpty() bool - GetLocalSchema(ctx context.Context) (model.Schema, error) - UpdateLocalSchema(ctx context.Context, schema model.Schema) error - GetTableSchemaInWarehouse(tableName string) model.TableSchema - GetTableSchemaInUpload(tableName string) model.TableSchema - GetLoadFilesMetadata(ctx context.Context, options GetLoadFilesOptions) ([]LoadFile, error) - GetSampleLoadFileLocation(ctx context.Context, tableName string) (string, error) - GetSingleLoadFile(ctx context.Context, tableName string) (LoadFile, error) - ShouldOnDedupUseNewRecord() bool - UseRudderStorage() bool - GetLoadFileGenStartTIme() time.Time - GetLoadFileType() string - GetFirstLastEvent() (time.Time, time.Time) - CanAppend() bool -} - type GetLoadFilesOptions struct { Table string StartID int64