Allow to do async mappings in the SniHandler

Motivation:

Sometimes a user want to do async mappings in the SniHandler as it is not possible to populate a Mapping up front.

Modifications:

Add AsyncMapping interface and make SniHandler work with it.

Result:

It is possible to do async mappings for SNI
This commit is contained in:
Norman Maurer 2015-12-17 20:25:02 +01:00
parent 7bcae8919d
commit 7b51412c3c
2 changed files with 158 additions and 16 deletions

View File

@ -0,0 +1,28 @@
/*
* Copyright 2015 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.util;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.Promise;
public interface AsyncMapping<IN, OUT> {
/**
* Returns the {@link Future} that will provide the result of the mapping. The given {@link Promise} will
* be fulfilled when the result is available.
*/
Future<OUT> map(IN input, Promise<OUT> promise);
}

View File

@ -18,14 +18,23 @@ package io.netty.handler.ssl;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandler;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.DecoderException;
import io.netty.util.AsyncMapping;
import io.netty.util.CharsetUtil;
import io.netty.util.DomainNameMapping;
import io.netty.util.Mapping;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.FutureListener;
import io.netty.util.concurrent.Promise;
import io.netty.util.internal.ObjectUtil;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import java.net.IDN;
import java.net.SocketAddress;
import java.util.List;
import java.util.Locale;
@ -36,19 +45,20 @@ import java.util.Locale;
* The client will send host name in the handshake data so server could decide
* which certificate to choose for the host name.</p>
*/
public class SniHandler extends ByteToMessageDecoder {
public class SniHandler extends ByteToMessageDecoder implements ChannelOutboundHandler {
// Maximal number of ssl records to inspect before fallback to the default SslContext.
private static final int MAX_SSL_RECORDS = 4;
private static final InternalLogger logger =
InternalLoggerFactory.getInstance(SniHandler.class);
private final Mapping<Object, SslContext> mapping;
private static final Selection EMPTY_SELECTION = new Selection(null, null);
private boolean handshakeFailed;
private final AsyncMapping<String, SslContext> mapping;
private boolean handshakeFailed;
private boolean suppressRead;
private boolean readPending;
private volatile Selection selection = EMPTY_SELECTION;
/**
@ -57,12 +67,8 @@ public class SniHandler extends ByteToMessageDecoder {
*
* @param mapping the mapping of domain name to {@link SslContext}
*/
@SuppressWarnings("unchecked")
public SniHandler(Mapping<? super String, ? extends SslContext> mapping) {
if (mapping == null) {
throw new NullPointerException("mapping");
}
this.mapping = (Mapping<Object, SslContext>) mapping;
this(new AsyncMappingAdapter(mapping));
}
/**
@ -75,6 +81,17 @@ public class SniHandler extends ByteToMessageDecoder {
this((Mapping<String, ? extends SslContext>) mapping);
}
/**
* Creates a SNI detection handler with configured {@link SslContext}
* maintained by {@link AsyncMapping}
*
* @param mapping the mapping of domain name to {@link SslContext}
*/
@SuppressWarnings("unchecked")
public SniHandler(AsyncMapping<? super String, ? extends SslContext> mapping) {
this.mapping = (AsyncMapping<String, SslContext>) ObjectUtil.checkNotNull(mapping, "mapping");
}
/**
* @return the selected hostname
*/
@ -91,11 +108,12 @@ public class SniHandler extends ByteToMessageDecoder {
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
if (!handshakeFailed && in.readableBytes() >= SslUtils.SSL_RECORD_HEADER_LENGTH) {
if (!suppressRead && !handshakeFailed && in.readableBytes() >= SslUtils.SSL_RECORD_HEADER_LENGTH) {
int writerIndex = in.writerIndex();
int readerIndex = in.readerIndex();
try {
loop: for (int i = 0; i < MAX_SSL_RECORDS; i++) {
loop:
for (int i = 0; i < MAX_SSL_RECORDS; i++) {
int command = in.getUnsignedByte(readerIndex);
// tls, but not handshake command
@ -183,6 +201,7 @@ public class SniHandler extends ByteToMessageDecoder {
int serverNameLength = in.getUnsignedShort(offset + 3);
String hostname = in.toString(offset + 5, serverNameLength,
CharsetUtil.UTF_8);
select(ctx, IDN.toASCII(hostname,
IDN.ALLOW_UNASSIGNED).toLowerCase(Locale.US));
return;
@ -212,13 +231,108 @@ public class SniHandler extends ByteToMessageDecoder {
}
}
private void select(ChannelHandlerContext ctx, String hostname) {
SslContext selectedContext = mapping.map(hostname);
selection = new Selection(selectedContext, hostname);
SslHandler sslHandler = selectedContext.newHandler(ctx.alloc());
private void select(final ChannelHandlerContext ctx, final String hostname) {
Future<SslContext> future = mapping.map(hostname, ctx.executor().<SslContext>newPromise());
if (future.isDone()) {
if (future.isSuccess()) {
replaceHandler(ctx, new Selection(future.getNow(), hostname));
} else {
throw new DecoderException("failed to get the SslContext for " + hostname, future.cause());
}
} else {
suppressRead = true;
future.addListener(new FutureListener<SslContext>() {
@Override
public void operationComplete(Future<SslContext> future) throws Exception {
try {
suppressRead = false;
if (future.isSuccess()) {
replaceHandler(ctx, new Selection(future.getNow(), hostname));
} else {
ctx.fireExceptionCaught(new DecoderException("failed to get the SslContext for "
+ hostname, future.cause()));
}
} finally {
if (readPending) {
readPending = false;
ctx.read();
}
}
}
});
}
}
private void replaceHandler(ChannelHandlerContext ctx, Selection selection) {
this.selection = selection;
SslHandler sslHandler = selection.context.newHandler(ctx.alloc());
ctx.pipeline().replace(this, SslHandler.class.getName(), sslHandler);
}
@Override
public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception {
ctx.bind(localAddress, promise);
}
@Override
public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress,
ChannelPromise promise) throws Exception {
ctx.connect(remoteAddress, localAddress, promise);
}
@Override
public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
ctx.disconnect(promise);
}
@Override
public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
ctx.close(promise);
}
@Override
public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
ctx.deregister(promise);
}
@Override
public void read(ChannelHandlerContext ctx) throws Exception {
if (suppressRead) {
readPending = true;
} else {
ctx.read();
}
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
ctx.write(msg, promise);
}
@Override
public void flush(ChannelHandlerContext ctx) throws Exception {
ctx.flush();
}
private static final class AsyncMappingAdapter implements AsyncMapping<String, SslContext> {
private final Mapping<? super String, ? extends SslContext> mapping;
private AsyncMappingAdapter(Mapping<? super String, ? extends SslContext> mapping) {
this.mapping = ObjectUtil.checkNotNull(mapping, "mapping");
}
@Override
public Future<SslContext> map(String input, Promise<SslContext> promise) {
final SslContext context;
try {
context = mapping.map(input);
} catch (Throwable cause) {
return promise.setFailure(cause);
}
return promise.setSuccess(context);
}
}
private static final class Selection {
final SslContext context;
final String hostname;