OpenSslEngine wrap may generate bad data if multiple src buffers

Motivation:
SSL_write requires a fixed amount of bytes for overhead related to the encryption process for each call. OpenSslEngine#wrap(..) will attempt to encrypt multiple input buffers until MAX_PLAINTEXT_LENGTH are consumed, but the size estimation provided by calculateOutNetBufSize may not leave enough room for each call to SSL_write. If SSL_write is not able to completely write results to the destination buffer it will keep state and attempt to write it later. Netty doesn't account for SSL_write keeping state and assumes all writes will complete synchronously (by attempting to allocate enough space to account for the overhead) and feeds the same data to SSL_write again later which results in corrupted data being generated.

Modifications:
- OpenSslEngine#wrap should only produce a single TLS packet according to the SSLEngine API specificaiton [1].
[1] https://docs.oracle.com/javase/8/docs/api/javax/net/ssl/SSLEngine.html#wrap-java.nio.ByteBuffer:A-int-int-java.nio.ByteBuffer-
- OpenSslEngine#wrap should only consider a single buffer when determining if there is enough space to write, because only a single buffer will ever be consumed.

Result:
OpenSslEngine#wrap will no longer produce corrupted data due to incorrect accounting of space required in the destination buffers.
This commit is contained in:
Scott Mitchell 2017-05-08 10:13:53 -07:00
parent cd80b6c2d8
commit 141089998f
3 changed files with 73 additions and 57 deletions

View File

@ -610,9 +610,9 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
} }
} }
if (dst.remaining() < calculateOutNetBufSize(srcsLen, endOffset - offset)) { // we will only produce a single TLS packet, and we don't aggregate src buffers,
// Can not hold the maximum packet so we need to tell the caller to use a bigger destination // so we always fix the number of buffers to 1 when checking if the dst buffer is large enough.
// buffer. if (dst.remaining() < calculateOutNetBufSize(srcsLen, 1)) {
return new SSLEngineResult(BUFFER_OVERFLOW, getHandshakeStatus(), 0, 0); return new SSLEngineResult(BUFFER_OVERFLOW, getHandshakeStatus(), 0, 0);
} }
@ -638,9 +638,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
bytesProduced += bioLengthBefore - pendingNow; bytesProduced += bioLengthBefore - pendingNow;
bioLengthBefore = pendingNow; bioLengthBefore = pendingNow;
if (bytesConsumed == MAX_PLAINTEXT_LENGTH || bytesProduced == dst.remaining()) { return newResultMayFinishHandshake(status, bytesConsumed, bytesProduced);
return newResultMayFinishHandshake(status, bytesConsumed, bytesProduced);
}
} else { } else {
int sslError = SSL.getError(ssl, bytesWritten); int sslError = SSL.getError(ssl, bytesWritten);
if (sslError == SSL.SSL_ERROR_ZERO_RETURN) { if (sslError == SSL.SSL_ERROR_ZERO_RETURN) {

View File

@ -210,7 +210,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
} }
@Override @Override
int calculateOutNetBufSize(SslHandler handler, int pendingBytes, int numComponents) { int calculateWrapBufferCapacity(SslHandler handler, int pendingBytes, int numComponents) {
return ReferenceCountedOpenSslEngine.calculateOutNetBufSize(pendingBytes, numComponents); return ReferenceCountedOpenSslEngine.calculateOutNetBufSize(pendingBytes, numComponents);
} }
}, },
@ -242,7 +242,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
} }
@Override @Override
int calculateOutNetBufSize(SslHandler handler, int pendingBytes, int numComponents) { int calculateWrapBufferCapacity(SslHandler handler, int pendingBytes, int numComponents) {
return ((ConscryptAlpnSslEngine) handler.engine).calculateOutNetBufSize(pendingBytes, numComponents); return ((ConscryptAlpnSslEngine) handler.engine).calculateOutNetBufSize(pendingBytes, numComponents);
} }
}, },
@ -258,7 +258,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
} }
@Override @Override
int calculateOutNetBufSize(SslHandler handler, int pendingBytes, int numComponents) { int calculateWrapBufferCapacity(SslHandler handler, int pendingBytes, int numComponents) {
return handler.maxPacketBufferSize; return handler.maxPacketBufferSize;
} }
}; };
@ -281,7 +281,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
abstract SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int readerIndex, int len, ByteBuf out) abstract SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int readerIndex, int len, ByteBuf out)
throws SSLException; throws SSLException;
abstract int calculateOutNetBufSize(SslHandler handler, int pendingBytes, int numComponents); abstract int calculateWrapBufferCapacity(SslHandler handler, int pendingBytes, int numComponents);
// BEGIN Platform-dependent flags // BEGIN Platform-dependent flags
@ -1719,7 +1719,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
* the specified amount of pending bytes. * the specified amount of pending bytes.
*/ */
private ByteBuf allocateOutNetBuf(ChannelHandlerContext ctx, int pendingBytes, int numComponents) { private ByteBuf allocateOutNetBuf(ChannelHandlerContext ctx, int pendingBytes, int numComponents) {
return allocate(ctx, engineType.calculateOutNetBufSize(this, pendingBytes, numComponents)); return allocate(ctx, engineType.calculateWrapBufferCapacity(this, pendingBytes, numComponents));
} }
private final class LazyChannelPromise extends DefaultPromise<Channel> { private final class LazyChannelPromise extends DefaultPromise<Channel> {

View File

@ -16,37 +16,32 @@
package io.netty.handler.ssl; package io.netty.handler.ssl;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.junit.Assert.*;
import static org.junit.Assume.assumeTrue;
import javax.net.ssl.ManagerFactoryParameters;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLProtocolException;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap; import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.EventLoopGroup; import io.netty.channel.EventLoopGroup;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.CodecException; import io.netty.handler.codec.CodecException;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.UnsupportedMessageTypeException;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.handler.ssl.util.SimpleTrustManagerFactory; import io.netty.handler.ssl.util.SimpleTrustManagerFactory;
import io.netty.util.IllegalReferenceCountException; import io.netty.util.IllegalReferenceCountException;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import io.netty.util.concurrent.Future; import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.FutureListener; import io.netty.util.concurrent.FutureListener;
import io.netty.util.concurrent.Promise; import io.netty.util.concurrent.Promise;
@ -54,18 +49,6 @@ import io.netty.util.concurrent.PromiseNotifier;
import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.EmptyArrays;
import org.junit.Test; import org.junit.Test;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.UnsupportedMessageTypeException;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import java.io.File; import java.io.File;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.nio.channels.ClosedChannelException; import java.nio.channels.ClosedChannelException;
@ -76,6 +59,23 @@ import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import javax.net.ssl.ManagerFactoryParameters;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLProtocolException;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.junit.Assume.assumeTrue;
public class SslHandlerTest { public class SslHandlerTest {
@ -566,7 +566,7 @@ public class SslHandlerTest {
} }
} }
@Test(timeout = 30000) @Test(timeout = 300000)
public void testCompositeBufSizeEstimationGuaranteesSynchronousWrite() public void testCompositeBufSizeEstimationGuaranteesSynchronousWrite()
throws CertificateException, SSLException, ExecutionException, InterruptedException { throws CertificateException, SSLException, ExecutionException, InterruptedException {
SslProvider[] providers = SslProvider.values(); SslProvider[] providers = SslProvider.values();
@ -576,7 +576,14 @@ public class SslHandlerTest {
for (int j = 0; j < providers.length; ++j) { for (int j = 0; j < providers.length; ++j) {
SslProvider clientProvider = providers[j]; SslProvider clientProvider = providers[j];
if (isSupported(clientProvider)) { if (isSupported(clientProvider)) {
compositeBufSizeEstimationGuaranteesSynchronousWrite(serverProvider, clientProvider); compositeBufSizeEstimationGuaranteesSynchronousWrite(serverProvider, clientProvider,
true, true);
compositeBufSizeEstimationGuaranteesSynchronousWrite(serverProvider, clientProvider,
true, false);
compositeBufSizeEstimationGuaranteesSynchronousWrite(serverProvider, clientProvider,
false, true);
compositeBufSizeEstimationGuaranteesSynchronousWrite(serverProvider, clientProvider,
false, false);
} }
} }
} }
@ -584,7 +591,8 @@ public class SslHandlerTest {
} }
private static void compositeBufSizeEstimationGuaranteesSynchronousWrite( private static void compositeBufSizeEstimationGuaranteesSynchronousWrite(
SslProvider serverProvider, SslProvider clientProvider) SslProvider serverProvider, SslProvider clientProvider,
final boolean letHandlerCreateServerEngine, final boolean letHandlerCreateClientEngine)
throws CertificateException, SSLException, ExecutionException, InterruptedException { throws CertificateException, SSLException, ExecutionException, InterruptedException {
SelfSignedCertificate ssc = new SelfSignedCertificate(); SelfSignedCertificate ssc = new SelfSignedCertificate();
@ -601,7 +609,14 @@ public class SslHandlerTest {
Channel cc = null; Channel cc = null;
try { try {
final Promise<Void> donePromise = group.next().newPromise(); final Promise<Void> donePromise = group.next().newPromise();
final int expectedBytes = 469 + 1024 + 1024; // The goal is to provide the SSLEngine with many ByteBuf components to ensure that the overhead for wrap
// is correctly accounted for on each component.
final int numComponents = 150;
// This is the TLS packet size. The goal is to divide the maximum amount of application data that can fit
// into a single TLS packet into many components to ensure the overhead is correctly taken into account.
final int desiredBytes = 16384;
final int singleComponentSize = desiredBytes / numComponents;
final int expectedBytes = numComponents * singleComponentSize;
sc = new ServerBootstrap() sc = new ServerBootstrap()
.group(group) .group(group)
@ -609,25 +624,24 @@ public class SslHandlerTest {
.childHandler(new ChannelInitializer<Channel>() { .childHandler(new ChannelInitializer<Channel>() {
@Override @Override
protected void initChannel(Channel ch) throws Exception { protected void initChannel(Channel ch) throws Exception {
ch.pipeline().addLast(sslServerCtx.newHandler(ch.alloc())); if (letHandlerCreateServerEngine) {
ch.pipeline().addLast(sslServerCtx.newHandler(ch.alloc()));
} else {
ch.pipeline().addLast(new SslHandler(sslServerCtx.newEngine(ch.alloc())));
}
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override @Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof SslHandshakeCompletionEvent) { if (evt instanceof SslHandshakeCompletionEvent) {
SslHandshakeCompletionEvent sslEvt = (SslHandshakeCompletionEvent) evt; SslHandshakeCompletionEvent sslEvt = (SslHandshakeCompletionEvent) evt;
if (sslEvt.isSuccess()) { if (sslEvt.isSuccess()) {
final ByteBuf input = ctx.alloc().buffer(); CompositeByteBuf content = ctx.alloc().compositeDirectBuffer(numComponents);
input.writeBytes(new byte[expectedBytes]); for (int i = 0; i < numComponents; ++i) {
CompositeByteBuf content = ctx.alloc().compositeBuffer(); ByteBuf buf = ctx.alloc().directBuffer(singleComponentSize);
content.addComponent(true, input.readRetainedSlice(469)); buf.writerIndex(buf.writerIndex() + singleComponentSize);
content.addComponent(true, input.readRetainedSlice(1024)); content.addComponent(true, buf);
content.addComponent(true, input.readRetainedSlice(1024)); }
ctx.writeAndFlush(content).addListener(new ChannelFutureListener() { ctx.writeAndFlush(content);
@Override
public void operationComplete(ChannelFuture future) {
input.release();
}
});
} else { } else {
donePromise.tryFailure(sslEvt.cause()); donePromise.tryFailure(sslEvt.cause());
} }
@ -654,7 +668,11 @@ public class SslHandlerTest {
.handler(new ChannelInitializer<Channel>() { .handler(new ChannelInitializer<Channel>() {
@Override @Override
protected void initChannel(Channel ch) throws Exception { protected void initChannel(Channel ch) throws Exception {
ch.pipeline().addLast(sslClientCtx.newHandler(ch.alloc())); if (letHandlerCreateClientEngine) {
ch.pipeline().addLast(sslClientCtx.newHandler(ch.alloc()));
} else {
ch.pipeline().addLast(new SslHandler(sslClientCtx.newEngine(ch.alloc())));
}
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
private int bytesSeen; private int bytesSeen;
@Override @Override