Reduce memory copies when using OpenSslEngine with SslHandler

Motivation:

When using OpenSslEngine with the SslHandler it is possible to reduce memory copies by unwrap(...) multiple ByteBuffers at the same time. This way we can eliminate a memory copy that is needed otherwise to cumulate partial received data.

Modifications:

- Add OpenSslEngine.unwrap(ByteBuffer[],...) method that can be used to unwrap multiple src ByteBuffer a the same time
- Use a CompositeByteBuffer in SslHandler for inbound data so we not need to memory copy
- Add OpenSslEngine.unwrap(ByteBuffer[],...) in SslHandler if OpenSslEngine is used and the inbound ByteBuf is backed by more then one ByteBuffer
- Reduce object allocation

Result:

SslHandler is faster when using OpenSslEngine and produce less GC
This commit is contained in:
Norman Maurer 2014-11-24 20:26:39 +01:00 committed by Norman Maurer
parent 017a5ef4e4
commit f46a3d74d0
2 changed files with 181 additions and 61 deletions

View File

@ -17,6 +17,7 @@ package io.netty.handler.ssl;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.EmptyArrays;
import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLogger;
@ -120,6 +121,8 @@ public final class OpenSslEngine extends SSLEngine {
private static final String INVALID_CIPHER = "SSL_NULL_WITH_NULL_NULL"; private static final String INVALID_CIPHER = "SSL_NULL_WITH_NULL_NULL";
private static final long EMPTY_ADDR = Buffer.address(Unpooled.EMPTY_BUFFER.nioBuffer());
// OpenSSL state // OpenSSL state
private long ssl; private long ssl;
private long networkBIO; private long networkBIO;
@ -147,8 +150,6 @@ public final class OpenSslEngine extends SSLEngine {
private boolean isOutboundDone; private boolean isOutboundDone;
private boolean engineClosed; private boolean engineClosed;
private int lastPrimingReadResult;
private final boolean clientMode; private final boolean clientMode;
private final ByteBufAllocator alloc; private final ByteBufAllocator alloc;
private final String fallbackApplicationProtocol; private final String fallbackApplicationProtocol;
@ -252,7 +253,7 @@ public final class OpenSslEngine extends SSLEngine {
} }
/** /**
* Write encrypted data to the OpenSSL network BIO * Write encrypted data to the OpenSSL network BIO.
*/ */
private int writeEncryptedData(final ByteBuffer src) { private int writeEncryptedData(final ByteBuffer src) {
final int pos = src.position(); final int pos = src.position();
@ -262,7 +263,6 @@ public final class OpenSslEngine extends SSLEngine {
final int netWrote = SSL.writeToBIO(networkBIO, addr, len); final int netWrote = SSL.writeToBIO(networkBIO, addr, len);
if (netWrote >= 0) { if (netWrote >= 0) {
src.position(pos + netWrote); src.position(pos + netWrote);
lastPrimingReadResult = SSL.readFromSSL(ssl, addr, 0); // priming read
return netWrote; return netWrote;
} }
} else { } else {
@ -275,7 +275,6 @@ public final class OpenSslEngine extends SSLEngine {
final int netWrote = SSL.writeToBIO(networkBIO, addr, len); final int netWrote = SSL.writeToBIO(networkBIO, addr, len);
if (netWrote >= 0) { if (netWrote >= 0) {
src.position(pos + netWrote); src.position(pos + netWrote);
lastPrimingReadResult = SSL.readFromSSL(ssl, addr, 0); // priming read
return netWrote; return netWrote;
} else { } else {
src.position(pos); src.position(pos);
@ -285,7 +284,7 @@ public final class OpenSslEngine extends SSLEngine {
} }
} }
return 0; return -1;
} }
/** /**
@ -464,9 +463,9 @@ public final class OpenSslEngine extends SSLEngine {
return new SSLEngineResult(getEngineStatus(), getHandshakeStatus(), bytesConsumed, bytesProduced); return new SSLEngineResult(getEngineStatus(), getHandshakeStatus(), bytesConsumed, bytesProduced);
} }
@Override
public synchronized SSLEngineResult unwrap( public synchronized SSLEngineResult unwrap(
final ByteBuffer src, final ByteBuffer[] dsts, final int offset, final int length) throws SSLException { final ByteBuffer[] srcs, int srcsOffset, final int srcsLength,
final ByteBuffer[] dsts, final int dstsOffset, final int dstsLength) throws SSLException {
// Check to make sure the engine has not been closed // Check to make sure the engine has not been closed
if (destroyed != 0) { if (destroyed != 0) {
@ -474,21 +473,26 @@ public final class OpenSslEngine extends SSLEngine {
} }
// Throw requried runtime exceptions // Throw requried runtime exceptions
if (src == null) { if (srcs == null) {
throw new NullPointerException("src"); throw new NullPointerException("srcs");
}
if (srcsOffset >= srcs.length
|| srcsOffset + srcsLength > srcs.length) {
throw new IndexOutOfBoundsException(
"offset: " + srcsOffset + ", length: " + srcsLength +
" (expected: offset <= offset + length <= srcs.length (" + srcs.length + "))");
} }
if (dsts == null) { if (dsts == null) {
throw new NullPointerException("dsts"); throw new NullPointerException("dsts");
} }
if (offset >= dsts.length || offset + length > dsts.length) { if (dstsOffset >= dsts.length || dstsOffset + dstsLength > dsts.length) {
throw new IndexOutOfBoundsException( throw new IndexOutOfBoundsException(
"offset: " + offset + ", length: " + length + "offset: " + dstsOffset + ", length: " + dstsLength +
" (expected: offset <= offset + length <= dsts.length (" + dsts.length + "))"); " (expected: offset <= offset + length <= dsts.length (" + dsts.length + "))");
} }
int capacity = 0; int capacity = 0;
final int endOffset = offset + length; final int endOffset = dstsOffset + dstsLength;
for (int i = offset; i < endOffset; i ++) { for (int i = dstsOffset; i < endOffset; i ++) {
ByteBuffer dst = dsts[i]; ByteBuffer dst = dsts[i];
if (dst == null) { if (dst == null) {
throw new IllegalArgumentException(); throw new IllegalArgumentException();
@ -511,8 +515,18 @@ public final class OpenSslEngine extends SSLEngine {
return new SSLEngineResult(getEngineStatus(), NEED_WRAP, 0, 0); return new SSLEngineResult(getEngineStatus(), NEED_WRAP, 0, 0);
} }
final int srcsEndOffset = srcsOffset + srcsLength;
int len = 0;
for (int i = srcsOffset; i < srcsEndOffset; i++) {
ByteBuffer src = srcs[i];
if (src == null) {
throw new NullPointerException("srcs[" + i + ']');
}
len += src.remaining();
}
// protect against protocol overflow attack vector // protect against protocol overflow attack vector
if (src.remaining() > MAX_ENCRYPTED_PACKET_LENGTH) { if (len > MAX_ENCRYPTED_PACKET_LENGTH) {
isInboundDone = true; isInboundDone = true;
isOutboundDone = true; isOutboundDone = true;
engineClosed = true; engineClosed = true;
@ -521,13 +535,37 @@ public final class OpenSslEngine extends SSLEngine {
} }
// Write encrypted data to network BIO // Write encrypted data to network BIO
int bytesConsumed = 0; int bytesConsumed = -1;
lastPrimingReadResult = 0; int lastPrimingReadResult = 0;
try { try {
bytesConsumed += writeEncryptedData(src); while (srcsOffset < srcsEndOffset) {
ByteBuffer src = srcs[srcsOffset];
int remaining = src.remaining();
int written = writeEncryptedData(src);
if (written >= 0) {
if (bytesConsumed == -1) {
bytesConsumed = written;
} else {
bytesConsumed += written;
}
if (written == remaining) {
srcsOffset ++;
} else if (written == 0) {
break;
}
} else {
break;
}
}
} catch (Exception e) { } catch (Exception e) {
throw new SSLException(e); throw new SSLException(e);
} }
if (bytesConsumed >= 0) {
lastPrimingReadResult = SSL.readFromSSL(ssl, EMPTY_ADDR, 0); // priming read
} else {
// Reset to 0 as -1 is used to signal that nothing was written and no priming read needs to be done
bytesConsumed = 0;
}
// Check for OpenSSL errors caused by the priming read // Check for OpenSSL errors caused by the priming read
long error = SSL.getLastErrorNumber(); long error = SSL.getLastErrorNumber();
@ -554,7 +592,7 @@ public final class OpenSslEngine extends SSLEngine {
// Write decrypted data to dsts buffers // Write decrypted data to dsts buffers
int bytesProduced = 0; int bytesProduced = 0;
int idx = offset; int idx = dstsOffset;
while (idx < endOffset) { while (idx < endOffset) {
ByteBuffer dst = dsts[idx]; ByteBuffer dst = dsts[idx];
if (!dst.hasRemaining()) { if (!dst.hasRemaining()) {
@ -595,6 +633,16 @@ public final class OpenSslEngine extends SSLEngine {
return new SSLEngineResult(getEngineStatus(), getHandshakeStatus(), bytesConsumed, bytesProduced); return new SSLEngineResult(getEngineStatus(), getHandshakeStatus(), bytesConsumed, bytesProduced);
} }
public SSLEngineResult unwrap(final ByteBuffer[] srcs, final ByteBuffer[] dsts) throws SSLException {
return unwrap(srcs, 0, srcs.length, dsts, 0, dsts.length);
}
@Override
public SSLEngineResult unwrap(
final ByteBuffer src, final ByteBuffer[] dsts, final int offset, final int length) throws SSLException {
return unwrap(new ByteBuffer[] { src }, 0, 1, dsts, offset, length);
}
@Override @Override
public Runnable getDelegatedTask() { public Runnable getDelegatedTask() {
// Currently, we do not delegate SSL computation tasks // Currently, we do not delegate SSL computation tasks

View File

@ -169,10 +169,9 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
/** /**
* Used in {@link #unwrapNonAppData(ChannelHandlerContext)} as input for * Used in {@link #unwrapNonAppData(ChannelHandlerContext)} as input for
* {@link #unwrap(ChannelHandlerContext, ByteBuffer, int)}. Using this static instance reduce object * {@link #unwrap(ChannelHandlerContext, ByteBuf, int, int)}. Using this static instance reduce object
* creation as {@link Unpooled#EMPTY_BUFFER#nioBuffer()} creates a new {@link ByteBuffer} everytime. * creation as {@link Unpooled#EMPTY_BUFFER#nioBuffer()} creates a new {@link ByteBuffer} everytime.
*/ */
private static final ByteBuffer EMPTY_DIRECT_BYTEBUFFER = Unpooled.EMPTY_BUFFER.nioBuffer();
private static final SSLException SSLENGINE_CLOSED = new SSLException("SSLEngine closed already"); private static final SSLException SSLENGINE_CLOSED = new SSLException("SSLEngine closed already");
private static final SSLException HANDSHAKE_TIMED_OUT = new SSLException("handshake timed out"); private static final SSLException HANDSHAKE_TIMED_OUT = new SSLException("handshake timed out");
private static final ClosedChannelException CHANNEL_CLOSED = new ClosedChannelException(); private static final ClosedChannelException CHANNEL_CLOSED = new ClosedChannelException();
@ -189,10 +188,11 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
private final Executor delegatedTaskExecutor; private final Executor delegatedTaskExecutor;
/** /**
* Used if {@link SSLEngine#wrap(ByteBuffer[], ByteBuffer)} should be called with a {@link ByteBuf} that is only * Used if {@link SSLEngine#wrap(ByteBuffer[], ByteBuffer)} and {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer[])}
* backed by one {@link ByteBuffer} to reduce the object creation. * should be called with a {@link ByteBuf} that is only backed by one {@link ByteBuffer} to reduce the object
* creation.
*/ */
private final ByteBuffer[] singleWrapBuffer = new ByteBuffer[1]; private final ByteBuffer[] singleBuffer = new ByteBuffer[1];
// BEGIN Platform-dependent flags // BEGIN Platform-dependent flags
@ -282,8 +282,19 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
this.startTls = startTls; this.startTls = startTls;
maxPacketBufferSize = engine.getSession().getPacketBufferSize(); maxPacketBufferSize = engine.getSession().getPacketBufferSize();
wantsDirectBuffer = engine instanceof OpenSslEngine; boolean opensslEngine = engine instanceof OpenSslEngine;
wantsLargeOutboundNetworkBuffer = !(engine instanceof OpenSslEngine); wantsDirectBuffer = opensslEngine;
wantsLargeOutboundNetworkBuffer = !opensslEngine;
/**
* When using JDK {@link SSLEngine}, we use {@link #MERGE_CUMULATOR} because it works only with
* one {@link ByteBuffer}.
*
* When using {@link OpenSslEngine}, we can use {@link #COMPOSITE_CUMULATOR} because it has
* {@link OpenSslEngine#unwrap(ByteBuffer[], ByteBuffer[])} which works with multiple {@link ByteBuffer}s
* and which does not need to do extra memory copies.
*/
setCumulator(opensslEngine ? COMPOSITE_CUMULATOR : MERGE_CUMULATOR);
} }
public long getHandshakeTimeoutMillis() { public long getHandshakeTimeoutMillis() {
@ -613,7 +624,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
// The worst that can happen is that we allocate an extra ByteBuffer[] in CompositeByteBuf.nioBuffers() // The worst that can happen is that we allocate an extra ByteBuffer[] in CompositeByteBuf.nioBuffers()
// which is better then walking the composed ByteBuf in most cases. // which is better then walking the composed ByteBuf in most cases.
if (!(in instanceof CompositeByteBuf) && in.nioBufferCount() == 1) { if (!(in instanceof CompositeByteBuf) && in.nioBufferCount() == 1) {
in0 = singleWrapBuffer; in0 = singleBuffer;
// We know its only backed by 1 ByteBuffer so use internalNioBuffer to keep object allocation // We know its only backed by 1 ByteBuffer so use internalNioBuffer to keep object allocation
// to a minimum. // to a minimum.
in0[0] = in.internalNioBuffer(readerIndex, readableBytes); in0[0] = in.internalNioBuffer(readerIndex, readableBytes);
@ -626,7 +637,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
// CompositeByteBuf to keep the complexity to a minimum // CompositeByteBuf to keep the complexity to a minimum
newDirectIn = alloc.directBuffer(readableBytes); newDirectIn = alloc.directBuffer(readableBytes);
newDirectIn.writeBytes(in, readerIndex, readableBytes); newDirectIn.writeBytes(in, readerIndex, readableBytes);
in0 = singleWrapBuffer; in0 = singleBuffer;
in0[0] = newDirectIn.internalNioBuffer(0, readableBytes); in0[0] = newDirectIn.internalNioBuffer(0, readableBytes);
} }
@ -646,7 +657,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
} }
} finally { } finally {
// Null out to allow GC of ByteBuffer // Null out to allow GC of ByteBuffer
singleWrapBuffer[0] = null; singleBuffer[0] = null;
if (newDirectIn != null) { if (newDirectIn != null) {
newDirectIn.release(); newDirectIn.release();
@ -842,7 +853,6 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
@Override @Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws SSLException { protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws SSLException {
final int startOffset = in.readerIndex(); final int startOffset = in.readerIndex();
final int endOffset = in.writerIndex(); final int endOffset = in.writerIndex();
int offset = startOffset; int offset = startOffset;
@ -906,9 +916,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
// See https://github.com/netty/netty/issues/1534 // See https://github.com/netty/netty/issues/1534
in.skipBytes(totalLength); in.skipBytes(totalLength);
final ByteBuffer inNetBuf = in.nioBuffer(startOffset, totalLength); unwrap(ctx, in, startOffset, totalLength);
unwrap(ctx, inNetBuf, totalLength);
assert !inNetBuf.hasRemaining() || engine.isInboundDone();
} }
if (nonSslRecord) { if (nonSslRecord) {
@ -940,24 +948,24 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
* Calls {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer)} with an empty buffer to handle handshakes, etc. * Calls {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer)} with an empty buffer to handle handshakes, etc.
*/ */
private void unwrapNonAppData(ChannelHandlerContext ctx) throws SSLException { private void unwrapNonAppData(ChannelHandlerContext ctx) throws SSLException {
unwrap(ctx, EMPTY_DIRECT_BYTEBUFFER, 0); unwrap(ctx, Unpooled.EMPTY_BUFFER, 0, 0);
} }
/** /**
* Unwraps inbound SSL records. * Unwraps inbound SSL records.
*/ */
private void unwrap( private void unwrap(ChannelHandlerContext ctx, ByteBuf packet,
ChannelHandlerContext ctx, ByteBuffer packet, int initialOutAppBufCapacity) throws SSLException { int readerIndex, int initialOutAppBufCapacity) throws SSLException {
int len = initialOutAppBufCapacity;
// If SSLEngine expects a heap buffer for unwrapping, do the conversion. // If SSLEngine expects a heap buffer for unwrapping, do the conversion.
final ByteBuffer oldPacket; final ByteBuf oldPacket;
final ByteBuf newPacket; final ByteBuf newPacket;
final int oldPos = packet.position();
if (wantsInboundHeapBuffer && packet.isDirect()) { if (wantsInboundHeapBuffer && packet.isDirect()) {
newPacket = ctx.alloc().heapBuffer(packet.limit() - oldPos); newPacket = ctx.alloc().heapBuffer(packet.readableBytes());
newPacket.writeBytes(packet); newPacket.writeBytes(packet, readerIndex, len);
oldPacket = packet; oldPacket = packet;
packet = newPacket.nioBuffer(); packet = newPacket;
} else { } else {
oldPacket = null; oldPacket = null;
newPacket = null; newPacket = null;
@ -968,12 +976,16 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
ByteBuf decodeOut = allocate(ctx, initialOutAppBufCapacity); ByteBuf decodeOut = allocate(ctx, initialOutAppBufCapacity);
try { try {
for (;;) { for (;;) {
final SSLEngineResult result = unwrap(engine, packet, decodeOut); final SSLEngineResult result = unwrap(engine, packet, readerIndex, len, decodeOut);
final Status status = result.getStatus(); final Status status = result.getStatus();
final HandshakeStatus handshakeStatus = result.getHandshakeStatus(); final HandshakeStatus handshakeStatus = result.getHandshakeStatus();
final int produced = result.bytesProduced(); final int produced = result.bytesProduced();
final int consumed = result.bytesConsumed(); final int consumed = result.bytesConsumed();
// Update indexes for the next iteration
readerIndex += consumed;
len -= consumed;
if (status == Status.CLOSED) { if (status == Status.CLOSED) {
// notify about the CLOSED state of the SSLEngine. See #137 // notify about the CLOSED state of the SSLEngine. See #137
notifyClosure = true; notifyClosure = true;
@ -1029,7 +1041,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
// If we converted packet into a heap buffer at the beginning of this method, // If we converted packet into a heap buffer at the beginning of this method,
// we should synchronize the position of the original buffer. // we should synchronize the position of the original buffer.
if (newPacket != null) { if (newPacket != null) {
oldPacket.position(oldPos + packet.position()); oldPacket.readerIndex(readerIndex + packet.readerIndex());
newPacket.release(); newPacket.release();
} }
@ -1041,25 +1053,85 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
} }
} }
private static SSLEngineResult unwrap(SSLEngine engine, ByteBuffer in, ByteBuf out) throws SSLException { private SSLEngineResult unwrap(
int overflows = 0; SSLEngine engine, ByteBuf in, int readerIndex, int len, ByteBuf out) throws SSLException {
for (;;) { int nioBufferCount = in.nioBufferCount();
ByteBuffer out0 = out.nioBuffer(out.writerIndex(), out.writableBytes()); if (engine instanceof OpenSslEngine && nioBufferCount > 1) {
SSLEngineResult result = engine.unwrap(in, out0); /**
out.writerIndex(out.writerIndex() + result.bytesProduced()); * If {@link OpenSslEngine} is in use,
switch (result.getStatus()) { * we can use a special {@link OpenSslEngine#unwrap(ByteBuffer[], ByteBuffer[])} method
case BUFFER_OVERFLOW: * that accepts multiple {@link ByteBuffer}s without additional memory copies.
int max = engine.getSession().getApplicationBufferSize(); */
switch (overflows ++) { OpenSslEngine opensslEngine = (OpenSslEngine) engine;
case 0: int overflows = 0;
out.ensureWritable(Math.min(max, in.remaining())); ByteBuffer[] in0 = in.nioBuffers(readerIndex, len);
try {
for (;;) {
int writerIndex = out.writerIndex();
int writableBytes = out.writableBytes();
ByteBuffer out0;
if (out.nioBufferCount() == 1) {
out0 = out.internalNioBuffer(writerIndex, writableBytes);
} else {
out0 = out.nioBuffer(writerIndex, writableBytes);
}
singleBuffer[0] = out0;
SSLEngineResult result = opensslEngine.unwrap(in0, singleBuffer);
out.writerIndex(out.writerIndex() + result.bytesProduced());
switch (result.getStatus()) {
case BUFFER_OVERFLOW:
int max = engine.getSession().getApplicationBufferSize();
switch (overflows ++) {
case 0:
out.ensureWritable(Math.min(max, in.readableBytes()));
break;
default:
out.ensureWritable(max);
}
break; break;
default: default:
out.ensureWritable(max); return result;
} }
break; }
default: } finally {
return result; singleBuffer[0] = null;
}
} else {
int overflows = 0;
ByteBuffer in0;
if (nioBufferCount == 1) {
// Use internalNioBuffer to reduce object creation.
in0 = in.internalNioBuffer(readerIndex, len);
} else {
// This should never be true as this is only the case when OpenSslEngine is used, anyway lets
// guard against it.
in0 = in.nioBuffer(readerIndex, len);
}
for (;;) {
int writerIndex = out.writerIndex();
int writableBytes = out.writableBytes();
ByteBuffer out0;
if (out.nioBufferCount() == 1) {
out0 = out.internalNioBuffer(writerIndex, writableBytes);
} else {
out0 = out.nioBuffer(writerIndex, writableBytes);
}
SSLEngineResult result = engine.unwrap(in0, out0);
out.writerIndex(out.writerIndex() + result.bytesProduced());
switch (result.getStatus()) {
case BUFFER_OVERFLOW:
int max = engine.getSession().getApplicationBufferSize();
switch (overflows ++) {
case 0:
out.ensureWritable(Math.min(max, in.readableBytes()));
break;
default:
out.ensureWritable(max);
}
break;
default:
return result;
}
} }
} }
} }