diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index ee6ac5983e..2c19502f93 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -150,9 +150,7 @@ import java.util.regex.Pattern; * For more details see * #832 in our issue tracker. */ -public class SslHandler - extends ByteToMessageDecoder - implements ChannelOutboundHandler { +public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundHandler { private static final InternalLogger logger = InternalLoggerFactory.getInstance(SslHandler.class); @@ -184,6 +182,8 @@ public class SslHandler private final CloseNotifyListener closeNotifyWriteListener = new CloseNotifyListener(); private final Queue pendingUnencryptedWrites = new ArrayDeque(); + private int packetLength; + private volatile long handshakeTimeoutMillis = 10000; private volatile long closeNotifyTimeoutMillis = 3000; @@ -702,6 +702,9 @@ public class SslHandler * Is thrown if the given {@link ByteBuf} has not at least 5 bytes to read. */ public static boolean isEncrypted(ByteBuf buffer) { + if (buffer.readableBytes() < 5) { + throw new IllegalArgumentException("buffer must have at least 5 readable bytes"); + } return getEncryptedPacketLength(buffer) != -1; } @@ -719,15 +722,12 @@ public class SslHandler * Is thrown if the given {@link ByteBuf} has not at least 5 bytes to read. */ private static int getEncryptedPacketLength(ByteBuf buffer) { - if (buffer.readableBytes() < 5) { - throw new IllegalArgumentException("buffer must have at least 5 readable bytes"); - } - + int first = buffer.readerIndex(); int packetLength = 0; // SSLv3 or TLS - Check ContentType boolean tls; - switch (buffer.getUnsignedByte(buffer.readerIndex())) { + switch (buffer.getUnsignedByte(first)) { case 20: // change_cipher_spec case 21: // alert case 22: // handshake @@ -741,10 +741,10 @@ public class SslHandler if (tls) { // SSLv3 or TLS - Check ProtocolVersion - int majorVersion = buffer.getUnsignedByte(buffer.readerIndex() + 1); + int majorVersion = buffer.getUnsignedByte(first + 1); if (majorVersion == 3) { // SSLv3 or TLS - packetLength = (getShort(buffer, buffer.readerIndex() + 3) & 0xFFFF) + 5; + packetLength = (buffer.getUnsignedShort(first + 3)) + 5; if (packetLength <= 5) { // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) tls = false; @@ -758,16 +758,14 @@ public class SslHandler if (!tls) { // SSLv2 or bad data - Check the version boolean sslv2 = true; - int headerLength = (buffer.getUnsignedByte( - buffer.readerIndex()) & 0x80) != 0 ? 2 : 3; - int majorVersion = buffer.getUnsignedByte( - buffer.readerIndex() + headerLength + 1); + int headerLength = (buffer.getUnsignedByte(first) & 0x80) != 0 ? 2 : 3; + int majorVersion = buffer.getUnsignedByte(first + headerLength + 1); if (majorVersion == 2 || majorVersion == 3) { // SSLv2 if (headerLength == 2) { - packetLength = (getShort(buffer, buffer.readerIndex()) & 0x7FFF) + 2; + packetLength = (buffer.getShort(first) & 0x7FFF) + 2; } else { - packetLength = (getShort(buffer, buffer.readerIndex()) & 0x3FFF) + 3; + packetLength = (buffer.getShort(first) & 0x3FFF) + 3; } if (packetLength <= headerLength) { sslv2 = false; @@ -793,24 +791,33 @@ public class SslHandler private void decode0(final ChannelHandlerContext ctx) throws SSLException { final ByteBuf in = internalBuffer(); - if (in.readableBytes() < 5) { - return; + // Check if the packet length was parsed yet, if so we can skip the parsing + final int readableBytes = in.readableBytes(); + int packetLength = this.packetLength; + if (packetLength == 0) { + if (readableBytes < 5) { + return; + } + + packetLength = getEncryptedPacketLength(in); + if (packetLength == -1) { + // Bad data - discard the buffer and raise an exception. + NotSslRecordException e = new NotSslRecordException( + "not an SSL/TLS record: " + ByteBufUtil.hexDump(in)); + in.skipBytes(readableBytes); + ctx.fireExceptionCaught(e); + setHandshakeFailure(e); + return; + } + + assert packetLength > 0; + this.packetLength = packetLength; } - int packetLength = getEncryptedPacketLength(in); - - if (packetLength == -1) { - // Bad data - discard the buffer and raise an exception. - NotSslRecordException e = new NotSslRecordException( - "not an SSL/TLS record: " + ByteBufUtil.hexDump(in)); - in.skipBytes(in.readableBytes()); - ctx.fireExceptionCaught(e); - setHandshakeFailure(e); + if (readableBytes < packetLength) { return; } - assert packetLength > 0; - boolean wrapLater = false; int bytesProduced = 0; try { @@ -863,6 +870,9 @@ public class SslHandler setHandshakeFailure(e); throw e; } finally { + // reset the packet length so it will be parsed again on the next call + this.packetLength = 0; + if (bytesProduced > 0) { ByteBuf decodeOut = this.decodeOut; this.decodeOut = null; @@ -871,14 +881,6 @@ public class SslHandler } } - /** - * Reads a big-endian short integer from the buffer. Please note that we do not use - * {@link ByteBuf#getShort(int)} because it might be a little-endian buffer. - */ - private static short getShort(ByteBuf buf, int offset) { - return (short) (buf.getByte(offset) << 8 | buf.getByte(offset + 1) & 0xFF); - } - private static SSLEngineResult unwrap(SSLEngine engine, ByteBuf in, ByteBuf out) throws SSLException { ByteBuffer in0 = in.nioBuffer(); for (;;) {