From f97866dbc6e7edc987de9bf7d31f563ce726833e Mon Sep 17 00:00:00 2001 From: Roger Kapsi Date: Fri, 26 Aug 2016 08:28:25 -0400 Subject: [PATCH] Expose SniHandler's replaceHandler() so that users can implement custom behavior. Motivation The SniHandler is currently hiding its replaceHandler() method and everything that comes with it. The user has no easy way of getting a hold onto the SslContext for the purpose of reference counting for example. The SniHandler does have getter methods for the SslContext and hostname but they're not very practical or useful. For one the SniHandler will remove itself from the pipeline and we'd have to track a reference of it externally and as we saw in #5745 it'll possibly leave its internal "selection" object with the "EMPTY_SELECTION" value (i.e. we've just lost track of the SslContext). Modifications Expose replaceHandler() and allow the user to override it and get a hold onto the hostname, SslContext and SslHandler that will replace the SniHandler. Result It's possible to get a hold onto the SslContext, the hostname and the SslHandler that is about to replace the SniHandler. Users can add additional behavior. --- .../java/io/netty/handler/ssl/SniHandler.java | 52 +++-- .../io/netty/handler/ssl/SniHandlerTest.java | 183 ++++++++++++++++-- 2 files changed, 206 insertions(+), 29 deletions(-) diff --git a/handler/src/main/java/io/netty/handler/ssl/SniHandler.java b/handler/src/main/java/io/netty/handler/ssl/SniHandler.java index 77d558f2a4..4cfe78145b 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SniHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SniHandler.java @@ -15,6 +15,13 @@ */ package io.netty.handler.ssl; +import java.net.IDN; +import java.net.SocketAddress; +import java.util.List; +import java.util.Locale; + +import javax.net.ssl.SSLEngine; + import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; import io.netty.channel.ChannelHandlerContext; @@ -34,13 +41,6 @@ import io.netty.util.internal.ObjectUtil; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; -import java.net.IDN; -import java.net.SocketAddress; -import java.util.List; -import java.util.Locale; - -import javax.net.ssl.SSLEngine; - /** *

Enables SNI * (Server Name Indication) extension for server side SSL. For clients @@ -277,7 +277,7 @@ public class SniHandler extends ByteToMessageDecoder implements ChannelOutboundH Future future = mapping.map(hostname, ctx.executor().newPromise()); if (future.isDone()) { if (future.isSuccess()) { - replaceHandler(ctx, new Selection(future.getNow(), hostname)); + onSslContext(ctx, hostname, future.getNow()); } else { throw new DecoderException("failed to get the SslContext for " + hostname, future.cause()); } @@ -289,7 +289,7 @@ public class SniHandler extends ByteToMessageDecoder implements ChannelOutboundH try { suppressRead = false; if (future.isSuccess()) { - replaceHandler(ctx, new Selection(future.getNow(), hostname)); + onSslContext(ctx, hostname, future.getNow()); } else { ctx.fireExceptionCaught(new DecoderException("failed to get the SslContext for " + hostname, future.cause())); @@ -305,19 +305,41 @@ public class SniHandler extends ByteToMessageDecoder implements ChannelOutboundH } } - private void replaceHandler(ChannelHandlerContext ctx, Selection selection) { - SSLEngine sslEngine = null; - this.selection = selection; + /** + * Called upon successful completion of the {@link AsyncMapping}'s {@link Future}. + * + * @see #select(ChannelHandlerContext, String) + */ + private void onSslContext(ChannelHandlerContext ctx, String hostname, SslContext sslContext) { + this.selection = new Selection(sslContext, hostname); try { - sslEngine = selection.context.newEngine(ctx.alloc()); - ctx.pipeline().replace(this, SslHandler.class.getName(), selection.context.newHandler(sslEngine)); + replaceHandler(ctx, hostname, sslContext); } catch (Throwable cause) { this.selection = EMPTY_SELECTION; + ctx.fireExceptionCaught(cause); + } + } + + /** + * The default implementation of this method will simply replace {@code this} {@link SniHandler} + * instance with a {@link SslHandler}. Users may override this method to implement custom behavior. + * + * Please be aware that this method may get called after a client has already disconnected and + * custom implementations must take it into consideration when overriding this method. + * + * It's also possible for the hostname argument to be {@code null}. + */ + protected void replaceHandler(ChannelHandlerContext ctx, String hostname, SslContext sslContext) throws Exception { + SSLEngine sslEngine = null; + try { + sslEngine = sslContext.newEngine(ctx.alloc()); + ctx.pipeline().replace(this, SslHandler.class.getName(), SslContext.newHandler(sslEngine)); + sslEngine = null; + } finally { // Since the SslHandler was not inserted into the pipeline the ownership of the SSLEngine was not // transferred to the SslHandler. // See https://github.com/netty/netty/issues/5678 ReferenceCountUtil.safeRelease(sslEngine); - ctx.fireExceptionCaught(cause); } } diff --git a/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java index 8170376f38..3df430770e 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java @@ -16,36 +16,52 @@ package io.netty.handler.ssl; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.junit.Assume.assumeTrue; + +import java.io.File; +import java.net.InetSocketAddress; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import javax.net.ssl.SSLEngine; +import javax.xml.bind.DatatypeConverter; + +import org.junit.Test; + import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelPipeline; +import io.netty.channel.DefaultEventLoopGroup; import io.netty.channel.EventLoopGroup; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.codec.DecoderException; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.util.SelfSignedCertificate; import io.netty.util.DomainNameMapping; import io.netty.util.DomainNameMappingBuilder; +import io.netty.util.Mapping; import io.netty.util.ReferenceCountUtil; -import org.junit.Test; - -import javax.xml.bind.DatatypeConverter; -import java.io.File; -import java.net.InetSocketAddress; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; - -import static org.hamcrest.CoreMatchers.is; -import static org.hamcrest.CoreMatchers.nullValue; -import static org.junit.Assert.assertThat; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import io.netty.util.ReferenceCounted; +import io.netty.util.concurrent.Promise; +import io.netty.util.internal.ObjectUtil; public class SniHandlerTest { @@ -60,10 +76,16 @@ public class SniHandlerTest { } private static SslContext makeSslContext() throws Exception { + return makeSslContext(null); + } + + private static SslContext makeSslContext(SslProvider provider) throws Exception { File keyFile = new File(SniHandlerTest.class.getResource("test_encrypted.pem").getFile()); File crtFile = new File(SniHandlerTest.class.getResource("test.crt").getFile()); - return SslContextBuilder.forServer(crtFile, keyFile, "12345").applicationProtocolConfig(newApnConfig()).build(); + return SslContextBuilder.forServer(crtFile, keyFile, "12345") + .sslProvider(provider) + .applicationProtocolConfig(newApnConfig()).build(); } private static SslContext makeSslClientContext() throws Exception { @@ -229,4 +251,137 @@ public class SniHandlerTest { group.shutdownGracefully(0, 0, TimeUnit.MICROSECONDS); } } + + @Test(timeout = 10L * 1000L) + public void testReplaceHandler() throws Exception { + + assumeTrue(OpenSsl.isAvailable()); + + final String sniHost = "sni.netty.io"; + LocalAddress address = new LocalAddress("testReplaceHandler-" + Math.random()); + EventLoopGroup group = new DefaultEventLoopGroup(1); + Channel sc = null; + Channel cc = null; + + SelfSignedCertificate cert = new SelfSignedCertificate(); + + try { + final SslContext sslServerContext = SslContextBuilder + .forServer(cert.key(), cert.cert()) + .sslProvider(SslProvider.OPENSSL) + .build(); + + final Mapping mapping = new Mapping() { + @Override + public SslContext map(String input) { + return sslServerContext; + } + }; + + final Promise releasePromise = group.next().newPromise(); + + final SniHandler handler = new SniHandler(mapping) { + @Override + protected void replaceHandler(ChannelHandlerContext ctx, + String hostname, final SslContext sslContext) + throws Exception { + + boolean success = false; + try { + // The SniHandler's replaceHandler() method allows us to implement custom behavior. + // As an example, we want to release() the SslContext upon channelInactive() or rather + // when the SslHandler closes it's SslEngine. If you take a close look at SslHandler + // you'll see that it's doing it in the #handlerRemoved0() method. + + SSLEngine sslEngine = sslContext.newEngine(ctx.alloc()); + try { + SslHandler customSslHandler = new CustomSslHandler(sslContext, sslEngine) { + @Override + public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { + try { + super.handlerRemoved0(ctx); + } finally { + releasePromise.trySuccess(null); + } + } + }; + ctx.pipeline().replace(this, CustomSslHandler.class.getName(), customSslHandler); + success = true; + } finally { + if (!success) { + ReferenceCountUtil.safeRelease(sslEngine); + } + } + } finally { + if (!success) { + ReferenceCountUtil.safeRelease(sslContext); + releasePromise.cancel(true); + } + } + } + }; + + ServerBootstrap sb = new ServerBootstrap(); + sc = sb.group(group).channel(LocalServerChannel.class).childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addFirst(handler); + } + }).bind(address).syncUninterruptibly().channel(); + + SslContext sslContext = SslContextBuilder.forClient().trustManager(InsecureTrustManagerFactory.INSTANCE) + .build(); + + Bootstrap cb = new Bootstrap(); + cc = cb.group(group).channel(LocalChannel.class).handler(new SslHandler( + sslContext.newEngine(ByteBufAllocator.DEFAULT, sniHost, -1))) + .connect(address).syncUninterruptibly().channel(); + + cc.writeAndFlush(Unpooled.wrappedBuffer("Hello, World!".getBytes())) + .syncUninterruptibly(); + + // Notice how the server's SslContext refCnt is 1 + assertEquals(1, ((ReferenceCounted) sslServerContext).refCnt()); + + // The client disconnects + cc.close().syncUninterruptibly(); + if (!releasePromise.awaitUninterruptibly(10L, TimeUnit.SECONDS)) { + throw new IllegalStateException("It doesn't seem #replaceHandler() got called."); + } + + // We should have successfully release() the SslContext + assertEquals(0, ((ReferenceCounted) sslServerContext).refCnt()); + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + if (sc != null) { + sc.close().syncUninterruptibly(); + } + group.shutdownGracefully(); + + cert.delete(); + } + } + + /** + * This is a {@link SslHandler} that will call {@code release()} on the {@link SslContext} when + * the client disconnects. + * + * @see SniHandlerTest#testReplaceHandler() + */ + private static class CustomSslHandler extends SslHandler { + private final SslContext sslContext; + + public CustomSslHandler(SslContext sslContext, SSLEngine sslEngine) { + super(sslEngine); + this.sslContext = ObjectUtil.checkNotNull(sslContext, "sslContext"); + } + + @Override + public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { + super.handlerRemoved0(ctx); + ReferenceCountUtil.release(sslContext); + } + } }