diff --git a/handler/src/main/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandler.java b/handler/src/main/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandler.java index 7810d5aa4a..0e3ea0f39a 100644 --- a/handler/src/main/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandler.java @@ -65,7 +65,6 @@ public abstract class ApplicationProtocolNegotiationHandler extends ChannelInbou InternalLoggerFactory.getInstance(ApplicationProtocolNegotiationHandler.class); private final String fallbackProtocol; - private SslHandler sslHandler; /** * Creates a new instance with the specified fallback protocol name. @@ -77,18 +76,6 @@ public abstract class ApplicationProtocolNegotiationHandler extends ChannelInbou this.fallbackProtocol = ObjectUtil.checkNotNull(fallbackProtocol, "fallbackProtocol"); } - @Override - public void handlerAdded(ChannelHandlerContext ctx) throws Exception { - // FIXME: There is no way to tell if the SSL handler is placed before the negotiation handler. - final SslHandler sslHandler = ctx.pipeline().get(SslHandler.class); - if (sslHandler == null) { - throw new IllegalStateException( - "cannot find a SslHandler in the pipeline (required for application-level protocol negotiation)"); - } - - this.sslHandler = sslHandler; - } - @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { if (evt instanceof SslHandshakeCompletionEvent) { @@ -96,6 +83,11 @@ public abstract class ApplicationProtocolNegotiationHandler extends ChannelInbou SslHandshakeCompletionEvent handshakeEvent = (SslHandshakeCompletionEvent) evt; if (handshakeEvent.isSuccess()) { + SslHandler sslHandler = ctx.pipeline().get(SslHandler.class); + if (sslHandler == null) { + throw new IllegalStateException("cannot find a SslHandler in the pipeline (required for " + + "application-level protocol negotiation)"); + } String protocol = sslHandler.applicationProtocol(); configurePipeline(ctx, protocol != null? protocol : fallbackProtocol); } else { diff --git a/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java index a6ae52345f..0496e21e88 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java @@ -16,26 +16,61 @@ package io.netty.handler.ssl; +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.EventLoopGroup; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.codec.DecoderException; +import io.netty.util.DomainMappingBuilder; import io.netty.util.DomainNameMapping; import io.netty.util.ReferenceCountUtil; import org.junit.Test; -import javax.xml.bind.DatatypeConverter; import java.io.File; +import java.net.InetSocketAddress; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; -import static org.hamcrest.CoreMatchers.*; -import static org.junit.Assert.*; +import javax.xml.bind.DatatypeConverter; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; public class SniHandlerTest { + private static ApplicationProtocolConfig newApnConfig() { + return new ApplicationProtocolConfig( + ApplicationProtocolConfig.Protocol.ALPN, + // NO_ADVERTISE is currently the only mode supported by both OpenSsl and JDK providers. + ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE, + // ACCEPT is currently the only mode supported by both OpenSsl and JDK providers. + ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT, + "myprotocol"); + } + private static SslContext makeSslContext() throws Exception { File keyFile = new File(SniHandlerTest.class.getResource("test_encrypted.pem").getFile()); File crtFile = new File(SniHandlerTest.class.getResource("test.crt").getFile()); - return SslContextBuilder.forServer(crtFile, keyFile, "12345").build(); + return SslContextBuilder.forServer(crtFile, keyFile, "12345").applicationProtocolConfig(newApnConfig()).build(); + } + + private static SslContext makeSslClientContext() throws Exception { + File crtFile = new File(SniHandlerTest.class.getResource("test.crt").getFile()); + + return SslContextBuilder.forClient().trustManager(crtFile).applicationProtocolConfig(newApnConfig()).build(); } @Test @@ -123,4 +158,77 @@ public class SniHandlerTest { assertThat(handler.hostname(), nullValue()); assertThat(handler.sslContext(), is(nettyContext)); } + + @Test + public void testSniWithApnHandler() throws Exception { + SslContext nettyContext = makeSslContext(); + SslContext leanContext = makeSslContext(); + final SslContext clientContext = makeSslClientContext(); + SslContext wrappedLeanContext = null; + final CountDownLatch apnDoneLatch = new CountDownLatch(1); + + final DomainNameMapping mapping = new DomainMappingBuilder(nettyContext) + .add("*.netty.io", nettyContext) + .add("*.LEANCLOUD.CN", leanContext).build(); + // hex dump of a client hello packet, which contains hostname "CHAT4。LEANCLOUD。CN" + final String tlsHandshakeMessageHex1 = "16030100"; + // part 2 + final String tlsHandshakeMessageHex = "bd010000b90303a74225676d1814ba57faff3b366" + + "3656ed05ee9dbb2a4dbb1bb1c32d2ea5fc39e0000000100008c0000001700150000164348" + + "415434E380824C45414E434C4F5544E38082434E000b000403000102000a00340032000e0" + + "00d0019000b000c00180009000a0016001700080006000700140015000400050012001300" + + "0100020003000f0010001100230000000d0020001e0601060206030501050205030401040" + + "20403030103020303020102020203000f00010133740000"; + + EventLoopGroup group = new NioEventLoopGroup(2); + Channel serverChannel = null; + Channel clientChannel = null; + try { + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group); + sb.channel(NioServerSocketChannel.class); + sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ChannelPipeline p = ch.pipeline(); + p.addLast(new SniHandler(mapping)); + p.addLast(new ApplicationProtocolNegotiationHandler("foo") { + @Override + protected void configurePipeline(ChannelHandlerContext ctx, String protocol) throws Exception { + apnDoneLatch.countDown(); + } + }); + } + }); + + Bootstrap cb = new Bootstrap(); + cb.group(group); + cb.channel(NioSocketChannel.class); + cb.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ChannelPipeline p = ch.pipeline(); + ch.write(Unpooled.wrappedBuffer(DatatypeConverter.parseHexBinary(tlsHandshakeMessageHex1))); + ch.writeAndFlush(Unpooled.wrappedBuffer(DatatypeConverter.parseHexBinary(tlsHandshakeMessageHex))); + ch.pipeline().addLast(clientContext.newHandler(ch.alloc())); + } + }); + + serverChannel = sb.bind(new InetSocketAddress(0)).sync().channel(); + + ChannelFuture ccf = cb.connect(serverChannel.localAddress()); + assertTrue(ccf.awaitUninterruptibly().isSuccess()); + clientChannel = ccf.channel(); + + assertTrue(apnDoneLatch.await(5, TimeUnit.SECONDS)); + } finally { + if (serverChannel != null) { + serverChannel.close().sync(); + } + if (clientChannel != null) { + clientChannel.close().sync(); + } + group.shutdownGracefully(0, 0, TimeUnit.MICROSECONDS); + } + } }