From 9e8b8ff53ac432c77319f152a86a31664ef99167 Mon Sep 17 00:00:00 2001 From: Chris Vest Date: Thu, 15 Jul 2021 09:02:03 +0200 Subject: [PATCH] Fix a bug where SslHandler clients would not process Server Hello messages in a timely manner (#11472) Motivation: The TLS handshake must be able to finish on its own, without being driven by outside read calls. This is currently not the case when TCP FastOpen is enabled. Reads must be permitted and marked as pending, even when a channel is not active. This is important because, with TCP FastOpen, the handshake processing of a TLS connection will start before the connection has been established -- before the process of connecting has even been started. The SslHandler on the client side will add the Client Hello message to the ChannelOutboundBuffer, then issue a `ctx.read` call for the anticipated Server Hello response, and then flush the Client Hello message which, in the case of TCP FastOpen, will cause the TCP connection to be established. In this transaction, it is important that the `ctx.read` call is not ignored since, if auto-read is turned off, this could delay or even prevent the Server Hello message from being processed, causing the server-side handshake to time out. Modification: Attach a listener to the SslHandler.handshakeFuture in the EchoClient, that will call ctx.read. Result: The SocketSslEchoTest now tests that the SslHandler can finish handshakes on its own, without being driven by 3rd party ctx.read calls. The various channel implementations have been updated to comply with this behaviour. --- .../SocketChannelNotYetConnectedTest.java | 61 +++++++++++++++++-- .../transport/socket/SocketSslEchoTest.java | 37 +++++++---- .../io/netty/channel/AbstractChannel.java | 4 -- .../netty/channel/oio/AbstractOioChannel.java | 7 ++- .../channel/oio/OioByteStreamChannel.java | 4 ++ 5 files changed, 91 insertions(+), 22 deletions(-) diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketChannelNotYetConnectedTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketChannelNotYetConnectedTest.java index ea32b86e57..13171a83df 100644 --- a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketChannelNotYetConnectedTest.java +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketChannelNotYetConnectedTest.java @@ -16,21 +16,32 @@ package io.netty.testsuite.transport.socket; import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; -import java.util.concurrent.TimeUnit; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.handler.codec.ByteToMessageDecoder; import org.junit.jupiter.api.Test; - -import java.net.SocketException; -import java.nio.channels.NotYetConnectedException; import org.junit.jupiter.api.TestInfo; import org.junit.jupiter.api.Timeout; +import java.net.SocketException; +import java.nio.channels.NotYetConnectedException; +import java.util.List; +import java.util.concurrent.CountDownLatch; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.fail; public class SocketChannelNotYetConnectedTest extends AbstractClientSocketTest { @Test - @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS) + @Timeout(30) public void testShutdownNotYetConnected(TestInfo testInfo) throws Throwable { run(testInfo, new Runner() { @Override @@ -68,4 +79,44 @@ public class SocketChannelNotYetConnectedTest extends AbstractClientSocketTest { throw cause; } } + + @Test + @Timeout(30) + public void readMustBePendingUntilChannelIsActive(TestInfo info) throws Throwable { + run(info, new Runner() { + @Override + public void run(Bootstrap bootstrap) throws Throwable { + NioEventLoopGroup group = new NioEventLoopGroup(1); + ServerBootstrap sb = new ServerBootstrap().group(group); + Channel serverChannel = sb.childHandler(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + ctx.writeAndFlush(Unpooled.copyInt(42)); + } + }).channel(NioServerSocketChannel.class).bind(0).sync().channel(); + + final CountDownLatch readLatch = new CountDownLatch(1); + bootstrap.handler(new ByteToMessageDecoder() { + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + assertFalse(ctx.channel().isActive()); + ctx.read(); + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + assertThat(in.readableBytes()).isLessThanOrEqualTo(Integer.BYTES); + if (in.readableBytes() == Integer.BYTES) { + assertThat(in.readInt()).isEqualTo(42); + readLatch.countDown(); + } + } + }); + bootstrap.connect(serverChannel.localAddress()).sync(); + + readLatch.await(); + group.shutdownGracefully().await(); + } + }); + } } diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslEchoTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslEchoTest.java index d04ccd1d14..5c6663951d 100644 --- a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslEchoTest.java +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslEchoTest.java @@ -38,6 +38,7 @@ import io.netty.handler.ssl.util.SelfSignedCertificate; import io.netty.handler.stream.ChunkedWriteHandler; import io.netty.testsuite.util.TestUtils; import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; import org.junit.jupiter.api.AfterAll; @@ -46,6 +47,7 @@ import org.junit.jupiter.api.Timeout; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import javax.net.ssl.SSLEngine; import java.io.File; import java.io.IOException; import java.security.cert.CertificateException; @@ -60,8 +62,6 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; -import javax.net.ssl.SSLEngine; - import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.is; @@ -276,7 +276,7 @@ public class SocketSslEchoTest extends AbstractSocketTest { if (useChunkedWriteHandler) { sch.pipeline().addLast(new ChunkedWriteHandler()); } - sch.pipeline().addLast("handler", serverHandler); + sch.pipeline().addLast("serverHandler", serverHandler); } }); @@ -298,7 +298,7 @@ public class SocketSslEchoTest extends AbstractSocketTest { if (useChunkedWriteHandler) { sch.pipeline().addLast(new ChunkedWriteHandler()); } - sch.pipeline().addLast("handler", clientHandler); + sch.pipeline().addLast("clientHandler", clientHandler); sch.pipeline().addLast(new ChannelInboundHandlerAdapter() { @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { @@ -471,14 +471,6 @@ public class SocketSslEchoTest extends AbstractSocketTest { this.exception = exception; } - @Override - public void channelActive(ChannelHandlerContext ctx) throws Exception { - if (!autoRead) { - ctx.read(); - } - ctx.fireChannelActive(); - } - @Override public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception { // We intentionally do not ctx.flush() here because we want to verify the SslHandler correctly flushing @@ -523,6 +515,19 @@ public class SocketSslEchoTest extends AbstractSocketTest { super(recvCounter, negoCounter, exception); } + @Override + public void handlerAdded(final ChannelHandlerContext ctx) { + if (!autoRead) { + ctx.pipeline().get(SslHandler.class).handshakeFuture().addListener( + new GenericFutureListener>() { + @Override + public void operationComplete(Future future) { + ctx.read(); + } + }); + } + } + @Override public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception { byte[] actual = new byte[in.readableBytes()]; @@ -552,6 +557,14 @@ public class SocketSslEchoTest extends AbstractSocketTest { renegoFuture = null; } + @Override + public void channelActive(final ChannelHandlerContext ctx) throws Exception { + if (!autoRead) { + ctx.read(); + } + ctx.fireChannelActive(); + } + @Override public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception { byte[] actual = new byte[in.readableBytes()]; diff --git a/transport/src/main/java/io/netty/channel/AbstractChannel.java b/transport/src/main/java/io/netty/channel/AbstractChannel.java index 6589cd5bd5..5aa2ab92db 100644 --- a/transport/src/main/java/io/netty/channel/AbstractChannel.java +++ b/transport/src/main/java/io/netty/channel/AbstractChannel.java @@ -848,10 +848,6 @@ public abstract class AbstractChannel extends DefaultAttributeMap implements Cha public final void beginRead() { assertEventLoop(); - if (!isActive()) { - return; - } - try { doBeginRead(); } catch (final Exception e) { diff --git a/transport/src/main/java/io/netty/channel/oio/AbstractOioChannel.java b/transport/src/main/java/io/netty/channel/oio/AbstractOioChannel.java index b68e586040..1c7dc0746f 100644 --- a/transport/src/main/java/io/netty/channel/oio/AbstractOioChannel.java +++ b/transport/src/main/java/io/netty/channel/oio/AbstractOioChannel.java @@ -34,7 +34,8 @@ public abstract class AbstractOioChannel extends AbstractChannel { protected static final int SO_TIMEOUT = 1000; boolean readPending; - private final Runnable readTask = new Runnable() { + boolean readWhenInactive; + final Runnable readTask = new Runnable() { @Override public void run() { doRead(); @@ -103,6 +104,10 @@ public abstract class AbstractOioChannel extends AbstractChannel { if (readPending) { return; } + if (!isActive()) { + readWhenInactive = true; + return; + } readPending = true; eventLoop().execute(readTask); diff --git a/transport/src/main/java/io/netty/channel/oio/OioByteStreamChannel.java b/transport/src/main/java/io/netty/channel/oio/OioByteStreamChannel.java index 1c5529328f..15f588cc1f 100644 --- a/transport/src/main/java/io/netty/channel/oio/OioByteStreamChannel.java +++ b/transport/src/main/java/io/netty/channel/oio/OioByteStreamChannel.java @@ -78,6 +78,10 @@ public abstract class OioByteStreamChannel extends AbstractOioByteChannel { } this.is = ObjectUtil.checkNotNull(is, "is"); this.os = ObjectUtil.checkNotNull(os, "os"); + if (readWhenInactive) { + eventLoop().execute(readTask); + readWhenInactive = false; + } } @Override