Optimize SslHandler.unwrap() so that it doesn't produce unnecessarily many buffers

- Adapted from 4c7fa950cc4f4c52eeaae5887335b1f3047592f8
- Related issue: #1905
This commit is contained in:
Trustin Lee 2013-11-25 18:03:00 +09:00
parent 92bcbcd0e1
commit 7347bfec50
2 changed files with 146 additions and 114 deletions

View File

@ -241,7 +241,7 @@ public class SslHandler extends FrameDecoder
private boolean closeOnSSLException;
private int packetLength = Integer.MIN_VALUE;
private int packetLength;
private final Timer timer;
private final long handshakeTimeoutInMillis;
@ -786,7 +786,7 @@ public class SslHandler extends FrameDecoder
* Is thrown if the given {@link ChannelBuffer} has not at least 5 bytes to read.
*/
public static boolean isEncrypted(ChannelBuffer buffer) {
return getEncryptedPacketLength(buffer) != -1;
return getEncryptedPacketLength(buffer, buffer.readerIndex()) != -1;
}
/**
@ -802,33 +802,29 @@ public class SslHandler extends FrameDecoder
* @throws IllegalArgumentException
* Is thrown if the given {@link ChannelBuffer} has not at least 5 bytes to read.
*/
private static int getEncryptedPacketLength(ChannelBuffer buffer) {
if (buffer.readableBytes() < 5) {
throw new IllegalArgumentException("buffer must have at least 5 readable bytes");
}
private static int getEncryptedPacketLength(ChannelBuffer buffer, int offset) {
int packetLength = 0;
// SSLv3 or TLS - Check ContentType
boolean tls;
switch (buffer.getUnsignedByte(buffer.readerIndex())) {
case 20: // change_cipher_spec
case 21: // alert
case 22: // handshake
case 23: // application_data
tls = true;
break;
default:
// SSLv2 or bad data
tls = false;
switch (buffer.getUnsignedByte(offset)) {
case 20: // change_cipher_spec
case 21: // alert
case 22: // handshake
case 23: // application_data
tls = true;
break;
default:
// SSLv2 or bad data
tls = false;
}
if (tls) {
// SSLv3 or TLS - Check ProtocolVersion
int majorVersion = buffer.getUnsignedByte(buffer.readerIndex() + 1);
int majorVersion = buffer.getUnsignedByte(offset + 1);
if (majorVersion == 3) {
// SSLv3 or TLS
packetLength = (getShort(buffer, buffer.readerIndex() + 3) & 0xFFFF) + 5;
packetLength = (getShort(buffer, offset + 3) & 0xFFFF) + 5;
if (packetLength <= 5) {
// Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
tls = false;
@ -842,16 +838,14 @@ public class SslHandler extends FrameDecoder
if (!tls) {
// SSLv2 or bad data - Check the version
boolean sslv2 = true;
int headerLength = (buffer.getUnsignedByte(
buffer.readerIndex()) & 0x80) != 0 ? 2 : 3;
int majorVersion = buffer.getUnsignedByte(
buffer.readerIndex() + headerLength + 1);
int headerLength = (buffer.getUnsignedByte(offset) & 0x80) != 0 ? 2 : 3;
int majorVersion = buffer.getUnsignedByte(offset + headerLength + 1);
if (majorVersion == 2 || majorVersion == 3) {
// SSLv2
if (headerLength == 2) {
packetLength = (getShort(buffer, buffer.readerIndex()) & 0x7FFF) + 2;
packetLength = (getShort(buffer, offset) & 0x7FFF) + 2;
} else {
packetLength = (getShort(buffer, buffer.readerIndex()) & 0x3FFF) + 3;
packetLength = (getShort(buffer, offset) & 0x3FFF) + 3;
}
if (packetLength <= headerLength) {
sslv2 = false;
@ -869,64 +863,82 @@ public class SslHandler extends FrameDecoder
@Override
protected Object decode(
final ChannelHandlerContext ctx, Channel channel, ChannelBuffer buffer) throws Exception {
final ChannelHandlerContext ctx, Channel channel, ChannelBuffer in) throws Exception {
// Check if the packet length was parsed yet, if so we can skip the parsing
if (packetLength == Integer.MIN_VALUE) {
if (buffer.readableBytes() < 5) {
final int startOffset = in.readerIndex();
final int endOffset = in.writerIndex();
int offset = startOffset;
// If we calculated the length of the current SSL record before, use that information.
if (packetLength > 0) {
if (endOffset - startOffset < packetLength) {
return null;
} else {
offset += packetLength;
packetLength = 0;
}
int packetLength = getEncryptedPacketLength(buffer);
}
boolean nonSslRecord = false;
for (;;) {
final int readableBytes = endOffset - offset;
if (readableBytes < 5) {
break;
}
final int packetLength = getEncryptedPacketLength(in, offset);
if (packetLength == -1) {
// Bad data - discard the buffer and raise an exception.
NotSslRecordException e = new NotSslRecordException(
"not an SSL/TLS record: " + ChannelBuffers.hexDump(buffer));
buffer.skipBytes(buffer.readableBytes());
if (closeOnSSLException) {
// first trigger the exception and then close the channel
fireExceptionCaught(ctx, e);
Channels.close(ctx, future(channel));
// just return null as we closed the channel before, that
// will take care of cleanup etc
return null;
} else {
throw e;
}
nonSslRecord = true;
break;
}
assert packetLength > 0;
this.packetLength = packetLength;
if (packetLength > readableBytes) {
// wait until the whole packet can be read
this.packetLength = packetLength;
break;
}
offset += packetLength;
}
if (buffer.readableBytes() < packetLength) {
return null;
final int length = offset - startOffset;
ChannelBuffer unwrapped = null;
if (length > 0) {
// The buffer contains one or more full SSL records.
// Slice out the whole packet so unwrap will only be called with complete packets.
// Also directly reset the packetLength. This is needed as unwrap(..) may trigger
// decode(...) again via:
// 1) unwrap(..) is called
// 2) wrap(...) is called from within unwrap(...)
// 3) wrap(...) calls unwrapLater(...)
// 4) unwrapLater(...) calls decode(...)
//
// See https://github.com/netty/netty/issues/1534
unwrapped = unwrap(ctx, channel, in, startOffset, length);
}
// We advance the buffer's readerIndex before calling unwrap() because
// unwrap() can trigger FrameDecoder call decode(), this method, recursively.
// The recursive call results in decoding the same packet twice if
// the readerIndex is advanced *after* decode().
//
// Here's an example:
// 1) An SSL packet is received from the wire.
// 2) SslHandler.decode() deciphers the packet and calls the user code.
// 3) The user closes the channel in the same thread.
// 4) The same thread triggers a channelDisconnected() event.
// 5) FrameDecoder.cleanup() is called, and it calls SslHandler.decode().
// 6) SslHandler.decode() will feed the same packet with what was
// deciphered at the step 2 again if the readerIndex was not advanced
// before calling the user code.
final int packetOffset = buffer.readerIndex();
buffer.skipBytes(packetLength);
try {
return unwrap(ctx, channel, buffer, packetOffset, packetLength);
} finally {
// reset the packet length so it will be parsed again on the next call
packetLength = Integer.MIN_VALUE;
if (nonSslRecord) {
// Not an SSL/TLS packet
NotSslRecordException e = new NotSslRecordException(
"not an SSL/TLS record: " + ChannelBuffers.hexDump(in));
in.skipBytes(in.readableBytes());
if (closeOnSSLException) {
// first trigger the exception and then close the channel
fireExceptionCaught(ctx, e);
Channels.close(ctx, future(channel));
// just return null as we closed the channel before, that
// will take care of cleanup etc
return null;
} else {
throw e;
}
}
return unwrapped;
}
/**
@ -1212,12 +1224,15 @@ public class SslHandler extends FrameDecoder
private ChannelBuffer unwrap(
ChannelHandlerContext ctx, Channel channel,
ChannelBuffer buffer, int offset, int length) throws SSLException {
ByteBuffer inNetBuf = buffer.toByteBuffer(offset, length);
ByteBuffer outAppBuf = bufferPool.acquireBuffer();
final ByteBuffer inNetBuf = buffer.toByteBuffer(offset, length);
final ByteBuffer outAppBuf = bufferPool.acquireBuffer();
final int bufferStartOffset = buffer.readerIndex();
final int inNetBufStartOffset = inNetBuf.position();
ChannelBuffer frame = null;
try {
boolean needsWrap = false;
loop:
for (;;) {
SSLEngineResult result;
boolean needsHandshake = false;
@ -1234,26 +1249,47 @@ public class SslHandler extends FrameDecoder
}
synchronized (handshakeLock) {
result = engine.unwrap(inNetBuf, outAppBuf);
// Decrypt at least one record in the inbound network buffer.
// It is impossible to consume no record here because we made sure the inbound network buffer
// always contain at least one record in decode(). Therefore, if SSLEngine.unwrap() returns
// BUFFER_OVERFLOW, it is always resolved by retrying after emptying the application buffer.
for (;;) {
try {
result = engine.unwrap(inNetBuf, outAppBuf);
switch (result.getStatus()) {
case CLOSED:
// notify about the CLOSED state of the SSLEngine. See #137
sslEngineCloseFuture.setClosed();
break;
case BUFFER_OVERFLOW:
// Flush the unwrapped data in the outAppBuf into frame and try again.
// See the finally block.
continue;
}
switch (result.getStatus()) {
case CLOSED:
// notify about the CLOSED state of the SSLEngine. See #137
sslEngineCloseFuture.setClosed();
break;
case BUFFER_OVERFLOW:
throw new SSLException("SSLEngine.unwrap() reported an impossible buffer overflow.");
} finally {
outAppBuf.flip();
// Sync the offset of the inbound buffer.
buffer.readerIndex(bufferStartOffset + inNetBuf.position() - inNetBufStartOffset);
// Copy the unwrapped data into a smaller buffer.
if (outAppBuf.hasRemaining()) {
if (frame == null) {
frame = ctx.getChannel().getConfig().getBufferFactory().getBuffer(length);
}
frame.writeBytes(outAppBuf);
}
outAppBuf.clear();
}
}
final HandshakeStatus handshakeStatus = result.getHandshakeStatus();
handleRenegotiation(handshakeStatus);
switch (handshakeStatus) {
case NEED_UNWRAP:
if (inNetBuf.hasRemaining() && !engine.isInboundDone()) {
break;
} else {
break loop;
}
break;
case NEED_WRAP:
wrapNonAppData(ctx, channel);
break;
@ -1263,16 +1299,21 @@ public class SslHandler extends FrameDecoder
case FINISHED:
setHandshakeSuccess(channel);
needsWrap = true;
break loop;
continue;
case NOT_HANDSHAKING:
needsWrap = true;
break loop;
break;
default:
throw new IllegalStateException(
"Unknown handshake status: " + handshakeStatus);
}
if (result.getStatus() == Status.BUFFER_UNDERFLOW ||
result.bytesConsumed() == 0 && result.bytesProduced() == 0) {
break;
}
}
}
if (needsWrap) {
// wrap() acquires pendingUnencryptedWrites first and then
// handshakeLock. If handshakeLock is already hold by the
@ -1287,26 +1328,18 @@ public class SslHandler extends FrameDecoder
wrap(ctx, channel);
}
}
outAppBuf.flip();
if (outAppBuf.hasRemaining()) {
ChannelBuffer frame = ctx.getChannel().getConfig().getBufferFactory().getBuffer(outAppBuf.remaining());
// Transfer the bytes to the new ChannelBuffer using some safe method that will also
// work with "non" heap buffers
//
// See https://github.com/netty/netty/issues/329
frame.writeBytes(outAppBuf);
return frame;
} else {
return null;
}
} catch (SSLException e) {
setHandshakeFailure(channel, e);
throw e;
} finally {
bufferPool.releaseBuffer(outAppBuf);
}
if (frame != null && frame.readable()) {
return frame;
} else {
return null;
}
}
private void handleRenegotiation(HandshakeStatus handshakeStatus) {

View File

@ -15,18 +15,6 @@
*/
package org.jboss.netty.handler.ssl;
import static org.junit.Assert.*;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.Random;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.SSLEngine;
import org.jboss.netty.bootstrap.ClientBootstrap;
import org.jboss.netty.bootstrap.ServerBootstrap;
import org.jboss.netty.buffer.ChannelBuffer;
@ -47,6 +35,17 @@ import org.jboss.netty.logging.InternalLoggerFactory;
import org.jboss.netty.util.TestUtil;
import org.junit.Test;
import javax.net.ssl.SSLEngine;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.Random;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicReference;
import static org.junit.Assert.*;
public abstract class AbstractSocketSslEchoTest {
static final InternalLogger logger =
InternalLoggerFactory.getInstance(AbstractSocketSslEchoTest.class);