From 784722eff4540a2e81517c1b176f99e422385082 Mon Sep 17 00:00:00 2001 From: norman Date: Wed, 6 Jun 2012 08:42:36 +0200 Subject: [PATCH] Only parse the packet length one time per packet. See #382 --- .../jboss/netty/handler/ssl/SslHandler.java | 115 ++++++++++-------- 1 file changed, 61 insertions(+), 54 deletions(-) diff --git a/src/main/java/org/jboss/netty/handler/ssl/SslHandler.java b/src/main/java/org/jboss/netty/handler/ssl/SslHandler.java index 4d23d11e50..90d85d166f 100644 --- a/src/main/java/org/jboss/netty/handler/ssl/SslHandler.java +++ b/src/main/java/org/jboss/netty/handler/ssl/SslHandler.java @@ -199,6 +199,8 @@ public class SslHandler extends FrameDecoder private final SSLEngineInboundCloseFuture sslEngineCloseFuture = new SSLEngineInboundCloseFuture(); + private int packetLength = -1; + /** * Creates a new instance. * @@ -572,73 +574,74 @@ public class SslHandler extends FrameDecoder protected Object decode( final ChannelHandlerContext ctx, Channel channel, ChannelBuffer buffer) throws Exception { - if (buffer.readableBytes() < 5) { - return null; - } + if (packetLength == -1) { + if (buffer.readableBytes() < 5) { + return null; + } - int packetLength = 0; + // SSLv3 or TLS - Check ContentType + boolean tls; + switch (buffer.getUnsignedByte(buffer.readerIndex())) { + case 20: // change_cipher_spec + case 21: // alert + case 22: // handshake + case 23: // application_data + tls = true; + break; + default: + // SSLv2 or bad data + tls = false; + } - // SSLv3 or TLS - Check ContentType - boolean tls; - switch (buffer.getUnsignedByte(buffer.readerIndex())) { - case 20: // change_cipher_spec - case 21: // alert - case 22: // handshake - case 23: // application_data - tls = true; - break; - default: - // SSLv2 or bad data - tls = false; - } - - if (tls) { - // SSLv3 or TLS - Check ProtocolVersion - int majorVersion = buffer.getUnsignedByte(buffer.readerIndex() + 1); - if (majorVersion == 3) { - // SSLv3 or TLS - packetLength = (getShort(buffer, buffer.readerIndex() + 3) & 0xFFFF) + 5; - if (packetLength <= 5) { + if (tls) { + // SSLv3 or TLS - Check ProtocolVersion + int majorVersion = buffer.getUnsignedByte(buffer.readerIndex() + 1); + if (majorVersion == 3) { + // SSLv3 or TLS + packetLength = (getShort(buffer, buffer.readerIndex() + 3) & 0xFFFF) + 5; + if (packetLength <= 5) { + // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) + tls = false; + } + } else { // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) tls = false; } - } else { - // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) - tls = false; } - } - if (!tls) { - // SSLv2 or bad data - Check the version - boolean sslv2 = true; - int headerLength = (buffer.getUnsignedByte( - buffer.readerIndex()) & 0x80) != 0 ? 2 : 3; - int majorVersion = buffer.getUnsignedByte( - buffer.readerIndex() + headerLength + 1); - if (majorVersion == 2 || majorVersion == 3) { - // SSLv2 - if (headerLength == 2) { - packetLength = (getShort(buffer, buffer.readerIndex()) & 0x7FFF) + 2; + if (!tls) { + // SSLv2 or bad data - Check the version + boolean sslv2 = true; + int headerLength = (buffer.getUnsignedByte( + buffer.readerIndex()) & 0x80) != 0 ? 2 : 3; + int majorVersion = buffer.getUnsignedByte( + buffer.readerIndex() + headerLength + 1); + if (majorVersion == 2 || majorVersion == 3) { + // SSLv2 + if (headerLength == 2) { + packetLength = (getShort(buffer, buffer.readerIndex()) & 0x7FFF) + 2; + } else { + packetLength = (getShort(buffer, buffer.readerIndex()) & 0x3FFF) + 3; + } + if (packetLength <= headerLength) { + sslv2 = false; + } } else { - packetLength = (getShort(buffer, buffer.readerIndex()) & 0x3FFF) + 3; - } - if (packetLength <= headerLength) { sslv2 = false; } - } else { - sslv2 = false; + + if (!sslv2) { + // Bad data - discard the buffer and raise an exception. + SSLException e = new SSLException( + "not an SSL/TLS record: " + ChannelBuffers.hexDump(buffer)); + buffer.skipBytes(buffer.readableBytes()); + throw e; + } } - if (!sslv2) { - // Bad data - discard the buffer and raise an exception. - SSLException e = new SSLException( - "not an SSL/TLS record: " + ChannelBuffers.hexDump(buffer)); - buffer.skipBytes(buffer.readableBytes()); - throw e; - } + assert packetLength > 0; } - assert packetLength > 0; if (buffer.readableBytes() < packetLength) { return null; @@ -660,7 +663,11 @@ public class SslHandler extends FrameDecoder // before calling the user code. final int packetOffset = buffer.readerIndex(); buffer.skipBytes(packetLength); - return unwrap(ctx, channel, buffer, packetOffset, packetLength); + try { + return unwrap(ctx, channel, buffer, packetOffset, packetLength); + } finally { + packetLength = -1; + } } /**