From 53fc69390130f7a5a33d454a59ad214f6b4176f5 Mon Sep 17 00:00:00 2001 From: Scott Mitchell Date: Fri, 3 Mar 2017 09:22:26 -0800 Subject: [PATCH] SslHandler and OpenSslEngine miscalculation of wrap destination buffer size Motivation: When we do a wrap operation we calculate the maximum size of the destination buffer ahead of time, and return a BUFFER_OVERFLOW exception if the destination buffer is not big enough. However if there is a CompositeByteBuf the wrap operation may consist of multiple ByteBuffers and each incurs its own overhead during the encryption. We currently don't account for the overhead required for encryption if there are multiple ByteBuffers and we assume the overhead will only apply once to the entire input size. If there is not enough room to write an entire encrypted packed into the BIO SSL_write will return -1 despite having actually written content to the BIO. We then attempt to retry the write with a bigger buffer, but because SSL_write is stateful the remaining bytes from the previous operation are put into the BIO. This results in sending the second half of the encrypted data being sent to the peer which is not of proper format and the peer will be confused and ultimately not get the expected data (which may result in a fatal error). In this case because SSL_write returns -1 we have no way to know how many bytes were actually consumed and so the best we can do is ensure that we always allocate a destination buffer with enough space so we are guaranteed to complete the write operation synchronously. Modifications: - SslHandler#allocateNetBuf should take into account how many ByteBuffers will be wrapped and apply the encryption overhead for each - Include the TLS header length in the overhead computation Result: Fixes https://github.com/netty/netty/issues/6481 --- .../ssl/ReferenceCountedOpenSslEngine.java | 21 ++- .../java/io/netty/handler/ssl/SslHandler.java | 23 ++-- .../netty/handler/ssl/OpenSslEngineTest.java | 32 +++-- .../io/netty/handler/ssl/SslHandlerTest.java | 130 ++++++++++++++++++ 4 files changed, 168 insertions(+), 38 deletions(-) diff --git a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java index 80b9990936..679a57e9d6 100644 --- a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java +++ b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java @@ -57,6 +57,7 @@ import javax.net.ssl.SSLSessionContext; import javax.security.cert.X509Certificate; import static io.netty.handler.ssl.OpenSsl.memoryAddress; +import static io.netty.handler.ssl.SslUtils.SSL_RECORD_HEADER_LENGTH; import static io.netty.util.internal.EmptyArrays.EMPTY_CERTIFICATES; import static io.netty.util.internal.EmptyArrays.EMPTY_JAVAX_X509_CERTIFICATES; import static io.netty.util.internal.ObjectUtil.checkNotNull; @@ -107,15 +108,14 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc * allow up to 255 bytes. 16 bytes is the max for PKC#5 (which handles it the same way as PKC#7) as we use a block * size of 16. See rfc5652#section-6.3. * - * 16 (IV) + 48 (MAC) + 1 (Padding_length field) + 15 (Padding) + 1 (ContentType) + 2 (ProtocolVersion) + 2 (Length) + * TLS Header (5) + 16 (IV) + 48 (MAC) + 1 (Padding_length field) + 15 (Padding) + 1 (ContentType) + + * 2 (ProtocolVersion) + 2 (Length) * * TODO: We may need to review this calculation once TLS 1.3 becomes available. */ - static final int MAX_ENCRYPTION_OVERHEAD_LENGTH = 15 + 48 + 1 + 16 + 1 + 2 + 2; + static final int MAX_TLS_RECORD_OVERHEAD_LENGTH = SSL_RECORD_HEADER_LENGTH + 16 + 48 + 1 + 15 + 1 + 2 + 2; - static final int MAX_ENCRYPTED_PACKET_LENGTH = MAX_PLAINTEXT_LENGTH + MAX_ENCRYPTION_OVERHEAD_LENGTH; - - private static final int MAX_ENCRYPTION_OVERHEAD_DIFF = Integer.MAX_VALUE - MAX_ENCRYPTION_OVERHEAD_LENGTH; + static final int MAX_ENCRYPTED_PACKET_LENGTH = MAX_PLAINTEXT_LENGTH + MAX_TLS_RECORD_OVERHEAD_LENGTH; private static final AtomicIntegerFieldUpdater DESTROYED_UPDATER = AtomicIntegerFieldUpdater.newUpdater(ReferenceCountedOpenSslEngine.class, "destroyed"); @@ -561,7 +561,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc } } - if (dst.remaining() < calculateOutNetBufSize(srcsLen)) { + if (dst.remaining() < calculateOutNetBufSize(srcsLen, endOffset - offset)) { // Can not hold the maximum packet so we need to tell the caller to use a bigger destination // buffer. return new SSLEngineResult(BUFFER_OVERFLOW, getHandshakeStatus(), 0, 0); @@ -772,7 +772,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc } } - if (len < SslUtils.SSL_RECORD_HEADER_LENGTH) { + if (len < SSL_RECORD_HEADER_LENGTH) { return newResultMayFinishHandshake(BUFFER_UNDERFLOW, status, 0, 0); } @@ -782,7 +782,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc throw new NotSslRecordException("not an SSL/TLS record"); } - if (packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH > capacity) { + if (packetLength - SSL_RECORD_HEADER_LENGTH > capacity) { // No enough space in the destination buffer so signal the caller // that the buffer needs to be increased. return newResultMayFinishHandshake(BUFFER_OVERFLOW, status, 0, 0); @@ -1606,9 +1606,8 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc return destroyed != 0; } - static int calculateOutNetBufSize(int pendingBytes) { - return min(MAX_ENCRYPTED_PACKET_LENGTH, MAX_ENCRYPTION_OVERHEAD_LENGTH - + min(MAX_ENCRYPTION_OVERHEAD_DIFF, pendingBytes)); + static int calculateOutNetBufSize(int pendingBytes, int numComponents) { + return (int) min(Integer.MAX_VALUE, pendingBytes + (long) MAX_TLS_RECORD_OVERHEAD_LENGTH * numComponents); } private final class OpenSslSession implements SSLSession, ApplicationProtocolAccessor { diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index 601f8dfd29..ac9843b0e8 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -193,7 +193,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH * we can use a special {@link OpenSslEngine#unwrap(ByteBuffer[], ByteBuffer[])} method * that accepts multiple {@link ByteBuffer}s without additional memory copies. */ - OpenSslEngine opensslEngine = (OpenSslEngine) handler.engine; + ReferenceCountedOpenSslEngine opensslEngine = (ReferenceCountedOpenSslEngine) handler.engine; try { handler.singleBuffer[0] = toByteBuffer(out, writerIndex, out.writableBytes()); @@ -210,8 +210,8 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH } @Override - int calculateOutNetBufSize(SslHandler handler, int pendingBytes) { - return ReferenceCountedOpenSslEngine.calculateOutNetBufSize(pendingBytes); + int calculateOutNetBufSize(SslHandler handler, int pendingBytes, int numComponents) { + return ReferenceCountedOpenSslEngine.calculateOutNetBufSize(pendingBytes, numComponents); } }, JDK(false, MERGE_CUMULATOR) { @@ -226,16 +226,13 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH } @Override - int calculateOutNetBufSize(SslHandler handler, int pendingBytes) { + int calculateOutNetBufSize(SslHandler handler, int pendingBytes, int numComponents) { return handler.maxPacketBufferSize; } }; static SslEngineType forEngine(SSLEngine engine) { - if (engine instanceof OpenSslEngine) { - return TCNATIVE; - } - return JDK; + return engine instanceof ReferenceCountedOpenSslEngine ? TCNATIVE : JDK; } SslEngineType(boolean wantsDirectBuffer, Cumulator cumulator) { @@ -246,7 +243,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH abstract SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int readerIndex, int len, ByteBuf out) throws SSLException; - abstract int calculateOutNetBufSize(SslHandler handler, int pendingBytes); + abstract int calculateOutNetBufSize(SslHandler handler, int pendingBytes, int numComponents); // BEGIN Platform-dependent flags @@ -652,7 +649,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH ByteBuf buf = (ByteBuf) msg; if (out == null) { - out = allocateOutNetBuf(ctx, buf.readableBytes()); + out = allocateOutNetBuf(ctx, buf.readableBytes(), buf.nioBufferCount()); } SSLEngineResult result = wrap(alloc, engine, buf, out); @@ -741,7 +738,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH // As this is called for the handshake we have no real idea how big the buffer needs to be. // That said 2048 should give us enough room to include everything like ALPN / NPN data. // If this is not enough we will increase the buffer in wrap(...). - out = allocateOutNetBuf(ctx, 2048); + out = allocateOutNetBuf(ctx, 2048, 1); } SSLEngineResult result = wrap(alloc, engine, Unpooled.EMPTY_BUFFER, out); @@ -1694,8 +1691,8 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH * Allocates an outbound network buffer for {@link SSLEngine#wrap(ByteBuffer, ByteBuffer)} which can encrypt * the specified amount of pending bytes. */ - private ByteBuf allocateOutNetBuf(ChannelHandlerContext ctx, int pendingBytes) { - return allocate(ctx, engineType.calculateOutNetBufSize(this, pendingBytes)); + private ByteBuf allocateOutNetBuf(ChannelHandlerContext ctx, int pendingBytes, int numComponents) { + return allocate(ctx, engineType.calculateOutNetBufSize(this, pendingBytes, numComponents)); } private final class LazyChannelPromise extends DefaultPromise { diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java index 517271441c..07fcc48380 100644 --- a/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java @@ -39,7 +39,11 @@ import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLException; import javax.net.ssl.SSLParameters; +import static io.netty.handler.ssl.ReferenceCountedOpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH; +import static io.netty.handler.ssl.ReferenceCountedOpenSslEngine.MAX_TLS_RECORD_OVERHEAD_LENGTH; +import static io.netty.handler.ssl.ReferenceCountedOpenSslEngine.MAX_PLAINTEXT_LENGTH; import static io.netty.internal.tcnative.SSL.SSL_CVERIFY_IGNORED; +import static java.lang.Integer.MAX_VALUE; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; @@ -200,12 +204,12 @@ public class OpenSslEngineTest extends SSLEngineTest { ByteBuffer src = allocateBuffer(srcLen); ByteBuffer dstTooSmall = allocateBuffer( - src.capacity() + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH - 1); + src.capacity() + MAX_TLS_RECORD_OVERHEAD_LENGTH - 1); ByteBuffer dst = allocateBuffer( - src.capacity() + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH); + src.capacity() + MAX_TLS_RECORD_OVERHEAD_LENGTH); // Check that we fail to wrap if the dst buffers capacity is not at least - // src.capacity() + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH + // src.capacity() + ReferenceCountedOpenSslEngine.MAX_TLS_RECORD_OVERHEAD_LENGTH SSLEngineResult result = clientEngine.wrap(src, dstTooSmall); assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus()); assertEquals(0, result.bytesConsumed()); @@ -214,7 +218,7 @@ public class OpenSslEngineTest extends SSLEngineTest { assertEquals(dst.remaining(), dst.capacity()); // Check that we can wrap with a dst buffer that has the capacity of - // src.capacity() + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH + // src.capacity() + ReferenceCountedOpenSslEngine.MAX_TLS_RECORD_OVERHEAD_LENGTH result = clientEngine.wrap(src, dst); assertEquals(SSLEngineResult.Status.OK, result.getStatus()); assertEquals(srcLen, result.bytesConsumed()); @@ -249,7 +253,7 @@ public class OpenSslEngineTest extends SSLEngineTest { ByteBuffer src2 = src.duplicate(); ByteBuffer dst = allocateBuffer(src.capacity() - + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH); + + MAX_TLS_RECORD_OVERHEAD_LENGTH); SSLEngineResult result = clientEngine.wrap(new ByteBuffer[] { src, src2 }, dst); assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus()); @@ -284,7 +288,7 @@ public class OpenSslEngineTest extends SSLEngineTest { ByteBuffer src = allocateBuffer(1024); List srcList = new ArrayList(); long srcsLen = 0; - long maxLen = ((long) Integer.MAX_VALUE) * 2; + long maxLen = ((long) MAX_VALUE) * 2; while (srcsLen < maxLen) { ByteBuffer dup = src.duplicate(); @@ -294,7 +298,7 @@ public class OpenSslEngineTest extends SSLEngineTest { ByteBuffer[] srcs = srcList.toArray(new ByteBuffer[srcList.size()]); - ByteBuffer dst = allocateBuffer(ReferenceCountedOpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH - 1); + ByteBuffer dst = allocateBuffer(MAX_ENCRYPTED_PACKET_LENGTH - 1); SSLEngineResult result = clientEngine.wrap(srcs, dst); assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus()); @@ -313,14 +317,14 @@ public class OpenSslEngineTest extends SSLEngineTest { @Test public void testCalculateOutNetBufSizeOverflow() { - assertEquals(ReferenceCountedOpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH, - ReferenceCountedOpenSslEngine.calculateOutNetBufSize(Integer.MAX_VALUE)); + assertEquals(MAX_VALUE, + ReferenceCountedOpenSslEngine.calculateOutNetBufSize(MAX_VALUE, 1)); } @Test public void testCalculateOutNetBufSize0() { - assertEquals(ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH, - ReferenceCountedOpenSslEngine.calculateOutNetBufSize(0)); + assertEquals(MAX_TLS_RECORD_OVERHEAD_LENGTH, + ReferenceCountedOpenSslEngine.calculateOutNetBufSize(0, 1)); } @Override @@ -538,9 +542,9 @@ public class OpenSslEngineTest extends SSLEngineTest { do { testWrapDstBigEnough(clientEngine, srcLen); srcLen += 64; - } while (srcLen < ReferenceCountedOpenSslEngine.MAX_PLAINTEXT_LENGTH); + } while (srcLen < MAX_PLAINTEXT_LENGTH); - testWrapDstBigEnough(clientEngine, ReferenceCountedOpenSslEngine.MAX_PLAINTEXT_LENGTH); + testWrapDstBigEnough(clientEngine, MAX_PLAINTEXT_LENGTH); } finally { cleanupClientSslEngine(clientEngine); cleanupServerSslEngine(serverEngine); @@ -549,7 +553,7 @@ public class OpenSslEngineTest extends SSLEngineTest { private void testWrapDstBigEnough(SSLEngine engine, int srcLen) throws SSLException { ByteBuffer src = allocateBuffer(srcLen); - ByteBuffer dst = allocateBuffer(srcLen + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH); + ByteBuffer dst = allocateBuffer(srcLen + MAX_TLS_RECORD_OVERHEAD_LENGTH); SSLEngineResult result = engine.wrap(src, dst); assertEquals(SSLEngineResult.Status.OK, result.getStatus()); diff --git a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java index b39de0fced..89ff9223b0 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java @@ -32,7 +32,11 @@ import javax.net.ssl.X509TrustManager; import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelInitializer; import io.netty.channel.EventLoopGroup; @@ -69,6 +73,7 @@ import java.security.KeyStore; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ExecutionException; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicBoolean; @@ -560,4 +565,129 @@ public class SslHandlerTest { ReferenceCountUtil.release(sslClientCtx); } } + + @Test(timeout = 30000) + public void testCompositeBufSizeEstimationGuaranteesSynchronousWrite() + throws CertificateException, SSLException, ExecutionException, InterruptedException { + SslProvider[] providers = SslProvider.values(); + for (int i = 0; i < providers.length; ++i) { + for (int j = 0; j < providers.length; ++j) { + compositeBufSizeEstimationGuaranteesSynchronousWrite(providers[i], providers[j]); + } + } + } + + private void compositeBufSizeEstimationGuaranteesSynchronousWrite( + SslProvider serverProvider, SslProvider clientProvider) + throws CertificateException, SSLException, ExecutionException, InterruptedException { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + + final SslContext sslServerCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(serverProvider) + .build(); + + final SslContext sslClientCtx = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(clientProvider).build(); + + EventLoopGroup group = new NioEventLoopGroup(); + Channel sc = null; + Channel cc = null; + try { + final Promise donePromise = group.next().newPromise(); + final int expectedBytes = 469 + 1024 + 1024; + + sc = new ServerBootstrap() + .group(group) + .channel(NioServerSocketChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(sslServerCtx.newHandler(ch.alloc())); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof SslHandshakeCompletionEvent) { + SslHandshakeCompletionEvent sslEvt = (SslHandshakeCompletionEvent) evt; + if (sslEvt.isSuccess()) { + final ByteBuf input = ctx.alloc().buffer(); + input.writeBytes(new byte[expectedBytes]); + CompositeByteBuf content = ctx.alloc().compositeBuffer(); + content.addComponent(true, input.readRetainedSlice(469)); + content.addComponent(true, input.readRetainedSlice(1024)); + content.addComponent(true, input.readRetainedSlice(1024)); + ctx.writeAndFlush(content).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + input.release(); + } + }); + } else { + donePromise.tryFailure(sslEvt.cause()); + } + } + ctx.fireUserEventTriggered(evt); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + donePromise.tryFailure(cause); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + donePromise.tryFailure(new IllegalStateException("server closed")); + } + }); + } + }).bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + + cc = new Bootstrap() + .group(group) + .channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(sslClientCtx.newHandler(ch.alloc())); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + private int bytesSeen; + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + if (msg instanceof ByteBuf) { + bytesSeen += ((ByteBuf) msg).readableBytes(); + if (bytesSeen == expectedBytes) { + donePromise.trySuccess(null); + } + } + ReferenceCountUtil.release(msg); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + donePromise.tryFailure(cause); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + donePromise.tryFailure(new IllegalStateException("client closed")); + } + }); + } + }).connect(sc.localAddress()).syncUninterruptibly().channel(); + + donePromise.get(); + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + if (sc != null) { + sc.close().syncUninterruptibly(); + } + group.shutdownGracefully(); + + ReferenceCountUtil.release(sslServerCtx); + ReferenceCountUtil.release(sslClientCtx); + ssc.delete(); + } + } }