From a35cd331fa0177d947388c2ff04f23be1ff140a1 Mon Sep 17 00:00:00 2001 From: Wojciech Lukowicz Date: Wed, 8 Jan 2025 11:11:16 +0000 Subject: [PATCH] make framer validate MsgType before parsing it --- .../artio/util/MessageTypeEncoding.java | 4 +- .../engine/framer/FixReceiverEndPoint.java | 53 +++++++++++++- .../artio/system_tests/FakeOtfAcceptor.java | 3 +- .../MessageBasedAcceptorSystemTest.java | 73 +++++++++++++++++++ 4 files changed, 127 insertions(+), 6 deletions(-) diff --git a/artio-codecs/src/main/java/uk/co/real_logic/artio/util/MessageTypeEncoding.java b/artio-codecs/src/main/java/uk/co/real_logic/artio/util/MessageTypeEncoding.java index 544b528ee0..744550ced6 100644 --- a/artio-codecs/src/main/java/uk/co/real_logic/artio/util/MessageTypeEncoding.java +++ b/artio-codecs/src/main/java/uk/co/real_logic/artio/util/MessageTypeEncoding.java @@ -31,6 +31,8 @@ */ public final class MessageTypeEncoding { + public static final int MAX_MESSAGE_TYPE_LENGTH = 8; + private static final int MESSAGE_TYPE_BITSHIFT = 8; /** @@ -87,7 +89,7 @@ public static long packMessageType(final char[] messageType, final int length) private static void checkLength(final int length) { - if (length > 8) + if (length > MAX_MESSAGE_TYPE_LENGTH) { throw new IllegalArgumentException("Message types longer than 8 are not supported yet"); } diff --git a/artio-core/src/main/java/uk/co/real_logic/artio/engine/framer/FixReceiverEndPoint.java b/artio-core/src/main/java/uk/co/real_logic/artio/engine/framer/FixReceiverEndPoint.java index bdfbeddbbe..1c5e7c4714 100644 --- a/artio-core/src/main/java/uk/co/real_logic/artio/engine/framer/FixReceiverEndPoint.java +++ b/artio-core/src/main/java/uk/co/real_logic/artio/engine/framer/FixReceiverEndPoint.java @@ -49,6 +49,7 @@ import static uk.co.real_logic.artio.session.Session.UNKNOWN; import static uk.co.real_logic.artio.util.AsciiBuffer.SEPARATOR; import static uk.co.real_logic.artio.util.AsciiBuffer.UNKNOWN_INDEX; +import static uk.co.real_logic.artio.util.MessageTypeEncoding.MAX_MESSAGE_TYPE_LENGTH; /** * Handles incoming data from sockets. @@ -461,9 +462,17 @@ private boolean frameMessages(final long readTimestampInNs) break; // Need more data } - final long messageType = getMessageType(endOfBodyLength, endOfMessage); final int length = (endOfMessage + 1) - offset; - if (!validateChecksum(endOfMessage, startOfChecksumValue, offset, startOfChecksumTag)) + + final long messageType = getMessageType(endOfBodyLength, endOfMessage); + if (messageType == 0) + { + if (saveInvalidMsgTypeMessage(offset, length, readTimestampInNs)) + { + return false; + } + } + else if (!validateChecksum(endOfMessage, startOfChecksumValue, offset, startOfChecksumTag)) { DebugLogger.logFixMessage( FIX_MESSAGE, messageType, "Invalidated (checksum): ", buffer, offset, length); @@ -938,9 +947,25 @@ private boolean isStartOfChecksum(final int startOfChecksumTag) private long getMessageType(final int endOfBodyLength, final int indexOfLastByteOfMessage) { - final int start = buffer.scan(endOfBodyLength, indexOfLastByteOfMessage, '=') + 1; - final int end = buffer.scan(start + 1, indexOfLastByteOfMessage, START_OF_HEADER); + if (0x3d353301 /* 35= */ != buffer.getInt(endOfBodyLength)) + { + return 0; + } + + final int start = endOfBodyLength + 4; + final int limit = Math.min(start + MAX_MESSAGE_TYPE_LENGTH, indexOfLastByteOfMessage) + 1; + final int end = buffer.scan(start, limit, START_OF_HEADER); + if (UNKNOWN_INDEX == end) + { + return 0; + } + final int length = end - start; + if (0 == length) + { + return 0; + } + return buffer.getMessageType(start, length); } @@ -1068,6 +1093,26 @@ private boolean saveInvalidChecksumMessage( return stashIfBackPressured(offset, position); } + private boolean saveInvalidMsgTypeMessage(final int offset, final int length, final long readTimestamp) + { + DebugLogger.log(FIX_MESSAGE, "Invalidated (MsgType): ", buffer, offset, length); + + final long position = publication.saveMessage( + buffer, + offset, + length, + libraryId, + INVALID_MESSAGE_TYPE, + sessionId, + sequenceIndex, + connectionId, + INVALID, + 0, + readTimestamp); + + return stashIfBackPressured(offset, position); + } + void onLogonSent(final int sequenceIndex) { pendingSequenceIndex = sequenceIndex; diff --git a/artio-system-tests/src/test/java/uk/co/real_logic/artio/system_tests/FakeOtfAcceptor.java b/artio-system-tests/src/test/java/uk/co/real_logic/artio/system_tests/FakeOtfAcceptor.java index a2ef4f78a1..44c979b2c7 100644 --- a/artio-system-tests/src/test/java/uk/co/real_logic/artio/system_tests/FakeOtfAcceptor.java +++ b/artio-system-tests/src/test/java/uk/co/real_logic/artio/system_tests/FakeOtfAcceptor.java @@ -28,6 +28,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Objects; import java.util.function.Predicate; import java.util.stream.Stream; @@ -159,7 +160,7 @@ public Stream receivedMessage(final String messageType) { return messages() .stream() - .filter((fixMessage) -> fixMessage.get(MSG_TYPE).equals(messageType)); + .filter((fixMessage) -> Objects.equals(fixMessage.get(MSG_TYPE), messageType)); } public Stream receivedReplay(final String messageType, final int sequenceNumber) diff --git a/artio-system-tests/src/test/java/uk/co/real_logic/artio/system_tests/MessageBasedAcceptorSystemTest.java b/artio-system-tests/src/test/java/uk/co/real_logic/artio/system_tests/MessageBasedAcceptorSystemTest.java index 6a68c49ef1..63e8d3ddf7 100644 --- a/artio-system-tests/src/test/java/uk/co/real_logic/artio/system_tests/MessageBasedAcceptorSystemTest.java +++ b/artio-system-tests/src/test/java/uk/co/real_logic/artio/system_tests/MessageBasedAcceptorSystemTest.java @@ -28,11 +28,16 @@ import uk.co.real_logic.artio.messages.*; import uk.co.real_logic.artio.session.Session; import uk.co.real_logic.artio.session.SessionWriter; +import uk.co.real_logic.artio.util.MessageTypeEncoding; import uk.co.real_logic.artio.util.MutableAsciiBuffer; import java.io.IOException; import java.net.InetSocketAddress; import java.nio.channels.SocketChannel; +import java.time.Instant; +import java.time.ZoneOffset; +import java.time.format.DateTimeFormatter; +import java.util.Arrays; import java.util.List; import java.util.concurrent.TimeUnit; import java.util.function.BiConsumer; @@ -61,6 +66,9 @@ public class MessageBasedAcceptorSystemTest extends AbstractMessageBasedAcceptorSystemTest { + private static final DateTimeFormatter UTC_TIMESTAMP_FORMATTER = + DateTimeFormatter.ofPattern("uuuuMMdd-HH:mm:ss.SSS").withZone(ZoneOffset.UTC); + private final int timeoutDisconnectHeartBtIntInS = 1; private final long timeoutDisconnectHeartBtIntInMs = TimeUnit.SECONDS.toMillis(timeoutDisconnectHeartBtIntInS); @@ -922,6 +930,71 @@ public void shouldDisconnectConnectionTryingToSendOversizedMessage() throws IOEx } } + @Test(timeout = TEST_TIMEOUT_IN_MS) + public void shouldInvalidateMessageWithInvalidMsgType() throws IOException + { + setup(true, true); + + setupLibrary(); + + try (FixConnection connection = FixConnection.initiate(port)) + { + logon(connection); + acquireSession(); + + connection.sendBytes(messageWithMsgType(null)); + connection.sendBytes(messageWithMsgType("")); + connection.sendBytes(messageWithMsgType("ABCDEFGHI")); + + final String longestAllowedMsgType = "ABCDEFGH"; + connection.sendBytes(messageWithMsgType(longestAllowedMsgType)); + + final FixMessage message = testSystem.awaitMessageOf(otfAcceptor, longestAllowedMsgType); + assertEquals(MessageTypeEncoding.packMessageType(longestAllowedMsgType), message.messageType()); + + if (otfAcceptor.messages().size() > 1) + { + fail("received more messages than expected: " + otfAcceptor.messages()); + } + } + } + + private static byte[] messageWithMsgType(final String msgType) + { + final byte[] buffer = new byte[128]; + final MutableAsciiBuffer asciiBuffer = new MutableAsciiBuffer(buffer); + + final int bodyStart = 16; + int index = bodyStart; + + if (null != msgType) + { + index += asciiBuffer.putIntAscii(index, 35); + asciiBuffer.putByte(index++, (byte)'='); + index += asciiBuffer.putAscii(index, msgType); + asciiBuffer.putByte(index++, START_OF_HEADER); + } + + index += asciiBuffer.putAscii(index, "49=initiator\u000156=acceptor\u000134=2\u000152="); + index += asciiBuffer.putAscii(index, UTC_TIMESTAMP_FORMATTER.format(Instant.now())); + asciiBuffer.putByte(index++, START_OF_HEADER); + + asciiBuffer.putByte(bodyStart - 1, START_OF_HEADER); + final int length = index - bodyStart; + int startIndex = asciiBuffer.putNaturalIntAsciiFromEnd(length, bodyStart - 1); + final String prefix = "8=FIX.4.4\u00019="; + final int prefixLength = prefix.length(); + startIndex -= asciiBuffer.putAscii(startIndex - prefixLength, prefix); + + final int checksum = asciiBuffer.computeChecksum(startIndex, index); + index += asciiBuffer.putAscii(index, "10="); + asciiBuffer.putNaturalPaddedIntAscii(index, 3, checksum); + index += 3; + asciiBuffer.putByte(index++, START_OF_HEADER); + + return Arrays.copyOfRange(buffer, startIndex, index); + } + @Test(timeout = TEST_TIMEOUT_IN_MS) public void shouldSupportFollowerSessionLogonWithoutSequenceResetOnDisconnectBeforeLibraryLogonResponse() throws IOException