diff --git a/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java index 3df430770e..53e265a895 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java @@ -26,6 +26,8 @@ import static org.junit.Assume.assumeTrue; import java.io.File; import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -62,7 +64,10 @@ import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCounted; import io.netty.util.concurrent.Promise; import io.netty.util.internal.ObjectUtil; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +@RunWith(Parameterized.class) public class SniHandlerTest { private static ApplicationProtocolConfig newApnConfig() { @@ -75,292 +80,360 @@ public class SniHandlerTest { "myprotocol"); } - private static SslContext makeSslContext() throws Exception { - return makeSslContext(null); + private static void assumeApnSupported(SslProvider provider) { + switch (provider) { + case OPENSSL: + case OPENSSL_REFCNT: + assumeTrue(OpenSsl.isAlpnSupported()); + break; + case JDK: + assumeTrue(JdkAlpnSslEngine.isAvailable()); + break; + default: + throw new Error(); + } } - private static SslContext makeSslContext(SslProvider provider) throws Exception { + private static SslContext makeSslContext(SslProvider provider, boolean apn) throws Exception { + if (apn) { + assumeApnSupported(provider); + } + File keyFile = new File(SniHandlerTest.class.getResource("test_encrypted.pem").getFile()); File crtFile = new File(SniHandlerTest.class.getResource("test.crt").getFile()); - return SslContextBuilder.forServer(crtFile, keyFile, "12345") - .sslProvider(provider) - .applicationProtocolConfig(newApnConfig()).build(); + SslContextBuilder sslCtxBuilder = SslContextBuilder.forServer(crtFile, keyFile, "12345") + .sslProvider(provider); + if (apn) { + sslCtxBuilder.applicationProtocolConfig(newApnConfig()); + } + return sslCtxBuilder.build(); } - private static SslContext makeSslClientContext() throws Exception { + private static SslContext makeSslClientContext(SslProvider provider, boolean apn) throws Exception { + if (apn) { + assumeApnSupported(provider); + } + File crtFile = new File(SniHandlerTest.class.getResource("test.crt").getFile()); - return SslContextBuilder.forClient().trustManager(crtFile).applicationProtocolConfig(newApnConfig()).build(); + SslContextBuilder sslCtxBuilder = SslContextBuilder.forClient().trustManager(crtFile).sslProvider(provider); + if (apn) { + sslCtxBuilder.applicationProtocolConfig(newApnConfig()); + } + return sslCtxBuilder.build(); + } + + @Parameterized.Parameters(name = "{index}: sslProvider={0}") + public static Iterable data() { + List params = new ArrayList(3); + if (OpenSsl.isAvailable()) { + params.add(SslProvider.OPENSSL); + params.add(SslProvider.OPENSSL_REFCNT); + } + params.add(SslProvider.JDK); + return params; + } + + private final SslProvider provider; + + public SniHandlerTest(SslProvider provider) { + this.provider = provider; } @Test public void testServerNameParsing() throws Exception { - SslContext nettyContext = makeSslContext(); - SslContext leanContext = makeSslContext(); - SslContext leanContext2 = makeSslContext(); - - DomainNameMapping mapping = new DomainNameMappingBuilder(nettyContext) - .add("*.netty.io", nettyContext) - // input with custom cases - .add("*.LEANCLOUD.CN", leanContext) - // a hostname conflict with previous one, since we are using order-sensitive config, the engine won't - // be used with the handler. - .add("chat4.leancloud.cn", leanContext2) - .build(); - - SniHandler handler = new SniHandler(mapping); - EmbeddedChannel ch = new EmbeddedChannel(handler); - - // hex dump of a client hello packet, which contains hostname "CHAT4。LEANCLOUD。CN" - String tlsHandshakeMessageHex1 = "16030100"; - // part 2 - String tlsHandshakeMessageHex = "bd010000b90303a74225676d1814ba57faff3b366" + - "3656ed05ee9dbb2a4dbb1bb1c32d2ea5fc39e0000000100008c0000001700150000164348" + - "415434E380824C45414E434C4F5544E38082434E000b000403000102000a00340032000e0" + - "00d0019000b000c00180009000a0016001700080006000700140015000400050012001300" + - "0100020003000f0010001100230000000d0020001e0601060206030501050205030401040" + - "20403030103020303020102020203000f00010133740000"; + SslContext nettyContext = makeSslContext(provider, false); + SslContext leanContext = makeSslContext(provider, false); + SslContext leanContext2 = makeSslContext(provider, false); try { - // Push the handshake message. - // Decode should fail because SNI error - ch.writeInbound(Unpooled.wrappedBuffer(DatatypeConverter.parseHexBinary(tlsHandshakeMessageHex1))); - ch.writeInbound(Unpooled.wrappedBuffer(DatatypeConverter.parseHexBinary(tlsHandshakeMessageHex))); - fail(); - } catch (DecoderException e) { - // expected - } + DomainNameMapping mapping = new DomainNameMappingBuilder(nettyContext) + .add("*.netty.io", nettyContext) + // input with custom cases + .add("*.LEANCLOUD.CN", leanContext) + // a hostname conflict with previous one, since we are using order-sensitive config, + // the engine won't be used with the handler. + .add("chat4.leancloud.cn", leanContext2) + .build(); - assertThat(ch.finish(), is(true)); - assertThat(handler.hostname(), is("chat4.leancloud.cn")); - assertThat(handler.sslContext(), is(leanContext)); + SniHandler handler = new SniHandler(mapping); + EmbeddedChannel ch = new EmbeddedChannel(handler); - for (;;) { - Object msg = ch.readOutbound(); - if (msg == null) { - break; + // hex dump of a client hello packet, which contains hostname "CHAT4。LEANCLOUD。CN" + String tlsHandshakeMessageHex1 = "16030100"; + // part 2 + String tlsHandshakeMessageHex = "bd010000b90303a74225676d1814ba57faff3b366" + + "3656ed05ee9dbb2a4dbb1bb1c32d2ea5fc39e0000000100008c0000001700150000164348" + + "415434E380824C45414E434C4F5544E38082434E000b000403000102000a00340032000e0" + + "00d0019000b000c00180009000a0016001700080006000700140015000400050012001300" + + "0100020003000f0010001100230000000d0020001e0601060206030501050205030401040" + + "20403030103020303020102020203000f00010133740000"; + + try { + // Push the handshake message. + // Decode should fail because SNI error + ch.writeInbound(Unpooled.wrappedBuffer(DatatypeConverter.parseHexBinary(tlsHandshakeMessageHex1))); + ch.writeInbound(Unpooled.wrappedBuffer(DatatypeConverter.parseHexBinary(tlsHandshakeMessageHex))); + fail(); + } catch (DecoderException e) { + // expected } - ReferenceCountUtil.release(msg); + + // Just call finish and not assert the return value. This is because OpenSSL correct produces an alert + // while the JDK SSLEngineImpl does not atm. + // See https://github.com/netty/netty/issues/5874 + ch.finish(); + + assertThat(handler.hostname(), is("chat4.leancloud.cn")); + assertThat(handler.sslContext(), is(leanContext)); + + for (;;) { + Object msg = ch.readOutbound(); + if (msg == null) { + break; + } + ReferenceCountUtil.release(msg); + } + } finally { + releaseAll(leanContext, leanContext2, nettyContext); } } @Test public void testFallbackToDefaultContext() throws Exception { - SslContext nettyContext = makeSslContext(); - SslContext leanContext = makeSslContext(); - SslContext leanContext2 = makeSslContext(); - - DomainNameMapping mapping = new DomainNameMappingBuilder(nettyContext) - .add("*.netty.io", nettyContext) - // input with custom cases - .add("*.LEANCLOUD.CN", leanContext) - // a hostname conflict with previous one, since we are using order-sensitive config, the engine won't - // be used with the handler. - .add("chat4.leancloud.cn", leanContext2) - .build(); - - SniHandler handler = new SniHandler(mapping); - EmbeddedChannel ch = new EmbeddedChannel(handler); - - // invalid - byte[] message = { 22, 3, 1, 0, 0 }; + SslContext nettyContext = makeSslContext(provider, false); + SslContext leanContext = makeSslContext(provider, false); + SslContext leanContext2 = makeSslContext(provider, false); try { - // Push the handshake message. - ch.writeInbound(Unpooled.wrappedBuffer(message)); - } catch (Exception e) { - // expected - } + DomainNameMapping mapping = new DomainNameMappingBuilder(nettyContext) + .add("*.netty.io", nettyContext) + // input with custom cases + .add("*.LEANCLOUD.CN", leanContext) + // a hostname conflict with previous one, since we are using order-sensitive config, + // the engine won't be used with the handler. + .add("chat4.leancloud.cn", leanContext2) + .build(); - assertThat(ch.finish(), is(false)); - assertThat(handler.hostname(), nullValue()); - assertThat(handler.sslContext(), is(nettyContext)); + SniHandler handler = new SniHandler(mapping); + EmbeddedChannel ch = new EmbeddedChannel(handler); + + // invalid + byte[] message = {22, 3, 1, 0, 0}; + + try { + // Push the handshake message. + ch.writeInbound(Unpooled.wrappedBuffer(message)); + } catch (Exception e) { + // expected + } + + assertThat(ch.finish(), is(false)); + assertThat(handler.hostname(), nullValue()); + assertThat(handler.sslContext(), is(nettyContext)); + } finally { + releaseAll(leanContext, leanContext2, nettyContext); + } } @Test public void testSniWithApnHandler() throws Exception { - SslContext nettyContext = makeSslContext(); - SslContext sniContext = makeSslContext(); - final SslContext clientContext = makeSslClientContext(); - final CountDownLatch serverApnDoneLatch = new CountDownLatch(1); - final CountDownLatch clientApnDoneLatch = new CountDownLatch(1); - - final DomainNameMapping mapping = new DomainNameMappingBuilder(nettyContext) - .add("*.netty.io", nettyContext) - .add("sni.fake.site", sniContext).build(); - final SniHandler handler = new SniHandler(mapping); - EventLoopGroup group = new NioEventLoopGroup(2); - Channel serverChannel = null; - Channel clientChannel = null; + SslContext nettyContext = makeSslContext(provider, true); + SslContext sniContext = makeSslContext(provider, true); + final SslContext clientContext = makeSslClientContext(provider, true); try { - ServerBootstrap sb = new ServerBootstrap(); - sb.group(group); - sb.channel(NioServerSocketChannel.class); - sb.childHandler(new ChannelInitializer() { - @Override - protected void initChannel(Channel ch) throws Exception { - ChannelPipeline p = ch.pipeline(); - // Server side SNI. - p.addLast(handler); - // Catch the notification event that APN has completed successfully. - p.addLast(new ApplicationProtocolNegotiationHandler("foo") { - @Override - protected void configurePipeline(ChannelHandlerContext ctx, String protocol) throws Exception { - serverApnDoneLatch.countDown(); - } - }); + final CountDownLatch serverApnDoneLatch = new CountDownLatch(1); + final CountDownLatch clientApnDoneLatch = new CountDownLatch(1); + + final DomainNameMapping mapping = new DomainNameMappingBuilder(nettyContext) + .add("*.netty.io", nettyContext) + .add("sni.fake.site", sniContext).build(); + final SniHandler handler = new SniHandler(mapping); + EventLoopGroup group = new NioEventLoopGroup(2); + Channel serverChannel = null; + Channel clientChannel = null; + try { + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group); + sb.channel(NioServerSocketChannel.class); + sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ChannelPipeline p = ch.pipeline(); + // Server side SNI. + p.addLast(handler); + // Catch the notification event that APN has completed successfully. + p.addLast(new ApplicationProtocolNegotiationHandler("foo") { + @Override + protected void configurePipeline(ChannelHandlerContext ctx, String protocol) { + serverApnDoneLatch.countDown(); + } + }); + } + }); + + Bootstrap cb = new Bootstrap(); + cb.group(group); + cb.channel(NioSocketChannel.class); + cb.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(new SslHandler(clientContext.newEngine( + ch.alloc(), "sni.fake.site", -1))); + // Catch the notification event that APN has completed successfully. + ch.pipeline().addLast(new ApplicationProtocolNegotiationHandler("foo") { + @Override + protected void configurePipeline(ChannelHandlerContext ctx, String protocol) { + clientApnDoneLatch.countDown(); + } + }); + } + }); + + serverChannel = sb.bind(new InetSocketAddress(0)).sync().channel(); + + ChannelFuture ccf = cb.connect(serverChannel.localAddress()); + assertTrue(ccf.awaitUninterruptibly().isSuccess()); + clientChannel = ccf.channel(); + + assertTrue(serverApnDoneLatch.await(5, TimeUnit.SECONDS)); + assertTrue(clientApnDoneLatch.await(5, TimeUnit.SECONDS)); + assertThat(handler.hostname(), is("sni.fake.site")); + assertThat(handler.sslContext(), is(sniContext)); + } finally { + if (serverChannel != null) { + serverChannel.close().sync(); } - }); - - Bootstrap cb = new Bootstrap(); - cb.group(group); - cb.channel(NioSocketChannel.class); - cb.handler(new ChannelInitializer() { - @Override - protected void initChannel(Channel ch) throws Exception { - ch.pipeline().addLast(new SslHandler(clientContext.newEngine( - ch.alloc(), "sni.fake.site", -1))); - // Catch the notification event that APN has completed successfully. - ch.pipeline().addLast(new ApplicationProtocolNegotiationHandler("foo") { - @Override - protected void configurePipeline(ChannelHandlerContext ctx, String protocol) throws Exception { - clientApnDoneLatch.countDown(); - } - }); + if (clientChannel != null) { + clientChannel.close().sync(); } - }); - - serverChannel = sb.bind(new InetSocketAddress(0)).sync().channel(); - - ChannelFuture ccf = cb.connect(serverChannel.localAddress()); - assertTrue(ccf.awaitUninterruptibly().isSuccess()); - clientChannel = ccf.channel(); - - assertTrue(serverApnDoneLatch.await(5, TimeUnit.SECONDS)); - assertTrue(clientApnDoneLatch.await(5, TimeUnit.SECONDS)); - assertThat(handler.hostname(), is("sni.fake.site")); - assertThat(handler.sslContext(), is(sniContext)); + group.shutdownGracefully(0, 0, TimeUnit.MICROSECONDS); + } } finally { - if (serverChannel != null) { - serverChannel.close().sync(); - } - if (clientChannel != null) { - clientChannel.close().sync(); - } - group.shutdownGracefully(0, 0, TimeUnit.MICROSECONDS); + releaseAll(clientContext, nettyContext, sniContext); } } @Test(timeout = 10L * 1000L) public void testReplaceHandler() throws Exception { + switch (provider) { + case OPENSSL: + case OPENSSL_REFCNT: + final String sniHost = "sni.netty.io"; + LocalAddress address = new LocalAddress("testReplaceHandler-" + Math.random()); + EventLoopGroup group = new DefaultEventLoopGroup(1); + Channel sc = null; + Channel cc = null; + SslContext sslContext = null; - assumeTrue(OpenSsl.isAvailable()); + SelfSignedCertificate cert = new SelfSignedCertificate(); - final String sniHost = "sni.netty.io"; - LocalAddress address = new LocalAddress("testReplaceHandler-" + Math.random()); - EventLoopGroup group = new DefaultEventLoopGroup(1); - Channel sc = null; - Channel cc = null; + try { + final SslContext sslServerContext = SslContextBuilder + .forServer(cert.key(), cert.cert()) + .sslProvider(provider) + .build(); - SelfSignedCertificate cert = new SelfSignedCertificate(); + final Mapping mapping = new Mapping() { + @Override + public SslContext map(String input) { + return sslServerContext; + } + }; - try { - final SslContext sslServerContext = SslContextBuilder - .forServer(cert.key(), cert.cert()) - .sslProvider(SslProvider.OPENSSL) - .build(); + final Promise releasePromise = group.next().newPromise(); - final Mapping mapping = new Mapping() { - @Override - public SslContext map(String input) { - return sslServerContext; - } - }; + final SniHandler handler = new SniHandler(mapping) { + @Override + protected void replaceHandler(ChannelHandlerContext ctx, + String hostname, final SslContext sslContext) + throws Exception { - final Promise releasePromise = group.next().newPromise(); + boolean success = false; + try { + // The SniHandler's replaceHandler() method allows us to implement custom behavior. + // As an example, we want to release() the SslContext upon channelInactive() or rather + // when the SslHandler closes it's SslEngine. If you take a close look at SslHandler + // you'll see that it's doing it in the #handlerRemoved0() method. - final SniHandler handler = new SniHandler(mapping) { - @Override - protected void replaceHandler(ChannelHandlerContext ctx, - String hostname, final SslContext sslContext) - throws Exception { - - boolean success = false; - try { - // The SniHandler's replaceHandler() method allows us to implement custom behavior. - // As an example, we want to release() the SslContext upon channelInactive() or rather - // when the SslHandler closes it's SslEngine. If you take a close look at SslHandler - // you'll see that it's doing it in the #handlerRemoved0() method. - - SSLEngine sslEngine = sslContext.newEngine(ctx.alloc()); - try { - SslHandler customSslHandler = new CustomSslHandler(sslContext, sslEngine) { - @Override - public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { - try { - super.handlerRemoved0(ctx); - } finally { - releasePromise.trySuccess(null); + SSLEngine sslEngine = sslContext.newEngine(ctx.alloc()); + try { + SslHandler customSslHandler = new CustomSslHandler(sslContext, sslEngine) { + @Override + public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { + try { + super.handlerRemoved0(ctx); + } finally { + releasePromise.trySuccess(null); + } + } + }; + ctx.pipeline().replace(this, CustomSslHandler.class.getName(), customSslHandler); + success = true; + } finally { + if (!success) { + ReferenceCountUtil.safeRelease(sslEngine); } } - }; - ctx.pipeline().replace(this, CustomSslHandler.class.getName(), customSslHandler); - success = true; - } finally { - if (!success) { - ReferenceCountUtil.safeRelease(sslEngine); + } finally { + if (!success) { + ReferenceCountUtil.safeRelease(sslContext); + releasePromise.cancel(true); + } } } - } finally { - if (!success) { - ReferenceCountUtil.safeRelease(sslContext); - releasePromise.cancel(true); + }; + + ServerBootstrap sb = new ServerBootstrap(); + sc = sb.group(group).channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addFirst(handler); } + }).bind(address).syncUninterruptibly().channel(); + + sslContext = SslContextBuilder.forClient().sslProvider(provider) + .trustManager(InsecureTrustManagerFactory.INSTANCE).build(); + + Bootstrap cb = new Bootstrap(); + cc = cb.group(group).channel(LocalChannel.class).handler(new SslHandler( + sslContext.newEngine(ByteBufAllocator.DEFAULT, sniHost, -1))) + .connect(address).syncUninterruptibly().channel(); + + cc.writeAndFlush(Unpooled.wrappedBuffer("Hello, World!".getBytes())) + .syncUninterruptibly(); + + // Notice how the server's SslContext refCnt is 1 + assertEquals(1, ((ReferenceCounted) sslServerContext).refCnt()); + + // The client disconnects + cc.close().syncUninterruptibly(); + if (!releasePromise.awaitUninterruptibly(10L, TimeUnit.SECONDS)) { + throw new IllegalStateException("It doesn't seem #replaceHandler() got called."); } + + // We should have successfully release() the SslContext + assertEquals(0, ((ReferenceCounted) sslServerContext).refCnt()); + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + if (sc != null) { + sc.close().syncUninterruptibly(); + } + if (sslContext != null) { + ReferenceCountUtil.release(sslContext); + } + group.shutdownGracefully(); + + cert.delete(); } - }; - - ServerBootstrap sb = new ServerBootstrap(); - sc = sb.group(group).channel(LocalServerChannel.class).childHandler(new ChannelInitializer() { - @Override - protected void initChannel(Channel ch) throws Exception { - ch.pipeline().addFirst(handler); - } - }).bind(address).syncUninterruptibly().channel(); - - SslContext sslContext = SslContextBuilder.forClient().trustManager(InsecureTrustManagerFactory.INSTANCE) - .build(); - - Bootstrap cb = new Bootstrap(); - cc = cb.group(group).channel(LocalChannel.class).handler(new SslHandler( - sslContext.newEngine(ByteBufAllocator.DEFAULT, sniHost, -1))) - .connect(address).syncUninterruptibly().channel(); - - cc.writeAndFlush(Unpooled.wrappedBuffer("Hello, World!".getBytes())) - .syncUninterruptibly(); - - // Notice how the server's SslContext refCnt is 1 - assertEquals(1, ((ReferenceCounted) sslServerContext).refCnt()); - - // The client disconnects - cc.close().syncUninterruptibly(); - if (!releasePromise.awaitUninterruptibly(10L, TimeUnit.SECONDS)) { - throw new IllegalStateException("It doesn't seem #replaceHandler() got called."); - } - - // We should have successfully release() the SslContext - assertEquals(0, ((ReferenceCounted) sslServerContext).refCnt()); - } finally { - if (cc != null) { - cc.close().syncUninterruptibly(); - } - if (sc != null) { - sc.close().syncUninterruptibly(); - } - group.shutdownGracefully(); - - cert.delete(); + case JDK: + return; + default: + throw new Error(); } } @@ -384,4 +457,10 @@ public class SniHandlerTest { ReferenceCountUtil.release(sslContext); } } + + private static void releaseAll(SslContext... contexts) { + for (SslContext ctx: contexts) { + ReferenceCountUtil.release(ctx); + } + } }