Correctly calculate the produced bytes in all cases when calling Refe… (#10063)

Motivation:

We did not correctly account for produced bytes when SSL_write(...) returns -1 in all cases. This could lead to lost data and so a corrupt SSL connection.

Modifications:

- Always ensure we calculate the produced bytes correctly
- Add unit tests

Result:

Fixes https://github.com/netty/netty/issues/10041
This commit is contained in:
Norman Maurer 2020-02-27 08:55:25 +01:00
parent 32970dc3d7
commit 1c28cf3a14
2 changed files with 123 additions and 5 deletions

View File

@ -859,14 +859,18 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
bytesWritten = writePlaintextData(src, min(remaining, availableCapacityForWrap));
}
// Determine how much encrypted data was generated.
//
// Even if SSL_write doesn't consume any application data it is possible that OpenSSL will
// produce non-application data into the BIO. For example session tickets....
// See https://github.com/netty/netty/issues/10041
final int pendingNow = SSL.bioLengthByteBuffer(networkBIO);
bytesProduced += bioLengthBefore - pendingNow;
bioLengthBefore = pendingNow;
if (bytesWritten > 0) {
bytesConsumed += bytesWritten;
// Determine how much encrypted data was generated:
final int pendingNow = SSL.bioLengthByteBuffer(networkBIO);
bytesProduced += bioLengthBefore - pendingNow;
bioLengthBefore = pendingNow;
if (jdkCompatibilityMode || bytesProduced == dst.remaining()) {
return newResultMayFinishHandshake(status, bytesConsumed, bytesProduced);
}

View File

@ -41,6 +41,7 @@ import io.netty.channel.local.LocalServerChannel;
import io.netty.channel.nio.NioHandler;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.CodecException;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.UnsupportedMessageTypeException;
@ -54,12 +55,14 @@ import io.netty.util.concurrent.FutureListener;
import io.netty.util.concurrent.ImmediateEventExecutor;
import io.netty.util.concurrent.ImmediateExecutor;
import io.netty.util.concurrent.Promise;
import io.netty.util.internal.PlatformDependent;
import org.hamcrest.CoreMatchers;
import org.junit.Test;
import java.net.InetSocketAddress;
import java.nio.channels.ClosedChannelException;
import java.security.NoSuchAlgorithmException;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletionException;
@ -70,6 +73,7 @@ import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
@ -1038,4 +1042,114 @@ public class SslHandlerTest {
ReferenceCountUtil.release(sslClientCtx);
}
}
@Test(timeout = 5000L)
public void testSessionTicketsWithTLSv12() throws Throwable {
testSessionTickets(SslUtils.PROTOCOL_TLS_V1_2);
}
@Test(timeout = 5000L)
public void testSessionTicketsWithTLSv13() throws Throwable {
assumeTrue(OpenSsl.isTlsv13Supported());
testSessionTickets(SslUtils.PROTOCOL_TLS_V1_3);
}
private static void testSessionTickets(String protocol) throws Throwable {
assumeTrue(OpenSsl.isAvailable());
final SslContext sslClientCtx = SslContextBuilder.forClient()
.trustManager(InsecureTrustManagerFactory.INSTANCE)
.sslProvider(SslProvider.OPENSSL)
.protocols(protocol)
.build();
final SelfSignedCertificate cert = new SelfSignedCertificate();
final SslContext sslServerCtx = SslContextBuilder.forServer(cert.key(), cert.cert())
.sslProvider(SslProvider.OPENSSL)
.protocols(protocol)
.build();
OpenSslSessionTicketKey key = new OpenSslSessionTicketKey(new byte[OpenSslSessionTicketKey.NAME_SIZE],
new byte[OpenSslSessionTicketKey.HMAC_KEY_SIZE], new byte[OpenSslSessionTicketKey.AES_KEY_SIZE]);
((OpenSslSessionContext) sslClientCtx.sessionContext()).setTicketKeys(key);
((OpenSslSessionContext) sslServerCtx.sessionContext()).setTicketKeys(key);
EventLoopGroup group = new MultithreadEventLoopGroup(NioHandler.newFactory());
Channel sc = null;
Channel cc = null;
final SslHandler clientSslHandler = sslClientCtx.newHandler(UnpooledByteBufAllocator.DEFAULT);
final SslHandler serverSslHandler = sslServerCtx.newHandler(UnpooledByteBufAllocator.DEFAULT);
final BlockingQueue<Object> queue = new LinkedBlockingQueue<Object>();
final byte[] bytes = new byte[96];
ThreadLocalRandom.current().nextBytes(bytes);
try {
sc = new ServerBootstrap()
.group(group)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) {
ch.pipeline().addLast(serverSslHandler);
ch.pipeline().addLast(new ChannelHandler() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof SslHandshakeCompletionEvent) {
ctx.writeAndFlush(Unpooled.wrappedBuffer(bytes));
}
}
});
}
})
.bind(new InetSocketAddress(0)).syncUninterruptibly().channel();
ChannelFuture future = new Bootstrap()
.group(group)
.channel(NioSocketChannel.class)
.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) {
ch.pipeline().addLast(clientSslHandler);
ch.pipeline().addLast(new ByteToMessageDecoder() {
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in) {
if (in.readableBytes() == bytes.length) {
queue.add(in.readBytes(bytes.length));
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
queue.add(cause);
}
});
}
}).connect(sc.localAddress());
cc = future.syncUninterruptibly().channel();
assertTrue(clientSslHandler.handshakeFuture().await().isSuccess());
assertTrue(serverSslHandler.handshakeFuture().await().isSuccess());
Object obj = queue.take();
if (obj instanceof ByteBuf) {
ByteBuf buffer = (ByteBuf) obj;
ByteBuf expected = Unpooled.wrappedBuffer(bytes);
try {
assertEquals(expected, buffer);
} finally {
expected.release();
}
} else {
throw (Throwable) obj;
}
} finally {
if (cc != null) {
cc.close().syncUninterruptibly();
}
if (sc != null) {
sc.close().syncUninterruptibly();
}
group.shutdownGracefully();
ReferenceCountUtil.release(sslClientCtx);
}
}
}