From fcbeebf6dff759a1e0cfe2e20fe5103e4c40075c Mon Sep 17 00:00:00 2001 From: Scott Mitchell Date: Fri, 1 Apr 2016 16:56:03 -0700 Subject: [PATCH] ApplicationProtocolNegotiationHandler doesn't work with SniHandler Motivation: ApplicationProtocolNegotiationHandler attempts to get a reference to an SslHandler in handlerAdded, but when SNI is in use the actual SslHandler will be added to the pipeline dynamically at some later time. When the handshake completes ApplicationProtocolNegotiationHandler throws an IllegalStateException because its reference to SslHandler is null. Modifications: - Instead of saving a reference to SslHandler in handlerAdded just search the pipeline when the SslHandler is needed Result: ApplicationProtocolNegotiationHandler support SniHandler. Fixes https://github.com/netty/netty/issues/5066 --- ...ApplicationProtocolNegotiationHandler.java | 18 +-- .../io/netty/handler/ssl/SniHandlerTest.java | 116 +++++++++++++++++- 2 files changed, 117 insertions(+), 17 deletions(-) diff --git a/handler/src/main/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandler.java b/handler/src/main/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandler.java index 7810d5aa4a..0e3ea0f39a 100644 --- a/handler/src/main/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandler.java @@ -65,7 +65,6 @@ public abstract class ApplicationProtocolNegotiationHandler extends ChannelInbou InternalLoggerFactory.getInstance(ApplicationProtocolNegotiationHandler.class); private final String fallbackProtocol; - private SslHandler sslHandler; /** * Creates a new instance with the specified fallback protocol name. @@ -77,18 +76,6 @@ public abstract class ApplicationProtocolNegotiationHandler extends ChannelInbou this.fallbackProtocol = ObjectUtil.checkNotNull(fallbackProtocol, "fallbackProtocol"); } - @Override - public void handlerAdded(ChannelHandlerContext ctx) throws Exception { - // FIXME: There is no way to tell if the SSL handler is placed before the negotiation handler. - final SslHandler sslHandler = ctx.pipeline().get(SslHandler.class); - if (sslHandler == null) { - throw new IllegalStateException( - "cannot find a SslHandler in the pipeline (required for application-level protocol negotiation)"); - } - - this.sslHandler = sslHandler; - } - @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { if (evt instanceof SslHandshakeCompletionEvent) { @@ -96,6 +83,11 @@ public abstract class ApplicationProtocolNegotiationHandler extends ChannelInbou SslHandshakeCompletionEvent handshakeEvent = (SslHandshakeCompletionEvent) evt; if (handshakeEvent.isSuccess()) { + SslHandler sslHandler = ctx.pipeline().get(SslHandler.class); + if (sslHandler == null) { + throw new IllegalStateException("cannot find a SslHandler in the pipeline (required for " + + "application-level protocol negotiation)"); + } String protocol = sslHandler.applicationProtocol(); configurePipeline(ctx, protocol != null? protocol : fallbackProtocol); } else { diff --git a/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java index a6ae52345f..0496e21e88 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java @@ -16,26 +16,61 @@ package io.netty.handler.ssl; +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.EventLoopGroup; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.codec.DecoderException; +import io.netty.util.DomainMappingBuilder; import io.netty.util.DomainNameMapping; import io.netty.util.ReferenceCountUtil; import org.junit.Test; -import javax.xml.bind.DatatypeConverter; import java.io.File; +import java.net.InetSocketAddress; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; -import static org.hamcrest.CoreMatchers.*; -import static org.junit.Assert.*; +import javax.xml.bind.DatatypeConverter; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; public class SniHandlerTest { + private static ApplicationProtocolConfig newApnConfig() { + return new ApplicationProtocolConfig( + ApplicationProtocolConfig.Protocol.ALPN, + // NO_ADVERTISE is currently the only mode supported by both OpenSsl and JDK providers. + ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE, + // ACCEPT is currently the only mode supported by both OpenSsl and JDK providers. + ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT, + "myprotocol"); + } + private static SslContext makeSslContext() throws Exception { File keyFile = new File(SniHandlerTest.class.getResource("test_encrypted.pem").getFile()); File crtFile = new File(SniHandlerTest.class.getResource("test.crt").getFile()); - return SslContextBuilder.forServer(crtFile, keyFile, "12345").build(); + return SslContextBuilder.forServer(crtFile, keyFile, "12345").applicationProtocolConfig(newApnConfig()).build(); + } + + private static SslContext makeSslClientContext() throws Exception { + File crtFile = new File(SniHandlerTest.class.getResource("test.crt").getFile()); + + return SslContextBuilder.forClient().trustManager(crtFile).applicationProtocolConfig(newApnConfig()).build(); } @Test @@ -123,4 +158,77 @@ public class SniHandlerTest { assertThat(handler.hostname(), nullValue()); assertThat(handler.sslContext(), is(nettyContext)); } + + @Test + public void testSniWithApnHandler() throws Exception { + SslContext nettyContext = makeSslContext(); + SslContext leanContext = makeSslContext(); + final SslContext clientContext = makeSslClientContext(); + SslContext wrappedLeanContext = null; + final CountDownLatch apnDoneLatch = new CountDownLatch(1); + + final DomainNameMapping mapping = new DomainMappingBuilder(nettyContext) + .add("*.netty.io", nettyContext) + .add("*.LEANCLOUD.CN", leanContext).build(); + // hex dump of a client hello packet, which contains hostname "CHAT4。LEANCLOUD。CN" + final String tlsHandshakeMessageHex1 = "16030100"; + // part 2 + final String tlsHandshakeMessageHex = "bd010000b90303a74225676d1814ba57faff3b366" + + "3656ed05ee9dbb2a4dbb1bb1c32d2ea5fc39e0000000100008c0000001700150000164348" + + "415434E380824C45414E434C4F5544E38082434E000b000403000102000a00340032000e0" + + "00d0019000b000c00180009000a0016001700080006000700140015000400050012001300" + + "0100020003000f0010001100230000000d0020001e0601060206030501050205030401040" + + "20403030103020303020102020203000f00010133740000"; + + EventLoopGroup group = new NioEventLoopGroup(2); + Channel serverChannel = null; + Channel clientChannel = null; + try { + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group); + sb.channel(NioServerSocketChannel.class); + sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ChannelPipeline p = ch.pipeline(); + p.addLast(new SniHandler(mapping)); + p.addLast(new ApplicationProtocolNegotiationHandler("foo") { + @Override + protected void configurePipeline(ChannelHandlerContext ctx, String protocol) throws Exception { + apnDoneLatch.countDown(); + } + }); + } + }); + + Bootstrap cb = new Bootstrap(); + cb.group(group); + cb.channel(NioSocketChannel.class); + cb.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ChannelPipeline p = ch.pipeline(); + ch.write(Unpooled.wrappedBuffer(DatatypeConverter.parseHexBinary(tlsHandshakeMessageHex1))); + ch.writeAndFlush(Unpooled.wrappedBuffer(DatatypeConverter.parseHexBinary(tlsHandshakeMessageHex))); + ch.pipeline().addLast(clientContext.newHandler(ch.alloc())); + } + }); + + serverChannel = sb.bind(new InetSocketAddress(0)).sync().channel(); + + ChannelFuture ccf = cb.connect(serverChannel.localAddress()); + assertTrue(ccf.awaitUninterruptibly().isSuccess()); + clientChannel = ccf.channel(); + + assertTrue(apnDoneLatch.await(5, TimeUnit.SECONDS)); + } finally { + if (serverChannel != null) { + serverChannel.close().sync(); + } + if (clientChannel != null) { + clientChannel.close().sync(); + } + group.shutdownGracefully(0, 0, TimeUnit.MICROSECONDS); + } + } }