SslHandler wrap reentry bug fix (#11133)

Motivation:
SslHandler's wrap method notifies the handshakeFuture and sends a
SslHandshakeCompletionEvent user event down the pipeline before writing
the plaintext that has just been wrapped. It is possible the application
may write as a result of these events and re-enter into wrap to write
more data. This will result in out of sequence data and result in alerts
such as SSLV3_ALERT_BAD_RECORD_MAC.

Modifications:
- SslHandler wrap should write any pending data before notifying
  promises, generating user events, or anything else that may create a
  re-entry scenario.

Result:
SslHandler will wrap/write data in the same order.
This commit is contained in:
Scott Mitchell 2021-04-01 01:00:18 -07:00 committed by GitHub
parent e4dd6ee532
commit 6b48e690fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 179 additions and 87 deletions

View File

@ -821,17 +821,14 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
// This method will not call setHandshakeFailure(...) !
private void wrap(ChannelHandlerContext ctx, boolean inUnwrap) throws SSLException {
ByteBuf out = null;
ChannelPromise promise = null;
ByteBufAllocator alloc = ctx.alloc();
boolean needUnwrap = false;
ByteBuf buf = null;
try {
final int wrapDataSize = this.wrapDataSize;
// Only continue to loop if the handler was not removed in the meantime.
// See https://github.com/netty/netty/issues/5860
outer: while (!ctx.isRemoved()) {
promise = ctx.newPromise();
buf = wrapDataSize > 0 ?
ChannelPromise promise = ctx.newPromise();
ByteBuf buf = wrapDataSize > 0 ?
pendingUnencryptedWrites.remove(alloc, wrapDataSize, promise) :
pendingUnencryptedWrites.removeFirst(promise);
if (buf == null) {
@ -844,9 +841,31 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
SSLEngineResult result = wrap(alloc, engine, buf, out);
if (result.getStatus() == Status.CLOSED) {
if (buf.isReadable()) {
pendingUnencryptedWrites.addFirst(buf, promise);
// When we add the buffer/promise pair back we need to be sure we don't complete the promise
// later. We only complete the promise if the buffer is completely consumed.
promise = null;
} else {
buf.release();
buf = null;
}
// We need to write any data before we invoke any methods which may trigger re-entry, otherwise
// writes may occur out of order and TLS sequencing may be off (e.g. SSLV3_ALERT_BAD_RECORD_MAC).
if (out.isReadable()) {
final ByteBuf b = out;
out = null;
if (promise != null) {
ctx.write(b, promise);
} else {
ctx.write(b);
}
} else if (promise != null) {
ctx.write(Unpooled.EMPTY_BUFFER, promise);
}
// else out is not readable we can re-use it and so save an extra allocation
if (result.getStatus() == Status.CLOSED) {
// Make a best effort to preserve any exception that way previously encountered from the handshake
// or the transport, else fallback to a general error.
Throwable exception = handshakePromise.cause();
@ -856,23 +875,9 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
exception = new SslClosedEngineException("SSLEngine closed already");
}
}
promise.tryFailure(exception);
promise = null;
// SSLEngine has been closed already.
// Any further write attempts should be denied.
pendingUnencryptedWrites.releaseAndFailAll(ctx, exception);
return;
} else {
if (buf.isReadable()) {
pendingUnencryptedWrites.addFirst(buf, promise);
// When we add the buffer/promise pair back we need to be sure we don't complete the promise
// later in finishWrap. We only complete the promise if the buffer is completely consumed.
promise = null;
} else {
buf.release();
}
buf = null;
switch (result.getHandshakeStatus()) {
case NEED_TASK:
if (!runDelegatedTasks(inUnwrap)) {
@ -883,27 +888,11 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
break;
case FINISHED:
setHandshakeSuccess();
// deliberate fall-through
break;
case NOT_HANDSHAKING:
setHandshakeSuccessIfStillHandshaking();
// deliberate fall-through
case NEED_WRAP: {
ChannelPromise p = promise;
// Null out the promise so it is not reused in the finally block in the cause of
// finishWrap(...) throwing.
promise = null;
final ByteBuf b;
if (out.isReadable()) {
// There is something in the out buffer. Ensure we null it out so it is not re-used.
b = out;
out = null;
} else {
// If out is not readable we can re-use it and so save an extra allocation
b = null;
}
finishWrap(ctx, b, p, inUnwrap, false);
break;
case NEED_WRAP:
// If we are expected to wrap again and we produced some data we need to ensure there
// is something in the queue to process as otherwise we will not try again before there
// was more added. Failing to do so may fail to produce an alert that can be
@ -912,9 +901,10 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
pendingUnencryptedWrites.add(Unpooled.EMPTY_BUFFER);
}
break;
}
case NEED_UNWRAP:
needUnwrap = true;
// The underlying engine is starving so we need to feed it with more data.
// See https://github.com/netty/netty/pull/5039
readIfNeeded(ctx);
return;
default:
throw new IllegalStateException(
@ -923,37 +913,12 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
}
}
} finally {
// Ownership of buffer was not transferred, release it.
if (buf != null) {
buf.release();
}
finishWrap(ctx, out, promise, inUnwrap, needUnwrap);
}
}
private void finishWrap(ChannelHandlerContext ctx, ByteBuf out, ChannelPromise promise, boolean inUnwrap,
boolean needUnwrap) {
if (out == null) {
out = Unpooled.EMPTY_BUFFER;
} else if (!out.isReadable()) {
if (out != null) {
out.release();
out = Unpooled.EMPTY_BUFFER;
}
if (promise != null) {
ctx.write(out, promise);
} else {
ctx.write(out);
}
if (inUnwrap) {
needsFlush = true;
}
if (needUnwrap) {
// The underlying engine is starving so we need to feed it with more data.
// See https://github.com/netty/netty/pull/5039
readIfNeeded(ctx);
}
}
@ -977,7 +942,6 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
out = allocateOutNetBuf(ctx, 2048, 1);
}
SSLEngineResult result = wrap(alloc, engine, Unpooled.EMPTY_BUFFER, out);
HandshakeStatus status = result.getHandshakeStatus();
if (result.bytesProduced() > 0) {
ctx.write(out).addListener(new ChannelFutureListener() {
@Override
@ -989,21 +953,22 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
}
});
if (inUnwrap) {
// We may be here because we read data and discovered the remote peer initiated a renegotiation
// and this write is to complete the new handshake. The user may have previously done a
// writeAndFlush which wasn't able to wrap data due to needing the pending handshake, so we
// attempt to wrap application data here if any is pending.
if (status == HandshakeStatus.FINISHED && !pendingUnencryptedWrites.isEmpty()) {
wrap(ctx, true);
}
needsFlush = true;
}
out = null;
}
HandshakeStatus status = result.getHandshakeStatus();
switch (status) {
case FINISHED:
setHandshakeSuccess();
// We may be here because we read data and discovered the remote peer initiated a renegotiation
// and this write is to complete the new handshake. The user may have previously done a
// writeAndFlush which wasn't able to wrap data due to needing the pending handshake, so we
// attempt to wrap application data here if any is pending.
if (inUnwrap && !pendingUnencryptedWrites.isEmpty()) {
wrap(ctx, true);
}
return false;
case NEED_TASK:
if (!runDelegatedTasks(inUnwrap)) {
@ -1095,11 +1060,9 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
in.skipBytes(result.bytesConsumed());
out.writerIndex(out.writerIndex() + result.bytesProduced());
switch (result.getStatus()) {
case BUFFER_OVERFLOW:
if (result.getStatus() == Status.BUFFER_OVERFLOW) {
out.ensureWritable(engine.getSession().getPacketBufferSize());
break;
default:
} else {
return result;
}
}

View File

@ -15,10 +15,12 @@
*/
package io.netty.handler.ssl;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.DecoderException;
import io.netty.util.CharsetUtil;
import org.junit.Test;
import javax.net.ssl.SSLContext;
@ -66,13 +68,20 @@ public class ApplicationProtocolNegotiationHandlerTest {
};
SSLEngine engine = SSLContext.getDefault().createSSLEngine();
engine.setUseClientMode(false);
// This test is mocked/simulated and doesn't go through full TLS handshake. Currently only JDK SSLEngineImpl
// client mode will generate a close_notify.
engine.setUseClientMode(true);
EmbeddedChannel channel = new EmbeddedChannel(new SslHandler(engine), alpnHandler);
channel.pipeline().fireUserEventTriggered(SslHandshakeCompletionEvent.SUCCESS);
assertNull(channel.pipeline().context(alpnHandler));
// Should produce the close_notify messages
assertTrue(channel.finishAndReleaseAll());
channel.releaseOutbound();
channel.close();
ByteBuf close_notify = channel.readOutbound();
assertTrue("close_notify: " + close_notify.toString(CharsetUtil.UTF_8), close_notify.readableBytes() >= 7);
close_notify.release();
channel.finishAndReleaseAll();
assertTrue(configureCalled.get());
}

View File

@ -26,19 +26,21 @@ import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.handler.ssl.util.SimpleTrustManagerFactory;
import io.netty.util.CharsetUtil;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.ResourcesUtil;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.FutureListener;
import io.netty.util.concurrent.Promise;
import io.netty.util.concurrent.PromiseNotifier;
import io.netty.util.internal.EmptyArrays;
import io.netty.util.internal.ResourcesUtil;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@ -54,9 +56,12 @@ import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicBoolean;
import static io.netty.buffer.ByteBufUtil.writeAscii;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
@ -494,5 +499,120 @@ public class ParameterizedSslHandlerTest {
ReferenceCountUtil.release(sslClientCtx);
}
}
@Test(timeout = 30000)
public void reentryWriteOnHandshakeComplete() throws Exception {
SelfSignedCertificate ssc = new SelfSignedCertificate();
final SslContext sslServerCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey())
.sslProvider(serverProvider)
.build();
final SslContext sslClientCtx = SslContextBuilder.forClient()
.trustManager(InsecureTrustManagerFactory.INSTANCE)
.sslProvider(clientProvider)
.build();
EventLoopGroup group = new NioEventLoopGroup();
Channel sc = null;
Channel cc = null;
try {
final String expectedContent = "HelloWorld";
final CountDownLatch serverLatch = new CountDownLatch(1);
final CountDownLatch clientLatch = new CountDownLatch(1);
final StringBuilder serverQueue = new StringBuilder(expectedContent.length());
final StringBuilder clientQueue = new StringBuilder(expectedContent.length());
sc = new ServerBootstrap()
.group(group)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) {
ch.pipeline().addLast(sslServerCtx.newHandler(ch.alloc()));
ch.pipeline().addLast(new ReentryWriteSslHandshakeHandler(expectedContent, serverQueue,
serverLatch));
}
}).bind(new InetSocketAddress(0)).syncUninterruptibly().channel();
cc = new Bootstrap()
.group(group)
.channel(NioSocketChannel.class)
.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) {
ch.pipeline().addLast(sslClientCtx.newHandler(ch.alloc()));
ch.pipeline().addLast(new ReentryWriteSslHandshakeHandler(expectedContent, clientQueue,
clientLatch));
}
}).connect(sc.localAddress()).syncUninterruptibly().channel();
serverLatch.await();
assertEquals(expectedContent, serverQueue.toString());
clientLatch.await();
assertEquals(expectedContent, clientQueue.toString());
} finally {
if (cc != null) {
cc.close().syncUninterruptibly();
}
if (sc != null) {
sc.close().syncUninterruptibly();
}
group.shutdownGracefully();
ReferenceCountUtil.release(sslServerCtx);
ReferenceCountUtil.release(sslClientCtx);
}
}
private static final class ReentryWriteSslHandshakeHandler extends SimpleChannelInboundHandler<ByteBuf> {
private final String toWrite;
private final StringBuilder readQueue;
private final CountDownLatch doneLatch;
ReentryWriteSslHandshakeHandler(String toWrite, StringBuilder readQueue, CountDownLatch doneLatch) {
this.toWrite = toWrite;
this.readQueue = readQueue;
this.doneLatch = doneLatch;
}
@Override
public void channelActive(ChannelHandlerContext ctx) {
// Write toWrite in two chunks, first here then we get SslHandshakeCompletionEvent (which is re-entry).
ctx.writeAndFlush(writeAscii(ctx.alloc(), toWrite.substring(0, toWrite.length() / 2)));
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) {
readQueue.append(msg.toString(CharsetUtil.US_ASCII));
if (readQueue.length() >= toWrite.length()) {
doneLatch.countDown();
}
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof SslHandshakeCompletionEvent) {
SslHandshakeCompletionEvent sslEvt = (SslHandshakeCompletionEvent) evt;
if (sslEvt.isSuccess()) {
// this is the re-entry write, it should be ordered after the subsequent write.
ctx.writeAndFlush(writeAscii(ctx.alloc(), toWrite.substring(toWrite.length() / 2)));
} else {
appendError(sslEvt.cause());
}
}
ctx.fireUserEventTriggered(evt);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
appendError(cause);
ctx.fireExceptionCaught(cause);
}
private void appendError(Throwable cause) {
readQueue.append("failed to write '").append(toWrite).append("': ").append(cause);
doneLatch.countDown();
}
}
}