Skip to content

Commit

Permalink
make framer validate MsgType before parsing it
Browse files Browse the repository at this point in the history
  • Loading branch information
wojciech-adaptive committed Jan 8, 2025
1 parent 50ddcde commit a35cd33
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
*/
public final class MessageTypeEncoding
{
public static final int MAX_MESSAGE_TYPE_LENGTH = 8;

private static final int MESSAGE_TYPE_BITSHIFT = 8;

/**
Expand Down Expand Up @@ -87,7 +89,7 @@ public static long packMessageType(final char[] messageType, final int length)

private static void checkLength(final int length)
{
if (length > 8)
if (length > MAX_MESSAGE_TYPE_LENGTH)
{
throw new IllegalArgumentException("Message types longer than 8 are not supported yet");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import static uk.co.real_logic.artio.session.Session.UNKNOWN;
import static uk.co.real_logic.artio.util.AsciiBuffer.SEPARATOR;
import static uk.co.real_logic.artio.util.AsciiBuffer.UNKNOWN_INDEX;
import static uk.co.real_logic.artio.util.MessageTypeEncoding.MAX_MESSAGE_TYPE_LENGTH;

/**
* Handles incoming data from sockets.
Expand Down Expand Up @@ -461,9 +462,17 @@ private boolean frameMessages(final long readTimestampInNs)
break; // Need more data
}

final long messageType = getMessageType(endOfBodyLength, endOfMessage);
final int length = (endOfMessage + 1) - offset;
if (!validateChecksum(endOfMessage, startOfChecksumValue, offset, startOfChecksumTag))

final long messageType = getMessageType(endOfBodyLength, endOfMessage);
if (messageType == 0)
{
if (saveInvalidMsgTypeMessage(offset, length, readTimestampInNs))
{
return false;
}
}
else if (!validateChecksum(endOfMessage, startOfChecksumValue, offset, startOfChecksumTag))
{
DebugLogger.logFixMessage(
FIX_MESSAGE, messageType, "Invalidated (checksum): ", buffer, offset, length);
Expand Down Expand Up @@ -938,9 +947,25 @@ private boolean isStartOfChecksum(final int startOfChecksumTag)

private long getMessageType(final int endOfBodyLength, final int indexOfLastByteOfMessage)
{
final int start = buffer.scan(endOfBodyLength, indexOfLastByteOfMessage, '=') + 1;
final int end = buffer.scan(start + 1, indexOfLastByteOfMessage, START_OF_HEADER);
if (0x3d353301 /* <SOH>35= */ != buffer.getInt(endOfBodyLength))
{
return 0;
}

final int start = endOfBodyLength + 4;
final int limit = Math.min(start + MAX_MESSAGE_TYPE_LENGTH, indexOfLastByteOfMessage) + 1;
final int end = buffer.scan(start, limit, START_OF_HEADER);
if (UNKNOWN_INDEX == end)
{
return 0;
}

final int length = end - start;
if (0 == length)
{
return 0;
}

return buffer.getMessageType(start, length);
}

Expand Down Expand Up @@ -1068,6 +1093,26 @@ private boolean saveInvalidChecksumMessage(
return stashIfBackPressured(offset, position);
}

private boolean saveInvalidMsgTypeMessage(final int offset, final int length, final long readTimestamp)
{
DebugLogger.log(FIX_MESSAGE, "Invalidated (MsgType): ", buffer, offset, length);

final long position = publication.saveMessage(
buffer,
offset,
length,
libraryId,
INVALID_MESSAGE_TYPE,
sessionId,
sequenceIndex,
connectionId,
INVALID,
0,
readTimestamp);

return stashIfBackPressured(offset, position);
}

void onLogonSent(final int sequenceIndex)
{
pendingSequenceIndex = sequenceIndex;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.function.Predicate;
import java.util.stream.Stream;

Expand Down Expand Up @@ -159,7 +160,7 @@ public Stream<FixMessage> receivedMessage(final String messageType)
{
return messages()
.stream()
.filter((fixMessage) -> fixMessage.get(MSG_TYPE).equals(messageType));
.filter((fixMessage) -> Objects.equals(fixMessage.get(MSG_TYPE), messageType));
}

public Stream<FixMessage> receivedReplay(final String messageType, final int sequenceNumber)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,16 @@
import uk.co.real_logic.artio.messages.*;
import uk.co.real_logic.artio.session.Session;
import uk.co.real_logic.artio.session.SessionWriter;
import uk.co.real_logic.artio.util.MessageTypeEncoding;
import uk.co.real_logic.artio.util.MutableAsciiBuffer;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.channels.SocketChannel;
import java.time.Instant;
import java.time.ZoneOffset;
import java.time.format.DateTimeFormatter;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;
Expand Down Expand Up @@ -61,6 +66,9 @@

public class MessageBasedAcceptorSystemTest extends AbstractMessageBasedAcceptorSystemTest
{
private static final DateTimeFormatter UTC_TIMESTAMP_FORMATTER =
DateTimeFormatter.ofPattern("uuuuMMdd-HH:mm:ss.SSS").withZone(ZoneOffset.UTC);

private final int timeoutDisconnectHeartBtIntInS = 1;
private final long timeoutDisconnectHeartBtIntInMs = TimeUnit.SECONDS.toMillis(timeoutDisconnectHeartBtIntInS);

Expand Down Expand Up @@ -922,6 +930,71 @@ public void shouldDisconnectConnectionTryingToSendOversizedMessage() throws IOEx
}
}

@Test(timeout = TEST_TIMEOUT_IN_MS)
public void shouldInvalidateMessageWithInvalidMsgType() throws IOException
{
setup(true, true);

setupLibrary();

try (FixConnection connection = FixConnection.initiate(port))
{
logon(connection);
acquireSession();

connection.sendBytes(messageWithMsgType(null));
connection.sendBytes(messageWithMsgType(""));
connection.sendBytes(messageWithMsgType("ABCDEFGHI"));

final String longestAllowedMsgType = "ABCDEFGH";
connection.sendBytes(messageWithMsgType(longestAllowedMsgType));

final FixMessage message = testSystem.awaitMessageOf(otfAcceptor, longestAllowedMsgType);
assertEquals(MessageTypeEncoding.packMessageType(longestAllowedMsgType), message.messageType());

if (otfAcceptor.messages().size() > 1)
{
fail("received more messages than expected: " + otfAcceptor.messages());
}
}
}

private static byte[] messageWithMsgType(final String msgType)
{
final byte[] buffer = new byte[128];
final MutableAsciiBuffer asciiBuffer = new MutableAsciiBuffer(buffer);

final int bodyStart = 16;
int index = bodyStart;

if (null != msgType)
{
index += asciiBuffer.putIntAscii(index, 35);
asciiBuffer.putByte(index++, (byte)'=');
index += asciiBuffer.putAscii(index, msgType);
asciiBuffer.putByte(index++, START_OF_HEADER);
}

index += asciiBuffer.putAscii(index, "49=initiator\u000156=acceptor\u000134=2\u000152=");
index += asciiBuffer.putAscii(index, UTC_TIMESTAMP_FORMATTER.format(Instant.now()));
asciiBuffer.putByte(index++, START_OF_HEADER);

asciiBuffer.putByte(bodyStart - 1, START_OF_HEADER);
final int length = index - bodyStart;
int startIndex = asciiBuffer.putNaturalIntAsciiFromEnd(length, bodyStart - 1);
final String prefix = "8=FIX.4.4\u00019=";
final int prefixLength = prefix.length();
startIndex -= asciiBuffer.putAscii(startIndex - prefixLength, prefix);

final int checksum = asciiBuffer.computeChecksum(startIndex, index);
index += asciiBuffer.putAscii(index, "10=");
asciiBuffer.putNaturalPaddedIntAscii(index, 3, checksum);
index += 3;
asciiBuffer.putByte(index++, START_OF_HEADER);

return Arrays.copyOfRange(buffer, startIndex, index);
}

@Test(timeout = TEST_TIMEOUT_IN_MS)
public void shouldSupportFollowerSessionLogonWithoutSequenceResetOnDisconnectBeforeLibraryLogonResponse()
throws IOException
Expand Down

0 comments on commit a35cd33

Please sign in to comment.