diff --git a/src/main/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoder.java b/src/main/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoder.java index 9f808fd5..817d0bc5 100644 --- a/src/main/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoder.java +++ b/src/main/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoder.java @@ -78,37 +78,42 @@ protected byte decodeProtocolVersion(ByteBuf in) { @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { - in.markReaderIndex(); - ProtocolCode protocolCode; - Protocol protocol; try { - protocolCode = decodeProtocolCode(in); - if (protocolCode == null) { - // read to end - return; - } + in.markReaderIndex(); + ProtocolCode protocolCode; + Protocol protocol; + try { + protocolCode = decodeProtocolCode(in); + if (protocolCode == null) { + // read to end + return; + } - byte protocolVersion = decodeProtocolVersion(in); - if (ctx.channel().attr(Connection.PROTOCOL).get() == null) { - ctx.channel().attr(Connection.PROTOCOL).set(protocolCode); - if (DEFAULT_ILLEGAL_PROTOCOL_VERSION_LENGTH != protocolVersion) { - ctx.channel().attr(Connection.VERSION).set(protocolVersion); + byte protocolVersion = decodeProtocolVersion(in); + if (ctx.channel().attr(Connection.PROTOCOL).get() == null) { + ctx.channel().attr(Connection.PROTOCOL).set(protocolCode); + if (DEFAULT_ILLEGAL_PROTOCOL_VERSION_LENGTH != protocolVersion) { + ctx.channel().attr(Connection.VERSION).set(protocolVersion); + } } + + protocol = ProtocolManager.getProtocol(protocolCode); + } finally { + // reset the readerIndex before throwing an exception or decoding content + // to ensure that the packet is complete + in.resetReaderIndex(); } - protocol = ProtocolManager.getProtocol(protocolCode); - } finally { - // reset the readerIndex before throwing an exception or decoding content - // to ensure that the packet is complete - in.resetReaderIndex(); - } + if (protocol == null) { + throw new CodecException("Unknown protocol code: [" + protocolCode + + "] while decode in ProtocolDecoder."); + } - if (protocol == null) { - in.release(); - throw new CodecException("Unknown protocol code: [" + protocolCode - + "] while decode in ProtocolDecoder."); + protocol.getDecoder().decode(ctx, in, out); + } catch (Exception e) { + // 清空可读取区域,让 AbstractBatchDecoder#L257行release它 + in.skipBytes(in.readableBytes()); + throw e; } - - protocol.getDecoder().decode(ctx, in, out); } } diff --git a/src/test/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoderTest.java b/src/test/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoderTest.java index 5a713017..44617622 100644 --- a/src/test/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoderTest.java +++ b/src/test/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoderTest.java @@ -33,7 +33,6 @@ import io.netty.channel.embedded.EmbeddedChannel; import io.netty.util.Attribute; import io.netty.util.AttributeKey; -import io.netty.util.ResourceLeakDetector; import io.netty.util.concurrent.EventExecutor; import org.junit.Assert; import org.junit.Test; @@ -54,6 +53,7 @@ public void testDecodeIllegalPacket() throws Exception { ProtocolCodeBasedDecoder decoder = new ProtocolCodeBasedDecoder(1); int readerIndex = byteBuf.readerIndex(); + int readableBytes = byteBuf.readableBytes(); Assert.assertEquals(0, readerIndex); Exception exception = null; @@ -67,13 +67,11 @@ public void testDecodeIllegalPacket() throws Exception { Assert.assertNotNull(exception); readerIndex = byteBuf.readerIndex(); - Assert.assertEquals(0, readerIndex); + Assert.assertEquals(readableBytes, readerIndex); } @Test public void testDecodeIllegalPacket2() { - ResourceLeakDetector.setLevel(ResourceLeakDetector.Level.PARANOID); - EmbeddedChannel channel = new EmbeddedChannel(); ProtocolCodeBasedDecoder decoder = new ProtocolCodeBasedDecoder(1); channel.pipeline().addLast(decoder); @@ -82,6 +80,7 @@ public void testDecodeIllegalPacket2() { byteBuf.writeByte((byte) 13); int readerIndex = byteBuf.readerIndex(); + int readableBytes = byteBuf.readableBytes(); Assert.assertEquals(0, readerIndex); Exception exception = null; try { @@ -92,7 +91,7 @@ public void testDecodeIllegalPacket2() { } Assert.assertNotNull(exception); readerIndex = byteBuf.readerIndex(); - Assert.assertEquals(0, readerIndex); + Assert.assertEquals(readableBytes, readerIndex); Assert.assertTrue(byteBuf.refCnt() == 0); }