diff --git a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java index d50cd3525f..ea22b2d473 100644 --- a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java +++ b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java @@ -379,10 +379,9 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc * * Calling this function with src.remaining == 0 is undefined. */ - private int writePlaintextData(final ByteBuffer src) { + private int writePlaintextData(final ByteBuffer src, int len) { final int pos = src.position(); final int limit = src.limit(); - final int len = Math.min(limit - pos, MAX_PLAINTEXT_LENGTH); final int sslWrote; if (src.isDirect()) { @@ -614,7 +613,8 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc int bytesProduced = 0; int bytesConsumed = 0; int endOffset = offset + length; - for (int i = offset; i < endOffset; ++i) { + + loop: for (int i = offset; i < endOffset; ++i) { final ByteBuffer src = srcs[i]; if (src == null) { throw new IllegalArgumentException("srcs[" + i + "] is null"); @@ -622,7 +622,9 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc while (src.hasRemaining()) { final SSLEngineResult pendingNetResult; // Write plaintext application data to the SSL engine - int result = writePlaintextData(src); + int result = writePlaintextData( + src, Math.min(src.remaining(), MAX_PLAINTEXT_LENGTH - bytesConsumed)); + if (result > 0) { bytesConsumed += result; @@ -633,6 +635,11 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc } bytesProduced = pendingNetResult.bytesProduced(); } + if (bytesConsumed == MAX_PLAINTEXT_LENGTH) { + // If we consumed the maximum amount of bytes for the plaintext length break out of the + // loop and start to fill the dst buffer. + break loop; + } } else { int sslError = SSL.getError(ssl, result); switch (sslError) { diff --git a/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java b/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java index 8f1cbee3f0..b6307fc141 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java @@ -694,25 +694,71 @@ public abstract class SSLEngineTest { SSLEngineResult clientResult; SSLEngineResult serverResult; + boolean clientHandshakeFinished = false; + boolean serverHandshakeFinished = false; + do { + int cTOsPos = cTOs.position(); + int sTOcPos = sTOc.position(); + clientResult = clientEngine.wrap(empty, cTOs); runDelegatedTasks(clientResult, clientEngine); serverResult = serverEngine.wrap(empty, sTOc); runDelegatedTasks(serverResult, serverEngine); + + // Verify that the consumed and produced number match what is in the buffers now. + assertEquals(empty.remaining(), clientResult.bytesConsumed()); + assertEquals(empty.remaining(), serverResult.bytesConsumed()); + assertEquals(cTOs.position() - cTOsPos, clientResult.bytesProduced()); + assertEquals(sTOc.position() - sTOcPos, serverResult.bytesProduced()); + cTOs.flip(); sTOc.flip(); + + // Verify that we only had one SSLEngineResult.HandshakeStatus.FINISHED + if (isHandshakeFinished(clientResult)) { + assertFalse(clientHandshakeFinished); + clientHandshakeFinished = true; + } + if (isHandshakeFinished(serverResult)) { + assertFalse(serverHandshakeFinished); + serverHandshakeFinished = true; + } + + cTOsPos = cTOs.position(); + sTOcPos = sTOc.position(); + + int clientAppReadBufferPos = clientAppReadBuffer.position(); + int serverAppReadBufferPos = serverAppReadBuffer.position(); + clientResult = clientEngine.unwrap(sTOc, clientAppReadBuffer); runDelegatedTasks(clientResult, clientEngine); serverResult = serverEngine.unwrap(cTOs, serverAppReadBuffer); runDelegatedTasks(serverResult, serverEngine); + + // Verify that the consumed and produced number match what is in the buffers now. + assertEquals(sTOc.position() - sTOcPos, clientResult.bytesConsumed()); + assertEquals(cTOs.position() - cTOsPos, serverResult.bytesConsumed()); + assertEquals(clientAppReadBuffer.position() - clientAppReadBufferPos, clientResult.bytesProduced()); + assertEquals(serverAppReadBuffer.position() - serverAppReadBufferPos, serverResult.bytesProduced()); + cTOs.compact(); sTOc.compact(); - } while (isHandshaking(clientResult) || isHandshaking(serverResult)); + + // Verify that we only had one SSLEngineResult.HandshakeStatus.FINISHED + if (isHandshakeFinished(clientResult)) { + assertFalse(clientHandshakeFinished); + clientHandshakeFinished = true; + } + if (isHandshakeFinished(serverResult)) { + assertFalse(serverHandshakeFinished); + serverHandshakeFinished = true; + } + } while (!clientHandshakeFinished || !serverHandshakeFinished); } - private static boolean isHandshaking(SSLEngineResult result) { - return result.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING && - result.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.FINISHED; + private static boolean isHandshakeFinished(SSLEngineResult result) { + return result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED; } private static void runDelegatedTasks(SSLEngineResult result, SSLEngine engine) { @@ -994,4 +1040,46 @@ public abstract class SSLEngineTest { cleanupServerSslEngine(server); } } + + @Test + public void testPacketBufferSizeLimit() throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + + clientSslCtx = SslContextBuilder + .forClient() + .trustManager(cert.cert()) + .sslProvider(sslClientProvider()) + .build(); + SSLEngine client = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + + serverSslCtx = SslContextBuilder + .forServer(cert.certificate(), cert.privateKey()) + .sslProvider(sslServerProvider()) + .build(); + SSLEngine server = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + + try { + // Allocate an buffer that is bigger then the max plain record size. + ByteBuffer plainServerOut = ByteBuffer.allocate(server.getSession().getApplicationBufferSize() * 2); + + handshake(client, server); + + // Fill the whole buffer and flip it. + plainServerOut.position(plainServerOut.capacity()); + plainServerOut.flip(); + + ByteBuffer encryptedServerToClient = ByteBuffer.allocate(server.getSession().getPacketBufferSize()); + + int encryptedServerToClientPos = encryptedServerToClient.position(); + int plainServerOutPos = plainServerOut.position(); + SSLEngineResult result = server.wrap(plainServerOut, encryptedServerToClient); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(plainServerOut.position() - plainServerOutPos, result.bytesConsumed()); + assertEquals(encryptedServerToClient.position() - encryptedServerToClientPos, result.bytesProduced()); + } finally { + cleanupClientSslEngine(client); + cleanupServerSslEngine(server); + cert.delete(); + } + } }