Skip to content

Commit

Permalink
Merge pull request #1584 from real-logic/ingress_unknown_schema_id
Browse files Browse the repository at this point in the history
[Java ]  Delegate decision on unknown schema to consensus agent
  • Loading branch information
langera authored Jun 3, 2024
2 parents 7a52d88 + 3c1089d commit 57376c8
Show file tree
Hide file tree
Showing 9 changed files with 352 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,13 @@ public static void main(final String[] args)
final ShutdownSignalBarrier barrier = new ShutdownSignalBarrier();
final MediaDriver.Context ctx = new MediaDriver.Context()
.terminationHook(barrier::signalAll);
final Archive.Context archiveCtx = new Archive.Context();

try (ArchivingMediaDriver ignore = launch(ctx, new Archive.Context()))
try (ArchivingMediaDriver ignore = launch(ctx, archiveCtx))
{
System.out.println("MediaDriver.Context " + ctx);
System.out.println("Archive.Context " + archiveCtx);

barrier.await();
System.out.println("Shutdown Archive...");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1434,7 +1434,8 @@ public static final class Context implements Cloneable
private Random random;
private TimerServiceSupplier timerServiceSupplier;
private Function<Context, LongConsumer> clusterTimeConsumerSupplier;

private Supplier<ConsensusModuleExtension> consensusModuleExtensionSupplier;
private ConsensusModuleExtension consensusModuleExtension;
private DistinctErrorLog errorLog;
private ErrorHandler errorHandler;
private AtomicCounter errorCounter;
Expand Down Expand Up @@ -3977,6 +3978,31 @@ public TimerServiceSupplier timerServiceSupplier()
return timerServiceSupplier;
}

/**
* Registers a ConsensusModuleExtension to extend beahviour of
* consensus module instead of using ClusteredServices
*
* @param extensionSupplier supplier for consensus module extension
* @return this for a fluent API.
*/
public Context consensusModuleExtension(final Supplier<ConsensusModuleExtension> extensionSupplier)
{
consensusModuleExtensionSupplier = extensionSupplier;
return this;
}

/**
* @return Supplier for registered consensus module extension or null
*/
public ConsensusModuleExtension consensusModuleExtension()
{
if (consensusModuleExtensionSupplier != null)
{
consensusModuleExtension = consensusModuleExtensionSupplier.get();
}
return consensusModuleExtension;
}

/**
* Deprecated for removal.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import io.aeron.driver.media.UdpChannel;
import io.aeron.exceptions.AeronException;
import io.aeron.logbuffer.ControlledFragmentHandler;
import io.aeron.logbuffer.Header;
import io.aeron.security.Authenticator;
import io.aeron.security.AuthorisationService;
import io.aeron.status.LocalSocketAddressStatus;
Expand Down Expand Up @@ -137,7 +138,7 @@ final class ConsensusModuleAgent implements Agent, TimerService.TimerHandler, Co
private final ArrayDeque<ClusterSession> uncommittedClosedSessions = new ArrayDeque<>();
private final LongArrayQueue uncommittedTimers = new LongArrayQueue(Long.MAX_VALUE);
private final PendingServiceMessageTracker[] pendingServiceMessageTrackers;

private final ConsensusModuleExtension consensusModuleExtension;
private final Authenticator authenticator;
private final AuthorisationService authorisationService;
private final ClusterSessionProxy sessionProxy;
Expand Down Expand Up @@ -226,7 +227,7 @@ final class ConsensusModuleAgent implements Agent, TimerService.TimerHandler, Co
pendingServiceMessageTrackers[i] = new PendingServiceMessageTracker(
i, commitPosition, logPublisher, clusterClock);
}

this.consensusModuleExtension = ctx.consensusModuleExtension();
responseChannelTemplate = Strings.isEmpty(ctx.egressChannel()) ? null : ChannelUri.parse(ctx.egressChannel());
}

Expand Down Expand Up @@ -405,6 +406,21 @@ public void onLoadBeginSnapshot(
}
}

public ControlledFragmentHandler.Action onExtensionMessage(
final int schemaId,
final int templateId,
final DirectBuffer buffer,
final int offset,
final int length,
final Header header)
{
if (consensusModuleExtension == null)
{
throw new ClusterException("expected schemaId=" + MessageHeaderDecoder.SCHEMA_ID + ", actual=" + schemaId);
}
return consensusModuleExtension.onMessage(schemaId, templateId, buffer, offset, length, header);
}

public void onLoadEndSnapshot(final DirectBuffer buffer, final int offset, final int length)
{
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright 2014-2024 Real Logic Limited.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.aeron.cluster;

import org.agrona.DirectBuffer;

import io.aeron.logbuffer.ControlledFragmentHandler;
import io.aeron.logbuffer.Header;

/**
* Adapter for handling messages from external schemas unknown to core Aeron cluster code
* thus providing an extension to the core ingress Consensus module behaviour
*/
public interface ConsensusModuleExtension extends AutoCloseable
{
/**
* schema supported by this extension
*
* @return schema id
*/
int supportedSchemaId();

/**
* Callback for handling fragments of data being read from a log.
* <p>
* Within this callback reentrant calls to the {@link io.aeron.Aeron} client are not permitted and
* will result in undefined behaviour.
*
* @param schemaId the schema id
* @param templateId the message template id (already parsed from header)
* @param buffer containing the data.
* @param offset at which the data begins.
* @param length of the data in bytes.
* @param header representing the metadata for the data.
* @return The action to be taken with regard to the stream position after the callback.
*/
ControlledFragmentHandler.Action onMessage(
int schemaId,
int templateId,
DirectBuffer buffer,
int offset,
int length,
Header header);

@Override
default void close()
{
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import io.aeron.ControlledFragmentAssembler;
import io.aeron.Subscription;
import io.aeron.cluster.client.AeronCluster;
import io.aeron.cluster.client.ClusterException;
import io.aeron.cluster.codecs.*;
import io.aeron.logbuffer.ControlledFragmentHandler;
import io.aeron.logbuffer.Header;
Expand Down Expand Up @@ -73,12 +72,12 @@ public Action onFragment(final DirectBuffer buffer, final int offset, final int
messageHeaderDecoder.wrap(buffer, offset);

final int schemaId = messageHeaderDecoder.schemaId();
final int templateId = messageHeaderDecoder.templateId();
if (schemaId != MessageHeaderDecoder.SCHEMA_ID)
{
throw new ClusterException("expected schemaId=" + MessageHeaderDecoder.SCHEMA_ID + ", actual=" + schemaId);
return consensusModuleAgent.onExtensionMessage(schemaId, templateId, buffer, offset, length, header);
}

final int templateId = messageHeaderDecoder.templateId();
if (templateId == SessionMessageHeaderDecoder.TEMPLATE_ID)
{
sessionMessageHeaderDecoder.wrap(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,20 @@
*/
package io.aeron.cluster;

import io.aeron.Aeron;
import io.aeron.ChannelUri;
import io.aeron.ConcurrentPublication;
import io.aeron.Counter;
import io.aeron.ExclusivePublication;
import io.aeron.Subscription;
import io.aeron.UnavailableImageHandler;
import io.aeron.archive.client.AeronArchive;
import io.aeron.cluster.codecs.CloseReason;
import io.aeron.cluster.codecs.ClusterAction;
import io.aeron.cluster.codecs.EventCode;
import io.aeron.cluster.service.Cluster;
import io.aeron.cluster.service.ClusterMarkFile;
import io.aeron.cluster.service.ClusterTerminationException;
import io.aeron.driver.DutyCycleTracker;
import io.aeron.security.AuthorisationService;
import io.aeron.security.DefaultAuthenticatorSupplier;
import io.aeron.status.ReadableCounter;
import io.aeron.test.TestContexts;
import io.aeron.test.Tests;
import io.aeron.test.cluster.TestClusterClock;
import static io.aeron.AeronCounters.*;
import static io.aeron.cluster.ClusterControl.ToggleState.*;
import static io.aeron.cluster.ConsensusModule.CLUSTER_ACTION_FLAGS_STANDBY_SNAPSHOT;
import static io.aeron.cluster.ConsensusModule.Configuration.SESSION_LIMIT_MSG;
import static io.aeron.cluster.ConsensusModuleAgent.SLOW_TICK_INTERVAL_NS;
import static io.aeron.cluster.client.AeronCluster.Configuration.PROTOCOL_SEMANTIC_VERSION;
import static java.lang.Boolean.TRUE;
import static org.agrona.concurrent.status.CountersReader.COUNTER_LENGTH;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;

import java.util.concurrent.TimeUnit;
import java.util.function.LongConsumer;

import org.agrona.collections.MutableLong;
import org.agrona.concurrent.AgentInvoker;
import org.agrona.concurrent.CountedErrorHandler;
Expand All @@ -50,28 +43,30 @@
import org.mockito.InOrder;
import org.mockito.Mockito;

import java.util.concurrent.TimeUnit;
import java.util.function.LongConsumer;

import static io.aeron.AeronCounters.CLUSTER_CONSENSUS_MODULE_STATE_TYPE_ID;
import static io.aeron.AeronCounters.CLUSTER_CONTROL_TOGGLE_TYPE_ID;
import static io.aeron.cluster.ClusterControl.ToggleState.*;
import static io.aeron.cluster.ConsensusModule.Configuration.SESSION_LIMIT_MSG;
import static io.aeron.cluster.ConsensusModule.CLUSTER_ACTION_FLAGS_STANDBY_SNAPSHOT;
import static io.aeron.cluster.ConsensusModuleAgent.SLOW_TICK_INTERVAL_NS;
import static io.aeron.cluster.client.AeronCluster.Configuration.PROTOCOL_SEMANTIC_VERSION;
import static java.lang.Boolean.TRUE;
import static org.agrona.concurrent.status.CountersReader.COUNTER_LENGTH;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.*;
import io.aeron.*;
import io.aeron.archive.client.AeronArchive;
import io.aeron.cluster.client.ClusterException;
import io.aeron.cluster.codecs.CloseReason;
import io.aeron.cluster.codecs.ClusterAction;
import io.aeron.cluster.codecs.EventCode;
import io.aeron.cluster.service.Cluster;
import io.aeron.cluster.service.ClusterMarkFile;
import io.aeron.cluster.service.ClusterTerminationException;
import io.aeron.driver.DutyCycleTracker;
import io.aeron.security.AuthorisationService;
import io.aeron.security.DefaultAuthenticatorSupplier;
import io.aeron.status.ReadableCounter;
import io.aeron.test.TestContexts;
import io.aeron.test.Tests;
import io.aeron.test.cluster.TestClusterClock;

public class ConsensusModuleAgentTest
{
private static final long SLOW_TICK_INTERVAL_MS = TimeUnit.NANOSECONDS.toMillis(SLOW_TICK_INTERVAL_NS);
private static final String RESPONSE_CHANNEL_ONE = "aeron:udp?endpoint=localhost:11111";
private static final String RESPONSE_CHANNEL_TWO = "aeron:udp?endpoint=localhost:22222";
private static final int SCHEMA_ID = 17;
private static final int MILLIS = 19;

private final EgressPublisher mockEgressPublisher = mock(EgressPublisher.class);
private final LogPublisher mockLogPublisher = mock(LogPublisher.class);
Expand Down Expand Up @@ -161,7 +156,7 @@ public void shouldLimitActiveSessions()
Tests.setField(agent, "appendPosition", mock(ReadableCounter.class));
agent.onSessionConnect(correlationIdOne, 2, PROTOCOL_SEMANTIC_VERSION, RESPONSE_CHANNEL_ONE, new byte[0]);

clock.update(17, TimeUnit.MILLISECONDS);
clock.update(MILLIS, TimeUnit.MILLISECONDS);
agent.doWork();
verify(mockTimeConsumer).accept(clock.time());

Expand Down Expand Up @@ -450,4 +445,33 @@ void onCommmitPositionShouldUpdateTimeOfLastLeaderMessageReceived()

assertEquals(444, consensusModuleAgent.timeOfLastLeaderUpdateNs());
}

@Test
void shouldDelegateHandlingToRegisteredExtension()
{
final ConsensusModuleExtension consensusModuleExtension = mock(ConsensusModuleExtension.class, "used adapter");
when(consensusModuleExtension.supportedSchemaId()).thenReturn(SCHEMA_ID);
final TestClusterClock clock = new TestClusterClock(TimeUnit.MILLISECONDS);
ctx.epochClock(clock)
.clusterClock(clock)
.consensusModuleExtension(() -> consensusModuleExtension);

final ConsensusModuleAgent agent = new ConsensusModuleAgent(ctx);
agent.onExtensionMessage(SCHEMA_ID, 1, null, 0, 0, null);

verify(consensusModuleExtension)
.onMessage(SCHEMA_ID, 1, null, 0, 0, null);
}

@Test
void shouldThrowExceptionOnUnknownSchemaAndNoAdapter()
{
final TestClusterClock clock = new TestClusterClock(TimeUnit.MILLISECONDS);
ctx.epochClock(clock).clusterClock(clock);

final ConsensusModuleAgent agent = new ConsensusModuleAgent(ctx);

assertThrows(ClusterException.class,
() -> agent.onExtensionMessage(SCHEMA_ID, 0, null, 0, 0, null));
}
}
Loading

0 comments on commit 57376c8

Please sign in to comment.