Optimize SslHandler in an OpenSslEngine-friendly way

Motivation:

Previous fix for the OpenSslEngine compatibility issue (#2216 and
18b0e95659) was to feed SSL records one by
one to OpenSslEngine.unwrap().  It is not optimal because it will result
in more JNI calls.

Modifications:

- Do not feed SSL records one by one.
- Feed as many records as possible up to MAX_ENCRYPTED_PACKET_LENGTH
- Deduplicate MAX_ENCRYPTED_PACKET_LENGTH definitions

Result:

- No allocation of intemediary arrays
- Reduced number of calls to SSLEngine and thus its underlying JNI calls
- A tad bit increase in throughput, probably reverting the tiny drop
  caused by 18b0e95659
This commit is contained in:
Trustin Lee 2014-05-18 03:32:35 +09:00
parent c58f28dfdd
commit 097ea8b5b5

View File

@ -16,6 +16,7 @@
package io.netty.handler.ssl; package io.netty.handler.ssl;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.ByteBufUtil; import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.channel.Channel; import io.netty.channel.Channel;
@ -52,7 +53,6 @@ import java.nio.channels.DatagramChannel;
import java.nio.channels.SocketChannel; import java.nio.channels.SocketChannel;
import java.util.ArrayDeque; import java.util.ArrayDeque;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Deque; import java.util.Deque;
import java.util.List; import java.util.List;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
@ -195,7 +195,6 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
private boolean needsFlush; private boolean needsFlush;
private int packetLength; private int packetLength;
private ByteBuf decodeOut;
private volatile long handshakeTimeoutMillis = 10000; private volatile long handshakeTimeoutMillis = 10000;
private volatile long closeNotifyTimeoutMillis = 3000; private volatile long closeNotifyTimeoutMillis = 3000;
@ -345,10 +344,6 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
@Override @Override
public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
if (decodeOut != null) {
decodeOut.release();
decodeOut = null;
}
for (;;) { for (;;) {
PendingWrite write = pendingUnencryptedWrites.poll(); PendingWrite write = pendingUnencryptedWrites.poll();
if (write == null) { if (write == null) {
@ -432,7 +427,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
break; break;
} }
if (out == null) { if (out == null) {
out = ctx.alloc().buffer(maxPacketBufferSize); out = allocate(ctx, maxPacketBufferSize);
} }
if (!(pending.msg() instanceof ByteBuf)) { if (!(pending.msg() instanceof ByteBuf)) {
@ -518,7 +513,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
try { try {
for (;;) { for (;;) {
if (out == null) { if (out == null) {
out = ctx.alloc().buffer(maxPacketBufferSize); out = allocate(ctx, maxPacketBufferSize);
} }
SSLEngineResult result = wrap(engine, Unpooled.EMPTY_BUFFER, out); SSLEngineResult result = wrap(engine, Unpooled.EMPTY_BUFFER, out);
@ -539,7 +534,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
break; break;
case NEED_UNWRAP: case NEED_UNWRAP:
if (!inUnwrap) { if (!inUnwrap) {
unwrapNonApp(ctx); unwrapNonAppData(ctx);
} }
break; break;
case NEED_WRAP: case NEED_WRAP:
@ -549,7 +544,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
// Workaround for TLS False Start problem reported at: // Workaround for TLS False Start problem reported at:
// https://github.com/netty/netty/issues/1108#issuecomment-14266970 // https://github.com/netty/netty/issues/1108#issuecomment-14266970
if (!inUnwrap) { if (!inUnwrap) {
unwrapNonApp(ctx); unwrapNonAppData(ctx);
} }
break; break;
default: default:
@ -777,31 +772,25 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
@Override @Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws SSLException { protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws SSLException {
// Keeps the list of the length of every SSL record in the input buffer.
int[] recordLengths = null;
int nRecords = 0;
final int startOffset = in.readerIndex(); final int startOffset = in.readerIndex();
final int endOffset = in.writerIndex(); final int endOffset = in.writerIndex();
int offset = startOffset; int offset = startOffset;
int totalLength = 0;
// If we calculated the length of the current SSL record before, use that information. // If we calculated the length of the current SSL record before, use that information.
if (packetLength > 0) { if (packetLength > 0) {
if (endOffset - startOffset < packetLength) { if (endOffset - startOffset < packetLength) {
return; return;
} else { } else {
recordLengths = new int[4];
recordLengths[0] = packetLength;
nRecords = 1;
offset += packetLength; offset += packetLength;
totalLength = packetLength;
packetLength = 0; packetLength = 0;
} }
} }
boolean nonSslRecord = false; boolean nonSslRecord = false;
for (;;) { while (totalLength < OpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH) {
final int readableBytes = endOffset - offset; final int readableBytes = endOffset - offset;
if (readableBytes < 5) { if (readableBytes < 5) {
break; break;
@ -821,21 +810,18 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
break; break;
} }
// We have a whole packet. int newTotalLength = totalLength + packetLength;
// Remember the length of the current packet. if (newTotalLength > OpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH) {
if (recordLengths == null) { // Don't read too much.
recordLengths = new int[4]; break;
} }
if (nRecords == recordLengths.length) {
recordLengths = Arrays.copyOf(recordLengths, recordLengths.length << 1);
}
recordLengths[nRecords ++] = packetLength;
// We have a whole packet.
// Increment the offset to handle the next packet. // Increment the offset to handle the next packet.
offset += packetLength; offset += packetLength;
totalLength = newTotalLength;
} }
final int totalLength = offset - startOffset;
if (totalLength > 0) { if (totalLength > 0) {
// The buffer contains one or more full SSL records. // The buffer contains one or more full SSL records.
// Slice out the whole packet so unwrap will only be called with complete packets. // Slice out the whole packet so unwrap will only be called with complete packets.
@ -847,9 +833,11 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
// 4) unwrapLater(...) calls decode(...) // 4) unwrapLater(...) calls decode(...)
// //
// See https://github.com/netty/netty/issues/1534 // See https://github.com/netty/netty/issues/1534
in.skipBytes(totalLength); in.skipBytes(totalLength);
ByteBuffer buffer = in.nioBuffer(startOffset, totalLength); final ByteBuffer inNetBuf = in.nioBuffer(startOffset, totalLength);
unwrapMultiple(ctx, buffer, totalLength, recordLengths, nRecords, out); unwrap(ctx, inNetBuf, totalLength);
assert !inNetBuf.hasRemaining() || engine.isInboundDone();
} }
if (nonSslRecord) { if (nonSslRecord) {
@ -874,52 +862,20 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
/** /**
* Calls {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer)} with an empty buffer to handle handshakes, etc. * Calls {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer)} with an empty buffer to handle handshakes, etc.
*/ */
private void unwrapNonApp(ChannelHandlerContext ctx) throws SSLException { private void unwrapNonAppData(ChannelHandlerContext ctx) throws SSLException {
try { unwrap(ctx, Unpooled.EMPTY_BUFFER.nioBuffer(), 0);
unwrapSingle(ctx, Unpooled.EMPTY_BUFFER.nioBuffer(), 0);
} finally {
ByteBuf decodeOut = this.decodeOut;
if (decodeOut != null && decodeOut.isReadable()) {
this.decodeOut = null;
ctx.fireChannelRead(decodeOut);
}
}
} }
/** /**
* Unwraps multiple inbound SSL records. * Unwraps inbound SSL records.
*/ */
private void unwrapMultiple( private void unwrap(
ChannelHandlerContext ctx, ByteBuffer packet, int totalLength,
int[] recordLengths, int nRecords, List<Object> out) throws SSLException {
for (int i = 0; i < nRecords; i ++) {
packet.limit(packet.position() + recordLengths[i]);
try {
unwrapSingle(ctx, packet, totalLength);
assert !packet.hasRemaining() || engine.isInboundDone();
} finally {
ByteBuf decodeOut = this.decodeOut;
if (decodeOut != null && decodeOut.isReadable()) {
this.decodeOut = null;
out.add(decodeOut);
}
}
}
}
/**
* Unwraps a single SSL record.
*/
private void unwrapSingle(
ChannelHandlerContext ctx, ByteBuffer packet, int initialOutAppBufCapacity) throws SSLException { ChannelHandlerContext ctx, ByteBuffer packet, int initialOutAppBufCapacity) throws SSLException {
boolean wrapLater = false; boolean wrapLater = false;
ByteBuf decodeOut = allocate(ctx, initialOutAppBufCapacity);
try { try {
for (;;) { for (;;) {
if (decodeOut == null) {
decodeOut = ctx.alloc().buffer(initialOutAppBufCapacity);
}
final SSLEngineResult result = unwrap(engine, packet, decodeOut); final SSLEngineResult result = unwrap(engine, packet, decodeOut);
final Status status = result.getStatus(); final Status status = result.getStatus();
final HandshakeStatus handshakeStatus = result.getHandshakeStatus(); final HandshakeStatus handshakeStatus = result.getHandshakeStatus();
@ -974,6 +930,12 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
} catch (SSLException e) { } catch (SSLException e) {
setHandshakeFailure(e); setHandshakeFailure(e);
throw e; throw e;
} finally {
if (decodeOut.isReadable()) {
ctx.fireChannelRead(decodeOut);
} else {
decodeOut.release();
}
} }
} }
@ -1253,6 +1215,19 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
}); });
} }
/**
* Always prefer a direct buffer when it's pooled, so that we reduce the number of memory copies
* in {@link OpenSslEngine}.
*/
private static ByteBuf allocate(ChannelHandlerContext ctx, int capacity) {
ByteBufAllocator alloc = ctx.alloc();
if (alloc.isDirectBufferPooled()) {
return alloc.directBuffer(capacity);
} else {
return alloc.buffer(capacity);
}
}
private final class LazyChannelPromise extends DefaultPromise<Channel> { private final class LazyChannelPromise extends DefaultPromise<Channel> {
@Override @Override