OpenSslEngine wrap may generate bad data if multiple src buffers

Motivation:
SSL_write requires a fixed amount of bytes for overhead related to the encryption process for each call. OpenSslEngine#wrap(..) will attempt to encrypt multiple input buffers until MAX_PLAINTEXT_LENGTH are consumed, but the size estimation provided by calculateOutNetBufSize may not leave enough room for each call to SSL_write. If SSL_write is not able to completely write results to the destination buffer it will keep state and attempt to write it later. Netty doesn't account for SSL_write keeping state and assumes all writes will complete synchronously (by attempting to allocate enough space to account for the overhead) and feeds the same data to SSL_write again later which results in corrupted data being generated.

Modifications:
- OpenSslEngine#wrap should only produce a single TLS packet according to the SSLEngine API specificaiton [1].
[1] https://docs.oracle.com/javase/8/docs/api/javax/net/ssl/SSLEngine.html#wrap-java.nio.ByteBuffer:A-int-int-java.nio.ByteBuffer-
- OpenSslEngine#wrap should only consider a single buffer when determining if there is enough space to write, because only a single buffer will ever be consumed.

Result:
OpenSslEngine#wrap will no longer produce corrupted data due to incorrect accounting of space required in the destination buffers.
This commit is contained in:
Scott Mitchell 2017-05-08 10:13:53 -07:00
parent cd80b6c2d8
commit 141089998f
3 changed files with 73 additions and 57 deletions

View File

@ -610,9 +610,9 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
}
}
if (dst.remaining() < calculateOutNetBufSize(srcsLen, endOffset - offset)) {
// Can not hold the maximum packet so we need to tell the caller to use a bigger destination
// buffer.
// we will only produce a single TLS packet, and we don't aggregate src buffers,
// so we always fix the number of buffers to 1 when checking if the dst buffer is large enough.
if (dst.remaining() < calculateOutNetBufSize(srcsLen, 1)) {
return new SSLEngineResult(BUFFER_OVERFLOW, getHandshakeStatus(), 0, 0);
}
@ -638,9 +638,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
bytesProduced += bioLengthBefore - pendingNow;
bioLengthBefore = pendingNow;
if (bytesConsumed == MAX_PLAINTEXT_LENGTH || bytesProduced == dst.remaining()) {
return newResultMayFinishHandshake(status, bytesConsumed, bytesProduced);
}
return newResultMayFinishHandshake(status, bytesConsumed, bytesProduced);
} else {
int sslError = SSL.getError(ssl, bytesWritten);
if (sslError == SSL.SSL_ERROR_ZERO_RETURN) {

View File

@ -210,7 +210,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
}
@Override
int calculateOutNetBufSize(SslHandler handler, int pendingBytes, int numComponents) {
int calculateWrapBufferCapacity(SslHandler handler, int pendingBytes, int numComponents) {
return ReferenceCountedOpenSslEngine.calculateOutNetBufSize(pendingBytes, numComponents);
}
},
@ -242,7 +242,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
}
@Override
int calculateOutNetBufSize(SslHandler handler, int pendingBytes, int numComponents) {
int calculateWrapBufferCapacity(SslHandler handler, int pendingBytes, int numComponents) {
return ((ConscryptAlpnSslEngine) handler.engine).calculateOutNetBufSize(pendingBytes, numComponents);
}
},
@ -258,7 +258,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
}
@Override
int calculateOutNetBufSize(SslHandler handler, int pendingBytes, int numComponents) {
int calculateWrapBufferCapacity(SslHandler handler, int pendingBytes, int numComponents) {
return handler.maxPacketBufferSize;
}
};
@ -281,7 +281,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
abstract SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int readerIndex, int len, ByteBuf out)
throws SSLException;
abstract int calculateOutNetBufSize(SslHandler handler, int pendingBytes, int numComponents);
abstract int calculateWrapBufferCapacity(SslHandler handler, int pendingBytes, int numComponents);
// BEGIN Platform-dependent flags
@ -1719,7 +1719,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
* the specified amount of pending bytes.
*/
private ByteBuf allocateOutNetBuf(ChannelHandlerContext ctx, int pendingBytes, int numComponents) {
return allocate(ctx, engineType.calculateOutNetBufSize(this, pendingBytes, numComponents));
return allocate(ctx, engineType.calculateWrapBufferCapacity(this, pendingBytes, numComponents));
}
private final class LazyChannelPromise extends DefaultPromise<Channel> {

View File

@ -16,37 +16,32 @@
package io.netty.handler.ssl;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.junit.Assert.*;
import static org.junit.Assume.assumeTrue;
import javax.net.ssl.ManagerFactoryParameters;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLProtocolException;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.CodecException;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.UnsupportedMessageTypeException;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.handler.ssl.util.SimpleTrustManagerFactory;
import io.netty.util.IllegalReferenceCountException;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.FutureListener;
import io.netty.util.concurrent.Promise;
@ -54,18 +49,6 @@ import io.netty.util.concurrent.PromiseNotifier;
import io.netty.util.internal.EmptyArrays;
import org.junit.Test;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.UnsupportedMessageTypeException;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import java.io.File;
import java.net.InetSocketAddress;
import java.nio.channels.ClosedChannelException;
@ -76,6 +59,23 @@ import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.net.ssl.ManagerFactoryParameters;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLProtocolException;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.junit.Assume.assumeTrue;
public class SslHandlerTest {
@ -566,7 +566,7 @@ public class SslHandlerTest {
}
}
@Test(timeout = 30000)
@Test(timeout = 300000)
public void testCompositeBufSizeEstimationGuaranteesSynchronousWrite()
throws CertificateException, SSLException, ExecutionException, InterruptedException {
SslProvider[] providers = SslProvider.values();
@ -576,7 +576,14 @@ public class SslHandlerTest {
for (int j = 0; j < providers.length; ++j) {
SslProvider clientProvider = providers[j];
if (isSupported(clientProvider)) {
compositeBufSizeEstimationGuaranteesSynchronousWrite(serverProvider, clientProvider);
compositeBufSizeEstimationGuaranteesSynchronousWrite(serverProvider, clientProvider,
true, true);
compositeBufSizeEstimationGuaranteesSynchronousWrite(serverProvider, clientProvider,
true, false);
compositeBufSizeEstimationGuaranteesSynchronousWrite(serverProvider, clientProvider,
false, true);
compositeBufSizeEstimationGuaranteesSynchronousWrite(serverProvider, clientProvider,
false, false);
}
}
}
@ -584,7 +591,8 @@ public class SslHandlerTest {
}
private static void compositeBufSizeEstimationGuaranteesSynchronousWrite(
SslProvider serverProvider, SslProvider clientProvider)
SslProvider serverProvider, SslProvider clientProvider,
final boolean letHandlerCreateServerEngine, final boolean letHandlerCreateClientEngine)
throws CertificateException, SSLException, ExecutionException, InterruptedException {
SelfSignedCertificate ssc = new SelfSignedCertificate();
@ -601,7 +609,14 @@ public class SslHandlerTest {
Channel cc = null;
try {
final Promise<Void> donePromise = group.next().newPromise();
final int expectedBytes = 469 + 1024 + 1024;
// The goal is to provide the SSLEngine with many ByteBuf components to ensure that the overhead for wrap
// is correctly accounted for on each component.
final int numComponents = 150;
// This is the TLS packet size. The goal is to divide the maximum amount of application data that can fit
// into a single TLS packet into many components to ensure the overhead is correctly taken into account.
final int desiredBytes = 16384;
final int singleComponentSize = desiredBytes / numComponents;
final int expectedBytes = numComponents * singleComponentSize;
sc = new ServerBootstrap()
.group(group)
@ -609,25 +624,24 @@ public class SslHandlerTest {
.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ch.pipeline().addLast(sslServerCtx.newHandler(ch.alloc()));
if (letHandlerCreateServerEngine) {
ch.pipeline().addLast(sslServerCtx.newHandler(ch.alloc()));
} else {
ch.pipeline().addLast(new SslHandler(sslServerCtx.newEngine(ch.alloc())));
}
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof SslHandshakeCompletionEvent) {
SslHandshakeCompletionEvent sslEvt = (SslHandshakeCompletionEvent) evt;
if (sslEvt.isSuccess()) {
final ByteBuf input = ctx.alloc().buffer();
input.writeBytes(new byte[expectedBytes]);
CompositeByteBuf content = ctx.alloc().compositeBuffer();
content.addComponent(true, input.readRetainedSlice(469));
content.addComponent(true, input.readRetainedSlice(1024));
content.addComponent(true, input.readRetainedSlice(1024));
ctx.writeAndFlush(content).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
input.release();
}
});
CompositeByteBuf content = ctx.alloc().compositeDirectBuffer(numComponents);
for (int i = 0; i < numComponents; ++i) {
ByteBuf buf = ctx.alloc().directBuffer(singleComponentSize);
buf.writerIndex(buf.writerIndex() + singleComponentSize);
content.addComponent(true, buf);
}
ctx.writeAndFlush(content);
} else {
donePromise.tryFailure(sslEvt.cause());
}
@ -654,7 +668,11 @@ public class SslHandlerTest {
.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ch.pipeline().addLast(sslClientCtx.newHandler(ch.alloc()));
if (letHandlerCreateClientEngine) {
ch.pipeline().addLast(sslClientCtx.newHandler(ch.alloc()));
} else {
ch.pipeline().addLast(new SslHandler(sslClientCtx.newEngine(ch.alloc())));
}
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
private int bytesSeen;
@Override