SslHandler and OpenSslEngine miscalculation of wrap destination buffer size

Motivation:
When we do a wrap operation we calculate the maximum size of the destination buffer ahead of time, and return a BUFFER_OVERFLOW exception if the destination buffer is not big enough. However if there is a CompositeByteBuf the wrap operation may consist of multiple ByteBuffers and each incurs its own overhead during the encryption. We currently don't account for the overhead required for encryption if there are multiple ByteBuffers and we assume the overhead will only apply once to the entire input size. If there is not enough room to write an entire encrypted packed into the BIO SSL_write will return -1 despite having actually written content to the BIO. We then attempt to retry the write with a bigger buffer, but because SSL_write is stateful the remaining bytes from the previous operation are put into the BIO. This results in sending the second half of the encrypted data being sent to the peer which is not of proper format and the peer will be confused and ultimately not get the expected data (which may result in a fatal error). In this case because SSL_write returns -1 we have no way to know how many bytes were actually consumed and so the best we can do is ensure that we always allocate a destination buffer with enough space so we are guaranteed to complete the write operation synchronously.

Modifications:
- SslHandler#allocateNetBuf should take into account how many ByteBuffers will be wrapped and apply the encryption overhead for each
- Include the TLS header length in the overhead computation

Result:
Fixes https://github.com/netty/netty/issues/6481
This commit is contained in:
Scott Mitchell 2017-03-03 09:22:26 -08:00
parent f343de8fb1
commit 53fc693901
4 changed files with 168 additions and 38 deletions

View File

@ -57,6 +57,7 @@ import javax.net.ssl.SSLSessionContext;
import javax.security.cert.X509Certificate;
import static io.netty.handler.ssl.OpenSsl.memoryAddress;
import static io.netty.handler.ssl.SslUtils.SSL_RECORD_HEADER_LENGTH;
import static io.netty.util.internal.EmptyArrays.EMPTY_CERTIFICATES;
import static io.netty.util.internal.EmptyArrays.EMPTY_JAVAX_X509_CERTIFICATES;
import static io.netty.util.internal.ObjectUtil.checkNotNull;
@ -107,15 +108,14 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
* allow up to 255 bytes. 16 bytes is the max for PKC#5 (which handles it the same way as PKC#7) as we use a block
* size of 16. See <a href="https://tools.ietf.org/html/rfc5652#section-6.3">rfc5652#section-6.3</a>.
*
* 16 (IV) + 48 (MAC) + 1 (Padding_length field) + 15 (Padding) + 1 (ContentType) + 2 (ProtocolVersion) + 2 (Length)
* TLS Header (5) + 16 (IV) + 48 (MAC) + 1 (Padding_length field) + 15 (Padding) + 1 (ContentType) +
* 2 (ProtocolVersion) + 2 (Length)
*
* TODO: We may need to review this calculation once TLS 1.3 becomes available.
*/
static final int MAX_ENCRYPTION_OVERHEAD_LENGTH = 15 + 48 + 1 + 16 + 1 + 2 + 2;
static final int MAX_TLS_RECORD_OVERHEAD_LENGTH = SSL_RECORD_HEADER_LENGTH + 16 + 48 + 1 + 15 + 1 + 2 + 2;
static final int MAX_ENCRYPTED_PACKET_LENGTH = MAX_PLAINTEXT_LENGTH + MAX_ENCRYPTION_OVERHEAD_LENGTH;
private static final int MAX_ENCRYPTION_OVERHEAD_DIFF = Integer.MAX_VALUE - MAX_ENCRYPTION_OVERHEAD_LENGTH;
static final int MAX_ENCRYPTED_PACKET_LENGTH = MAX_PLAINTEXT_LENGTH + MAX_TLS_RECORD_OVERHEAD_LENGTH;
private static final AtomicIntegerFieldUpdater<ReferenceCountedOpenSslEngine> DESTROYED_UPDATER =
AtomicIntegerFieldUpdater.newUpdater(ReferenceCountedOpenSslEngine.class, "destroyed");
@ -561,7 +561,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
}
}
if (dst.remaining() < calculateOutNetBufSize(srcsLen)) {
if (dst.remaining() < calculateOutNetBufSize(srcsLen, endOffset - offset)) {
// Can not hold the maximum packet so we need to tell the caller to use a bigger destination
// buffer.
return new SSLEngineResult(BUFFER_OVERFLOW, getHandshakeStatus(), 0, 0);
@ -772,7 +772,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
}
}
if (len < SslUtils.SSL_RECORD_HEADER_LENGTH) {
if (len < SSL_RECORD_HEADER_LENGTH) {
return newResultMayFinishHandshake(BUFFER_UNDERFLOW, status, 0, 0);
}
@ -782,7 +782,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
throw new NotSslRecordException("not an SSL/TLS record");
}
if (packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH > capacity) {
if (packetLength - SSL_RECORD_HEADER_LENGTH > capacity) {
// No enough space in the destination buffer so signal the caller
// that the buffer needs to be increased.
return newResultMayFinishHandshake(BUFFER_OVERFLOW, status, 0, 0);
@ -1606,9 +1606,8 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
return destroyed != 0;
}
static int calculateOutNetBufSize(int pendingBytes) {
return min(MAX_ENCRYPTED_PACKET_LENGTH, MAX_ENCRYPTION_OVERHEAD_LENGTH
+ min(MAX_ENCRYPTION_OVERHEAD_DIFF, pendingBytes));
static int calculateOutNetBufSize(int pendingBytes, int numComponents) {
return (int) min(Integer.MAX_VALUE, pendingBytes + (long) MAX_TLS_RECORD_OVERHEAD_LENGTH * numComponents);
}
private final class OpenSslSession implements SSLSession, ApplicationProtocolAccessor {

View File

@ -193,7 +193,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
* we can use a special {@link OpenSslEngine#unwrap(ByteBuffer[], ByteBuffer[])} method
* that accepts multiple {@link ByteBuffer}s without additional memory copies.
*/
OpenSslEngine opensslEngine = (OpenSslEngine) handler.engine;
ReferenceCountedOpenSslEngine opensslEngine = (ReferenceCountedOpenSslEngine) handler.engine;
try {
handler.singleBuffer[0] = toByteBuffer(out, writerIndex,
out.writableBytes());
@ -210,8 +210,8 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
}
@Override
int calculateOutNetBufSize(SslHandler handler, int pendingBytes) {
return ReferenceCountedOpenSslEngine.calculateOutNetBufSize(pendingBytes);
int calculateOutNetBufSize(SslHandler handler, int pendingBytes, int numComponents) {
return ReferenceCountedOpenSslEngine.calculateOutNetBufSize(pendingBytes, numComponents);
}
},
JDK(false, MERGE_CUMULATOR) {
@ -226,16 +226,13 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
}
@Override
int calculateOutNetBufSize(SslHandler handler, int pendingBytes) {
int calculateOutNetBufSize(SslHandler handler, int pendingBytes, int numComponents) {
return handler.maxPacketBufferSize;
}
};
static SslEngineType forEngine(SSLEngine engine) {
if (engine instanceof OpenSslEngine) {
return TCNATIVE;
}
return JDK;
return engine instanceof ReferenceCountedOpenSslEngine ? TCNATIVE : JDK;
}
SslEngineType(boolean wantsDirectBuffer, Cumulator cumulator) {
@ -246,7 +243,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
abstract SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int readerIndex, int len, ByteBuf out)
throws SSLException;
abstract int calculateOutNetBufSize(SslHandler handler, int pendingBytes);
abstract int calculateOutNetBufSize(SslHandler handler, int pendingBytes, int numComponents);
// BEGIN Platform-dependent flags
@ -652,7 +649,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
ByteBuf buf = (ByteBuf) msg;
if (out == null) {
out = allocateOutNetBuf(ctx, buf.readableBytes());
out = allocateOutNetBuf(ctx, buf.readableBytes(), buf.nioBufferCount());
}
SSLEngineResult result = wrap(alloc, engine, buf, out);
@ -741,7 +738,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
// As this is called for the handshake we have no real idea how big the buffer needs to be.
// That said 2048 should give us enough room to include everything like ALPN / NPN data.
// If this is not enough we will increase the buffer in wrap(...).
out = allocateOutNetBuf(ctx, 2048);
out = allocateOutNetBuf(ctx, 2048, 1);
}
SSLEngineResult result = wrap(alloc, engine, Unpooled.EMPTY_BUFFER, out);
@ -1694,8 +1691,8 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
* Allocates an outbound network buffer for {@link SSLEngine#wrap(ByteBuffer, ByteBuffer)} which can encrypt
* the specified amount of pending bytes.
*/
private ByteBuf allocateOutNetBuf(ChannelHandlerContext ctx, int pendingBytes) {
return allocate(ctx, engineType.calculateOutNetBufSize(this, pendingBytes));
private ByteBuf allocateOutNetBuf(ChannelHandlerContext ctx, int pendingBytes, int numComponents) {
return allocate(ctx, engineType.calculateOutNetBufSize(this, pendingBytes, numComponents));
}
private final class LazyChannelPromise extends DefaultPromise<Channel> {

View File

@ -39,7 +39,11 @@ import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLParameters;
import static io.netty.handler.ssl.ReferenceCountedOpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH;
import static io.netty.handler.ssl.ReferenceCountedOpenSslEngine.MAX_TLS_RECORD_OVERHEAD_LENGTH;
import static io.netty.handler.ssl.ReferenceCountedOpenSslEngine.MAX_PLAINTEXT_LENGTH;
import static io.netty.internal.tcnative.SSL.SSL_CVERIFY_IGNORED;
import static java.lang.Integer.MAX_VALUE;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
@ -200,12 +204,12 @@ public class OpenSslEngineTest extends SSLEngineTest {
ByteBuffer src = allocateBuffer(srcLen);
ByteBuffer dstTooSmall = allocateBuffer(
src.capacity() + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH - 1);
src.capacity() + MAX_TLS_RECORD_OVERHEAD_LENGTH - 1);
ByteBuffer dst = allocateBuffer(
src.capacity() + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH);
src.capacity() + MAX_TLS_RECORD_OVERHEAD_LENGTH);
// Check that we fail to wrap if the dst buffers capacity is not at least
// src.capacity() + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH
// src.capacity() + ReferenceCountedOpenSslEngine.MAX_TLS_RECORD_OVERHEAD_LENGTH
SSLEngineResult result = clientEngine.wrap(src, dstTooSmall);
assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus());
assertEquals(0, result.bytesConsumed());
@ -214,7 +218,7 @@ public class OpenSslEngineTest extends SSLEngineTest {
assertEquals(dst.remaining(), dst.capacity());
// Check that we can wrap with a dst buffer that has the capacity of
// src.capacity() + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH
// src.capacity() + ReferenceCountedOpenSslEngine.MAX_TLS_RECORD_OVERHEAD_LENGTH
result = clientEngine.wrap(src, dst);
assertEquals(SSLEngineResult.Status.OK, result.getStatus());
assertEquals(srcLen, result.bytesConsumed());
@ -249,7 +253,7 @@ public class OpenSslEngineTest extends SSLEngineTest {
ByteBuffer src2 = src.duplicate();
ByteBuffer dst = allocateBuffer(src.capacity()
+ ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH);
+ MAX_TLS_RECORD_OVERHEAD_LENGTH);
SSLEngineResult result = clientEngine.wrap(new ByteBuffer[] { src, src2 }, dst);
assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus());
@ -284,7 +288,7 @@ public class OpenSslEngineTest extends SSLEngineTest {
ByteBuffer src = allocateBuffer(1024);
List<ByteBuffer> srcList = new ArrayList<ByteBuffer>();
long srcsLen = 0;
long maxLen = ((long) Integer.MAX_VALUE) * 2;
long maxLen = ((long) MAX_VALUE) * 2;
while (srcsLen < maxLen) {
ByteBuffer dup = src.duplicate();
@ -294,7 +298,7 @@ public class OpenSslEngineTest extends SSLEngineTest {
ByteBuffer[] srcs = srcList.toArray(new ByteBuffer[srcList.size()]);
ByteBuffer dst = allocateBuffer(ReferenceCountedOpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH - 1);
ByteBuffer dst = allocateBuffer(MAX_ENCRYPTED_PACKET_LENGTH - 1);
SSLEngineResult result = clientEngine.wrap(srcs, dst);
assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus());
@ -313,14 +317,14 @@ public class OpenSslEngineTest extends SSLEngineTest {
@Test
public void testCalculateOutNetBufSizeOverflow() {
assertEquals(ReferenceCountedOpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH,
ReferenceCountedOpenSslEngine.calculateOutNetBufSize(Integer.MAX_VALUE));
assertEquals(MAX_VALUE,
ReferenceCountedOpenSslEngine.calculateOutNetBufSize(MAX_VALUE, 1));
}
@Test
public void testCalculateOutNetBufSize0() {
assertEquals(ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH,
ReferenceCountedOpenSslEngine.calculateOutNetBufSize(0));
assertEquals(MAX_TLS_RECORD_OVERHEAD_LENGTH,
ReferenceCountedOpenSslEngine.calculateOutNetBufSize(0, 1));
}
@Override
@ -538,9 +542,9 @@ public class OpenSslEngineTest extends SSLEngineTest {
do {
testWrapDstBigEnough(clientEngine, srcLen);
srcLen += 64;
} while (srcLen < ReferenceCountedOpenSslEngine.MAX_PLAINTEXT_LENGTH);
} while (srcLen < MAX_PLAINTEXT_LENGTH);
testWrapDstBigEnough(clientEngine, ReferenceCountedOpenSslEngine.MAX_PLAINTEXT_LENGTH);
testWrapDstBigEnough(clientEngine, MAX_PLAINTEXT_LENGTH);
} finally {
cleanupClientSslEngine(clientEngine);
cleanupServerSslEngine(serverEngine);
@ -549,7 +553,7 @@ public class OpenSslEngineTest extends SSLEngineTest {
private void testWrapDstBigEnough(SSLEngine engine, int srcLen) throws SSLException {
ByteBuffer src = allocateBuffer(srcLen);
ByteBuffer dst = allocateBuffer(srcLen + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH);
ByteBuffer dst = allocateBuffer(srcLen + MAX_TLS_RECORD_OVERHEAD_LENGTH);
SSLEngineResult result = engine.wrap(src, dst);
assertEquals(SSLEngineResult.Status.OK, result.getStatus());

View File

@ -32,7 +32,11 @@ import javax.net.ssl.X509TrustManager;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.EventLoopGroup;
@ -69,6 +73,7 @@ import java.security.KeyStore;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
@ -560,4 +565,129 @@ public class SslHandlerTest {
ReferenceCountUtil.release(sslClientCtx);
}
}
@Test(timeout = 30000)
public void testCompositeBufSizeEstimationGuaranteesSynchronousWrite()
throws CertificateException, SSLException, ExecutionException, InterruptedException {
SslProvider[] providers = SslProvider.values();
for (int i = 0; i < providers.length; ++i) {
for (int j = 0; j < providers.length; ++j) {
compositeBufSizeEstimationGuaranteesSynchronousWrite(providers[i], providers[j]);
}
}
}
private void compositeBufSizeEstimationGuaranteesSynchronousWrite(
SslProvider serverProvider, SslProvider clientProvider)
throws CertificateException, SSLException, ExecutionException, InterruptedException {
SelfSignedCertificate ssc = new SelfSignedCertificate();
final SslContext sslServerCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey())
.sslProvider(serverProvider)
.build();
final SslContext sslClientCtx = SslContextBuilder.forClient()
.trustManager(InsecureTrustManagerFactory.INSTANCE)
.sslProvider(clientProvider).build();
EventLoopGroup group = new NioEventLoopGroup();
Channel sc = null;
Channel cc = null;
try {
final Promise<Void> donePromise = group.next().newPromise();
final int expectedBytes = 469 + 1024 + 1024;
sc = new ServerBootstrap()
.group(group)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ch.pipeline().addLast(sslServerCtx.newHandler(ch.alloc()));
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof SslHandshakeCompletionEvent) {
SslHandshakeCompletionEvent sslEvt = (SslHandshakeCompletionEvent) evt;
if (sslEvt.isSuccess()) {
final ByteBuf input = ctx.alloc().buffer();
input.writeBytes(new byte[expectedBytes]);
CompositeByteBuf content = ctx.alloc().compositeBuffer();
content.addComponent(true, input.readRetainedSlice(469));
content.addComponent(true, input.readRetainedSlice(1024));
content.addComponent(true, input.readRetainedSlice(1024));
ctx.writeAndFlush(content).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
input.release();
}
});
} else {
donePromise.tryFailure(sslEvt.cause());
}
}
ctx.fireUserEventTriggered(evt);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
donePromise.tryFailure(cause);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) {
donePromise.tryFailure(new IllegalStateException("server closed"));
}
});
}
}).bind(new InetSocketAddress(0)).syncUninterruptibly().channel();
cc = new Bootstrap()
.group(group)
.channel(NioSocketChannel.class)
.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ch.pipeline().addLast(sslClientCtx.newHandler(ch.alloc()));
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
private int bytesSeen;
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
if (msg instanceof ByteBuf) {
bytesSeen += ((ByteBuf) msg).readableBytes();
if (bytesSeen == expectedBytes) {
donePromise.trySuccess(null);
}
}
ReferenceCountUtil.release(msg);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
donePromise.tryFailure(cause);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) {
donePromise.tryFailure(new IllegalStateException("client closed"));
}
});
}
}).connect(sc.localAddress()).syncUninterruptibly().channel();
donePromise.get();
} finally {
if (cc != null) {
cc.close().syncUninterruptibly();
}
if (sc != null) {
sc.close().syncUninterruptibly();
}
group.shutdownGracefully();
ReferenceCountUtil.release(sslServerCtx);
ReferenceCountUtil.release(sslClientCtx);
ssc.delete();
}
}
}