diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index 342e60baff..68a2b00de1 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -34,7 +34,6 @@ import io.netty.util.concurrent.GenericFutureListener; import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.PendingWrite; import io.netty.util.internal.PlatformDependent; -import io.netty.util.internal.RecyclableArrayList; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -50,6 +49,7 @@ import java.nio.channels.ClosedChannelException; import java.nio.channels.DatagramChannel; import java.nio.channels.SocketChannel; import java.util.ArrayDeque; +import java.util.Arrays; import java.util.Deque; import java.util.List; import java.util.concurrent.ScheduledFuture; @@ -448,7 +448,8 @@ public class SslHandler extends ByteToMessageDecoder { case NEED_UNWRAP: return; default: - throw new IllegalStateException("Unknown handshake status: " + result.getHandshakeStatus()); + throw new IllegalStateException( + "Unknown handshake status: " + result.getHandshakeStatus()); } } } @@ -505,7 +506,7 @@ public class SslHandler extends ByteToMessageDecoder { break; case NEED_UNWRAP: if (!inUnwrap) { - unwrap(ctx); + unwrapNonApp(ctx); } break; case NEED_WRAP: @@ -515,7 +516,7 @@ public class SslHandler extends ByteToMessageDecoder { // Workaround for TLS False Start problem reported at: // https://github.com/netty/netty/issues/1108#issuecomment-14266970 if (!inUnwrap) { - unwrap(ctx); + unwrapNonApp(ctx); } break; default: @@ -742,6 +743,11 @@ public class SslHandler extends ByteToMessageDecoder { @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws SSLException { + + // Keeps the list of the length of every SSL record in the input buffer. + int[] recordLengths = null; + int nRecords = 0; + final int startOffset = in.readerIndex(); final int endOffset = in.writerIndex(); int offset = startOffset; @@ -751,6 +757,10 @@ public class SslHandler extends ByteToMessageDecoder { if (endOffset - startOffset < packetLength) { return; } else { + recordLengths = new int[4]; + recordLengths[0] = packetLength; + nRecords = 1; + offset += packetLength; packetLength = 0; } @@ -778,11 +788,22 @@ public class SslHandler extends ByteToMessageDecoder { break; } + // We have a whole packet. + // Remember the length of the current packet. + if (recordLengths == null) { + recordLengths = new int[4]; + } + if (nRecords == recordLengths.length) { + recordLengths = Arrays.copyOf(recordLengths, recordLengths.length << 1); + } + recordLengths[nRecords ++] = packetLength; + + // Increment the offset to handle the next packet. offset += packetLength; } - final int length = offset - startOffset; - if (length > 0) { + final int totalLength = offset - startOffset; + if (totalLength > 0) { // The buffer contains one or more full SSL records. // Slice out the whole packet so unwrap will only be called with complete packets. // Also directly reset the packetLength. This is needed as unwrap(..) may trigger @@ -793,9 +814,9 @@ public class SslHandler extends ByteToMessageDecoder { // 4) unwrapLater(...) calls decode(...) // // See https://github.com/netty/netty/issues/1534 - in.skipBytes(length); - ByteBuffer buffer = in.nioBuffer(startOffset, length); - unwrap(ctx, buffer, out); + in.skipBytes(totalLength); + ByteBuffer buffer = in.nioBuffer(startOffset, totalLength); + unwrapMultiple(ctx, buffer, totalLength, recordLengths, nRecords, out); } if (nonSslRecord) { @@ -817,26 +838,54 @@ public class SslHandler extends ByteToMessageDecoder { super.channelReadComplete(ctx); } - private void unwrap(ChannelHandlerContext ctx) throws SSLException { - RecyclableArrayList out = RecyclableArrayList.newInstance(); + /** + * Calls {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer)} with an empty buffer to handle handshakes, etc. + */ + private void unwrapNonApp(ChannelHandlerContext ctx) throws SSLException { try { - unwrap(ctx, Unpooled.EMPTY_BUFFER.nioBuffer(), out); - final int size = out.size(); - for (int i = 0; i < size; i++) { - ctx.fireChannelRead(out.get(i)); - } + unwrapSingle(ctx, Unpooled.EMPTY_BUFFER.nioBuffer(), 0); } finally { - out.recycle(); + ByteBuf decodeOut = this.decodeOut; + if (decodeOut != null && decodeOut.isReadable()) { + this.decodeOut = null; + ctx.fireChannelRead(decodeOut); + } } } - private void unwrap(ChannelHandlerContext ctx, ByteBuffer packet, List out) throws SSLException { + /** + * Unwraps multiple inbound SSL records. + */ + private void unwrapMultiple( + ChannelHandlerContext ctx, ByteBuffer packet, int totalLength, + int[] recordLengths, int nRecords, List out) throws SSLException { + + for (int i = 0; i < nRecords; i ++) { + packet.limit(packet.position() + recordLengths[i]); + try { + unwrapSingle(ctx, packet, totalLength); + assert !packet.hasRemaining(); + } finally { + ByteBuf decodeOut = this.decodeOut; + if (decodeOut != null && decodeOut.isReadable()) { + this.decodeOut = null; + out.add(decodeOut); + } + } + } + } + + /** + * Unwraps a single SSL record. + */ + private void unwrapSingle( + ChannelHandlerContext ctx, ByteBuffer packet, int initialOutAppBufCapacity) throws SSLException { + boolean wrapLater = false; - int totalProduced = 0; try { for (;;) { if (decodeOut == null) { - decodeOut = ctx.alloc().buffer(packet.remaining()); + decodeOut = ctx.alloc().buffer(initialOutAppBufCapacity); } final SSLEngineResult result = unwrap(engine, packet, decodeOut); @@ -845,7 +894,6 @@ public class SslHandler extends ByteToMessageDecoder { final int produced = result.bytesProduced(); final int consumed = result.bytesConsumed(); - totalProduced += produced; if (status == Status.CLOSED) { // notify about the CLOSED state of the SSLEngine. See #137 sslCloseFuture.trySuccess(ctx.channel()); @@ -886,12 +934,6 @@ public class SslHandler extends ByteToMessageDecoder { } catch (SSLException e) { setHandshakeFailure(e); throw e; - } finally { - if (totalProduced > 0) { - ByteBuf decodeOut = this.decodeOut; - this.decodeOut = null; - out.add(decodeOut); - } } }