Allow to configure SslHandler to wait for close_notify response before closing the Channel and fix racy flush close_notify timeout scheduling.

Motivation:

SslHandler closed the channel as soon as it was able to write out the close_notify message. This may not be what the user want as it may make sense to only close it after the actual response to the close_notify was received in order to guarantee a clean-shutdown of the connection in all cases.

Beside this closeNotifyFlushTimeoutMillis is volatile so may change between two reads. We need to cache it in a local variable to ensure it not change int between. Beside this we also need to check if the flush promise was complete the schedule timeout as this may happened but we were not able to cancel the timeout yet. Otherwise we will produce an missleading log message.

Modifications:

- Add new setter / getter to SslHandler which allows to specify the behavior (old behavior is preserved as default)
- Added unit test.
- Cache volatile closeNotifyTimeoutMillis.
- Correctly check if flush promise was complete before we try to forcibly close the Channel and log a warning.
- Add missing javadocs.

Result:

More clean shutdown of connection possible when using SSL and fix racy way of schedule close_notify flush timeouts and javadocs.
This commit is contained in:
Norman Maurer 2017-01-11 16:07:47 +01:00
parent 907726988d
commit 640ef615be
2 changed files with 271 additions and 20 deletions

View File

@ -308,7 +308,8 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
private boolean firedChannelRead; private boolean firedChannelRead;
private volatile long handshakeTimeoutMillis = 10000; private volatile long handshakeTimeoutMillis = 10000;
private volatile long closeNotifyTimeoutMillis = 3000; private volatile long closeNotifyFlushTimeoutMillis = 3000;
private volatile long closeNotifyReadTimeoutMillis;
/** /**
* Creates a new instance. * Creates a new instance.
@ -378,24 +379,86 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
this.handshakeTimeoutMillis = handshakeTimeoutMillis; this.handshakeTimeoutMillis = handshakeTimeoutMillis;
} }
/**
* @deprecated use {@link #getCloseNotifyFlushTimeoutMillis()}
*/
@Deprecated
public long getCloseNotifyTimeoutMillis() { public long getCloseNotifyTimeoutMillis() {
return closeNotifyTimeoutMillis; return getCloseNotifyFlushTimeoutMillis();
} }
/**
* @deprecated use {@link #setCloseNotifyFlushTimeout(long, TimeUnit)}
*/
@Deprecated
public void setCloseNotifyTimeout(long closeNotifyTimeout, TimeUnit unit) { public void setCloseNotifyTimeout(long closeNotifyTimeout, TimeUnit unit) {
if (unit == null) { setCloseNotifyFlushTimeout(closeNotifyTimeout, unit);
throw new NullPointerException("unit");
} }
setCloseNotifyTimeoutMillis(unit.toMillis(closeNotifyTimeout)); /**
* @deprecated use {@link #setCloseNotifyFlushTimeoutMillis(long)}
*/
@Deprecated
public void setCloseNotifyTimeoutMillis(long closeNotifyFlushTimeoutMillis) {
setCloseNotifyFlushTimeoutMillis(closeNotifyFlushTimeoutMillis);
} }
public void setCloseNotifyTimeoutMillis(long closeNotifyTimeoutMillis) { /**
if (closeNotifyTimeoutMillis < 0) { * Gets the timeout for flushing the close_notify that was triggered by closing the
* {@link Channel}. If the close_notify was not flushed in the given timeout the {@link Channel} will be closed
* forcibily.
*/
public final long getCloseNotifyFlushTimeoutMillis() {
return closeNotifyFlushTimeoutMillis;
}
/**
* Sets the timeout for flushing the close_notify that was triggered by closing the
* {@link Channel}. If the close_notify was not flushed in the given timeout the {@link Channel} will be closed
* forcibily.
*/
public final void setCloseNotifyFlushTimeout(long closeNotifyFlushTimeoutMillis, TimeUnit unit) {
setCloseNotifyFlushTimeoutMillis(unit.toMillis(closeNotifyFlushTimeoutMillis));
}
/**
* See {@link #setCloseNotifyFlushTimeout(long, TimeUnit)}.
*/
public final void setCloseNotifyFlushTimeoutMillis(long closeNotifyFlushTimeoutMillis) {
if (closeNotifyFlushTimeoutMillis < 0) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"closeNotifyTimeoutMillis: " + closeNotifyTimeoutMillis + " (expected: >= 0)"); "closeNotifyFlushTimeoutMillis: " + closeNotifyFlushTimeoutMillis + " (expected: >= 0)");
} }
this.closeNotifyTimeoutMillis = closeNotifyTimeoutMillis; this.closeNotifyFlushTimeoutMillis = closeNotifyFlushTimeoutMillis;
}
/**
* Gets the timeout (in ms) for receiving the response for the close_notify that was triggered by closing the
* {@link Channel}. This timeout starts after the close_notify message was successfully written to the
* remote peer. Use {@code 0} to directly close the {@link Channel} and not wait for the response.
*/
public final long getCloseNotifyReadTimeoutMillis() {
return closeNotifyReadTimeoutMillis;
}
/**
* Sets the timeout for receiving the response for the close_notify that was triggered by closing the
* {@link Channel}. This timeout starts after the close_notify message was successfully written to the
* remote peer. Use {@code 0} to directly close the {@link Channel} and not wait for the response.
*/
public final void setCloseNotifyReadTimeout(long closeNotifyReadTimeoutMillis, TimeUnit unit) {
setCloseNotifyReadTimeoutMillis(unit.toMillis(closeNotifyReadTimeoutMillis));
}
/**
* See {@link #setCloseNotifyReadTimeout(long, TimeUnit)}.
*/
public final void setCloseNotifyReadTimeoutMillis(long closeNotifyReadTimeoutMillis) {
if (closeNotifyReadTimeoutMillis < 0) {
throw new IllegalArgumentException(
"closeNotifyReadTimeoutMillis: " + closeNotifyReadTimeoutMillis + " (expected: >= 0)");
}
this.closeNotifyReadTimeoutMillis = closeNotifyReadTimeoutMillis;
} }
/** /**
@ -791,6 +854,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
// Ensure we always notify the sslClosePromise as well // Ensure we always notify the sslClosePromise as well
notifyClosePromise(CHANNEL_CLOSED); notifyClosePromise(CHANNEL_CLOSED);
super.channelInactive(ctx); super.channelInactive(ctx);
} }
@ -1506,7 +1570,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
} }
private void safeClose( private void safeClose(
final ChannelHandlerContext ctx, ChannelFuture flushFuture, final ChannelHandlerContext ctx, final ChannelFuture flushFuture,
final ChannelPromise promise) { final ChannelPromise promise) {
if (!ctx.channel().isActive()) { if (!ctx.channel().isActive()) {
ctx.close(promise); ctx.close(promise);
@ -1515,16 +1579,20 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
final ScheduledFuture<?> timeoutFuture; final ScheduledFuture<?> timeoutFuture;
if (!flushFuture.isDone()) { if (!flushFuture.isDone()) {
if (closeNotifyTimeoutMillis > 0) { long closeNotifyTimeout = closeNotifyFlushTimeoutMillis;
if (closeNotifyTimeout > 0) {
// Force-close the connection if close_notify is not fully sent in time. // Force-close the connection if close_notify is not fully sent in time.
timeoutFuture = ctx.executor().schedule(new Runnable() { timeoutFuture = ctx.executor().schedule(new Runnable() {
@Override @Override
public void run() { public void run() {
logger.warn("{} Last write attempt timed out; force-closing the connection.", ctx.channel()); // May be done in the meantime as cancel(...) is only best effort.
if (!flushFuture.isDone()) {
logger.warn("{} Last write attempt timed out; force-closing the connection.",
ctx.channel());
addCloseListener(ctx.close(ctx.newPromise()), promise); addCloseListener(ctx.close(ctx.newPromise()), promise);
} }
}, closeNotifyTimeoutMillis, TimeUnit.MILLISECONDS); }
}, closeNotifyTimeout, TimeUnit.MILLISECONDS);
} else { } else {
timeoutFuture = null; timeoutFuture = null;
} }
@ -1540,9 +1608,43 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
if (timeoutFuture != null) { if (timeoutFuture != null) {
timeoutFuture.cancel(false); timeoutFuture.cancel(false);
} }
final long closeNotifyReadTimeout = closeNotifyReadTimeoutMillis;
if (closeNotifyReadTimeout <= 0) {
// Trigger the close in all cases to make sure the promise is notified // Trigger the close in all cases to make sure the promise is notified
// See https://github.com/netty/netty/issues/2358 // See https://github.com/netty/netty/issues/2358
addCloseListener(ctx.close(ctx.newPromise()), promise); addCloseListener(ctx.close(ctx.newPromise()), promise);
} else {
final ScheduledFuture<?> closeNotifyReadTimeoutFuture;
if (!sslClosePromise.isDone()) {
closeNotifyReadTimeoutFuture = ctx.executor().schedule(new Runnable() {
@Override
public void run() {
if (!sslClosePromise.isDone()) {
logger.debug(
"{} did not receive close_notify in {}ms; force-closing the connection.",
ctx.channel(), closeNotifyReadTimeout);
// Do the close now...
addCloseListener(ctx.close(ctx.newPromise()), promise);
}
}
}, closeNotifyReadTimeout, TimeUnit.MILLISECONDS);
} else {
closeNotifyReadTimeoutFuture = null;
}
// Do the close once the we received the close_notify.
sslClosePromise.addListener(new FutureListener<Channel>() {
@Override
public void operationComplete(Future<Channel> future) throws Exception {
if (closeNotifyReadTimeoutFuture != null) {
closeNotifyReadTimeoutFuture.cancel(false);
}
addCloseListener(ctx.close(ctx.newPromise()), promise);
}
});
}
} }
}); });
} }

View File

@ -35,6 +35,7 @@ import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelInitializer;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel;
@ -45,6 +46,7 @@ import io.netty.util.IllegalReferenceCountException;
import io.netty.util.concurrent.Future; import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.FutureListener; import io.netty.util.concurrent.FutureListener;
import io.netty.util.concurrent.Promise; import io.netty.util.concurrent.Promise;
import io.netty.util.concurrent.PromiseNotifier;
import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.EmptyArrays;
import org.junit.Test; import org.junit.Test;
@ -68,6 +70,7 @@ import java.security.cert.CertificateException;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import java.util.concurrent.BlockingQueue; import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
public class SslHandlerTest { public class SslHandlerTest {
@ -411,4 +414,150 @@ public class SslHandlerTest {
assertTrue(evt.cause() instanceof ClosedChannelException); assertTrue(evt.cause() instanceof ClosedChannelException);
assertTrue(events.isEmpty()); assertTrue(events.isEmpty());
} }
@Test(timeout = 30000)
public void testCloseNotifyReceivedJdk() throws Exception {
testCloseNotify(SslProvider.JDK, 5000, false);
}
@Test(timeout = 30000)
public void testCloseNotifyReceivedOpenSsl() throws Exception {
assumeTrue(OpenSsl.isAvailable());
testCloseNotify(SslProvider.OPENSSL, 5000, false);
testCloseNotify(SslProvider.OPENSSL_REFCNT, 5000, false);
}
@Test(timeout = 30000)
public void testCloseNotifyReceivedJdkTimeout() throws Exception {
testCloseNotify(SslProvider.JDK, 100, true);
}
@Test(timeout = 30000)
public void testCloseNotifyReceivedOpenSslTimeout() throws Exception {
assumeTrue(OpenSsl.isAvailable());
testCloseNotify(SslProvider.OPENSSL, 100, true);
testCloseNotify(SslProvider.OPENSSL_REFCNT, 100, true);
}
@Test(timeout = 30000)
public void testCloseNotifyNotWaitForResponseJdk() throws Exception {
testCloseNotify(SslProvider.JDK, 0, false);
}
@Test(timeout = 30000)
public void testCloseNotifyNotWaitForResponseOpenSsl() throws Exception {
assumeTrue(OpenSsl.isAvailable());
testCloseNotify(SslProvider.OPENSSL, 0, false);
testCloseNotify(SslProvider.OPENSSL_REFCNT, 0, false);
}
private static void testCloseNotify(SslProvider provider, final long closeNotifyReadTimeout, final boolean timeout)
throws Exception {
SelfSignedCertificate ssc = new SelfSignedCertificate();
final SslContext sslServerCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey())
.sslProvider(provider)
.build();
final SslContext sslClientCtx = SslContextBuilder.forClient()
.trustManager(InsecureTrustManagerFactory.INSTANCE)
.sslProvider(provider).build();
EventLoopGroup group = new NioEventLoopGroup();
Channel sc = null;
Channel cc = null;
try {
final Promise<Channel> clientPromise = group.next().newPromise();
final Promise<Channel> serverPromise = group.next().newPromise();
sc = new ServerBootstrap()
.group(group)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
SslHandler handler = sslServerCtx.newHandler(ch.alloc());
handler.setCloseNotifyReadTimeoutMillis(closeNotifyReadTimeout);
handler.sslCloseFuture().addListener(
new PromiseNotifier<Channel, Future<Channel>>(serverPromise));
handler.handshakeFuture().addListener(new FutureListener<Channel>() {
@Override
public void operationComplete(Future<Channel> future) {
if (!future.isSuccess()) {
// Something bad happened during handshake fail the promise!
serverPromise.tryFailure(future.cause());
}
}
});
ch.pipeline().addLast(handler);
}
}).bind(new InetSocketAddress(0)).syncUninterruptibly().channel();
cc = new Bootstrap()
.group(group)
.channel(NioSocketChannel.class)
.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
final AtomicBoolean closeSent = new AtomicBoolean();
if (timeout) {
ch.pipeline().addFirst(new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (closeSent.get()) {
// Drop data on the floor so we will get a timeout while waiting for the
// close_notify.
ReferenceCountUtil.release(msg);
} else {
super.channelRead(ctx, msg);
}
}
});
}
SslHandler handler = sslClientCtx.newHandler(ch.alloc());
handler.setCloseNotifyReadTimeoutMillis(closeNotifyReadTimeout);
handler.sslCloseFuture().addListener(
new PromiseNotifier<Channel, Future<Channel>>(clientPromise));
handler.handshakeFuture().addListener(new FutureListener<Channel>() {
@Override
public void operationComplete(Future<Channel> future) {
if (future.isSuccess()) {
closeSent.compareAndSet(false, true);
future.getNow().close();
} else {
// Something bad happened during handshake fail the promise!
clientPromise.tryFailure(future.cause());
}
}
});
ch.pipeline().addLast(handler);
}
}).connect(sc.localAddress()).syncUninterruptibly().channel();
serverPromise.awaitUninterruptibly();
clientPromise.awaitUninterruptibly();
// Server always received the close_notify as the client triggers the close sequence.
assertTrue(serverPromise.isSuccess());
// Depending on if we wait for the response or not the promise will be failed or not.
if (closeNotifyReadTimeout > 0 && !timeout) {
assertTrue(clientPromise.isSuccess());
} else {
assertFalse(clientPromise.isSuccess());
}
} finally {
if (cc != null) {
cc.close().syncUninterruptibly();
}
if (sc != null) {
sc.close().syncUninterruptibly();
}
group.shutdownGracefully();
ReferenceCountUtil.release(sslServerCtx);
ReferenceCountUtil.release(sslClientCtx);
}
}
} }