> results = enqueueNewPartitions();
super.wakeup();
- results.forEach(TransactionalRequestResult::await);
+ results.forEach(future -> {
+ try {
+ future.get();
+ } catch (Exception e) {
+ throw new RuntimeException("Error while flushing new partitions", e);
+ }
+ });
}
/**
@@ -212,97 +217,17 @@ private void flushNewPartitions() throws ProducerException {
* If there are no new transactions we return a {@link TransactionalRequestResult} that is
* already done.
*/
- private Set enqueueNewPartitions() throws ProducerException {
- Set transactionalRequestResults = new HashSet<>();
+ private Set> enqueueNewPartitions() throws ProducerException {
+ Set> transactionalRequestResults = new HashSet<>();
Set transactionManagers = super.getTransactionManagers();
for (Object transactionManager : transactionManagers) {
synchronized (transactionManager) {
- Object newPartitionsInTransaction = getField(transactionManager, "newPartitionsInTransaction");
- Object newPartitionsInTransactionIsEmpty = invoke(newPartitionsInTransaction, "isEmpty");
- TransactionalRequestResult result;
- if (newPartitionsInTransactionIsEmpty instanceof Boolean && !((Boolean) newPartitionsInTransactionIsEmpty)) {
- Object txnRequestHandler = invoke(transactionManager, "addPartitionsToTransactionHandler");
- invoke(transactionManager, "enqueueRequest", new Class[]{txnRequestHandler.getClass().getSuperclass()}, new Object[]{txnRequestHandler});
- result = (TransactionalRequestResult) getField(txnRequestHandler, txnRequestHandler.getClass().getSuperclass(), "result");
- } else {
- // we don't have an operation but this operation string is also used in
- // addPartitionsToTransactionHandler.
- result = new TransactionalRequestResult("AddPartitionsToTxn");
- result.done();
- }
- transactionalRequestResults.add(result);
+ Future transactionalRequestResultFuture =
+ TransactionManagerUtils.enqueueInFlightTransactions(transactionManager);
+ transactionalRequestResults.add(transactionalRequestResultFuture);
}
}
return transactionalRequestResults;
}
- protected static Enum> getEnum(String enumFullName) {
- String[] x = enumFullName.split("\\.(?=[^\\.]+$)");
- if (x.length == 2) {
- String enumClassName = x[0];
- String enumName = x[1];
- try {
- Class cl = (Class) Class.forName(enumClassName);
- return Enum.valueOf(cl, enumName);
- } catch (ClassNotFoundException e) {
- throw new RuntimeException("Incompatible KafkaProducer version", e);
- }
- }
- return null;
- }
-
- protected static Object invoke(Object object, String methodName, Object... args) {
- Class>[] argTypes = new Class[args.length];
- for (int i = 0; i < args.length; i++) {
- argTypes[i] = args[i].getClass();
- }
- return invoke(object, methodName, argTypes, args);
- }
-
- private static Object invoke(Object object, String methodName, Class>[] argTypes, Object[] args) {
- try {
- Method method = object.getClass().getDeclaredMethod(methodName, argTypes);
- method.setAccessible(true);
- return method.invoke(object, args);
- } catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) {
- throw new RuntimeException("Incompatible PscProducer version", e);
- }
- }
-
- /**
- * Gets and returns the field {@code fieldName} from the given Object {@code object} using
- * reflection.
- */
- protected static Object getField(Object object, String fieldName) {
- return getField(object, object.getClass(), fieldName);
- }
-
- /**
- * Gets and returns the field {@code fieldName} from the given Object {@code object} using
- * reflection.
- */
- private static Object getField(Object object, Class> clazz, String fieldName) {
- try {
- Field field = clazz.getDeclaredField(fieldName);
- field.setAccessible(true);
- return field.get(object);
- } catch (NoSuchFieldException | IllegalAccessException e) {
- throw new RuntimeException("Incompatible KafkaProducer version", e);
- }
- }
-
- /**
- * Sets the field {@code fieldName} on the given Object {@code object} to {@code value} using
- * reflection.
- */
- protected static void setField(Object object, String fieldName, Object value) {
- try {
- Field field = object.getClass().getDeclaredField(fieldName);
- field.setAccessible(true);
- field.set(object, value);
- } catch (NoSuchFieldException | IllegalAccessException e) {
- throw new RuntimeException("Incompatible KafkaProducer version", e);
- }
- }
-
}
diff --git a/psc-integration-test/pom.xml b/psc-integration-test/pom.xml
index ae7cd5e..f8b7b42 100644
--- a/psc-integration-test/pom.xml
+++ b/psc-integration-test/pom.xml
@@ -5,7 +5,7 @@
psc-java-oss
com.pinterest.psc
- 3.2.0
+ 3.2.1-SNAPSHOT
../pom.xml
4.0.0
diff --git a/psc-integration-test/src/test/java/com/pinterest/psc/producer/TestMultiKafkaClusterBackends.java b/psc-integration-test/src/test/java/com/pinterest/psc/producer/TestMultiKafkaClusterBackends.java
index 1c410f5..2dd282e 100644
--- a/psc-integration-test/src/test/java/com/pinterest/psc/producer/TestMultiKafkaClusterBackends.java
+++ b/psc-integration-test/src/test/java/com/pinterest/psc/producer/TestMultiKafkaClusterBackends.java
@@ -40,8 +40,10 @@ public class TestMultiKafkaClusterBackends {
private static final int partitions1 = 12;
private static final String topic2 = "topic2";
private static final int partitions2 = 24;
+ private static final String topic3 = "topic3";
+ private static final int partitions3 = 36;
private KafkaCluster kafkaCluster1, kafkaCluster2;
- private String topicUriStr1, topicUriStr2;
+ private String topicUriStr1, topicUriStr2, topicUriStr3;
/**
* Initializes two Kafka clusters that are commonly used by all tests, and creates a single topic on each.
@@ -61,10 +63,14 @@ public void setup() throws IOException, InterruptedException {
kafkaCluster2 = new KafkaCluster("plaintext", "region2", "cluster2", 9092);
topicUriStr2 = String.format("%s:%s%s:kafka:env:cloud_%s::%s:%s",
- kafkaCluster2.getTransport(), TopicUri.SEPARATOR, TopicUri.STANDARD, kafkaCluster2.getRegion(), kafkaCluster2.getCluster(), topic1);
+ kafkaCluster2.getTransport(), TopicUri.SEPARATOR, TopicUri.STANDARD, kafkaCluster2.getRegion(), kafkaCluster2.getCluster(), topic2);
+
+ topicUriStr3 = String.format("%s:%s%s:kafka:env:cloud_%s::%s:%s",
+ kafkaCluster1.getTransport(), TopicUri.SEPARATOR, TopicUri.STANDARD, kafkaCluster1.getRegion(), kafkaCluster1.getCluster(), topic3);
PscTestUtils.createTopicAndVerify(sharedKafkaTestResource1, topic1, partitions1);
PscTestUtils.createTopicAndVerify(sharedKafkaTestResource1, topic2, partitions2);
+ PscTestUtils.createTopicAndVerify(sharedKafkaTestResource1, topic3, partitions3);
}
/**
@@ -78,12 +84,16 @@ public void setup() throws IOException, InterruptedException {
public void tearDown() throws ExecutionException, InterruptedException {
PscTestUtils.deleteTopicAndVerify(sharedKafkaTestResource1, topic1);
PscTestUtils.deleteTopicAndVerify(sharedKafkaTestResource1, topic2);
+ PscTestUtils.deleteTopicAndVerify(sharedKafkaTestResource1, topic3);
Thread.sleep(1000);
}
/**
* Verifies that backend producers each have their own transactional states that could be different at times.
*
+ * Also, verifies that the PscProducer throws the appropriate exception when trying to send messages via a
+ * new backend producer while the PscProducer is already transactional.
+ *
* @throws ConfigurationException
* @throws ProducerException
*/
@@ -100,19 +110,43 @@ public void testTransactionalProducersStates() throws ConfigurationException, Pr
PscBackendProducer backendProducer1 = pscProducer.getBackendProducer(topicUriStr1);
assertEquals(PscProducer.TransactionalState.BEGUN, pscProducer.getBackendProducerState(backendProducer1));
- Exception e = assertThrows(ProducerException.class, () -> pscProducer.beginTransaction());
- assertEquals("Invalid transaction state: consecutive calls to beginTransaction().", e.getMessage());
+ PscProducerMessage producerMessageTopic1 = new PscProducerMessage<>(topicUriStr1, 0);
+ pscProducer.send(producerMessageTopic1);
+ assertEquals(PscProducer.TransactionalState.IN_TRANSACTION, pscProducer.getBackendProducerState(backendProducer1));
+
+ PscProducerMessage producerMessageTopic3 = new PscProducerMessage<>(topicUriStr3, 1);
+ pscProducer.send(producerMessageTopic3);
+ assertEquals(PscProducer.TransactionalState.IN_TRANSACTION, pscProducer.getBackendProducerState(backendProducer1));
+
+ assertEquals(1, pscProducer.getBackendProducers().size()); // topic1 and topic3 belong to same cluster so there should only be one backend producer at this point
+ assertEquals(backendProducer1, pscProducer.getBackendProducers().iterator().next());
+
+ assertEquals(PscProducer.TransactionalState.IN_TRANSACTION, pscProducer.getBackendProducerState(backendProducer1));
+ assertEquals(PscProducer.TransactionalState.INIT_AND_BEGUN, pscProducer.getTransactionalState());
+
+ pscProducer.commitTransaction();
+
+ assertEquals(PscProducer.TransactionalState.INIT_AND_BEGUN, pscProducer.getTransactionalState());
+ assertEquals(PscProducer.TransactionalState.READY, pscProducer.getBackendProducerState(backendProducer1));
+
+ pscProducer.beginTransaction();
+
+ assertEquals(PscProducer.TransactionalState.INIT_AND_BEGUN, pscProducer.getTransactionalState());
+ assertEquals(PscProducer.TransactionalState.BEGUN, pscProducer.getBackendProducerState(backendProducer1));
- PscProducerMessage producerMessage = new PscProducerMessage<>(topicUriStr2, 0);
- pscProducer.send(producerMessage);
+ PscProducerMessage producerMessageTopic2 = new PscProducerMessage<>(topicUriStr2, 0);
+ Exception e = assertThrows(ProducerException.class, () -> pscProducer.send(producerMessageTopic2));
+ assertEquals("Invalid call to send() which would have created a new backend producer. This is not allowed when the PscProducer is already transactional.", e.getMessage());
- assertEquals(2, pscProducer.getBackendProducers().size());
+ assertEquals(1, pscProducer.getBackendProducers().size());
- PscBackendProducer backendProducer2 = pscProducer.getBackendProducer(topicUriStr2);
- assertNotEquals(backendProducer1, backendProducer2);
+ pscProducer.send(producerMessageTopic1); // this should go through
+ assertEquals(PscProducer.TransactionalState.INIT_AND_BEGUN, pscProducer.getTransactionalState());
+ assertEquals(PscProducer.TransactionalState.IN_TRANSACTION, pscProducer.getBackendProducerState(backendProducer1));
+ pscProducer.commitTransaction();
+ assertEquals(PscProducer.TransactionalState.INIT_AND_BEGUN, pscProducer.getTransactionalState());
assertEquals(PscProducer.TransactionalState.READY, pscProducer.getBackendProducerState(backendProducer1));
- assertEquals(PscProducer.TransactionalState.IN_TRANSACTION, pscProducer.getBackendProducerState(backendProducer2));
pscProducer.close();
}
diff --git a/psc-logging/pom.xml b/psc-logging/pom.xml
index 536ae6c..fdc5c26 100644
--- a/psc-logging/pom.xml
+++ b/psc-logging/pom.xml
@@ -5,7 +5,7 @@
psc-java-oss
com.pinterest.psc
- 3.2.0
+ 3.2.1-SNAPSHOT
../pom.xml
4.0.0
diff --git a/psc/pom.xml b/psc/pom.xml
index bdaed1c..5901608 100644
--- a/psc/pom.xml
+++ b/psc/pom.xml
@@ -5,7 +5,7 @@
psc-java-oss
com.pinterest.psc
- 3.2.0
+ 3.2.1-SNAPSHOT
../pom.xml
4.0.0
diff --git a/psc/src/main/java/com/pinterest/psc/producer/PscProducer.java b/psc/src/main/java/com/pinterest/psc/producer/PscProducer.java
index 8ff1fcc..3ab1dbf 100644
--- a/psc/src/main/java/com/pinterest/psc/producer/PscProducer.java
+++ b/psc/src/main/java/com/pinterest/psc/producer/PscProducer.java
@@ -272,6 +272,12 @@ public Future send(PscProducerMessage pscProducerMessage) throw
public Future send(PscProducerMessage pscProducerMessage, Callback callback) throws ProducerException, ConfigurationException {
ensureOpen();
validateProducerMessage(pscProducerMessage);
+ if (transactionalState.get() != TransactionalState.NON_TRANSACTIONAL &&
+ !backendProducers.isEmpty() &&
+ !pscBackendProducerByTopicUriPrefix.containsKey(pscProducerMessage.getTopicUriPartition().getTopicUri().getTopicUriPrefix())) {
+ throw new ProducerException("Invalid call to send() which would have created a new backend producer. This is not allowed when the PscProducer" +
+ " is already transactional.");
+ }
PscBackendProducer backendProducer =
getBackendProducerForTopicUri(pscProducerMessage.getTopicUriPartition().getTopicUri());
@@ -303,7 +309,7 @@ public Future send(PscProducerMessage pscProducerMessage, Callb
}
break;
case BEGUN:
- transactionalStateByBackendProducer.replace(backendProducer, TransactionalState.INIT_AND_BEGUN, TransactionalState.IN_TRANSACTION);
+ transactionalStateByBackendProducer.replace(backendProducer, TransactionalState.BEGUN, TransactionalState.IN_TRANSACTION);
break;
}
@@ -438,6 +444,17 @@ protected PscProducerTransactionalProperties initTransactions(String topicUriStr
TopicUri topicUri = validateTopicUri(topicUriString);
PscBackendProducer backendProducer = getBackendProducerForTopicUri(topicUri);
+ initTransactions(backendProducer);
+ return backendProducer.getTransactionalProperties();
+ }
+
+ /**
+ * Centralized logic for initializing transactions for a given backend producer.
+ *
+ * @param backendProducer the backendProducer to initialize transactions for
+ * @throws ProducerException if the producer is already closed, or is not in the proper state to initialize transactions
+ */
+ private void initTransactions(PscBackendProducer backendProducer) throws ProducerException {
if (!transactionalStateByBackendProducer.get(backendProducer).equals(TransactionalState.NON_TRANSACTIONAL) &&
!transactionalStateByBackendProducer.get(backendProducer).equals(TransactionalState.INIT_AND_BEGUN))
throw new ProducerException("Invalid transaction state: initializing transactions works only once for a PSC producer.");
@@ -456,7 +473,6 @@ protected PscProducerTransactionalProperties initTransactions(String topicUriStr
}
this.beginTransaction();
- return backendProducer.getTransactionalProperties();
}
/**
@@ -671,6 +687,22 @@ public Set getTransactionManagers() throws ProducerException {
return transactionManagers;
}
+ /**
+ * Get exactly one transaction manager. If there is more than one transaction managers / backend producers,
+ * this method will throw an exception. Note that this is added due to a dependency by Flink connector,
+ * and should not need to be used otherwise.
+ *
+ * @return the transaction manager object
+ * @throws ProducerException if there is an error in the backend producer or if there is more than one transaction managers
+ */
+ @InterfaceStability.Evolving
+ protected Object getExactlyOneTransactionManager() throws ProducerException {
+ Set transactionManagers = getTransactionManagers();
+ if (transactionManagers.size() != 1)
+ throw new ProducerException("Expected exactly one transaction manager, but found " + transactionManagers.size());
+ return transactionManagers.iterator().next();
+ }
+
/**
* This API is added due to a dependency by Flink connector, and should not be normally used by a typical producer.
*
@@ -730,6 +762,7 @@ public void flush() throws ProducerException {
*
* @throws ProducerException if closing some backend producer fails
*/
+ @Override
public void close() throws ProducerException {
close(Duration.ofMillis(Long.MAX_VALUE));
}
diff --git a/psc/src/main/java/com/pinterest/psc/producer/kafka/PscKafkaProducer.java b/psc/src/main/java/com/pinterest/psc/producer/kafka/PscKafkaProducer.java
index 9bbe30c..d597051 100644
--- a/psc/src/main/java/com/pinterest/psc/producer/kafka/PscKafkaProducer.java
+++ b/psc/src/main/java/com/pinterest/psc/producer/kafka/PscKafkaProducer.java
@@ -31,6 +31,7 @@
import com.pinterest.psc.producer.PscBackendProducer;
import com.pinterest.psc.producer.PscProducerMessage;
import com.pinterest.psc.producer.PscProducerTransactionalProperties;
+import com.pinterest.psc.producer.transaction.TransactionManagerUtils;
import org.apache.kafka.clients.consumer.OffsetAndMetadata;
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.clients.producer.ProducerConfig;
@@ -113,8 +114,16 @@ private KafkaProducer getNewKafkaProducer(boolean bumpRef) {
// they must be -1 in order for initTransaction to succeed - we don't yet know why
// in some cases they are not -1 even though they should be initialized to -1 on calling the KafkaProducer
// constructor
- setProducerId(-1);
- setEpoch((short) -1);
+ try {
+ // the transactionManager will be non-null iff the producer enables idempotence and transactions
+ if (getTransactionManager() != null) {
+ setProducerId(-1);
+ setEpoch((short) -1);
+ }
+ } catch (ProducerException e) {
+ logger.warn("Error in setting producer ID and epoch." +
+ " This might be ok if the producer won't be transactional.", e);
+ }
updateStatus(kafkaProducer, true);
PscMetricRegistryManager.getInstance().incrementBackendCounterMetric(
null,
@@ -731,42 +740,11 @@ public void resumeTransaction(PscProducerTransactionalProperties pscProducerTran
handleUninitializedKafkaProducer("resumeTransaction()");
try {
- Object transactionManager = PscCommon.getField(kafkaProducer, "transactionManager");
+ Object transactionManager = getTransactionManager();
+ if (transactionManager == null)
+ handleNullTransactionManager();
synchronized (kafkaProducer) {
- Object topicPartitionBookkeeper =
- PscCommon.getField(transactionManager, "topicPartitionBookkeeper");
-
- PscCommon.invoke(
- transactionManager,
- "transitionTo",
- PscCommon.getEnum(
- "org.apache.kafka.clients.producer.internals.TransactionManager$State.INITIALIZING"
- )
- );
-
- PscCommon.invoke(topicPartitionBookkeeper, "reset");
-
- Object producerIdAndEpoch = PscCommon.getField(transactionManager, "producerIdAndEpoch");
- PscCommon.setField(producerIdAndEpoch, "producerId", pscProducerTransactionalProperties.getProducerId());
- PscCommon.setField(producerIdAndEpoch, "epoch", pscProducerTransactionalProperties.getEpoch());
-
- PscCommon.invoke(
- transactionManager,
- "transitionTo",
- PscCommon.getEnum(
- "org.apache.kafka.clients.producer.internals.TransactionManager$State.READY"
- )
- );
-
- PscCommon.invoke(
- transactionManager,
- "transitionTo",
- PscCommon.getEnum(
- "org.apache.kafka.clients.producer.internals.TransactionManager$State.IN_TRANSACTION"
- )
- );
-
- PscCommon.setField(transactionManager, "transactionStarted", true);
+ TransactionManagerUtils.resumeTransaction(transactionManager, pscProducerTransactionalProperties);
}
} catch (Exception exception) {
handleException(exception, true);
@@ -778,11 +756,23 @@ public PscProducerTransactionalProperties getTransactionalProperties() throws Pr
if (kafkaProducer == null)
handleUninitializedKafkaProducer("getTransactionalProperties()");
- Object transactionManager = PscCommon.getField(kafkaProducer, "transactionManager");
- Object producerIdAndEpoch = PscCommon.getField(transactionManager, "producerIdAndEpoch");
+ Object transactionManager = getTransactionManager();
+ if (transactionManager == null)
+ handleNullTransactionManager();
return new PscProducerTransactionalProperties(
- (long) PscCommon.getField(producerIdAndEpoch, "producerId"),
- (short) PscCommon.getField(producerIdAndEpoch, "epoch")
+ TransactionManagerUtils.getProducerId(transactionManager),
+ TransactionManagerUtils.getEpoch(transactionManager)
+ );
+ }
+
+ private void handleNullTransactionManager() throws ProducerException {
+ handleException(
+ new BackendProducerException(
+ "Attempting to get transactionManager in KafkaProducer when " +
+ "transactionManager is null. This indicates that the KafkaProducer " +
+ "was not initialized to be transaction-ready.",
+ PscUtils.BACKEND_TYPE_KAFKA
+ ), true
);
}
@@ -798,30 +788,46 @@ public PscProducerTransactionalProperties getTransactionalProperties() throws Pr
}
private long getProducerId() {
- Object transactionManager = PscCommon.getField(kafkaProducer, "transactionManager");
- Object producerIdAndEpoch = PscCommon.getField(transactionManager, "producerIdAndEpoch");
- return (long) PscCommon.getField(producerIdAndEpoch, "producerId");
+ try {
+ Object transactionManager = getTransactionManager();
+ if (transactionManager == null)
+ handleNullTransactionManager();
+ return TransactionManagerUtils.getProducerId(transactionManager);
+ } catch (ProducerException e) {
+ throw new RuntimeException("Unable to get producerId", e);
+ }
}
private void setProducerId(long producerId) {
- Object transactionManager = PscCommon.getField(kafkaProducer, "transactionManager");
- if (transactionManager != null) {
- Object producerIdAndEpoch = PscCommon.getField(transactionManager, "producerIdAndEpoch");
- PscCommon.setField(producerIdAndEpoch, "producerId", producerId);
+ try {
+ Object transactionManager = getTransactionManager();
+ if (transactionManager == null)
+ handleNullTransactionManager();
+ TransactionManagerUtils.setProducerId(transactionManager, producerId);
+ } catch (ProducerException e) {
+ throw new RuntimeException("Unable to set producerId", e);
}
}
public short getEpoch() {
- Object transactionManager = PscCommon.getField(kafkaProducer, "transactionManager");
- Object producerIdAndEpoch = PscCommon.getField(transactionManager, "producerIdAndEpoch");
- return (short) PscCommon.getField(producerIdAndEpoch, "epoch");
+ try {
+ Object transactionManager = getTransactionManager();
+ if (transactionManager == null)
+ handleNullTransactionManager();
+ return TransactionManagerUtils.getEpoch(transactionManager);
+ } catch (ProducerException e) {
+ throw new RuntimeException("Unable to get epoch", e);
+ }
}
private void setEpoch(short epoch) {
- Object transactionManager = PscCommon.getField(kafkaProducer, "transactionManager");
- if (transactionManager != null) {
- Object producerIdAndEpoch = PscCommon.getField(transactionManager, "producerIdAndEpoch");
- PscCommon.setField(producerIdAndEpoch, "epoch", epoch);
+ try {
+ Object transactionManager = getTransactionManager();
+ if (transactionManager == null)
+ handleNullTransactionManager();
+ TransactionManagerUtils.setEpoch(transactionManager, epoch);
+ } catch (ProducerException e) {
+ throw new RuntimeException("Unable to set epoch", e);
}
}
diff --git a/psc/src/main/java/com/pinterest/psc/producer/transaction/TransactionManagerOperator.java b/psc/src/main/java/com/pinterest/psc/producer/transaction/TransactionManagerOperator.java
new file mode 100644
index 0000000..53292d8
--- /dev/null
+++ b/psc/src/main/java/com/pinterest/psc/producer/transaction/TransactionManagerOperator.java
@@ -0,0 +1,30 @@
+package com.pinterest.psc.producer.transaction;
+
+import com.pinterest.psc.producer.PscProducerTransactionalProperties;
+
+import java.util.concurrent.Future;
+
+/**
+ * Backend-agnostic interface for operating on a transaction manager.
+ * See {@link TransactionManagerUtils} for more details.
+ */
+public interface TransactionManagerOperator {
+
+ short getEpoch(Object transactionManager);
+
+ String getTransactionId(Object transactionManager);
+
+ long getProducerId(Object transactionManager);
+
+ void setEpoch(Object transactionManager, short epoch);
+
+ void setTransactionId(Object transactionManager, String transactionId);
+
+ void setProducerId(Object transactionManager, long producerId);
+
+ Future enqueueInFlightTransactions(Object transactionManager);
+
+ void resumeTransaction(Object transactionManager, PscProducerTransactionalProperties transactionalProperties);
+
+ int getTransactionCoordinatorId(Object transactionManager);
+}
diff --git a/psc/src/main/java/com/pinterest/psc/producer/transaction/TransactionManagerUtils.java b/psc/src/main/java/com/pinterest/psc/producer/transaction/TransactionManagerUtils.java
new file mode 100644
index 0000000..7e39708
--- /dev/null
+++ b/psc/src/main/java/com/pinterest/psc/producer/transaction/TransactionManagerUtils.java
@@ -0,0 +1,126 @@
+package com.pinterest.psc.producer.transaction;
+
+import com.pinterest.psc.producer.PscProducerTransactionalProperties;
+import com.pinterest.psc.producer.transaction.kafka.KafkaTransactionManagerOperator;
+import org.apache.kafka.clients.producer.internals.TransactionManager;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.Future;
+
+/**
+ * This class is used to abstract the differences between different transaction managers and unify the operations on them.
+ * The operations supported here generally require direct field access or reflections on the transaction manager object,
+ * which is necessary for certain operations like setting the producer ID, epoch, or transaction ID.
+ *
+ * For regular transaction operations like beginTransaction, commitTransaction, and abortTransaction, they are not included here
+ * because they should be directly accessible via public APIs of the backend producer implementation.
+ *
+ * Each backend PubSub implementation should provide a {@link TransactionManagerOperator} implementation to support the operations,
+ * and register it in the TXN_MANAGER_CLASSNAME_TO_OPERATOR map.
+ */
+public class TransactionManagerUtils {
+
+ private static final Map TXN_MANAGER_CLASSNAME_TO_OPERATOR = new HashMap<>();
+
+ private static TransactionManagerOperator getOrCreateTransactionManagerOperator(Object transactionManager) {
+ return TXN_MANAGER_CLASSNAME_TO_OPERATOR.computeIfAbsent(transactionManager.getClass().getName(), className -> {
+ if (className.equals(TransactionManager.class.getName())) {
+ return new KafkaTransactionManagerOperator();
+ }
+ throw new IllegalArgumentException("Unsupported transaction manager class: " + className);
+ });
+ }
+
+ /**
+ * Get the epoch of the transaction manager.
+ *
+ * @param transactionManager the transaction manager object
+ * @return the epoch of the transaction manager
+ */
+ public static short getEpoch(Object transactionManager) {
+ return getOrCreateTransactionManagerOperator(transactionManager).getEpoch(transactionManager);
+ }
+
+ /**
+ * Get the producer ID of the transaction manager.
+ *
+ * @param transactionManager the transaction manager object
+ * @return the producer ID of the transaction manager
+ */
+ public static long getProducerId(Object transactionManager) {
+ return getOrCreateTransactionManagerOperator(transactionManager).getProducerId(transactionManager);
+ }
+
+ /**
+ * Get the transaction ID of the transaction manager.
+ *
+ * @param transactionManager the transaction manager object
+ * @return the transaction ID of the transaction manager
+ */
+ public static String getTransactionId(Object transactionManager) {
+ return getOrCreateTransactionManagerOperator(transactionManager).getTransactionId(transactionManager);
+ }
+
+ /**
+ * Set the epoch of the transaction manager.
+ *
+ * @param transactionManager the transaction manager object
+ * @param epoch the epoch to set
+ */
+ public static void setEpoch(Object transactionManager, short epoch) {
+ getOrCreateTransactionManagerOperator(transactionManager).setEpoch(transactionManager, epoch);
+ }
+
+ /**
+ * Set the producer ID of the transaction manager.
+ *
+ * @param transactionManager the transaction manager object
+ * @param producerId the producer ID to set
+ */
+ public static void setProducerId(Object transactionManager, long producerId) {
+ getOrCreateTransactionManagerOperator(transactionManager).setProducerId(transactionManager, producerId);
+ }
+
+ /**
+ * Set the transaction ID of the transaction manager.
+ *
+ * @param transactionManager the transaction manager object
+ * @param transactionalId the transaction ID to set
+ */
+ public static void setTransactionId(Object transactionManager, String transactionalId) {
+ getOrCreateTransactionManagerOperator(transactionManager).setTransactionId(transactionManager, transactionalId);
+ }
+
+ /**
+ * Enqueue in-flight transactions in the transaction manager.
+ *
+ * @param transactionManager the transaction manager object
+ * @return a future that completes when the in-flight transactions are enqueued
+ */
+ public static Future enqueueInFlightTransactions(Object transactionManager) {
+ return getOrCreateTransactionManagerOperator(transactionManager).enqueueInFlightTransactions(transactionManager);
+ }
+
+ /**
+ * Resume the transaction in the transaction manager with the given transactional properties. Typically,
+ * this is used to resume a transaction with the same transaction ID and producer ID after a previous transaction
+ * has been aborted or failed.
+ *
+ * @param transactionManager the transaction manager object
+ * @param transactionalProperties the transactional properties
+ */
+ public static void resumeTransaction(Object transactionManager, PscProducerTransactionalProperties transactionalProperties) {
+ getOrCreateTransactionManagerOperator(transactionManager).resumeTransaction(transactionManager, transactionalProperties);
+ }
+
+ /**
+ * Get the transaction coordinator ID of the transaction manager.
+ *
+ * @param transactionManager the transaction manager object
+ * @return the transaction coordinator ID of the transaction manager
+ */
+ public static int getTransactionCoordinatorId(Object transactionManager) {
+ return getOrCreateTransactionManagerOperator(transactionManager).getTransactionCoordinatorId(transactionManager);
+ }
+}
diff --git a/psc/src/main/java/com/pinterest/psc/producer/transaction/kafka/KafkaTransactionManagerOperator.java b/psc/src/main/java/com/pinterest/psc/producer/transaction/kafka/KafkaTransactionManagerOperator.java
new file mode 100644
index 0000000..f59c5c2
--- /dev/null
+++ b/psc/src/main/java/com/pinterest/psc/producer/transaction/kafka/KafkaTransactionManagerOperator.java
@@ -0,0 +1,289 @@
+package com.pinterest.psc.producer.transaction.kafka;
+
+import com.pinterest.psc.common.PscCommon;
+import com.pinterest.psc.producer.PscProducerTransactionalProperties;
+import com.pinterest.psc.producer.transaction.TransactionManagerOperator;
+import org.apache.kafka.clients.producer.internals.TransactionManager;
+import org.apache.kafka.clients.producer.internals.TransactionalRequestResult;
+import org.apache.kafka.common.Node;
+import org.apache.kafka.common.requests.FindCoordinatorRequest;
+import org.apache.kafka.common.utils.ProducerIdAndEpoch;
+
+import java.lang.reflect.Constructor;
+import java.lang.reflect.Field;
+import java.lang.reflect.InvocationTargetException;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * A Kafka transaction manager operator that provides methods to interact with the transaction manager.
+ */
+public class KafkaTransactionManagerOperator implements TransactionManagerOperator {
+
+ private static final String KAFKA_TXN_MANAGER_PRODUCER_ID_AND_EPOCH_FIELD_NAME = "producerIdAndEpoch";
+ private static final String TRANSACTION_MANAGER_STATE_ENUM =
+ "org.apache.kafka.clients.producer.internals.TransactionManager$State";
+
+ /**
+ * Returns the ProducerIdAndEpoch from the transaction manager.
+ *
+ * @param transactionManager the transaction manager
+ * @return the ProducerIdAndEpoch
+ */
+ private ProducerIdAndEpoch getProducerIdAndEpoch(Object transactionManager) {
+ ProducerIdAndEpoch producerIdAndEpoch = (ProducerIdAndEpoch) PscCommon.getField(transactionManager, KAFKA_TXN_MANAGER_PRODUCER_ID_AND_EPOCH_FIELD_NAME);
+ if (producerIdAndEpoch == null) {
+ throw new IllegalStateException("ProducerIdAndEpoch is null");
+ }
+ return producerIdAndEpoch;
+ }
+
+ /**
+ * Returns the epoch from the transaction manager.
+ *
+ * @param transactionManager the transaction manager
+ * @return the epoch
+ */
+ @Override
+ public short getEpoch(Object transactionManager) {
+ ProducerIdAndEpoch producerIdAndEpoch = getProducerIdAndEpoch(transactionManager);
+ return (short) PscCommon.getField(producerIdAndEpoch, "epoch");
+ }
+
+ /**
+ * Returns the transaction id from the transaction manager.
+ *
+ * @param transactionManager the transaction manager
+ * @return the transaction id
+ */
+ @Override
+ public String getTransactionId(Object transactionManager) {
+ return (String) PscCommon.getField(transactionManager, "transactionalId");
+ }
+
+ /**
+ * Returns the producer id from the transaction manager.
+ *
+ * @param transactionManager the transaction manager
+ * @return the producer id
+ */
+ @Override
+ public long getProducerId(Object transactionManager) {
+ ProducerIdAndEpoch producerIdAndEpoch = getProducerIdAndEpoch(transactionManager);
+ return (long) PscCommon.getField(producerIdAndEpoch, "producerId");
+ }
+
+ /**
+ * Sets the epoch in the transaction manager.
+ *
+ * @param transactionManager the transaction manager
+ * @param epoch the epoch
+ */
+ @Override
+ public void setEpoch(Object transactionManager, short epoch) {
+ ProducerIdAndEpoch producerIdAndEpoch = getProducerIdAndEpoch(transactionManager);
+ PscCommon.setField(producerIdAndEpoch, "epoch", epoch);
+ }
+
+ /**
+ * Sets the transaction id in the transaction manager.
+ *
+ * @param transactionManager the transaction manager
+ * @param transactionId the transaction id
+ */
+ @Override
+ public void setTransactionId(Object transactionManager, String transactionId) {
+ PscCommon.setField(transactionManager, "transactionalId", transactionId);
+ PscCommon.setField(
+ transactionManager,
+ "currentState",
+ getTransactionManagerState("UNINITIALIZED"));
+ }
+
+ /**
+ * Sets the producer id in the transaction manager.
+ *
+ * @param transactionManager the transaction manager
+ * @param producerId the producer id
+ */
+ @Override
+ public void setProducerId(Object transactionManager, long producerId) {
+ ProducerIdAndEpoch producerIdAndEpoch = getProducerIdAndEpoch(transactionManager);
+ PscCommon.setField(producerIdAndEpoch, "producerId", producerId);
+ }
+
+ /**
+ * Enqueues in-flight transactions at the transaction manager. This method is used to ensure that
+ * in-flight transactions are flushed before the producer is closed.
+ *
+ * Calling the {@link Future#get()} method will block until the {@link TransactionalRequestResult} is completed.
+ * The {@link TransactionalRequestResult} is completed when the transaction manager has successfully enqueued the
+ * new partitions. The boolean value returned by the {@link Future#get()} method indicates whether the operation
+ * was successful.
+ *
+ * @param transactionManager the transaction manager
+ * @return a future that can be used to wait for the operation to complete.
+ */
+ @Override
+ public Future enqueueInFlightTransactions(Object transactionManager) {
+ TransactionalRequestResult result = enqueueNewPartitions(transactionManager);
+ return new Future() {
+ @Override
+ public boolean cancel(boolean mayInterruptIfRunning) {
+ return false;
+ }
+
+ @Override
+ public boolean isCancelled() {
+ return false;
+ }
+
+ @Override
+ public boolean isDone() {
+ return result.isCompleted();
+ }
+
+ @Override
+ public Boolean get() {
+ result.await();
+ return result.isSuccessful();
+ }
+
+ @Override
+ public Boolean get(long timeout, TimeUnit unit) {
+ result.await(timeout, unit);
+ return result.isSuccessful();
+ }
+ };
+ }
+
+ /**
+ * Returns the transaction manager state given an enum name as a String.
+ *
+ * @param enumName the enum name String
+ * @return the transaction manager state enum
+ */
+ @SuppressWarnings("unchecked")
+ private Enum> getTransactionManagerState(String enumName) {
+ try {
+ Class cl = (Class) Class.forName(TRANSACTION_MANAGER_STATE_ENUM);
+ return Enum.valueOf(cl, enumName);
+ } catch (ClassNotFoundException e) {
+ throw new RuntimeException("Incompatible KafkaProducer version", e);
+ }
+ }
+
+ /**
+ * Enqueues new transactions at the transaction manager and returns a {@link
+ * TransactionalRequestResult} that allows waiting on them.
+ *
+ * If there are no new transactions we return a {@link TransactionalRequestResult} that is
+ * already done.
+ */
+ private TransactionalRequestResult enqueueNewPartitions(Object transactionManager) {
+ Object newPartitionsInTransaction =
+ PscCommon.getField(transactionManager, "newPartitionsInTransaction");
+ Object newPartitionsInTransactionIsEmpty =
+ PscCommon.invoke(newPartitionsInTransaction, "isEmpty");
+ TransactionalRequestResult result;
+ if (newPartitionsInTransactionIsEmpty instanceof Boolean
+ && !((Boolean) newPartitionsInTransactionIsEmpty)) {
+ Object txnRequestHandler =
+ PscCommon.invoke(transactionManager, "addPartitionsToTransactionHandler");
+ PscCommon.invoke(
+ transactionManager,
+ "enqueueRequest",
+ new Class[] {txnRequestHandler.getClass().getSuperclass()},
+ new Object[] {txnRequestHandler});
+ result =
+ (TransactionalRequestResult)
+ PscCommon.getField(
+ txnRequestHandler,
+ txnRequestHandler.getClass().getSuperclass(),
+ "result");
+ } else {
+ // we don't have an operation but this operation string is also used in
+ // addPartitionsToTransactionHandler.
+ result = new TransactionalRequestResult("AddPartitionsToTxn");
+ result.done();
+ }
+ return result;
+ }
+
+ /**
+ * Resumes the transaction in the transaction manager by setting the producer id and epoch in the transactionManager
+ * to what is provided by the {@link PscProducerTransactionalProperties} and transitioning the transactionManager
+ * state, first to "INITIALIZING", then to "READY", and finally to "IN_TRANSACTION".
+ *
+ * @param transactionManager the transaction manager
+ * @param transactionalProperties the transactional properties containing the producerId and epoch to resume the transaction with
+ */
+ @Override
+ public void resumeTransaction(Object transactionManager, PscProducerTransactionalProperties transactionalProperties) {
+ Object topicPartitionBookkeeper =
+ PscCommon.getField(transactionManager, "topicPartitionBookkeeper");
+
+ transitionTransactionManagerStateTo(transactionManager, "INITIALIZING");
+ PscCommon.invoke(topicPartitionBookkeeper, "reset");
+
+ PscCommon.setField(
+ transactionManager,
+ KAFKA_TXN_MANAGER_PRODUCER_ID_AND_EPOCH_FIELD_NAME,
+ createProducerIdAndEpoch(transactionalProperties.getProducerId(), transactionalProperties.getEpoch()));
+
+ transitionTransactionManagerStateTo(transactionManager, "READY");
+
+ transitionTransactionManagerStateTo(transactionManager, "IN_TRANSACTION");
+ PscCommon.setField(transactionManager, "transactionStarted", true);
+ }
+
+ /**
+ * Returns the transaction coordinator id of the transaction manager.
+ *
+ * @param transactionManager the transaction manager
+ * @return the transaction coordinator id
+ */
+ @Override
+ public int getTransactionCoordinatorId(Object transactionManager) {
+ Node coordinatorNode = (Node) PscCommon.invoke(transactionManager, "coordinator", FindCoordinatorRequest.CoordinatorType.TRANSACTION);
+ if (coordinatorNode == null) {
+ throw new IllegalStateException("Transaction coordinator node is null");
+ }
+ return coordinatorNode.id();
+ }
+
+ /**
+ * Creates a {@link ProducerIdAndEpoch} object with the given producerId and epoch.
+ *
+ * @param producerId the producer id
+ * @param epoch the epoch
+ * @return the producer id and epoch object
+ */
+ private ProducerIdAndEpoch createProducerIdAndEpoch(long producerId, short epoch) {
+ try {
+ Field field =
+ TransactionManager.class.getDeclaredField(KAFKA_TXN_MANAGER_PRODUCER_ID_AND_EPOCH_FIELD_NAME);
+ Class> clazz = field.getType();
+ Constructor> constructor = clazz.getDeclaredConstructor(Long.TYPE, Short.TYPE);
+ constructor.setAccessible(true);
+ return (ProducerIdAndEpoch) constructor.newInstance(producerId, epoch);
+ } catch (InvocationTargetException
+ | InstantiationException
+ | IllegalAccessException
+ | NoSuchFieldException
+ | NoSuchMethodException e) {
+ throw new RuntimeException("Incompatible KafkaProducer version", e);
+ }
+ }
+
+ /**
+ * Transitions the transaction manager state to the given state.
+ *
+ * @param transactionManager the transaction manager
+ * @param state the state to transition to
+ */
+ private void transitionTransactionManagerStateTo(
+ Object transactionManager, String state) {
+ PscCommon.invoke(transactionManager, "transitionTo", getTransactionManagerState(state));
+ }
+}
diff --git a/psc/src/test/java/com/pinterest/psc/producer/transaction/TestTransactionManagerUtils.java b/psc/src/test/java/com/pinterest/psc/producer/transaction/TestTransactionManagerUtils.java
new file mode 100644
index 0000000..8277261
--- /dev/null
+++ b/psc/src/test/java/com/pinterest/psc/producer/transaction/TestTransactionManagerUtils.java
@@ -0,0 +1,31 @@
+package com.pinterest.psc.producer.transaction;
+
+import org.apache.kafka.clients.producer.internals.TransactionManager;
+import org.apache.kafka.common.utils.LogContext;
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNull;
+
+public class TestTransactionManagerUtils {
+
+ @Test
+ void testKafkaTransactionManagerOperator() {
+ // create a Kafka TransactionManager just for testing
+ TransactionManager transactionManager = new TransactionManager(new LogContext(), null, 10000, 100L, null, false);
+
+ long id = TransactionManagerUtils.getProducerId(transactionManager);
+ short epoch = TransactionManagerUtils.getEpoch(transactionManager);
+ String transactionId = TransactionManagerUtils.getTransactionId(transactionManager);
+ assertEquals(-1, id); // uninitialized
+ assertEquals(-1, epoch); // uninitialized
+ assertNull(transactionId); // uninitialized
+
+ TransactionManagerUtils.setProducerId(transactionManager, 100L);
+ TransactionManagerUtils.setEpoch(transactionManager, (short) 1);
+ TransactionManagerUtils.setTransactionId(transactionManager, "transaction-id");
+ assertEquals(100L, TransactionManagerUtils.getProducerId(transactionManager));
+ assertEquals(1, TransactionManagerUtils.getEpoch(transactionManager));
+ assertEquals("transaction-id", TransactionManagerUtils.getTransactionId(transactionManager));
+ }
+}