Optimize SslHandler
- Fixes #1905 - Call ctx.flush() only when necessary - Improve the estimation of application and packet buffer sizes - decode() method now tries to call unwrap() with as many SSL records as possible to reduce the number of events triggered
This commit is contained in:
parent
db78581bbb
commit
11f95c78e2
@ -176,6 +176,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
|
||||
private volatile ChannelHandlerContext ctx;
|
||||
private final SSLEngine engine;
|
||||
private final int maxPacketBufferSize;
|
||||
private final Executor delegatedTaskExecutor;
|
||||
|
||||
private final boolean startTls;
|
||||
@ -185,6 +186,12 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
private final LazyChannelPromise sslCloseFuture = new LazyChannelPromise();
|
||||
private final Deque<PendingWrite> pendingUnencryptedWrites = new ArrayDeque<PendingWrite>();
|
||||
|
||||
/**
|
||||
* Set by wrap*() methods when something is produced.
|
||||
* {@link #channelReadComplete(ChannelHandlerContext)} will check this flag, clear it, and call ctx.flush().
|
||||
*/
|
||||
private boolean needsFlush;
|
||||
|
||||
private int packetLength;
|
||||
private ByteBuf decodeOut;
|
||||
|
||||
@ -246,6 +253,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
this.engine = engine;
|
||||
this.delegatedTaskExecutor = delegatedTaskExecutor;
|
||||
this.startTls = startTls;
|
||||
maxPacketBufferSize = engine.getSession().getPacketBufferSize();
|
||||
}
|
||||
|
||||
public long getHandshakeTimeoutMillis() {
|
||||
@ -372,6 +380,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
}
|
||||
|
||||
@Override
|
||||
@Deprecated
|
||||
public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
|
||||
ctx.deregister(promise);
|
||||
}
|
||||
@ -417,10 +426,11 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
if (pendingUnencryptedWrites.isEmpty()) {
|
||||
pendingUnencryptedWrites.add(PendingWrite.newInstance(Unpooled.EMPTY_BUFFER, null));
|
||||
}
|
||||
flush0(ctx);
|
||||
wrap(ctx, false);
|
||||
ctx.flush();
|
||||
}
|
||||
|
||||
private void flush0(ChannelHandlerContext ctx) throws SSLException {
|
||||
private void wrap(ChannelHandlerContext ctx, boolean inUnwrap) throws SSLException {
|
||||
ByteBuf out = null;
|
||||
ChannelPromise promise = null;
|
||||
try {
|
||||
@ -430,8 +440,9 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
break;
|
||||
}
|
||||
if (out == null) {
|
||||
out = ctx.alloc().buffer();
|
||||
out = ctx.alloc().buffer(maxPacketBufferSize);
|
||||
}
|
||||
|
||||
ByteBuf buf = (ByteBuf) pending.msg();
|
||||
SSLEngineResult result = wrap(engine, buf, out);
|
||||
|
||||
@ -464,13 +475,9 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
// deliberate fall-through
|
||||
case NOT_HANDSHAKING:
|
||||
case NEED_WRAP:
|
||||
if (promise != null) {
|
||||
ctx.writeAndFlush(out, promise);
|
||||
promise = null;
|
||||
} else {
|
||||
ctx.writeAndFlush(out);
|
||||
}
|
||||
out = ctx.alloc().buffer();
|
||||
finishWrap(ctx, out, promise, inUnwrap);
|
||||
promise = null;
|
||||
out = null;
|
||||
break;
|
||||
case NEED_UNWRAP:
|
||||
return;
|
||||
@ -483,33 +490,46 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
setHandshakeFailure(e);
|
||||
throw e;
|
||||
} finally {
|
||||
if (out != null && out.isReadable()) {
|
||||
finishWrap(ctx, out, promise, inUnwrap);
|
||||
}
|
||||
}
|
||||
|
||||
private void finishWrap(ChannelHandlerContext ctx, ByteBuf out, ChannelPromise promise, boolean inUnwrap) {
|
||||
if (out != null) {
|
||||
if (out.isReadable()) {
|
||||
if (promise != null) {
|
||||
ctx.writeAndFlush(out, promise);
|
||||
ctx.write(out, promise);
|
||||
} else {
|
||||
ctx.writeAndFlush(out);
|
||||
ctx.write(out);
|
||||
}
|
||||
out = null;
|
||||
} else if (promise != null) {
|
||||
promise.trySuccess();
|
||||
}
|
||||
if (out != null) {
|
||||
if (inUnwrap) {
|
||||
needsFlush = true;
|
||||
}
|
||||
} else {
|
||||
out.release();
|
||||
}
|
||||
} else if (promise != null) {
|
||||
ctx.write(Unpooled.EMPTY_BUFFER, promise);
|
||||
if (inUnwrap) {
|
||||
needsFlush = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void flushNonAppData0(ChannelHandlerContext ctx, boolean inUnwrap) throws SSLException {
|
||||
private void wrapNonAppData(ChannelHandlerContext ctx, boolean inUnwrap) throws SSLException {
|
||||
ByteBuf out = null;
|
||||
try {
|
||||
for (;;) {
|
||||
if (out == null) {
|
||||
out = ctx.alloc().buffer();
|
||||
out = ctx.alloc().buffer(maxPacketBufferSize);
|
||||
}
|
||||
SSLEngineResult result = wrap(engine, Unpooled.EMPTY_BUFFER, out);
|
||||
|
||||
if (result.bytesProduced() > 0) {
|
||||
ctx.writeAndFlush(out);
|
||||
ctx.write(out);
|
||||
if (inUnwrap) {
|
||||
needsFlush = true;
|
||||
}
|
||||
out = null;
|
||||
}
|
||||
|
||||
@ -552,17 +572,20 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
}
|
||||
}
|
||||
|
||||
private static SSLEngineResult wrap(SSLEngine engine, ByteBuf in, ByteBuf out) throws SSLException {
|
||||
private SSLEngineResult wrap(SSLEngine engine, ByteBuf in, ByteBuf out) throws SSLException {
|
||||
ByteBuffer in0 = in.nioBuffer();
|
||||
for (;;) {
|
||||
ByteBuffer out0 = out.nioBuffer(out.writerIndex(), out.writableBytes());
|
||||
SSLEngineResult result = engine.wrap(in0, out0);
|
||||
in.skipBytes(result.bytesConsumed());
|
||||
out.writerIndex(out.writerIndex() + result.bytesProduced());
|
||||
if (result.getStatus() == Status.BUFFER_OVERFLOW) {
|
||||
out.ensureWritable(engine.getSession().getPacketBufferSize());
|
||||
} else {
|
||||
return result;
|
||||
|
||||
switch (result.getStatus()) {
|
||||
case BUFFER_OVERFLOW:
|
||||
out.ensureWritable(maxPacketBufferSize);
|
||||
break;
|
||||
default:
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -678,7 +701,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
if (buffer.readableBytes() < 5) {
|
||||
throw new IllegalArgumentException("buffer must have at least 5 readable bytes");
|
||||
}
|
||||
return getEncryptedPacketLength(buffer) != -1;
|
||||
return getEncryptedPacketLength(buffer, buffer.readerIndex()) != -1;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -694,13 +717,12 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
* @throws IllegalArgumentException
|
||||
* Is thrown if the given {@link ByteBuf} has not at least 5 bytes to read.
|
||||
*/
|
||||
private static int getEncryptedPacketLength(ByteBuf buffer) {
|
||||
int first = buffer.readerIndex();
|
||||
private static int getEncryptedPacketLength(ByteBuf buffer, int offset) {
|
||||
int packetLength = 0;
|
||||
|
||||
// SSLv3 or TLS - Check ContentType
|
||||
boolean tls;
|
||||
switch (buffer.getUnsignedByte(first)) {
|
||||
switch (buffer.getUnsignedByte(offset)) {
|
||||
case 20: // change_cipher_spec
|
||||
case 21: // alert
|
||||
case 22: // handshake
|
||||
@ -714,10 +736,10 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
|
||||
if (tls) {
|
||||
// SSLv3 or TLS - Check ProtocolVersion
|
||||
int majorVersion = buffer.getUnsignedByte(first + 1);
|
||||
int majorVersion = buffer.getUnsignedByte(offset + 1);
|
||||
if (majorVersion == 3) {
|
||||
// SSLv3 or TLS
|
||||
packetLength = buffer.getUnsignedShort(first + 3) + 5;
|
||||
packetLength = buffer.getUnsignedShort(offset + 3) + 5;
|
||||
if (packetLength <= 5) {
|
||||
// Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
|
||||
tls = false;
|
||||
@ -731,14 +753,14 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
if (!tls) {
|
||||
// SSLv2 or bad data - Check the version
|
||||
boolean sslv2 = true;
|
||||
int headerLength = (buffer.getUnsignedByte(first) & 0x80) != 0 ? 2 : 3;
|
||||
int majorVersion = buffer.getUnsignedByte(first + headerLength + 1);
|
||||
int headerLength = (buffer.getUnsignedByte(offset) & 0x80) != 0 ? 2 : 3;
|
||||
int majorVersion = buffer.getUnsignedByte(offset + headerLength + 1);
|
||||
if (majorVersion == 2 || majorVersion == 3) {
|
||||
// SSLv2
|
||||
if (headerLength == 2) {
|
||||
packetLength = (buffer.getShort(first) & 0x7FFF) + 2;
|
||||
packetLength = (buffer.getShort(offset) & 0x7FFF) + 2;
|
||||
} else {
|
||||
packetLength = (buffer.getShort(first) & 0x3FFF) + 3;
|
||||
packetLength = (buffer.getShort(offset) & 0x3FFF) + 3;
|
||||
}
|
||||
if (packetLength <= headerLength) {
|
||||
sslv2 = false;
|
||||
@ -756,51 +778,79 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
|
||||
@Override
|
||||
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws SSLException {
|
||||
int packetLength = this.packetLength;
|
||||
final int readableBytes = in.readableBytes();
|
||||
final int startOffset = in.readerIndex();
|
||||
final int endOffset = in.writerIndex();
|
||||
int offset = startOffset;
|
||||
|
||||
if (packetLength == 0) {
|
||||
// the previous packet was consumed so try to read the length of the next packet
|
||||
if (readableBytes < 5) {
|
||||
// not enough bytes readable to read the packet length
|
||||
// If we calculated the length of the current SSL record before, use that information.
|
||||
if (packetLength > 0) {
|
||||
if (endOffset - startOffset < packetLength) {
|
||||
return;
|
||||
} else {
|
||||
offset += packetLength;
|
||||
packetLength = 0;
|
||||
}
|
||||
}
|
||||
|
||||
boolean nonSslRecord = false;
|
||||
|
||||
for (;;) {
|
||||
final int readableBytes = endOffset - offset;
|
||||
if (readableBytes < 5) {
|
||||
break;
|
||||
}
|
||||
|
||||
packetLength = getEncryptedPacketLength(in);
|
||||
final int packetLength = getEncryptedPacketLength(in, offset);
|
||||
if (packetLength == -1) {
|
||||
// Not an SSL/TLS packet
|
||||
NotSslRecordException e = new NotSslRecordException(
|
||||
"not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
|
||||
in.skipBytes(readableBytes);
|
||||
ctx.fireExceptionCaught(e);
|
||||
setHandshakeFailure(e);
|
||||
return;
|
||||
nonSslRecord = true;
|
||||
break;
|
||||
}
|
||||
|
||||
assert packetLength > 0;
|
||||
this.packetLength = packetLength;
|
||||
|
||||
if (packetLength > readableBytes) {
|
||||
// wait until the whole packet can be read
|
||||
this.packetLength = packetLength;
|
||||
break;
|
||||
}
|
||||
|
||||
offset += packetLength;
|
||||
}
|
||||
|
||||
if (readableBytes < packetLength) {
|
||||
// wait until the whole packet can be read
|
||||
return;
|
||||
final int length = offset - startOffset;
|
||||
if (length > 0) {
|
||||
// The buffer contains one or more full SSL records.
|
||||
// Slice out the whole packet so unwrap will only be called with complete packets.
|
||||
// Also directly reset the packetLength. This is needed as unwrap(..) may trigger
|
||||
// decode(...) again via:
|
||||
// 1) unwrap(..) is called
|
||||
// 2) wrap(...) is called from within unwrap(...)
|
||||
// 3) wrap(...) calls unwrapLater(...)
|
||||
// 4) unwrapLater(...) calls decode(...)
|
||||
//
|
||||
// See https://github.com/netty/netty/issues/1534
|
||||
in.skipBytes(length);
|
||||
ByteBuffer buffer = in.nioBuffer(startOffset, length);
|
||||
unwrap(ctx, buffer, out);
|
||||
}
|
||||
|
||||
// Slice out the whole packet so unwrap will only be called with complete packets.
|
||||
// Also directly reset the packetLength. This is needed as unwrap(..) may trigger
|
||||
// decode(...) again via:
|
||||
// 1) unwrap(..) is called
|
||||
// 2) wrap(...) is called from within unwrap(...)
|
||||
// 3) wrap(...) calls unwrapLater(...)
|
||||
// 4) unwrapLater(...) calls decode(...)
|
||||
//
|
||||
// See https://github.com/netty/netty/issues/1534
|
||||
int readerIndex = in.readerIndex();
|
||||
in.skipBytes(packetLength);
|
||||
ByteBuffer buffer = in.nioBuffer(readerIndex, packetLength);
|
||||
this.packetLength = 0;
|
||||
if (nonSslRecord) {
|
||||
// Not an SSL/TLS packet
|
||||
NotSslRecordException e = new NotSslRecordException(
|
||||
"not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
|
||||
in.skipBytes(in.readableBytes());
|
||||
ctx.fireExceptionCaught(e);
|
||||
setHandshakeFailure(e);
|
||||
}
|
||||
}
|
||||
|
||||
unwrap(ctx, buffer, out);
|
||||
@Override
|
||||
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
|
||||
if (needsFlush) {
|
||||
needsFlush = false;
|
||||
ctx.flush();
|
||||
}
|
||||
super.channelReadComplete(ctx);
|
||||
}
|
||||
|
||||
private void unwrap(ChannelHandlerContext ctx) throws SSLException {
|
||||
@ -822,7 +872,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
try {
|
||||
for (;;) {
|
||||
if (decodeOut == null) {
|
||||
decodeOut = ctx.alloc().buffer();
|
||||
decodeOut = ctx.alloc().buffer(packet.remaining());
|
||||
}
|
||||
|
||||
final SSLEngineResult result = unwrap(engine, packet, decodeOut);
|
||||
@ -842,7 +892,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
case NEED_UNWRAP:
|
||||
break;
|
||||
case NEED_WRAP:
|
||||
flushNonAppData0(ctx, true);
|
||||
wrapNonAppData(ctx, true);
|
||||
break;
|
||||
case NEED_TASK:
|
||||
runDelegatedTasks();
|
||||
@ -863,7 +913,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
}
|
||||
|
||||
if (wrapLater) {
|
||||
flush0(ctx);
|
||||
wrap(ctx, true);
|
||||
}
|
||||
} catch (SSLException e) {
|
||||
setHandshakeFailure(e);
|
||||
@ -878,13 +928,21 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
}
|
||||
|
||||
private static SSLEngineResult unwrap(SSLEngine engine, ByteBuffer in, ByteBuf out) throws SSLException {
|
||||
int overflows = 0;
|
||||
for (;;) {
|
||||
ByteBuffer out0 = out.nioBuffer(out.writerIndex(), out.writableBytes());
|
||||
SSLEngineResult result = engine.unwrap(in, out0);
|
||||
out.writerIndex(out.writerIndex() + result.bytesProduced());
|
||||
switch (result.getStatus()) {
|
||||
case BUFFER_OVERFLOW:
|
||||
out.ensureWritable(engine.getSession().getApplicationBufferSize());
|
||||
int max = engine.getSession().getApplicationBufferSize();
|
||||
switch (overflows ++) {
|
||||
case 0:
|
||||
out.ensureWritable(Math.min(max, in.remaining()));
|
||||
break;
|
||||
default:
|
||||
out.ensureWritable(max);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
return result;
|
||||
@ -975,14 +1033,14 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
if (ctx.channel().isActive()) {
|
||||
// channelActive() event has been fired already, which means this.channelActive() will
|
||||
// not be invoked. We have to initialize here instead.
|
||||
handshake0();
|
||||
handshake();
|
||||
} else {
|
||||
// channelActive() event has not been fired yet. this.channelOpen() will be invoked
|
||||
// and initialization will occur there.
|
||||
}
|
||||
}
|
||||
|
||||
private Future<Channel> handshake0() {
|
||||
private Future<Channel> handshake() {
|
||||
final ScheduledFuture<?> timeoutFuture;
|
||||
if (handshakeTimeoutMillis > 0) {
|
||||
timeoutFuture = ctx.executor().schedule(new Runnable() {
|
||||
@ -1008,7 +1066,8 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
});
|
||||
try {
|
||||
engine.beginHandshake();
|
||||
flushNonAppData0(ctx, false);
|
||||
wrapNonAppData(ctx, false);
|
||||
ctx.flush();
|
||||
} catch (Exception e) {
|
||||
notifyHandshakeFailure(e);
|
||||
}
|
||||
@ -1023,7 +1082,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
if (!startTls && engine.getUseClientMode()) {
|
||||
// issue and handshake and add a listener to it which will fire an exception event if
|
||||
// an exception was thrown while doing the handshake
|
||||
handshake0().addListener(new GenericFutureListener<Future<Channel>>() {
|
||||
handshake().addListener(new GenericFutureListener<Future<Channel>>() {
|
||||
@Override
|
||||
public void operationComplete(Future<Channel> future) throws Exception {
|
||||
if (!future.isSuccess()) {
|
||||
|
Loading…
Reference in New Issue
Block a user