Correct guard against non SSL data in ReferenceCountedOpenSslEngine

Motivation:

When non SSL data is passed into SSLEngine.unwrap(...) we need to throw an SSLException. This was not done at the moment. Even worse we threw an IllegalArgumentException as we tried to allocate a direct buffer with capacity of -1.

Modifications:

- Guard against non SSL data and added an unit test.
- Make code more consistent

Result:

Correct behaving SSLEngine implementation.
This commit is contained in:
Norman Maurer 2016-12-01 09:30:26 +01:00 committed by Norman Maurer
parent ae1234c303
commit 0ca2c3016b
5 changed files with 69 additions and 22 deletions

View File

@ -796,6 +796,11 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
}
int packetLength = SslUtils.getEncryptedPacketLength(srcs, srcsOffset);
if (packetLength == SslUtils.NOT_ENCRYPTED) {
throw new NotSslRecordException("not an SSL/TLS record");
}
if (packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH > capacity) {
// No enough space in the destination buffer so signal the caller
// that the buffer needs to be increased.

View File

@ -130,7 +130,7 @@ public class SniHandler extends ByteToMessageDecoder implements ChannelOutboundH
final int len = SslUtils.getEncryptedPacketLength(in, readerIndex);
// Not an SSL/TLS packet
if (len == -1) {
if (len == SslUtils.NOT_ENCRYPTED) {
handshakeFailed = true;
NotSslRecordException e = new NotSslRecordException(
"not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
@ -140,7 +140,8 @@ public class SniHandler extends ByteToMessageDecoder implements ChannelOutboundH
SslUtils.notifyHandshakeFailure(ctx, e);
return;
}
if (writerIndex - readerIndex - SslUtils.SSL_RECORD_HEADER_LENGTH < len) {
if (len == SslUtils.NOT_ENOUGH_DATA ||
writerIndex - readerIndex - SslUtils.SSL_RECORD_HEADER_LENGTH < len) {
// Not enough data
return;
}

View File

@ -877,7 +877,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
throw new IllegalArgumentException(
"buffer must have at least " + SslUtils.SSL_RECORD_HEADER_LENGTH + " readable bytes");
}
return getEncryptedPacketLength(buffer, buffer.readerIndex()) != -1;
return getEncryptedPacketLength(buffer, buffer.readerIndex()) != SslUtils.NOT_ENCRYPTED;
}
@Override
@ -907,7 +907,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
}
final int packetLength = getEncryptedPacketLength(in, offset);
if (packetLength == -1) {
if (packetLength == SslUtils.NOT_ENCRYPTED) {
nonSslRecord = true;
break;
}

View File

@ -31,27 +31,37 @@ final class SslUtils {
/**
* change cipher spec
*/
public static final int SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC = 20;
static final int SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC = 20;
/**
* alert
*/
public static final int SSL_CONTENT_TYPE_ALERT = 21;
static final int SSL_CONTENT_TYPE_ALERT = 21;
/**
* handshake
*/
public static final int SSL_CONTENT_TYPE_HANDSHAKE = 22;
static final int SSL_CONTENT_TYPE_HANDSHAKE = 22;
/**
* application data
*/
public static final int SSL_CONTENT_TYPE_APPLICATION_DATA = 23;
static final int SSL_CONTENT_TYPE_APPLICATION_DATA = 23;
/**
* the length of the ssl record header (in bytes)
*/
public static final int SSL_RECORD_HEADER_LENGTH = 5;
static final int SSL_RECORD_HEADER_LENGTH = 5;
/**
* Not enough data in buffer to parse the record length
*/
static final int NOT_ENOUGH_DATA = -1;
/**
* data is not encrypted
*/
static final int NOT_ENCRYPTED = -2;
/**
* Return how much bytes can be read out of the encrypted data. Be aware that this method will not increase
@ -62,8 +72,10 @@ final class SslUtils {
* {@link #SSL_RECORD_HEADER_LENGTH} bytes to read,
* otherwise it will throw an {@link IllegalArgumentException}.
* @return length
* The length of the encrypted packet that is included in the buffer. This will
* return {@code -1} if the given {@link ByteBuf} is not encrypted at all.
* The length of the encrypted packet that is included in the buffer or
* {@link #SslUtils#NOT_ENOUGH_DATA} if not enought data is present in the
* {@link ByteBuf}. This will return {@link SslUtils#NOT_ENCRYPTED} if
* the given {@link ByteBuf} is not encrypted at all.
* @throws IllegalArgumentException
* Is thrown if the given {@link ByteBuf} has not at least {@link #SSL_RECORD_HEADER_LENGTH}
* bytes to read.
@ -113,10 +125,10 @@ final class SslUtils {
packetLength = (buffer.getShort(offset) & 0x3FFF) + 3;
}
if (packetLength <= headerLength) {
return -1;
return NOT_ENOUGH_DATA;
}
} else {
return -1;
return NOT_ENCRYPTED;
}
}
return packetLength;
@ -134,7 +146,7 @@ final class SslUtils {
ByteBuffer buffer = buffers[offset];
// Check if everything we need is in one ByteBuffer. If so we can make use of the fast-path.
if (buffer.remaining() >= SslUtils.SSL_RECORD_HEADER_LENGTH) {
if (buffer.remaining() >= SSL_RECORD_HEADER_LENGTH) {
return getEncryptedPacketLength(buffer);
}
@ -160,10 +172,10 @@ final class SslUtils {
// SSLv3 or TLS - Check ContentType
boolean tls;
switch (unsignedByte(buffer.get(pos))) {
case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
case SslUtils.SSL_CONTENT_TYPE_ALERT:
case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE:
case SslUtils.SSL_CONTENT_TYPE_APPLICATION_DATA:
case SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
case SSL_CONTENT_TYPE_ALERT:
case SSL_CONTENT_TYPE_HANDSHAKE:
case SSL_CONTENT_TYPE_APPLICATION_DATA:
tls = true;
break;
default:
@ -176,8 +188,8 @@ final class SslUtils {
int majorVersion = unsignedByte(buffer.get(pos + 1));
if (majorVersion == 3) {
// SSLv3 or TLS
packetLength = unsignedShort(buffer.getShort(pos + 3)) + SslUtils.SSL_RECORD_HEADER_LENGTH;
if (packetLength <= SslUtils.SSL_RECORD_HEADER_LENGTH) {
packetLength = unsignedShort(buffer.getShort(pos + 3)) + SSL_RECORD_HEADER_LENGTH;
if (packetLength <= SSL_RECORD_HEADER_LENGTH) {
// Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
tls = false;
}
@ -199,10 +211,10 @@ final class SslUtils {
packetLength = (buffer.getShort(pos) & 0x3FFF) + 3;
}
if (packetLength <= headerLength) {
return -1;
return NOT_ENOUGH_DATA;
}
} else {
return -1;
return NOT_ENCRYPTED;
}
}
return packetLength;

View File

@ -74,6 +74,7 @@ import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.verify;
public abstract class SSLEngineTest {
@ -1082,4 +1083,32 @@ public abstract class SSLEngineTest {
cert.delete();
}
}
@Test
public void testSSLEngineUnwrapNoSslRecord() throws Exception {
clientSslCtx = SslContextBuilder
.forClient()
.sslProvider(sslClientProvider())
.build();
SSLEngine client = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT);
try {
ByteBuffer src = ByteBuffer.allocate(client.getSession().getApplicationBufferSize());
ByteBuffer dst = ByteBuffer.allocate(client.getSession().getPacketBufferSize());
ByteBuffer empty = ByteBuffer.allocateDirect(0);
SSLEngineResult clientResult = client.wrap(empty, dst);
assertEquals(SSLEngineResult.Status.OK, clientResult.getStatus());
assertEquals(SSLEngineResult.HandshakeStatus.NEED_UNWRAP, clientResult.getHandshakeStatus());
try {
client.unwrap(src, dst);
fail();
} catch (SSLException expected) {
// expected
}
} finally {
cleanupClientSslEngine(client);
}
}
}