From f87d0698f37e5360af7c7cb9b4f1de7cbf66db64 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Fri, 4 Nov 2016 10:09:24 +0100 Subject: [PATCH] Let OpenSslEngine.wrap(...) / OpenSslEngine.unwrap(...) behave like stated in the javadocs. Motivation: OpenSslEngine.wrap(...) and OpenSslEngie.unwrap(...) may consume bytes even if an BUFFER_OVERFLOW / BUFFER_UNDERFLOW is detected. This is not correct as it should only consume bytes if it can process them without storing data between unwrap(...) / wrap(...) calls. Beside this it also should only process one record at a time. Modifications: - Correctly detect BUFFER_OVERFLOW / BUFFER_UNDERFLOW and only consume bytes if non of them is detected. - Only process one record per call. Result: OpenSslEngine behaves like stated in the javadocs of SSLEngine. --- .../ssl/ReferenceCountedOpenSslEngine.java | 50 +++++++++-- .../java/io/netty/handler/ssl/SslHandler.java | 23 +---- .../java/io/netty/handler/ssl/SslUtils.java | 88 +++++++++++++++++++ .../io/netty/handler/ssl/SSLEngineTest.java | 85 +++++++++++++++++- 4 files changed, 217 insertions(+), 29 deletions(-) diff --git a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java index 2fc2ed3dd6..f39aa205bc 100644 --- a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java +++ b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java @@ -67,6 +67,7 @@ import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_UNWRAP; import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP; import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING; import static javax.net.ssl.SSLEngineResult.Status.BUFFER_OVERFLOW; +import static javax.net.ssl.SSLEngineResult.Status.BUFFER_UNDERFLOW; import static javax.net.ssl.SSLEngineResult.Status.CLOSED; import static javax.net.ssl.SSLEngineResult.Status.OK; @@ -395,9 +396,8 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc /** * Write encrypted data to the OpenSSL network BIO. */ - private int writeEncryptedData(final ByteBuffer src) { + private int writeEncryptedData(final ByteBuffer src, int len) { final int pos = src.position(); - final int len = src.remaining(); final int netWrote; if (src.isDirect()) { final long addr = Buffer.address(src) + pos; @@ -409,8 +409,12 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc final ByteBuf buf = alloc.directBuffer(len); try { final long addr = memoryAddress(buf); - - buf.setBytes(0, src); + int newLimit = pos + len; + if (newLimit != src.remaining()) { + buf.setBytes(0, (ByteBuffer) src.duplicate().position(pos).limit(newLimit)); + } else { + buf.setBytes(0, src); + } netWrote = SSL.writeToBIO(networkBIO, addr, len); if (netWrote >= 0) { @@ -580,6 +584,11 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc } } + if (dst.remaining() < MAX_ENCRYPTED_PACKET_LENGTH) { + // Can not hold the maximum packet so we need to tell the caller to use a bigger destination + // buffer. + return new SSLEngineResult(BUFFER_OVERFLOW, getHandshakeStatus(), 0, 0); + } // There was no pending data in the network BIO -- encrypt any application data int bytesProduced = 0; int bytesConsumed = 0; @@ -754,9 +763,29 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc } } - // Write encrypted data to network BIO + if (len < SslUtils.SSL_RECORD_HEADER_LENGTH) { + return new SSLEngineResult(BUFFER_UNDERFLOW, getHandshakeStatus(), 0, 0); + } + + int packetLength = SslUtils.getEncryptedPacketLength(srcs, srcsOffset); + if (packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH > capacity) { + // No enough space in the destination buffer so signal the caller + // that the buffer needs to be increased. + return new SSLEngineResult(BUFFER_OVERFLOW, getHandshakeStatus(), 0, 0); + } + + if (len < packetLength) { + // We either have no enough data to read the packet length at all or not enough for reading + // the whole packet. + return new SSLEngineResult(BUFFER_UNDERFLOW, getHandshakeStatus(), 0, 0); + } + int bytesConsumed = 0; if (srcsOffset < srcsEndOffset) { + + // Write encrypted data to network BIO + int packetLengthRemaining = packetLength; + do { ByteBuffer src = srcs[srcsOffset]; int remaining = src.remaining(); @@ -766,9 +795,15 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc srcsOffset++; continue; } - int written = writeEncryptedData(src); + // Write more encrypted data into the BIO. Ensure we only read one packet at a time as + // stated in the SSLEngine javadocs. + int written = writeEncryptedData(src, Math.min(packetLengthRemaining, src.remaining())); if (written > 0) { - bytesConsumed += written; + packetLengthRemaining -= written; + if (packetLengthRemaining == 0) { + // A whole packet has been consumed. + break; + } if (written == remaining) { srcsOffset++; @@ -787,6 +822,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc break; } } while (srcsOffset < srcsEndOffset); + bytesConsumed = packetLength - packetLengthRemaining; } // Number of produced bytes diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index 55c09aef43..b06e3a7a52 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -197,14 +197,6 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH * {@code true} if and only if {@link SSLEngine} expects a direct buffer. */ private final boolean wantsDirectBuffer; - /** - * {@code true} if and only if {@link SSLEngine#wrap(ByteBuffer, ByteBuffer)} requires the output buffer - * to be always as large as {@link #maxPacketBufferSize} even if the input buffer contains small amount of data. - *

- * If this flag is {@code false}, we allocate a smaller output buffer. - *

- */ - private final boolean wantsLargeOutboundNetworkBuffer; // END Platform-dependent flags @@ -283,7 +275,6 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH boolean opensslEngine = engine instanceof OpenSslEngine; wantsDirectBuffer = opensslEngine; - wantsLargeOutboundNetworkBuffer = !opensslEngine; /** * When using JDK {@link SSLEngine}, we use {@link #MERGE_CUMULATOR} because it works only with @@ -516,7 +507,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH ByteBuf buf = (ByteBuf) msg; if (out == null) { - out = allocateOutNetBuf(ctx, buf.readableBytes()); + out = allocateOutNetBuf(ctx); } SSLEngineResult result = wrap(alloc, engine, buf, out); @@ -599,7 +590,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH // See https://github.com/netty/netty/issues/5860 while (!ctx.isRemoved()) { if (out == null) { - out = allocateOutNetBuf(ctx, 0); + out = allocateOutNetBuf(ctx); } SSLEngineResult result = wrap(alloc, engine, Unpooled.EMPTY_BUFFER, out); @@ -1477,14 +1468,8 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH * Allocates an outbound network buffer for {@link SSLEngine#wrap(ByteBuffer, ByteBuffer)} which can encrypt * the specified amount of pending bytes. */ - private ByteBuf allocateOutNetBuf(ChannelHandlerContext ctx, int pendingBytes) { - if (wantsLargeOutboundNetworkBuffer) { - return allocate(ctx, maxPacketBufferSize); - } else { - return allocate(ctx, Math.min( - pendingBytes + OpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH, - maxPacketBufferSize)); - } + private ByteBuf allocateOutNetBuf(ChannelHandlerContext ctx) { + return allocate(ctx, maxPacketBufferSize); } private final class LazyChannelPromise extends DefaultPromise { diff --git a/handler/src/main/java/io/netty/handler/ssl/SslUtils.java b/handler/src/main/java/io/netty/handler/ssl/SslUtils.java index 86e8040c3b..47e40a4a2c 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslUtils.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslUtils.java @@ -22,6 +22,8 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.base64.Base64; import io.netty.handler.codec.base64.Base64Dialect; +import java.nio.ByteBuffer; + /** * Constants for SSL packets. */ @@ -121,6 +123,92 @@ final class SslUtils { return packetLength; } + private static short unsignedByte(byte b) { + return (short) (b & 0xFF); + } + + private static int unsignedShort(short s) { + return s & 0xFFFF; + } + + static int getEncryptedPacketLength(ByteBuffer[] buffers, int offset) { + ByteBuffer buffer = buffers[offset]; + + // Check if everything we need is in one ByteBuffer. If so we can make use of the fast-path. + if (buffer.remaining() >= SslUtils.SSL_RECORD_HEADER_LENGTH) { + return getEncryptedPacketLength(buffer); + } + + // We need to copy 5 bytes into a temporary buffer so we can parse out the packet length easily. + ByteBuffer tmp = ByteBuffer.allocate(5); + + do { + buffer = buffers[offset++].duplicate(); + if (buffer.remaining() > tmp.remaining()) { + buffer.limit(buffer.position() + tmp.remaining()); + } + tmp.put(buffer); + } while (tmp.hasRemaining()); + + // Done, flip the buffer so we can read from it. + tmp.flip(); + return getEncryptedPacketLength(tmp); + } + + private static int getEncryptedPacketLength(ByteBuffer buffer) { + int packetLength = 0; + int pos = buffer.position(); + // SSLv3 or TLS - Check ContentType + boolean tls; + switch (unsignedByte(buffer.get(pos))) { + case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC: + case SslUtils.SSL_CONTENT_TYPE_ALERT: + case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE: + case SslUtils.SSL_CONTENT_TYPE_APPLICATION_DATA: + tls = true; + break; + default: + // SSLv2 or bad data + tls = false; + } + + if (tls) { + // SSLv3 or TLS - Check ProtocolVersion + int majorVersion = unsignedByte(buffer.get(pos + 1)); + if (majorVersion == 3) { + // SSLv3 or TLS + packetLength = unsignedShort(buffer.getShort(pos + 3)) + SslUtils.SSL_RECORD_HEADER_LENGTH; + if (packetLength <= SslUtils.SSL_RECORD_HEADER_LENGTH) { + // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) + tls = false; + } + } else { + // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) + tls = false; + } + } + + if (!tls) { + // SSLv2 or bad data - Check the version + int headerLength = (unsignedByte(buffer.get(pos)) & 0x80) != 0 ? 2 : 3; + int majorVersion = unsignedByte(buffer.get(pos + headerLength + 1)); + if (majorVersion == 2 || majorVersion == 3) { + // SSLv2 + if (headerLength == 2) { + packetLength = (buffer.getShort(pos) & 0x7FFF) + 2; + } else { + packetLength = (buffer.getShort(pos) & 0x3FFF) + 3; + } + if (packetLength <= headerLength) { + return -1; + } + } else { + return -1; + } + } + return packetLength; + } + static void notifyHandshakeFailure(ChannelHandlerContext ctx, Throwable cause) { // We have may haven written some parts of data before an exception was thrown so ensure we always flush. // See https://github.com/netty/netty/issues/3900#issuecomment-172481830 diff --git a/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java b/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java index df54cd2e8f..5e12033a50 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java @@ -678,9 +678,8 @@ public abstract class SSLEngineTest { } protected static void handshake(SSLEngine clientEngine, SSLEngine serverEngine) throws SSLException { - int netBufferSize = 17 * 1024; - ByteBuffer cTOs = ByteBuffer.allocateDirect(netBufferSize); - ByteBuffer sTOc = ByteBuffer.allocateDirect(netBufferSize); + ByteBuffer cTOs = ByteBuffer.allocateDirect(clientEngine.getSession().getPacketBufferSize()); + ByteBuffer sTOc = ByteBuffer.allocateDirect(serverEngine.getSession().getPacketBufferSize()); ByteBuffer serverAppReadBuffer = ByteBuffer.allocateDirect( serverEngine.getSession().getApplicationBufferSize()); @@ -915,4 +914,84 @@ public abstract class SSLEngineTest { promise.syncUninterruptibly(); } + + @Test + public void testUnwrapBehavior() throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + + clientSslCtx = SslContextBuilder + .forClient() + .trustManager(cert.cert()) + .sslProvider(sslClientProvider()) + .build(); + SSLEngine client = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + + serverSslCtx = SslContextBuilder + .forServer(cert.certificate(), cert.privateKey()) + .sslProvider(sslServerProvider()) + .build(); + SSLEngine server = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + + byte[] bytes = "Hello World".getBytes(CharsetUtil.US_ASCII); + + try { + ByteBuffer plainClientOut = ByteBuffer.allocate(client.getSession().getApplicationBufferSize()); + ByteBuffer encryptedClientToServer = ByteBuffer.allocate(server.getSession().getPacketBufferSize() * 2); + ByteBuffer plainServerIn = ByteBuffer.allocate(server.getSession().getApplicationBufferSize()); + + handshake(client, server); + + // create two TLS frames + + // first frame + plainClientOut.put(bytes, 0, 5); + plainClientOut.flip(); + + SSLEngineResult result = client.wrap(plainClientOut, encryptedClientToServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(5, result.bytesConsumed()); + assertTrue(result.bytesProduced() > 0); + + assertFalse(plainClientOut.hasRemaining()); + + // second frame + plainClientOut.clear(); + plainClientOut.put(bytes, 5, 6); + plainClientOut.flip(); + + result = client.wrap(plainClientOut, encryptedClientToServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(6, result.bytesConsumed()); + assertTrue(result.bytesProduced() > 0); + + // send over to server + encryptedClientToServer.flip(); + + // try with too small output buffer first (to check BUFFER_OVERFLOW case) + int remaining = encryptedClientToServer.remaining(); + ByteBuffer small = ByteBuffer.allocate(3); + result = server.unwrap(encryptedClientToServer, small); + assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus()); + assertEquals(remaining, encryptedClientToServer.remaining()); + + // now with big enough buffer + result = server.unwrap(encryptedClientToServer, plainServerIn); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + + assertEquals(5, result.bytesProduced()); + assertTrue(encryptedClientToServer.hasRemaining()); + + result = server.unwrap(encryptedClientToServer, plainServerIn); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(6, result.bytesProduced()); + assertFalse(encryptedClientToServer.hasRemaining()); + + plainServerIn.flip(); + + assertEquals(ByteBuffer.wrap(bytes), plainServerIn); + } finally { + cleanupClientSslEngine(client); + cleanupServerSslEngine(server); + } + } }