Optimize SslHandler.unwrap() so that it doesn't produce unnecessarily many buffers
- Adapted from 4c7fa950cc4f4c52eeaae5887335b1f3047592f8 - Related issue: #1905
This commit is contained in:
parent
92bcbcd0e1
commit
7347bfec50
@ -241,7 +241,7 @@ public class SslHandler extends FrameDecoder
|
||||
|
||||
private boolean closeOnSSLException;
|
||||
|
||||
private int packetLength = Integer.MIN_VALUE;
|
||||
private int packetLength;
|
||||
|
||||
private final Timer timer;
|
||||
private final long handshakeTimeoutInMillis;
|
||||
@ -786,7 +786,7 @@ public class SslHandler extends FrameDecoder
|
||||
* Is thrown if the given {@link ChannelBuffer} has not at least 5 bytes to read.
|
||||
*/
|
||||
public static boolean isEncrypted(ChannelBuffer buffer) {
|
||||
return getEncryptedPacketLength(buffer) != -1;
|
||||
return getEncryptedPacketLength(buffer, buffer.readerIndex()) != -1;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -802,33 +802,29 @@ public class SslHandler extends FrameDecoder
|
||||
* @throws IllegalArgumentException
|
||||
* Is thrown if the given {@link ChannelBuffer} has not at least 5 bytes to read.
|
||||
*/
|
||||
private static int getEncryptedPacketLength(ChannelBuffer buffer) {
|
||||
if (buffer.readableBytes() < 5) {
|
||||
throw new IllegalArgumentException("buffer must have at least 5 readable bytes");
|
||||
}
|
||||
|
||||
private static int getEncryptedPacketLength(ChannelBuffer buffer, int offset) {
|
||||
int packetLength = 0;
|
||||
|
||||
// SSLv3 or TLS - Check ContentType
|
||||
boolean tls;
|
||||
switch (buffer.getUnsignedByte(buffer.readerIndex())) {
|
||||
case 20: // change_cipher_spec
|
||||
case 21: // alert
|
||||
case 22: // handshake
|
||||
case 23: // application_data
|
||||
tls = true;
|
||||
break;
|
||||
default:
|
||||
// SSLv2 or bad data
|
||||
tls = false;
|
||||
switch (buffer.getUnsignedByte(offset)) {
|
||||
case 20: // change_cipher_spec
|
||||
case 21: // alert
|
||||
case 22: // handshake
|
||||
case 23: // application_data
|
||||
tls = true;
|
||||
break;
|
||||
default:
|
||||
// SSLv2 or bad data
|
||||
tls = false;
|
||||
}
|
||||
|
||||
if (tls) {
|
||||
// SSLv3 or TLS - Check ProtocolVersion
|
||||
int majorVersion = buffer.getUnsignedByte(buffer.readerIndex() + 1);
|
||||
int majorVersion = buffer.getUnsignedByte(offset + 1);
|
||||
if (majorVersion == 3) {
|
||||
// SSLv3 or TLS
|
||||
packetLength = (getShort(buffer, buffer.readerIndex() + 3) & 0xFFFF) + 5;
|
||||
packetLength = (getShort(buffer, offset + 3) & 0xFFFF) + 5;
|
||||
if (packetLength <= 5) {
|
||||
// Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
|
||||
tls = false;
|
||||
@ -842,16 +838,14 @@ public class SslHandler extends FrameDecoder
|
||||
if (!tls) {
|
||||
// SSLv2 or bad data - Check the version
|
||||
boolean sslv2 = true;
|
||||
int headerLength = (buffer.getUnsignedByte(
|
||||
buffer.readerIndex()) & 0x80) != 0 ? 2 : 3;
|
||||
int majorVersion = buffer.getUnsignedByte(
|
||||
buffer.readerIndex() + headerLength + 1);
|
||||
int headerLength = (buffer.getUnsignedByte(offset) & 0x80) != 0 ? 2 : 3;
|
||||
int majorVersion = buffer.getUnsignedByte(offset + headerLength + 1);
|
||||
if (majorVersion == 2 || majorVersion == 3) {
|
||||
// SSLv2
|
||||
if (headerLength == 2) {
|
||||
packetLength = (getShort(buffer, buffer.readerIndex()) & 0x7FFF) + 2;
|
||||
packetLength = (getShort(buffer, offset) & 0x7FFF) + 2;
|
||||
} else {
|
||||
packetLength = (getShort(buffer, buffer.readerIndex()) & 0x3FFF) + 3;
|
||||
packetLength = (getShort(buffer, offset) & 0x3FFF) + 3;
|
||||
}
|
||||
if (packetLength <= headerLength) {
|
||||
sslv2 = false;
|
||||
@ -869,64 +863,82 @@ public class SslHandler extends FrameDecoder
|
||||
|
||||
@Override
|
||||
protected Object decode(
|
||||
final ChannelHandlerContext ctx, Channel channel, ChannelBuffer buffer) throws Exception {
|
||||
final ChannelHandlerContext ctx, Channel channel, ChannelBuffer in) throws Exception {
|
||||
|
||||
// Check if the packet length was parsed yet, if so we can skip the parsing
|
||||
if (packetLength == Integer.MIN_VALUE) {
|
||||
if (buffer.readableBytes() < 5) {
|
||||
final int startOffset = in.readerIndex();
|
||||
final int endOffset = in.writerIndex();
|
||||
int offset = startOffset;
|
||||
|
||||
// If we calculated the length of the current SSL record before, use that information.
|
||||
if (packetLength > 0) {
|
||||
if (endOffset - startOffset < packetLength) {
|
||||
return null;
|
||||
} else {
|
||||
offset += packetLength;
|
||||
packetLength = 0;
|
||||
}
|
||||
int packetLength = getEncryptedPacketLength(buffer);
|
||||
}
|
||||
|
||||
boolean nonSslRecord = false;
|
||||
|
||||
for (;;) {
|
||||
final int readableBytes = endOffset - offset;
|
||||
if (readableBytes < 5) {
|
||||
break;
|
||||
}
|
||||
|
||||
final int packetLength = getEncryptedPacketLength(in, offset);
|
||||
if (packetLength == -1) {
|
||||
// Bad data - discard the buffer and raise an exception.
|
||||
NotSslRecordException e = new NotSslRecordException(
|
||||
"not an SSL/TLS record: " + ChannelBuffers.hexDump(buffer));
|
||||
buffer.skipBytes(buffer.readableBytes());
|
||||
|
||||
if (closeOnSSLException) {
|
||||
// first trigger the exception and then close the channel
|
||||
fireExceptionCaught(ctx, e);
|
||||
Channels.close(ctx, future(channel));
|
||||
|
||||
// just return null as we closed the channel before, that
|
||||
// will take care of cleanup etc
|
||||
return null;
|
||||
} else {
|
||||
throw e;
|
||||
}
|
||||
nonSslRecord = true;
|
||||
break;
|
||||
}
|
||||
|
||||
assert packetLength > 0;
|
||||
this.packetLength = packetLength;
|
||||
|
||||
if (packetLength > readableBytes) {
|
||||
// wait until the whole packet can be read
|
||||
this.packetLength = packetLength;
|
||||
break;
|
||||
}
|
||||
|
||||
offset += packetLength;
|
||||
}
|
||||
|
||||
if (buffer.readableBytes() < packetLength) {
|
||||
return null;
|
||||
final int length = offset - startOffset;
|
||||
ChannelBuffer unwrapped = null;
|
||||
if (length > 0) {
|
||||
// The buffer contains one or more full SSL records.
|
||||
// Slice out the whole packet so unwrap will only be called with complete packets.
|
||||
// Also directly reset the packetLength. This is needed as unwrap(..) may trigger
|
||||
// decode(...) again via:
|
||||
// 1) unwrap(..) is called
|
||||
// 2) wrap(...) is called from within unwrap(...)
|
||||
// 3) wrap(...) calls unwrapLater(...)
|
||||
// 4) unwrapLater(...) calls decode(...)
|
||||
//
|
||||
// See https://github.com/netty/netty/issues/1534
|
||||
unwrapped = unwrap(ctx, channel, in, startOffset, length);
|
||||
}
|
||||
|
||||
// We advance the buffer's readerIndex before calling unwrap() because
|
||||
// unwrap() can trigger FrameDecoder call decode(), this method, recursively.
|
||||
// The recursive call results in decoding the same packet twice if
|
||||
// the readerIndex is advanced *after* decode().
|
||||
//
|
||||
// Here's an example:
|
||||
// 1) An SSL packet is received from the wire.
|
||||
// 2) SslHandler.decode() deciphers the packet and calls the user code.
|
||||
// 3) The user closes the channel in the same thread.
|
||||
// 4) The same thread triggers a channelDisconnected() event.
|
||||
// 5) FrameDecoder.cleanup() is called, and it calls SslHandler.decode().
|
||||
// 6) SslHandler.decode() will feed the same packet with what was
|
||||
// deciphered at the step 2 again if the readerIndex was not advanced
|
||||
// before calling the user code.
|
||||
final int packetOffset = buffer.readerIndex();
|
||||
buffer.skipBytes(packetLength);
|
||||
try {
|
||||
return unwrap(ctx, channel, buffer, packetOffset, packetLength);
|
||||
} finally {
|
||||
// reset the packet length so it will be parsed again on the next call
|
||||
packetLength = Integer.MIN_VALUE;
|
||||
if (nonSslRecord) {
|
||||
// Not an SSL/TLS packet
|
||||
NotSslRecordException e = new NotSslRecordException(
|
||||
"not an SSL/TLS record: " + ChannelBuffers.hexDump(in));
|
||||
in.skipBytes(in.readableBytes());
|
||||
if (closeOnSSLException) {
|
||||
// first trigger the exception and then close the channel
|
||||
fireExceptionCaught(ctx, e);
|
||||
Channels.close(ctx, future(channel));
|
||||
|
||||
// just return null as we closed the channel before, that
|
||||
// will take care of cleanup etc
|
||||
return null;
|
||||
} else {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
return unwrapped;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -1212,12 +1224,15 @@ public class SslHandler extends FrameDecoder
|
||||
private ChannelBuffer unwrap(
|
||||
ChannelHandlerContext ctx, Channel channel,
|
||||
ChannelBuffer buffer, int offset, int length) throws SSLException {
|
||||
ByteBuffer inNetBuf = buffer.toByteBuffer(offset, length);
|
||||
ByteBuffer outAppBuf = bufferPool.acquireBuffer();
|
||||
|
||||
final ByteBuffer inNetBuf = buffer.toByteBuffer(offset, length);
|
||||
final ByteBuffer outAppBuf = bufferPool.acquireBuffer();
|
||||
final int bufferStartOffset = buffer.readerIndex();
|
||||
final int inNetBufStartOffset = inNetBuf.position();
|
||||
|
||||
ChannelBuffer frame = null;
|
||||
try {
|
||||
boolean needsWrap = false;
|
||||
loop:
|
||||
for (;;) {
|
||||
SSLEngineResult result;
|
||||
boolean needsHandshake = false;
|
||||
@ -1234,26 +1249,47 @@ public class SslHandler extends FrameDecoder
|
||||
}
|
||||
|
||||
synchronized (handshakeLock) {
|
||||
result = engine.unwrap(inNetBuf, outAppBuf);
|
||||
// Decrypt at least one record in the inbound network buffer.
|
||||
// It is impossible to consume no record here because we made sure the inbound network buffer
|
||||
// always contain at least one record in decode(). Therefore, if SSLEngine.unwrap() returns
|
||||
// BUFFER_OVERFLOW, it is always resolved by retrying after emptying the application buffer.
|
||||
for (;;) {
|
||||
try {
|
||||
result = engine.unwrap(inNetBuf, outAppBuf);
|
||||
switch (result.getStatus()) {
|
||||
case CLOSED:
|
||||
// notify about the CLOSED state of the SSLEngine. See #137
|
||||
sslEngineCloseFuture.setClosed();
|
||||
break;
|
||||
case BUFFER_OVERFLOW:
|
||||
// Flush the unwrapped data in the outAppBuf into frame and try again.
|
||||
// See the finally block.
|
||||
continue;
|
||||
}
|
||||
|
||||
switch (result.getStatus()) {
|
||||
case CLOSED:
|
||||
// notify about the CLOSED state of the SSLEngine. See #137
|
||||
sslEngineCloseFuture.setClosed();
|
||||
break;
|
||||
case BUFFER_OVERFLOW:
|
||||
throw new SSLException("SSLEngine.unwrap() reported an impossible buffer overflow.");
|
||||
} finally {
|
||||
outAppBuf.flip();
|
||||
|
||||
// Sync the offset of the inbound buffer.
|
||||
buffer.readerIndex(bufferStartOffset + inNetBuf.position() - inNetBufStartOffset);
|
||||
|
||||
// Copy the unwrapped data into a smaller buffer.
|
||||
if (outAppBuf.hasRemaining()) {
|
||||
if (frame == null) {
|
||||
frame = ctx.getChannel().getConfig().getBufferFactory().getBuffer(length);
|
||||
}
|
||||
frame.writeBytes(outAppBuf);
|
||||
}
|
||||
outAppBuf.clear();
|
||||
}
|
||||
}
|
||||
|
||||
final HandshakeStatus handshakeStatus = result.getHandshakeStatus();
|
||||
handleRenegotiation(handshakeStatus);
|
||||
switch (handshakeStatus) {
|
||||
case NEED_UNWRAP:
|
||||
if (inNetBuf.hasRemaining() && !engine.isInboundDone()) {
|
||||
break;
|
||||
} else {
|
||||
break loop;
|
||||
}
|
||||
break;
|
||||
case NEED_WRAP:
|
||||
wrapNonAppData(ctx, channel);
|
||||
break;
|
||||
@ -1263,16 +1299,21 @@ public class SslHandler extends FrameDecoder
|
||||
case FINISHED:
|
||||
setHandshakeSuccess(channel);
|
||||
needsWrap = true;
|
||||
break loop;
|
||||
continue;
|
||||
case NOT_HANDSHAKING:
|
||||
needsWrap = true;
|
||||
break loop;
|
||||
break;
|
||||
default:
|
||||
throw new IllegalStateException(
|
||||
"Unknown handshake status: " + handshakeStatus);
|
||||
}
|
||||
|
||||
if (result.getStatus() == Status.BUFFER_UNDERFLOW ||
|
||||
result.bytesConsumed() == 0 && result.bytesProduced() == 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (needsWrap) {
|
||||
// wrap() acquires pendingUnencryptedWrites first and then
|
||||
// handshakeLock. If handshakeLock is already hold by the
|
||||
@ -1287,26 +1328,18 @@ public class SslHandler extends FrameDecoder
|
||||
wrap(ctx, channel);
|
||||
}
|
||||
}
|
||||
outAppBuf.flip();
|
||||
|
||||
if (outAppBuf.hasRemaining()) {
|
||||
ChannelBuffer frame = ctx.getChannel().getConfig().getBufferFactory().getBuffer(outAppBuf.remaining());
|
||||
// Transfer the bytes to the new ChannelBuffer using some safe method that will also
|
||||
// work with "non" heap buffers
|
||||
//
|
||||
// See https://github.com/netty/netty/issues/329
|
||||
frame.writeBytes(outAppBuf);
|
||||
|
||||
return frame;
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
} catch (SSLException e) {
|
||||
setHandshakeFailure(channel, e);
|
||||
throw e;
|
||||
} finally {
|
||||
bufferPool.releaseBuffer(outAppBuf);
|
||||
}
|
||||
|
||||
if (frame != null && frame.readable()) {
|
||||
return frame;
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private void handleRenegotiation(HandshakeStatus handshakeStatus) {
|
||||
|
@ -15,18 +15,6 @@
|
||||
*/
|
||||
package org.jboss.netty.handler.ssl;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.util.Random;
|
||||
import java.util.concurrent.Executor;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import javax.net.ssl.SSLEngine;
|
||||
|
||||
import org.jboss.netty.bootstrap.ClientBootstrap;
|
||||
import org.jboss.netty.bootstrap.ServerBootstrap;
|
||||
import org.jboss.netty.buffer.ChannelBuffer;
|
||||
@ -47,6 +35,17 @@ import org.jboss.netty.logging.InternalLoggerFactory;
|
||||
import org.jboss.netty.util.TestUtil;
|
||||
import org.junit.Test;
|
||||
|
||||
import javax.net.ssl.SSLEngine;
|
||||
import java.io.IOException;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.util.Random;
|
||||
import java.util.concurrent.Executor;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public abstract class AbstractSocketSslEchoTest {
|
||||
static final InternalLogger logger =
|
||||
InternalLoggerFactory.getInstance(AbstractSocketSslEchoTest.class);
|
||||
|
Loading…
x
Reference in New Issue
Block a user