diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index c1ed587783..1d0e559f50 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -913,7 +913,7 @@ public class SslHandler extends ByteToMessageDecoder { * {@link #setHandshakeFailure(ChannelHandlerContext, Throwable)}. * @return {@code true} if this method ends on {@link SSLEngineResult.HandshakeStatus#NOT_HANDSHAKING}. */ - private boolean wrapNonAppData(ChannelHandlerContext ctx, boolean inUnwrap) throws SSLException { + private boolean wrapNonAppData(final ChannelHandlerContext ctx, boolean inUnwrap) throws SSLException { ByteBuf out = null; ByteBufAllocator alloc = ctx.alloc(); try { @@ -929,7 +929,15 @@ public class SslHandler extends ByteToMessageDecoder { SSLEngineResult result = wrap(alloc, engine, Unpooled.EMPTY_BUFFER, out); if (result.bytesProduced() > 0) { - ctx.write(out); + ctx.write(out).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + Throwable cause = future.cause(); + if (cause != null) { + setHandshakeFailureTransportFailure(ctx, cause); + } + } + }); if (inUnwrap) { needsFlush = true; } @@ -1768,11 +1776,26 @@ public class SslHandler extends ByteToMessageDecoder { } } finally { // Ensure we remove and fail all pending writes in all cases and so release memory quickly. - releaseAndFailAll(cause); + releaseAndFailAll(ctx, cause); } } - private void releaseAndFailAll(Throwable cause) { + private void setHandshakeFailureTransportFailure(ChannelHandlerContext ctx, Throwable cause) { + // If TLS control frames fail to write we are in an unknown state and may become out of + // sync with our peer. We give up and close the channel. This will also take care of + // cleaning up any outstanding state (e.g. handshake promise, queued unencrypted data). + try { + SSLException transportFailure = new SSLException("failure when writing TLS control frames", cause); + releaseAndFailAll(ctx, transportFailure); + if (handshakePromise.tryFailure(transportFailure)) { + ctx.fireUserEventTriggered(new SslHandshakeCompletionEvent(transportFailure)); + } + } finally { + ctx.close(); + } + } + + private void releaseAndFailAll(ChannelHandlerContext ctx, Throwable cause) { if (pendingUnencryptedWrites != null) { pendingUnencryptedWrites.releaseAndFailAll(ctx, cause); } @@ -1956,7 +1979,7 @@ public class SslHandler extends ByteToMessageDecoder { SslUtils.handleHandshakeFailure(ctx, exception, true); } } finally { - releaseAndFailAll(exception); + releaseAndFailAll(ctx, exception); } }, handshakeTimeoutMillis, TimeUnit.MILLISECONDS); diff --git a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java index c8379eb2f6..8829033bfd 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java @@ -23,6 +23,7 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.channel.Channel; +import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandler; @@ -60,8 +61,10 @@ import org.junit.Test; import java.net.InetSocketAddress; import java.nio.channels.ClosedChannelException; import java.security.NoSuchAlgorithmException; +import java.util.Queue; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; @@ -78,9 +81,13 @@ import javax.net.ssl.SSLException; import javax.net.ssl.SSLProtocolException; import static io.netty.buffer.Unpooled.wrappedBuffer; -import static org.hamcrest.CoreMatchers.*; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.nullValue; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -88,6 +95,62 @@ import static org.junit.Assume.assumeTrue; public class SslHandlerTest { + @Test(timeout = 5000) + public void testNonApplicationDataFailureFailsQueuedWrites() throws NoSuchAlgorithmException, InterruptedException { + final CountDownLatch writeLatch = new CountDownLatch(1); + final Queue writesToFail = new ConcurrentLinkedQueue(); + SSLEngine engine = newClientModeSSLEngine(); + SslHandler handler = new SslHandler(engine) { + @Override + public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + super.write(ctx, msg, promise); + writeLatch.countDown(); + } + }; + EmbeddedChannel ch = new EmbeddedChannel(new ChannelDuplexHandler() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + if (msg instanceof ByteBuf) { + if (((ByteBuf) msg).isReadable()) { + writesToFail.add(promise); + } else { + promise.setSuccess(); + } + } + ReferenceCountUtil.release(msg); + } + }, handler); + + try { + final CountDownLatch writeCauseLatch = new CountDownLatch(1); + final AtomicReference failureRef = new AtomicReference(); + ch.write(Unpooled.wrappedBuffer(new byte[]{1})).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + failureRef.compareAndSet(null, future.cause()); + writeCauseLatch.countDown(); + } + }); + writeLatch.await(); + + // Simulate failing the SslHandler non-application writes after there are applications writes queued. + ChannelPromise promiseToFail; + while ((promiseToFail = writesToFail.poll()) != null) { + promiseToFail.setFailure(new RuntimeException("fake exception")); + } + + writeCauseLatch.await(); + Throwable writeCause = failureRef.get(); + assertNotNull(writeCause); + assertThat(writeCause, is(CoreMatchers.instanceOf(SSLException.class))); + Throwable cause = handler.handshakeFuture().cause(); + assertNotNull(cause); + assertThat(cause, is(CoreMatchers.instanceOf(SSLException.class))); + } finally { + assertFalse(ch.finishAndReleaseAll()); + } + } + @Test public void testNoSslHandshakeEventWhenNoHandshake() throws Exception { final AtomicBoolean inActive = new AtomicBoolean(false); @@ -147,6 +210,16 @@ public class SslHandlerTest { return engine; } + private static SSLEngine newClientModeSSLEngine() throws NoSuchAlgorithmException { + SSLEngine engine = SSLContext.getDefault().createSSLEngine(); + // Set the mode before we try to do the handshake as otherwise it may throw an IllegalStateException. + // See: + // - https://docs.oracle.com/javase/10/docs/api/javax/net/ssl/SSLEngine.html#beginHandshake() + // - http://mail.openjdk.java.net/pipermail/security-dev/2018-July/017715.html + engine.setUseClientMode(true); + return engine; + } + private static void testHandshakeTimeout(boolean client) throws Throwable { SSLEngine engine = SSLContext.getDefault().createSSLEngine(); engine.setUseClientMode(client);