From a56ea06e58c4d152a3f578311ec98d14fbb3cabc Mon Sep 17 00:00:00 2001 From: norman Date: Wed, 6 Jun 2012 08:20:30 +0200 Subject: [PATCH] Only parse packet length once per packet. See #382 --- .../java/io/netty/handler/ssl/SslHandler.java | 125 +++++++++--------- 1 file changed, 64 insertions(+), 61 deletions(-) diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index c283368ef9..47968e19df 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -166,7 +166,6 @@ public class SslHandler extends StreamToStreamCodec { private final SslBufferPool bufferPool; private final Executor delegatedTaskExecutor; - // TODO: Fix STARTTLS private final boolean startTls; private boolean sentFirstMessage; @@ -183,6 +182,8 @@ public class SslHandler extends StreamToStreamCodec { private volatile boolean issueHandshake; private final SSLEngineInboundCloseFuture sslEngineCloseFuture = new SSLEngineInboundCloseFuture(); + + private int packetLength = -1; /** * Creates a new instance. @@ -610,82 +611,79 @@ public class SslHandler extends StreamToStreamCodec { @Override public void decode(ChannelInboundHandlerContext ctx, ChannelBuffer in, ChannelBuffer out) throws Exception { - if (in.readableBytes() < 5) { - return; - } - in.markReaderIndex(); - - int packetLength = 0; + // check if the packet lenght was read before + if (packetLength == -1) { + if (in.readableBytes() < 5) { + return; + } + // SSLv3 or TLS - Check ContentType + boolean tls; + switch (in.getUnsignedByte(in.readerIndex())) { + case 20: // change_cipher_spec + case 21: // alert + case 22: // handshake + case 23: // application_data + tls = true; + break; + default: + // SSLv2 or bad data + tls = false; + } - // SSLv3 or TLS - Check ContentType - boolean tls; - switch (in.getUnsignedByte(in.readerIndex())) { - case 20: // change_cipher_spec - case 21: // alert - case 22: // handshake - case 23: // application_data - tls = true; - break; - default: - // SSLv2 or bad data - tls = false; - } - - if (tls) { - // SSLv3 or TLS - Check ProtocolVersion - int majorVersion = in.getUnsignedByte(in.readerIndex() + 1); - if (majorVersion == 3) { - // SSLv3 or TLS - packetLength = (getShort(in, in.readerIndex() + 3) & 0xFFFF) + 5; - if (packetLength <= 5) { + if (tls) { + // SSLv3 or TLS - Check ProtocolVersion + int majorVersion = in.getUnsignedByte(in.readerIndex() + 1); + if (majorVersion == 3) { + // SSLv3 or TLS + packetLength = (getShort(in, in.readerIndex() + 3) & 0xFFFF) + 5; + if (packetLength <= 5) { + // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) + tls = false; + } + } else { // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) tls = false; } - } else { - // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) - tls = false; } - } - if (!tls) { - // SSLv2 or bad data - Check the version - boolean sslv2 = true; - int headerLength = (in.getUnsignedByte( + if (!tls) { + // SSLv2 or bad data - Check the version + boolean sslv2 = true; + int headerLength = (in.getUnsignedByte( in.readerIndex()) & 0x80) != 0 ? 2 : 3; - int majorVersion = in.getUnsignedByte( + int majorVersion = in.getUnsignedByte( in.readerIndex() + headerLength + 1); - if (majorVersion == 2 || majorVersion == 3) { - // SSLv2 - if (headerLength == 2) { - packetLength = (getShort(in, in.readerIndex()) & 0x7FFF) + 2; + if (majorVersion == 2 || majorVersion == 3) { + // SSLv2 + if (headerLength == 2) { + packetLength = (getShort(in, in.readerIndex()) & 0x7FFF) + 2; + } else { + packetLength = (getShort(in, in.readerIndex()) & 0x3FFF) + 3; + } + if (packetLength <= headerLength) { + sslv2 = false; + } } else { - packetLength = (getShort(in, in.readerIndex()) & 0x3FFF) + 3; - } - if (packetLength <= headerLength) { sslv2 = false; } - } else { - sslv2 = false; - } - if (!sslv2) { - // Bad data - discard the buffer and raise an exception. - SSLException e = new SSLException( - "not an SSL/TLS record: " + ChannelBuffers.hexDump(in)); - in.skipBytes(in.readableBytes()); - throw e; + if (!sslv2) { + // Bad data - discard the buffer and raise an exception. + SSLException e = new SSLException( + "not an SSL/TLS record: " + ChannelBuffers.hexDump(in)); + in.skipBytes(in.readableBytes()); + throw e; + } } + + assert packetLength > 0; } - - assert packetLength > 0; - + if (in.readableBytes() < packetLength) { - // not enough bytes so reset the reader index and return - // - // TODO: store the parsed packetlength and reuse it. This will safe us from re-parse it - in.resetReaderIndex(); + // not enough bytes left to read the packet + // so return here for now return; } @@ -705,7 +703,12 @@ public class SslHandler extends StreamToStreamCodec { // before calling the user code. final int packetOffset = in.readerIndex(); in.skipBytes(packetLength); - unwrap(ctx, ctx.channel(), in, packetOffset, packetLength, out); + try { + unwrap(ctx, ctx.channel(), in, packetOffset, packetLength, out); + } finally { + // reset packet length + packetLength = -1; + } } /**