diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index 0e3db8d83a..4e73796e98 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -767,6 +767,9 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH sentFirstMessage = true; pendingUnencryptedWrites.writeAndRemoveAll(ctx); forceFlush(ctx); + // Explicit start handshake processing once we send the first message. This will also ensure + // we will schedule the timeout if needed. + startHandshakeProcessing(); return; } @@ -1661,14 +1664,16 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH } private void startHandshakeProcessing() { - handshakeStarted = true; - if (engine.getUseClientMode()) { - // Begin the initial handshake. - // channelActive() event has been fired already, which means this.channelActive() will - // not be invoked. We have to initialize here instead. - handshake(null); - } else { - applyHandshakeTimeout(null); + if (!handshakeStarted) { + handshakeStarted = true; + if (engine.getUseClientMode()) { + // Begin the initial handshake. + // channelActive() event has been fired already, which means this.channelActive() will + // not be invoked. We have to initialize here instead. + handshake(null, true); + } else { + applyHandshakeTimeout(null); + } } } @@ -1702,13 +1707,13 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH executor.execute(new Runnable() { @Override public void run() { - handshake(promise); + handshake(promise, false); } }); return promise; } - handshake(promise); + handshake(promise, false); return promise; } @@ -1719,7 +1724,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH * assuming that the current negotiation has not been finished. * Currently, {@code null} is expected only for the initial handshake. */ - private void handshake(final Promise newHandshakePromise) { + private void handshake(final Promise newHandshakePromise, boolean initialHandshake) { final Promise p; if (newHandshakePromise != null) { final Promise oldHandshakePromise = handshakePromise; @@ -1741,6 +1746,11 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH handshakePromise = p = newHandshakePromise; } else if (engine.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING) { + if (initialHandshake) { + // This is the intial handshake either triggered by handlerAdded(...), channelActive(...) or + // flush(...) when starttls was used. In all the cases we need to ensure we schedule a timeout. + applyHandshakeTimeout(null); + } // Not all SSLEngine implementations support calling beginHandshake multiple times while a handshake // is in progress. See https://github.com/netty/netty/issues/4718. return; diff --git a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java index ab7a63c908..772a15f9eb 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java @@ -74,9 +74,7 @@ import javax.net.ssl.SSLException; import javax.net.ssl.SSLProtocolException; import static io.netty.buffer.Unpooled.wrappedBuffer; -import static org.hamcrest.CoreMatchers.instanceOf; -import static org.hamcrest.CoreMatchers.is; -import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.CoreMatchers.*; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; @@ -754,4 +752,76 @@ public class SslHandlerTest { ReferenceCountUtil.release(sslClientCtx); } } + + @Test(timeout = 10000) + public void testHandshakeTimeoutFlushStartsHandshake() throws Exception { + testHandshakeTimeout0(false); + } + + @Test(timeout = 10000) + public void testHandshakeTimeoutStartTLS() throws Exception { + testHandshakeTimeout0(true); + } + + private static void testHandshakeTimeout0(final boolean startTls) throws Exception { + final SslContext sslClientCtx = SslContextBuilder.forClient() + .startTls(true) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(SslProvider.JDK).build(); + + EventLoopGroup group = new NioEventLoopGroup(); + Channel sc = null; + Channel cc = null; + final SslHandler sslHandler = sslClientCtx.newHandler(UnpooledByteBufAllocator.DEFAULT); + sslHandler.setHandshakeTimeout(500, TimeUnit.MILLISECONDS); + + try { + sc = new ServerBootstrap() + .group(group) + .channel(NioServerSocketChannel.class) + .childHandler(new ChannelInboundHandlerAdapter()) + .bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + + ChannelFuture future = new Bootstrap() + .group(group) + .channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(sslHandler); + if (startTls) { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + ctx.writeAndFlush(wrappedBuffer(new byte[] { 1, 2, 3, 4 })); + } + }); + } + } + }).connect(sc.localAddress()); + if (!startTls) { + future.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + // Write something to trigger the handshake before fireChannelActive is called. + future.channel().writeAndFlush(wrappedBuffer(new byte [] { 1, 2, 3, 4 })); + } + }); + } + cc = future.syncUninterruptibly().channel(); + + Throwable cause = sslHandler.handshakeFuture().await().cause(); + assertThat(cause, CoreMatchers.instanceOf(SSLException.class)); + assertThat(cause.getMessage(), containsString("timed out")); + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + if (sc != null) { + sc.close().syncUninterruptibly(); + } + group.shutdownGracefully(); + ReferenceCountUtil.release(sslClientCtx); + } + } }