Optimize SslHandler

- Fixes #1905
- Call ctx.flush() only when necessary
- Improve the estimation of application and packet buffer sizes
- decode() method now tries to call unwrap() with as many SSL records as
  possible to reduce the number of events triggered
This commit is contained in:
Trustin Lee 2013-11-08 14:21:19 +09:00
parent db78581bbb
commit 11f95c78e2

View File

@ -176,6 +176,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
private volatile ChannelHandlerContext ctx;
private final SSLEngine engine;
private final int maxPacketBufferSize;
private final Executor delegatedTaskExecutor;
private final boolean startTls;
@ -185,6 +186,12 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
private final LazyChannelPromise sslCloseFuture = new LazyChannelPromise();
private final Deque<PendingWrite> pendingUnencryptedWrites = new ArrayDeque<PendingWrite>();
/**
* Set by wrap*() methods when something is produced.
* {@link #channelReadComplete(ChannelHandlerContext)} will check this flag, clear it, and call ctx.flush().
*/
private boolean needsFlush;
private int packetLength;
private ByteBuf decodeOut;
@ -246,6 +253,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
this.engine = engine;
this.delegatedTaskExecutor = delegatedTaskExecutor;
this.startTls = startTls;
maxPacketBufferSize = engine.getSession().getPacketBufferSize();
}
public long getHandshakeTimeoutMillis() {
@ -372,6 +380,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
}
@Override
@Deprecated
public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
ctx.deregister(promise);
}
@ -417,10 +426,11 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
if (pendingUnencryptedWrites.isEmpty()) {
pendingUnencryptedWrites.add(PendingWrite.newInstance(Unpooled.EMPTY_BUFFER, null));
}
flush0(ctx);
wrap(ctx, false);
ctx.flush();
}
private void flush0(ChannelHandlerContext ctx) throws SSLException {
private void wrap(ChannelHandlerContext ctx, boolean inUnwrap) throws SSLException {
ByteBuf out = null;
ChannelPromise promise = null;
try {
@ -430,8 +440,9 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
break;
}
if (out == null) {
out = ctx.alloc().buffer();
out = ctx.alloc().buffer(maxPacketBufferSize);
}
ByteBuf buf = (ByteBuf) pending.msg();
SSLEngineResult result = wrap(engine, buf, out);
@ -464,13 +475,9 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
// deliberate fall-through
case NOT_HANDSHAKING:
case NEED_WRAP:
if (promise != null) {
ctx.writeAndFlush(out, promise);
promise = null;
} else {
ctx.writeAndFlush(out);
}
out = ctx.alloc().buffer();
finishWrap(ctx, out, promise, inUnwrap);
promise = null;
out = null;
break;
case NEED_UNWRAP:
return;
@ -483,33 +490,46 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
setHandshakeFailure(e);
throw e;
} finally {
if (out != null && out.isReadable()) {
finishWrap(ctx, out, promise, inUnwrap);
}
}
private void finishWrap(ChannelHandlerContext ctx, ByteBuf out, ChannelPromise promise, boolean inUnwrap) {
if (out != null) {
if (out.isReadable()) {
if (promise != null) {
ctx.writeAndFlush(out, promise);
ctx.write(out, promise);
} else {
ctx.writeAndFlush(out);
ctx.write(out);
}
out = null;
} else if (promise != null) {
promise.trySuccess();
}
if (out != null) {
if (inUnwrap) {
needsFlush = true;
}
} else {
out.release();
}
} else if (promise != null) {
ctx.write(Unpooled.EMPTY_BUFFER, promise);
if (inUnwrap) {
needsFlush = true;
}
}
}
private void flushNonAppData0(ChannelHandlerContext ctx, boolean inUnwrap) throws SSLException {
private void wrapNonAppData(ChannelHandlerContext ctx, boolean inUnwrap) throws SSLException {
ByteBuf out = null;
try {
for (;;) {
if (out == null) {
out = ctx.alloc().buffer();
out = ctx.alloc().buffer(maxPacketBufferSize);
}
SSLEngineResult result = wrap(engine, Unpooled.EMPTY_BUFFER, out);
if (result.bytesProduced() > 0) {
ctx.writeAndFlush(out);
ctx.write(out);
if (inUnwrap) {
needsFlush = true;
}
out = null;
}
@ -552,17 +572,20 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
}
}
private static SSLEngineResult wrap(SSLEngine engine, ByteBuf in, ByteBuf out) throws SSLException {
private SSLEngineResult wrap(SSLEngine engine, ByteBuf in, ByteBuf out) throws SSLException {
ByteBuffer in0 = in.nioBuffer();
for (;;) {
ByteBuffer out0 = out.nioBuffer(out.writerIndex(), out.writableBytes());
SSLEngineResult result = engine.wrap(in0, out0);
in.skipBytes(result.bytesConsumed());
out.writerIndex(out.writerIndex() + result.bytesProduced());
if (result.getStatus() == Status.BUFFER_OVERFLOW) {
out.ensureWritable(engine.getSession().getPacketBufferSize());
} else {
return result;
switch (result.getStatus()) {
case BUFFER_OVERFLOW:
out.ensureWritable(maxPacketBufferSize);
break;
default:
return result;
}
}
}
@ -678,7 +701,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
if (buffer.readableBytes() < 5) {
throw new IllegalArgumentException("buffer must have at least 5 readable bytes");
}
return getEncryptedPacketLength(buffer) != -1;
return getEncryptedPacketLength(buffer, buffer.readerIndex()) != -1;
}
/**
@ -694,13 +717,12 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
* @throws IllegalArgumentException
* Is thrown if the given {@link ByteBuf} has not at least 5 bytes to read.
*/
private static int getEncryptedPacketLength(ByteBuf buffer) {
int first = buffer.readerIndex();
private static int getEncryptedPacketLength(ByteBuf buffer, int offset) {
int packetLength = 0;
// SSLv3 or TLS - Check ContentType
boolean tls;
switch (buffer.getUnsignedByte(first)) {
switch (buffer.getUnsignedByte(offset)) {
case 20: // change_cipher_spec
case 21: // alert
case 22: // handshake
@ -714,10 +736,10 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
if (tls) {
// SSLv3 or TLS - Check ProtocolVersion
int majorVersion = buffer.getUnsignedByte(first + 1);
int majorVersion = buffer.getUnsignedByte(offset + 1);
if (majorVersion == 3) {
// SSLv3 or TLS
packetLength = buffer.getUnsignedShort(first + 3) + 5;
packetLength = buffer.getUnsignedShort(offset + 3) + 5;
if (packetLength <= 5) {
// Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
tls = false;
@ -731,14 +753,14 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
if (!tls) {
// SSLv2 or bad data - Check the version
boolean sslv2 = true;
int headerLength = (buffer.getUnsignedByte(first) & 0x80) != 0 ? 2 : 3;
int majorVersion = buffer.getUnsignedByte(first + headerLength + 1);
int headerLength = (buffer.getUnsignedByte(offset) & 0x80) != 0 ? 2 : 3;
int majorVersion = buffer.getUnsignedByte(offset + headerLength + 1);
if (majorVersion == 2 || majorVersion == 3) {
// SSLv2
if (headerLength == 2) {
packetLength = (buffer.getShort(first) & 0x7FFF) + 2;
packetLength = (buffer.getShort(offset) & 0x7FFF) + 2;
} else {
packetLength = (buffer.getShort(first) & 0x3FFF) + 3;
packetLength = (buffer.getShort(offset) & 0x3FFF) + 3;
}
if (packetLength <= headerLength) {
sslv2 = false;
@ -756,51 +778,79 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws SSLException {
int packetLength = this.packetLength;
final int readableBytes = in.readableBytes();
final int startOffset = in.readerIndex();
final int endOffset = in.writerIndex();
int offset = startOffset;
if (packetLength == 0) {
// the previous packet was consumed so try to read the length of the next packet
if (readableBytes < 5) {
// not enough bytes readable to read the packet length
// If we calculated the length of the current SSL record before, use that information.
if (packetLength > 0) {
if (endOffset - startOffset < packetLength) {
return;
} else {
offset += packetLength;
packetLength = 0;
}
}
boolean nonSslRecord = false;
for (;;) {
final int readableBytes = endOffset - offset;
if (readableBytes < 5) {
break;
}
packetLength = getEncryptedPacketLength(in);
final int packetLength = getEncryptedPacketLength(in, offset);
if (packetLength == -1) {
// Not an SSL/TLS packet
NotSslRecordException e = new NotSslRecordException(
"not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
in.skipBytes(readableBytes);
ctx.fireExceptionCaught(e);
setHandshakeFailure(e);
return;
nonSslRecord = true;
break;
}
assert packetLength > 0;
this.packetLength = packetLength;
if (packetLength > readableBytes) {
// wait until the whole packet can be read
this.packetLength = packetLength;
break;
}
offset += packetLength;
}
if (readableBytes < packetLength) {
// wait until the whole packet can be read
return;
final int length = offset - startOffset;
if (length > 0) {
// The buffer contains one or more full SSL records.
// Slice out the whole packet so unwrap will only be called with complete packets.
// Also directly reset the packetLength. This is needed as unwrap(..) may trigger
// decode(...) again via:
// 1) unwrap(..) is called
// 2) wrap(...) is called from within unwrap(...)
// 3) wrap(...) calls unwrapLater(...)
// 4) unwrapLater(...) calls decode(...)
//
// See https://github.com/netty/netty/issues/1534
in.skipBytes(length);
ByteBuffer buffer = in.nioBuffer(startOffset, length);
unwrap(ctx, buffer, out);
}
// Slice out the whole packet so unwrap will only be called with complete packets.
// Also directly reset the packetLength. This is needed as unwrap(..) may trigger
// decode(...) again via:
// 1) unwrap(..) is called
// 2) wrap(...) is called from within unwrap(...)
// 3) wrap(...) calls unwrapLater(...)
// 4) unwrapLater(...) calls decode(...)
//
// See https://github.com/netty/netty/issues/1534
int readerIndex = in.readerIndex();
in.skipBytes(packetLength);
ByteBuffer buffer = in.nioBuffer(readerIndex, packetLength);
this.packetLength = 0;
if (nonSslRecord) {
// Not an SSL/TLS packet
NotSslRecordException e = new NotSslRecordException(
"not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
in.skipBytes(in.readableBytes());
ctx.fireExceptionCaught(e);
setHandshakeFailure(e);
}
}
unwrap(ctx, buffer, out);
@Override
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
if (needsFlush) {
needsFlush = false;
ctx.flush();
}
super.channelReadComplete(ctx);
}
private void unwrap(ChannelHandlerContext ctx) throws SSLException {
@ -822,7 +872,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
try {
for (;;) {
if (decodeOut == null) {
decodeOut = ctx.alloc().buffer();
decodeOut = ctx.alloc().buffer(packet.remaining());
}
final SSLEngineResult result = unwrap(engine, packet, decodeOut);
@ -842,7 +892,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
case NEED_UNWRAP:
break;
case NEED_WRAP:
flushNonAppData0(ctx, true);
wrapNonAppData(ctx, true);
break;
case NEED_TASK:
runDelegatedTasks();
@ -863,7 +913,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
}
if (wrapLater) {
flush0(ctx);
wrap(ctx, true);
}
} catch (SSLException e) {
setHandshakeFailure(e);
@ -878,13 +928,21 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
}
private static SSLEngineResult unwrap(SSLEngine engine, ByteBuffer in, ByteBuf out) throws SSLException {
int overflows = 0;
for (;;) {
ByteBuffer out0 = out.nioBuffer(out.writerIndex(), out.writableBytes());
SSLEngineResult result = engine.unwrap(in, out0);
out.writerIndex(out.writerIndex() + result.bytesProduced());
switch (result.getStatus()) {
case BUFFER_OVERFLOW:
out.ensureWritable(engine.getSession().getApplicationBufferSize());
int max = engine.getSession().getApplicationBufferSize();
switch (overflows ++) {
case 0:
out.ensureWritable(Math.min(max, in.remaining()));
break;
default:
out.ensureWritable(max);
}
break;
default:
return result;
@ -975,14 +1033,14 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
if (ctx.channel().isActive()) {
// channelActive() event has been fired already, which means this.channelActive() will
// not be invoked. We have to initialize here instead.
handshake0();
handshake();
} else {
// channelActive() event has not been fired yet. this.channelOpen() will be invoked
// and initialization will occur there.
}
}
private Future<Channel> handshake0() {
private Future<Channel> handshake() {
final ScheduledFuture<?> timeoutFuture;
if (handshakeTimeoutMillis > 0) {
timeoutFuture = ctx.executor().schedule(new Runnable() {
@ -1008,7 +1066,8 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
});
try {
engine.beginHandshake();
flushNonAppData0(ctx, false);
wrapNonAppData(ctx, false);
ctx.flush();
} catch (Exception e) {
notifyHandshakeFailure(e);
}
@ -1023,7 +1082,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
if (!startTls && engine.getUseClientMode()) {
// issue and handshake and add a listener to it which will fire an exception event if
// an exception was thrown while doing the handshake
handshake0().addListener(new GenericFutureListener<Future<Channel>>() {
handshake().addListener(new GenericFutureListener<Future<Channel>>() {
@Override
public void operationComplete(Future<Channel> future) throws Exception {
if (!future.isSuccess()) {