Fix a bug where SslHandler clients would not process Server Hello messages in a timely manner (#11472)

Motivation:
The TLS handshake must be able to finish on its own, without being driven by outside read calls.
This is currently not the case when TCP FastOpen is enabled.
Reads must be permitted and marked as pending, even when a channel is not active.

This is important because, with TCP FastOpen, the handshake processing of a TLS connection will start
before the connection has been established -- before the process of connecting has even been started.

The SslHandler on the client side will add the Client Hello message to the ChannelOutboundBuffer, then
issue a `ctx.read` call for the anticipated Server Hello response, and then flush the Client Hello
message which, in the case of TCP FastOpen, will cause the TCP connection to be established.

In this transaction, it is important that the `ctx.read` call is not ignored since, if auto-read is
turned off, this could delay or even prevent the Server Hello message from being processed, causing
the server-side handshake to time out.

Modification:
Attach a listener to the SslHandler.handshakeFuture in the EchoClient, that will call ctx.read.

Result:
The SocketSslEchoTest now tests that the SslHandler can finish handshakes on its own, without being driven by 3rd party ctx.read calls.
The various channel implementations have been updated to comply with this behaviour.
This commit is contained in:
Chris Vest 2021-07-15 09:02:03 +02:00 committed by GitHub
parent 55957d3e75
commit 9e8b8ff53a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 91 additions and 22 deletions

View File

@ -16,21 +16,32 @@
package io.netty.testsuite.transport.socket; package io.netty.testsuite.transport.socket;
import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.SocketChannel;
import java.util.concurrent.TimeUnit; import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.ByteToMessageDecoder;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.net.SocketException;
import java.nio.channels.NotYetConnectedException;
import org.junit.jupiter.api.TestInfo; import org.junit.jupiter.api.TestInfo;
import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.Timeout;
import java.net.SocketException;
import java.nio.channels.NotYetConnectedException;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assertions.fail;
public class SocketChannelNotYetConnectedTest extends AbstractClientSocketTest { public class SocketChannelNotYetConnectedTest extends AbstractClientSocketTest {
@Test @Test
@Timeout(value = 30000, unit = TimeUnit.MILLISECONDS) @Timeout(30)
public void testShutdownNotYetConnected(TestInfo testInfo) throws Throwable { public void testShutdownNotYetConnected(TestInfo testInfo) throws Throwable {
run(testInfo, new Runner<Bootstrap>() { run(testInfo, new Runner<Bootstrap>() {
@Override @Override
@ -68,4 +79,44 @@ public class SocketChannelNotYetConnectedTest extends AbstractClientSocketTest {
throw cause; throw cause;
} }
} }
@Test
@Timeout(30)
public void readMustBePendingUntilChannelIsActive(TestInfo info) throws Throwable {
run(info, new Runner<Bootstrap>() {
@Override
public void run(Bootstrap bootstrap) throws Throwable {
NioEventLoopGroup group = new NioEventLoopGroup(1);
ServerBootstrap sb = new ServerBootstrap().group(group);
Channel serverChannel = sb.childHandler(new ChannelInboundHandlerAdapter() {
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
ctx.writeAndFlush(Unpooled.copyInt(42));
}
}).channel(NioServerSocketChannel.class).bind(0).sync().channel();
final CountDownLatch readLatch = new CountDownLatch(1);
bootstrap.handler(new ByteToMessageDecoder() {
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
assertFalse(ctx.channel().isActive());
ctx.read();
}
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
assertThat(in.readableBytes()).isLessThanOrEqualTo(Integer.BYTES);
if (in.readableBytes() == Integer.BYTES) {
assertThat(in.readInt()).isEqualTo(42);
readLatch.countDown();
}
}
});
bootstrap.connect(serverChannel.localAddress()).sync();
readLatch.await();
group.shutdownGracefully().await();
}
});
}
} }

View File

@ -38,6 +38,7 @@ import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.handler.stream.ChunkedWriteHandler; import io.netty.handler.stream.ChunkedWriteHandler;
import io.netty.testsuite.util.TestUtils; import io.netty.testsuite.util.TestUtils;
import io.netty.util.concurrent.Future; import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory; import io.netty.util.internal.logging.InternalLoggerFactory;
import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterAll;
@ -46,6 +47,7 @@ import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import javax.net.ssl.SSLEngine;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.security.cert.CertificateException; import java.security.cert.CertificateException;
@ -60,8 +62,6 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.SSLEngine;
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
@ -276,7 +276,7 @@ public class SocketSslEchoTest extends AbstractSocketTest {
if (useChunkedWriteHandler) { if (useChunkedWriteHandler) {
sch.pipeline().addLast(new ChunkedWriteHandler()); sch.pipeline().addLast(new ChunkedWriteHandler());
} }
sch.pipeline().addLast("handler", serverHandler); sch.pipeline().addLast("serverHandler", serverHandler);
} }
}); });
@ -298,7 +298,7 @@ public class SocketSslEchoTest extends AbstractSocketTest {
if (useChunkedWriteHandler) { if (useChunkedWriteHandler) {
sch.pipeline().addLast(new ChunkedWriteHandler()); sch.pipeline().addLast(new ChunkedWriteHandler());
} }
sch.pipeline().addLast("handler", clientHandler); sch.pipeline().addLast("clientHandler", clientHandler);
sch.pipeline().addLast(new ChannelInboundHandlerAdapter() { sch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override @Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
@ -471,14 +471,6 @@ public class SocketSslEchoTest extends AbstractSocketTest {
this.exception = exception; this.exception = exception;
} }
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
if (!autoRead) {
ctx.read();
}
ctx.fireChannelActive();
}
@Override @Override
public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception { public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
// We intentionally do not ctx.flush() here because we want to verify the SslHandler correctly flushing // We intentionally do not ctx.flush() here because we want to verify the SslHandler correctly flushing
@ -523,6 +515,19 @@ public class SocketSslEchoTest extends AbstractSocketTest {
super(recvCounter, negoCounter, exception); super(recvCounter, negoCounter, exception);
} }
@Override
public void handlerAdded(final ChannelHandlerContext ctx) {
if (!autoRead) {
ctx.pipeline().get(SslHandler.class).handshakeFuture().addListener(
new GenericFutureListener<Future<? super Channel>>() {
@Override
public void operationComplete(Future<? super Channel> future) {
ctx.read();
}
});
}
}
@Override @Override
public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception { public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
byte[] actual = new byte[in.readableBytes()]; byte[] actual = new byte[in.readableBytes()];
@ -552,6 +557,14 @@ public class SocketSslEchoTest extends AbstractSocketTest {
renegoFuture = null; renegoFuture = null;
} }
@Override
public void channelActive(final ChannelHandlerContext ctx) throws Exception {
if (!autoRead) {
ctx.read();
}
ctx.fireChannelActive();
}
@Override @Override
public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception { public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
byte[] actual = new byte[in.readableBytes()]; byte[] actual = new byte[in.readableBytes()];

View File

@ -848,10 +848,6 @@ public abstract class AbstractChannel extends DefaultAttributeMap implements Cha
public final void beginRead() { public final void beginRead() {
assertEventLoop(); assertEventLoop();
if (!isActive()) {
return;
}
try { try {
doBeginRead(); doBeginRead();
} catch (final Exception e) { } catch (final Exception e) {

View File

@ -34,7 +34,8 @@ public abstract class AbstractOioChannel extends AbstractChannel {
protected static final int SO_TIMEOUT = 1000; protected static final int SO_TIMEOUT = 1000;
boolean readPending; boolean readPending;
private final Runnable readTask = new Runnable() { boolean readWhenInactive;
final Runnable readTask = new Runnable() {
@Override @Override
public void run() { public void run() {
doRead(); doRead();
@ -103,6 +104,10 @@ public abstract class AbstractOioChannel extends AbstractChannel {
if (readPending) { if (readPending) {
return; return;
} }
if (!isActive()) {
readWhenInactive = true;
return;
}
readPending = true; readPending = true;
eventLoop().execute(readTask); eventLoop().execute(readTask);

View File

@ -78,6 +78,10 @@ public abstract class OioByteStreamChannel extends AbstractOioByteChannel {
} }
this.is = ObjectUtil.checkNotNull(is, "is"); this.is = ObjectUtil.checkNotNull(is, "is");
this.os = ObjectUtil.checkNotNull(os, "os"); this.os = ObjectUtil.checkNotNull(os, "os");
if (readWhenInactive) {
eventLoop().execute(readTask);
readWhenInactive = false;
}
} }
@Override @Override