Skip to content

Commit

Permalink
[Java] Create a temporary copy of the RecordingDescriptor buffer if s…
Browse files Browse the repository at this point in the history
…ending failed and must be enqueued for a re-send, because the `descriptorBuffer` is shared object.
  • Loading branch information
vyazelenko committed Jan 22, 2025
1 parent 7749c22 commit 61e4aba
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,10 @@
import org.agrona.MutableDirectBuffer;
import org.agrona.concurrent.UnsafeBuffer;

import static io.aeron.archive.codecs.RecordingDescriptorEncoder.recordingIdEncodingOffset;
import static org.agrona.BitUtil.SIZE_OF_LONG;

class ControlResponseProxy
{
private static final int SEND_ATTEMPTS = 3;
private static final int MESSAGE_HEADER_LENGTH = MessageHeaderEncoder.ENCODED_LENGTH;
private static final int DESCRIPTOR_PREFIX_LENGTH = MESSAGE_HEADER_LENGTH + 2 * SIZE_OF_LONG;
private static final int DESCRIPTOR_CONTENT_OFFSET =
RecordingDescriptorHeaderDecoder.BLOCK_LENGTH + recordingIdEncodingOffset();

private final ExpandableArrayBuffer buffer = new ExpandableArrayBuffer(1024);
private final BufferClaim bufferClaim = new BufferClaim();
Expand All @@ -53,11 +47,10 @@ boolean sendDescriptor(
final long controlSessionId,
final long correlationId,
final UnsafeBuffer descriptorBuffer,
final int descriptorOffset,
final int descriptorLength,
final ControlSession session)
{
final int messageLength = Catalog.descriptorLength(descriptorBuffer) + MESSAGE_HEADER_LENGTH;
final int contentLength = messageLength - recordingIdEncodingOffset() - MESSAGE_HEADER_LENGTH;

recordingDescriptorEncoder
.wrapAndApplyHeader(buffer, 0, messageHeaderEncoder)
.controlSessionId(controlSessionId)
Expand All @@ -70,10 +63,10 @@ boolean sendDescriptor(
final long position = publication.offer(
buffer,
0,
DESCRIPTOR_PREFIX_LENGTH,
MESSAGE_HEADER_LENGTH + RecordingDescriptorDecoder.recordingIdEncodingOffset(),
descriptorBuffer,
DESCRIPTOR_CONTENT_OFFSET,
contentLength);
descriptorOffset,
descriptorLength);
if (position > 0)
{
return true;
Expand Down
14 changes: 12 additions & 2 deletions aeron-archive/src/main/java/io/aeron/archive/ControlSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import io.aeron.Subscription;
import io.aeron.archive.client.ArchiveEvent;
import io.aeron.archive.codecs.ControlResponseCode;
import io.aeron.archive.codecs.RecordingDescriptorDecoder;
import io.aeron.archive.codecs.RecordingDescriptorHeaderDecoder;
import io.aeron.archive.codecs.RecordingSignal;
import io.aeron.archive.codecs.SourceLocation;
import io.aeron.security.Authenticator;
Expand Down Expand Up @@ -713,12 +715,20 @@ void asyncSendOkResponse(final long correlationId, final long replaySessionId)
void sendDescriptor(final long correlationId, final UnsafeBuffer descriptorBuffer)
{
assertCalledOnConductorThread();
final int descriptorOffset =
RecordingDescriptorHeaderDecoder.BLOCK_LENGTH + RecordingDescriptorDecoder.recordingIdEncodingOffset();
final int descriptorLength =
Catalog.descriptorLength(descriptorBuffer) - RecordingDescriptorDecoder.recordingIdEncodingOffset();
if (!syncResponseQueue.isEmpty() ||
!controlResponseProxy.sendDescriptor(controlSessionId, correlationId, descriptorBuffer, this))
!controlResponseProxy.sendDescriptor(
controlSessionId, correlationId, descriptorBuffer, descriptorOffset, descriptorLength, this))
{
final UnsafeBuffer tmpBuffer = new UnsafeBuffer(new byte[descriptorLength]);
tmpBuffer.putBytes(0, descriptorBuffer, descriptorOffset, descriptorLength);

updateActivityDeadline(cachedEpochClock.time());
syncResponseQueue.offer(() -> controlResponseProxy.sendDescriptor(
controlSessionId, correlationId, descriptorBuffer, this));
controlSessionId, correlationId, tmpBuffer, 0, tmpBuffer.capacity(), this));
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,28 @@

import io.aeron.Aeron;
import io.aeron.ExclusivePublication;
import io.aeron.archive.codecs.RecordingDescriptorEncoder;
import io.aeron.archive.codecs.RecordingDescriptorHeaderEncoder;
import io.aeron.archive.codecs.RecordingState;
import io.aeron.security.Authenticator;
import org.agrona.BitUtil;
import org.agrona.concurrent.CachedEpochClock;
import org.agrona.concurrent.UnsafeBuffer;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;

import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;

import static io.aeron.archive.ControlSession.State.DONE;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static io.aeron.archive.ControlSession.State.*;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;

class ControlSessionTest
{
private static final long CONTROL_SESSION_ID = -953749534;
private static final long CORRELATION_ID = 47354;
private static final long CONTROL_PUBLICATION_ID = 777;
private static final long CONNECT_TIMEOUT_MS = TimeUnit.SECONDS.toMillis(5);
private static final long SESSION_LIVENESS_CHECK_INTERVAL_NS = TimeUnit.MILLISECONDS.toNanos(100);
Expand All @@ -50,8 +57,8 @@ class ControlSessionTest
void before()
{
session = new ControlSession(
1,
2,
CONTROL_SESSION_ID,
CORRELATION_ID,
CONNECT_TIMEOUT_MS,
SESSION_LIVENESS_CHECK_INTERVAL_NS,
CONTROL_PUBLICATION_ID,
Expand All @@ -65,6 +72,7 @@ void before()
mockSessionProxy);

when(mockAeron.getExclusivePublication(CONTROL_PUBLICATION_ID)).thenReturn(mockControlPublication);
when(mockControlPublication.isConnected()).thenReturn(true);
}

@Test
Expand Down Expand Up @@ -96,4 +104,63 @@ void shouldTimeoutIfConnectSentButPublicationFailsToSend()
assertEquals(DONE, session.state());
assertTrue(session.isDone());
}

@Test
void shouldCopyDescriptor()
{
final long correlationId = -438682374754L;
final UnsafeBuffer buffer = new UnsafeBuffer(new byte[1024]);
ThreadLocalRandom.current().nextBytes(buffer.byteArray());

final RecordingDescriptorEncoder recordingDescriptorEncoder = new RecordingDescriptorEncoder();
recordingDescriptorEncoder
.wrap(buffer, RecordingDescriptorHeaderEncoder.BLOCK_LENGTH)
.strippedChannel("aeron:udp?endpoint=localhost:12345")
.originalChannel("aeron:udp?mtu=2048|term-length=128k|endpoint=localhost:12345")
.sourceIdentity("the source of this mess");

final int fullDescriptorLength = BitUtil.align(
RecordingDescriptorHeaderEncoder.BLOCK_LENGTH + recordingDescriptorEncoder.encodedLength(),
BitUtil.CACHE_LINE_LENGTH);

final RecordingDescriptorHeaderEncoder headerEncoder = new RecordingDescriptorHeaderEncoder();
headerEncoder
.wrap(buffer, 0)
.length(fullDescriptorLength)
.state(RecordingState.VALID)
.checksum(ThreadLocalRandom.current().nextInt());

final int payloadOffset =
RecordingDescriptorHeaderEncoder.BLOCK_LENGTH + RecordingDescriptorEncoder.recordingIdEncodingOffset();
final int payloadLength = fullDescriptorLength - RecordingDescriptorEncoder.recordingIdEncodingOffset();

session.sendDescriptor(correlationId, buffer);

verify(mockProxy)
.sendDescriptor(CONTROL_SESSION_ID, correlationId, buffer, payloadOffset, payloadLength, session);

while (session.state() != CONNECTED)
{
session.doWork();
}

session.authenticate(buffer.byteArray());
assertEquals(AUTHENTICATED, session.state());

session.onArchiveId(777);
assertEquals(ACTIVE, session.state());

session.doWork();

final ArgumentCaptor<UnsafeBuffer> bufferCaptor = ArgumentCaptor.forClass(UnsafeBuffer.class);
verify(mockProxy).sendDescriptor(
eq(CONTROL_SESSION_ID), eq(correlationId), bufferCaptor.capture(), eq(0), eq(payloadLength), same(session));
final UnsafeBuffer tmpBuffer = bufferCaptor.getValue();
assertNotNull(tmpBuffer);
assertNotSame(buffer, tmpBuffer);
for (int i = 0; i < payloadLength; i++)
{
assertEquals(buffer.getByte(payloadOffset + i), tmpBuffer.getByte(i));
}
}
}

0 comments on commit 61e4aba

Please sign in to comment.