diff --git a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java index 80b9990936..679a57e9d6 100644 --- a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java +++ b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java @@ -57,6 +57,7 @@ import javax.net.ssl.SSLSessionContext; import javax.security.cert.X509Certificate; import static io.netty.handler.ssl.OpenSsl.memoryAddress; +import static io.netty.handler.ssl.SslUtils.SSL_RECORD_HEADER_LENGTH; import static io.netty.util.internal.EmptyArrays.EMPTY_CERTIFICATES; import static io.netty.util.internal.EmptyArrays.EMPTY_JAVAX_X509_CERTIFICATES; import static io.netty.util.internal.ObjectUtil.checkNotNull; @@ -107,15 +108,14 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc * allow up to 255 bytes. 16 bytes is the max for PKC#5 (which handles it the same way as PKC#7) as we use a block * size of 16. See rfc5652#section-6.3. * - * 16 (IV) + 48 (MAC) + 1 (Padding_length field) + 15 (Padding) + 1 (ContentType) + 2 (ProtocolVersion) + 2 (Length) + * TLS Header (5) + 16 (IV) + 48 (MAC) + 1 (Padding_length field) + 15 (Padding) + 1 (ContentType) + + * 2 (ProtocolVersion) + 2 (Length) * * TODO: We may need to review this calculation once TLS 1.3 becomes available. */ - static final int MAX_ENCRYPTION_OVERHEAD_LENGTH = 15 + 48 + 1 + 16 + 1 + 2 + 2; + static final int MAX_TLS_RECORD_OVERHEAD_LENGTH = SSL_RECORD_HEADER_LENGTH + 16 + 48 + 1 + 15 + 1 + 2 + 2; - static final int MAX_ENCRYPTED_PACKET_LENGTH = MAX_PLAINTEXT_LENGTH + MAX_ENCRYPTION_OVERHEAD_LENGTH; - - private static final int MAX_ENCRYPTION_OVERHEAD_DIFF = Integer.MAX_VALUE - MAX_ENCRYPTION_OVERHEAD_LENGTH; + static final int MAX_ENCRYPTED_PACKET_LENGTH = MAX_PLAINTEXT_LENGTH + MAX_TLS_RECORD_OVERHEAD_LENGTH; private static final AtomicIntegerFieldUpdater DESTROYED_UPDATER = AtomicIntegerFieldUpdater.newUpdater(ReferenceCountedOpenSslEngine.class, "destroyed"); @@ -561,7 +561,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc } } - if (dst.remaining() < calculateOutNetBufSize(srcsLen)) { + if (dst.remaining() < calculateOutNetBufSize(srcsLen, endOffset - offset)) { // Can not hold the maximum packet so we need to tell the caller to use a bigger destination // buffer. return new SSLEngineResult(BUFFER_OVERFLOW, getHandshakeStatus(), 0, 0); @@ -772,7 +772,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc } } - if (len < SslUtils.SSL_RECORD_HEADER_LENGTH) { + if (len < SSL_RECORD_HEADER_LENGTH) { return newResultMayFinishHandshake(BUFFER_UNDERFLOW, status, 0, 0); } @@ -782,7 +782,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc throw new NotSslRecordException("not an SSL/TLS record"); } - if (packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH > capacity) { + if (packetLength - SSL_RECORD_HEADER_LENGTH > capacity) { // No enough space in the destination buffer so signal the caller // that the buffer needs to be increased. return newResultMayFinishHandshake(BUFFER_OVERFLOW, status, 0, 0); @@ -1606,9 +1606,8 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc return destroyed != 0; } - static int calculateOutNetBufSize(int pendingBytes) { - return min(MAX_ENCRYPTED_PACKET_LENGTH, MAX_ENCRYPTION_OVERHEAD_LENGTH - + min(MAX_ENCRYPTION_OVERHEAD_DIFF, pendingBytes)); + static int calculateOutNetBufSize(int pendingBytes, int numComponents) { + return (int) min(Integer.MAX_VALUE, pendingBytes + (long) MAX_TLS_RECORD_OVERHEAD_LENGTH * numComponents); } private final class OpenSslSession implements SSLSession, ApplicationProtocolAccessor { diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index 601f8dfd29..ac9843b0e8 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -193,7 +193,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH * we can use a special {@link OpenSslEngine#unwrap(ByteBuffer[], ByteBuffer[])} method * that accepts multiple {@link ByteBuffer}s without additional memory copies. */ - OpenSslEngine opensslEngine = (OpenSslEngine) handler.engine; + ReferenceCountedOpenSslEngine opensslEngine = (ReferenceCountedOpenSslEngine) handler.engine; try { handler.singleBuffer[0] = toByteBuffer(out, writerIndex, out.writableBytes()); @@ -210,8 +210,8 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH } @Override - int calculateOutNetBufSize(SslHandler handler, int pendingBytes) { - return ReferenceCountedOpenSslEngine.calculateOutNetBufSize(pendingBytes); + int calculateOutNetBufSize(SslHandler handler, int pendingBytes, int numComponents) { + return ReferenceCountedOpenSslEngine.calculateOutNetBufSize(pendingBytes, numComponents); } }, JDK(false, MERGE_CUMULATOR) { @@ -226,16 +226,13 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH } @Override - int calculateOutNetBufSize(SslHandler handler, int pendingBytes) { + int calculateOutNetBufSize(SslHandler handler, int pendingBytes, int numComponents) { return handler.maxPacketBufferSize; } }; static SslEngineType forEngine(SSLEngine engine) { - if (engine instanceof OpenSslEngine) { - return TCNATIVE; - } - return JDK; + return engine instanceof ReferenceCountedOpenSslEngine ? TCNATIVE : JDK; } SslEngineType(boolean wantsDirectBuffer, Cumulator cumulator) { @@ -246,7 +243,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH abstract SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int readerIndex, int len, ByteBuf out) throws SSLException; - abstract int calculateOutNetBufSize(SslHandler handler, int pendingBytes); + abstract int calculateOutNetBufSize(SslHandler handler, int pendingBytes, int numComponents); // BEGIN Platform-dependent flags @@ -652,7 +649,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH ByteBuf buf = (ByteBuf) msg; if (out == null) { - out = allocateOutNetBuf(ctx, buf.readableBytes()); + out = allocateOutNetBuf(ctx, buf.readableBytes(), buf.nioBufferCount()); } SSLEngineResult result = wrap(alloc, engine, buf, out); @@ -741,7 +738,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH // As this is called for the handshake we have no real idea how big the buffer needs to be. // That said 2048 should give us enough room to include everything like ALPN / NPN data. // If this is not enough we will increase the buffer in wrap(...). - out = allocateOutNetBuf(ctx, 2048); + out = allocateOutNetBuf(ctx, 2048, 1); } SSLEngineResult result = wrap(alloc, engine, Unpooled.EMPTY_BUFFER, out); @@ -1694,8 +1691,8 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH * Allocates an outbound network buffer for {@link SSLEngine#wrap(ByteBuffer, ByteBuffer)} which can encrypt * the specified amount of pending bytes. */ - private ByteBuf allocateOutNetBuf(ChannelHandlerContext ctx, int pendingBytes) { - return allocate(ctx, engineType.calculateOutNetBufSize(this, pendingBytes)); + private ByteBuf allocateOutNetBuf(ChannelHandlerContext ctx, int pendingBytes, int numComponents) { + return allocate(ctx, engineType.calculateOutNetBufSize(this, pendingBytes, numComponents)); } private final class LazyChannelPromise extends DefaultPromise { diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java index 517271441c..07fcc48380 100644 --- a/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java @@ -39,7 +39,11 @@ import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLException; import javax.net.ssl.SSLParameters; +import static io.netty.handler.ssl.ReferenceCountedOpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH; +import static io.netty.handler.ssl.ReferenceCountedOpenSslEngine.MAX_TLS_RECORD_OVERHEAD_LENGTH; +import static io.netty.handler.ssl.ReferenceCountedOpenSslEngine.MAX_PLAINTEXT_LENGTH; import static io.netty.internal.tcnative.SSL.SSL_CVERIFY_IGNORED; +import static java.lang.Integer.MAX_VALUE; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; @@ -200,12 +204,12 @@ public class OpenSslEngineTest extends SSLEngineTest { ByteBuffer src = allocateBuffer(srcLen); ByteBuffer dstTooSmall = allocateBuffer( - src.capacity() + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH - 1); + src.capacity() + MAX_TLS_RECORD_OVERHEAD_LENGTH - 1); ByteBuffer dst = allocateBuffer( - src.capacity() + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH); + src.capacity() + MAX_TLS_RECORD_OVERHEAD_LENGTH); // Check that we fail to wrap if the dst buffers capacity is not at least - // src.capacity() + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH + // src.capacity() + ReferenceCountedOpenSslEngine.MAX_TLS_RECORD_OVERHEAD_LENGTH SSLEngineResult result = clientEngine.wrap(src, dstTooSmall); assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus()); assertEquals(0, result.bytesConsumed()); @@ -214,7 +218,7 @@ public class OpenSslEngineTest extends SSLEngineTest { assertEquals(dst.remaining(), dst.capacity()); // Check that we can wrap with a dst buffer that has the capacity of - // src.capacity() + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH + // src.capacity() + ReferenceCountedOpenSslEngine.MAX_TLS_RECORD_OVERHEAD_LENGTH result = clientEngine.wrap(src, dst); assertEquals(SSLEngineResult.Status.OK, result.getStatus()); assertEquals(srcLen, result.bytesConsumed()); @@ -249,7 +253,7 @@ public class OpenSslEngineTest extends SSLEngineTest { ByteBuffer src2 = src.duplicate(); ByteBuffer dst = allocateBuffer(src.capacity() - + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH); + + MAX_TLS_RECORD_OVERHEAD_LENGTH); SSLEngineResult result = clientEngine.wrap(new ByteBuffer[] { src, src2 }, dst); assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus()); @@ -284,7 +288,7 @@ public class OpenSslEngineTest extends SSLEngineTest { ByteBuffer src = allocateBuffer(1024); List srcList = new ArrayList(); long srcsLen = 0; - long maxLen = ((long) Integer.MAX_VALUE) * 2; + long maxLen = ((long) MAX_VALUE) * 2; while (srcsLen < maxLen) { ByteBuffer dup = src.duplicate(); @@ -294,7 +298,7 @@ public class OpenSslEngineTest extends SSLEngineTest { ByteBuffer[] srcs = srcList.toArray(new ByteBuffer[srcList.size()]); - ByteBuffer dst = allocateBuffer(ReferenceCountedOpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH - 1); + ByteBuffer dst = allocateBuffer(MAX_ENCRYPTED_PACKET_LENGTH - 1); SSLEngineResult result = clientEngine.wrap(srcs, dst); assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus()); @@ -313,14 +317,14 @@ public class OpenSslEngineTest extends SSLEngineTest { @Test public void testCalculateOutNetBufSizeOverflow() { - assertEquals(ReferenceCountedOpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH, - ReferenceCountedOpenSslEngine.calculateOutNetBufSize(Integer.MAX_VALUE)); + assertEquals(MAX_VALUE, + ReferenceCountedOpenSslEngine.calculateOutNetBufSize(MAX_VALUE, 1)); } @Test public void testCalculateOutNetBufSize0() { - assertEquals(ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH, - ReferenceCountedOpenSslEngine.calculateOutNetBufSize(0)); + assertEquals(MAX_TLS_RECORD_OVERHEAD_LENGTH, + ReferenceCountedOpenSslEngine.calculateOutNetBufSize(0, 1)); } @Override @@ -538,9 +542,9 @@ public class OpenSslEngineTest extends SSLEngineTest { do { testWrapDstBigEnough(clientEngine, srcLen); srcLen += 64; - } while (srcLen < ReferenceCountedOpenSslEngine.MAX_PLAINTEXT_LENGTH); + } while (srcLen < MAX_PLAINTEXT_LENGTH); - testWrapDstBigEnough(clientEngine, ReferenceCountedOpenSslEngine.MAX_PLAINTEXT_LENGTH); + testWrapDstBigEnough(clientEngine, MAX_PLAINTEXT_LENGTH); } finally { cleanupClientSslEngine(clientEngine); cleanupServerSslEngine(serverEngine); @@ -549,7 +553,7 @@ public class OpenSslEngineTest extends SSLEngineTest { private void testWrapDstBigEnough(SSLEngine engine, int srcLen) throws SSLException { ByteBuffer src = allocateBuffer(srcLen); - ByteBuffer dst = allocateBuffer(srcLen + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH); + ByteBuffer dst = allocateBuffer(srcLen + MAX_TLS_RECORD_OVERHEAD_LENGTH); SSLEngineResult result = engine.wrap(src, dst); assertEquals(SSLEngineResult.Status.OK, result.getStatus()); diff --git a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java index b39de0fced..89ff9223b0 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java @@ -32,7 +32,11 @@ import javax.net.ssl.X509TrustManager; import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelInitializer; import io.netty.channel.EventLoopGroup; @@ -69,6 +73,7 @@ import java.security.KeyStore; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ExecutionException; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicBoolean; @@ -560,4 +565,129 @@ public class SslHandlerTest { ReferenceCountUtil.release(sslClientCtx); } } + + @Test(timeout = 30000) + public void testCompositeBufSizeEstimationGuaranteesSynchronousWrite() + throws CertificateException, SSLException, ExecutionException, InterruptedException { + SslProvider[] providers = SslProvider.values(); + for (int i = 0; i < providers.length; ++i) { + for (int j = 0; j < providers.length; ++j) { + compositeBufSizeEstimationGuaranteesSynchronousWrite(providers[i], providers[j]); + } + } + } + + private void compositeBufSizeEstimationGuaranteesSynchronousWrite( + SslProvider serverProvider, SslProvider clientProvider) + throws CertificateException, SSLException, ExecutionException, InterruptedException { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + + final SslContext sslServerCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(serverProvider) + .build(); + + final SslContext sslClientCtx = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(clientProvider).build(); + + EventLoopGroup group = new NioEventLoopGroup(); + Channel sc = null; + Channel cc = null; + try { + final Promise donePromise = group.next().newPromise(); + final int expectedBytes = 469 + 1024 + 1024; + + sc = new ServerBootstrap() + .group(group) + .channel(NioServerSocketChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(sslServerCtx.newHandler(ch.alloc())); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof SslHandshakeCompletionEvent) { + SslHandshakeCompletionEvent sslEvt = (SslHandshakeCompletionEvent) evt; + if (sslEvt.isSuccess()) { + final ByteBuf input = ctx.alloc().buffer(); + input.writeBytes(new byte[expectedBytes]); + CompositeByteBuf content = ctx.alloc().compositeBuffer(); + content.addComponent(true, input.readRetainedSlice(469)); + content.addComponent(true, input.readRetainedSlice(1024)); + content.addComponent(true, input.readRetainedSlice(1024)); + ctx.writeAndFlush(content).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + input.release(); + } + }); + } else { + donePromise.tryFailure(sslEvt.cause()); + } + } + ctx.fireUserEventTriggered(evt); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + donePromise.tryFailure(cause); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + donePromise.tryFailure(new IllegalStateException("server closed")); + } + }); + } + }).bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + + cc = new Bootstrap() + .group(group) + .channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(sslClientCtx.newHandler(ch.alloc())); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + private int bytesSeen; + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + if (msg instanceof ByteBuf) { + bytesSeen += ((ByteBuf) msg).readableBytes(); + if (bytesSeen == expectedBytes) { + donePromise.trySuccess(null); + } + } + ReferenceCountUtil.release(msg); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + donePromise.tryFailure(cause); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + donePromise.tryFailure(new IllegalStateException("client closed")); + } + }); + } + }).connect(sc.localAddress()).syncUninterruptibly().channel(); + + donePromise.get(); + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + if (sc != null) { + sc.close().syncUninterruptibly(); + } + group.shutdownGracefully(); + + ReferenceCountUtil.release(sslServerCtx); + ReferenceCountUtil.release(sslClientCtx); + ssc.delete(); + } + } }