Optimize SslHandler in an OpenSslEngine-friendly way
Motivation: Previous fix for the OpenSslEngine compatibility issue (#2216 and 18b0e95659c057b126653bad2f018a8ce5385255) was to feed SSL records one by one to OpenSslEngine.unwrap(). It is not optimal because it will result in more JNI calls. Modifications: - Do not feed SSL records one by one. - Feed as many records as possible up to MAX_ENCRYPTED_PACKET_LENGTH - Deduplicate MAX_ENCRYPTED_PACKET_LENGTH definitions Result: - No allocation of intemediary arrays - Reduced number of calls to SSLEngine and thus its underlying JNI calls - A tad bit increase in throughput, probably reverting the tiny drop caused by 18b0e95659c057b126653bad2f018a8ce5385255
This commit is contained in:
parent
c58f28dfdd
commit
097ea8b5b5
@ -16,6 +16,7 @@
|
||||
package io.netty.handler.ssl;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufAllocator;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.Channel;
|
||||
@ -52,7 +53,6 @@ import java.nio.channels.DatagramChannel;
|
||||
import java.nio.channels.SocketChannel;
|
||||
import java.util.ArrayDeque;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Deque;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
@ -195,7 +195,6 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
private boolean needsFlush;
|
||||
|
||||
private int packetLength;
|
||||
private ByteBuf decodeOut;
|
||||
|
||||
private volatile long handshakeTimeoutMillis = 10000;
|
||||
private volatile long closeNotifyTimeoutMillis = 3000;
|
||||
@ -345,10 +344,6 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
|
||||
@Override
|
||||
public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
|
||||
if (decodeOut != null) {
|
||||
decodeOut.release();
|
||||
decodeOut = null;
|
||||
}
|
||||
for (;;) {
|
||||
PendingWrite write = pendingUnencryptedWrites.poll();
|
||||
if (write == null) {
|
||||
@ -432,7 +427,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
break;
|
||||
}
|
||||
if (out == null) {
|
||||
out = ctx.alloc().buffer(maxPacketBufferSize);
|
||||
out = allocate(ctx, maxPacketBufferSize);
|
||||
}
|
||||
|
||||
if (!(pending.msg() instanceof ByteBuf)) {
|
||||
@ -518,7 +513,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
try {
|
||||
for (;;) {
|
||||
if (out == null) {
|
||||
out = ctx.alloc().buffer(maxPacketBufferSize);
|
||||
out = allocate(ctx, maxPacketBufferSize);
|
||||
}
|
||||
SSLEngineResult result = wrap(engine, Unpooled.EMPTY_BUFFER, out);
|
||||
|
||||
@ -539,7 +534,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
break;
|
||||
case NEED_UNWRAP:
|
||||
if (!inUnwrap) {
|
||||
unwrapNonApp(ctx);
|
||||
unwrapNonAppData(ctx);
|
||||
}
|
||||
break;
|
||||
case NEED_WRAP:
|
||||
@ -549,7 +544,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
// Workaround for TLS False Start problem reported at:
|
||||
// https://github.com/netty/netty/issues/1108#issuecomment-14266970
|
||||
if (!inUnwrap) {
|
||||
unwrapNonApp(ctx);
|
||||
unwrapNonAppData(ctx);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
@ -777,31 +772,25 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
@Override
|
||||
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws SSLException {
|
||||
|
||||
// Keeps the list of the length of every SSL record in the input buffer.
|
||||
int[] recordLengths = null;
|
||||
int nRecords = 0;
|
||||
|
||||
final int startOffset = in.readerIndex();
|
||||
final int endOffset = in.writerIndex();
|
||||
int offset = startOffset;
|
||||
int totalLength = 0;
|
||||
|
||||
// If we calculated the length of the current SSL record before, use that information.
|
||||
if (packetLength > 0) {
|
||||
if (endOffset - startOffset < packetLength) {
|
||||
return;
|
||||
} else {
|
||||
recordLengths = new int[4];
|
||||
recordLengths[0] = packetLength;
|
||||
nRecords = 1;
|
||||
|
||||
offset += packetLength;
|
||||
totalLength = packetLength;
|
||||
packetLength = 0;
|
||||
}
|
||||
}
|
||||
|
||||
boolean nonSslRecord = false;
|
||||
|
||||
for (;;) {
|
||||
while (totalLength < OpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH) {
|
||||
final int readableBytes = endOffset - offset;
|
||||
if (readableBytes < 5) {
|
||||
break;
|
||||
@ -821,21 +810,18 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
break;
|
||||
}
|
||||
|
||||
// We have a whole packet.
|
||||
// Remember the length of the current packet.
|
||||
if (recordLengths == null) {
|
||||
recordLengths = new int[4];
|
||||
int newTotalLength = totalLength + packetLength;
|
||||
if (newTotalLength > OpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH) {
|
||||
// Don't read too much.
|
||||
break;
|
||||
}
|
||||
if (nRecords == recordLengths.length) {
|
||||
recordLengths = Arrays.copyOf(recordLengths, recordLengths.length << 1);
|
||||
}
|
||||
recordLengths[nRecords ++] = packetLength;
|
||||
|
||||
// We have a whole packet.
|
||||
// Increment the offset to handle the next packet.
|
||||
offset += packetLength;
|
||||
totalLength = newTotalLength;
|
||||
}
|
||||
|
||||
final int totalLength = offset - startOffset;
|
||||
if (totalLength > 0) {
|
||||
// The buffer contains one or more full SSL records.
|
||||
// Slice out the whole packet so unwrap will only be called with complete packets.
|
||||
@ -847,9 +833,11 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
// 4) unwrapLater(...) calls decode(...)
|
||||
//
|
||||
// See https://github.com/netty/netty/issues/1534
|
||||
|
||||
in.skipBytes(totalLength);
|
||||
ByteBuffer buffer = in.nioBuffer(startOffset, totalLength);
|
||||
unwrapMultiple(ctx, buffer, totalLength, recordLengths, nRecords, out);
|
||||
final ByteBuffer inNetBuf = in.nioBuffer(startOffset, totalLength);
|
||||
unwrap(ctx, inNetBuf, totalLength);
|
||||
assert !inNetBuf.hasRemaining() || engine.isInboundDone();
|
||||
}
|
||||
|
||||
if (nonSslRecord) {
|
||||
@ -874,52 +862,20 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
/**
|
||||
* Calls {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer)} with an empty buffer to handle handshakes, etc.
|
||||
*/
|
||||
private void unwrapNonApp(ChannelHandlerContext ctx) throws SSLException {
|
||||
try {
|
||||
unwrapSingle(ctx, Unpooled.EMPTY_BUFFER.nioBuffer(), 0);
|
||||
} finally {
|
||||
ByteBuf decodeOut = this.decodeOut;
|
||||
if (decodeOut != null && decodeOut.isReadable()) {
|
||||
this.decodeOut = null;
|
||||
ctx.fireChannelRead(decodeOut);
|
||||
}
|
||||
}
|
||||
private void unwrapNonAppData(ChannelHandlerContext ctx) throws SSLException {
|
||||
unwrap(ctx, Unpooled.EMPTY_BUFFER.nioBuffer(), 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Unwraps multiple inbound SSL records.
|
||||
* Unwraps inbound SSL records.
|
||||
*/
|
||||
private void unwrapMultiple(
|
||||
ChannelHandlerContext ctx, ByteBuffer packet, int totalLength,
|
||||
int[] recordLengths, int nRecords, List<Object> out) throws SSLException {
|
||||
for (int i = 0; i < nRecords; i ++) {
|
||||
packet.limit(packet.position() + recordLengths[i]);
|
||||
try {
|
||||
unwrapSingle(ctx, packet, totalLength);
|
||||
assert !packet.hasRemaining() || engine.isInboundDone();
|
||||
} finally {
|
||||
ByteBuf decodeOut = this.decodeOut;
|
||||
if (decodeOut != null && decodeOut.isReadable()) {
|
||||
this.decodeOut = null;
|
||||
out.add(decodeOut);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Unwraps a single SSL record.
|
||||
*/
|
||||
private void unwrapSingle(
|
||||
private void unwrap(
|
||||
ChannelHandlerContext ctx, ByteBuffer packet, int initialOutAppBufCapacity) throws SSLException {
|
||||
|
||||
boolean wrapLater = false;
|
||||
ByteBuf decodeOut = allocate(ctx, initialOutAppBufCapacity);
|
||||
try {
|
||||
for (;;) {
|
||||
if (decodeOut == null) {
|
||||
decodeOut = ctx.alloc().buffer(initialOutAppBufCapacity);
|
||||
}
|
||||
|
||||
final SSLEngineResult result = unwrap(engine, packet, decodeOut);
|
||||
final Status status = result.getStatus();
|
||||
final HandshakeStatus handshakeStatus = result.getHandshakeStatus();
|
||||
@ -974,6 +930,12 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
} catch (SSLException e) {
|
||||
setHandshakeFailure(e);
|
||||
throw e;
|
||||
} finally {
|
||||
if (decodeOut.isReadable()) {
|
||||
ctx.fireChannelRead(decodeOut);
|
||||
} else {
|
||||
decodeOut.release();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1253,6 +1215,19 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Always prefer a direct buffer when it's pooled, so that we reduce the number of memory copies
|
||||
* in {@link OpenSslEngine}.
|
||||
*/
|
||||
private static ByteBuf allocate(ChannelHandlerContext ctx, int capacity) {
|
||||
ByteBufAllocator alloc = ctx.alloc();
|
||||
if (alloc.isDirectBufferPooled()) {
|
||||
return alloc.directBuffer(capacity);
|
||||
} else {
|
||||
return alloc.buffer(capacity);
|
||||
}
|
||||
}
|
||||
|
||||
private final class LazyChannelPromise extends DefaultPromise<Channel> {
|
||||
|
||||
@Override
|
||||
|
Loading…
x
Reference in New Issue
Block a user