Fix a bug in SslHandler where a truncated packet isn't handled correctly

- Fixes #1534
This commit is contained in:
Trustin Lee 2013-07-07 11:22:02 +09:00
parent 086ae3536c
commit 5010fe0a61

View File

@ -150,9 +150,7 @@ import java.util.regex.Pattern;
* For more details see * For more details see
* <a href="https://github.com/netty/netty/issues/832">#832</a> in our issue tracker. * <a href="https://github.com/netty/netty/issues/832">#832</a> in our issue tracker.
*/ */
public class SslHandler public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundHandler {
extends ByteToMessageDecoder
implements ChannelOutboundHandler {
private static final InternalLogger logger = private static final InternalLogger logger =
InternalLoggerFactory.getInstance(SslHandler.class); InternalLoggerFactory.getInstance(SslHandler.class);
@ -184,6 +182,8 @@ public class SslHandler
private final CloseNotifyListener closeNotifyWriteListener = new CloseNotifyListener(); private final CloseNotifyListener closeNotifyWriteListener = new CloseNotifyListener();
private final Queue<PendingWrite> pendingUnencryptedWrites = new ArrayDeque<PendingWrite>(); private final Queue<PendingWrite> pendingUnencryptedWrites = new ArrayDeque<PendingWrite>();
private int packetLength;
private volatile long handshakeTimeoutMillis = 10000; private volatile long handshakeTimeoutMillis = 10000;
private volatile long closeNotifyTimeoutMillis = 3000; private volatile long closeNotifyTimeoutMillis = 3000;
@ -702,6 +702,9 @@ public class SslHandler
* Is thrown if the given {@link ByteBuf} has not at least 5 bytes to read. * Is thrown if the given {@link ByteBuf} has not at least 5 bytes to read.
*/ */
public static boolean isEncrypted(ByteBuf buffer) { public static boolean isEncrypted(ByteBuf buffer) {
if (buffer.readableBytes() < 5) {
throw new IllegalArgumentException("buffer must have at least 5 readable bytes");
}
return getEncryptedPacketLength(buffer) != -1; return getEncryptedPacketLength(buffer) != -1;
} }
@ -719,15 +722,12 @@ public class SslHandler
* Is thrown if the given {@link ByteBuf} has not at least 5 bytes to read. * Is thrown if the given {@link ByteBuf} has not at least 5 bytes to read.
*/ */
private static int getEncryptedPacketLength(ByteBuf buffer) { private static int getEncryptedPacketLength(ByteBuf buffer) {
if (buffer.readableBytes() < 5) { int first = buffer.readerIndex();
throw new IllegalArgumentException("buffer must have at least 5 readable bytes");
}
int packetLength = 0; int packetLength = 0;
// SSLv3 or TLS - Check ContentType // SSLv3 or TLS - Check ContentType
boolean tls; boolean tls;
switch (buffer.getUnsignedByte(buffer.readerIndex())) { switch (buffer.getUnsignedByte(first)) {
case 20: // change_cipher_spec case 20: // change_cipher_spec
case 21: // alert case 21: // alert
case 22: // handshake case 22: // handshake
@ -741,10 +741,10 @@ public class SslHandler
if (tls) { if (tls) {
// SSLv3 or TLS - Check ProtocolVersion // SSLv3 or TLS - Check ProtocolVersion
int majorVersion = buffer.getUnsignedByte(buffer.readerIndex() + 1); int majorVersion = buffer.getUnsignedByte(first + 1);
if (majorVersion == 3) { if (majorVersion == 3) {
// SSLv3 or TLS // SSLv3 or TLS
packetLength = (getShort(buffer, buffer.readerIndex() + 3) & 0xFFFF) + 5; packetLength = (buffer.getUnsignedShort(first + 3)) + 5;
if (packetLength <= 5) { if (packetLength <= 5) {
// Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
tls = false; tls = false;
@ -758,16 +758,14 @@ public class SslHandler
if (!tls) { if (!tls) {
// SSLv2 or bad data - Check the version // SSLv2 or bad data - Check the version
boolean sslv2 = true; boolean sslv2 = true;
int headerLength = (buffer.getUnsignedByte( int headerLength = (buffer.getUnsignedByte(first) & 0x80) != 0 ? 2 : 3;
buffer.readerIndex()) & 0x80) != 0 ? 2 : 3; int majorVersion = buffer.getUnsignedByte(first + headerLength + 1);
int majorVersion = buffer.getUnsignedByte(
buffer.readerIndex() + headerLength + 1);
if (majorVersion == 2 || majorVersion == 3) { if (majorVersion == 2 || majorVersion == 3) {
// SSLv2 // SSLv2
if (headerLength == 2) { if (headerLength == 2) {
packetLength = (getShort(buffer, buffer.readerIndex()) & 0x7FFF) + 2; packetLength = (buffer.getShort(first) & 0x7FFF) + 2;
} else { } else {
packetLength = (getShort(buffer, buffer.readerIndex()) & 0x3FFF) + 3; packetLength = (buffer.getShort(first) & 0x3FFF) + 3;
} }
if (packetLength <= headerLength) { if (packetLength <= headerLength) {
sslv2 = false; sslv2 = false;
@ -793,24 +791,33 @@ public class SslHandler
private void decode0(final ChannelHandlerContext ctx) throws SSLException { private void decode0(final ChannelHandlerContext ctx) throws SSLException {
final ByteBuf in = internalBuffer(); final ByteBuf in = internalBuffer();
if (in.readableBytes() < 5) { // Check if the packet length was parsed yet, if so we can skip the parsing
return; final int readableBytes = in.readableBytes();
int packetLength = this.packetLength;
if (packetLength == 0) {
if (readableBytes < 5) {
return;
}
packetLength = getEncryptedPacketLength(in);
if (packetLength == -1) {
// Bad data - discard the buffer and raise an exception.
NotSslRecordException e = new NotSslRecordException(
"not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
in.skipBytes(readableBytes);
ctx.fireExceptionCaught(e);
setHandshakeFailure(e);
return;
}
assert packetLength > 0;
this.packetLength = packetLength;
} }
int packetLength = getEncryptedPacketLength(in); if (readableBytes < packetLength) {
if (packetLength == -1) {
// Bad data - discard the buffer and raise an exception.
NotSslRecordException e = new NotSslRecordException(
"not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
in.skipBytes(in.readableBytes());
ctx.fireExceptionCaught(e);
setHandshakeFailure(e);
return; return;
} }
assert packetLength > 0;
boolean wrapLater = false; boolean wrapLater = false;
int bytesProduced = 0; int bytesProduced = 0;
try { try {
@ -863,6 +870,9 @@ public class SslHandler
setHandshakeFailure(e); setHandshakeFailure(e);
throw e; throw e;
} finally { } finally {
// reset the packet length so it will be parsed again on the next call
this.packetLength = 0;
if (bytesProduced > 0) { if (bytesProduced > 0) {
ByteBuf decodeOut = this.decodeOut; ByteBuf decodeOut = this.decodeOut;
this.decodeOut = null; this.decodeOut = null;
@ -871,14 +881,6 @@ public class SslHandler
} }
} }
/**
* Reads a big-endian short integer from the buffer. Please note that we do not use
* {@link ByteBuf#getShort(int)} because it might be a little-endian buffer.
*/
private static short getShort(ByteBuf buf, int offset) {
return (short) (buf.getByte(offset) << 8 | buf.getByte(offset + 1) & 0xFF);
}
private static SSLEngineResult unwrap(SSLEngine engine, ByteBuf in, ByteBuf out) throws SSLException { private static SSLEngineResult unwrap(SSLEngine engine, ByteBuf in, ByteBuf out) throws SSLException {
ByteBuffer in0 = in.nioBuffer(); ByteBuffer in0 = in.nioBuffer();
for (;;) { for (;;) {