diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index ac32c9c76f..f72cf87462 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -1536,11 +1536,15 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH public void handlerAdded(final ChannelHandlerContext ctx) throws Exception { this.ctx = ctx; - if (ctx.channel().isActive() && engine.getUseClientMode()) { - // Begin the initial handshake. - // channelActive() event has been fired already, which means this.channelActive() will - // not be invoked. We have to initialize here instead. - handshake(null); + if (ctx.channel().isActive()) { + if (engine.getUseClientMode()) { + // Begin the initial handshake. + // channelActive() event has been fired already, which means this.channelActive() will + // not be invoked. We have to initialize here instead. + handshake(null); + } else { + applyHandshakeTimeout(null); + } } else { // channelActive() event has not been fired yet. this.channelOpen() will be invoked // and initialization will occur there. @@ -1635,17 +1639,21 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH } finally { forceFlush(ctx); } + applyHandshakeTimeout(p); + } + private void applyHandshakeTimeout(Promise p) { + final Promise promise = p == null ? handshakePromise : p; // Set timeout if necessary. final long handshakeTimeoutMillis = this.handshakeTimeoutMillis; - if (handshakeTimeoutMillis <= 0 || p.isDone()) { + if (handshakeTimeoutMillis <= 0 || promise.isDone()) { return; } final ScheduledFuture timeoutFuture = ctx.executor().schedule(new Runnable() { @Override public void run() { - if (p.isDone()) { + if (promise.isDone()) { return; } notifyHandshakeFailure(HANDSHAKE_TIMED_OUT); @@ -1653,7 +1661,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH }, handshakeTimeoutMillis, TimeUnit.MILLISECONDS); // Cancel the handshake timeout when handshake is finished. - p.addListener(new FutureListener() { + promise.addListener(new FutureListener() { @Override public void operationComplete(Future f) throws Exception { timeoutFuture.cancel(false); @@ -1671,9 +1679,13 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH */ @Override public void channelActive(final ChannelHandlerContext ctx) throws Exception { - if (!startTls && engine.getUseClientMode()) { - // Begin the initial handshake - handshake(null); + if (!startTls) { + if (engine.getUseClientMode()) { + // Begin the initial handshake. + handshake(null); + } else { + applyHandshakeTimeout(null); + } } ctx.fireChannelActive(); } diff --git a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java index 5e10a6d569..478c77400f 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java @@ -85,6 +85,36 @@ import static org.junit.Assume.assumeTrue; public class SslHandlerTest { + @Test(expected = SSLException.class, timeout = 3000) + public void testClientHandshakeTimeout() throws Exception { + testHandshakeTimeout(true); + } + + @Test(expected = SSLException.class, timeout = 3000) + public void testServerHandshakeTimeout() throws Exception { + testHandshakeTimeout(false); + } + + private static void testHandshakeTimeout(boolean client) throws Exception { + SSLEngine engine = SSLContext.getDefault().createSSLEngine(); + engine.setUseClientMode(client); + SslHandler handler = new SslHandler(engine); + handler.setHandshakeTimeoutMillis(1); + + EmbeddedChannel ch = new EmbeddedChannel(handler); + try { + while (!handler.handshakeFuture().isDone()) { + Thread.sleep(10); + // We need to run all pending tasks as the handshake timeout is scheduled on the EventLoop. + ch.runPendingTasks(); + } + + handler.handshakeFuture().syncUninterruptibly(); + } finally { + ch.finishAndReleaseAll(); + } + } + @Test public void testTruncatedPacket() throws Exception { SSLEngine engine = SSLContext.getDefault().createSSLEngine();