diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index 01ef75131a..ee8ae7554a 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -19,6 +19,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; +import io.netty.channel.ChannelException; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; @@ -183,6 +184,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH private final Queue pendingUnencryptedWrites = new ArrayDeque(); private int packetLength; + private ByteBuf decodeOut; private volatile long handshakeTimeoutMillis = 10000; private volatile long closeNotifyTimeoutMillis = 3000; @@ -343,9 +345,17 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH } @Override - protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { + public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { if (decodeOut != null) { decodeOut.release(); + decodeOut = null; + } + for (;;) { + PendingWrite write = pendingUnencryptedWrites.poll(); + if (write == null) { + break; + } + write.fail(new ChannelException("Pending write on removal of SslHandler")); } } @@ -495,7 +505,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH } if (unwrapLater) { - decode0(ctx); + unwrapLater(ctx); } } catch (SSLException e) { setHandshakeFailure(e); @@ -517,6 +527,16 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH } } + private void unwrapLater(ChannelHandlerContext ctx) throws SSLException { + MessageList messageList = MessageList.newInstance(); + decode(ctx, internalBuffer(), messageList); + if (messageList.isEmpty()) { + messageList.recycle(); + } else { + ctx.fireMessageReceived(messageList); + } + } + private void flushNonAppData0(ChannelHandlerContext ctx) throws SSLException { boolean unwrapLater = false; ByteBuf out = null; @@ -567,7 +587,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH } if (unwrapLater) { - decode0(ctx); + unwrapLater(ctx); } } catch (SSLException e) { setHandshakeFailure(e); @@ -782,26 +802,19 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH } @Override - public void decode(final ChannelHandlerContext ctx, ByteBuf in, MessageList out) throws Exception { - decode0(ctx); - } - - private ByteBuf decodeOut; - - private void decode0(final ChannelHandlerContext ctx) throws SSLException { - final ByteBuf in = internalBuffer(); - - // Check if the packet length was parsed yet, if so we can skip the parsing - final int readableBytes = in.readableBytes(); + protected void decode(ChannelHandlerContext ctx, ByteBuf in, MessageList out) throws SSLException { int packetLength = this.packetLength; if (packetLength == 0) { + // the previous packet was consumed so try to read the length of the next packet + final int readableBytes = in.readableBytes(); if (readableBytes < 5) { + // not enough bytes readable to read the packet length return; } packetLength = getEncryptedPacketLength(in); if (packetLength == -1) { - // Bad data - discard the buffer and raise an exception. + // Not an SSL/TLS packet NotSslRecordException e = new NotSslRecordException( "not an SSL/TLS record: " + ByteBufUtil.hexDump(in)); in.skipBytes(readableBytes); @@ -814,10 +827,22 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH this.packetLength = packetLength; } - if (readableBytes < packetLength) { + if (in.readableBytes() < packetLength) { + // wait until the whole packet can be read return; } + try { + // slice out the whole packet so unwrap will only be called with complete packets + int readerIndex = in.readerIndex(); + in.skipBytes(packetLength); + unwrap(ctx, in.nioBuffer(readerIndex, packetLength), out); + } finally { + this.packetLength = 0; + } + } + + private void unwrap(ChannelHandlerContext ctx, ByteBuffer packet, MessageList out) throws SSLException { boolean wrapLater = false; int bytesProduced = 0; try { @@ -826,9 +851,8 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH if (decodeOut == null) { decodeOut = ctx.alloc().buffer(); } - SSLEngineResult result = unwrap(engine, in, decodeOut); + SSLEngineResult result = unwrap(engine, packet, decodeOut); bytesProduced += result.bytesProduced(); - switch (result.getStatus()) { case CLOSED: // notify about the CLOSED state of the SSLEngine. See #137 @@ -870,23 +894,18 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH setHandshakeFailure(e); throw e; } finally { - // reset the packet length so it will be parsed again on the next call - this.packetLength = 0; - if (bytesProduced > 0) { ByteBuf decodeOut = this.decodeOut; this.decodeOut = null; - ctx.fireMessageReceived(decodeOut); + out.add(decodeOut); } } } - private static SSLEngineResult unwrap(SSLEngine engine, ByteBuf in, ByteBuf out) throws SSLException { - ByteBuffer in0 = in.nioBuffer(); + private static SSLEngineResult unwrap(SSLEngine engine, ByteBuffer in, ByteBuf out) throws SSLException { for (;;) { ByteBuffer out0 = out.nioBuffer(out.writerIndex(), out.writableBytes()); - SSLEngineResult result = engine.unwrap(in0, out0); - in.skipBytes(result.bytesConsumed()); + SSLEngineResult result = engine.unwrap(in, out0); out.writerIndex(out.writerIndex() + result.bytesProduced()); switch (result.getStatus()) { case BUFFER_OVERFLOW: