diff --git a/aeron-archive/src/main/java/io/aeron/archive/ControlResponseProxy.java b/aeron-archive/src/main/java/io/aeron/archive/ControlResponseProxy.java index 8144015673..90cc24d96a 100644 --- a/aeron-archive/src/main/java/io/aeron/archive/ControlResponseProxy.java +++ b/aeron-archive/src/main/java/io/aeron/archive/ControlResponseProxy.java @@ -27,16 +27,10 @@ import org.agrona.MutableDirectBuffer; import org.agrona.concurrent.UnsafeBuffer; -import static io.aeron.archive.codecs.RecordingDescriptorEncoder.recordingIdEncodingOffset; -import static org.agrona.BitUtil.SIZE_OF_LONG; - class ControlResponseProxy { private static final int SEND_ATTEMPTS = 3; private static final int MESSAGE_HEADER_LENGTH = MessageHeaderEncoder.ENCODED_LENGTH; - private static final int DESCRIPTOR_PREFIX_LENGTH = MESSAGE_HEADER_LENGTH + 2 * SIZE_OF_LONG; - private static final int DESCRIPTOR_CONTENT_OFFSET = - RecordingDescriptorHeaderDecoder.BLOCK_LENGTH + recordingIdEncodingOffset(); private final ExpandableArrayBuffer buffer = new ExpandableArrayBuffer(1024); private final BufferClaim bufferClaim = new BufferClaim(); @@ -53,11 +47,10 @@ boolean sendDescriptor( final long controlSessionId, final long correlationId, final UnsafeBuffer descriptorBuffer, + final int descriptorOffset, + final int descriptorLength, final ControlSession session) { - final int messageLength = Catalog.descriptorLength(descriptorBuffer) + MESSAGE_HEADER_LENGTH; - final int contentLength = messageLength - recordingIdEncodingOffset() - MESSAGE_HEADER_LENGTH; - recordingDescriptorEncoder .wrapAndApplyHeader(buffer, 0, messageHeaderEncoder) .controlSessionId(controlSessionId) @@ -70,10 +63,10 @@ boolean sendDescriptor( final long position = publication.offer( buffer, 0, - DESCRIPTOR_PREFIX_LENGTH, + MESSAGE_HEADER_LENGTH + RecordingDescriptorDecoder.recordingIdEncodingOffset(), descriptorBuffer, - DESCRIPTOR_CONTENT_OFFSET, - contentLength); + descriptorOffset, + descriptorLength); if (position > 0) { return true; diff --git a/aeron-archive/src/main/java/io/aeron/archive/ControlSession.java b/aeron-archive/src/main/java/io/aeron/archive/ControlSession.java index ef8f7943ac..324abc4496 100644 --- a/aeron-archive/src/main/java/io/aeron/archive/ControlSession.java +++ b/aeron-archive/src/main/java/io/aeron/archive/ControlSession.java @@ -20,6 +20,8 @@ import io.aeron.Subscription; import io.aeron.archive.client.ArchiveEvent; import io.aeron.archive.codecs.ControlResponseCode; +import io.aeron.archive.codecs.RecordingDescriptorDecoder; +import io.aeron.archive.codecs.RecordingDescriptorHeaderDecoder; import io.aeron.archive.codecs.RecordingSignal; import io.aeron.archive.codecs.SourceLocation; import io.aeron.security.Authenticator; @@ -713,12 +715,20 @@ void asyncSendOkResponse(final long correlationId, final long replaySessionId) void sendDescriptor(final long correlationId, final UnsafeBuffer descriptorBuffer) { assertCalledOnConductorThread(); + final int descriptorOffset = + RecordingDescriptorHeaderDecoder.BLOCK_LENGTH + RecordingDescriptorDecoder.recordingIdEncodingOffset(); + final int descriptorLength = + Catalog.descriptorLength(descriptorBuffer) - RecordingDescriptorDecoder.recordingIdEncodingOffset(); if (!syncResponseQueue.isEmpty() || - !controlResponseProxy.sendDescriptor(controlSessionId, correlationId, descriptorBuffer, this)) + !controlResponseProxy.sendDescriptor( + controlSessionId, correlationId, descriptorBuffer, descriptorOffset, descriptorLength, this)) { + final UnsafeBuffer tmpBuffer = new UnsafeBuffer(new byte[descriptorLength]); + tmpBuffer.putBytes(0, descriptorBuffer, descriptorOffset, descriptorLength); + updateActivityDeadline(cachedEpochClock.time()); syncResponseQueue.offer(() -> controlResponseProxy.sendDescriptor( - controlSessionId, correlationId, descriptorBuffer, this)); + controlSessionId, correlationId, tmpBuffer, 0, tmpBuffer.capacity(), this)); } else { diff --git a/aeron-archive/src/test/java/io/aeron/archive/ControlSessionTest.java b/aeron-archive/src/test/java/io/aeron/archive/ControlSessionTest.java index aa1a6e1966..4b2ae568e1 100644 --- a/aeron-archive/src/test/java/io/aeron/archive/ControlSessionTest.java +++ b/aeron-archive/src/test/java/io/aeron/archive/ControlSessionTest.java @@ -17,21 +17,28 @@ import io.aeron.Aeron; import io.aeron.ExclusivePublication; +import io.aeron.archive.codecs.RecordingDescriptorEncoder; +import io.aeron.archive.codecs.RecordingDescriptorHeaderEncoder; +import io.aeron.archive.codecs.RecordingState; import io.aeron.security.Authenticator; +import org.agrona.BitUtil; import org.agrona.concurrent.CachedEpochClock; +import org.agrona.concurrent.UnsafeBuffer; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; -import static io.aeron.archive.ControlSession.State.DONE; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import static io.aeron.archive.ControlSession.State.*; +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; class ControlSessionTest { + private static final long CONTROL_SESSION_ID = -953749534; + private static final long CORRELATION_ID = 47354; private static final long CONTROL_PUBLICATION_ID = 777; private static final long CONNECT_TIMEOUT_MS = TimeUnit.SECONDS.toMillis(5); private static final long SESSION_LIVENESS_CHECK_INTERVAL_NS = TimeUnit.MILLISECONDS.toNanos(100); @@ -50,8 +57,8 @@ class ControlSessionTest void before() { session = new ControlSession( - 1, - 2, + CONTROL_SESSION_ID, + CORRELATION_ID, CONNECT_TIMEOUT_MS, SESSION_LIVENESS_CHECK_INTERVAL_NS, CONTROL_PUBLICATION_ID, @@ -65,6 +72,7 @@ void before() mockSessionProxy); when(mockAeron.getExclusivePublication(CONTROL_PUBLICATION_ID)).thenReturn(mockControlPublication); + when(mockControlPublication.isConnected()).thenReturn(true); } @Test @@ -96,4 +104,63 @@ void shouldTimeoutIfConnectSentButPublicationFailsToSend() assertEquals(DONE, session.state()); assertTrue(session.isDone()); } + + @Test + void shouldCopyDescriptor() + { + final long correlationId = -438682374754L; + final UnsafeBuffer buffer = new UnsafeBuffer(new byte[1024]); + ThreadLocalRandom.current().nextBytes(buffer.byteArray()); + + final RecordingDescriptorEncoder recordingDescriptorEncoder = new RecordingDescriptorEncoder(); + recordingDescriptorEncoder + .wrap(buffer, RecordingDescriptorHeaderEncoder.BLOCK_LENGTH) + .strippedChannel("aeron:udp?endpoint=localhost:12345") + .originalChannel("aeron:udp?mtu=2048|term-length=128k|endpoint=localhost:12345") + .sourceIdentity("the source of this mess"); + + final int fullDescriptorLength = BitUtil.align( + RecordingDescriptorHeaderEncoder.BLOCK_LENGTH + recordingDescriptorEncoder.encodedLength(), + BitUtil.CACHE_LINE_LENGTH); + + final RecordingDescriptorHeaderEncoder headerEncoder = new RecordingDescriptorHeaderEncoder(); + headerEncoder + .wrap(buffer, 0) + .length(fullDescriptorLength) + .state(RecordingState.VALID) + .checksum(ThreadLocalRandom.current().nextInt()); + + final int payloadOffset = + RecordingDescriptorHeaderEncoder.BLOCK_LENGTH + RecordingDescriptorEncoder.recordingIdEncodingOffset(); + final int payloadLength = fullDescriptorLength - RecordingDescriptorEncoder.recordingIdEncodingOffset(); + + session.sendDescriptor(correlationId, buffer); + + verify(mockProxy) + .sendDescriptor(CONTROL_SESSION_ID, correlationId, buffer, payloadOffset, payloadLength, session); + + while (session.state() != CONNECTED) + { + session.doWork(); + } + + session.authenticate(buffer.byteArray()); + assertEquals(AUTHENTICATED, session.state()); + + session.onArchiveId(777); + assertEquals(ACTIVE, session.state()); + + session.doWork(); + + final ArgumentCaptor bufferCaptor = ArgumentCaptor.forClass(UnsafeBuffer.class); + verify(mockProxy).sendDescriptor( + eq(CONTROL_SESSION_ID), eq(correlationId), bufferCaptor.capture(), eq(0), eq(payloadLength), same(session)); + final UnsafeBuffer tmpBuffer = bufferCaptor.getValue(); + assertNotNull(tmpBuffer); + assertNotSame(buffer, tmpBuffer); + for (int i = 0; i < payloadLength; i++) + { + assertEquals(buffer.getByte(payloadOffset + i), tmpBuffer.getByte(i)); + } + } }