ApplicationProtocolNegotiationHandler doesn't work with SniHandler

Motivation:
ApplicationProtocolNegotiationHandler attempts to get a reference to an SslHandler in handlerAdded, but when SNI is in use the actual SslHandler will be added to the pipeline dynamically at some later time. When the handshake completes ApplicationProtocolNegotiationHandler throws an IllegalStateException because its reference to SslHandler is null.

Modifications:
- Instead of saving a reference to SslHandler in handlerAdded just search the pipeline when the SslHandler is needed

Result:
ApplicationProtocolNegotiationHandler support SniHandler.
Fixes https://github.com/netty/netty/issues/5066
This commit is contained in:
Scott Mitchell 2016-04-01 16:56:03 -07:00 committed by Norman Maurer
parent 516e4933c4
commit fcbeebf6df
2 changed files with 117 additions and 17 deletions

View File

@ -65,7 +65,6 @@ public abstract class ApplicationProtocolNegotiationHandler extends ChannelInbou
InternalLoggerFactory.getInstance(ApplicationProtocolNegotiationHandler.class);
private final String fallbackProtocol;
private SslHandler sslHandler;
/**
* Creates a new instance with the specified fallback protocol name.
@ -77,18 +76,6 @@ public abstract class ApplicationProtocolNegotiationHandler extends ChannelInbou
this.fallbackProtocol = ObjectUtil.checkNotNull(fallbackProtocol, "fallbackProtocol");
}
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
// FIXME: There is no way to tell if the SSL handler is placed before the negotiation handler.
final SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
if (sslHandler == null) {
throw new IllegalStateException(
"cannot find a SslHandler in the pipeline (required for application-level protocol negotiation)");
}
this.sslHandler = sslHandler;
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof SslHandshakeCompletionEvent) {
@ -96,6 +83,11 @@ public abstract class ApplicationProtocolNegotiationHandler extends ChannelInbou
SslHandshakeCompletionEvent handshakeEvent = (SslHandshakeCompletionEvent) evt;
if (handshakeEvent.isSuccess()) {
SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
if (sslHandler == null) {
throw new IllegalStateException("cannot find a SslHandler in the pipeline (required for " +
"application-level protocol negotiation)");
}
String protocol = sslHandler.applicationProtocol();
configurePipeline(ctx, protocol != null? protocol : fallbackProtocol);
} else {

View File

@ -16,26 +16,61 @@
package io.netty.handler.ssl;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.DecoderException;
import io.netty.util.DomainMappingBuilder;
import io.netty.util.DomainNameMapping;
import io.netty.util.ReferenceCountUtil;
import org.junit.Test;
import javax.xml.bind.DatatypeConverter;
import java.io.File;
import java.net.InetSocketAddress;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import static org.hamcrest.CoreMatchers.*;
import static org.junit.Assert.*;
import javax.xml.bind.DatatypeConverter;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
public class SniHandlerTest {
private static ApplicationProtocolConfig newApnConfig() {
return new ApplicationProtocolConfig(
ApplicationProtocolConfig.Protocol.ALPN,
// NO_ADVERTISE is currently the only mode supported by both OpenSsl and JDK providers.
ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE,
// ACCEPT is currently the only mode supported by both OpenSsl and JDK providers.
ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT,
"myprotocol");
}
private static SslContext makeSslContext() throws Exception {
File keyFile = new File(SniHandlerTest.class.getResource("test_encrypted.pem").getFile());
File crtFile = new File(SniHandlerTest.class.getResource("test.crt").getFile());
return SslContextBuilder.forServer(crtFile, keyFile, "12345").build();
return SslContextBuilder.forServer(crtFile, keyFile, "12345").applicationProtocolConfig(newApnConfig()).build();
}
private static SslContext makeSslClientContext() throws Exception {
File crtFile = new File(SniHandlerTest.class.getResource("test.crt").getFile());
return SslContextBuilder.forClient().trustManager(crtFile).applicationProtocolConfig(newApnConfig()).build();
}
@Test
@ -123,4 +158,77 @@ public class SniHandlerTest {
assertThat(handler.hostname(), nullValue());
assertThat(handler.sslContext(), is(nettyContext));
}
@Test
public void testSniWithApnHandler() throws Exception {
SslContext nettyContext = makeSslContext();
SslContext leanContext = makeSslContext();
final SslContext clientContext = makeSslClientContext();
SslContext wrappedLeanContext = null;
final CountDownLatch apnDoneLatch = new CountDownLatch(1);
final DomainNameMapping<SslContext> mapping = new DomainMappingBuilder(nettyContext)
.add("*.netty.io", nettyContext)
.add("*.LEANCLOUD.CN", leanContext).build();
// hex dump of a client hello packet, which contains hostname "CHAT4。LEANCLOUD。CN"
final String tlsHandshakeMessageHex1 = "16030100";
// part 2
final String tlsHandshakeMessageHex = "bd010000b90303a74225676d1814ba57faff3b366" +
"3656ed05ee9dbb2a4dbb1bb1c32d2ea5fc39e0000000100008c0000001700150000164348" +
"415434E380824C45414E434C4F5544E38082434E000b000403000102000a00340032000e0" +
"00d0019000b000c00180009000a0016001700080006000700140015000400050012001300" +
"0100020003000f0010001100230000000d0020001e0601060206030501050205030401040" +
"20403030103020303020102020203000f00010133740000";
EventLoopGroup group = new NioEventLoopGroup(2);
Channel serverChannel = null;
Channel clientChannel = null;
try {
ServerBootstrap sb = new ServerBootstrap();
sb.group(group);
sb.channel(NioServerSocketChannel.class);
sb.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
p.addLast(new SniHandler(mapping));
p.addLast(new ApplicationProtocolNegotiationHandler("foo") {
@Override
protected void configurePipeline(ChannelHandlerContext ctx, String protocol) throws Exception {
apnDoneLatch.countDown();
}
});
}
});
Bootstrap cb = new Bootstrap();
cb.group(group);
cb.channel(NioSocketChannel.class);
cb.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
ch.write(Unpooled.wrappedBuffer(DatatypeConverter.parseHexBinary(tlsHandshakeMessageHex1)));
ch.writeAndFlush(Unpooled.wrappedBuffer(DatatypeConverter.parseHexBinary(tlsHandshakeMessageHex)));
ch.pipeline().addLast(clientContext.newHandler(ch.alloc()));
}
});
serverChannel = sb.bind(new InetSocketAddress(0)).sync().channel();
ChannelFuture ccf = cb.connect(serverChannel.localAddress());
assertTrue(ccf.awaitUninterruptibly().isSuccess());
clientChannel = ccf.channel();
assertTrue(apnDoneLatch.await(5, TimeUnit.SECONDS));
} finally {
if (serverChannel != null) {
serverChannel.close().sync();
}
if (clientChannel != null) {
clientChannel.close().sync();
}
group.shutdownGracefully(0, 0, TimeUnit.MICROSECONDS);
}
}
}