diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index d1a80ea81d..76594ceb39 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -823,17 +823,14 @@ public class SslHandler extends ByteToMessageDecoder { // This method will not call setHandshakeFailure(...) ! private void wrap(ChannelHandlerContext ctx, boolean inUnwrap) throws SSLException { ByteBuf out = null; - ChannelPromise promise = null; ByteBufAllocator alloc = ctx.alloc(); - boolean needUnwrap = false; - ByteBuf buf = null; try { final int wrapDataSize = this.wrapDataSize; // Only continue to loop if the handler was not removed in the meantime. // See https://github.com/netty/netty/issues/5860 outer: while (!ctx.isRemoved()) { - promise = ctx.newPromise(); - buf = wrapDataSize > 0 ? + ChannelPromise promise = ctx.newPromise(); + ByteBuf buf = wrapDataSize > 0 ? pendingUnencryptedWrites.remove(alloc, wrapDataSize, promise) : pendingUnencryptedWrites.removeFirst(promise); if (buf == null) { @@ -846,9 +843,31 @@ public class SslHandler extends ByteToMessageDecoder { SSLEngineResult result = wrap(alloc, engine, buf, out); - if (result.getStatus() == Status.CLOSED) { + if (buf.isReadable()) { + pendingUnencryptedWrites.addFirst(buf, promise); + // When we add the buffer/promise pair back we need to be sure we don't complete the promise + // later. We only complete the promise if the buffer is completely consumed. + promise = null; + } else { buf.release(); - buf = null; + } + + // We need to write any data before we invoke any methods which may trigger re-entry, otherwise + // writes may occur out of order and TLS sequencing may be off (e.g. SSLV3_ALERT_BAD_RECORD_MAC). + if (out.isReadable()) { + final ByteBuf b = out; + out = null; + if (promise != null) { + ctx.write(b, promise); + } else { + ctx.write(b); + } + } else if (promise != null) { + ctx.write(Unpooled.EMPTY_BUFFER, promise); + } + // else out is not readable we can re-use it and so save an extra allocation + + if (result.getStatus() == Status.CLOSED) { // Make a best effort to preserve any exception that way previously encountered from the handshake // or the transport, else fallback to a general error. Throwable exception = handshakePromise.cause(); @@ -858,23 +877,9 @@ public class SslHandler extends ByteToMessageDecoder { exception = new SslClosedEngineException("SSLEngine closed already"); } } - promise.tryFailure(exception); - promise = null; - // SSLEngine has been closed already. - // Any further write attempts should be denied. pendingUnencryptedWrites.releaseAndFailAll(ctx, exception); return; } else { - if (buf.isReadable()) { - pendingUnencryptedWrites.addFirst(buf, promise); - // When we add the buffer/promise pair back we need to be sure we don't complete the promise - // later in finishWrap. We only complete the promise if the buffer is completely consumed. - promise = null; - } else { - buf.release(); - } - buf = null; - switch (result.getHandshakeStatus()) { case NEED_TASK: if (!runDelegatedTasks(inUnwrap)) { @@ -885,27 +890,11 @@ public class SslHandler extends ByteToMessageDecoder { break; case FINISHED: setHandshakeSuccess(); - // deliberate fall-through + break; case NOT_HANDSHAKING: setHandshakeSuccessIfStillHandshaking(); - // deliberate fall-through - case NEED_WRAP: { - ChannelPromise p = promise; - - // Null out the promise so it is not reused in the finally block in the cause of - // finishWrap(...) throwing. - promise = null; - final ByteBuf b; - - if (out.isReadable()) { - // There is something in the out buffer. Ensure we null it out so it is not re-used. - b = out; - out = null; - } else { - // If out is not readable we can re-use it and so save an extra allocation - b = null; - } - finishWrap(ctx, b, p, inUnwrap, false); + break; + case NEED_WRAP: // If we are expected to wrap again and we produced some data we need to ensure there // is something in the queue to process as otherwise we will not try again before there // was more added. Failing to do so may fail to produce an alert that can be @@ -914,9 +903,10 @@ public class SslHandler extends ByteToMessageDecoder { pendingUnencryptedWrites.add(Unpooled.EMPTY_BUFFER); } break; - } case NEED_UNWRAP: - needUnwrap = true; + // The underlying engine is starving so we need to feed it with more data. + // See https://github.com/netty/netty/pull/5039 + readIfNeeded(ctx); return; default: throw new IllegalStateException( @@ -925,37 +915,12 @@ public class SslHandler extends ByteToMessageDecoder { } } } finally { - // Ownership of buffer was not transferred, release it. - if (buf != null) { - buf.release(); + if (out != null) { + out.release(); + } + if (inUnwrap) { + needsFlush = true; } - finishWrap(ctx, out, promise, inUnwrap, needUnwrap); - } - } - - private void finishWrap(ChannelHandlerContext ctx, ByteBuf out, ChannelPromise promise, boolean inUnwrap, - boolean needUnwrap) { - if (out == null) { - out = Unpooled.EMPTY_BUFFER; - } else if (!out.isReadable()) { - out.release(); - out = Unpooled.EMPTY_BUFFER; - } - - if (promise != null) { - ctx.write(out, promise); - } else { - ctx.write(out); - } - - if (inUnwrap) { - needsFlush = true; - } - - if (needUnwrap) { - // The underlying engine is starving so we need to feed it with more data. - // See https://github.com/netty/netty/pull/5039 - readIfNeeded(ctx); } } @@ -979,7 +944,6 @@ public class SslHandler extends ByteToMessageDecoder { out = allocateOutNetBuf(ctx, 2048, 1); } SSLEngineResult result = wrap(alloc, engine, Unpooled.EMPTY_BUFFER, out); - HandshakeStatus status = result.getHandshakeStatus(); if (result.bytesProduced() > 0) { ctx.write(out).addListener((ChannelFutureListener) future -> { Throwable cause = future.cause(); @@ -988,21 +952,22 @@ public class SslHandler extends ByteToMessageDecoder { } }); if (inUnwrap) { - // We may be here because we read data and discovered the remote peer initiated a renegotiation - // and this write is to complete the new handshake. The user may have previously done a - // writeAndFlush which wasn't able to wrap data due to needing the pending handshake, so we - // attempt to wrap application data here if any is pending. - if (status == HandshakeStatus.FINISHED && !pendingUnencryptedWrites.isEmpty()) { - wrap(ctx, true); - } needsFlush = true; } out = null; } + HandshakeStatus status = result.getHandshakeStatus(); switch (status) { case FINISHED: setHandshakeSuccess(); + // We may be here because we read data and discovered the remote peer initiated a renegotiation + // and this write is to complete the new handshake. The user may have previously done a + // writeAndFlush which wasn't able to wrap data due to needing the pending handshake, so we + // attempt to wrap application data here if any is pending. + if (inUnwrap && !pendingUnencryptedWrites.isEmpty()) { + wrap(ctx, true); + } return false; case NEED_TASK: if (!runDelegatedTasks(inUnwrap)) { diff --git a/handler/src/test/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandlerTest.java index b160970a9a..c3fcf998ec 100644 --- a/handler/src/test/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandlerTest.java @@ -15,10 +15,12 @@ */ package io.netty.handler.ssl; +import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.DecoderException; +import io.netty.util.CharsetUtil; import org.junit.Test; import javax.net.ssl.SSLContext; @@ -66,13 +68,20 @@ public class ApplicationProtocolNegotiationHandlerTest { }; SSLEngine engine = SSLContext.getDefault().createSSLEngine(); - engine.setUseClientMode(false); + // This test is mocked/simulated and doesn't go through full TLS handshake. Currently only JDK SSLEngineImpl + // client mode will generate a close_notify. + engine.setUseClientMode(true); EmbeddedChannel channel = new EmbeddedChannel(new SslHandler(engine), alpnHandler); channel.pipeline().fireUserEventTriggered(SslHandshakeCompletionEvent.SUCCESS); assertNull(channel.pipeline().context(alpnHandler)); // Should produce the close_notify messages - assertTrue(channel.finishAndReleaseAll()); + channel.releaseOutbound(); + channel.close(); + ByteBuf close_notify = channel.readOutbound(); + assertTrue("close_notify: " + close_notify.toString(CharsetUtil.UTF_8), close_notify.readableBytes() >= 7); + close_notify.release(); + channel.finishAndReleaseAll(); assertTrue(configureCalled.get()); } diff --git a/handler/src/test/java/io/netty/handler/ssl/ParameterizedSslHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/ParameterizedSslHandlerTest.java index 4d21f81ed5..449cf68d13 100644 --- a/handler/src/test/java/io/netty/handler/ssl/ParameterizedSslHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/ParameterizedSslHandlerTest.java @@ -26,12 +26,15 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInitializer; import io.netty.channel.EventLoopGroup; import io.netty.channel.MultithreadEventLoopGroup; +import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.nio.NioHandler; + import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import io.netty.handler.ssl.util.SelfSignedCertificate; import io.netty.handler.ssl.util.SimpleTrustManagerFactory; +import io.netty.util.CharsetUtil; import io.netty.util.ReferenceCountUtil; import io.netty.util.internal.ResourcesUtil; import io.netty.util.concurrent.FutureListener; @@ -53,9 +56,12 @@ import java.security.cert.X509Certificate; import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicBoolean; +import static io.netty.buffer.ByteBufUtil.writeAscii; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -484,5 +490,120 @@ public class ParameterizedSslHandlerTest { ReferenceCountUtil.release(sslClientCtx); } } + + @Test(timeout = 30000) + public void reentryWriteOnHandshakeComplete() throws Exception { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + final SslContext sslServerCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(serverProvider) + .build(); + + final SslContext sslClientCtx = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(clientProvider) + .build(); + + EventLoopGroup group = new MultithreadEventLoopGroup(NioHandler.newFactory()); + Channel sc = null; + Channel cc = null; + try { + final String expectedContent = "HelloWorld"; + final CountDownLatch serverLatch = new CountDownLatch(1); + final CountDownLatch clientLatch = new CountDownLatch(1); + final StringBuilder serverQueue = new StringBuilder(expectedContent.length()); + final StringBuilder clientQueue = new StringBuilder(expectedContent.length()); + + sc = new ServerBootstrap() + .group(group) + .channel(NioServerSocketChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(sslServerCtx.newHandler(ch.alloc())); + ch.pipeline().addLast(new ReentryWriteSslHandshakeHandler(expectedContent, serverQueue, + serverLatch)); + } + }).bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + + cc = new Bootstrap() + .group(group) + .channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(sslClientCtx.newHandler(ch.alloc())); + ch.pipeline().addLast(new ReentryWriteSslHandshakeHandler(expectedContent, clientQueue, + clientLatch)); + } + }).connect(sc.localAddress()).syncUninterruptibly().channel(); + + serverLatch.await(); + assertEquals(expectedContent, serverQueue.toString()); + clientLatch.await(); + assertEquals(expectedContent, clientQueue.toString()); + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + if (sc != null) { + sc.close().syncUninterruptibly(); + } + group.shutdownGracefully(); + + ReferenceCountUtil.release(sslServerCtx); + ReferenceCountUtil.release(sslClientCtx); + } + } + + private static final class ReentryWriteSslHandshakeHandler extends SimpleChannelInboundHandler { + private final String toWrite; + private final StringBuilder readQueue; + private final CountDownLatch doneLatch; + + ReentryWriteSslHandshakeHandler(String toWrite, StringBuilder readQueue, CountDownLatch doneLatch) { + this.toWrite = toWrite; + this.readQueue = readQueue; + this.doneLatch = doneLatch; + } + + @Override + public void channelActive(ChannelHandlerContext ctx) { + // Write toWrite in two chunks, first here then we get SslHandshakeCompletionEvent (which is re-entry). + ctx.writeAndFlush(writeAscii(ctx.alloc(), toWrite.substring(0, toWrite.length() / 2))); + } + + @Override + protected void messageReceived(ChannelHandlerContext ctx, ByteBuf msg) { + readQueue.append(msg.toString(CharsetUtil.US_ASCII)); + if (readQueue.length() >= toWrite.length()) { + doneLatch.countDown(); + } + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof SslHandshakeCompletionEvent) { + SslHandshakeCompletionEvent sslEvt = (SslHandshakeCompletionEvent) evt; + if (sslEvt.isSuccess()) { + // this is the re-entry write, it should be ordered after the subsequent write. + ctx.writeAndFlush(writeAscii(ctx.alloc(), toWrite.substring(toWrite.length() / 2))); + } else { + appendError(sslEvt.cause()); + } + } + ctx.fireUserEventTriggered(evt); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + appendError(cause); + ctx.fireExceptionCaught(cause); + } + + private void appendError(Throwable cause) { + readQueue.append("failed to write '").append(toWrite).append("': ").append(cause); + doneLatch.countDown(); + } + } }