Optimize SslHandler in an OpenSslEngine-friendly way

Motivation:

Previous fix for the OpenSslEngine compatibility issue (#2216 and
18b0e95659) was to feed SSL records one by
one to OpenSslEngine.unwrap().  It is not optimal because it will result
in more JNI calls.

Modifications:

- Do not feed SSL records one by one.
- Feed as many records as possible up to MAX_ENCRYPTED_PACKET_LENGTH
- Deduplicate MAX_ENCRYPTED_PACKET_LENGTH definitions

Result:

- No allocation of intemediary arrays
- Reduced number of calls to SSLEngine and thus its underlying JNI calls
- A tad bit increase in throughput, probably reverting the tiny drop
  caused by 18b0e95659
This commit is contained in:
Trustin Lee 2014-05-18 03:32:35 +09:00
parent c58f28dfdd
commit 097ea8b5b5

View File

@ -16,6 +16,7 @@
package io.netty.handler.ssl;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
@ -52,7 +53,6 @@ import java.nio.channels.DatagramChannel;
import java.nio.channels.SocketChannel;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Deque;
import java.util.List;
import java.util.concurrent.CountDownLatch;
@ -195,7 +195,6 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
private boolean needsFlush;
private int packetLength;
private ByteBuf decodeOut;
private volatile long handshakeTimeoutMillis = 10000;
private volatile long closeNotifyTimeoutMillis = 3000;
@ -345,10 +344,6 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
@Override
public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
if (decodeOut != null) {
decodeOut.release();
decodeOut = null;
}
for (;;) {
PendingWrite write = pendingUnencryptedWrites.poll();
if (write == null) {
@ -432,7 +427,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
break;
}
if (out == null) {
out = ctx.alloc().buffer(maxPacketBufferSize);
out = allocate(ctx, maxPacketBufferSize);
}
if (!(pending.msg() instanceof ByteBuf)) {
@ -518,7 +513,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
try {
for (;;) {
if (out == null) {
out = ctx.alloc().buffer(maxPacketBufferSize);
out = allocate(ctx, maxPacketBufferSize);
}
SSLEngineResult result = wrap(engine, Unpooled.EMPTY_BUFFER, out);
@ -539,7 +534,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
break;
case NEED_UNWRAP:
if (!inUnwrap) {
unwrapNonApp(ctx);
unwrapNonAppData(ctx);
}
break;
case NEED_WRAP:
@ -549,7 +544,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
// Workaround for TLS False Start problem reported at:
// https://github.com/netty/netty/issues/1108#issuecomment-14266970
if (!inUnwrap) {
unwrapNonApp(ctx);
unwrapNonAppData(ctx);
}
break;
default:
@ -777,31 +772,25 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws SSLException {
// Keeps the list of the length of every SSL record in the input buffer.
int[] recordLengths = null;
int nRecords = 0;
final int startOffset = in.readerIndex();
final int endOffset = in.writerIndex();
int offset = startOffset;
int totalLength = 0;
// If we calculated the length of the current SSL record before, use that information.
if (packetLength > 0) {
if (endOffset - startOffset < packetLength) {
return;
} else {
recordLengths = new int[4];
recordLengths[0] = packetLength;
nRecords = 1;
offset += packetLength;
totalLength = packetLength;
packetLength = 0;
}
}
boolean nonSslRecord = false;
for (;;) {
while (totalLength < OpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH) {
final int readableBytes = endOffset - offset;
if (readableBytes < 5) {
break;
@ -821,21 +810,18 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
break;
}
// We have a whole packet.
// Remember the length of the current packet.
if (recordLengths == null) {
recordLengths = new int[4];
int newTotalLength = totalLength + packetLength;
if (newTotalLength > OpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH) {
// Don't read too much.
break;
}
if (nRecords == recordLengths.length) {
recordLengths = Arrays.copyOf(recordLengths, recordLengths.length << 1);
}
recordLengths[nRecords ++] = packetLength;
// We have a whole packet.
// Increment the offset to handle the next packet.
offset += packetLength;
totalLength = newTotalLength;
}
final int totalLength = offset - startOffset;
if (totalLength > 0) {
// The buffer contains one or more full SSL records.
// Slice out the whole packet so unwrap will only be called with complete packets.
@ -847,9 +833,11 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
// 4) unwrapLater(...) calls decode(...)
//
// See https://github.com/netty/netty/issues/1534
in.skipBytes(totalLength);
ByteBuffer buffer = in.nioBuffer(startOffset, totalLength);
unwrapMultiple(ctx, buffer, totalLength, recordLengths, nRecords, out);
final ByteBuffer inNetBuf = in.nioBuffer(startOffset, totalLength);
unwrap(ctx, inNetBuf, totalLength);
assert !inNetBuf.hasRemaining() || engine.isInboundDone();
}
if (nonSslRecord) {
@ -874,52 +862,20 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
/**
* Calls {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer)} with an empty buffer to handle handshakes, etc.
*/
private void unwrapNonApp(ChannelHandlerContext ctx) throws SSLException {
try {
unwrapSingle(ctx, Unpooled.EMPTY_BUFFER.nioBuffer(), 0);
} finally {
ByteBuf decodeOut = this.decodeOut;
if (decodeOut != null && decodeOut.isReadable()) {
this.decodeOut = null;
ctx.fireChannelRead(decodeOut);
}
}
private void unwrapNonAppData(ChannelHandlerContext ctx) throws SSLException {
unwrap(ctx, Unpooled.EMPTY_BUFFER.nioBuffer(), 0);
}
/**
* Unwraps multiple inbound SSL records.
* Unwraps inbound SSL records.
*/
private void unwrapMultiple(
ChannelHandlerContext ctx, ByteBuffer packet, int totalLength,
int[] recordLengths, int nRecords, List<Object> out) throws SSLException {
for (int i = 0; i < nRecords; i ++) {
packet.limit(packet.position() + recordLengths[i]);
try {
unwrapSingle(ctx, packet, totalLength);
assert !packet.hasRemaining() || engine.isInboundDone();
} finally {
ByteBuf decodeOut = this.decodeOut;
if (decodeOut != null && decodeOut.isReadable()) {
this.decodeOut = null;
out.add(decodeOut);
}
}
}
}
/**
* Unwraps a single SSL record.
*/
private void unwrapSingle(
private void unwrap(
ChannelHandlerContext ctx, ByteBuffer packet, int initialOutAppBufCapacity) throws SSLException {
boolean wrapLater = false;
ByteBuf decodeOut = allocate(ctx, initialOutAppBufCapacity);
try {
for (;;) {
if (decodeOut == null) {
decodeOut = ctx.alloc().buffer(initialOutAppBufCapacity);
}
final SSLEngineResult result = unwrap(engine, packet, decodeOut);
final Status status = result.getStatus();
final HandshakeStatus handshakeStatus = result.getHandshakeStatus();
@ -974,6 +930,12 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
} catch (SSLException e) {
setHandshakeFailure(e);
throw e;
} finally {
if (decodeOut.isReadable()) {
ctx.fireChannelRead(decodeOut);
} else {
decodeOut.release();
}
}
}
@ -1253,6 +1215,19 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
});
}
/**
* Always prefer a direct buffer when it's pooled, so that we reduce the number of memory copies
* in {@link OpenSslEngine}.
*/
private static ByteBuf allocate(ChannelHandlerContext ctx, int capacity) {
ByteBufAllocator alloc = ctx.alloc();
if (alloc.isDirectBufferPooled()) {
return alloc.directBuffer(capacity);
} else {
return alloc.buffer(capacity);
}
}
private final class LazyChannelPromise extends DefaultPromise<Channel> {
@Override