diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index f5d4124e0e..27792f9139 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -16,6 +16,7 @@ package io.netty.handler.ssl; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufUtil; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; @@ -48,7 +49,6 @@ import java.nio.channels.ClosedChannelException; import java.nio.channels.DatagramChannel; import java.nio.channels.SocketChannel; import java.util.ArrayDeque; -import java.util.Arrays; import java.util.Deque; import java.util.List; import java.util.concurrent.ScheduledFuture; @@ -189,7 +189,6 @@ public class SslHandler extends ByteToMessageDecoder { private boolean needsFlush; private int packetLength; - private ByteBuf decodeOut; private volatile long handshakeTimeoutMillis = 10000; private volatile long closeNotifyTimeoutMillis = 3000; @@ -318,10 +317,6 @@ public class SslHandler extends ByteToMessageDecoder { @Override public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { - if (decodeOut != null) { - decodeOut.release(); - decodeOut = null; - } for (;;) { PendingWrite write = pendingUnencryptedWrites.poll(); if (write == null) { @@ -384,7 +379,7 @@ public class SslHandler extends ByteToMessageDecoder { break; } if (out == null) { - out = ctx.alloc().buffer(maxPacketBufferSize); + out = allocate(ctx, maxPacketBufferSize); } if (!(pending.msg() instanceof ByteBuf)) { @@ -470,7 +465,7 @@ public class SslHandler extends ByteToMessageDecoder { try { for (;;) { if (out == null) { - out = ctx.alloc().buffer(maxPacketBufferSize); + out = allocate(ctx, maxPacketBufferSize); } SSLEngineResult result = wrap(engine, Unpooled.EMPTY_BUFFER, out); @@ -491,7 +486,7 @@ public class SslHandler extends ByteToMessageDecoder { break; case NEED_UNWRAP: if (!inUnwrap) { - unwrapNonApp(ctx); + unwrapNonAppData(ctx); } break; case NEED_WRAP: @@ -501,7 +496,7 @@ public class SslHandler extends ByteToMessageDecoder { // Workaround for TLS False Start problem reported at: // https://github.com/netty/netty/issues/1108#issuecomment-14266970 if (!inUnwrap) { - unwrapNonApp(ctx); + unwrapNonAppData(ctx); } break; default: @@ -729,31 +724,25 @@ public class SslHandler extends ByteToMessageDecoder { @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws SSLException { - // Keeps the list of the length of every SSL record in the input buffer. - int[] recordLengths = null; - int nRecords = 0; - final int startOffset = in.readerIndex(); final int endOffset = in.writerIndex(); int offset = startOffset; + int totalLength = 0; // If we calculated the length of the current SSL record before, use that information. if (packetLength > 0) { if (endOffset - startOffset < packetLength) { return; } else { - recordLengths = new int[4]; - recordLengths[0] = packetLength; - nRecords = 1; - offset += packetLength; + totalLength = packetLength; packetLength = 0; } } boolean nonSslRecord = false; - for (;;) { + while (totalLength < OpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH) { final int readableBytes = endOffset - offset; if (readableBytes < 5) { break; @@ -773,21 +762,18 @@ public class SslHandler extends ByteToMessageDecoder { break; } - // We have a whole packet. - // Remember the length of the current packet. - if (recordLengths == null) { - recordLengths = new int[4]; + int newTotalLength = totalLength + packetLength; + if (newTotalLength > OpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH) { + // Don't read too much. + break; } - if (nRecords == recordLengths.length) { - recordLengths = Arrays.copyOf(recordLengths, recordLengths.length << 1); - } - recordLengths[nRecords ++] = packetLength; + // We have a whole packet. // Increment the offset to handle the next packet. offset += packetLength; + totalLength = newTotalLength; } - final int totalLength = offset - startOffset; if (totalLength > 0) { // The buffer contains one or more full SSL records. // Slice out the whole packet so unwrap will only be called with complete packets. @@ -799,9 +785,11 @@ public class SslHandler extends ByteToMessageDecoder { // 4) unwrapLater(...) calls decode(...) // // See https://github.com/netty/netty/issues/1534 + in.skipBytes(totalLength); - ByteBuffer buffer = in.nioBuffer(startOffset, totalLength); - unwrapMultiple(ctx, buffer, totalLength, recordLengths, nRecords, out); + final ByteBuffer inNetBuf = in.nioBuffer(startOffset, totalLength); + unwrap(ctx, inNetBuf, totalLength); + assert !inNetBuf.hasRemaining() || engine.isInboundDone(); } if (nonSslRecord) { @@ -826,52 +814,20 @@ public class SslHandler extends ByteToMessageDecoder { /** * Calls {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer)} with an empty buffer to handle handshakes, etc. */ - private void unwrapNonApp(ChannelHandlerContext ctx) throws SSLException { - try { - unwrapSingle(ctx, Unpooled.EMPTY_BUFFER.nioBuffer(), 0); - } finally { - ByteBuf decodeOut = this.decodeOut; - if (decodeOut != null && decodeOut.isReadable()) { - this.decodeOut = null; - ctx.fireChannelRead(decodeOut); - } - } + private void unwrapNonAppData(ChannelHandlerContext ctx) throws SSLException { + unwrap(ctx, Unpooled.EMPTY_BUFFER.nioBuffer(), 0); } /** - * Unwraps multiple inbound SSL records. + * Unwraps inbound SSL records. */ - private void unwrapMultiple( - ChannelHandlerContext ctx, ByteBuffer packet, int totalLength, - int[] recordLengths, int nRecords, List out) throws SSLException { - for (int i = 0; i < nRecords; i ++) { - packet.limit(packet.position() + recordLengths[i]); - try { - unwrapSingle(ctx, packet, totalLength); - assert !packet.hasRemaining() || engine.isInboundDone(); - } finally { - ByteBuf decodeOut = this.decodeOut; - if (decodeOut != null && decodeOut.isReadable()) { - this.decodeOut = null; - out.add(decodeOut); - } - } - } - } - - /** - * Unwraps a single SSL record. - */ - private void unwrapSingle( + private void unwrap( ChannelHandlerContext ctx, ByteBuffer packet, int initialOutAppBufCapacity) throws SSLException { boolean wrapLater = false; + ByteBuf decodeOut = allocate(ctx, initialOutAppBufCapacity); try { for (;;) { - if (decodeOut == null) { - decodeOut = ctx.alloc().buffer(initialOutAppBufCapacity); - } - final SSLEngineResult result = unwrap(engine, packet, decodeOut); final Status status = result.getStatus(); final HandshakeStatus handshakeStatus = result.getHandshakeStatus(); @@ -926,6 +882,12 @@ public class SslHandler extends ByteToMessageDecoder { } catch (SSLException e) { setHandshakeFailure(e); throw e; + } finally { + if (decodeOut.isReadable()) { + ctx.fireChannelRead(decodeOut); + } else { + decodeOut.release(); + } } } @@ -1156,6 +1118,19 @@ public class SslHandler extends ByteToMessageDecoder { }); } + /** + * Always prefer a direct buffer when it's pooled, so that we reduce the number of memory copies + * in {@link OpenSslEngine}. + */ + private static ByteBuf allocate(ChannelHandlerContext ctx, int capacity) { + ByteBufAllocator alloc = ctx.alloc(); + if (alloc.isDirectBufferPooled()) { + return alloc.directBuffer(capacity); + } else { + return alloc.buffer(capacity); + } + } + private final class LazyChannelPromise extends DefaultPromise { @Override