Correct guard against non SSL data in ReferenceCountedOpenSslEngine

Motivation:

When non SSL data is passed into SSLEngine.unwrap(...) we need to throw an SSLException. This was not done at the moment. Even worse we threw an IllegalArgumentException as we tried to allocate a direct buffer with capacity of -1.

Modifications:

- Guard against non SSL data and added an unit test.
- Make code more consistent

Result:

Correct behaving SSLEngine implementation.
This commit is contained in:
Norman Maurer 2016-12-01 09:30:26 +01:00 committed by Norman Maurer
parent aa89f37c2a
commit 4d327463c7
5 changed files with 69 additions and 22 deletions

View File

@ -775,6 +775,11 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
} }
int packetLength = SslUtils.getEncryptedPacketLength(srcs, srcsOffset); int packetLength = SslUtils.getEncryptedPacketLength(srcs, srcsOffset);
if (packetLength == SslUtils.NOT_ENCRYPTED) {
throw new NotSslRecordException("not an SSL/TLS record");
}
if (packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH > capacity) { if (packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH > capacity) {
// No enough space in the destination buffer so signal the caller // No enough space in the destination buffer so signal the caller
// that the buffer needs to be increased. // that the buffer needs to be increased.

View File

@ -104,7 +104,7 @@ public class SniHandler extends ByteToMessageDecoder {
final int len = SslUtils.getEncryptedPacketLength(in, readerIndex); final int len = SslUtils.getEncryptedPacketLength(in, readerIndex);
// Not an SSL/TLS packet // Not an SSL/TLS packet
if (len == -1) { if (len == SslUtils.NOT_ENCRYPTED) {
handshakeFailed = true; handshakeFailed = true;
NotSslRecordException e = new NotSslRecordException( NotSslRecordException e = new NotSslRecordException(
"not an SSL/TLS record: " + ByteBufUtil.hexDump(in)); "not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
@ -114,7 +114,8 @@ public class SniHandler extends ByteToMessageDecoder {
SslUtils.notifyHandshakeFailure(ctx, e); SslUtils.notifyHandshakeFailure(ctx, e);
return; return;
} }
if (writerIndex - readerIndex - SslUtils.SSL_RECORD_HEADER_LENGTH < len) { if (len == SslUtils.NOT_ENOUGH_DATA ||
writerIndex - readerIndex - SslUtils.SSL_RECORD_HEADER_LENGTH < len) {
// Not enough data // Not enough data
return; return;
} }

View File

@ -877,7 +877,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
throw new IllegalArgumentException( throw new IllegalArgumentException(
"buffer must have at least " + SslUtils.SSL_RECORD_HEADER_LENGTH + " readable bytes"); "buffer must have at least " + SslUtils.SSL_RECORD_HEADER_LENGTH + " readable bytes");
} }
return getEncryptedPacketLength(buffer, buffer.readerIndex()) != -1; return getEncryptedPacketLength(buffer, buffer.readerIndex()) != SslUtils.NOT_ENCRYPTED;
} }
@Override @Override
@ -907,7 +907,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
} }
final int packetLength = getEncryptedPacketLength(in, offset); final int packetLength = getEncryptedPacketLength(in, offset);
if (packetLength == -1) { if (packetLength == SslUtils.NOT_ENCRYPTED) {
nonSslRecord = true; nonSslRecord = true;
break; break;
} }

View File

@ -32,27 +32,37 @@ final class SslUtils {
/** /**
* change cipher spec * change cipher spec
*/ */
public static final int SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC = 20; static final int SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC = 20;
/** /**
* alert * alert
*/ */
public static final int SSL_CONTENT_TYPE_ALERT = 21; static final int SSL_CONTENT_TYPE_ALERT = 21;
/** /**
* handshake * handshake
*/ */
public static final int SSL_CONTENT_TYPE_HANDSHAKE = 22; static final int SSL_CONTENT_TYPE_HANDSHAKE = 22;
/** /**
* application data * application data
*/ */
public static final int SSL_CONTENT_TYPE_APPLICATION_DATA = 23; static final int SSL_CONTENT_TYPE_APPLICATION_DATA = 23;
/** /**
* the length of the ssl record header (in bytes) * the length of the ssl record header (in bytes)
*/ */
public static final int SSL_RECORD_HEADER_LENGTH = 5; static final int SSL_RECORD_HEADER_LENGTH = 5;
/**
* Not enough data in buffer to parse the record length
*/
static final int NOT_ENOUGH_DATA = -1;
/**
* data is not encrypted
*/
static final int NOT_ENCRYPTED = -2;
/** /**
* Return how much bytes can be read out of the encrypted data. Be aware that this method will not increase * Return how much bytes can be read out of the encrypted data. Be aware that this method will not increase
@ -63,8 +73,10 @@ final class SslUtils {
* {@link #SSL_RECORD_HEADER_LENGTH} bytes to read, * {@link #SSL_RECORD_HEADER_LENGTH} bytes to read,
* otherwise it will throw an {@link IllegalArgumentException}. * otherwise it will throw an {@link IllegalArgumentException}.
* @return length * @return length
* The length of the encrypted packet that is included in the buffer. This will * The length of the encrypted packet that is included in the buffer or
* return {@code -1} if the given {@link ByteBuf} is not encrypted at all. * {@link #SslUtils#NOT_ENOUGH_DATA} if not enought data is present in the
* {@link ByteBuf}. This will return {@link SslUtils#NOT_ENCRYPTED} if
* the given {@link ByteBuf} is not encrypted at all.
* @throws IllegalArgumentException * @throws IllegalArgumentException
* Is thrown if the given {@link ByteBuf} has not at least {@link #SSL_RECORD_HEADER_LENGTH} * Is thrown if the given {@link ByteBuf} has not at least {@link #SSL_RECORD_HEADER_LENGTH}
* bytes to read. * bytes to read.
@ -114,10 +126,10 @@ final class SslUtils {
packetLength = (buffer.getShort(offset) & 0x3FFF) + 3; packetLength = (buffer.getShort(offset) & 0x3FFF) + 3;
} }
if (packetLength <= headerLength) { if (packetLength <= headerLength) {
return -1; return NOT_ENOUGH_DATA;
} }
} else { } else {
return -1; return NOT_ENCRYPTED;
} }
} }
return packetLength; return packetLength;
@ -135,7 +147,7 @@ final class SslUtils {
ByteBuffer buffer = buffers[offset]; ByteBuffer buffer = buffers[offset];
// Check if everything we need is in one ByteBuffer. If so we can make use of the fast-path. // Check if everything we need is in one ByteBuffer. If so we can make use of the fast-path.
if (buffer.remaining() >= SslUtils.SSL_RECORD_HEADER_LENGTH) { if (buffer.remaining() >= SSL_RECORD_HEADER_LENGTH) {
return getEncryptedPacketLength(buffer); return getEncryptedPacketLength(buffer);
} }
@ -161,10 +173,10 @@ final class SslUtils {
// SSLv3 or TLS - Check ContentType // SSLv3 or TLS - Check ContentType
boolean tls; boolean tls;
switch (unsignedByte(buffer.get(pos))) { switch (unsignedByte(buffer.get(pos))) {
case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC: case SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
case SslUtils.SSL_CONTENT_TYPE_ALERT: case SSL_CONTENT_TYPE_ALERT:
case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE: case SSL_CONTENT_TYPE_HANDSHAKE:
case SslUtils.SSL_CONTENT_TYPE_APPLICATION_DATA: case SSL_CONTENT_TYPE_APPLICATION_DATA:
tls = true; tls = true;
break; break;
default: default:
@ -177,8 +189,8 @@ final class SslUtils {
int majorVersion = unsignedByte(buffer.get(pos + 1)); int majorVersion = unsignedByte(buffer.get(pos + 1));
if (majorVersion == 3) { if (majorVersion == 3) {
// SSLv3 or TLS // SSLv3 or TLS
packetLength = unsignedShort(buffer.getShort(pos + 3)) + SslUtils.SSL_RECORD_HEADER_LENGTH; packetLength = unsignedShort(buffer.getShort(pos + 3)) + SSL_RECORD_HEADER_LENGTH;
if (packetLength <= SslUtils.SSL_RECORD_HEADER_LENGTH) { if (packetLength <= SSL_RECORD_HEADER_LENGTH) {
// Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
tls = false; tls = false;
} }
@ -200,10 +212,10 @@ final class SslUtils {
packetLength = (buffer.getShort(pos) & 0x3FFF) + 3; packetLength = (buffer.getShort(pos) & 0x3FFF) + 3;
} }
if (packetLength <= headerLength) { if (packetLength <= headerLength) {
return -1; return NOT_ENOUGH_DATA;
} }
} else { } else {
return -1; return NOT_ENCRYPTED;
} }
} }
return packetLength; return packetLength;

View File

@ -74,6 +74,7 @@ import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
public abstract class SSLEngineTest { public abstract class SSLEngineTest {
@ -1082,4 +1083,32 @@ public abstract class SSLEngineTest {
cert.delete(); cert.delete();
} }
} }
@Test
public void testSSLEngineUnwrapNoSslRecord() throws Exception {
clientSslCtx = SslContextBuilder
.forClient()
.sslProvider(sslClientProvider())
.build();
SSLEngine client = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT);
try {
ByteBuffer src = ByteBuffer.allocate(client.getSession().getApplicationBufferSize());
ByteBuffer dst = ByteBuffer.allocate(client.getSession().getPacketBufferSize());
ByteBuffer empty = ByteBuffer.allocateDirect(0);
SSLEngineResult clientResult = client.wrap(empty, dst);
assertEquals(SSLEngineResult.Status.OK, clientResult.getStatus());
assertEquals(SSLEngineResult.HandshakeStatus.NEED_UNWRAP, clientResult.getHandshakeStatus());
try {
client.unwrap(src, dst);
fail();
} catch (SSLException expected) {
// expected
}
} finally {
cleanupClientSslEngine(client);
}
}
} }