diff --git a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java
index 2fc2ed3dd6..f39aa205bc 100644
--- a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java
+++ b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java
@@ -67,6 +67,7 @@ import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_UNWRAP;
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP;
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING;
import static javax.net.ssl.SSLEngineResult.Status.BUFFER_OVERFLOW;
+import static javax.net.ssl.SSLEngineResult.Status.BUFFER_UNDERFLOW;
import static javax.net.ssl.SSLEngineResult.Status.CLOSED;
import static javax.net.ssl.SSLEngineResult.Status.OK;
@@ -395,9 +396,8 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
/**
* Write encrypted data to the OpenSSL network BIO.
*/
- private int writeEncryptedData(final ByteBuffer src) {
+ private int writeEncryptedData(final ByteBuffer src, int len) {
final int pos = src.position();
- final int len = src.remaining();
final int netWrote;
if (src.isDirect()) {
final long addr = Buffer.address(src) + pos;
@@ -409,8 +409,12 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
final ByteBuf buf = alloc.directBuffer(len);
try {
final long addr = memoryAddress(buf);
-
- buf.setBytes(0, src);
+ int newLimit = pos + len;
+ if (newLimit != src.remaining()) {
+ buf.setBytes(0, (ByteBuffer) src.duplicate().position(pos).limit(newLimit));
+ } else {
+ buf.setBytes(0, src);
+ }
netWrote = SSL.writeToBIO(networkBIO, addr, len);
if (netWrote >= 0) {
@@ -580,6 +584,11 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
}
}
+ if (dst.remaining() < MAX_ENCRYPTED_PACKET_LENGTH) {
+ // Can not hold the maximum packet so we need to tell the caller to use a bigger destination
+ // buffer.
+ return new SSLEngineResult(BUFFER_OVERFLOW, getHandshakeStatus(), 0, 0);
+ }
// There was no pending data in the network BIO -- encrypt any application data
int bytesProduced = 0;
int bytesConsumed = 0;
@@ -754,9 +763,29 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
}
}
- // Write encrypted data to network BIO
+ if (len < SslUtils.SSL_RECORD_HEADER_LENGTH) {
+ return new SSLEngineResult(BUFFER_UNDERFLOW, getHandshakeStatus(), 0, 0);
+ }
+
+ int packetLength = SslUtils.getEncryptedPacketLength(srcs, srcsOffset);
+ if (packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH > capacity) {
+ // No enough space in the destination buffer so signal the caller
+ // that the buffer needs to be increased.
+ return new SSLEngineResult(BUFFER_OVERFLOW, getHandshakeStatus(), 0, 0);
+ }
+
+ if (len < packetLength) {
+ // We either have no enough data to read the packet length at all or not enough for reading
+ // the whole packet.
+ return new SSLEngineResult(BUFFER_UNDERFLOW, getHandshakeStatus(), 0, 0);
+ }
+
int bytesConsumed = 0;
if (srcsOffset < srcsEndOffset) {
+
+ // Write encrypted data to network BIO
+ int packetLengthRemaining = packetLength;
+
do {
ByteBuffer src = srcs[srcsOffset];
int remaining = src.remaining();
@@ -766,9 +795,15 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
srcsOffset++;
continue;
}
- int written = writeEncryptedData(src);
+ // Write more encrypted data into the BIO. Ensure we only read one packet at a time as
+ // stated in the SSLEngine javadocs.
+ int written = writeEncryptedData(src, Math.min(packetLengthRemaining, src.remaining()));
if (written > 0) {
- bytesConsumed += written;
+ packetLengthRemaining -= written;
+ if (packetLengthRemaining == 0) {
+ // A whole packet has been consumed.
+ break;
+ }
if (written == remaining) {
srcsOffset++;
@@ -787,6 +822,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
break;
}
} while (srcsOffset < srcsEndOffset);
+ bytesConsumed = packetLength - packetLengthRemaining;
}
// Number of produced bytes
diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java
index 55c09aef43..b06e3a7a52 100644
--- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java
+++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java
@@ -197,14 +197,6 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
* {@code true} if and only if {@link SSLEngine} expects a direct buffer.
*/
private final boolean wantsDirectBuffer;
- /**
- * {@code true} if and only if {@link SSLEngine#wrap(ByteBuffer, ByteBuffer)} requires the output buffer
- * to be always as large as {@link #maxPacketBufferSize} even if the input buffer contains small amount of data.
- *
- * If this flag is {@code false}, we allocate a smaller output buffer.
- *
- */
- private final boolean wantsLargeOutboundNetworkBuffer;
// END Platform-dependent flags
@@ -283,7 +275,6 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
boolean opensslEngine = engine instanceof OpenSslEngine;
wantsDirectBuffer = opensslEngine;
- wantsLargeOutboundNetworkBuffer = !opensslEngine;
/**
* When using JDK {@link SSLEngine}, we use {@link #MERGE_CUMULATOR} because it works only with
@@ -516,7 +507,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
ByteBuf buf = (ByteBuf) msg;
if (out == null) {
- out = allocateOutNetBuf(ctx, buf.readableBytes());
+ out = allocateOutNetBuf(ctx);
}
SSLEngineResult result = wrap(alloc, engine, buf, out);
@@ -599,7 +590,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
// See https://github.com/netty/netty/issues/5860
while (!ctx.isRemoved()) {
if (out == null) {
- out = allocateOutNetBuf(ctx, 0);
+ out = allocateOutNetBuf(ctx);
}
SSLEngineResult result = wrap(alloc, engine, Unpooled.EMPTY_BUFFER, out);
@@ -1477,14 +1468,8 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
* Allocates an outbound network buffer for {@link SSLEngine#wrap(ByteBuffer, ByteBuffer)} which can encrypt
* the specified amount of pending bytes.
*/
- private ByteBuf allocateOutNetBuf(ChannelHandlerContext ctx, int pendingBytes) {
- if (wantsLargeOutboundNetworkBuffer) {
- return allocate(ctx, maxPacketBufferSize);
- } else {
- return allocate(ctx, Math.min(
- pendingBytes + OpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH,
- maxPacketBufferSize));
- }
+ private ByteBuf allocateOutNetBuf(ChannelHandlerContext ctx) {
+ return allocate(ctx, maxPacketBufferSize);
}
private final class LazyChannelPromise extends DefaultPromise {
diff --git a/handler/src/main/java/io/netty/handler/ssl/SslUtils.java b/handler/src/main/java/io/netty/handler/ssl/SslUtils.java
index 86e8040c3b..47e40a4a2c 100644
--- a/handler/src/main/java/io/netty/handler/ssl/SslUtils.java
+++ b/handler/src/main/java/io/netty/handler/ssl/SslUtils.java
@@ -22,6 +22,8 @@ import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.base64.Base64;
import io.netty.handler.codec.base64.Base64Dialect;
+import java.nio.ByteBuffer;
+
/**
* Constants for SSL packets.
*/
@@ -121,6 +123,92 @@ final class SslUtils {
return packetLength;
}
+ private static short unsignedByte(byte b) {
+ return (short) (b & 0xFF);
+ }
+
+ private static int unsignedShort(short s) {
+ return s & 0xFFFF;
+ }
+
+ static int getEncryptedPacketLength(ByteBuffer[] buffers, int offset) {
+ ByteBuffer buffer = buffers[offset];
+
+ // Check if everything we need is in one ByteBuffer. If so we can make use of the fast-path.
+ if (buffer.remaining() >= SslUtils.SSL_RECORD_HEADER_LENGTH) {
+ return getEncryptedPacketLength(buffer);
+ }
+
+ // We need to copy 5 bytes into a temporary buffer so we can parse out the packet length easily.
+ ByteBuffer tmp = ByteBuffer.allocate(5);
+
+ do {
+ buffer = buffers[offset++].duplicate();
+ if (buffer.remaining() > tmp.remaining()) {
+ buffer.limit(buffer.position() + tmp.remaining());
+ }
+ tmp.put(buffer);
+ } while (tmp.hasRemaining());
+
+ // Done, flip the buffer so we can read from it.
+ tmp.flip();
+ return getEncryptedPacketLength(tmp);
+ }
+
+ private static int getEncryptedPacketLength(ByteBuffer buffer) {
+ int packetLength = 0;
+ int pos = buffer.position();
+ // SSLv3 or TLS - Check ContentType
+ boolean tls;
+ switch (unsignedByte(buffer.get(pos))) {
+ case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
+ case SslUtils.SSL_CONTENT_TYPE_ALERT:
+ case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE:
+ case SslUtils.SSL_CONTENT_TYPE_APPLICATION_DATA:
+ tls = true;
+ break;
+ default:
+ // SSLv2 or bad data
+ tls = false;
+ }
+
+ if (tls) {
+ // SSLv3 or TLS - Check ProtocolVersion
+ int majorVersion = unsignedByte(buffer.get(pos + 1));
+ if (majorVersion == 3) {
+ // SSLv3 or TLS
+ packetLength = unsignedShort(buffer.getShort(pos + 3)) + SslUtils.SSL_RECORD_HEADER_LENGTH;
+ if (packetLength <= SslUtils.SSL_RECORD_HEADER_LENGTH) {
+ // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
+ tls = false;
+ }
+ } else {
+ // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
+ tls = false;
+ }
+ }
+
+ if (!tls) {
+ // SSLv2 or bad data - Check the version
+ int headerLength = (unsignedByte(buffer.get(pos)) & 0x80) != 0 ? 2 : 3;
+ int majorVersion = unsignedByte(buffer.get(pos + headerLength + 1));
+ if (majorVersion == 2 || majorVersion == 3) {
+ // SSLv2
+ if (headerLength == 2) {
+ packetLength = (buffer.getShort(pos) & 0x7FFF) + 2;
+ } else {
+ packetLength = (buffer.getShort(pos) & 0x3FFF) + 3;
+ }
+ if (packetLength <= headerLength) {
+ return -1;
+ }
+ } else {
+ return -1;
+ }
+ }
+ return packetLength;
+ }
+
static void notifyHandshakeFailure(ChannelHandlerContext ctx, Throwable cause) {
// We have may haven written some parts of data before an exception was thrown so ensure we always flush.
// See https://github.com/netty/netty/issues/3900#issuecomment-172481830
diff --git a/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java b/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java
index df54cd2e8f..5e12033a50 100644
--- a/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java
+++ b/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java
@@ -678,9 +678,8 @@ public abstract class SSLEngineTest {
}
protected static void handshake(SSLEngine clientEngine, SSLEngine serverEngine) throws SSLException {
- int netBufferSize = 17 * 1024;
- ByteBuffer cTOs = ByteBuffer.allocateDirect(netBufferSize);
- ByteBuffer sTOc = ByteBuffer.allocateDirect(netBufferSize);
+ ByteBuffer cTOs = ByteBuffer.allocateDirect(clientEngine.getSession().getPacketBufferSize());
+ ByteBuffer sTOc = ByteBuffer.allocateDirect(serverEngine.getSession().getPacketBufferSize());
ByteBuffer serverAppReadBuffer = ByteBuffer.allocateDirect(
serverEngine.getSession().getApplicationBufferSize());
@@ -915,4 +914,84 @@ public abstract class SSLEngineTest {
promise.syncUninterruptibly();
}
+
+ @Test
+ public void testUnwrapBehavior() throws Exception {
+ SelfSignedCertificate cert = new SelfSignedCertificate();
+
+ clientSslCtx = SslContextBuilder
+ .forClient()
+ .trustManager(cert.cert())
+ .sslProvider(sslClientProvider())
+ .build();
+ SSLEngine client = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT);
+
+ serverSslCtx = SslContextBuilder
+ .forServer(cert.certificate(), cert.privateKey())
+ .sslProvider(sslServerProvider())
+ .build();
+ SSLEngine server = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT);
+
+ byte[] bytes = "Hello World".getBytes(CharsetUtil.US_ASCII);
+
+ try {
+ ByteBuffer plainClientOut = ByteBuffer.allocate(client.getSession().getApplicationBufferSize());
+ ByteBuffer encryptedClientToServer = ByteBuffer.allocate(server.getSession().getPacketBufferSize() * 2);
+ ByteBuffer plainServerIn = ByteBuffer.allocate(server.getSession().getApplicationBufferSize());
+
+ handshake(client, server);
+
+ // create two TLS frames
+
+ // first frame
+ plainClientOut.put(bytes, 0, 5);
+ plainClientOut.flip();
+
+ SSLEngineResult result = client.wrap(plainClientOut, encryptedClientToServer);
+ assertEquals(SSLEngineResult.Status.OK, result.getStatus());
+ assertEquals(5, result.bytesConsumed());
+ assertTrue(result.bytesProduced() > 0);
+
+ assertFalse(plainClientOut.hasRemaining());
+
+ // second frame
+ plainClientOut.clear();
+ plainClientOut.put(bytes, 5, 6);
+ plainClientOut.flip();
+
+ result = client.wrap(plainClientOut, encryptedClientToServer);
+ assertEquals(SSLEngineResult.Status.OK, result.getStatus());
+ assertEquals(6, result.bytesConsumed());
+ assertTrue(result.bytesProduced() > 0);
+
+ // send over to server
+ encryptedClientToServer.flip();
+
+ // try with too small output buffer first (to check BUFFER_OVERFLOW case)
+ int remaining = encryptedClientToServer.remaining();
+ ByteBuffer small = ByteBuffer.allocate(3);
+ result = server.unwrap(encryptedClientToServer, small);
+ assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus());
+ assertEquals(remaining, encryptedClientToServer.remaining());
+
+ // now with big enough buffer
+ result = server.unwrap(encryptedClientToServer, plainServerIn);
+ assertEquals(SSLEngineResult.Status.OK, result.getStatus());
+
+ assertEquals(5, result.bytesProduced());
+ assertTrue(encryptedClientToServer.hasRemaining());
+
+ result = server.unwrap(encryptedClientToServer, plainServerIn);
+ assertEquals(SSLEngineResult.Status.OK, result.getStatus());
+ assertEquals(6, result.bytesProduced());
+ assertFalse(encryptedClientToServer.hasRemaining());
+
+ plainServerIn.flip();
+
+ assertEquals(ByteBuffer.wrap(bytes), plainServerIn);
+ } finally {
+ cleanupClientSslEngine(client);
+ cleanupServerSslEngine(server);
+ }
+ }
}