diff --git a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java
index 80b9990936..679a57e9d6 100644
--- a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java
+++ b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java
@@ -57,6 +57,7 @@ import javax.net.ssl.SSLSessionContext;
import javax.security.cert.X509Certificate;
import static io.netty.handler.ssl.OpenSsl.memoryAddress;
+import static io.netty.handler.ssl.SslUtils.SSL_RECORD_HEADER_LENGTH;
import static io.netty.util.internal.EmptyArrays.EMPTY_CERTIFICATES;
import static io.netty.util.internal.EmptyArrays.EMPTY_JAVAX_X509_CERTIFICATES;
import static io.netty.util.internal.ObjectUtil.checkNotNull;
@@ -107,15 +108,14 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
* allow up to 255 bytes. 16 bytes is the max for PKC#5 (which handles it the same way as PKC#7) as we use a block
* size of 16. See rfc5652#section-6.3.
*
- * 16 (IV) + 48 (MAC) + 1 (Padding_length field) + 15 (Padding) + 1 (ContentType) + 2 (ProtocolVersion) + 2 (Length)
+ * TLS Header (5) + 16 (IV) + 48 (MAC) + 1 (Padding_length field) + 15 (Padding) + 1 (ContentType) +
+ * 2 (ProtocolVersion) + 2 (Length)
*
* TODO: We may need to review this calculation once TLS 1.3 becomes available.
*/
- static final int MAX_ENCRYPTION_OVERHEAD_LENGTH = 15 + 48 + 1 + 16 + 1 + 2 + 2;
+ static final int MAX_TLS_RECORD_OVERHEAD_LENGTH = SSL_RECORD_HEADER_LENGTH + 16 + 48 + 1 + 15 + 1 + 2 + 2;
- static final int MAX_ENCRYPTED_PACKET_LENGTH = MAX_PLAINTEXT_LENGTH + MAX_ENCRYPTION_OVERHEAD_LENGTH;
-
- private static final int MAX_ENCRYPTION_OVERHEAD_DIFF = Integer.MAX_VALUE - MAX_ENCRYPTION_OVERHEAD_LENGTH;
+ static final int MAX_ENCRYPTED_PACKET_LENGTH = MAX_PLAINTEXT_LENGTH + MAX_TLS_RECORD_OVERHEAD_LENGTH;
private static final AtomicIntegerFieldUpdater DESTROYED_UPDATER =
AtomicIntegerFieldUpdater.newUpdater(ReferenceCountedOpenSslEngine.class, "destroyed");
@@ -561,7 +561,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
}
}
- if (dst.remaining() < calculateOutNetBufSize(srcsLen)) {
+ if (dst.remaining() < calculateOutNetBufSize(srcsLen, endOffset - offset)) {
// Can not hold the maximum packet so we need to tell the caller to use a bigger destination
// buffer.
return new SSLEngineResult(BUFFER_OVERFLOW, getHandshakeStatus(), 0, 0);
@@ -772,7 +772,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
}
}
- if (len < SslUtils.SSL_RECORD_HEADER_LENGTH) {
+ if (len < SSL_RECORD_HEADER_LENGTH) {
return newResultMayFinishHandshake(BUFFER_UNDERFLOW, status, 0, 0);
}
@@ -782,7 +782,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
throw new NotSslRecordException("not an SSL/TLS record");
}
- if (packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH > capacity) {
+ if (packetLength - SSL_RECORD_HEADER_LENGTH > capacity) {
// No enough space in the destination buffer so signal the caller
// that the buffer needs to be increased.
return newResultMayFinishHandshake(BUFFER_OVERFLOW, status, 0, 0);
@@ -1606,9 +1606,8 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
return destroyed != 0;
}
- static int calculateOutNetBufSize(int pendingBytes) {
- return min(MAX_ENCRYPTED_PACKET_LENGTH, MAX_ENCRYPTION_OVERHEAD_LENGTH
- + min(MAX_ENCRYPTION_OVERHEAD_DIFF, pendingBytes));
+ static int calculateOutNetBufSize(int pendingBytes, int numComponents) {
+ return (int) min(Integer.MAX_VALUE, pendingBytes + (long) MAX_TLS_RECORD_OVERHEAD_LENGTH * numComponents);
}
private final class OpenSslSession implements SSLSession, ApplicationProtocolAccessor {
diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java
index 601f8dfd29..ac9843b0e8 100644
--- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java
+++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java
@@ -193,7 +193,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
* we can use a special {@link OpenSslEngine#unwrap(ByteBuffer[], ByteBuffer[])} method
* that accepts multiple {@link ByteBuffer}s without additional memory copies.
*/
- OpenSslEngine opensslEngine = (OpenSslEngine) handler.engine;
+ ReferenceCountedOpenSslEngine opensslEngine = (ReferenceCountedOpenSslEngine) handler.engine;
try {
handler.singleBuffer[0] = toByteBuffer(out, writerIndex,
out.writableBytes());
@@ -210,8 +210,8 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
}
@Override
- int calculateOutNetBufSize(SslHandler handler, int pendingBytes) {
- return ReferenceCountedOpenSslEngine.calculateOutNetBufSize(pendingBytes);
+ int calculateOutNetBufSize(SslHandler handler, int pendingBytes, int numComponents) {
+ return ReferenceCountedOpenSslEngine.calculateOutNetBufSize(pendingBytes, numComponents);
}
},
JDK(false, MERGE_CUMULATOR) {
@@ -226,16 +226,13 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
}
@Override
- int calculateOutNetBufSize(SslHandler handler, int pendingBytes) {
+ int calculateOutNetBufSize(SslHandler handler, int pendingBytes, int numComponents) {
return handler.maxPacketBufferSize;
}
};
static SslEngineType forEngine(SSLEngine engine) {
- if (engine instanceof OpenSslEngine) {
- return TCNATIVE;
- }
- return JDK;
+ return engine instanceof ReferenceCountedOpenSslEngine ? TCNATIVE : JDK;
}
SslEngineType(boolean wantsDirectBuffer, Cumulator cumulator) {
@@ -246,7 +243,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
abstract SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int readerIndex, int len, ByteBuf out)
throws SSLException;
- abstract int calculateOutNetBufSize(SslHandler handler, int pendingBytes);
+ abstract int calculateOutNetBufSize(SslHandler handler, int pendingBytes, int numComponents);
// BEGIN Platform-dependent flags
@@ -652,7 +649,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
ByteBuf buf = (ByteBuf) msg;
if (out == null) {
- out = allocateOutNetBuf(ctx, buf.readableBytes());
+ out = allocateOutNetBuf(ctx, buf.readableBytes(), buf.nioBufferCount());
}
SSLEngineResult result = wrap(alloc, engine, buf, out);
@@ -741,7 +738,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
// As this is called for the handshake we have no real idea how big the buffer needs to be.
// That said 2048 should give us enough room to include everything like ALPN / NPN data.
// If this is not enough we will increase the buffer in wrap(...).
- out = allocateOutNetBuf(ctx, 2048);
+ out = allocateOutNetBuf(ctx, 2048, 1);
}
SSLEngineResult result = wrap(alloc, engine, Unpooled.EMPTY_BUFFER, out);
@@ -1694,8 +1691,8 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
* Allocates an outbound network buffer for {@link SSLEngine#wrap(ByteBuffer, ByteBuffer)} which can encrypt
* the specified amount of pending bytes.
*/
- private ByteBuf allocateOutNetBuf(ChannelHandlerContext ctx, int pendingBytes) {
- return allocate(ctx, engineType.calculateOutNetBufSize(this, pendingBytes));
+ private ByteBuf allocateOutNetBuf(ChannelHandlerContext ctx, int pendingBytes, int numComponents) {
+ return allocate(ctx, engineType.calculateOutNetBufSize(this, pendingBytes, numComponents));
}
private final class LazyChannelPromise extends DefaultPromise {
diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java
index 517271441c..07fcc48380 100644
--- a/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java
+++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java
@@ -39,7 +39,11 @@ import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLParameters;
+import static io.netty.handler.ssl.ReferenceCountedOpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH;
+import static io.netty.handler.ssl.ReferenceCountedOpenSslEngine.MAX_TLS_RECORD_OVERHEAD_LENGTH;
+import static io.netty.handler.ssl.ReferenceCountedOpenSslEngine.MAX_PLAINTEXT_LENGTH;
import static io.netty.internal.tcnative.SSL.SSL_CVERIFY_IGNORED;
+import static java.lang.Integer.MAX_VALUE;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
@@ -200,12 +204,12 @@ public class OpenSslEngineTest extends SSLEngineTest {
ByteBuffer src = allocateBuffer(srcLen);
ByteBuffer dstTooSmall = allocateBuffer(
- src.capacity() + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH - 1);
+ src.capacity() + MAX_TLS_RECORD_OVERHEAD_LENGTH - 1);
ByteBuffer dst = allocateBuffer(
- src.capacity() + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH);
+ src.capacity() + MAX_TLS_RECORD_OVERHEAD_LENGTH);
// Check that we fail to wrap if the dst buffers capacity is not at least
- // src.capacity() + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH
+ // src.capacity() + ReferenceCountedOpenSslEngine.MAX_TLS_RECORD_OVERHEAD_LENGTH
SSLEngineResult result = clientEngine.wrap(src, dstTooSmall);
assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus());
assertEquals(0, result.bytesConsumed());
@@ -214,7 +218,7 @@ public class OpenSslEngineTest extends SSLEngineTest {
assertEquals(dst.remaining(), dst.capacity());
// Check that we can wrap with a dst buffer that has the capacity of
- // src.capacity() + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH
+ // src.capacity() + ReferenceCountedOpenSslEngine.MAX_TLS_RECORD_OVERHEAD_LENGTH
result = clientEngine.wrap(src, dst);
assertEquals(SSLEngineResult.Status.OK, result.getStatus());
assertEquals(srcLen, result.bytesConsumed());
@@ -249,7 +253,7 @@ public class OpenSslEngineTest extends SSLEngineTest {
ByteBuffer src2 = src.duplicate();
ByteBuffer dst = allocateBuffer(src.capacity()
- + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH);
+ + MAX_TLS_RECORD_OVERHEAD_LENGTH);
SSLEngineResult result = clientEngine.wrap(new ByteBuffer[] { src, src2 }, dst);
assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus());
@@ -284,7 +288,7 @@ public class OpenSslEngineTest extends SSLEngineTest {
ByteBuffer src = allocateBuffer(1024);
List srcList = new ArrayList();
long srcsLen = 0;
- long maxLen = ((long) Integer.MAX_VALUE) * 2;
+ long maxLen = ((long) MAX_VALUE) * 2;
while (srcsLen < maxLen) {
ByteBuffer dup = src.duplicate();
@@ -294,7 +298,7 @@ public class OpenSslEngineTest extends SSLEngineTest {
ByteBuffer[] srcs = srcList.toArray(new ByteBuffer[srcList.size()]);
- ByteBuffer dst = allocateBuffer(ReferenceCountedOpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH - 1);
+ ByteBuffer dst = allocateBuffer(MAX_ENCRYPTED_PACKET_LENGTH - 1);
SSLEngineResult result = clientEngine.wrap(srcs, dst);
assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus());
@@ -313,14 +317,14 @@ public class OpenSslEngineTest extends SSLEngineTest {
@Test
public void testCalculateOutNetBufSizeOverflow() {
- assertEquals(ReferenceCountedOpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH,
- ReferenceCountedOpenSslEngine.calculateOutNetBufSize(Integer.MAX_VALUE));
+ assertEquals(MAX_VALUE,
+ ReferenceCountedOpenSslEngine.calculateOutNetBufSize(MAX_VALUE, 1));
}
@Test
public void testCalculateOutNetBufSize0() {
- assertEquals(ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH,
- ReferenceCountedOpenSslEngine.calculateOutNetBufSize(0));
+ assertEquals(MAX_TLS_RECORD_OVERHEAD_LENGTH,
+ ReferenceCountedOpenSslEngine.calculateOutNetBufSize(0, 1));
}
@Override
@@ -538,9 +542,9 @@ public class OpenSslEngineTest extends SSLEngineTest {
do {
testWrapDstBigEnough(clientEngine, srcLen);
srcLen += 64;
- } while (srcLen < ReferenceCountedOpenSslEngine.MAX_PLAINTEXT_LENGTH);
+ } while (srcLen < MAX_PLAINTEXT_LENGTH);
- testWrapDstBigEnough(clientEngine, ReferenceCountedOpenSslEngine.MAX_PLAINTEXT_LENGTH);
+ testWrapDstBigEnough(clientEngine, MAX_PLAINTEXT_LENGTH);
} finally {
cleanupClientSslEngine(clientEngine);
cleanupServerSslEngine(serverEngine);
@@ -549,7 +553,7 @@ public class OpenSslEngineTest extends SSLEngineTest {
private void testWrapDstBigEnough(SSLEngine engine, int srcLen) throws SSLException {
ByteBuffer src = allocateBuffer(srcLen);
- ByteBuffer dst = allocateBuffer(srcLen + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH);
+ ByteBuffer dst = allocateBuffer(srcLen + MAX_TLS_RECORD_OVERHEAD_LENGTH);
SSLEngineResult result = engine.wrap(src, dst);
assertEquals(SSLEngineResult.Status.OK, result.getStatus());
diff --git a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java
index b39de0fced..89ff9223b0 100644
--- a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java
+++ b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java
@@ -32,7 +32,11 @@ import javax.net.ssl.X509TrustManager;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.CompositeByteBuf;
import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.EventLoopGroup;
@@ -69,6 +73,7 @@ import java.security.KeyStore;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ExecutionException;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
@@ -560,4 +565,129 @@ public class SslHandlerTest {
ReferenceCountUtil.release(sslClientCtx);
}
}
+
+ @Test(timeout = 30000)
+ public void testCompositeBufSizeEstimationGuaranteesSynchronousWrite()
+ throws CertificateException, SSLException, ExecutionException, InterruptedException {
+ SslProvider[] providers = SslProvider.values();
+ for (int i = 0; i < providers.length; ++i) {
+ for (int j = 0; j < providers.length; ++j) {
+ compositeBufSizeEstimationGuaranteesSynchronousWrite(providers[i], providers[j]);
+ }
+ }
+ }
+
+ private void compositeBufSizeEstimationGuaranteesSynchronousWrite(
+ SslProvider serverProvider, SslProvider clientProvider)
+ throws CertificateException, SSLException, ExecutionException, InterruptedException {
+ SelfSignedCertificate ssc = new SelfSignedCertificate();
+
+ final SslContext sslServerCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey())
+ .sslProvider(serverProvider)
+ .build();
+
+ final SslContext sslClientCtx = SslContextBuilder.forClient()
+ .trustManager(InsecureTrustManagerFactory.INSTANCE)
+ .sslProvider(clientProvider).build();
+
+ EventLoopGroup group = new NioEventLoopGroup();
+ Channel sc = null;
+ Channel cc = null;
+ try {
+ final Promise donePromise = group.next().newPromise();
+ final int expectedBytes = 469 + 1024 + 1024;
+
+ sc = new ServerBootstrap()
+ .group(group)
+ .channel(NioServerSocketChannel.class)
+ .childHandler(new ChannelInitializer() {
+ @Override
+ protected void initChannel(Channel ch) throws Exception {
+ ch.pipeline().addLast(sslServerCtx.newHandler(ch.alloc()));
+ ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
+ @Override
+ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
+ if (evt instanceof SslHandshakeCompletionEvent) {
+ SslHandshakeCompletionEvent sslEvt = (SslHandshakeCompletionEvent) evt;
+ if (sslEvt.isSuccess()) {
+ final ByteBuf input = ctx.alloc().buffer();
+ input.writeBytes(new byte[expectedBytes]);
+ CompositeByteBuf content = ctx.alloc().compositeBuffer();
+ content.addComponent(true, input.readRetainedSlice(469));
+ content.addComponent(true, input.readRetainedSlice(1024));
+ content.addComponent(true, input.readRetainedSlice(1024));
+ ctx.writeAndFlush(content).addListener(new ChannelFutureListener() {
+ @Override
+ public void operationComplete(ChannelFuture future) {
+ input.release();
+ }
+ });
+ } else {
+ donePromise.tryFailure(sslEvt.cause());
+ }
+ }
+ ctx.fireUserEventTriggered(evt);
+ }
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
+ donePromise.tryFailure(cause);
+ }
+
+ @Override
+ public void channelInactive(ChannelHandlerContext ctx) {
+ donePromise.tryFailure(new IllegalStateException("server closed"));
+ }
+ });
+ }
+ }).bind(new InetSocketAddress(0)).syncUninterruptibly().channel();
+
+ cc = new Bootstrap()
+ .group(group)
+ .channel(NioSocketChannel.class)
+ .handler(new ChannelInitializer() {
+ @Override
+ protected void initChannel(Channel ch) throws Exception {
+ ch.pipeline().addLast(sslClientCtx.newHandler(ch.alloc()));
+ ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
+ private int bytesSeen;
+ @Override
+ public void channelRead(ChannelHandlerContext ctx, Object msg) {
+ if (msg instanceof ByteBuf) {
+ bytesSeen += ((ByteBuf) msg).readableBytes();
+ if (bytesSeen == expectedBytes) {
+ donePromise.trySuccess(null);
+ }
+ }
+ ReferenceCountUtil.release(msg);
+ }
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
+ donePromise.tryFailure(cause);
+ }
+
+ @Override
+ public void channelInactive(ChannelHandlerContext ctx) {
+ donePromise.tryFailure(new IllegalStateException("client closed"));
+ }
+ });
+ }
+ }).connect(sc.localAddress()).syncUninterruptibly().channel();
+
+ donePromise.get();
+ } finally {
+ if (cc != null) {
+ cc.close().syncUninterruptibly();
+ }
+ if (sc != null) {
+ sc.close().syncUninterruptibly();
+ }
+ group.shutdownGracefully();
+
+ ReferenceCountUtil.release(sslServerCtx);
+ ReferenceCountUtil.release(sslClientCtx);
+ ssc.delete();
+ }
+ }
}