From 1bb818bb59612f1676bb34eb9cde4ea9fb1d3025 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Mon, 24 Nov 2014 20:26:39 +0100 Subject: [PATCH] Reduce memory copies when using OpenSslEngine with SslHandler Motivation: When using OpenSslEngine with the SslHandler it is possible to reduce memory copies by unwrap(...) multiple ByteBuffers at the same time. This way we can eliminate a memory copy that is needed otherwise to cumulate partial received data. Modifications: - Add OpenSslEngine.unwrap(ByteBuffer[],...) method that can be used to unwrap multiple src ByteBuffer a the same time - Use a CompositeByteBuffer in SslHandler for inbound data so we not need to memory copy - Add OpenSslEngine.unwrap(ByteBuffer[],...) in SslHandler if OpenSslEngine is used and the inbound ByteBuf is backed by more then one ByteBuffer - Reduce object allocation Result: SslHandler is faster when using OpenSslEngine and produce less GC --- .../io/netty/handler/ssl/OpenSslEngine.java | 90 ++++++++--- .../java/io/netty/handler/ssl/SslHandler.java | 152 +++++++++++++----- 2 files changed, 181 insertions(+), 61 deletions(-) diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslEngine.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslEngine.java index 6a50a15e92..a42d0ed047 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSslEngine.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslEngine.java @@ -17,6 +17,7 @@ package io.netty.handler.ssl; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.logging.InternalLogger; @@ -120,6 +121,8 @@ public final class OpenSslEngine extends SSLEngine { private static final String INVALID_CIPHER = "SSL_NULL_WITH_NULL_NULL"; + private static final long EMPTY_ADDR = Buffer.address(Unpooled.EMPTY_BUFFER.nioBuffer()); + // OpenSSL state private long ssl; private long networkBIO; @@ -147,8 +150,6 @@ public final class OpenSslEngine extends SSLEngine { private boolean isOutboundDone; private boolean engineClosed; - private int lastPrimingReadResult; - private final boolean clientMode; private final ByteBufAllocator alloc; private final String fallbackApplicationProtocol; @@ -252,7 +253,7 @@ public final class OpenSslEngine extends SSLEngine { } /** - * Write encrypted data to the OpenSSL network BIO + * Write encrypted data to the OpenSSL network BIO. */ private int writeEncryptedData(final ByteBuffer src) { final int pos = src.position(); @@ -262,7 +263,6 @@ public final class OpenSslEngine extends SSLEngine { final int netWrote = SSL.writeToBIO(networkBIO, addr, len); if (netWrote >= 0) { src.position(pos + netWrote); - lastPrimingReadResult = SSL.readFromSSL(ssl, addr, 0); // priming read return netWrote; } } else { @@ -275,7 +275,6 @@ public final class OpenSslEngine extends SSLEngine { final int netWrote = SSL.writeToBIO(networkBIO, addr, len); if (netWrote >= 0) { src.position(pos + netWrote); - lastPrimingReadResult = SSL.readFromSSL(ssl, addr, 0); // priming read return netWrote; } else { src.position(pos); @@ -285,7 +284,7 @@ public final class OpenSslEngine extends SSLEngine { } } - return 0; + return -1; } /** @@ -464,9 +463,9 @@ public final class OpenSslEngine extends SSLEngine { return new SSLEngineResult(getEngineStatus(), getHandshakeStatus(), bytesConsumed, bytesProduced); } - @Override public synchronized SSLEngineResult unwrap( - final ByteBuffer src, final ByteBuffer[] dsts, final int offset, final int length) throws SSLException { + final ByteBuffer[] srcs, int srcsOffset, final int srcsLength, + final ByteBuffer[] dsts, final int dstsOffset, final int dstsLength) throws SSLException { // Check to make sure the engine has not been closed if (destroyed != 0) { @@ -474,21 +473,26 @@ public final class OpenSslEngine extends SSLEngine { } // Throw requried runtime exceptions - if (src == null) { - throw new NullPointerException("src"); + if (srcs == null) { + throw new NullPointerException("srcs"); + } + if (srcsOffset >= srcs.length + || srcsOffset + srcsLength > srcs.length) { + throw new IndexOutOfBoundsException( + "offset: " + srcsOffset + ", length: " + srcsLength + + " (expected: offset <= offset + length <= srcs.length (" + srcs.length + "))"); } if (dsts == null) { throw new NullPointerException("dsts"); } - if (offset >= dsts.length || offset + length > dsts.length) { + if (dstsOffset >= dsts.length || dstsOffset + dstsLength > dsts.length) { throw new IndexOutOfBoundsException( - "offset: " + offset + ", length: " + length + - " (expected: offset <= offset + length <= dsts.length (" + dsts.length + "))"); + "offset: " + dstsOffset + ", length: " + dstsLength + + " (expected: offset <= offset + length <= dsts.length (" + dsts.length + "))"); } - int capacity = 0; - final int endOffset = offset + length; - for (int i = offset; i < endOffset; i ++) { + final int endOffset = dstsOffset + dstsLength; + for (int i = dstsOffset; i < endOffset; i ++) { ByteBuffer dst = dsts[i]; if (dst == null) { throw new IllegalArgumentException(); @@ -511,8 +515,18 @@ public final class OpenSslEngine extends SSLEngine { return new SSLEngineResult(getEngineStatus(), NEED_WRAP, 0, 0); } + final int srcsEndOffset = srcsOffset + srcsLength; + int len = 0; + for (int i = srcsOffset; i < srcsEndOffset; i++) { + ByteBuffer src = srcs[i]; + if (src == null) { + throw new NullPointerException("srcs[" + i + ']'); + } + len += src.remaining(); + } + // protect against protocol overflow attack vector - if (src.remaining() > MAX_ENCRYPTED_PACKET_LENGTH) { + if (len > MAX_ENCRYPTED_PACKET_LENGTH) { isInboundDone = true; isOutboundDone = true; engineClosed = true; @@ -521,13 +535,37 @@ public final class OpenSslEngine extends SSLEngine { } // Write encrypted data to network BIO - int bytesConsumed = 0; - lastPrimingReadResult = 0; + int bytesConsumed = -1; + int lastPrimingReadResult = 0; try { - bytesConsumed += writeEncryptedData(src); + while (srcsOffset < srcsEndOffset) { + ByteBuffer src = srcs[srcsOffset]; + int remaining = src.remaining(); + int written = writeEncryptedData(src); + if (written >= 0) { + if (bytesConsumed == -1) { + bytesConsumed = written; + } else { + bytesConsumed += written; + } + if (written == remaining) { + srcsOffset ++; + } else if (written == 0) { + break; + } + } else { + break; + } + } } catch (Exception e) { throw new SSLException(e); } + if (bytesConsumed >= 0) { + lastPrimingReadResult = SSL.readFromSSL(ssl, EMPTY_ADDR, 0); // priming read + } else { + // Reset to 0 as -1 is used to signal that nothing was written and no priming read needs to be done + bytesConsumed = 0; + } // Check for OpenSSL errors caused by the priming read long error = SSL.getLastErrorNumber(); @@ -554,7 +592,7 @@ public final class OpenSslEngine extends SSLEngine { // Write decrypted data to dsts buffers int bytesProduced = 0; - int idx = offset; + int idx = dstsOffset; while (idx < endOffset) { ByteBuffer dst = dsts[idx]; if (!dst.hasRemaining()) { @@ -595,6 +633,16 @@ public final class OpenSslEngine extends SSLEngine { return new SSLEngineResult(getEngineStatus(), getHandshakeStatus(), bytesConsumed, bytesProduced); } + public SSLEngineResult unwrap(final ByteBuffer[] srcs, final ByteBuffer[] dsts) throws SSLException { + return unwrap(srcs, 0, srcs.length, dsts, 0, dsts.length); + } + + @Override + public SSLEngineResult unwrap( + final ByteBuffer src, final ByteBuffer[] dsts, final int offset, final int length) throws SSLException { + return unwrap(new ByteBuffer[] { src }, 0, 1, dsts, offset, length); + } + @Override public Runnable getDelegatedTask() { // Currently, we do not delegate SSL computation tasks diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index c92f96567d..695c611353 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -169,10 +169,9 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH /** * Used in {@link #unwrapNonAppData(ChannelHandlerContext)} as input for - * {@link #unwrap(ChannelHandlerContext, ByteBuffer, int)}. Using this static instance reduce object + * {@link #unwrap(ChannelHandlerContext, ByteBuf, int, int)}. Using this static instance reduce object * creation as {@link Unpooled#EMPTY_BUFFER#nioBuffer()} creates a new {@link ByteBuffer} everytime. */ - private static final ByteBuffer EMPTY_DIRECT_BYTEBUFFER = Unpooled.EMPTY_BUFFER.nioBuffer(); private static final SSLException SSLENGINE_CLOSED = new SSLException("SSLEngine closed already"); private static final SSLException HANDSHAKE_TIMED_OUT = new SSLException("handshake timed out"); private static final ClosedChannelException CHANNEL_CLOSED = new ClosedChannelException(); @@ -189,10 +188,11 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH private final Executor delegatedTaskExecutor; /** - * Used if {@link SSLEngine#wrap(ByteBuffer[], ByteBuffer)} should be called with a {@link ByteBuf} that is only - * backed by one {@link ByteBuffer} to reduce the object creation. + * Used if {@link SSLEngine#wrap(ByteBuffer[], ByteBuffer)} and {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer[])} + * should be called with a {@link ByteBuf} that is only backed by one {@link ByteBuffer} to reduce the object + * creation. */ - private final ByteBuffer[] singleWrapBuffer = new ByteBuffer[1]; + private final ByteBuffer[] singleBuffer = new ByteBuffer[1]; // BEGIN Platform-dependent flags @@ -282,8 +282,19 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH this.startTls = startTls; maxPacketBufferSize = engine.getSession().getPacketBufferSize(); - wantsDirectBuffer = engine instanceof OpenSslEngine; - wantsLargeOutboundNetworkBuffer = !(engine instanceof OpenSslEngine); + boolean opensslEngine = engine instanceof OpenSslEngine; + wantsDirectBuffer = opensslEngine; + wantsLargeOutboundNetworkBuffer = !opensslEngine; + + /** + * When using JDK {@link SSLEngine}, we use {@link #MERGE_CUMULATOR} because it works only with + * one {@link ByteBuffer}. + * + * When using {@link OpenSslEngine}, we can use {@link #COMPOSITE_CUMULATOR} because it has + * {@link OpenSslEngine#unwrap(ByteBuffer[], ByteBuffer[])} which works with multiple {@link ByteBuffer}s + * and which does not need to do extra memory copies. + */ + setCumulator(opensslEngine ? COMPOSITE_CUMULATOR : MERGE_CUMULATOR); } public long getHandshakeTimeoutMillis() { @@ -613,7 +624,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH // The worst that can happen is that we allocate an extra ByteBuffer[] in CompositeByteBuf.nioBuffers() // which is better then walking the composed ByteBuf in most cases. if (!(in instanceof CompositeByteBuf) && in.nioBufferCount() == 1) { - in0 = singleWrapBuffer; + in0 = singleBuffer; // We know its only backed by 1 ByteBuffer so use internalNioBuffer to keep object allocation // to a minimum. in0[0] = in.internalNioBuffer(readerIndex, readableBytes); @@ -626,7 +637,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH // CompositeByteBuf to keep the complexity to a minimum newDirectIn = alloc.directBuffer(readableBytes); newDirectIn.writeBytes(in, readerIndex, readableBytes); - in0 = singleWrapBuffer; + in0 = singleBuffer; in0[0] = newDirectIn.internalNioBuffer(0, readableBytes); } @@ -646,7 +657,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH } } finally { // Null out to allow GC of ByteBuffer - singleWrapBuffer[0] = null; + singleBuffer[0] = null; if (newDirectIn != null) { newDirectIn.release(); @@ -842,7 +853,6 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws SSLException { - final int startOffset = in.readerIndex(); final int endOffset = in.writerIndex(); int offset = startOffset; @@ -906,9 +916,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH // See https://github.com/netty/netty/issues/1534 in.skipBytes(totalLength); - final ByteBuffer inNetBuf = in.nioBuffer(startOffset, totalLength); - unwrap(ctx, inNetBuf, totalLength); - assert !inNetBuf.hasRemaining() || engine.isInboundDone(); + unwrap(ctx, in, startOffset, totalLength); } if (nonSslRecord) { @@ -940,24 +948,24 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH * Calls {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer)} with an empty buffer to handle handshakes, etc. */ private void unwrapNonAppData(ChannelHandlerContext ctx) throws SSLException { - unwrap(ctx, EMPTY_DIRECT_BYTEBUFFER, 0); + unwrap(ctx, Unpooled.EMPTY_BUFFER, 0, 0); } /** * Unwraps inbound SSL records. */ - private void unwrap( - ChannelHandlerContext ctx, ByteBuffer packet, int initialOutAppBufCapacity) throws SSLException { + private void unwrap(ChannelHandlerContext ctx, ByteBuf packet, + int readerIndex, int initialOutAppBufCapacity) throws SSLException { + int len = initialOutAppBufCapacity; // If SSLEngine expects a heap buffer for unwrapping, do the conversion. - final ByteBuffer oldPacket; + final ByteBuf oldPacket; final ByteBuf newPacket; - final int oldPos = packet.position(); if (wantsInboundHeapBuffer && packet.isDirect()) { - newPacket = ctx.alloc().heapBuffer(packet.limit() - oldPos); - newPacket.writeBytes(packet); + newPacket = ctx.alloc().heapBuffer(packet.readableBytes()); + newPacket.writeBytes(packet, readerIndex, len); oldPacket = packet; - packet = newPacket.nioBuffer(); + packet = newPacket; } else { oldPacket = null; newPacket = null; @@ -968,12 +976,16 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH ByteBuf decodeOut = allocate(ctx, initialOutAppBufCapacity); try { for (;;) { - final SSLEngineResult result = unwrap(engine, packet, decodeOut); + final SSLEngineResult result = unwrap(engine, packet, readerIndex, len, decodeOut); final Status status = result.getStatus(); final HandshakeStatus handshakeStatus = result.getHandshakeStatus(); final int produced = result.bytesProduced(); final int consumed = result.bytesConsumed(); + // Update indexes for the next iteration + readerIndex += consumed; + len -= consumed; + if (status == Status.CLOSED) { // notify about the CLOSED state of the SSLEngine. See #137 notifyClosure = true; @@ -1029,7 +1041,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH // If we converted packet into a heap buffer at the beginning of this method, // we should synchronize the position of the original buffer. if (newPacket != null) { - oldPacket.position(oldPos + packet.position()); + oldPacket.readerIndex(readerIndex + packet.readerIndex()); newPacket.release(); } @@ -1041,25 +1053,85 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH } } - private static SSLEngineResult unwrap(SSLEngine engine, ByteBuffer in, ByteBuf out) throws SSLException { - int overflows = 0; - for (;;) { - ByteBuffer out0 = out.nioBuffer(out.writerIndex(), out.writableBytes()); - SSLEngineResult result = engine.unwrap(in, out0); - out.writerIndex(out.writerIndex() + result.bytesProduced()); - switch (result.getStatus()) { - case BUFFER_OVERFLOW: - int max = engine.getSession().getApplicationBufferSize(); - switch (overflows ++) { - case 0: - out.ensureWritable(Math.min(max, in.remaining())); + private SSLEngineResult unwrap( + SSLEngine engine, ByteBuf in, int readerIndex, int len, ByteBuf out) throws SSLException { + int nioBufferCount = in.nioBufferCount(); + if (engine instanceof OpenSslEngine && nioBufferCount > 1) { + /** + * If {@link OpenSslEngine} is in use, + * we can use a special {@link OpenSslEngine#unwrap(ByteBuffer[], ByteBuffer[])} method + * that accepts multiple {@link ByteBuffer}s without additional memory copies. + */ + OpenSslEngine opensslEngine = (OpenSslEngine) engine; + int overflows = 0; + ByteBuffer[] in0 = in.nioBuffers(readerIndex, len); + try { + for (;;) { + int writerIndex = out.writerIndex(); + int writableBytes = out.writableBytes(); + ByteBuffer out0; + if (out.nioBufferCount() == 1) { + out0 = out.internalNioBuffer(writerIndex, writableBytes); + } else { + out0 = out.nioBuffer(writerIndex, writableBytes); + } + singleBuffer[0] = out0; + SSLEngineResult result = opensslEngine.unwrap(in0, singleBuffer); + out.writerIndex(out.writerIndex() + result.bytesProduced()); + switch (result.getStatus()) { + case BUFFER_OVERFLOW: + int max = engine.getSession().getApplicationBufferSize(); + switch (overflows ++) { + case 0: + out.ensureWritable(Math.min(max, in.readableBytes())); + break; + default: + out.ensureWritable(max); + } break; default: - out.ensureWritable(max); + return result; } - break; - default: - return result; + } + } finally { + singleBuffer[0] = null; + } + } else { + int overflows = 0; + ByteBuffer in0; + if (nioBufferCount == 1) { + // Use internalNioBuffer to reduce object creation. + in0 = in.internalNioBuffer(readerIndex, len); + } else { + // This should never be true as this is only the case when OpenSslEngine is used, anyway lets + // guard against it. + in0 = in.nioBuffer(readerIndex, len); + } + for (;;) { + int writerIndex = out.writerIndex(); + int writableBytes = out.writableBytes(); + ByteBuffer out0; + if (out.nioBufferCount() == 1) { + out0 = out.internalNioBuffer(writerIndex, writableBytes); + } else { + out0 = out.nioBuffer(writerIndex, writableBytes); + } + SSLEngineResult result = engine.unwrap(in0, out0); + out.writerIndex(out.writerIndex() + result.bytesProduced()); + switch (result.getStatus()) { + case BUFFER_OVERFLOW: + int max = engine.getSession().getApplicationBufferSize(); + switch (overflows ++) { + case 0: + out.ensureWritable(Math.min(max, in.readableBytes())); + break; + default: + out.ensureWritable(max); + } + break; + default: + return result; + } } } }