Optimize SslHandler in an OpenSslEngine-friendly way
Motivation: Previous fix for the OpenSslEngine compatibility issue (#2216 and18b0e95659
) was to feed SSL records one by one to OpenSslEngine.unwrap(). It is not optimal because it will result in more JNI calls. Modifications: - Do not feed SSL records one by one. - Feed as many records as possible up to MAX_ENCRYPTED_PACKET_LENGTH - Deduplicate MAX_ENCRYPTED_PACKET_LENGTH definitions Result: - No allocation of intemediary arrays - Reduced number of calls to SSLEngine and thus its underlying JNI calls - A tad bit increase in throughput, probably reverting the tiny drop caused by18b0e95659
This commit is contained in:
parent
a72230061d
commit
1095e4486f
@ -16,6 +16,7 @@
|
|||||||
package io.netty.handler.ssl;
|
package io.netty.handler.ssl;
|
||||||
|
|
||||||
import io.netty.buffer.ByteBuf;
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import io.netty.buffer.ByteBufAllocator;
|
||||||
import io.netty.buffer.ByteBufUtil;
|
import io.netty.buffer.ByteBufUtil;
|
||||||
import io.netty.buffer.Unpooled;
|
import io.netty.buffer.Unpooled;
|
||||||
import io.netty.channel.Channel;
|
import io.netty.channel.Channel;
|
||||||
@ -48,7 +49,6 @@ import java.nio.channels.ClosedChannelException;
|
|||||||
import java.nio.channels.DatagramChannel;
|
import java.nio.channels.DatagramChannel;
|
||||||
import java.nio.channels.SocketChannel;
|
import java.nio.channels.SocketChannel;
|
||||||
import java.util.ArrayDeque;
|
import java.util.ArrayDeque;
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Deque;
|
import java.util.Deque;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.ScheduledFuture;
|
import java.util.concurrent.ScheduledFuture;
|
||||||
@ -189,7 +189,6 @@ public class SslHandler extends ByteToMessageDecoder {
|
|||||||
private boolean needsFlush;
|
private boolean needsFlush;
|
||||||
|
|
||||||
private int packetLength;
|
private int packetLength;
|
||||||
private ByteBuf decodeOut;
|
|
||||||
|
|
||||||
private volatile long handshakeTimeoutMillis = 10000;
|
private volatile long handshakeTimeoutMillis = 10000;
|
||||||
private volatile long closeNotifyTimeoutMillis = 3000;
|
private volatile long closeNotifyTimeoutMillis = 3000;
|
||||||
@ -318,10 +317,6 @@ public class SslHandler extends ByteToMessageDecoder {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
|
public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
|
||||||
if (decodeOut != null) {
|
|
||||||
decodeOut.release();
|
|
||||||
decodeOut = null;
|
|
||||||
}
|
|
||||||
for (;;) {
|
for (;;) {
|
||||||
PendingWrite write = pendingUnencryptedWrites.poll();
|
PendingWrite write = pendingUnencryptedWrites.poll();
|
||||||
if (write == null) {
|
if (write == null) {
|
||||||
@ -384,7 +379,7 @@ public class SslHandler extends ByteToMessageDecoder {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (out == null) {
|
if (out == null) {
|
||||||
out = ctx.alloc().buffer(maxPacketBufferSize);
|
out = allocate(ctx, maxPacketBufferSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!(pending.msg() instanceof ByteBuf)) {
|
if (!(pending.msg() instanceof ByteBuf)) {
|
||||||
@ -470,7 +465,7 @@ public class SslHandler extends ByteToMessageDecoder {
|
|||||||
try {
|
try {
|
||||||
for (;;) {
|
for (;;) {
|
||||||
if (out == null) {
|
if (out == null) {
|
||||||
out = ctx.alloc().buffer(maxPacketBufferSize);
|
out = allocate(ctx, maxPacketBufferSize);
|
||||||
}
|
}
|
||||||
SSLEngineResult result = wrap(engine, Unpooled.EMPTY_BUFFER, out);
|
SSLEngineResult result = wrap(engine, Unpooled.EMPTY_BUFFER, out);
|
||||||
|
|
||||||
@ -491,7 +486,7 @@ public class SslHandler extends ByteToMessageDecoder {
|
|||||||
break;
|
break;
|
||||||
case NEED_UNWRAP:
|
case NEED_UNWRAP:
|
||||||
if (!inUnwrap) {
|
if (!inUnwrap) {
|
||||||
unwrapNonApp(ctx);
|
unwrapNonAppData(ctx);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case NEED_WRAP:
|
case NEED_WRAP:
|
||||||
@ -501,7 +496,7 @@ public class SslHandler extends ByteToMessageDecoder {
|
|||||||
// Workaround for TLS False Start problem reported at:
|
// Workaround for TLS False Start problem reported at:
|
||||||
// https://github.com/netty/netty/issues/1108#issuecomment-14266970
|
// https://github.com/netty/netty/issues/1108#issuecomment-14266970
|
||||||
if (!inUnwrap) {
|
if (!inUnwrap) {
|
||||||
unwrapNonApp(ctx);
|
unwrapNonAppData(ctx);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
@ -729,31 +724,25 @@ public class SslHandler extends ByteToMessageDecoder {
|
|||||||
@Override
|
@Override
|
||||||
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws SSLException {
|
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws SSLException {
|
||||||
|
|
||||||
// Keeps the list of the length of every SSL record in the input buffer.
|
|
||||||
int[] recordLengths = null;
|
|
||||||
int nRecords = 0;
|
|
||||||
|
|
||||||
final int startOffset = in.readerIndex();
|
final int startOffset = in.readerIndex();
|
||||||
final int endOffset = in.writerIndex();
|
final int endOffset = in.writerIndex();
|
||||||
int offset = startOffset;
|
int offset = startOffset;
|
||||||
|
int totalLength = 0;
|
||||||
|
|
||||||
// If we calculated the length of the current SSL record before, use that information.
|
// If we calculated the length of the current SSL record before, use that information.
|
||||||
if (packetLength > 0) {
|
if (packetLength > 0) {
|
||||||
if (endOffset - startOffset < packetLength) {
|
if (endOffset - startOffset < packetLength) {
|
||||||
return;
|
return;
|
||||||
} else {
|
} else {
|
||||||
recordLengths = new int[4];
|
|
||||||
recordLengths[0] = packetLength;
|
|
||||||
nRecords = 1;
|
|
||||||
|
|
||||||
offset += packetLength;
|
offset += packetLength;
|
||||||
|
totalLength = packetLength;
|
||||||
packetLength = 0;
|
packetLength = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean nonSslRecord = false;
|
boolean nonSslRecord = false;
|
||||||
|
|
||||||
for (;;) {
|
while (totalLength < OpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH) {
|
||||||
final int readableBytes = endOffset - offset;
|
final int readableBytes = endOffset - offset;
|
||||||
if (readableBytes < 5) {
|
if (readableBytes < 5) {
|
||||||
break;
|
break;
|
||||||
@ -773,21 +762,18 @@ public class SslHandler extends ByteToMessageDecoder {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// We have a whole packet.
|
int newTotalLength = totalLength + packetLength;
|
||||||
// Remember the length of the current packet.
|
if (newTotalLength > OpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH) {
|
||||||
if (recordLengths == null) {
|
// Don't read too much.
|
||||||
recordLengths = new int[4];
|
break;
|
||||||
}
|
}
|
||||||
if (nRecords == recordLengths.length) {
|
|
||||||
recordLengths = Arrays.copyOf(recordLengths, recordLengths.length << 1);
|
|
||||||
}
|
|
||||||
recordLengths[nRecords ++] = packetLength;
|
|
||||||
|
|
||||||
|
// We have a whole packet.
|
||||||
// Increment the offset to handle the next packet.
|
// Increment the offset to handle the next packet.
|
||||||
offset += packetLength;
|
offset += packetLength;
|
||||||
|
totalLength = newTotalLength;
|
||||||
}
|
}
|
||||||
|
|
||||||
final int totalLength = offset - startOffset;
|
|
||||||
if (totalLength > 0) {
|
if (totalLength > 0) {
|
||||||
// The buffer contains one or more full SSL records.
|
// The buffer contains one or more full SSL records.
|
||||||
// Slice out the whole packet so unwrap will only be called with complete packets.
|
// Slice out the whole packet so unwrap will only be called with complete packets.
|
||||||
@ -799,9 +785,11 @@ public class SslHandler extends ByteToMessageDecoder {
|
|||||||
// 4) unwrapLater(...) calls decode(...)
|
// 4) unwrapLater(...) calls decode(...)
|
||||||
//
|
//
|
||||||
// See https://github.com/netty/netty/issues/1534
|
// See https://github.com/netty/netty/issues/1534
|
||||||
|
|
||||||
in.skipBytes(totalLength);
|
in.skipBytes(totalLength);
|
||||||
ByteBuffer buffer = in.nioBuffer(startOffset, totalLength);
|
final ByteBuffer inNetBuf = in.nioBuffer(startOffset, totalLength);
|
||||||
unwrapMultiple(ctx, buffer, totalLength, recordLengths, nRecords, out);
|
unwrap(ctx, inNetBuf, totalLength);
|
||||||
|
assert !inNetBuf.hasRemaining() || engine.isInboundDone();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (nonSslRecord) {
|
if (nonSslRecord) {
|
||||||
@ -826,52 +814,20 @@ public class SslHandler extends ByteToMessageDecoder {
|
|||||||
/**
|
/**
|
||||||
* Calls {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer)} with an empty buffer to handle handshakes, etc.
|
* Calls {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer)} with an empty buffer to handle handshakes, etc.
|
||||||
*/
|
*/
|
||||||
private void unwrapNonApp(ChannelHandlerContext ctx) throws SSLException {
|
private void unwrapNonAppData(ChannelHandlerContext ctx) throws SSLException {
|
||||||
try {
|
unwrap(ctx, Unpooled.EMPTY_BUFFER.nioBuffer(), 0);
|
||||||
unwrapSingle(ctx, Unpooled.EMPTY_BUFFER.nioBuffer(), 0);
|
|
||||||
} finally {
|
|
||||||
ByteBuf decodeOut = this.decodeOut;
|
|
||||||
if (decodeOut != null && decodeOut.isReadable()) {
|
|
||||||
this.decodeOut = null;
|
|
||||||
ctx.fireChannelRead(decodeOut);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Unwraps multiple inbound SSL records.
|
* Unwraps inbound SSL records.
|
||||||
*/
|
*/
|
||||||
private void unwrapMultiple(
|
private void unwrap(
|
||||||
ChannelHandlerContext ctx, ByteBuffer packet, int totalLength,
|
|
||||||
int[] recordLengths, int nRecords, List<Object> out) throws SSLException {
|
|
||||||
for (int i = 0; i < nRecords; i ++) {
|
|
||||||
packet.limit(packet.position() + recordLengths[i]);
|
|
||||||
try {
|
|
||||||
unwrapSingle(ctx, packet, totalLength);
|
|
||||||
assert !packet.hasRemaining() || engine.isInboundDone();
|
|
||||||
} finally {
|
|
||||||
ByteBuf decodeOut = this.decodeOut;
|
|
||||||
if (decodeOut != null && decodeOut.isReadable()) {
|
|
||||||
this.decodeOut = null;
|
|
||||||
out.add(decodeOut);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Unwraps a single SSL record.
|
|
||||||
*/
|
|
||||||
private void unwrapSingle(
|
|
||||||
ChannelHandlerContext ctx, ByteBuffer packet, int initialOutAppBufCapacity) throws SSLException {
|
ChannelHandlerContext ctx, ByteBuffer packet, int initialOutAppBufCapacity) throws SSLException {
|
||||||
|
|
||||||
boolean wrapLater = false;
|
boolean wrapLater = false;
|
||||||
|
ByteBuf decodeOut = allocate(ctx, initialOutAppBufCapacity);
|
||||||
try {
|
try {
|
||||||
for (;;) {
|
for (;;) {
|
||||||
if (decodeOut == null) {
|
|
||||||
decodeOut = ctx.alloc().buffer(initialOutAppBufCapacity);
|
|
||||||
}
|
|
||||||
|
|
||||||
final SSLEngineResult result = unwrap(engine, packet, decodeOut);
|
final SSLEngineResult result = unwrap(engine, packet, decodeOut);
|
||||||
final Status status = result.getStatus();
|
final Status status = result.getStatus();
|
||||||
final HandshakeStatus handshakeStatus = result.getHandshakeStatus();
|
final HandshakeStatus handshakeStatus = result.getHandshakeStatus();
|
||||||
@ -926,6 +882,12 @@ public class SslHandler extends ByteToMessageDecoder {
|
|||||||
} catch (SSLException e) {
|
} catch (SSLException e) {
|
||||||
setHandshakeFailure(e);
|
setHandshakeFailure(e);
|
||||||
throw e;
|
throw e;
|
||||||
|
} finally {
|
||||||
|
if (decodeOut.isReadable()) {
|
||||||
|
ctx.fireChannelRead(decodeOut);
|
||||||
|
} else {
|
||||||
|
decodeOut.release();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1156,6 +1118,19 @@ public class SslHandler extends ByteToMessageDecoder {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Always prefer a direct buffer when it's pooled, so that we reduce the number of memory copies
|
||||||
|
* in {@link OpenSslEngine}.
|
||||||
|
*/
|
||||||
|
private static ByteBuf allocate(ChannelHandlerContext ctx, int capacity) {
|
||||||
|
ByteBufAllocator alloc = ctx.alloc();
|
||||||
|
if (alloc.isDirectBufferPooled()) {
|
||||||
|
return alloc.directBuffer(capacity);
|
||||||
|
} else {
|
||||||
|
return alloc.buffer(capacity);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private final class LazyChannelPromise extends DefaultPromise<Channel> {
|
private final class LazyChannelPromise extends DefaultPromise<Channel> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
Loading…
Reference in New Issue
Block a user