Propagate SSLException to the Http2StreamChannels (#11023)

Motivation:

When TLSv1.3 is used (or TLS_FALSE_START) together with mTLS the handshake is considered successful before the server actually did verify the key material that was provided by the client. If the verification fails we currently will just close the stream without any extra information which makes it very hard to debug on the client side.

Modifications:

- Propagate SSLExceptions to the active streams
- Add unit test

Result:

Better visibility into why a stream was closed
This commit is contained in:
Norman Maurer 2021-02-19 08:09:22 +01:00 committed by GitHub
parent 329dae1ec5
commit bd8a337e99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 215 additions and 1 deletions

View File

@ -76,6 +76,18 @@
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
</dependency>
<dependency>
<groupId>org.bouncycastle</groupId>
<artifactId>bcpkix-jdk15on</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>${project.groupId}</groupId>
<artifactId>${tcnative.artifactId}</artifactId>
<classifier>${tcnative.classifier}</classifier>
<scope>test</scope>
<optional>true</optional>
</dependency>
</dependencies>
</project>

View File

@ -25,11 +25,13 @@ import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoop;
import io.netty.channel.ServerChannel;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.http2.Http2FrameCodec.DefaultHttp2FrameStream;
import io.netty.util.ReferenceCounted;
import io.netty.util.internal.ObjectUtil;
import io.netty.util.internal.UnstableApi;
import javax.net.ssl.SSLException;
import java.util.ArrayDeque;
import java.util.Queue;
@ -264,7 +266,7 @@ public final class Http2MultiplexHandler extends Http2ChannelDuplexHandler {
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
public void exceptionCaught(ChannelHandlerContext ctx, final Throwable cause) throws Exception {
if (cause instanceof Http2FrameStreamException) {
Http2FrameStreamException exception = (Http2FrameStreamException) cause;
Http2FrameStream stream = exception.stream();
@ -277,6 +279,17 @@ public final class Http2MultiplexHandler extends Http2ChannelDuplexHandler {
}
return;
}
if (cause.getCause() instanceof SSLException) {
forEachActiveStream(new Http2FrameStreamVisitor() {
@Override
public boolean visit(Http2FrameStream stream) {
AbstractHttp2StreamChannel childChannel = (AbstractHttp2StreamChannel)
((DefaultHttp2FrameStream) stream).attachment;
childChannel.pipeline().fireExceptionCaught(cause);
return true;
}
});
}
ctx.fireExceptionCaught(cause);
}

View File

@ -29,14 +29,32 @@ import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.ssl.ApplicationProtocolConfig;
import io.netty.handler.ssl.ApplicationProtocolNames;
import io.netty.handler.ssl.ClientAuth;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import io.netty.handler.ssl.SslProvider;
import io.netty.handler.ssl.SupportedCipherSuiteFilter;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.util.CharsetUtil;
import io.netty.util.NetUtil;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.FutureListener;
import org.junit.After;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import javax.net.ssl.SSLException;
import javax.net.ssl.X509TrustManager;
import java.net.InetSocketAddress;
import java.security.cert.CertificateException;
import java.security.cert.CertificateExpiredException;
import java.security.cert.X509Certificate;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
@ -45,6 +63,7 @@ import java.util.concurrent.atomic.AtomicReference;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static org.junit.Assert.assertFalse;
import static org.junit.Assume.assumeTrue;
public class Http2MultiplexTransportTest {
private static final ChannelHandler DISCARD_HANDLER = new ChannelInboundHandlerAdapter() {
@ -254,4 +273,174 @@ public class Http2MultiplexTransportTest {
executorService.shutdown();
}
}
@Test(timeout = 5000L)
public void testSSLExceptionOpenSslTLSv12() throws Exception {
testSslException(SslProvider.OPENSSL, false);
}
@Test(timeout = 5000L)
public void testSSLExceptionOpenSslTLSv13() throws Exception {
testSslException(SslProvider.OPENSSL, true);
}
@Ignore("JDK SSLEngine does not produce an alert")
@Test(timeout = 5000L)
public void testSSLExceptionJDKTLSv12() throws Exception {
testSslException(SslProvider.JDK, false);
}
@Ignore("JDK SSLEngine does not produce an alert")
@Test(timeout = 5000L)
public void testSSLExceptionJDKTLSv13() throws Exception {
testSslException(SslProvider.JDK, true);
}
private void testSslException(SslProvider provider, final boolean tlsv13) throws Exception {
assumeTrue(SslProvider.isAlpnSupported(provider));
if (tlsv13) {
assumeTrue(SslProvider.isTlsv13Supported(provider));
}
final String protocol = tlsv13 ? "TLSv1.3" : "TLSv1.2";
SelfSignedCertificate ssc = null;
try {
ssc = new SelfSignedCertificate();
final SslContext sslCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey())
.trustManager(new X509TrustManager() {
@Override
public void checkClientTrusted(X509Certificate[] chain, String authType)
throws CertificateException {
throw new CertificateExpiredException();
}
@Override
public void checkServerTrusted(X509Certificate[] chain, String authType)
throws CertificateException {
throw new CertificateExpiredException();
}
@Override
public X509Certificate[] getAcceptedIssuers() {
return new X509Certificate[0];
}
}).sslProvider(provider)
.ciphers(Http2SecurityUtil.CIPHERS, SupportedCipherSuiteFilter.INSTANCE)
.protocols(protocol)
.applicationProtocolConfig(new ApplicationProtocolConfig(
ApplicationProtocolConfig.Protocol.ALPN,
// NO_ADVERTISE is currently the only mode supported by both OpenSsl and JDK providers.
ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE,
// ACCEPT is currently the only mode supported by both OpenSsl and JDK providers.
ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT,
ApplicationProtocolNames.HTTP_2,
ApplicationProtocolNames.HTTP_1_1)).clientAuth(ClientAuth.REQUIRE)
.build();
ServerBootstrap sb = new ServerBootstrap();
sb.group(eventLoopGroup);
sb.channel(NioServerSocketChannel.class);
sb.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) {
ch.pipeline().addLast(sslCtx.newHandler(ch.alloc()));
ch.pipeline().addLast(new Http2FrameCodecBuilder(true).build());
ch.pipeline().addLast(new Http2MultiplexHandler(DISCARD_HANDLER));
}
});
serverChannel = sb.bind(new InetSocketAddress(NetUtil.LOCALHOST, 0)).syncUninterruptibly().channel();
final SslContext clientCtx = SslContextBuilder.forClient()
.keyManager(ssc.key(), ssc.cert())
.sslProvider(provider)
/* NOTE: the cipher filter may not include all ciphers required by the HTTP/2 specification.
* Please refer to the HTTP/2 specification for cipher requirements. */
.ciphers(Http2SecurityUtil.CIPHERS, SupportedCipherSuiteFilter.INSTANCE)
.trustManager(InsecureTrustManagerFactory.INSTANCE)
.protocols(protocol)
.applicationProtocolConfig(new ApplicationProtocolConfig(
ApplicationProtocolConfig.Protocol.ALPN,
// NO_ADVERTISE is currently the only mode supported by both OpenSsl and JDK providers.
ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE,
// ACCEPT is currently the only mode supported by both OpenSsl and JDK providers.
ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT,
ApplicationProtocolNames.HTTP_2,
ApplicationProtocolNames.HTTP_1_1))
.build();
final CountDownLatch latch = new CountDownLatch(2);
final AtomicReference<AssertionError> errorRef = new AtomicReference<AssertionError>();
Bootstrap bs = new Bootstrap();
bs.group(eventLoopGroup);
bs.channel(NioSocketChannel.class);
bs.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) {
ch.pipeline().addLast(clientCtx.newHandler(ch.alloc()));
ch.pipeline().addLast(new Http2FrameCodecBuilder(false).build());
ch.pipeline().addLast(new Http2MultiplexHandler(DISCARD_HANDLER));
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof SslHandshakeCompletionEvent) {
SslHandshakeCompletionEvent handshakeCompletionEvent =
(SslHandshakeCompletionEvent) evt;
if (handshakeCompletionEvent.isSuccess()) {
// In case of TLSv1.3 we should succeed the handshake. The alert for
// the mTLS failure will be send in the next round-trip.
if (!tlsv13) {
errorRef.set(new AssertionError("TLSv1.3 expected"));
}
Http2StreamChannelBootstrap h2Bootstrap =
new Http2StreamChannelBootstrap(ctx.channel());
h2Bootstrap.handler(new ChannelInboundHandlerAdapter() {
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
if (cause.getCause() instanceof SSLException) {
latch.countDown();
}
}
@Override
public void channelInactive(ChannelHandlerContext ctx) {
latch.countDown();
}
});
h2Bootstrap.open().addListener(new FutureListener<Channel>() {
@Override
public void operationComplete(Future<Channel> future) {
if (future.isSuccess()) {
future.getNow().writeAndFlush(new DefaultHttp2HeadersFrame(
new DefaultHttp2Headers(), false));
}
}
});
} else if (handshakeCompletionEvent.cause() instanceof SSLException) {
// In case of TLSv1.2 we should never see the handshake succeed as the alert for
// the mTLS failure will be send in the same round-trip.
if (tlsv13) {
errorRef.set(new AssertionError("TLSv1.2 expected"));
}
latch.countDown();
latch.countDown();
}
}
}
});
}
});
clientChannel = bs.connect(serverChannel.localAddress()).syncUninterruptibly().channel();
latch.await();
AssertionError error = errorRef.get();
if (error != null) {
throw error;
}
} finally {
if (ssc != null) {
ssc.delete();
}
}
}
}