Skip to content

Commit

Permalink
Address minor issues
Browse files Browse the repository at this point in the history
Signed-off-by: Hai Yan <[email protected]>
  • Loading branch information
oeyh committed Jan 17, 2025
1 parent 81c5883 commit a3bec6f
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,15 @@ public void createLogicalReplicationSlot(final List<String> tableNames, final St

@Override
public List<String> getPrimaryKeys(final String fullTableName) {
final String schema = fullTableName.split("\\.")[0];
final String table = fullTableName.split("\\.")[1];
final String[] splits = fullTableName.split("\\.");
final String database = splits[0];
final String schema = splits[1];
final String table = splits[2];
int retry = 0;
while (retry <= NUM_OF_RETRIES) {
final List<String> primaryKeys = new ArrayList<>();
try (final Connection connection = connectionManager.getConnection()) {
try (final ResultSet rs = connection.getMetaData().getPrimaryKeys(null, schema, table)) {
try (final ResultSet rs = connection.getMetaData().getPrimaryKeys(database, schema, table)) {
while (rs.next()) {
primaryKeys.add(rs.getString(COLUMN_NAME));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

package org.opensearch.dataprepper.plugins.source.rds.stream;

import org.opensearch.dataprepper.plugins.source.rds.schema.PostgresConnectionManager;
import org.opensearch.dataprepper.plugins.source.rds.schema.ConnectionManager;
import org.postgresql.PGConnection;
import org.postgresql.replication.LogSequenceNumber;
import org.postgresql.replication.PGReplicationStream;
Expand All @@ -25,14 +25,14 @@ public class LogicalReplicationClient implements ReplicationLogClient {

private static final Logger LOG = LoggerFactory.getLogger(LogicalReplicationClient.class);

private final PostgresConnectionManager connectionManager;
private final ConnectionManager connectionManager;
private final String replicationSlotName;
private LogSequenceNumber startLsn;
private LogicalReplicationEventProcessor eventProcessor;

private volatile boolean disconnectRequested = false;

public LogicalReplicationClient(final PostgresConnectionManager connectionManager,
public LogicalReplicationClient(final ConnectionManager connectionManager,
final String replicationSlotName) {
this.connectionManager = connectionManager;
this.replicationSlotName = replicationSlotName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public class LogicalReplicationEventProcessor {
static final int DEFAULT_BUFFER_BATCH_SIZE = 1_000;

private final StreamPartition streamPartition;
private final RdsSourceConfig sourceConfig;
private final StreamRecordConverter recordConverter;
private final Buffer<Record<Event>> buffer;
private final BufferAccumulator<Record<Event>> bufferAccumulator;
Expand All @@ -56,6 +57,7 @@ public LogicalReplicationEventProcessor(final StreamPartition streamPartition,
final Buffer<Record<Event>> buffer,
final String s3Prefix) {
this.streamPartition = streamPartition;
this.sourceConfig = sourceConfig;
recordConverter = new StreamRecordConverter(s3Prefix, sourceConfig.getPartitionCount());
this.buffer = buffer;
bufferAccumulator = BufferAccumulator.create(buffer, DEFAULT_BUFFER_BATCH_SIZE, BUFFER_TIMEOUT);
Expand Down Expand Up @@ -122,7 +124,7 @@ void processRelationMessage(ByteBuffer msg) {
columnNames.add(columnName);
}

final List<String> primaryKeys = getPrimaryKeys(schemaName + "." + tableName);
final List<String> primaryKeys = getPrimaryKeys(schemaName, tableName);
final TableMetadata tableMetadata = new TableMetadata(
tableName, schemaName, columnNames, primaryKeys);

Expand Down Expand Up @@ -301,9 +303,10 @@ private String getNullTerminatedString(ByteBuffer msg) {
return sb.toString();
}

private List<String> getPrimaryKeys(String fullTableName) {
private List<String> getPrimaryKeys(String schemaName, String tableName) {
final String databaseName = sourceConfig.getTableNames().get(0).split("\\.")[0];
StreamProgressState progressState = streamPartition.getProgressState().get();

return progressState.getPrimaryKeyMap().get(fullTableName);
return progressState.getPrimaryKeyMap().get(databaseName + "." + schemaName + "." + tableName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import org.opensearch.dataprepper.plugins.source.rds.configuration.EngineType;
import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition;
import org.opensearch.dataprepper.plugins.source.rds.model.DbMetadata;
import org.opensearch.dataprepper.plugins.source.rds.schema.PostgresConnectionManager;
import org.opensearch.dataprepper.plugins.source.rds.schema.ConnectionManager;
import org.opensearch.dataprepper.plugins.source.rds.schema.ConnectionManagerFactory;
import software.amazon.awssdk.services.rds.RdsClient;

import java.util.List;
Expand Down Expand Up @@ -70,13 +71,8 @@ private LogicalReplicationClient createLogicalReplicationClient(StreamPartition
if (replicationSlotName == null) {
throw new NoSuchElementException("Replication slot name is not found in progress state.");
}
final PostgresConnectionManager connectionManager = new PostgresConnectionManager(
dbMetadata.getEndpoint(),
dbMetadata.getPort(),
username,
password,
!sourceConfig.getTlsConfig().isInsecure(),
getDatabaseName(sourceConfig.getTableNames()));
final ConnectionManagerFactory connectionManagerFactory = new ConnectionManagerFactory(sourceConfig, dbMetadata);
final ConnectionManager connectionManager = connectionManagerFactory.getConnectionManager();
return new LogicalReplicationClient(connectionManager, replicationSlotName);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ void test_correct_process_method_invoked_for_relation_message() {
when(message.get()).thenReturn((byte) 'R');
final StreamProgressState progressState = mock(StreamProgressState.class);
when(streamPartition.getProgressState()).thenReturn(Optional.of(progressState));
when(progressState.getPrimaryKeyMap()).thenReturn(Map.of(".", List.of("key1", "key2")));
when(sourceConfig.getTableNames()).thenReturn(List.of("database.schema.table1"));
when(progressState.getPrimaryKeyMap()).thenReturn(Map.of("database.schema.table1", List.of("key1", "key2")));

objectUnderTest.process(message);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void test_create_logical_replication_client() {
final List<String> tableNames = List.of("table1", "table2");

when(sourceConfig.getEngine()).thenReturn(EngineType.POSTGRES);
when(sourceConfig.getTlsConfig().isInsecure()).thenReturn(false);
when(sourceConfig.isTlsEnabled()).thenReturn(true);
when(sourceConfig.getTableNames()).thenReturn(tableNames);
when(sourceConfig.getAuthenticationConfig().getUsername()).thenReturn(username);
when(sourceConfig.getAuthenticationConfig().getPassword()).thenReturn(password);
Expand Down

0 comments on commit a3bec6f

Please sign in to comment.