Fix a bug where SslHandler clients would not process Server Hello messages in a timely manner (#11472)

Motivation:
The TLS handshake must be able to finish on its own, without being driven by outside read calls.
This is currently not the case when TCP FastOpen is enabled.
Reads must be permitted and marked as pending, even when a channel is not active.

This is important because, with TCP FastOpen, the handshake processing of a TLS connection will start
before the connection has been established -- before the process of connecting has even been started.

The SslHandler on the client side will add the Client Hello message to the ChannelOutboundBuffer, then
issue a `ctx.read` call for the anticipated Server Hello response, and then flush the Client Hello
message which, in the case of TCP FastOpen, will cause the TCP connection to be established.

In this transaction, it is important that the `ctx.read` call is not ignored since, if auto-read is
turned off, this could delay or even prevent the Server Hello message from being processed, causing
the server-side handshake to time out.

Modification:
Attach a listener to the SslHandler.handshakeFuture in the EchoClient, that will call ctx.read.

Result:
The SocketSslEchoTest now tests that the SslHandler can finish handshakes on its own, without being driven by 3rd party ctx.read calls.
The various channel implementations have been updated to comply with this behaviour.
This commit is contained in:
Chris Vest 2021-07-15 09:02:03 +02:00
parent 23bf55206c
commit c982d45493
3 changed files with 88 additions and 18 deletions

View File

@ -16,21 +16,34 @@
package io.netty.testsuite.transport.socket;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SingleThreadEventLoop;
import io.netty.channel.nio.NioHandler;
import io.netty.channel.socket.SocketChannel;
import java.util.concurrent.TimeUnit;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.util.concurrent.DefaultThreadFactory;
import org.junit.jupiter.api.Test;
import java.net.SocketException;
import java.nio.channels.NotYetConnectedException;
import org.junit.jupiter.api.TestInfo;
import org.junit.jupiter.api.Timeout;
import java.net.SocketException;
import java.nio.channels.NotYetConnectedException;
import java.util.concurrent.CountDownLatch;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.fail;
public class SocketChannelNotYetConnectedTest extends AbstractClientSocketTest {
@Test
@Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
@Timeout(30)
public void testShutdownNotYetConnected(TestInfo testInfo) throws Throwable {
run(testInfo, this::testShutdownNotYetConnected);
}
@ -63,4 +76,45 @@ public class SocketChannelNotYetConnectedTest extends AbstractClientSocketTest {
throw cause;
}
}
@Test
@Timeout(30)
public void readMustBePendingUntilChannelIsActive(TestInfo info) throws Throwable {
run(info, new Runner<Bootstrap>() {
@Override
public void run(Bootstrap bootstrap) throws Throwable {
SingleThreadEventLoop group = new SingleThreadEventLoop(
new DefaultThreadFactory(getClass()), NioHandler.newFactory().newHandler());
ServerBootstrap sb = new ServerBootstrap().group(group);
Channel serverChannel = sb.childHandler(new ChannelHandlerAdapter() {
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
ctx.writeAndFlush(Unpooled.copyInt(42));
}
}).channel(NioServerSocketChannel.class).bind(0).sync().channel();
final CountDownLatch readLatch = new CountDownLatch(1);
bootstrap.handler(new ByteToMessageDecoder() {
@Override
public void handlerAdded0(ChannelHandlerContext ctx) throws Exception {
assertFalse(ctx.channel().isActive());
ctx.read();
}
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
assertThat(in.readableBytes()).isLessThanOrEqualTo(Integer.BYTES);
if (in.readableBytes() == Integer.BYTES) {
assertThat(in.readInt()).isEqualTo(42);
readLatch.countDown();
}
}
});
bootstrap.connect(serverChannel.localAddress()).sync();
readLatch.await();
group.shutdownGracefully().await();
}
});
}
}

View File

@ -38,6 +38,7 @@ import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.handler.stream.ChunkedWriteHandler;
import io.netty.testsuite.util.TestUtils;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import org.junit.jupiter.api.AfterAll;
@ -46,6 +47,7 @@ import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import javax.net.ssl.SSLEngine;
import java.io.File;
import java.io.IOException;
import java.security.cert.CertificateException;
@ -60,8 +62,6 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.SSLEngine;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.is;
@ -272,7 +272,7 @@ public class SocketSslEchoTest extends AbstractSocketTest {
if (useChunkedWriteHandler) {
sch.pipeline().addLast(new ChunkedWriteHandler());
}
sch.pipeline().addLast("handler", serverHandler);
sch.pipeline().addLast("serverHandler", serverHandler);
}
});
@ -294,7 +294,7 @@ public class SocketSslEchoTest extends AbstractSocketTest {
if (useChunkedWriteHandler) {
sch.pipeline().addLast(new ChunkedWriteHandler());
}
sch.pipeline().addLast("handler", clientHandler);
sch.pipeline().addLast("clientHandler", clientHandler);
sch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
@ -467,14 +467,6 @@ public class SocketSslEchoTest extends AbstractSocketTest {
this.exception = exception;
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
if (!autoRead) {
ctx.read();
}
ctx.fireChannelActive();
}
@Override
public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
// We intentionally do not ctx.flush() here because we want to verify the SslHandler correctly flushing
@ -519,6 +511,19 @@ public class SocketSslEchoTest extends AbstractSocketTest {
super(recvCounter, negoCounter, exception);
}
@Override
public void handlerAdded(final ChannelHandlerContext ctx) {
if (!autoRead) {
ctx.pipeline().get(SslHandler.class).handshakeFuture().addListener(
new GenericFutureListener<Future<? super Channel>>() {
@Override
public void operationComplete(Future<? super Channel> future) {
ctx.read();
}
});
}
}
@Override
public void messageReceived(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
byte[] actual = new byte[in.readableBytes()];
@ -548,6 +553,14 @@ public class SocketSslEchoTest extends AbstractSocketTest {
renegoFuture = null;
}
@Override
public void channelActive(final ChannelHandlerContext ctx) throws Exception {
if (!autoRead) {
ctx.read();
}
ctx.fireChannelActive();
}
@Override
public void messageReceived(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
byte[] actual = new byte[in.readableBytes()];

View File

@ -60,6 +60,7 @@ public abstract class AbstractChannel extends DefaultAttributeMap implements Cha
private volatile boolean registered;
private boolean closeInitiated;
private Throwable initialCloseCause;
private boolean readBeforeActive;
/** Cache for the string representation of this channel */
private boolean strValActive;
@ -433,7 +434,8 @@ public abstract class AbstractChannel extends DefaultAttributeMap implements Cha
}
protected final void readIfIsAutoRead() {
if (config().isAutoRead()) {
if (config().isAutoRead() || readBeforeActive) {
readBeforeActive = false;
read();
}
}
@ -807,6 +809,7 @@ public abstract class AbstractChannel extends DefaultAttributeMap implements Cha
assertEventLoop();
if (!isActive()) {
readBeforeActive = true;
return;
}