diff --git a/codec-http2/pom.xml b/codec-http2/pom.xml index 4ce8699476..b5ee563f30 100644 --- a/codec-http2/pom.xml +++ b/codec-http2/pom.xml @@ -76,6 +76,18 @@ org.mockito mockito-core + + org.bouncycastle + bcpkix-jdk15on + test + + + ${project.groupId} + ${tcnative.artifactId} + ${tcnative.classifier} + test + true + diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexHandler.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexHandler.java index e9e86408e9..a98751f518 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexHandler.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexHandler.java @@ -25,11 +25,13 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; import io.netty.channel.EventLoop; import io.netty.channel.ServerChannel; +import io.netty.handler.codec.DecoderException; import io.netty.handler.codec.http2.Http2FrameCodec.DefaultHttp2FrameStream; import io.netty.util.ReferenceCounted; import io.netty.util.internal.ObjectUtil; import io.netty.util.internal.UnstableApi; +import javax.net.ssl.SSLException; import java.util.ArrayDeque; import java.util.Queue; @@ -264,7 +266,7 @@ public final class Http2MultiplexHandler extends Http2ChannelDuplexHandler { } @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + public void exceptionCaught(ChannelHandlerContext ctx, final Throwable cause) throws Exception { if (cause instanceof Http2FrameStreamException) { Http2FrameStreamException exception = (Http2FrameStreamException) cause; Http2FrameStream stream = exception.stream(); @@ -277,6 +279,17 @@ public final class Http2MultiplexHandler extends Http2ChannelDuplexHandler { } return; } + if (cause.getCause() instanceof SSLException) { + forEachActiveStream(new Http2FrameStreamVisitor() { + @Override + public boolean visit(Http2FrameStream stream) { + AbstractHttp2StreamChannel childChannel = (AbstractHttp2StreamChannel) + ((DefaultHttp2FrameStream) stream).attachment; + childChannel.pipeline().fireExceptionCaught(cause); + return true; + } + }); + } ctx.fireExceptionCaught(cause); } diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexTransportTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexTransportTest.java index 62aaacd9df..77045683e9 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexTransportTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexTransportTest.java @@ -29,14 +29,32 @@ import io.netty.channel.EventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.ssl.ApplicationProtocolConfig; +import io.netty.handler.ssl.ApplicationProtocolNames; +import io.netty.handler.ssl.ClientAuth; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.SslHandshakeCompletionEvent; +import io.netty.handler.ssl.SslProvider; +import io.netty.handler.ssl.SupportedCipherSuiteFilter; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.util.SelfSignedCertificate; import io.netty.util.CharsetUtil; import io.netty.util.NetUtil; import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; import org.junit.After; import org.junit.Before; +import org.junit.Ignore; import org.junit.Test; +import javax.net.ssl.SSLException; +import javax.net.ssl.X509TrustManager; import java.net.InetSocketAddress; +import java.security.cert.CertificateException; +import java.security.cert.CertificateExpiredException; +import java.security.cert.X509Certificate; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; @@ -45,6 +63,7 @@ import java.util.concurrent.atomic.AtomicReference; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.junit.Assert.assertFalse; +import static org.junit.Assume.assumeTrue; public class Http2MultiplexTransportTest { private static final ChannelHandler DISCARD_HANDLER = new ChannelInboundHandlerAdapter() { @@ -254,4 +273,174 @@ public class Http2MultiplexTransportTest { executorService.shutdown(); } } + + @Test(timeout = 5000L) + public void testSSLExceptionOpenSslTLSv12() throws Exception { + testSslException(SslProvider.OPENSSL, false); + } + + @Test(timeout = 5000L) + public void testSSLExceptionOpenSslTLSv13() throws Exception { + testSslException(SslProvider.OPENSSL, true); + } + + @Ignore("JDK SSLEngine does not produce an alert") + @Test(timeout = 5000L) + public void testSSLExceptionJDKTLSv12() throws Exception { + testSslException(SslProvider.JDK, false); + } + + @Ignore("JDK SSLEngine does not produce an alert") + @Test(timeout = 5000L) + public void testSSLExceptionJDKTLSv13() throws Exception { + testSslException(SslProvider.JDK, true); + } + + private void testSslException(SslProvider provider, final boolean tlsv13) throws Exception { + assumeTrue(SslProvider.isAlpnSupported(provider)); + if (tlsv13) { + assumeTrue(SslProvider.isTlsv13Supported(provider)); + } + final String protocol = tlsv13 ? "TLSv1.3" : "TLSv1.2"; + SelfSignedCertificate ssc = null; + try { + ssc = new SelfSignedCertificate(); + final SslContext sslCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .trustManager(new X509TrustManager() { + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + throw new CertificateExpiredException(); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + throw new CertificateExpiredException(); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return new X509Certificate[0]; + } + }).sslProvider(provider) + .ciphers(Http2SecurityUtil.CIPHERS, SupportedCipherSuiteFilter.INSTANCE) + .protocols(protocol) + .applicationProtocolConfig(new ApplicationProtocolConfig( + ApplicationProtocolConfig.Protocol.ALPN, + // NO_ADVERTISE is currently the only mode supported by both OpenSsl and JDK providers. + ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE, + // ACCEPT is currently the only mode supported by both OpenSsl and JDK providers. + ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT, + ApplicationProtocolNames.HTTP_2, + ApplicationProtocolNames.HTTP_1_1)).clientAuth(ClientAuth.REQUIRE) + .build(); + + ServerBootstrap sb = new ServerBootstrap(); + sb.group(eventLoopGroup); + sb.channel(NioServerSocketChannel.class); + sb.childHandler(new ChannelInitializer() { + + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(sslCtx.newHandler(ch.alloc())); + ch.pipeline().addLast(new Http2FrameCodecBuilder(true).build()); + ch.pipeline().addLast(new Http2MultiplexHandler(DISCARD_HANDLER)); + } + }); + serverChannel = sb.bind(new InetSocketAddress(NetUtil.LOCALHOST, 0)).syncUninterruptibly().channel(); + + final SslContext clientCtx = SslContextBuilder.forClient() + .keyManager(ssc.key(), ssc.cert()) + .sslProvider(provider) + /* NOTE: the cipher filter may not include all ciphers required by the HTTP/2 specification. + * Please refer to the HTTP/2 specification for cipher requirements. */ + .ciphers(Http2SecurityUtil.CIPHERS, SupportedCipherSuiteFilter.INSTANCE) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .protocols(protocol) + .applicationProtocolConfig(new ApplicationProtocolConfig( + ApplicationProtocolConfig.Protocol.ALPN, + // NO_ADVERTISE is currently the only mode supported by both OpenSsl and JDK providers. + ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE, + // ACCEPT is currently the only mode supported by both OpenSsl and JDK providers. + ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT, + ApplicationProtocolNames.HTTP_2, + ApplicationProtocolNames.HTTP_1_1)) + .build(); + + final CountDownLatch latch = new CountDownLatch(2); + final AtomicReference errorRef = new AtomicReference(); + Bootstrap bs = new Bootstrap(); + bs.group(eventLoopGroup); + bs.channel(NioSocketChannel.class); + bs.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(clientCtx.newHandler(ch.alloc())); + ch.pipeline().addLast(new Http2FrameCodecBuilder(false).build()); + ch.pipeline().addLast(new Http2MultiplexHandler(DISCARD_HANDLER)); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof SslHandshakeCompletionEvent) { + SslHandshakeCompletionEvent handshakeCompletionEvent = + (SslHandshakeCompletionEvent) evt; + if (handshakeCompletionEvent.isSuccess()) { + // In case of TLSv1.3 we should succeed the handshake. The alert for + // the mTLS failure will be send in the next round-trip. + if (!tlsv13) { + errorRef.set(new AssertionError("TLSv1.3 expected")); + } + + Http2StreamChannelBootstrap h2Bootstrap = + new Http2StreamChannelBootstrap(ctx.channel()); + h2Bootstrap.handler(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + if (cause.getCause() instanceof SSLException) { + latch.countDown(); + } + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + latch.countDown(); + } + }); + h2Bootstrap.open().addListener(new FutureListener() { + @Override + public void operationComplete(Future future) { + if (future.isSuccess()) { + future.getNow().writeAndFlush(new DefaultHttp2HeadersFrame( + new DefaultHttp2Headers(), false)); + } + } + }); + + } else if (handshakeCompletionEvent.cause() instanceof SSLException) { + // In case of TLSv1.2 we should never see the handshake succeed as the alert for + // the mTLS failure will be send in the same round-trip. + if (tlsv13) { + errorRef.set(new AssertionError("TLSv1.2 expected")); + } + latch.countDown(); + latch.countDown(); + } + } + } + }); + } + }); + clientChannel = bs.connect(serverChannel.localAddress()).syncUninterruptibly().channel(); + latch.await(); + AssertionError error = errorRef.get(); + if (error != null) { + throw error; + } + } finally { + if (ssc != null) { + ssc.delete(); + } + } + } }