diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index 8722119fe4..83503c8593 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -308,7 +308,8 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH private boolean firedChannelRead; private volatile long handshakeTimeoutMillis = 10000; - private volatile long closeNotifyTimeoutMillis = 3000; + private volatile long closeNotifyFlushTimeoutMillis = 3000; + private volatile long closeNotifyReadTimeoutMillis; /** * Creates a new instance. @@ -378,24 +379,86 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH this.handshakeTimeoutMillis = handshakeTimeoutMillis; } + /** + * @deprecated use {@link #getCloseNotifyFlushTimeoutMillis()} + */ + @Deprecated public long getCloseNotifyTimeoutMillis() { - return closeNotifyTimeoutMillis; + return getCloseNotifyFlushTimeoutMillis(); } + /** + * @deprecated use {@link #setCloseNotifyFlushTimeout(long, TimeUnit)} + */ + @Deprecated public void setCloseNotifyTimeout(long closeNotifyTimeout, TimeUnit unit) { - if (unit == null) { - throw new NullPointerException("unit"); - } - - setCloseNotifyTimeoutMillis(unit.toMillis(closeNotifyTimeout)); + setCloseNotifyFlushTimeout(closeNotifyTimeout, unit); } - public void setCloseNotifyTimeoutMillis(long closeNotifyTimeoutMillis) { - if (closeNotifyTimeoutMillis < 0) { + /** + * @deprecated use {@link #setCloseNotifyFlushTimeoutMillis(long)} + */ + @Deprecated + public void setCloseNotifyTimeoutMillis(long closeNotifyFlushTimeoutMillis) { + setCloseNotifyFlushTimeoutMillis(closeNotifyFlushTimeoutMillis); + } + + /** + * Gets the timeout for flushing the close_notify that was triggered by closing the + * {@link Channel}. If the close_notify was not flushed in the given timeout the {@link Channel} will be closed + * forcibily. + */ + public final long getCloseNotifyFlushTimeoutMillis() { + return closeNotifyFlushTimeoutMillis; + } + + /** + * Sets the timeout for flushing the close_notify that was triggered by closing the + * {@link Channel}. If the close_notify was not flushed in the given timeout the {@link Channel} will be closed + * forcibily. + */ + public final void setCloseNotifyFlushTimeout(long closeNotifyFlushTimeoutMillis, TimeUnit unit) { + setCloseNotifyFlushTimeoutMillis(unit.toMillis(closeNotifyFlushTimeoutMillis)); + } + + /** + * See {@link #setCloseNotifyFlushTimeout(long, TimeUnit)}. + */ + public final void setCloseNotifyFlushTimeoutMillis(long closeNotifyFlushTimeoutMillis) { + if (closeNotifyFlushTimeoutMillis < 0) { throw new IllegalArgumentException( - "closeNotifyTimeoutMillis: " + closeNotifyTimeoutMillis + " (expected: >= 0)"); + "closeNotifyFlushTimeoutMillis: " + closeNotifyFlushTimeoutMillis + " (expected: >= 0)"); } - this.closeNotifyTimeoutMillis = closeNotifyTimeoutMillis; + this.closeNotifyFlushTimeoutMillis = closeNotifyFlushTimeoutMillis; + } + + /** + * Gets the timeout (in ms) for receiving the response for the close_notify that was triggered by closing the + * {@link Channel}. This timeout starts after the close_notify message was successfully written to the + * remote peer. Use {@code 0} to directly close the {@link Channel} and not wait for the response. + */ + public final long getCloseNotifyReadTimeoutMillis() { + return closeNotifyReadTimeoutMillis; + } + + /** + * Sets the timeout for receiving the response for the close_notify that was triggered by closing the + * {@link Channel}. This timeout starts after the close_notify message was successfully written to the + * remote peer. Use {@code 0} to directly close the {@link Channel} and not wait for the response. + */ + public final void setCloseNotifyReadTimeout(long closeNotifyReadTimeoutMillis, TimeUnit unit) { + setCloseNotifyReadTimeoutMillis(unit.toMillis(closeNotifyReadTimeoutMillis)); + } + + /** + * See {@link #setCloseNotifyReadTimeout(long, TimeUnit)}. + */ + public final void setCloseNotifyReadTimeoutMillis(long closeNotifyReadTimeoutMillis) { + if (closeNotifyReadTimeoutMillis < 0) { + throw new IllegalArgumentException( + "closeNotifyReadTimeoutMillis: " + closeNotifyReadTimeoutMillis + " (expected: >= 0)"); + } + this.closeNotifyReadTimeoutMillis = closeNotifyReadTimeoutMillis; } /** @@ -791,6 +854,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH // Ensure we always notify the sslClosePromise as well notifyClosePromise(CHANNEL_CLOSED); + super.channelInactive(ctx); } @@ -1506,7 +1570,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH } private void safeClose( - final ChannelHandlerContext ctx, ChannelFuture flushFuture, + final ChannelHandlerContext ctx, final ChannelFuture flushFuture, final ChannelPromise promise) { if (!ctx.channel().isActive()) { ctx.close(promise); @@ -1515,16 +1579,20 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH final ScheduledFuture timeoutFuture; if (!flushFuture.isDone()) { - if (closeNotifyTimeoutMillis > 0) { + long closeNotifyTimeout = closeNotifyFlushTimeoutMillis; + if (closeNotifyTimeout > 0) { // Force-close the connection if close_notify is not fully sent in time. timeoutFuture = ctx.executor().schedule(new Runnable() { @Override public void run() { - logger.warn("{} Last write attempt timed out; force-closing the connection.", ctx.channel()); - - addCloseListener(ctx.close(ctx.newPromise()), promise); + // May be done in the meantime as cancel(...) is only best effort. + if (!flushFuture.isDone()) { + logger.warn("{} Last write attempt timed out; force-closing the connection.", + ctx.channel()); + addCloseListener(ctx.close(ctx.newPromise()), promise); + } } - }, closeNotifyTimeoutMillis, TimeUnit.MILLISECONDS); + }, closeNotifyTimeout, TimeUnit.MILLISECONDS); } else { timeoutFuture = null; } @@ -1540,9 +1608,43 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH if (timeoutFuture != null) { timeoutFuture.cancel(false); } - // Trigger the close in all cases to make sure the promise is notified - // See https://github.com/netty/netty/issues/2358 - addCloseListener(ctx.close(ctx.newPromise()), promise); + final long closeNotifyReadTimeout = closeNotifyReadTimeoutMillis; + if (closeNotifyReadTimeout <= 0) { + // Trigger the close in all cases to make sure the promise is notified + // See https://github.com/netty/netty/issues/2358 + addCloseListener(ctx.close(ctx.newPromise()), promise); + } else { + final ScheduledFuture closeNotifyReadTimeoutFuture; + + if (!sslClosePromise.isDone()) { + closeNotifyReadTimeoutFuture = ctx.executor().schedule(new Runnable() { + @Override + public void run() { + if (!sslClosePromise.isDone()) { + logger.debug( + "{} did not receive close_notify in {}ms; force-closing the connection.", + ctx.channel(), closeNotifyReadTimeout); + + // Do the close now... + addCloseListener(ctx.close(ctx.newPromise()), promise); + } + } + }, closeNotifyReadTimeout, TimeUnit.MILLISECONDS); + } else { + closeNotifyReadTimeoutFuture = null; + } + + // Do the close once the we received the close_notify. + sslClosePromise.addListener(new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + if (closeNotifyReadTimeoutFuture != null) { + closeNotifyReadTimeoutFuture.cancel(false); + } + addCloseListener(ctx.close(ctx.newPromise()), promise); + } + }); + } } }); } diff --git a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java index 17db194058..b39de0fced 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java @@ -35,6 +35,7 @@ import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.Channel; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelInitializer; +import io.netty.channel.EventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; @@ -45,6 +46,7 @@ import io.netty.util.IllegalReferenceCountException; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.FutureListener; import io.netty.util.concurrent.Promise; +import io.netty.util.concurrent.PromiseNotifier; import io.netty.util.internal.EmptyArrays; import org.junit.Test; @@ -68,6 +70,7 @@ import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.atomic.AtomicBoolean; public class SslHandlerTest { @@ -411,4 +414,150 @@ public class SslHandlerTest { assertTrue(evt.cause() instanceof ClosedChannelException); assertTrue(events.isEmpty()); } + + @Test(timeout = 30000) + public void testCloseNotifyReceivedJdk() throws Exception { + testCloseNotify(SslProvider.JDK, 5000, false); + } + + @Test(timeout = 30000) + public void testCloseNotifyReceivedOpenSsl() throws Exception { + assumeTrue(OpenSsl.isAvailable()); + testCloseNotify(SslProvider.OPENSSL, 5000, false); + testCloseNotify(SslProvider.OPENSSL_REFCNT, 5000, false); + } + + @Test(timeout = 30000) + public void testCloseNotifyReceivedJdkTimeout() throws Exception { + testCloseNotify(SslProvider.JDK, 100, true); + } + + @Test(timeout = 30000) + public void testCloseNotifyReceivedOpenSslTimeout() throws Exception { + assumeTrue(OpenSsl.isAvailable()); + testCloseNotify(SslProvider.OPENSSL, 100, true); + testCloseNotify(SslProvider.OPENSSL_REFCNT, 100, true); + } + + @Test(timeout = 30000) + public void testCloseNotifyNotWaitForResponseJdk() throws Exception { + testCloseNotify(SslProvider.JDK, 0, false); + } + + @Test(timeout = 30000) + public void testCloseNotifyNotWaitForResponseOpenSsl() throws Exception { + assumeTrue(OpenSsl.isAvailable()); + testCloseNotify(SslProvider.OPENSSL, 0, false); + testCloseNotify(SslProvider.OPENSSL_REFCNT, 0, false); + } + + private static void testCloseNotify(SslProvider provider, final long closeNotifyReadTimeout, final boolean timeout) + throws Exception { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + + final SslContext sslServerCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(provider) + .build(); + + final SslContext sslClientCtx = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(provider).build(); + + EventLoopGroup group = new NioEventLoopGroup(); + Channel sc = null; + Channel cc = null; + try { + final Promise clientPromise = group.next().newPromise(); + final Promise serverPromise = group.next().newPromise(); + + sc = new ServerBootstrap() + .group(group) + .channel(NioServerSocketChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + SslHandler handler = sslServerCtx.newHandler(ch.alloc()); + handler.setCloseNotifyReadTimeoutMillis(closeNotifyReadTimeout); + handler.sslCloseFuture().addListener( + new PromiseNotifier>(serverPromise)); + handler.handshakeFuture().addListener(new FutureListener() { + @Override + public void operationComplete(Future future) { + if (!future.isSuccess()) { + // Something bad happened during handshake fail the promise! + serverPromise.tryFailure(future.cause()); + } + } + }); + ch.pipeline().addLast(handler); + } + }).bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + + cc = new Bootstrap() + .group(group) + .channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + final AtomicBoolean closeSent = new AtomicBoolean(); + if (timeout) { + ch.pipeline().addFirst(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (closeSent.get()) { + // Drop data on the floor so we will get a timeout while waiting for the + // close_notify. + ReferenceCountUtil.release(msg); + } else { + super.channelRead(ctx, msg); + } + } + }); + } + + SslHandler handler = sslClientCtx.newHandler(ch.alloc()); + handler.setCloseNotifyReadTimeoutMillis(closeNotifyReadTimeout); + handler.sslCloseFuture().addListener( + new PromiseNotifier>(clientPromise)); + handler.handshakeFuture().addListener(new FutureListener() { + @Override + public void operationComplete(Future future) { + if (future.isSuccess()) { + closeSent.compareAndSet(false, true); + future.getNow().close(); + } else { + // Something bad happened during handshake fail the promise! + clientPromise.tryFailure(future.cause()); + } + } + }); + ch.pipeline().addLast(handler); + } + }).connect(sc.localAddress()).syncUninterruptibly().channel(); + + serverPromise.awaitUninterruptibly(); + clientPromise.awaitUninterruptibly(); + + // Server always received the close_notify as the client triggers the close sequence. + assertTrue(serverPromise.isSuccess()); + + // Depending on if we wait for the response or not the promise will be failed or not. + if (closeNotifyReadTimeout > 0 && !timeout) { + assertTrue(clientPromise.isSuccess()); + } else { + assertFalse(clientPromise.isSuccess()); + } + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + if (sc != null) { + sc.close().syncUninterruptibly(); + } + group.shutdownGracefully(); + + ReferenceCountUtil.release(sslServerCtx); + ReferenceCountUtil.release(sslClientCtx); + } + } }