Add SslHandler.renegotiate()

Related: #3125

Motivation:

We did not expose a way to initiate TLS renegotiation and to get
notified when the renegotiation is done.

Modifications:

- Add SslHandler.renegotiate() so that a user can initiate TLS
  renegotiation and get the future that's notified on completion
- Make SslHandler.handshakeFuture() return the future for the most
  recent handshake so that a user can get the future of the last
  renegotiation
- Add the test for renegotiation to SocketSslEchoTest

Result:

Both client-initiated and server-initiated renegotiations are now
supported properly.
This commit is contained in:
Trustin Lee 2014-12-10 18:47:53 +09:00
parent b653491a9f
commit e9685ea45a
3 changed files with 204 additions and 83 deletions

View File

@ -33,9 +33,11 @@ import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.util.concurrent.DefaultPromise; import io.netty.util.concurrent.DefaultPromise;
import io.netty.util.concurrent.EventExecutor; import io.netty.util.concurrent.EventExecutor;
import io.netty.util.concurrent.Future; import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener; import io.netty.util.concurrent.FutureListener;
import io.netty.util.concurrent.ImmediateExecutor; import io.netty.util.concurrent.ImmediateExecutor;
import io.netty.util.concurrent.Promise;
import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.EmptyArrays;
import io.netty.util.internal.OneTimeTask;
import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory; import io.netty.util.internal.logging.InternalLoggerFactory;
@ -208,7 +210,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
private boolean readDuringHandshake; private boolean readDuringHandshake;
private PendingWriteQueue pendingUnencryptedWrites; private PendingWriteQueue pendingUnencryptedWrites;
private final LazyChannelPromise handshakePromise = new LazyChannelPromise(); private Promise<Channel> handshakePromise = new LazyChannelPromise();
private final LazyChannelPromise sslCloseFuture = new LazyChannelPromise(); private final LazyChannelPromise sslCloseFuture = new LazyChannelPromise();
/** /**
@ -319,7 +321,10 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
} }
/** /**
* Returns a {@link Future} that will get notified once the handshake completes. * Returns a {@link Future} that will get notified once the current TLS handshake completes.
*
* @return the {@link Future} for the iniital TLS handshake if {@link #renegotiate()} was not invoked.
* The {@link Future} for the most recent {@linkplain #renegotiate() TLS renegotiation} otherwise.
*/ */
public Future<Channel> handshakeFuture() { public Future<Channel> handshakeFuture() {
return handshakePromise; return handshakePromise;
@ -357,12 +362,11 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
} }
/** /**
* Return the {@link ChannelFuture} that will get notified if the inbound of the {@link SSLEngine} will get closed. * Return the {@link Future} that will get notified if the inbound of the {@link SSLEngine} is closed.
* *
* This method will return the same {@link ChannelFuture} all the time. * This method will return the same {@link Future} all the time.
*
* For more informations see the apidocs of {@link SSLEngine}
* *
* @see SSLEngine
*/ */
public Future<Channel> sslCloseFuture() { public Future<Channel> sslCloseFuture() {
return sslCloseFuture; return sslCloseFuture;
@ -1110,12 +1114,12 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
wantsInboundHeapBuffer = true; wantsInboundHeapBuffer = true;
} }
if (handshakePromise.trySuccess(ctx.channel())) { handshakePromise.trySuccess(ctx.channel());
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug(ctx.channel() + " HANDSHAKEN: " + engine.getSession().getCipherSuite()); logger.debug(ctx.channel() + " HANDSHAKEN: " + engine.getSession().getCipherSuite());
} }
ctx.fireUserEventTriggered(SslHandshakeCompletionEvent.SUCCESS); ctx.fireUserEventTriggered(SslHandshakeCompletionEvent.SUCCESS);
}
if (readDuringHandshake && !ctx.channel().config().isAutoRead()) { if (readDuringHandshake && !ctx.channel().config().isAutoRead()) {
readDuringHandshake = false; readDuringHandshake = false;
@ -1179,39 +1183,92 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
pendingUnencryptedWrites = new PendingWriteQueue(ctx); pendingUnencryptedWrites = new PendingWriteQueue(ctx);
if (ctx.channel().isActive() && engine.getUseClientMode()) { if (ctx.channel().isActive() && engine.getUseClientMode()) {
// Begin the initial handshake.
// channelActive() event has been fired already, which means this.channelActive() will // channelActive() event has been fired already, which means this.channelActive() will
// not be invoked. We have to initialize here instead. // not be invoked. We have to initialize here instead.
handshake(); handshake(null);
} else { } else {
// channelActive() event has not been fired yet. this.channelOpen() will be invoked // channelActive() event has not been fired yet. this.channelOpen() will be invoked
// and initialization will occur there. // and initialization will occur there.
} }
} }
private Future<Channel> handshake() { /**
final ScheduledFuture<?> timeoutFuture; * Performs TLS renegotiation.
if (handshakeTimeoutMillis > 0) { */
timeoutFuture = ctx.executor().schedule(new Runnable() { public Future<Channel> renegotiate() {
@Override ChannelHandlerContext ctx = this.ctx;
public void run() { if (ctx == null) {
if (handshakePromise.isDone()) { throw new IllegalStateException();
return;
}
notifyHandshakeFailure(HANDSHAKE_TIMED_OUT);
}
}, handshakeTimeoutMillis, TimeUnit.MILLISECONDS);
} else {
timeoutFuture = null;
} }
handshakePromise.addListener(new GenericFutureListener<Future<Channel>>() { return renegotiate(ctx.executor().<Channel>newPromise());
}
/**
* Performs TLS renegotiation.
*/
public Future<Channel> renegotiate(final Promise<Channel> promise) {
if (promise == null) {
throw new NullPointerException("promise");
}
ChannelHandlerContext ctx = this.ctx;
if (ctx == null) {
throw new IllegalStateException();
}
EventExecutor executor = ctx.executor();
if (!executor.inEventLoop()) {
executor.execute(new OneTimeTask() {
@Override @Override
public void operationComplete(Future<Channel> f) throws Exception { public void run() {
if (timeoutFuture != null) { handshake(promise);
timeoutFuture.cancel(false); }
});
return promise;
}
handshake(promise);
return promise;
}
/**
* Performs TLS (re)negotiation.
*
* @param newHandshakePromise if {@code null}, use the existing {@link #handshakePromise},
* assuming that the current negotiation has not been finished.
* Currently, {@code null} is expected only for the initial handshake.
*/
private void handshake(final Promise<Channel> newHandshakePromise) {
final Promise<Channel> p;
if (newHandshakePromise != null) {
final Promise<Channel> oldHandshakePromise = handshakePromise;
if (!oldHandshakePromise.isDone()) {
// There's no need to handshake because handshake is in progress already.
// Merge the new promise into the old one.
oldHandshakePromise.addListener(new FutureListener<Channel>() {
@Override
public void operationComplete(Future<Channel> future) throws Exception {
if (future.isSuccess()) {
newHandshakePromise.setSuccess(future.getNow());
} else {
newHandshakePromise.setFailure(future.cause());
} }
} }
}); });
return;
}
handshakePromise = p = newHandshakePromise;
} else {
// Forced to reuse the old handshake.
p = handshakePromise;
assert !p.isDone();
}
// Begin handshake.
final ChannelHandlerContext ctx = this.ctx;
try { try {
engine.beginHandshake(); engine.beginHandshake();
wrapNonAppData(ctx, false); wrapNonAppData(ctx, false);
@ -1219,26 +1276,40 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
} catch (Exception e) { } catch (Exception e) {
notifyHandshakeFailure(e); notifyHandshakeFailure(e);
} }
return handshakePromise;
// Set timeout if necessary.
final long handshakeTimeoutMillis = this.handshakeTimeoutMillis;
if (handshakeTimeoutMillis <= 0 || p.isDone()) {
return;
}
final ScheduledFuture<?> timeoutFuture = ctx.executor().schedule(new Runnable() {
@Override
public void run() {
if (p.isDone()) {
return;
}
notifyHandshakeFailure(HANDSHAKE_TIMED_OUT);
}
}, handshakeTimeoutMillis, TimeUnit.MILLISECONDS);
// Cancel the handshake timeout when handshake is finished.
p.addListener(new FutureListener<Channel>() {
@Override
public void operationComplete(Future<Channel> f) throws Exception {
timeoutFuture.cancel(false);
}
});
} }
/** /**
* Issues a SSL handshake once connected when used in client-mode * Issues an initial TLS handshake once connected when used in client-mode
*/ */
@Override @Override
public void channelActive(final ChannelHandlerContext ctx) throws Exception { public void channelActive(final ChannelHandlerContext ctx) throws Exception {
if (!startTls && engine.getUseClientMode()) { if (!startTls && engine.getUseClientMode()) {
// issue and handshake and add a listener to it which will fire an exception event if // Begin the initial handshake
// an exception was thrown while doing the handshake handshake(null);
handshake().addListener(new GenericFutureListener<Future<Channel>>() {
@Override
public void operationComplete(Future<Channel> future) throws Exception {
if (!future.isSuccess()) {
logger.debug("Failed to complete handshake", future.cause());
ctx.close();
}
}
});
} }
ctx.fireChannelActive(); ctx.fireChannelActive();
} }

View File

@ -31,6 +31,7 @@ import io.netty.handler.ssl.OpenSsl;
import io.netty.handler.ssl.OpenSslServerContext; import io.netty.handler.ssl.OpenSslServerContext;
import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslHandler; import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import io.netty.handler.ssl.util.SelfSignedCertificate; import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.handler.stream.ChunkedWriteHandler; import io.netty.handler.stream.ChunkedWriteHandler;
import io.netty.util.concurrent.Future; import io.netty.util.concurrent.Future;
@ -54,6 +55,7 @@ import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@RunWith(Parameterized.class) @RunWith(Parameterized.class)
@ -80,10 +82,16 @@ public class SocketSslEchoTest extends AbstractSocketTest {
KEY_FILE = ssc.privateKey(); KEY_FILE = ssc.privateKey();
} }
@Parameters(name = "{index}: " + public enum RenegotiationType {
"serverEngine = {0}, clientEngine = {1}, " + NONE, // no renegotiation
"serverUsesDelegatedTaskExecutor = {2}, clientUsesDelegatedTaskExecutor = {3}, " + CLIENT_INITIATED, // renegotiation from client
"useChunkedWriteHandler = {4}, useCompositeByteBuf = {5}") SERVER_INITIATED, // renegotiation from server
}
@Parameters(name =
"{index}: serverEngine = {0}, clientEngine = {1}, renegotiation = {2}, " +
"serverUsesDelegatedTaskExecutor = {3}, clientUsesDelegatedTaskExecutor = {4}, " +
"autoRead = {5}, useChunkedWriteHandler = {6}, useCompositeByteBuf = {7}")
public static Collection<Object[]> data() throws Exception { public static Collection<Object[]> data() throws Exception {
List<SslContext> serverContexts = new ArrayList<SslContext>(); List<SslContext> serverContexts = new ArrayList<SslContext>();
serverContexts.add(new JdkSslServerContext(CERT_FILE, KEY_FILE)); serverContexts.add(new JdkSslServerContext(CERT_FILE, KEY_FILE));
@ -104,8 +112,18 @@ public class SocketSslEchoTest extends AbstractSocketTest {
List<Object[]> params = new ArrayList<Object[]>(); List<Object[]> params = new ArrayList<Object[]>();
for (SslContext sc: serverContexts) { for (SslContext sc: serverContexts) {
for (SslContext cc: clientContexts) { for (SslContext cc: clientContexts) {
for (int i = 0; i < 16; i ++) { for (RenegotiationType rt: RenegotiationType.values()) {
params.add(new Object[] { sc, cc, (i & 8) != 0, (i & 4) != 0, (i & 2) != 0, (i & 1) != 0 }); if (rt != RenegotiationType.NONE &&
(sc instanceof OpenSslServerContext || cc instanceof OpenSslServerContext)) {
// TODO: OpenSslEngine does not support renegotiation yet.
continue;
}
for (int i = 0; i < 32; i++) {
params.add(new Object[] {
sc, cc, rt,
(i & 16) != 0, (i & 8) != 0, (i & 4) != 0, (i & 2) != 0, (i & 1) != 0 });
}
} }
} }
} }
@ -115,19 +133,23 @@ public class SocketSslEchoTest extends AbstractSocketTest {
private final SslContext serverCtx; private final SslContext serverCtx;
private final SslContext clientCtx; private final SslContext clientCtx;
private final RenegotiationType renegotiationType;
private final boolean serverUsesDelegatedTaskExecutor; private final boolean serverUsesDelegatedTaskExecutor;
private final boolean clientUsesDelegatedTaskExecutor; private final boolean clientUsesDelegatedTaskExecutor;
private final boolean autoRead;
private final boolean useChunkedWriteHandler; private final boolean useChunkedWriteHandler;
private final boolean useCompositeByteBuf; private final boolean useCompositeByteBuf;
public SocketSslEchoTest( public SocketSslEchoTest(
SslContext serverCtx, SslContext clientCtx, SslContext serverCtx, SslContext clientCtx, RenegotiationType renegotiationType,
boolean serverUsesDelegatedTaskExecutor, boolean clientUsesDelegatedTaskExecutor, boolean serverUsesDelegatedTaskExecutor, boolean clientUsesDelegatedTaskExecutor,
boolean useChunkedWriteHandler, boolean useCompositeByteBuf) { boolean autoRead, boolean useChunkedWriteHandler, boolean useCompositeByteBuf) {
this.serverCtx = serverCtx; this.serverCtx = serverCtx;
this.clientCtx = clientCtx; this.clientCtx = clientCtx;
this.serverUsesDelegatedTaskExecutor = serverUsesDelegatedTaskExecutor; this.serverUsesDelegatedTaskExecutor = serverUsesDelegatedTaskExecutor;
this.clientUsesDelegatedTaskExecutor = clientUsesDelegatedTaskExecutor; this.clientUsesDelegatedTaskExecutor = clientUsesDelegatedTaskExecutor;
this.renegotiationType = renegotiationType;
this.autoRead = autoRead;
this.useChunkedWriteHandler = useChunkedWriteHandler; this.useChunkedWriteHandler = useChunkedWriteHandler;
this.useCompositeByteBuf = useCompositeByteBuf; this.useCompositeByteBuf = useCompositeByteBuf;
} }
@ -138,22 +160,9 @@ public class SocketSslEchoTest extends AbstractSocketTest {
} }
public void testSslEcho(ServerBootstrap sb, Bootstrap cb) throws Throwable { public void testSslEcho(ServerBootstrap sb, Bootstrap cb) throws Throwable {
testSslEcho(sb, cb, true);
}
@Test(timeout = 30000)
public void testSslEchoNotAutoRead() throws Throwable {
run();
}
public void testSslEchoNotAutoRead(ServerBootstrap sb, Bootstrap cb) throws Throwable {
testSslEcho(sb, cb, false);
}
private void testSslEcho(ServerBootstrap sb, Bootstrap cb, boolean autoRead) throws Throwable {
final ExecutorService delegatedTaskExecutor = Executors.newCachedThreadPool(); final ExecutorService delegatedTaskExecutor = Executors.newCachedThreadPool();
final EchoHandler sh = new EchoHandler(true, useCompositeByteBuf, autoRead); final EchoHandler sh = new EchoHandler(true);
final EchoHandler ch = new EchoHandler(false, useCompositeByteBuf, autoRead); final EchoHandler ch = new EchoHandler(false);
sb.childHandler(new ChannelInitializer<SocketChannel>() { sb.childHandler(new ChannelInitializer<SocketChannel>() {
@Override @Override
@ -199,6 +208,8 @@ public class SocketSslEchoTest extends AbstractSocketTest {
assertFalse(firstByteWriteFutureDone.get()); assertFalse(firstByteWriteFutureDone.get());
boolean needsRenegotiation = renegotiationType == RenegotiationType.CLIENT_INITIATED;
Future<Channel> renegoFuture = null;
for (int i = FIRST_MESSAGE_SIZE; i < data.length;) { for (int i = FIRST_MESSAGE_SIZE; i < data.length;) {
int length = Math.min(random.nextInt(1024 * 64), data.length - i); int length = Math.min(random.nextInt(1024 * 64), data.length - i);
ByteBuf buf = Unpooled.wrappedBuffer(data, i, length); ByteBuf buf = Unpooled.wrappedBuffer(data, i, length);
@ -208,8 +219,23 @@ public class SocketSslEchoTest extends AbstractSocketTest {
ChannelFuture future = cc.writeAndFlush(buf); ChannelFuture future = cc.writeAndFlush(buf);
future.sync(); future.sync();
i += length; i += length;
if (needsRenegotiation && i >= data.length / 2) {
needsRenegotiation = false;
renegoFuture = cc.pipeline().get(SslHandler.class).renegotiate();
assertThat(renegoFuture, is(not(sameInstance(hf))));
}
} }
// Wait until renegotiation is done.
if (renegoFuture != null) {
renegoFuture.sync();
}
if (sh.renegoFuture != null) {
sh.renegoFuture.sync();
}
// Ensure all data has been exchanged.
while (ch.counter < data.length) { while (ch.counter < data.length) {
if (sh.exception.get() != null) { if (sh.exception.get() != null) {
break; break;
@ -257,20 +283,27 @@ public class SocketSslEchoTest extends AbstractSocketTest {
if (ch.exception.get() != null) { if (ch.exception.get() != null) {
throw ch.exception.get(); throw ch.exception.get();
} }
// When renegotiation is done, both the client and server side should be notified.
if (renegotiationType != RenegotiationType.NONE) {
assertThat(sh.negoCounter, is(2));
assertThat(ch.negoCounter, is(2));
} else {
assertThat(sh.negoCounter, is(1));
assertThat(ch.negoCounter, is(1));
}
} }
private static class EchoHandler extends SimpleChannelInboundHandler<ByteBuf> { private class EchoHandler extends SimpleChannelInboundHandler<ByteBuf> {
volatile Channel channel; volatile Channel channel;
final AtomicReference<Throwable> exception = new AtomicReference<Throwable>(); final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
volatile int counter; volatile int counter;
private final boolean server; private final boolean server;
private final boolean composite; volatile Future<Channel> renegoFuture;
private final boolean autoRead; volatile int negoCounter;
EchoHandler(boolean server, boolean composite, boolean autoRead) { EchoHandler(boolean server) {
this.server = server; this.server = server;
this.composite = composite;
this.autoRead = autoRead;
} }
@Override @Override
@ -291,13 +324,24 @@ public class SocketSslEchoTest extends AbstractSocketTest {
if (channel.parent() != null) { if (channel.parent() != null) {
ByteBuf buf = Unpooled.wrappedBuffer(actual); ByteBuf buf = Unpooled.wrappedBuffer(actual);
if (composite) { if (useCompositeByteBuf) {
buf = Unpooled.compositeBuffer().addComponent(buf).writerIndex(buf.writerIndex()); buf = Unpooled.compositeBuffer().addComponent(buf).writerIndex(buf.writerIndex());
} }
channel.write(buf); channel.write(buf);
} }
counter += actual.length; counter += actual.length;
// Perform server-initiated renegotiation if necessary.
if (server && renegotiationType == RenegotiationType.SERVER_INITIATED &&
counter > data.length / 2 && renegoFuture == null) {
SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
Future<Channel> hf = sslHandler.handshakeFuture();
renegoFuture = sslHandler.renegotiate();
assertThat(renegoFuture, is(not(sameInstance(hf))));
assertThat(renegoFuture, is(sameInstance(sslHandler.handshakeFuture())));
}
} }
@Override @Override
@ -311,6 +355,13 @@ public class SocketSslEchoTest extends AbstractSocketTest {
} }
} }
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof SslHandshakeCompletionEvent) {
negoCounter ++;
}
}
@Override @Override
public void exceptionCaught(ChannelHandlerContext ctx, public void exceptionCaught(ChannelHandlerContext ctx,
Throwable cause) throws Exception { Throwable cause) throws Exception {

View File

@ -24,15 +24,14 @@ import io.netty.testsuite.transport.socket.SocketSslEchoTest;
import java.util.List; import java.util.List;
public class EpollSocketSslEchoTest extends SocketSslEchoTest { public class EpollSocketSslEchoTest extends SocketSslEchoTest {
public EpollSocketSslEchoTest( public EpollSocketSslEchoTest(
SslContext serverCtx, SslContext clientCtx, SslContext serverCtx, SslContext clientCtx, RenegotiationType renegotiationType,
boolean serverUsesDelegatedTaskExecutor, boolean clientUsesDelegatedTaskExecutor, boolean serverUsesDelegatedTaskExecutor, boolean clientUsesDelegatedTaskExecutor,
boolean useChunkedWriteHandler, boolean useCompositeByteBuf) { boolean autoRead, boolean useChunkedWriteHandler, boolean useCompositeByteBuf) {
super(
serverCtx, clientCtx, super(serverCtx, clientCtx, renegotiationType,
serverUsesDelegatedTaskExecutor, clientUsesDelegatedTaskExecutor, serverUsesDelegatedTaskExecutor, clientUsesDelegatedTaskExecutor,
useChunkedWriteHandler, useCompositeByteBuf); autoRead, useChunkedWriteHandler, useCompositeByteBuf);
} }
@Override @Override