diff --git a/common/src/main/java/io/netty/util/AsyncMapping.java b/common/src/main/java/io/netty/util/AsyncMapping.java new file mode 100644 index 0000000000..63745b2488 --- /dev/null +++ b/common/src/main/java/io/netty/util/AsyncMapping.java @@ -0,0 +1,28 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; + +public interface AsyncMapping { + + /** + * Returns the {@link Future} that will provide the result of the mapping. The given {@link Promise} will + * be fulfilled when the result is available. + */ + Future map(IN input, Promise promise); +} diff --git a/handler/src/main/java/io/netty/handler/ssl/SniHandler.java b/handler/src/main/java/io/netty/handler/ssl/SniHandler.java index 0faa40f5b1..ccf638529b 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SniHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SniHandler.java @@ -18,14 +18,23 @@ package io.netty.handler.ssl; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandler; +import io.netty.channel.ChannelPromise; import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.DecoderException; +import io.netty.util.AsyncMapping; import io.netty.util.CharsetUtil; import io.netty.util.DomainNameMapping; import io.netty.util.Mapping; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; +import io.netty.util.concurrent.Promise; +import io.netty.util.internal.ObjectUtil; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; import java.net.IDN; +import java.net.SocketAddress; import java.util.List; import java.util.Locale; @@ -34,21 +43,22 @@ import java.util.Locale; * (Server Name Indication) extension for server side SSL. For clients * support SNI, the server could have multiple host name bound on a single IP. * The client will send host name in the handshake data so server could decide - * which certificate to choose for the host name.

+ * which certificate to choose for the host name.

*/ -public class SniHandler extends ByteToMessageDecoder { +public class SniHandler extends ByteToMessageDecoder implements ChannelOutboundHandler { // Maximal number of ssl records to inspect before fallback to the default SslContext. private static final int MAX_SSL_RECORDS = 4; private static final InternalLogger logger = InternalLoggerFactory.getInstance(SniHandler.class); - - private final Mapping mapping; private static final Selection EMPTY_SELECTION = new Selection(null, null); - private boolean handshakeFailed; + private final AsyncMapping mapping; + private boolean handshakeFailed; + private boolean suppressRead; + private boolean readPending; private volatile Selection selection = EMPTY_SELECTION; /** @@ -57,12 +67,8 @@ public class SniHandler extends ByteToMessageDecoder { * * @param mapping the mapping of domain name to {@link SslContext} */ - @SuppressWarnings("unchecked") public SniHandler(Mapping mapping) { - if (mapping == null) { - throw new NullPointerException("mapping"); - } - this.mapping = (Mapping) mapping; + this(new AsyncMappingAdapter(mapping)); } /** @@ -75,6 +81,17 @@ public class SniHandler extends ByteToMessageDecoder { this((Mapping) mapping); } + /** + * Creates a SNI detection handler with configured {@link SslContext} + * maintained by {@link AsyncMapping} + * + * @param mapping the mapping of domain name to {@link SslContext} + */ + @SuppressWarnings("unchecked") + public SniHandler(AsyncMapping mapping) { + this.mapping = (AsyncMapping) ObjectUtil.checkNotNull(mapping, "mapping"); + } + /** * @return the selected hostname */ @@ -91,11 +108,12 @@ public class SniHandler extends ByteToMessageDecoder { @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { - if (!handshakeFailed && in.readableBytes() >= SslUtils.SSL_RECORD_HEADER_LENGTH) { + if (!suppressRead && !handshakeFailed && in.readableBytes() >= SslUtils.SSL_RECORD_HEADER_LENGTH) { int writerIndex = in.writerIndex(); int readerIndex = in.readerIndex(); try { - loop: for (int i = 0; i < MAX_SSL_RECORDS; i++) { + loop: + for (int i = 0; i < MAX_SSL_RECORDS; i++) { int command = in.getUnsignedByte(readerIndex); // tls, but not handshake command @@ -183,6 +201,7 @@ public class SniHandler extends ByteToMessageDecoder { int serverNameLength = in.getUnsignedShort(offset + 3); String hostname = in.toString(offset + 5, serverNameLength, CharsetUtil.UTF_8); + select(ctx, IDN.toASCII(hostname, IDN.ALLOW_UNASSIGNED).toLowerCase(Locale.US)); return; @@ -212,13 +231,108 @@ public class SniHandler extends ByteToMessageDecoder { } } - private void select(ChannelHandlerContext ctx, String hostname) { - SslContext selectedContext = mapping.map(hostname); - selection = new Selection(selectedContext, hostname); - SslHandler sslHandler = selectedContext.newHandler(ctx.alloc()); + private void select(final ChannelHandlerContext ctx, final String hostname) { + Future future = mapping.map(hostname, ctx.executor().newPromise()); + if (future.isDone()) { + if (future.isSuccess()) { + replaceHandler(ctx, new Selection(future.getNow(), hostname)); + } else { + throw new DecoderException("failed to get the SslContext for " + hostname, future.cause()); + } + } else { + suppressRead = true; + future.addListener(new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + try { + suppressRead = false; + if (future.isSuccess()) { + replaceHandler(ctx, new Selection(future.getNow(), hostname)); + } else { + ctx.fireExceptionCaught(new DecoderException("failed to get the SslContext for " + + hostname, future.cause())); + } + } finally { + if (readPending) { + readPending = false; + ctx.read(); + } + } + } + }); + } + } + + private void replaceHandler(ChannelHandlerContext ctx, Selection selection) { + this.selection = selection; + SslHandler sslHandler = selection.context.newHandler(ctx.alloc()); ctx.pipeline().replace(this, SslHandler.class.getName(), sslHandler); } + @Override + public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception { + ctx.bind(localAddress, promise); + } + + @Override + public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress, + ChannelPromise promise) throws Exception { + ctx.connect(remoteAddress, localAddress, promise); + } + + @Override + public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + ctx.disconnect(promise); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + ctx.close(promise); + } + + @Override + public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + ctx.deregister(promise); + } + + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + if (suppressRead) { + readPending = true; + } else { + ctx.read(); + } + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + ctx.write(msg, promise); + } + + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + ctx.flush(); + } + + private static final class AsyncMappingAdapter implements AsyncMapping { + private final Mapping mapping; + + private AsyncMappingAdapter(Mapping mapping) { + this.mapping = ObjectUtil.checkNotNull(mapping, "mapping"); + } + + @Override + public Future map(String input, Promise promise) { + final SslContext context; + try { + context = mapping.map(input); + } catch (Throwable cause) { + return promise.setFailure(cause); + } + return promise.setSuccess(context); + } + } + private static final class Selection { final SslContext context; final String hostname;