Reduce memory copies when using OpenSslEngine with SslHandler

Motivation:

When using OpenSslEngine with the SslHandler it is possible to reduce memory copies by unwrap(...) multiple ByteBuffers at the same time. This way we can eliminate a memory copy that is needed otherwise to cumulate partial received data.

Modifications:

- Add OpenSslEngine.unwrap(ByteBuffer[],...) method that can be used to unwrap multiple src ByteBuffer a the same time
- Use a CompositeByteBuffer in SslHandler for inbound data so we not need to memory copy
- Add OpenSslEngine.unwrap(ByteBuffer[],...) in SslHandler if OpenSslEngine is used and the inbound ByteBuf is backed by more then one ByteBuffer
- Reduce object allocation

Result:

SslHandler is faster when using OpenSslEngine and produce less GC
This commit is contained in:
Norman Maurer 2014-11-24 20:26:39 +01:00 committed by Norman Maurer
parent 017a5ef4e4
commit f46a3d74d0
2 changed files with 181 additions and 61 deletions

View File

@ -17,6 +17,7 @@ package io.netty.handler.ssl;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.util.internal.EmptyArrays;
import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.logging.InternalLogger;
@ -120,6 +121,8 @@ public final class OpenSslEngine extends SSLEngine {
private static final String INVALID_CIPHER = "SSL_NULL_WITH_NULL_NULL";
private static final long EMPTY_ADDR = Buffer.address(Unpooled.EMPTY_BUFFER.nioBuffer());
// OpenSSL state
private long ssl;
private long networkBIO;
@ -147,8 +150,6 @@ public final class OpenSslEngine extends SSLEngine {
private boolean isOutboundDone;
private boolean engineClosed;
private int lastPrimingReadResult;
private final boolean clientMode;
private final ByteBufAllocator alloc;
private final String fallbackApplicationProtocol;
@ -252,7 +253,7 @@ public final class OpenSslEngine extends SSLEngine {
}
/**
* Write encrypted data to the OpenSSL network BIO
* Write encrypted data to the OpenSSL network BIO.
*/
private int writeEncryptedData(final ByteBuffer src) {
final int pos = src.position();
@ -262,7 +263,6 @@ public final class OpenSslEngine extends SSLEngine {
final int netWrote = SSL.writeToBIO(networkBIO, addr, len);
if (netWrote >= 0) {
src.position(pos + netWrote);
lastPrimingReadResult = SSL.readFromSSL(ssl, addr, 0); // priming read
return netWrote;
}
} else {
@ -275,7 +275,6 @@ public final class OpenSslEngine extends SSLEngine {
final int netWrote = SSL.writeToBIO(networkBIO, addr, len);
if (netWrote >= 0) {
src.position(pos + netWrote);
lastPrimingReadResult = SSL.readFromSSL(ssl, addr, 0); // priming read
return netWrote;
} else {
src.position(pos);
@ -285,7 +284,7 @@ public final class OpenSslEngine extends SSLEngine {
}
}
return 0;
return -1;
}
/**
@ -464,9 +463,9 @@ public final class OpenSslEngine extends SSLEngine {
return new SSLEngineResult(getEngineStatus(), getHandshakeStatus(), bytesConsumed, bytesProduced);
}
@Override
public synchronized SSLEngineResult unwrap(
final ByteBuffer src, final ByteBuffer[] dsts, final int offset, final int length) throws SSLException {
final ByteBuffer[] srcs, int srcsOffset, final int srcsLength,
final ByteBuffer[] dsts, final int dstsOffset, final int dstsLength) throws SSLException {
// Check to make sure the engine has not been closed
if (destroyed != 0) {
@ -474,21 +473,26 @@ public final class OpenSslEngine extends SSLEngine {
}
// Throw requried runtime exceptions
if (src == null) {
throw new NullPointerException("src");
if (srcs == null) {
throw new NullPointerException("srcs");
}
if (srcsOffset >= srcs.length
|| srcsOffset + srcsLength > srcs.length) {
throw new IndexOutOfBoundsException(
"offset: " + srcsOffset + ", length: " + srcsLength +
" (expected: offset <= offset + length <= srcs.length (" + srcs.length + "))");
}
if (dsts == null) {
throw new NullPointerException("dsts");
}
if (offset >= dsts.length || offset + length > dsts.length) {
if (dstsOffset >= dsts.length || dstsOffset + dstsLength > dsts.length) {
throw new IndexOutOfBoundsException(
"offset: " + offset + ", length: " + length +
" (expected: offset <= offset + length <= dsts.length (" + dsts.length + "))");
"offset: " + dstsOffset + ", length: " + dstsLength +
" (expected: offset <= offset + length <= dsts.length (" + dsts.length + "))");
}
int capacity = 0;
final int endOffset = offset + length;
for (int i = offset; i < endOffset; i ++) {
final int endOffset = dstsOffset + dstsLength;
for (int i = dstsOffset; i < endOffset; i ++) {
ByteBuffer dst = dsts[i];
if (dst == null) {
throw new IllegalArgumentException();
@ -511,8 +515,18 @@ public final class OpenSslEngine extends SSLEngine {
return new SSLEngineResult(getEngineStatus(), NEED_WRAP, 0, 0);
}
final int srcsEndOffset = srcsOffset + srcsLength;
int len = 0;
for (int i = srcsOffset; i < srcsEndOffset; i++) {
ByteBuffer src = srcs[i];
if (src == null) {
throw new NullPointerException("srcs[" + i + ']');
}
len += src.remaining();
}
// protect against protocol overflow attack vector
if (src.remaining() > MAX_ENCRYPTED_PACKET_LENGTH) {
if (len > MAX_ENCRYPTED_PACKET_LENGTH) {
isInboundDone = true;
isOutboundDone = true;
engineClosed = true;
@ -521,13 +535,37 @@ public final class OpenSslEngine extends SSLEngine {
}
// Write encrypted data to network BIO
int bytesConsumed = 0;
lastPrimingReadResult = 0;
int bytesConsumed = -1;
int lastPrimingReadResult = 0;
try {
bytesConsumed += writeEncryptedData(src);
while (srcsOffset < srcsEndOffset) {
ByteBuffer src = srcs[srcsOffset];
int remaining = src.remaining();
int written = writeEncryptedData(src);
if (written >= 0) {
if (bytesConsumed == -1) {
bytesConsumed = written;
} else {
bytesConsumed += written;
}
if (written == remaining) {
srcsOffset ++;
} else if (written == 0) {
break;
}
} else {
break;
}
}
} catch (Exception e) {
throw new SSLException(e);
}
if (bytesConsumed >= 0) {
lastPrimingReadResult = SSL.readFromSSL(ssl, EMPTY_ADDR, 0); // priming read
} else {
// Reset to 0 as -1 is used to signal that nothing was written and no priming read needs to be done
bytesConsumed = 0;
}
// Check for OpenSSL errors caused by the priming read
long error = SSL.getLastErrorNumber();
@ -554,7 +592,7 @@ public final class OpenSslEngine extends SSLEngine {
// Write decrypted data to dsts buffers
int bytesProduced = 0;
int idx = offset;
int idx = dstsOffset;
while (idx < endOffset) {
ByteBuffer dst = dsts[idx];
if (!dst.hasRemaining()) {
@ -595,6 +633,16 @@ public final class OpenSslEngine extends SSLEngine {
return new SSLEngineResult(getEngineStatus(), getHandshakeStatus(), bytesConsumed, bytesProduced);
}
public SSLEngineResult unwrap(final ByteBuffer[] srcs, final ByteBuffer[] dsts) throws SSLException {
return unwrap(srcs, 0, srcs.length, dsts, 0, dsts.length);
}
@Override
public SSLEngineResult unwrap(
final ByteBuffer src, final ByteBuffer[] dsts, final int offset, final int length) throws SSLException {
return unwrap(new ByteBuffer[] { src }, 0, 1, dsts, offset, length);
}
@Override
public Runnable getDelegatedTask() {
// Currently, we do not delegate SSL computation tasks

View File

@ -169,10 +169,9 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
/**
* Used in {@link #unwrapNonAppData(ChannelHandlerContext)} as input for
* {@link #unwrap(ChannelHandlerContext, ByteBuffer, int)}. Using this static instance reduce object
* {@link #unwrap(ChannelHandlerContext, ByteBuf, int, int)}. Using this static instance reduce object
* creation as {@link Unpooled#EMPTY_BUFFER#nioBuffer()} creates a new {@link ByteBuffer} everytime.
*/
private static final ByteBuffer EMPTY_DIRECT_BYTEBUFFER = Unpooled.EMPTY_BUFFER.nioBuffer();
private static final SSLException SSLENGINE_CLOSED = new SSLException("SSLEngine closed already");
private static final SSLException HANDSHAKE_TIMED_OUT = new SSLException("handshake timed out");
private static final ClosedChannelException CHANNEL_CLOSED = new ClosedChannelException();
@ -189,10 +188,11 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
private final Executor delegatedTaskExecutor;
/**
* Used if {@link SSLEngine#wrap(ByteBuffer[], ByteBuffer)} should be called with a {@link ByteBuf} that is only
* backed by one {@link ByteBuffer} to reduce the object creation.
* Used if {@link SSLEngine#wrap(ByteBuffer[], ByteBuffer)} and {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer[])}
* should be called with a {@link ByteBuf} that is only backed by one {@link ByteBuffer} to reduce the object
* creation.
*/
private final ByteBuffer[] singleWrapBuffer = new ByteBuffer[1];
private final ByteBuffer[] singleBuffer = new ByteBuffer[1];
// BEGIN Platform-dependent flags
@ -282,8 +282,19 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
this.startTls = startTls;
maxPacketBufferSize = engine.getSession().getPacketBufferSize();
wantsDirectBuffer = engine instanceof OpenSslEngine;
wantsLargeOutboundNetworkBuffer = !(engine instanceof OpenSslEngine);
boolean opensslEngine = engine instanceof OpenSslEngine;
wantsDirectBuffer = opensslEngine;
wantsLargeOutboundNetworkBuffer = !opensslEngine;
/**
* When using JDK {@link SSLEngine}, we use {@link #MERGE_CUMULATOR} because it works only with
* one {@link ByteBuffer}.
*
* When using {@link OpenSslEngine}, we can use {@link #COMPOSITE_CUMULATOR} because it has
* {@link OpenSslEngine#unwrap(ByteBuffer[], ByteBuffer[])} which works with multiple {@link ByteBuffer}s
* and which does not need to do extra memory copies.
*/
setCumulator(opensslEngine ? COMPOSITE_CUMULATOR : MERGE_CUMULATOR);
}
public long getHandshakeTimeoutMillis() {
@ -613,7 +624,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
// The worst that can happen is that we allocate an extra ByteBuffer[] in CompositeByteBuf.nioBuffers()
// which is better then walking the composed ByteBuf in most cases.
if (!(in instanceof CompositeByteBuf) && in.nioBufferCount() == 1) {
in0 = singleWrapBuffer;
in0 = singleBuffer;
// We know its only backed by 1 ByteBuffer so use internalNioBuffer to keep object allocation
// to a minimum.
in0[0] = in.internalNioBuffer(readerIndex, readableBytes);
@ -626,7 +637,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
// CompositeByteBuf to keep the complexity to a minimum
newDirectIn = alloc.directBuffer(readableBytes);
newDirectIn.writeBytes(in, readerIndex, readableBytes);
in0 = singleWrapBuffer;
in0 = singleBuffer;
in0[0] = newDirectIn.internalNioBuffer(0, readableBytes);
}
@ -646,7 +657,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
}
} finally {
// Null out to allow GC of ByteBuffer
singleWrapBuffer[0] = null;
singleBuffer[0] = null;
if (newDirectIn != null) {
newDirectIn.release();
@ -842,7 +853,6 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws SSLException {
final int startOffset = in.readerIndex();
final int endOffset = in.writerIndex();
int offset = startOffset;
@ -906,9 +916,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
// See https://github.com/netty/netty/issues/1534
in.skipBytes(totalLength);
final ByteBuffer inNetBuf = in.nioBuffer(startOffset, totalLength);
unwrap(ctx, inNetBuf, totalLength);
assert !inNetBuf.hasRemaining() || engine.isInboundDone();
unwrap(ctx, in, startOffset, totalLength);
}
if (nonSslRecord) {
@ -940,24 +948,24 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
* Calls {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer)} with an empty buffer to handle handshakes, etc.
*/
private void unwrapNonAppData(ChannelHandlerContext ctx) throws SSLException {
unwrap(ctx, EMPTY_DIRECT_BYTEBUFFER, 0);
unwrap(ctx, Unpooled.EMPTY_BUFFER, 0, 0);
}
/**
* Unwraps inbound SSL records.
*/
private void unwrap(
ChannelHandlerContext ctx, ByteBuffer packet, int initialOutAppBufCapacity) throws SSLException {
private void unwrap(ChannelHandlerContext ctx, ByteBuf packet,
int readerIndex, int initialOutAppBufCapacity) throws SSLException {
int len = initialOutAppBufCapacity;
// If SSLEngine expects a heap buffer for unwrapping, do the conversion.
final ByteBuffer oldPacket;
final ByteBuf oldPacket;
final ByteBuf newPacket;
final int oldPos = packet.position();
if (wantsInboundHeapBuffer && packet.isDirect()) {
newPacket = ctx.alloc().heapBuffer(packet.limit() - oldPos);
newPacket.writeBytes(packet);
newPacket = ctx.alloc().heapBuffer(packet.readableBytes());
newPacket.writeBytes(packet, readerIndex, len);
oldPacket = packet;
packet = newPacket.nioBuffer();
packet = newPacket;
} else {
oldPacket = null;
newPacket = null;
@ -968,12 +976,16 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
ByteBuf decodeOut = allocate(ctx, initialOutAppBufCapacity);
try {
for (;;) {
final SSLEngineResult result = unwrap(engine, packet, decodeOut);
final SSLEngineResult result = unwrap(engine, packet, readerIndex, len, decodeOut);
final Status status = result.getStatus();
final HandshakeStatus handshakeStatus = result.getHandshakeStatus();
final int produced = result.bytesProduced();
final int consumed = result.bytesConsumed();
// Update indexes for the next iteration
readerIndex += consumed;
len -= consumed;
if (status == Status.CLOSED) {
// notify about the CLOSED state of the SSLEngine. See #137
notifyClosure = true;
@ -1029,7 +1041,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
// If we converted packet into a heap buffer at the beginning of this method,
// we should synchronize the position of the original buffer.
if (newPacket != null) {
oldPacket.position(oldPos + packet.position());
oldPacket.readerIndex(readerIndex + packet.readerIndex());
newPacket.release();
}
@ -1041,25 +1053,85 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
}
}
private static SSLEngineResult unwrap(SSLEngine engine, ByteBuffer in, ByteBuf out) throws SSLException {
int overflows = 0;
for (;;) {
ByteBuffer out0 = out.nioBuffer(out.writerIndex(), out.writableBytes());
SSLEngineResult result = engine.unwrap(in, out0);
out.writerIndex(out.writerIndex() + result.bytesProduced());
switch (result.getStatus()) {
case BUFFER_OVERFLOW:
int max = engine.getSession().getApplicationBufferSize();
switch (overflows ++) {
case 0:
out.ensureWritable(Math.min(max, in.remaining()));
private SSLEngineResult unwrap(
SSLEngine engine, ByteBuf in, int readerIndex, int len, ByteBuf out) throws SSLException {
int nioBufferCount = in.nioBufferCount();
if (engine instanceof OpenSslEngine && nioBufferCount > 1) {
/**
* If {@link OpenSslEngine} is in use,
* we can use a special {@link OpenSslEngine#unwrap(ByteBuffer[], ByteBuffer[])} method
* that accepts multiple {@link ByteBuffer}s without additional memory copies.
*/
OpenSslEngine opensslEngine = (OpenSslEngine) engine;
int overflows = 0;
ByteBuffer[] in0 = in.nioBuffers(readerIndex, len);
try {
for (;;) {
int writerIndex = out.writerIndex();
int writableBytes = out.writableBytes();
ByteBuffer out0;
if (out.nioBufferCount() == 1) {
out0 = out.internalNioBuffer(writerIndex, writableBytes);
} else {
out0 = out.nioBuffer(writerIndex, writableBytes);
}
singleBuffer[0] = out0;
SSLEngineResult result = opensslEngine.unwrap(in0, singleBuffer);
out.writerIndex(out.writerIndex() + result.bytesProduced());
switch (result.getStatus()) {
case BUFFER_OVERFLOW:
int max = engine.getSession().getApplicationBufferSize();
switch (overflows ++) {
case 0:
out.ensureWritable(Math.min(max, in.readableBytes()));
break;
default:
out.ensureWritable(max);
}
break;
default:
out.ensureWritable(max);
return result;
}
break;
default:
return result;
}
} finally {
singleBuffer[0] = null;
}
} else {
int overflows = 0;
ByteBuffer in0;
if (nioBufferCount == 1) {
// Use internalNioBuffer to reduce object creation.
in0 = in.internalNioBuffer(readerIndex, len);
} else {
// This should never be true as this is only the case when OpenSslEngine is used, anyway lets
// guard against it.
in0 = in.nioBuffer(readerIndex, len);
}
for (;;) {
int writerIndex = out.writerIndex();
int writableBytes = out.writableBytes();
ByteBuffer out0;
if (out.nioBufferCount() == 1) {
out0 = out.internalNioBuffer(writerIndex, writableBytes);
} else {
out0 = out.nioBuffer(writerIndex, writableBytes);
}
SSLEngineResult result = engine.unwrap(in0, out0);
out.writerIndex(out.writerIndex() + result.bytesProduced());
switch (result.getStatus()) {
case BUFFER_OVERFLOW:
int max = engine.getSession().getApplicationBufferSize();
switch (overflows ++) {
case 0:
out.ensureWritable(Math.min(max, in.readableBytes()));
break;
default:
out.ensureWritable(max);
}
break;
default:
return result;
}
}
}
}