Skip to content

Commit

Permalink
Address more comments
Browse files Browse the repository at this point in the history
Signed-off-by: Hai Yan <[email protected]>
  • Loading branch information
oeyh committed Jan 24, 2025
1 parent 00abf68 commit f4e1683
Show file tree
Hide file tree
Showing 9 changed files with 153 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ public class PostgresStreamState {
@JsonProperty("currentLsn")
private String currentLsn;

@JsonProperty("publicationName")
private String publicationName;

@JsonProperty("replicationSlotName")
private String replicationSlotName;

Expand All @@ -27,6 +30,14 @@ public void setCurrentLsn(String currentLsn) {
this.currentLsn = currentLsn;
}

public String getPublicationName() {
return publicationName;
}

public void setPublicationName(String publicationName) {
this.publicationName = publicationName;
}

public String getReplicationSlotName() {
return replicationSlotName;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ public String getTypeName() {
}

public static ColumnType getByTypeId(int typeId) {
if (!TYPE_ID_MAP.containsKey(typeId)) {
throw new IllegalArgumentException("Unsupported column type id: " + typeId);
}
return TYPE_ID_MAP.get(typeId);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ private void createStreamPartition(RdsSourceConfig sourceConfig) {
final String publicationName = generatePublicationName();
final String slotName = generateReplicationSlotName();
((PostgresSchemaManager)schemaManager).createLogicalReplicationSlot(sourceConfig.getTableNames(), publicationName, slotName);
progressState.getPostgresStreamState().setPublicationName(publicationName);
progressState.getPostgresStreamState().setReplicationSlotName(slotName);
}
StreamPartition streamPartition = new StreamPartition(sourceConfig.getDbIdentifier(), progressState);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public class PostgresSchemaManager implements SchemaManager {
static final int NUM_OF_RETRIES = 3;
static final int BACKOFF_IN_MILLIS = 500;
static final String COLUMN_NAME = "COLUMN_NAME";
static final String PGOUTPUT = "pgoutput";

public PostgresSchemaManager(ConnectionManager connectionManager) {
this.connectionManager = connectionManager;
Expand All @@ -51,22 +52,34 @@ public void createLogicalReplicationSlot(final List<String> tableNames, final St
PreparedStatement statement = conn.prepareStatement(createPublicationStatement);
statement.executeUpdate();
} catch (Exception e) {
LOG.info("Failed to create publication: {}", e.getMessage());
LOG.warn("Failed to create publication: {}", e.getMessage());
}

PGConnection pgConnection = conn.unwrap(PGConnection.class);

// Create replication slot
PGReplicationConnection replicationConnection = pgConnection.getReplicationAPI();
try {
// Check if replication slot exists
String checkSlotQuery = "SELECT EXISTS (SELECT 1 FROM pg_replication_slots WHERE slot_name = ?);";
PreparedStatement checkSlotStatement = conn.prepareStatement(checkSlotQuery);
checkSlotStatement.setString(1, slotName);
try (ResultSet resultSet = checkSlotStatement.executeQuery()) {
if (resultSet.next() && resultSet.getBoolean(1)) {
LOG.info("Replication slot {} already exists. ", slotName);
return;
}
}

LOG.info("Creating replication slot {}...", slotName);
replicationConnection.createReplicationSlot()
.logical()
.withSlotName(slotName)
.withOutputPlugin("pgoutput")
.withOutputPlugin(PGOUTPUT)
.make();
LOG.info("Replication slot {} created successfully. ", slotName);
} catch (Exception e) {
LOG.info("Failed to create replication slot {}: {}", slotName, e.getMessage());
LOG.warn("Failed to create replication slot {}: {}", slotName, e.getMessage());
}
} catch (Exception e) {
LOG.error("Exception when creating replication slot. ", e);
Expand Down Expand Up @@ -95,8 +108,7 @@ public List<String> getPrimaryKeys(final String fullTableName) {
applyBackoff();
retry++;
}
LOG.warn("Failed to get primary keys for table {}", table);
return List.of();
throw new RuntimeException("Failed to get primary keys for table " + table);
}

private void applyBackoff() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,20 @@ public class LogicalReplicationClient implements ReplicationLogClient {

private static final Logger LOG = LoggerFactory.getLogger(LogicalReplicationClient.class);

static final String PUBLICATION_NAMES_KEY = "publication_names";

private final ConnectionManager connectionManager;
private final String publicationName;
private final String replicationSlotName;
private LogSequenceNumber startLsn;
private LogicalReplicationEventProcessor eventProcessor;

private volatile boolean disconnectRequested = false;

public LogicalReplicationClient(final ConnectionManager connectionManager,
final String replicationSlotName) {
final String replicationSlotName,
final String publicationName) {
this.publicationName = publicationName;
this.connectionManager = connectionManager;
this.replicationSlotName = replicationSlotName;
}
Expand All @@ -49,8 +54,7 @@ public void connect() {
.replicationStream()
.logical()
.withSlotName(replicationSlotName)
.withSlotOption("proto_version", "1")
.withSlotOption("publication_names", "my_publication");
.withSlotOption(PUBLICATION_NAMES_KEY, publicationName);
if (startLsn != null) {
logicalStreamBuilder.withStartPosition(startLsn);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,30 @@
import java.util.Map;

public class LogicalReplicationEventProcessor {
enum TupleDataType {
NEW('N'),
KEY('K'),
OLD('O');

private final char value;

TupleDataType(char value) {
this.value = value;
}

public char getValue() {
return value;
}

public static TupleDataType fromValue(char value) {
for (TupleDataType type : TupleDataType.values()) {
if (type.getValue() == value) {
return type;
}
}
throw new IllegalArgumentException("Invalid TupleDataType value: " + value);
}
}

private static final Logger LOG = LoggerFactory.getLogger(LogicalReplicationEventProcessor.class);

Expand Down Expand Up @@ -170,18 +194,18 @@ void processUpdateMessage(ByteBuffer msg) {
final List<String> primaryKeys = tableMetadata.getPrimaryKeys();
final long eventTimestampMillis = currentEventTimestamp;

char typeId = (char) msg.get();
if (typeId == 'N') {
TupleDataType tupleDataType = TupleDataType.fromValue((char) msg.get());
if (tupleDataType == TupleDataType.NEW) {
doProcess(msg, columnNames, tableMetadata, primaryKeys, eventTimestampMillis, OpenSearchBulkActions.INDEX);
LOG.debug("Processed an UPDATE message with table id: {}", tableId);
} else if (typeId == 'K') {
} else if (tupleDataType == TupleDataType.KEY) {
// Primary keys were changed
doProcess(msg, columnNames, tableMetadata, primaryKeys, eventTimestampMillis, OpenSearchBulkActions.DELETE);
msg.get(); // should be a char 'N'
doProcess(msg, columnNames, tableMetadata, primaryKeys, eventTimestampMillis, OpenSearchBulkActions.INDEX);
LOG.debug("Processed an UPDATE message with table id: {} and primary key(s) were changed", tableId);

} else if (typeId == 'O') {
} else if (tupleDataType == TupleDataType.OLD) {
// Replica Identity is set to full, containing both old and new row data
Map<String, Object> oldRowDataMap = getRowDataMap(msg, columnNames);
msg.get(); // should be a char 'N'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.dataprepper.plugins.source.rds.RdsSourceConfig;
import org.opensearch.dataprepper.plugins.source.rds.configuration.EngineType;
import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition;
import org.opensearch.dataprepper.plugins.source.rds.coordination.state.PostgresStreamState;
import org.opensearch.dataprepper.plugins.source.rds.model.DbMetadata;
import org.opensearch.dataprepper.plugins.source.rds.schema.ConnectionManager;
import org.opensearch.dataprepper.plugins.source.rds.schema.ConnectionManagerFactory;
Expand Down Expand Up @@ -67,13 +68,15 @@ private BinaryLogClient createBinaryLogClient() {
}

private LogicalReplicationClient createLogicalReplicationClient(StreamPartition streamPartition) {
final String replicationSlotName = streamPartition.getProgressState().get().getPostgresStreamState().getReplicationSlotName();
final PostgresStreamState postgresStreamState = streamPartition.getProgressState().get().getPostgresStreamState();
final String publicationName = postgresStreamState.getPublicationName();
final String replicationSlotName = postgresStreamState.getReplicationSlotName();
if (replicationSlotName == null) {
throw new NoSuchElementException("Replication slot name is not found in progress state.");
}
final ConnectionManagerFactory connectionManagerFactory = new ConnectionManagerFactory(sourceConfig, dbMetadata);
final ConnectionManager connectionManager = connectionManagerFactory.getConnectionManager();
return new LogicalReplicationClient(connectionManager, replicationSlotName);
return new LogicalReplicationClient(connectionManager, publicationName, replicationSlotName);
}

public void setSSLMode(SSLMode sslMode) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.UUID;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

Expand All @@ -51,13 +55,14 @@ void setUp() {
}

@Test
void test_createLogicalReplicationSlot() throws SQLException {
void test_createLogicalReplicationSlot_creates_slot_if_not_exists() throws SQLException {
final List<String> tableNames = List.of("table1", "table2");
final String publicationName = "publication1";
final String slotName = "slot1";
final PreparedStatement preparedStatement = mock(PreparedStatement.class);
final PGConnection pgConnection = mock(PGConnection.class);
final PGReplicationConnection replicationConnection = mock(PGReplicationConnection.class);
final ResultSet resultSet = mock(ResultSet.class);
final ChainedCreateReplicationSlotBuilder chainedCreateSlotBuilder = mock(ChainedCreateReplicationSlotBuilder.class);
final ChainedLogicalCreateSlotBuilder slotBuilder = mock(ChainedLogicalCreateSlotBuilder.class);

Expand All @@ -66,6 +71,8 @@ void test_createLogicalReplicationSlot() throws SQLException {
when(connectionManager.getConnection()).thenReturn(connection);
when(connection.prepareStatement(statementCaptor.capture())).thenReturn(preparedStatement);
when(connection.unwrap(PGConnection.class)).thenReturn(pgConnection);
when(preparedStatement.executeQuery()).thenReturn(resultSet);
when(resultSet.next()).thenReturn(false); // Replication slot doesn't exist
when(pgConnection.getReplicationAPI()).thenReturn(replicationConnection);
when(replicationConnection.createReplicationSlot()).thenReturn(chainedCreateSlotBuilder);
when(chainedCreateSlotBuilder.logical()).thenReturn(slotBuilder);
Expand All @@ -74,9 +81,11 @@ void test_createLogicalReplicationSlot() throws SQLException {

schemaManager.createLogicalReplicationSlot(tableNames, publicationName, slotName);

String statement = statementCaptor.getValue();
assertThat(statement, is("CREATE PUBLICATION " + publicationName + " FOR TABLE " + String.join(", ", tableNames) + ";"));
List<String> statements = statementCaptor.getAllValues();
assertThat(statements.get(0), is("CREATE PUBLICATION " + publicationName + " FOR TABLE " + String.join(", ", tableNames) + ";"));
assertThat(statements.get(1), is("SELECT EXISTS (SELECT 1 FROM pg_replication_slots WHERE slot_name = ?);"));
verify(preparedStatement).executeUpdate();
verify(preparedStatement).executeQuery();
verify(pgConnection).getReplicationAPI();
verify(replicationConnection).createReplicationSlot();
verify(chainedCreateSlotBuilder).logical();
Expand All @@ -85,6 +94,72 @@ void test_createLogicalReplicationSlot() throws SQLException {
verify(slotBuilder).make();
}

@Test
void test_createLogicalReplicationSlot_skip_creation_if_slot_exists() throws SQLException {
final List<String> tableNames = List.of("table1", "table2");
final String publicationName = "publication1";
final String slotName = "slot1";
final PreparedStatement preparedStatement = mock(PreparedStatement.class);
final PGConnection pgConnection = mock(PGConnection.class);
final PGReplicationConnection replicationConnection = mock(PGReplicationConnection.class);
final ResultSet resultSet = mock(ResultSet.class);

ArgumentCaptor<String> statementCaptor = ArgumentCaptor.forClass(String.class);

when(connectionManager.getConnection()).thenReturn(connection);
when(connection.prepareStatement(statementCaptor.capture())).thenReturn(preparedStatement);
when(connection.unwrap(PGConnection.class)).thenReturn(pgConnection);
when(preparedStatement.executeQuery()).thenReturn(resultSet);
when(resultSet.next()).thenReturn(true); // Replication slot exists
when(resultSet.getBoolean(1)).thenReturn(true);
when(pgConnection.getReplicationAPI()).thenReturn(replicationConnection);

schemaManager.createLogicalReplicationSlot(tableNames, publicationName, slotName);

List<String> statements = statementCaptor.getAllValues();
assertThat(statements.get(0), is("CREATE PUBLICATION " + publicationName + " FOR TABLE " + String.join(", ", tableNames) + ";"));
assertThat(statements.get(1), is("SELECT EXISTS (SELECT 1 FROM pg_replication_slots WHERE slot_name = ?);"));
verify(preparedStatement).executeUpdate();
verify(preparedStatement).executeQuery();
verify(pgConnection).getReplicationAPI();
verify(replicationConnection, never()).createReplicationSlot();
}

@Test
void test_getPrimaryKeys_returns_primary_keys() throws SQLException {
final String database = UUID.randomUUID().toString();
final String schema = UUID.randomUUID().toString();
final String table = UUID.randomUUID().toString();
final String fullTableName = database + "." + schema + "." + table;
final ResultSet resultSet = mock(ResultSet.class);
final String primaryKeyName = UUID.randomUUID().toString();

when(connectionManager.getConnection()).thenReturn(connection);
when(connection.getMetaData().getPrimaryKeys(database, schema, table)).thenReturn(resultSet);
when(resultSet.next()).thenReturn(true, false);
when(resultSet.getString("COLUMN_NAME")).thenReturn(primaryKeyName);

final List<String> primaryKeys = schemaManager.getPrimaryKeys(fullTableName);

assertThat(primaryKeys.size(), is(1));
assertThat(primaryKeys.get(0), is(primaryKeyName));
}

@Test
void test_getPrimaryKeys_throws_exception_if_failed() throws SQLException {
final String database = UUID.randomUUID().toString();
final String schema = UUID.randomUUID().toString();
final String table = UUID.randomUUID().toString();
final String fullTableName = database + "." + schema + "." + table;
final ResultSet resultSet = mock(ResultSet.class);

when(connectionManager.getConnection()).thenReturn(connection);
when(connection.getMetaData().getPrimaryKeys(database, schema, table)).thenReturn(resultSet);
when(resultSet.next()).thenThrow(RuntimeException.class);

assertThrows(RuntimeException.class, () -> schemaManager.getPrimaryKeys(fullTableName));
}

private PostgresSchemaManager createObjectUnderTest() {
return new PostgresSchemaManager(connectionManager);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@ class LogicalReplicationClientTest {
@Mock
private LogicalReplicationEventProcessor eventProcessor;

private String publicationName;
private String replicationSlotName;
private LogicalReplicationClient logicalReplicationClient;

@BeforeEach
void setUp() {
publicationName = UUID.randomUUID().toString();
replicationSlotName = UUID.randomUUID().toString();
logicalReplicationClient = createObjectUnderTest();
logicalReplicationClient.setEventProcessor(eventProcessor);
Expand Down Expand Up @@ -86,6 +88,6 @@ void test_connect() throws SQLException, InterruptedException {
}

private LogicalReplicationClient createObjectUnderTest() {
return new LogicalReplicationClient(connectionManager, replicationSlotName);
return new LogicalReplicationClient(connectionManager, replicationSlotName, publicationName);
}
}

0 comments on commit f4e1683

Please sign in to comment.