From c921742a4281817052e06ce3bf28c859699e2cbf Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Sun, 5 Nov 2017 18:53:13 +0100 Subject: [PATCH] Dont fire an SslHandshakeEvent if the handshake was not started at all. Motivation: We should not fire a SslHandshakeEvent if the channel is closed but the handshake was not started. Modifications: - Add a variable to SslHandler which tracks if an handshake was started yet or not and depending on this fire the event. - Add a unit test Result: Fixes [#7262]. --- .../netty/handler/ssl/AbstractSniHandler.java | 2 +- .../java/io/netty/handler/ssl/SslHandler.java | 45 ++++++++++--------- .../java/io/netty/handler/ssl/SslUtils.java | 6 ++- .../io/netty/handler/ssl/SslHandlerTest.java | 41 +++++++++++++++++ 4 files changed, 69 insertions(+), 25 deletions(-) diff --git a/handler/src/main/java/io/netty/handler/ssl/AbstractSniHandler.java b/handler/src/main/java/io/netty/handler/ssl/AbstractSniHandler.java index 44fda4d16e..d279a022ab 100644 --- a/handler/src/main/java/io/netty/handler/ssl/AbstractSniHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/AbstractSniHandler.java @@ -81,7 +81,7 @@ public abstract class AbstractSniHandler extends ByteToMessageDecoder impleme "not an SSL/TLS record: " + ByteBufUtil.hexDump(in)); in.skipBytes(in.readableBytes()); - SslUtils.notifyHandshakeFailure(ctx, e); + SslUtils.notifyHandshakeFailure(ctx, e, true); throw e; } if (len == SslUtils.NOT_ENOUGH_DATA || diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index 82fc18657a..709030fd36 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -371,6 +371,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH private boolean sentFirstMessage; private boolean flushedBeforeHandshake; private boolean readDuringHandshake; + private boolean handshakeStarted; private SslHandlerCoalescingBufferQueue pendingUnencryptedWrites; private Promise handshakePromise = new LazyChannelPromise(); private final LazyChannelPromise sslClosePromise = new LazyChannelPromise(); @@ -864,7 +865,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH /** * This method will not call - * {@link #setHandshakeFailure(ChannelHandlerContext, Throwable, boolean)} or + * {@link #setHandshakeFailure(ChannelHandlerContext, Throwable, boolean, boolean)} or * {@link #setHandshakeFailure(ChannelHandlerContext, Throwable)}. * @return {@code true} if this method ends on {@link SSLEngineResult.HandshakeStatus#NOT_HANDSHAKING}. */ @@ -1001,7 +1002,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH public void channelInactive(ChannelHandlerContext ctx) throws Exception { // Make sure to release SSLEngine, // and notify the handshake future if the connection has been closed during handshake. - setHandshakeFailure(ctx, CHANNEL_CLOSED, !outboundClosed); + setHandshakeFailure(ctx, CHANNEL_CLOSED, !outboundClosed, handshakeStarted); // Ensure we always notify the sslClosePromise as well notifyClosePromise(CHANNEL_CLOSED); @@ -1489,13 +1490,13 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH * Notify all the handshake futures about the failure during the handshake. */ private void setHandshakeFailure(ChannelHandlerContext ctx, Throwable cause) { - setHandshakeFailure(ctx, cause, true); + setHandshakeFailure(ctx, cause, true, true); } /** * Notify all the handshake futures about the failure during the handshake. */ - private void setHandshakeFailure(ChannelHandlerContext ctx, Throwable cause, boolean closeInbound) { + private void setHandshakeFailure(ChannelHandlerContext ctx, Throwable cause, boolean closeInbound, boolean notify) { try { // Release all resources such as internal buffers that SSLEngine // is managing. @@ -1515,7 +1516,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH } } } - notifyHandshakeFailure(cause); + notifyHandshakeFailure(cause, notify); } finally { // Ensure we remove and fail all pending writes in all cases and so release memory quickly. releaseAndFailAll(cause); @@ -1528,9 +1529,9 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH } } - private void notifyHandshakeFailure(Throwable cause) { + private void notifyHandshakeFailure(Throwable cause, boolean notify) { if (handshakePromise.tryFailure(cause)) { - SslUtils.notifyHandshakeFailure(ctx, cause); + SslUtils.notifyHandshakeFailure(ctx, cause, notify); } } @@ -1592,14 +1593,19 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH pendingUnencryptedWrites = new SslHandlerCoalescingBufferQueue(ctx.channel(), 16); if (ctx.channel().isActive()) { - if (engine.getUseClientMode()) { - // Begin the initial handshake. - // channelActive() event has been fired already, which means this.channelActive() will - // not be invoked. We have to initialize here instead. - handshake(null); - } else { - applyHandshakeTimeout(null); - } + startHandshakeProcessing(); + } + } + + private void startHandshakeProcessing() { + handshakeStarted = true; + if (engine.getUseClientMode()) { + // Begin the initial handshake. + // channelActive() event has been fired already, which means this.channelActive() will + // not be invoked. We have to initialize here instead. + handshake(null); + } else { + applyHandshakeTimeout(null); } } @@ -1709,7 +1715,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH return; } try { - notifyHandshakeFailure(HANDSHAKE_TIMED_OUT); + notifyHandshakeFailure(HANDSHAKE_TIMED_OUT, true); } finally { releaseAndFailAll(HANDSHAKE_TIMED_OUT); } @@ -1736,12 +1742,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH @Override public void channelActive(final ChannelHandlerContext ctx) throws Exception { if (!startTls) { - if (engine.getUseClientMode()) { - // Begin the initial handshake. - handshake(null); - } else { - applyHandshakeTimeout(null); - } + startHandshakeProcessing(); } ctx.fireChannelActive(); } diff --git a/handler/src/main/java/io/netty/handler/ssl/SslUtils.java b/handler/src/main/java/io/netty/handler/ssl/SslUtils.java index ae66ae412d..175238d41e 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslUtils.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslUtils.java @@ -310,11 +310,13 @@ final class SslUtils { return packetLength; } - static void notifyHandshakeFailure(ChannelHandlerContext ctx, Throwable cause) { + static void notifyHandshakeFailure(ChannelHandlerContext ctx, Throwable cause, boolean notify) { // We have may haven written some parts of data before an exception was thrown so ensure we always flush. // See https://github.com/netty/netty/issues/3900#issuecomment-172481830 ctx.flush(); - ctx.fireUserEventTriggered(new SslHandshakeCompletionEvent(cause)); + if (notify) { + ctx.fireUserEventTriggered(new SslHandshakeCompletionEvent(cause)); + } ctx.close(); } diff --git a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java index 79cc7025bf..65eaffd435 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java @@ -28,6 +28,7 @@ import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOutboundHandlerAdapter; import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelId; import io.netty.channel.EventLoopGroup; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.embedded.EmbeddedChannel; @@ -55,6 +56,7 @@ import java.util.concurrent.BlockingQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.atomic.AtomicBoolean; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; @@ -74,6 +76,45 @@ import static org.junit.Assume.assumeTrue; public class SslHandlerTest { + @Test + public void testNoSslHandshakeEventWhenNoHandshake() throws Exception { + final AtomicBoolean inActive = new AtomicBoolean(false); + + SSLEngine engine = SSLContext.getDefault().createSSLEngine(); + EmbeddedChannel ch = new EmbeddedChannel( + DefaultChannelId.newInstance(), false, false, new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + // Not forward the event to the SslHandler but just close the Channel. + ctx.close(); + } + }, new SslHandler(engine) { + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + // We want to override what Channel.isActive() will return as otherwise it will + // return true and so trigger an handshake. + inActive.set(true); + super.handlerAdded(ctx); + inActive.set(false); + } + }, new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof SslHandshakeCompletionEvent) { + throw (Exception) ((SslHandshakeCompletionEvent) evt).cause(); + } + } + }) { + @Override + public boolean isActive() { + return !inActive.get() && super.isActive(); + } + }; + + ch.register(); + assertFalse(ch.finishAndReleaseAll()); + } + @Test(expected = SSLException.class, timeout = 3000) public void testClientHandshakeTimeout() throws Exception { testHandshakeTimeout(true);