SslHandler to fail handshake and pending writes if non-application write fails (#9240)

Motivation:
SslHandler must generate control data as part of the TLS protocol, for example
to do handshakes. SslHandler doesn't capture the status of the future
corresponding to the writes when writing this control (aka non-application
data). If there is another handler before the SslHandler that wants to fail
these writes the SslHandler will not detect the failure and we must wait until
the handshake timeout to detect a failure.

Modifications:
- SslHandler should detect if non application writes fail, tear down the
channel, and clean up any pending state.

Result:
SslHandler detects non application write failures and cleans up immediately.
This commit is contained in:
Scott Mitchell 2019-06-15 22:38:33 -07:00 committed by Norman Maurer
parent db1e662933
commit bf7f41a993
2 changed files with 102 additions and 6 deletions

View File

@ -913,7 +913,7 @@ public class SslHandler extends ByteToMessageDecoder {
* {@link #setHandshakeFailure(ChannelHandlerContext, Throwable)}.
* @return {@code true} if this method ends on {@link SSLEngineResult.HandshakeStatus#NOT_HANDSHAKING}.
*/
private boolean wrapNonAppData(ChannelHandlerContext ctx, boolean inUnwrap) throws SSLException {
private boolean wrapNonAppData(final ChannelHandlerContext ctx, boolean inUnwrap) throws SSLException {
ByteBuf out = null;
ByteBufAllocator alloc = ctx.alloc();
try {
@ -929,7 +929,15 @@ public class SslHandler extends ByteToMessageDecoder {
SSLEngineResult result = wrap(alloc, engine, Unpooled.EMPTY_BUFFER, out);
if (result.bytesProduced() > 0) {
ctx.write(out);
ctx.write(out).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
Throwable cause = future.cause();
if (cause != null) {
setHandshakeFailureTransportFailure(ctx, cause);
}
}
});
if (inUnwrap) {
needsFlush = true;
}
@ -1768,11 +1776,26 @@ public class SslHandler extends ByteToMessageDecoder {
}
} finally {
// Ensure we remove and fail all pending writes in all cases and so release memory quickly.
releaseAndFailAll(cause);
releaseAndFailAll(ctx, cause);
}
}
private void releaseAndFailAll(Throwable cause) {
private void setHandshakeFailureTransportFailure(ChannelHandlerContext ctx, Throwable cause) {
// If TLS control frames fail to write we are in an unknown state and may become out of
// sync with our peer. We give up and close the channel. This will also take care of
// cleaning up any outstanding state (e.g. handshake promise, queued unencrypted data).
try {
SSLException transportFailure = new SSLException("failure when writing TLS control frames", cause);
releaseAndFailAll(ctx, transportFailure);
if (handshakePromise.tryFailure(transportFailure)) {
ctx.fireUserEventTriggered(new SslHandshakeCompletionEvent(transportFailure));
}
} finally {
ctx.close();
}
}
private void releaseAndFailAll(ChannelHandlerContext ctx, Throwable cause) {
if (pendingUnencryptedWrites != null) {
pendingUnencryptedWrites.releaseAndFailAll(ctx, cause);
}
@ -1956,7 +1979,7 @@ public class SslHandler extends ByteToMessageDecoder {
SslUtils.handleHandshakeFailure(ctx, exception, true);
}
} finally {
releaseAndFailAll(exception);
releaseAndFailAll(ctx, exception);
}
}, handshakeTimeoutMillis, TimeUnit.MILLISECONDS);

View File

@ -23,6 +23,7 @@ import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
@ -60,8 +61,10 @@ import org.junit.Test;
import java.net.InetSocketAddress;
import java.nio.channels.ClosedChannelException;
import java.security.NoSuchAlgorithmException;
import java.util.Queue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
@ -78,9 +81,13 @@ import javax.net.ssl.SSLException;
import javax.net.ssl.SSLProtocolException;
import static io.netty.buffer.Unpooled.wrappedBuffer;
import static org.hamcrest.CoreMatchers.*;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@ -88,6 +95,62 @@ import static org.junit.Assume.assumeTrue;
public class SslHandlerTest {
@Test(timeout = 5000)
public void testNonApplicationDataFailureFailsQueuedWrites() throws NoSuchAlgorithmException, InterruptedException {
final CountDownLatch writeLatch = new CountDownLatch(1);
final Queue<ChannelPromise> writesToFail = new ConcurrentLinkedQueue<ChannelPromise>();
SSLEngine engine = newClientModeSSLEngine();
SslHandler handler = new SslHandler(engine) {
@Override
public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
super.write(ctx, msg, promise);
writeLatch.countDown();
}
};
EmbeddedChannel ch = new EmbeddedChannel(new ChannelDuplexHandler() {
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
if (msg instanceof ByteBuf) {
if (((ByteBuf) msg).isReadable()) {
writesToFail.add(promise);
} else {
promise.setSuccess();
}
}
ReferenceCountUtil.release(msg);
}
}, handler);
try {
final CountDownLatch writeCauseLatch = new CountDownLatch(1);
final AtomicReference<Throwable> failureRef = new AtomicReference<Throwable>();
ch.write(Unpooled.wrappedBuffer(new byte[]{1})).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
failureRef.compareAndSet(null, future.cause());
writeCauseLatch.countDown();
}
});
writeLatch.await();
// Simulate failing the SslHandler non-application writes after there are applications writes queued.
ChannelPromise promiseToFail;
while ((promiseToFail = writesToFail.poll()) != null) {
promiseToFail.setFailure(new RuntimeException("fake exception"));
}
writeCauseLatch.await();
Throwable writeCause = failureRef.get();
assertNotNull(writeCause);
assertThat(writeCause, is(CoreMatchers.<Throwable>instanceOf(SSLException.class)));
Throwable cause = handler.handshakeFuture().cause();
assertNotNull(cause);
assertThat(cause, is(CoreMatchers.<Throwable>instanceOf(SSLException.class)));
} finally {
assertFalse(ch.finishAndReleaseAll());
}
}
@Test
public void testNoSslHandshakeEventWhenNoHandshake() throws Exception {
final AtomicBoolean inActive = new AtomicBoolean(false);
@ -147,6 +210,16 @@ public class SslHandlerTest {
return engine;
}
private static SSLEngine newClientModeSSLEngine() throws NoSuchAlgorithmException {
SSLEngine engine = SSLContext.getDefault().createSSLEngine();
// Set the mode before we try to do the handshake as otherwise it may throw an IllegalStateException.
// See:
// - https://docs.oracle.com/javase/10/docs/api/javax/net/ssl/SSLEngine.html#beginHandshake()
// - http://mail.openjdk.java.net/pipermail/security-dev/2018-July/017715.html
engine.setUseClientMode(true);
return engine;
}
private static void testHandshakeTimeout(boolean client) throws Throwable {
SSLEngine engine = SSLContext.getDefault().createSSLEngine();
engine.setUseClientMode(client);