Expose SniHandler's replaceHandler() so that users can implement custom behavior.
Motivation The SniHandler is currently hiding its replaceHandler() method and everything that comes with it. The user has no easy way of getting a hold onto the SslContext for the purpose of reference counting for example. The SniHandler does have getter methods for the SslContext and hostname but they're not very practical or useful. For one the SniHandler will remove itself from the pipeline and we'd have to track a reference of it externally and as we saw in #5745 it'll possibly leave its internal "selection" object with the "EMPTY_SELECTION" value (i.e. we've just lost track of the SslContext). Modifications Expose replaceHandler() and allow the user to override it and get a hold onto the hostname, SslContext and SslHandler that will replace the SniHandler. Result It's possible to get a hold onto the SslContext, the hostname and the SslHandler that is about to replace the SniHandler. Users can add additional behavior.
This commit is contained in:
parent
1208b90f57
commit
f97866dbc6
@ -15,6 +15,13 @@
|
|||||||
*/
|
*/
|
||||||
package io.netty.handler.ssl;
|
package io.netty.handler.ssl;
|
||||||
|
|
||||||
|
import java.net.IDN;
|
||||||
|
import java.net.SocketAddress;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Locale;
|
||||||
|
|
||||||
|
import javax.net.ssl.SSLEngine;
|
||||||
|
|
||||||
import io.netty.buffer.ByteBuf;
|
import io.netty.buffer.ByteBuf;
|
||||||
import io.netty.buffer.ByteBufUtil;
|
import io.netty.buffer.ByteBufUtil;
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
@ -34,13 +41,6 @@ import io.netty.util.internal.ObjectUtil;
|
|||||||
import io.netty.util.internal.logging.InternalLogger;
|
import io.netty.util.internal.logging.InternalLogger;
|
||||||
import io.netty.util.internal.logging.InternalLoggerFactory;
|
import io.netty.util.internal.logging.InternalLoggerFactory;
|
||||||
|
|
||||||
import java.net.IDN;
|
|
||||||
import java.net.SocketAddress;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Locale;
|
|
||||||
|
|
||||||
import javax.net.ssl.SSLEngine;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* <p>Enables <a href="https://tools.ietf.org/html/rfc3546#section-3.1">SNI
|
* <p>Enables <a href="https://tools.ietf.org/html/rfc3546#section-3.1">SNI
|
||||||
* (Server Name Indication)</a> extension for server side SSL. For clients
|
* (Server Name Indication)</a> extension for server side SSL. For clients
|
||||||
@ -277,7 +277,7 @@ public class SniHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
|||||||
Future<SslContext> future = mapping.map(hostname, ctx.executor().<SslContext>newPromise());
|
Future<SslContext> future = mapping.map(hostname, ctx.executor().<SslContext>newPromise());
|
||||||
if (future.isDone()) {
|
if (future.isDone()) {
|
||||||
if (future.isSuccess()) {
|
if (future.isSuccess()) {
|
||||||
replaceHandler(ctx, new Selection(future.getNow(), hostname));
|
onSslContext(ctx, hostname, future.getNow());
|
||||||
} else {
|
} else {
|
||||||
throw new DecoderException("failed to get the SslContext for " + hostname, future.cause());
|
throw new DecoderException("failed to get the SslContext for " + hostname, future.cause());
|
||||||
}
|
}
|
||||||
@ -289,7 +289,7 @@ public class SniHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
|||||||
try {
|
try {
|
||||||
suppressRead = false;
|
suppressRead = false;
|
||||||
if (future.isSuccess()) {
|
if (future.isSuccess()) {
|
||||||
replaceHandler(ctx, new Selection(future.getNow(), hostname));
|
onSslContext(ctx, hostname, future.getNow());
|
||||||
} else {
|
} else {
|
||||||
ctx.fireExceptionCaught(new DecoderException("failed to get the SslContext for "
|
ctx.fireExceptionCaught(new DecoderException("failed to get the SslContext for "
|
||||||
+ hostname, future.cause()));
|
+ hostname, future.cause()));
|
||||||
@ -305,19 +305,41 @@ public class SniHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void replaceHandler(ChannelHandlerContext ctx, Selection selection) {
|
/**
|
||||||
SSLEngine sslEngine = null;
|
* Called upon successful completion of the {@link AsyncMapping}'s {@link Future}.
|
||||||
this.selection = selection;
|
*
|
||||||
|
* @see #select(ChannelHandlerContext, String)
|
||||||
|
*/
|
||||||
|
private void onSslContext(ChannelHandlerContext ctx, String hostname, SslContext sslContext) {
|
||||||
|
this.selection = new Selection(sslContext, hostname);
|
||||||
try {
|
try {
|
||||||
sslEngine = selection.context.newEngine(ctx.alloc());
|
replaceHandler(ctx, hostname, sslContext);
|
||||||
ctx.pipeline().replace(this, SslHandler.class.getName(), selection.context.newHandler(sslEngine));
|
|
||||||
} catch (Throwable cause) {
|
} catch (Throwable cause) {
|
||||||
this.selection = EMPTY_SELECTION;
|
this.selection = EMPTY_SELECTION;
|
||||||
|
ctx.fireExceptionCaught(cause);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The default implementation of this method will simply replace {@code this} {@link SniHandler}
|
||||||
|
* instance with a {@link SslHandler}. Users may override this method to implement custom behavior.
|
||||||
|
*
|
||||||
|
* Please be aware that this method may get called after a client has already disconnected and
|
||||||
|
* custom implementations must take it into consideration when overriding this method.
|
||||||
|
*
|
||||||
|
* It's also possible for the hostname argument to be {@code null}.
|
||||||
|
*/
|
||||||
|
protected void replaceHandler(ChannelHandlerContext ctx, String hostname, SslContext sslContext) throws Exception {
|
||||||
|
SSLEngine sslEngine = null;
|
||||||
|
try {
|
||||||
|
sslEngine = sslContext.newEngine(ctx.alloc());
|
||||||
|
ctx.pipeline().replace(this, SslHandler.class.getName(), SslContext.newHandler(sslEngine));
|
||||||
|
sslEngine = null;
|
||||||
|
} finally {
|
||||||
// Since the SslHandler was not inserted into the pipeline the ownership of the SSLEngine was not
|
// Since the SslHandler was not inserted into the pipeline the ownership of the SSLEngine was not
|
||||||
// transferred to the SslHandler.
|
// transferred to the SslHandler.
|
||||||
// See https://github.com/netty/netty/issues/5678
|
// See https://github.com/netty/netty/issues/5678
|
||||||
ReferenceCountUtil.safeRelease(sslEngine);
|
ReferenceCountUtil.safeRelease(sslEngine);
|
||||||
ctx.fireExceptionCaught(cause);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,36 +16,52 @@
|
|||||||
|
|
||||||
package io.netty.handler.ssl;
|
package io.netty.handler.ssl;
|
||||||
|
|
||||||
|
import static org.hamcrest.CoreMatchers.is;
|
||||||
|
import static org.hamcrest.CoreMatchers.nullValue;
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertThat;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
import static org.junit.Assert.fail;
|
||||||
|
import static org.junit.Assume.assumeTrue;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.net.InetSocketAddress;
|
||||||
|
import java.util.concurrent.CountDownLatch;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
|
import javax.net.ssl.SSLEngine;
|
||||||
|
import javax.xml.bind.DatatypeConverter;
|
||||||
|
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
import io.netty.bootstrap.Bootstrap;
|
import io.netty.bootstrap.Bootstrap;
|
||||||
import io.netty.bootstrap.ServerBootstrap;
|
import io.netty.bootstrap.ServerBootstrap;
|
||||||
|
import io.netty.buffer.ByteBufAllocator;
|
||||||
import io.netty.buffer.Unpooled;
|
import io.netty.buffer.Unpooled;
|
||||||
import io.netty.channel.Channel;
|
import io.netty.channel.Channel;
|
||||||
import io.netty.channel.ChannelFuture;
|
import io.netty.channel.ChannelFuture;
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
import io.netty.channel.ChannelInitializer;
|
import io.netty.channel.ChannelInitializer;
|
||||||
import io.netty.channel.ChannelPipeline;
|
import io.netty.channel.ChannelPipeline;
|
||||||
|
import io.netty.channel.DefaultEventLoopGroup;
|
||||||
import io.netty.channel.EventLoopGroup;
|
import io.netty.channel.EventLoopGroup;
|
||||||
import io.netty.channel.embedded.EmbeddedChannel;
|
import io.netty.channel.embedded.EmbeddedChannel;
|
||||||
|
import io.netty.channel.local.LocalAddress;
|
||||||
|
import io.netty.channel.local.LocalChannel;
|
||||||
|
import io.netty.channel.local.LocalServerChannel;
|
||||||
import io.netty.channel.nio.NioEventLoopGroup;
|
import io.netty.channel.nio.NioEventLoopGroup;
|
||||||
import io.netty.channel.socket.nio.NioServerSocketChannel;
|
import io.netty.channel.socket.nio.NioServerSocketChannel;
|
||||||
import io.netty.channel.socket.nio.NioSocketChannel;
|
import io.netty.channel.socket.nio.NioSocketChannel;
|
||||||
import io.netty.handler.codec.DecoderException;
|
import io.netty.handler.codec.DecoderException;
|
||||||
|
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
|
||||||
|
import io.netty.handler.ssl.util.SelfSignedCertificate;
|
||||||
import io.netty.util.DomainNameMapping;
|
import io.netty.util.DomainNameMapping;
|
||||||
import io.netty.util.DomainNameMappingBuilder;
|
import io.netty.util.DomainNameMappingBuilder;
|
||||||
|
import io.netty.util.Mapping;
|
||||||
import io.netty.util.ReferenceCountUtil;
|
import io.netty.util.ReferenceCountUtil;
|
||||||
import org.junit.Test;
|
import io.netty.util.ReferenceCounted;
|
||||||
|
import io.netty.util.concurrent.Promise;
|
||||||
import javax.xml.bind.DatatypeConverter;
|
import io.netty.util.internal.ObjectUtil;
|
||||||
import java.io.File;
|
|
||||||
import java.net.InetSocketAddress;
|
|
||||||
import java.util.concurrent.CountDownLatch;
|
|
||||||
import java.util.concurrent.TimeUnit;
|
|
||||||
|
|
||||||
import static org.hamcrest.CoreMatchers.is;
|
|
||||||
import static org.hamcrest.CoreMatchers.nullValue;
|
|
||||||
import static org.junit.Assert.assertThat;
|
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
import static org.junit.Assert.fail;
|
|
||||||
|
|
||||||
public class SniHandlerTest {
|
public class SniHandlerTest {
|
||||||
|
|
||||||
@ -60,10 +76,16 @@ public class SniHandlerTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private static SslContext makeSslContext() throws Exception {
|
private static SslContext makeSslContext() throws Exception {
|
||||||
|
return makeSslContext(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static SslContext makeSslContext(SslProvider provider) throws Exception {
|
||||||
File keyFile = new File(SniHandlerTest.class.getResource("test_encrypted.pem").getFile());
|
File keyFile = new File(SniHandlerTest.class.getResource("test_encrypted.pem").getFile());
|
||||||
File crtFile = new File(SniHandlerTest.class.getResource("test.crt").getFile());
|
File crtFile = new File(SniHandlerTest.class.getResource("test.crt").getFile());
|
||||||
|
|
||||||
return SslContextBuilder.forServer(crtFile, keyFile, "12345").applicationProtocolConfig(newApnConfig()).build();
|
return SslContextBuilder.forServer(crtFile, keyFile, "12345")
|
||||||
|
.sslProvider(provider)
|
||||||
|
.applicationProtocolConfig(newApnConfig()).build();
|
||||||
}
|
}
|
||||||
|
|
||||||
private static SslContext makeSslClientContext() throws Exception {
|
private static SslContext makeSslClientContext() throws Exception {
|
||||||
@ -229,4 +251,137 @@ public class SniHandlerTest {
|
|||||||
group.shutdownGracefully(0, 0, TimeUnit.MICROSECONDS);
|
group.shutdownGracefully(0, 0, TimeUnit.MICROSECONDS);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test(timeout = 10L * 1000L)
|
||||||
|
public void testReplaceHandler() throws Exception {
|
||||||
|
|
||||||
|
assumeTrue(OpenSsl.isAvailable());
|
||||||
|
|
||||||
|
final String sniHost = "sni.netty.io";
|
||||||
|
LocalAddress address = new LocalAddress("testReplaceHandler-" + Math.random());
|
||||||
|
EventLoopGroup group = new DefaultEventLoopGroup(1);
|
||||||
|
Channel sc = null;
|
||||||
|
Channel cc = null;
|
||||||
|
|
||||||
|
SelfSignedCertificate cert = new SelfSignedCertificate();
|
||||||
|
|
||||||
|
try {
|
||||||
|
final SslContext sslServerContext = SslContextBuilder
|
||||||
|
.forServer(cert.key(), cert.cert())
|
||||||
|
.sslProvider(SslProvider.OPENSSL)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
final Mapping<String, SslContext> mapping = new Mapping<String, SslContext>() {
|
||||||
|
@Override
|
||||||
|
public SslContext map(String input) {
|
||||||
|
return sslServerContext;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
final Promise<Void> releasePromise = group.next().newPromise();
|
||||||
|
|
||||||
|
final SniHandler handler = new SniHandler(mapping) {
|
||||||
|
@Override
|
||||||
|
protected void replaceHandler(ChannelHandlerContext ctx,
|
||||||
|
String hostname, final SslContext sslContext)
|
||||||
|
throws Exception {
|
||||||
|
|
||||||
|
boolean success = false;
|
||||||
|
try {
|
||||||
|
// The SniHandler's replaceHandler() method allows us to implement custom behavior.
|
||||||
|
// As an example, we want to release() the SslContext upon channelInactive() or rather
|
||||||
|
// when the SslHandler closes it's SslEngine. If you take a close look at SslHandler
|
||||||
|
// you'll see that it's doing it in the #handlerRemoved0() method.
|
||||||
|
|
||||||
|
SSLEngine sslEngine = sslContext.newEngine(ctx.alloc());
|
||||||
|
try {
|
||||||
|
SslHandler customSslHandler = new CustomSslHandler(sslContext, sslEngine) {
|
||||||
|
@Override
|
||||||
|
public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
|
||||||
|
try {
|
||||||
|
super.handlerRemoved0(ctx);
|
||||||
|
} finally {
|
||||||
|
releasePromise.trySuccess(null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
ctx.pipeline().replace(this, CustomSslHandler.class.getName(), customSslHandler);
|
||||||
|
success = true;
|
||||||
|
} finally {
|
||||||
|
if (!success) {
|
||||||
|
ReferenceCountUtil.safeRelease(sslEngine);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
if (!success) {
|
||||||
|
ReferenceCountUtil.safeRelease(sslContext);
|
||||||
|
releasePromise.cancel(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
ServerBootstrap sb = new ServerBootstrap();
|
||||||
|
sc = sb.group(group).channel(LocalServerChannel.class).childHandler(new ChannelInitializer<Channel>() {
|
||||||
|
@Override
|
||||||
|
protected void initChannel(Channel ch) throws Exception {
|
||||||
|
ch.pipeline().addFirst(handler);
|
||||||
|
}
|
||||||
|
}).bind(address).syncUninterruptibly().channel();
|
||||||
|
|
||||||
|
SslContext sslContext = SslContextBuilder.forClient().trustManager(InsecureTrustManagerFactory.INSTANCE)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
Bootstrap cb = new Bootstrap();
|
||||||
|
cc = cb.group(group).channel(LocalChannel.class).handler(new SslHandler(
|
||||||
|
sslContext.newEngine(ByteBufAllocator.DEFAULT, sniHost, -1)))
|
||||||
|
.connect(address).syncUninterruptibly().channel();
|
||||||
|
|
||||||
|
cc.writeAndFlush(Unpooled.wrappedBuffer("Hello, World!".getBytes()))
|
||||||
|
.syncUninterruptibly();
|
||||||
|
|
||||||
|
// Notice how the server's SslContext refCnt is 1
|
||||||
|
assertEquals(1, ((ReferenceCounted) sslServerContext).refCnt());
|
||||||
|
|
||||||
|
// The client disconnects
|
||||||
|
cc.close().syncUninterruptibly();
|
||||||
|
if (!releasePromise.awaitUninterruptibly(10L, TimeUnit.SECONDS)) {
|
||||||
|
throw new IllegalStateException("It doesn't seem #replaceHandler() got called.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// We should have successfully release() the SslContext
|
||||||
|
assertEquals(0, ((ReferenceCounted) sslServerContext).refCnt());
|
||||||
|
} finally {
|
||||||
|
if (cc != null) {
|
||||||
|
cc.close().syncUninterruptibly();
|
||||||
|
}
|
||||||
|
if (sc != null) {
|
||||||
|
sc.close().syncUninterruptibly();
|
||||||
|
}
|
||||||
|
group.shutdownGracefully();
|
||||||
|
|
||||||
|
cert.delete();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This is a {@link SslHandler} that will call {@code release()} on the {@link SslContext} when
|
||||||
|
* the client disconnects.
|
||||||
|
*
|
||||||
|
* @see SniHandlerTest#testReplaceHandler()
|
||||||
|
*/
|
||||||
|
private static class CustomSslHandler extends SslHandler {
|
||||||
|
private final SslContext sslContext;
|
||||||
|
|
||||||
|
public CustomSslHandler(SslContext sslContext, SSLEngine sslEngine) {
|
||||||
|
super(sslEngine);
|
||||||
|
this.sslContext = ObjectUtil.checkNotNull(sslContext, "sslContext");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
|
||||||
|
super.handlerRemoved0(ctx);
|
||||||
|
ReferenceCountUtil.release(sslContext);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user