From 39997814bfa3ea70d5f63a7c9bacdf34c962d823 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Wed, 28 Sep 2016 10:26:06 +0200 Subject: [PATCH] [#5860] Ensure removal of SslHandler not produce IllegalReferenceCountException Motivation: If the user removes the SslHandler while still in the processing loop we will produce an IllegalReferenceCountException. We should stop looping when the handlerwas removed. Modifications: Ensure we stop looping when the handler is removed. Result: No more IllegalReferenceCountException. --- .../java/io/netty/handler/ssl/SslHandler.java | 21 ++++- .../io/netty/handler/ssl/SslHandlerTest.java | 94 +++++++++++++++++++ 2 files changed, 110 insertions(+), 5 deletions(-) diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index 11d8f95226..55c09aef43 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -506,7 +506,9 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH ByteBufAllocator alloc = ctx.alloc(); boolean needUnwrap = false; try { - for (;;) { + // Only continue to loop if the handler was not removed in the meantime. + // See https://github.com/netty/netty/issues/5860 + while (!ctx.isRemoved()) { Object msg = pendingUnencryptedWrites.current(); if (msg == null) { break; @@ -593,7 +595,9 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH ByteBuf out = null; ByteBufAllocator alloc = ctx.alloc(); try { - for (;;) { + // Only continue to loop if the handler was not removed in the meantime. + // See https://github.com/netty/netty/issues/5860 + while (!ctx.isRemoved()) { if (out == null) { out = allocateOutNetBuf(ctx, 0); } @@ -903,8 +907,12 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH NotSslRecordException e = new NotSslRecordException( "not an SSL/TLS record: " + ByteBufUtil.hexDump(in)); in.skipBytes(in.readableBytes()); - ctx.fireExceptionCaught(e); + + // First fail the handshake promise as we may need to have access to the SSLEngine which may + // be released because the user will remove the SslHandler in an exceptionCaught(...) implementation. setHandshakeFailure(ctx, e); + + ctx.fireExceptionCaught(e); } } @@ -954,7 +962,9 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH boolean notifyClosure = false; ByteBuf decodeOut = allocate(ctx, length); try { - for (;;) { + // Only continue to loop if the handler was not removed in the meantime. + // See https://github.com/netty/netty/issues/5860 + while (!ctx.isRemoved()) { final SSLEngineResult result = unwrap(engine, packet, offset, length, decodeOut); final Status status = result.getStatus(); final HandshakeStatus handshakeStatus = result.getHandshakeStatus(); @@ -968,6 +978,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH switch (status) { case BUFFER_OVERFLOW: int readableBytes = decodeOut.readableBytes(); + int bufferSize = engine.getSession().getApplicationBufferSize() - readableBytes; if (readableBytes > 0) { decoded = true; ctx.fireChannelRead(decodeOut); @@ -977,7 +988,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH // Allocate a new buffer which can hold all the rest data and loop again. // TODO: We may want to reconsider how we calculate the length here as we may // have more then one ssl message to decode. - decodeOut = allocate(ctx, engine.getSession().getApplicationBufferSize() - readableBytes); + decodeOut = allocate(ctx, bufferSize); continue; case CLOSED: // notify about the CLOSED state of the SSLEngine. See #137 diff --git a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java index f5bcb72e38..4c16eb5196 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java @@ -26,6 +26,20 @@ import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLProtocolException; +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.CodecException; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; +import io.netty.util.concurrent.Promise; import org.junit.Test; import io.netty.buffer.ByteBufAllocator; @@ -40,6 +54,8 @@ import io.netty.handler.ssl.util.SelfSignedCertificate; import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCounted; +import java.net.InetSocketAddress; + public class SslHandlerTest { @Test @@ -148,4 +164,82 @@ public class SslHandlerTest { // the handshake is initiated by flush new TlsReadTest().test(true); } + + @Test(timeout = 5000) + public void testRemoval() throws Exception { + NioEventLoopGroup group = new NioEventLoopGroup(); + Channel sc = null; + Channel cc = null; + try { + final Promise clientPromise = group.next().newPromise(); + Bootstrap bootstrap = new Bootstrap() + .group(group) + .channel(NioSocketChannel.class) + .handler(newHandler(SslContextBuilder.forClient().trustManager( + InsecureTrustManagerFactory.INSTANCE).build(), clientPromise)); + + SelfSignedCertificate ssc = new SelfSignedCertificate(); + final Promise serverPromise = group.next().newPromise(); + ServerBootstrap serverBootstrap = new ServerBootstrap() + .group(group, group) + .channel(NioServerSocketChannel.class) + .childHandler(newHandler(SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()).build(), + serverPromise)); + sc = serverBootstrap.bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + cc = bootstrap.connect(sc.localAddress()).syncUninterruptibly().channel(); + + serverPromise.syncUninterruptibly(); + clientPromise.syncUninterruptibly(); + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + if (sc != null) { + sc.close().syncUninterruptibly(); + } + group.shutdownGracefully(); + } + } + + private static ChannelHandler newHandler(final SslContext sslCtx, final Promise promise) { + return new ChannelInitializer() { + @Override + protected void initChannel(final Channel ch) { + final SslHandler sslHandler = sslCtx.newHandler(ch.alloc()); + sslHandler.setHandshakeTimeoutMillis(1000); + ch.pipeline().addFirst(sslHandler); + sslHandler.handshakeFuture().addListener(new FutureListener() { + @Override + public void operationComplete(final Future future) { + ch.pipeline().remove(sslHandler); + + // Schedule the close so removal has time to propergate exception if any. + ch.eventLoop().execute(new Runnable() { + @Override + public void run() { + ch.close(); + } + }); + } + }); + + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + if (cause instanceof CodecException) { + cause = cause.getCause(); + } + if (cause instanceof IllegalReferenceCountException) { + promise.setFailure(cause); + } + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + promise.trySuccess(null); + } + }); + } + }; + } }