From f175ce06535e748b7df9b1bc8eb7c028d41780bc Mon Sep 17 00:00:00 2001 From: Trustin Lee Date: Sun, 20 Apr 2014 17:20:00 +0900 Subject: [PATCH] Feed only a single SSL record to SSLEngine.unwrap() Motivation: Some SSLEngine implementations violate the contract and raises an exception when SslHandler feeds an input buffer that contains multiple SSL records to SSLEngine.unwrap(), while the expected behavior is to decode the first record and return. Modification: - Modify SslHandler.decode() to keep the lengths of each record and feed SSLEngine.unwrap() record by record to work around the forementioned issue. - Rename unwrap() to unwrapMultiple() and unwrapNonApp() - Rename unwrap0() to unwrapSingle() Result: SslHandler now works OpenSSLEngine from finagle-native. Performance impact remains unnoticeable. Slightly better readability. Fixes #2116. --- .../java/io/netty/handler/ssl/SslHandler.java | 96 +++++++++++++------ 1 file changed, 69 insertions(+), 27 deletions(-) diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index 3296be54bb..b38725ecee 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -36,7 +36,6 @@ import io.netty.util.concurrent.ImmediateExecutor; import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.PendingWrite; import io.netty.util.internal.PlatformDependent; -import io.netty.util.internal.RecyclableArrayList; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -53,6 +52,7 @@ import java.nio.channels.DatagramChannel; import java.nio.channels.SocketChannel; import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Arrays; import java.util.Deque; import java.util.List; import java.util.concurrent.CountDownLatch; @@ -480,7 +480,8 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH case NEED_UNWRAP: return; default: - throw new IllegalStateException("Unknown handshake status: " + result.getHandshakeStatus()); + throw new IllegalStateException( + "Unknown handshake status: " + result.getHandshakeStatus()); } } } @@ -537,7 +538,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH break; case NEED_UNWRAP: if (!inUnwrap) { - unwrap(ctx); + unwrapNonApp(ctx); } break; case NEED_WRAP: @@ -547,7 +548,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH // Workaround for TLS False Start problem reported at: // https://github.com/netty/netty/issues/1108#issuecomment-14266970 if (!inUnwrap) { - unwrap(ctx); + unwrapNonApp(ctx); } break; default: @@ -774,6 +775,11 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws SSLException { + + // Keeps the list of the length of every SSL record in the input buffer. + int[] recordLengths = null; + int nRecords = 0; + final int startOffset = in.readerIndex(); final int endOffset = in.writerIndex(); int offset = startOffset; @@ -783,6 +789,10 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH if (endOffset - startOffset < packetLength) { return; } else { + recordLengths = new int[4]; + recordLengths[0] = packetLength; + nRecords = 1; + offset += packetLength; packetLength = 0; } @@ -810,11 +820,22 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH break; } + // We have a whole packet. + // Remember the length of the current packet. + if (recordLengths == null) { + recordLengths = new int[4]; + } + if (nRecords == recordLengths.length) { + recordLengths = Arrays.copyOf(recordLengths, recordLengths.length << 1); + } + recordLengths[nRecords ++] = packetLength; + + // Increment the offset to handle the next packet. offset += packetLength; } - final int length = offset - startOffset; - if (length > 0) { + final int totalLength = offset - startOffset; + if (totalLength > 0) { // The buffer contains one or more full SSL records. // Slice out the whole packet so unwrap will only be called with complete packets. // Also directly reset the packetLength. This is needed as unwrap(..) may trigger @@ -825,9 +846,9 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH // 4) unwrapLater(...) calls decode(...) // // See https://github.com/netty/netty/issues/1534 - in.skipBytes(length); - ByteBuffer buffer = in.nioBuffer(startOffset, length); - unwrap(ctx, buffer, out); + in.skipBytes(totalLength); + ByteBuffer buffer = in.nioBuffer(startOffset, totalLength); + unwrapMultiple(ctx, buffer, totalLength, recordLengths, nRecords, out); } if (nonSslRecord) { @@ -849,26 +870,54 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH super.channelReadComplete(ctx); } - private void unwrap(ChannelHandlerContext ctx) throws SSLException { - RecyclableArrayList out = RecyclableArrayList.newInstance(); + /** + * Calls {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer)} with an empty buffer to handle handshakes, etc. + */ + private void unwrapNonApp(ChannelHandlerContext ctx) throws SSLException { try { - unwrap(ctx, Unpooled.EMPTY_BUFFER.nioBuffer(), out); - final int size = out.size(); - for (int i = 0; i < size; i++) { - ctx.fireChannelRead(out.get(i)); - } + unwrapSingle(ctx, Unpooled.EMPTY_BUFFER.nioBuffer(), 0); } finally { - out.recycle(); + ByteBuf decodeOut = this.decodeOut; + if (decodeOut != null && decodeOut.isReadable()) { + this.decodeOut = null; + ctx.fireChannelRead(decodeOut); + } } } - private void unwrap(ChannelHandlerContext ctx, ByteBuffer packet, List out) throws SSLException { + /** + * Unwraps multiple inbound SSL records. + */ + private void unwrapMultiple( + ChannelHandlerContext ctx, ByteBuffer packet, int totalLength, + int[] recordLengths, int nRecords, List out) throws SSLException { + + for (int i = 0; i < nRecords; i ++) { + packet.limit(packet.position() + recordLengths[i]); + try { + unwrapSingle(ctx, packet, totalLength); + assert !packet.hasRemaining(); + } finally { + ByteBuf decodeOut = this.decodeOut; + if (decodeOut != null && decodeOut.isReadable()) { + this.decodeOut = null; + out.add(decodeOut); + } + } + } + } + + /** + * Unwraps a single SSL record. + */ + private void unwrapSingle( + ChannelHandlerContext ctx, ByteBuffer packet, int initialOutAppBufCapacity) throws SSLException { + boolean wrapLater = false; - int totalProduced = 0; try { for (;;) { if (decodeOut == null) { - decodeOut = ctx.alloc().buffer(packet.remaining()); + decodeOut = ctx.alloc().buffer(initialOutAppBufCapacity); } final SSLEngineResult result = unwrap(engine, packet, decodeOut); @@ -877,7 +926,6 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH final int produced = result.bytesProduced(); final int consumed = result.bytesConsumed(); - totalProduced += produced; if (status == Status.CLOSED) { // notify about the CLOSED state of the SSLEngine. See #137 sslCloseFuture.trySuccess(ctx.channel()); @@ -918,12 +966,6 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH } catch (SSLException e) { setHandshakeFailure(e); throw e; - } finally { - if (totalProduced > 0) { - ByteBuf decodeOut = this.decodeOut; - this.decodeOut = null; - out.add(decodeOut); - } } }