Skip to content

Commit

Permalink
[clickhouse] validate certain aspects of existing destination tables (#…
Browse files Browse the repository at this point in the history
…2026)

Partially addresses #2019 

Tests column name equality taking exclusion into account, PeerDB
columns, engine and emptiness
  • Loading branch information
heavycrystal authored Sep 4, 2024
1 parent b3d8ce7 commit ee00aad
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 15 deletions.
48 changes: 48 additions & 0 deletions flow/cmd/validate_mirror.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/jackc/pgx/v5/pgtype"

"github.com/PeerDB-io/peer-flow/connectors"
connclickhouse "github.com/PeerDB-io/peer-flow/connectors/clickhouse"
connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres"
"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/generated/protos"
Expand Down Expand Up @@ -104,6 +105,7 @@ func (h *FlowRequestHandler) ValidateCDCMirror(
}

sourceTables := make([]*utils.SchemaTable, 0, len(req.ConnectionConfigs.TableMappings))
srcTableNames := make([]string, 0, len(req.ConnectionConfigs.TableMappings))
for _, tableMapping := range req.ConnectionConfigs.TableMappings {
parsedTable, parseErr := utils.ParseSchemaTable(tableMapping.SourceTableIdentifier)
if parseErr != nil {
Expand All @@ -117,6 +119,7 @@ func (h *FlowRequestHandler) ValidateCDCMirror(
}

sourceTables = append(sourceTables, parsedTable)
srcTableNames = append(srcTableNames, tableMapping.SourceTableIdentifier)
}

pubName := req.ConnectionConfigs.PublicationName
Expand Down Expand Up @@ -164,6 +167,51 @@ func (h *FlowRequestHandler) ValidateCDCMirror(
}
}

dstPeer, err := connectors.LoadPeer(ctx, h.pool, req.ConnectionConfigs.DestinationName)
if err != nil {
slog.Error("/validatecdc failed to load destination peer", slog.String("peer", req.ConnectionConfigs.DestinationName))
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, err
}
if dstPeer.GetClickhouseConfig() != nil {
chPeer, err := connclickhouse.NewClickhouseConnector(ctx, nil, dstPeer.GetClickhouseConfig())
if err != nil {
displayErr := fmt.Errorf("failed to create clickhouse connector: %v", err)
h.alerter.LogNonFlowWarning(ctx, telemetry.CreateMirror, req.ConnectionConfigs.FlowJobName,
fmt.Sprint(displayErr),
)
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, displayErr
}
defer chPeer.Close()

res, err := pgPeer.GetTableSchema(ctx, &protos.GetTableSchemaBatchInput{
TableIdentifiers: srcTableNames,
System: protos.TypeSystem_PG,
})
if err != nil {
displayErr := fmt.Errorf("failed to get source table schema: %v", err)
h.alerter.LogNonFlowWarning(ctx, telemetry.CreateMirror, req.ConnectionConfigs.FlowJobName,
fmt.Sprint(displayErr),
)
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, displayErr
}

err = chPeer.CheckDestinationTables(ctx, req.ConnectionConfigs, res.TableNameSchemaMapping)
if err != nil {
h.alerter.LogNonFlowWarning(ctx, telemetry.CreateMirror, req.ConnectionConfigs.FlowJobName,
fmt.Sprint(err),
)
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, err
}
}

return &protos.ValidateCDCMirrorResponse{
Ok: true,
}, nil
Expand Down
171 changes: 158 additions & 13 deletions flow/connectors/clickhouse/clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import (
"errors"
"fmt"
"log/slog"
"maps"
"net/url"
"slices"
"strings"
"time"

Expand All @@ -25,12 +27,11 @@ import (

type ClickhouseConnector struct {
*metadataStore.PostgresMetadata
database clickhouse.Conn
tableSchemaMapping map[string]*protos.TableSchema
logger log.Logger
config *protos.ClickhouseConfig
credsProvider *utils.ClickHouseS3Credentials
s3Stage *ClickHouseS3Stage
database clickhouse.Conn
logger log.Logger
config *protos.ClickhouseConfig
credsProvider *utils.ClickHouseS3Credentials
s3Stage *ClickHouseS3Stage
}

func ValidateS3(ctx context.Context, creds *utils.ClickHouseS3Credentials) error {
Expand Down Expand Up @@ -195,13 +196,12 @@ func NewClickhouseConnector(
}

return &ClickhouseConnector{
database: database,
PostgresMetadata: pgMetadata,
tableSchemaMapping: nil,
config: config,
logger: logger,
credsProvider: &clickHouseS3CredentialsNew,
s3Stage: NewClickHouseS3Stage(),
database: database,
PostgresMetadata: pgMetadata,
config: config,
logger: logger,
credsProvider: &clickHouseS3CredentialsNew,
s3Stage: NewClickHouseS3Stage(),
}, nil
}

Expand Down Expand Up @@ -273,3 +273,148 @@ func (c *ClickhouseConnector) execWithLogging(ctx context.Context, query string)
c.logger.Info("[clickhouse] executing DDL statement", slog.String("query", query))
return c.database.Exec(ctx, query)
}

func (c *ClickhouseConnector) checkTablesEmptyAndEngine(ctx context.Context, tables []string) error {
queryInput := make([]interface{}, 0, len(tables)+1)
queryInput = append(queryInput, c.config.Database)
for _, table := range tables {
queryInput = append(queryInput, table)
}
rows, err := c.database.Query(ctx,
fmt.Sprintf("SELECT name,engine,total_rows FROM system.tables WHERE database=? AND table IN (%s)",
strings.Join(slices.Repeat([]string{"?"}, len(tables)), ",")), queryInput...)
if err != nil {
return fmt.Errorf("failed to get information for destination tables: %w", err)
}
defer rows.Close()

for rows.Next() {
var tableName, engine string
var totalRows uint64
err = rows.Scan(&tableName, &engine, &totalRows)
if err != nil {
return fmt.Errorf("failed to scan information for tables: %w", err)
}
if totalRows != 0 {
return fmt.Errorf("table %s exists and is not empty", tableName)
}
if !slices.Contains(acceptableTableEngines, engine) {
return fmt.Errorf("table %s exists and is not using ReplacingMergeTree/MergeTree engine", tableName)
}
}
if rows.Err() != nil {
return fmt.Errorf("failed to read rows: %w", rows.Err())
}
return nil
}

func (c *ClickhouseConnector) getTableColumnsMapping(ctx context.Context,
tables []string,
) (map[string][]*protos.FieldDescription, error) {
tableColumnsMapping := make(map[string][]*protos.FieldDescription, len(tables))
queryInput := make([]interface{}, 0, len(tables)+1)
queryInput = append(queryInput, c.config.Database)
for _, table := range tables {
queryInput = append(queryInput, table)
}
rows, err := c.database.Query(ctx,
fmt.Sprintf("SELECT name,type,table FROM system.columns WHERE database=? AND table IN (%s)",
strings.Join(slices.Repeat([]string{"?"}, len(tables)), ",")), queryInput...)
if err != nil {
return nil, fmt.Errorf("failed to get columns for destination tables: %w", err)
}
defer rows.Close()
for rows.Next() {
var tableName string
var fieldDescription protos.FieldDescription
err = rows.Scan(&fieldDescription.Name, &fieldDescription.Type, &tableName)
if err != nil {
return nil, fmt.Errorf("failed to scan columns for tables: %w", err)
}
tableColumnsMapping[tableName] = append(tableColumnsMapping[tableName], &fieldDescription)
}
if rows.Err() != nil {
return nil, fmt.Errorf("failed to read rows: %w", rows.Err())
}
return tableColumnsMapping, nil
}

func (c *ClickhouseConnector) processTableComparison(dstTableName string, srcSchema *protos.TableSchema,
dstSchema []*protos.FieldDescription, peerDBColumns []string, tableMapping *protos.TableMapping,
) error {
for _, srcField := range srcSchema.Columns {
colName := srcField.Name
// if the column is mapped to a different name, find and use that name instead
for _, col := range tableMapping.Columns {
if col.SourceName == colName {
if col.DestinationName != "" {
colName = col.DestinationName
}
break
}
}
found := false
// compare either the source column name or the mapped destination column name to the ClickHouse schema
for _, dstField := range dstSchema {
// not doing type checks for now
if dstField.Name == colName {
found = true
break
}
}
if !found {
return fmt.Errorf("field %s not found in destination table", srcField.Name)
}
}
foundPeerDBColumns := 0
for _, dstField := range dstSchema {
// all these columns need to be present in the destination table
if slices.Contains(peerDBColumns, dstField.Name) {
foundPeerDBColumns++
}
}
if foundPeerDBColumns != len(peerDBColumns) {
return fmt.Errorf("not all PeerDB columns found in destination table %s", dstTableName)
}
return nil
}

func (c *ClickhouseConnector) CheckDestinationTables(ctx context.Context, req *protos.FlowConnectionConfigs,
tableNameSchemaMapping map[string]*protos.TableSchema,
) error {
peerDBColumns := []string{signColName, versionColName}
if req.SyncedAtColName != "" {
peerDBColumns = append(peerDBColumns, req.SyncedAtColName)
}
// this is for handling column exclusion, processed schema does that in a step
processedMapping := shared.BuildProcessedSchemaMapping(req.TableMappings, tableNameSchemaMapping, c.logger)
dstTableNames := slices.Collect(maps.Keys(processedMapping))
err := c.checkTablesEmptyAndEngine(ctx, dstTableNames)
if err != nil {
return err
}
// optimization: fetching columns for all tables at once
chTableColumnsMapping, err := c.getTableColumnsMapping(ctx, dstTableNames)
if err != nil {
return err
}

for _, tableMapping := range req.TableMappings {
dstTableName := tableMapping.DestinationTableIdentifier
if _, ok := processedMapping[dstTableName]; !ok {
// if destination table is not a key, that means source table was not a key in the original schema mapping(?)
return fmt.Errorf("source table %s not found in schema mapping", tableMapping.SourceTableIdentifier)
}
// if destination table does not exist, we're good
if _, ok := chTableColumnsMapping[dstTableName]; !ok {
continue
}

err = c.processTableComparison(dstTableName, processedMapping[dstTableName],
chTableColumnsMapping[dstTableName], peerDBColumns, tableMapping)
if err != nil {
return err
}
}
return nil
}
5 changes: 3 additions & 2 deletions flow/connectors/clickhouse/normalize.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ const (
versionColType = "Int64"
)

var acceptableTableEngines = []string{"ReplacingMergeTree", "MergeTree"}

func (c *ClickhouseConnector) StartSetupNormalizedTables(_ context.Context) (interface{}, error) {
return nil, nil
}
Expand Down Expand Up @@ -396,8 +398,7 @@ func (c *ClickhouseConnector) getDistinctTableNamesInBatch(
tableNames = append(tableNames, tableName.String)
}

err = rows.Err()
if err != nil {
if rows.Err() != nil {
return nil, fmt.Errorf("failed to read rows: %w", err)
}

Expand Down

0 comments on commit ee00aad

Please sign in to comment.