diff --git a/psc-flink/src/main/java/com/pinterest/flink/streaming/connectors/psc/internals/FlinkPscInternalProducer.java b/psc-flink/src/main/java/com/pinterest/flink/streaming/connectors/psc/internals/FlinkPscInternalProducer.java index 2aadaaf..75f7b58 100644 --- a/psc-flink/src/main/java/com/pinterest/flink/streaming/connectors/psc/internals/FlinkPscInternalProducer.java +++ b/psc-flink/src/main/java/com/pinterest/flink/streaming/connectors/psc/internals/FlinkPscInternalProducer.java @@ -30,6 +30,7 @@ import com.pinterest.psc.producer.PscProducer; import com.pinterest.psc.producer.PscProducerMessage; import com.pinterest.psc.producer.PscProducerTransactionalProperties; +import com.pinterest.psc.producer.transaction.TransactionManagerUtils; import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.shaded.guava30.com.google.common.base.Joiner; @@ -165,14 +166,12 @@ public String getTransactionalId() { public long getProducerId(PscProducerMessage pscProducerMessage) throws ProducerException { Object transactionManager = super.getTransactionManager(pscProducerMessage); - Object producerIdAndEpoch = getField(transactionManager, "producerIdAndEpoch"); - return (long) getField(producerIdAndEpoch, "producerId"); + return TransactionManagerUtils.getProducerId(transactionManager); } public short getEpoch(PscProducerMessage pscProducerMessage) throws ProducerException { Object transactionManager = super.getTransactionManager(pscProducerMessage); - Object producerIdAndEpoch = getField(transactionManager, "producerIdAndEpoch"); - return (short) getField(producerIdAndEpoch, "epoch"); + return TransactionManagerUtils.getEpoch(transactionManager); } @VisibleForTesting @@ -180,7 +179,7 @@ public Set getTransactionCoordinatorIds() throws ProducerException { Set coordinatorIds = new HashSet<>(); super.getTransactionManagers().forEach(transactionManager -> coordinatorIds.add( - ((Node) invoke(transactionManager, "coordinator", FindCoordinatorRequest.CoordinatorType.TRANSACTION)).id() + TransactionManagerUtils.getTransactionCoordinatorId(transactionManager) ) ); return coordinatorIds; @@ -200,9 +199,15 @@ private void ensureNotClosed() { */ private void flushNewPartitions() throws ProducerException { LOG.info("Flushing new partitions"); - Set results = enqueueNewPartitions(); + Set> results = enqueueNewPartitions(); super.wakeup(); - results.forEach(TransactionalRequestResult::await); + results.forEach(future -> { + try { + future.get(); + } catch (Exception e) { + throw new RuntimeException("Error while flushing new partitions", e); + } + }); } /** @@ -212,97 +217,17 @@ private void flushNewPartitions() throws ProducerException { *

If there are no new transactions we return a {@link TransactionalRequestResult} that is * already done. */ - private Set enqueueNewPartitions() throws ProducerException { - Set transactionalRequestResults = new HashSet<>(); + private Set> enqueueNewPartitions() throws ProducerException { + Set> transactionalRequestResults = new HashSet<>(); Set transactionManagers = super.getTransactionManagers(); for (Object transactionManager : transactionManagers) { synchronized (transactionManager) { - Object newPartitionsInTransaction = getField(transactionManager, "newPartitionsInTransaction"); - Object newPartitionsInTransactionIsEmpty = invoke(newPartitionsInTransaction, "isEmpty"); - TransactionalRequestResult result; - if (newPartitionsInTransactionIsEmpty instanceof Boolean && !((Boolean) newPartitionsInTransactionIsEmpty)) { - Object txnRequestHandler = invoke(transactionManager, "addPartitionsToTransactionHandler"); - invoke(transactionManager, "enqueueRequest", new Class[]{txnRequestHandler.getClass().getSuperclass()}, new Object[]{txnRequestHandler}); - result = (TransactionalRequestResult) getField(txnRequestHandler, txnRequestHandler.getClass().getSuperclass(), "result"); - } else { - // we don't have an operation but this operation string is also used in - // addPartitionsToTransactionHandler. - result = new TransactionalRequestResult("AddPartitionsToTxn"); - result.done(); - } - transactionalRequestResults.add(result); + Future transactionalRequestResultFuture = + TransactionManagerUtils.enqueueInFlightTransactions(transactionManager); + transactionalRequestResults.add(transactionalRequestResultFuture); } } return transactionalRequestResults; } - protected static Enum getEnum(String enumFullName) { - String[] x = enumFullName.split("\\.(?=[^\\.]+$)"); - if (x.length == 2) { - String enumClassName = x[0]; - String enumName = x[1]; - try { - Class cl = (Class) Class.forName(enumClassName); - return Enum.valueOf(cl, enumName); - } catch (ClassNotFoundException e) { - throw new RuntimeException("Incompatible KafkaProducer version", e); - } - } - return null; - } - - protected static Object invoke(Object object, String methodName, Object... args) { - Class[] argTypes = new Class[args.length]; - for (int i = 0; i < args.length; i++) { - argTypes[i] = args[i].getClass(); - } - return invoke(object, methodName, argTypes, args); - } - - private static Object invoke(Object object, String methodName, Class[] argTypes, Object[] args) { - try { - Method method = object.getClass().getDeclaredMethod(methodName, argTypes); - method.setAccessible(true); - return method.invoke(object, args); - } catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) { - throw new RuntimeException("Incompatible PscProducer version", e); - } - } - - /** - * Gets and returns the field {@code fieldName} from the given Object {@code object} using - * reflection. - */ - protected static Object getField(Object object, String fieldName) { - return getField(object, object.getClass(), fieldName); - } - - /** - * Gets and returns the field {@code fieldName} from the given Object {@code object} using - * reflection. - */ - private static Object getField(Object object, Class clazz, String fieldName) { - try { - Field field = clazz.getDeclaredField(fieldName); - field.setAccessible(true); - return field.get(object); - } catch (NoSuchFieldException | IllegalAccessException e) { - throw new RuntimeException("Incompatible KafkaProducer version", e); - } - } - - /** - * Sets the field {@code fieldName} on the given Object {@code object} to {@code value} using - * reflection. - */ - protected static void setField(Object object, String fieldName, Object value) { - try { - Field field = object.getClass().getDeclaredField(fieldName); - field.setAccessible(true); - field.set(object, value); - } catch (NoSuchFieldException | IllegalAccessException e) { - throw new RuntimeException("Incompatible KafkaProducer version", e); - } - } - } diff --git a/psc/src/main/java/com/pinterest/psc/producer/transaction/TransactionManagerOperator.java b/psc/src/main/java/com/pinterest/psc/producer/transaction/TransactionManagerOperator.java index 1ffcc49..53292d8 100644 --- a/psc/src/main/java/com/pinterest/psc/producer/transaction/TransactionManagerOperator.java +++ b/psc/src/main/java/com/pinterest/psc/producer/transaction/TransactionManagerOperator.java @@ -25,4 +25,6 @@ public interface TransactionManagerOperator { Future enqueueInFlightTransactions(Object transactionManager); void resumeTransaction(Object transactionManager, PscProducerTransactionalProperties transactionalProperties); + + int getTransactionCoordinatorId(Object transactionManager); } diff --git a/psc/src/main/java/com/pinterest/psc/producer/transaction/TransactionManagerUtils.java b/psc/src/main/java/com/pinterest/psc/producer/transaction/TransactionManagerUtils.java index 837715a..7e39708 100644 --- a/psc/src/main/java/com/pinterest/psc/producer/transaction/TransactionManagerUtils.java +++ b/psc/src/main/java/com/pinterest/psc/producer/transaction/TransactionManagerUtils.java @@ -113,4 +113,14 @@ public static Future enqueueInFlightTransactions(Object transactionMana public static void resumeTransaction(Object transactionManager, PscProducerTransactionalProperties transactionalProperties) { getOrCreateTransactionManagerOperator(transactionManager).resumeTransaction(transactionManager, transactionalProperties); } + + /** + * Get the transaction coordinator ID of the transaction manager. + * + * @param transactionManager the transaction manager object + * @return the transaction coordinator ID of the transaction manager + */ + public static int getTransactionCoordinatorId(Object transactionManager) { + return getOrCreateTransactionManagerOperator(transactionManager).getTransactionCoordinatorId(transactionManager); + } } diff --git a/psc/src/main/java/com/pinterest/psc/producer/transaction/kafka/KafkaTransactionManagerOperator.java b/psc/src/main/java/com/pinterest/psc/producer/transaction/kafka/KafkaTransactionManagerOperator.java index bb5de0a..f59c5c2 100644 --- a/psc/src/main/java/com/pinterest/psc/producer/transaction/kafka/KafkaTransactionManagerOperator.java +++ b/psc/src/main/java/com/pinterest/psc/producer/transaction/kafka/KafkaTransactionManagerOperator.java @@ -5,6 +5,8 @@ import com.pinterest.psc.producer.transaction.TransactionManagerOperator; import org.apache.kafka.clients.producer.internals.TransactionManager; import org.apache.kafka.clients.producer.internals.TransactionalRequestResult; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.requests.FindCoordinatorRequest; import org.apache.kafka.common.utils.ProducerIdAndEpoch; import java.lang.reflect.Constructor; @@ -235,6 +237,21 @@ public void resumeTransaction(Object transactionManager, PscProducerTransactiona PscCommon.setField(transactionManager, "transactionStarted", true); } + /** + * Returns the transaction coordinator id of the transaction manager. + * + * @param transactionManager the transaction manager + * @return the transaction coordinator id + */ + @Override + public int getTransactionCoordinatorId(Object transactionManager) { + Node coordinatorNode = (Node) PscCommon.invoke(transactionManager, "coordinator", FindCoordinatorRequest.CoordinatorType.TRANSACTION); + if (coordinatorNode == null) { + throw new IllegalStateException("Transaction coordinator node is null"); + } + return coordinatorNode.id(); + } + /** * Creates a {@link ProducerIdAndEpoch} object with the given producerId and epoch. *