Add SslHandler.renegotiate()

Related: #3125

Motivation:

We did not expose a way to initiate TLS renegotiation and to get
notified when the renegotiation is done.

Modifications:

- Add SslHandler.renegotiate() so that a user can initiate TLS
  renegotiation and get the future that's notified on completion
- Make SslHandler.handshakeFuture() return the future for the most
  recent handshake so that a user can get the future of the last
  renegotiation
- Add the test for renegotiation to SocketSslEchoTest

Result:

Both client-initiated and server-initiated renegotiations are now
supported properly.
This commit is contained in:
Trustin Lee 2014-12-10 18:47:53 +09:00
parent b653491a9f
commit e9685ea45a
3 changed files with 204 additions and 83 deletions

View File

@ -33,9 +33,11 @@ import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.util.concurrent.DefaultPromise;
import io.netty.util.concurrent.EventExecutor;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.concurrent.FutureListener;
import io.netty.util.concurrent.ImmediateExecutor;
import io.netty.util.concurrent.Promise;
import io.netty.util.internal.EmptyArrays;
import io.netty.util.internal.OneTimeTask;
import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
@ -208,7 +210,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
private boolean readDuringHandshake;
private PendingWriteQueue pendingUnencryptedWrites;
private final LazyChannelPromise handshakePromise = new LazyChannelPromise();
private Promise<Channel> handshakePromise = new LazyChannelPromise();
private final LazyChannelPromise sslCloseFuture = new LazyChannelPromise();
/**
@ -319,7 +321,10 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
}
/**
* Returns a {@link Future} that will get notified once the handshake completes.
* Returns a {@link Future} that will get notified once the current TLS handshake completes.
*
* @return the {@link Future} for the iniital TLS handshake if {@link #renegotiate()} was not invoked.
* The {@link Future} for the most recent {@linkplain #renegotiate() TLS renegotiation} otherwise.
*/
public Future<Channel> handshakeFuture() {
return handshakePromise;
@ -357,12 +362,11 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
}
/**
* Return the {@link ChannelFuture} that will get notified if the inbound of the {@link SSLEngine} will get closed.
* Return the {@link Future} that will get notified if the inbound of the {@link SSLEngine} is closed.
*
* This method will return the same {@link ChannelFuture} all the time.
*
* For more informations see the apidocs of {@link SSLEngine}
* This method will return the same {@link Future} all the time.
*
* @see SSLEngine
*/
public Future<Channel> sslCloseFuture() {
return sslCloseFuture;
@ -1110,12 +1114,12 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
wantsInboundHeapBuffer = true;
}
if (handshakePromise.trySuccess(ctx.channel())) {
handshakePromise.trySuccess(ctx.channel());
if (logger.isDebugEnabled()) {
logger.debug(ctx.channel() + " HANDSHAKEN: " + engine.getSession().getCipherSuite());
}
ctx.fireUserEventTriggered(SslHandshakeCompletionEvent.SUCCESS);
}
if (readDuringHandshake && !ctx.channel().config().isAutoRead()) {
readDuringHandshake = false;
@ -1179,39 +1183,92 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
pendingUnencryptedWrites = new PendingWriteQueue(ctx);
if (ctx.channel().isActive() && engine.getUseClientMode()) {
// Begin the initial handshake.
// channelActive() event has been fired already, which means this.channelActive() will
// not be invoked. We have to initialize here instead.
handshake();
handshake(null);
} else {
// channelActive() event has not been fired yet. this.channelOpen() will be invoked
// and initialization will occur there.
}
}
private Future<Channel> handshake() {
final ScheduledFuture<?> timeoutFuture;
if (handshakeTimeoutMillis > 0) {
timeoutFuture = ctx.executor().schedule(new Runnable() {
@Override
public void run() {
if (handshakePromise.isDone()) {
return;
}
notifyHandshakeFailure(HANDSHAKE_TIMED_OUT);
}
}, handshakeTimeoutMillis, TimeUnit.MILLISECONDS);
} else {
timeoutFuture = null;
/**
* Performs TLS renegotiation.
*/
public Future<Channel> renegotiate() {
ChannelHandlerContext ctx = this.ctx;
if (ctx == null) {
throw new IllegalStateException();
}
handshakePromise.addListener(new GenericFutureListener<Future<Channel>>() {
return renegotiate(ctx.executor().<Channel>newPromise());
}
/**
* Performs TLS renegotiation.
*/
public Future<Channel> renegotiate(final Promise<Channel> promise) {
if (promise == null) {
throw new NullPointerException("promise");
}
ChannelHandlerContext ctx = this.ctx;
if (ctx == null) {
throw new IllegalStateException();
}
EventExecutor executor = ctx.executor();
if (!executor.inEventLoop()) {
executor.execute(new OneTimeTask() {
@Override
public void operationComplete(Future<Channel> f) throws Exception {
if (timeoutFuture != null) {
timeoutFuture.cancel(false);
public void run() {
handshake(promise);
}
});
return promise;
}
handshake(promise);
return promise;
}
/**
* Performs TLS (re)negotiation.
*
* @param newHandshakePromise if {@code null}, use the existing {@link #handshakePromise},
* assuming that the current negotiation has not been finished.
* Currently, {@code null} is expected only for the initial handshake.
*/
private void handshake(final Promise<Channel> newHandshakePromise) {
final Promise<Channel> p;
if (newHandshakePromise != null) {
final Promise<Channel> oldHandshakePromise = handshakePromise;
if (!oldHandshakePromise.isDone()) {
// There's no need to handshake because handshake is in progress already.
// Merge the new promise into the old one.
oldHandshakePromise.addListener(new FutureListener<Channel>() {
@Override
public void operationComplete(Future<Channel> future) throws Exception {
if (future.isSuccess()) {
newHandshakePromise.setSuccess(future.getNow());
} else {
newHandshakePromise.setFailure(future.cause());
}
}
});
return;
}
handshakePromise = p = newHandshakePromise;
} else {
// Forced to reuse the old handshake.
p = handshakePromise;
assert !p.isDone();
}
// Begin handshake.
final ChannelHandlerContext ctx = this.ctx;
try {
engine.beginHandshake();
wrapNonAppData(ctx, false);
@ -1219,26 +1276,40 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
} catch (Exception e) {
notifyHandshakeFailure(e);
}
return handshakePromise;
// Set timeout if necessary.
final long handshakeTimeoutMillis = this.handshakeTimeoutMillis;
if (handshakeTimeoutMillis <= 0 || p.isDone()) {
return;
}
final ScheduledFuture<?> timeoutFuture = ctx.executor().schedule(new Runnable() {
@Override
public void run() {
if (p.isDone()) {
return;
}
notifyHandshakeFailure(HANDSHAKE_TIMED_OUT);
}
}, handshakeTimeoutMillis, TimeUnit.MILLISECONDS);
// Cancel the handshake timeout when handshake is finished.
p.addListener(new FutureListener<Channel>() {
@Override
public void operationComplete(Future<Channel> f) throws Exception {
timeoutFuture.cancel(false);
}
});
}
/**
* Issues a SSL handshake once connected when used in client-mode
* Issues an initial TLS handshake once connected when used in client-mode
*/
@Override
public void channelActive(final ChannelHandlerContext ctx) throws Exception {
if (!startTls && engine.getUseClientMode()) {
// issue and handshake and add a listener to it which will fire an exception event if
// an exception was thrown while doing the handshake
handshake().addListener(new GenericFutureListener<Future<Channel>>() {
@Override
public void operationComplete(Future<Channel> future) throws Exception {
if (!future.isSuccess()) {
logger.debug("Failed to complete handshake", future.cause());
ctx.close();
}
}
});
// Begin the initial handshake
handshake(null);
}
ctx.fireChannelActive();
}

View File

@ -31,6 +31,7 @@ import io.netty.handler.ssl.OpenSsl;
import io.netty.handler.ssl.OpenSslServerContext;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.handler.stream.ChunkedWriteHandler;
import io.netty.util.concurrent.Future;
@ -54,6 +55,7 @@ import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;
@RunWith(Parameterized.class)
@ -80,10 +82,16 @@ public class SocketSslEchoTest extends AbstractSocketTest {
KEY_FILE = ssc.privateKey();
}
@Parameters(name = "{index}: " +
"serverEngine = {0}, clientEngine = {1}, " +
"serverUsesDelegatedTaskExecutor = {2}, clientUsesDelegatedTaskExecutor = {3}, " +
"useChunkedWriteHandler = {4}, useCompositeByteBuf = {5}")
public enum RenegotiationType {
NONE, // no renegotiation
CLIENT_INITIATED, // renegotiation from client
SERVER_INITIATED, // renegotiation from server
}
@Parameters(name =
"{index}: serverEngine = {0}, clientEngine = {1}, renegotiation = {2}, " +
"serverUsesDelegatedTaskExecutor = {3}, clientUsesDelegatedTaskExecutor = {4}, " +
"autoRead = {5}, useChunkedWriteHandler = {6}, useCompositeByteBuf = {7}")
public static Collection<Object[]> data() throws Exception {
List<SslContext> serverContexts = new ArrayList<SslContext>();
serverContexts.add(new JdkSslServerContext(CERT_FILE, KEY_FILE));
@ -104,8 +112,18 @@ public class SocketSslEchoTest extends AbstractSocketTest {
List<Object[]> params = new ArrayList<Object[]>();
for (SslContext sc: serverContexts) {
for (SslContext cc: clientContexts) {
for (int i = 0; i < 16; i ++) {
params.add(new Object[] { sc, cc, (i & 8) != 0, (i & 4) != 0, (i & 2) != 0, (i & 1) != 0 });
for (RenegotiationType rt: RenegotiationType.values()) {
if (rt != RenegotiationType.NONE &&
(sc instanceof OpenSslServerContext || cc instanceof OpenSslServerContext)) {
// TODO: OpenSslEngine does not support renegotiation yet.
continue;
}
for (int i = 0; i < 32; i++) {
params.add(new Object[] {
sc, cc, rt,
(i & 16) != 0, (i & 8) != 0, (i & 4) != 0, (i & 2) != 0, (i & 1) != 0 });
}
}
}
}
@ -115,19 +133,23 @@ public class SocketSslEchoTest extends AbstractSocketTest {
private final SslContext serverCtx;
private final SslContext clientCtx;
private final RenegotiationType renegotiationType;
private final boolean serverUsesDelegatedTaskExecutor;
private final boolean clientUsesDelegatedTaskExecutor;
private final boolean autoRead;
private final boolean useChunkedWriteHandler;
private final boolean useCompositeByteBuf;
public SocketSslEchoTest(
SslContext serverCtx, SslContext clientCtx,
SslContext serverCtx, SslContext clientCtx, RenegotiationType renegotiationType,
boolean serverUsesDelegatedTaskExecutor, boolean clientUsesDelegatedTaskExecutor,
boolean useChunkedWriteHandler, boolean useCompositeByteBuf) {
boolean autoRead, boolean useChunkedWriteHandler, boolean useCompositeByteBuf) {
this.serverCtx = serverCtx;
this.clientCtx = clientCtx;
this.serverUsesDelegatedTaskExecutor = serverUsesDelegatedTaskExecutor;
this.clientUsesDelegatedTaskExecutor = clientUsesDelegatedTaskExecutor;
this.renegotiationType = renegotiationType;
this.autoRead = autoRead;
this.useChunkedWriteHandler = useChunkedWriteHandler;
this.useCompositeByteBuf = useCompositeByteBuf;
}
@ -138,22 +160,9 @@ public class SocketSslEchoTest extends AbstractSocketTest {
}
public void testSslEcho(ServerBootstrap sb, Bootstrap cb) throws Throwable {
testSslEcho(sb, cb, true);
}
@Test(timeout = 30000)
public void testSslEchoNotAutoRead() throws Throwable {
run();
}
public void testSslEchoNotAutoRead(ServerBootstrap sb, Bootstrap cb) throws Throwable {
testSslEcho(sb, cb, false);
}
private void testSslEcho(ServerBootstrap sb, Bootstrap cb, boolean autoRead) throws Throwable {
final ExecutorService delegatedTaskExecutor = Executors.newCachedThreadPool();
final EchoHandler sh = new EchoHandler(true, useCompositeByteBuf, autoRead);
final EchoHandler ch = new EchoHandler(false, useCompositeByteBuf, autoRead);
final EchoHandler sh = new EchoHandler(true);
final EchoHandler ch = new EchoHandler(false);
sb.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
@ -199,6 +208,8 @@ public class SocketSslEchoTest extends AbstractSocketTest {
assertFalse(firstByteWriteFutureDone.get());
boolean needsRenegotiation = renegotiationType == RenegotiationType.CLIENT_INITIATED;
Future<Channel> renegoFuture = null;
for (int i = FIRST_MESSAGE_SIZE; i < data.length;) {
int length = Math.min(random.nextInt(1024 * 64), data.length - i);
ByteBuf buf = Unpooled.wrappedBuffer(data, i, length);
@ -208,8 +219,23 @@ public class SocketSslEchoTest extends AbstractSocketTest {
ChannelFuture future = cc.writeAndFlush(buf);
future.sync();
i += length;
if (needsRenegotiation && i >= data.length / 2) {
needsRenegotiation = false;
renegoFuture = cc.pipeline().get(SslHandler.class).renegotiate();
assertThat(renegoFuture, is(not(sameInstance(hf))));
}
}
// Wait until renegotiation is done.
if (renegoFuture != null) {
renegoFuture.sync();
}
if (sh.renegoFuture != null) {
sh.renegoFuture.sync();
}
// Ensure all data has been exchanged.
while (ch.counter < data.length) {
if (sh.exception.get() != null) {
break;
@ -257,20 +283,27 @@ public class SocketSslEchoTest extends AbstractSocketTest {
if (ch.exception.get() != null) {
throw ch.exception.get();
}
// When renegotiation is done, both the client and server side should be notified.
if (renegotiationType != RenegotiationType.NONE) {
assertThat(sh.negoCounter, is(2));
assertThat(ch.negoCounter, is(2));
} else {
assertThat(sh.negoCounter, is(1));
assertThat(ch.negoCounter, is(1));
}
}
private static class EchoHandler extends SimpleChannelInboundHandler<ByteBuf> {
private class EchoHandler extends SimpleChannelInboundHandler<ByteBuf> {
volatile Channel channel;
final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
volatile int counter;
private final boolean server;
private final boolean composite;
private final boolean autoRead;
volatile Future<Channel> renegoFuture;
volatile int negoCounter;
EchoHandler(boolean server, boolean composite, boolean autoRead) {
EchoHandler(boolean server) {
this.server = server;
this.composite = composite;
this.autoRead = autoRead;
}
@Override
@ -291,13 +324,24 @@ public class SocketSslEchoTest extends AbstractSocketTest {
if (channel.parent() != null) {
ByteBuf buf = Unpooled.wrappedBuffer(actual);
if (composite) {
if (useCompositeByteBuf) {
buf = Unpooled.compositeBuffer().addComponent(buf).writerIndex(buf.writerIndex());
}
channel.write(buf);
}
counter += actual.length;
// Perform server-initiated renegotiation if necessary.
if (server && renegotiationType == RenegotiationType.SERVER_INITIATED &&
counter > data.length / 2 && renegoFuture == null) {
SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
Future<Channel> hf = sslHandler.handshakeFuture();
renegoFuture = sslHandler.renegotiate();
assertThat(renegoFuture, is(not(sameInstance(hf))));
assertThat(renegoFuture, is(sameInstance(sslHandler.handshakeFuture())));
}
}
@Override
@ -311,6 +355,13 @@ public class SocketSslEchoTest extends AbstractSocketTest {
}
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof SslHandshakeCompletionEvent) {
negoCounter ++;
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx,
Throwable cause) throws Exception {

View File

@ -24,15 +24,14 @@ import io.netty.testsuite.transport.socket.SocketSslEchoTest;
import java.util.List;
public class EpollSocketSslEchoTest extends SocketSslEchoTest {
public EpollSocketSslEchoTest(
SslContext serverCtx, SslContext clientCtx,
SslContext serverCtx, SslContext clientCtx, RenegotiationType renegotiationType,
boolean serverUsesDelegatedTaskExecutor, boolean clientUsesDelegatedTaskExecutor,
boolean useChunkedWriteHandler, boolean useCompositeByteBuf) {
super(
serverCtx, clientCtx,
boolean autoRead, boolean useChunkedWriteHandler, boolean useCompositeByteBuf) {
super(serverCtx, clientCtx, renegotiationType,
serverUsesDelegatedTaskExecutor, clientUsesDelegatedTaskExecutor,
useChunkedWriteHandler, useCompositeByteBuf);
autoRead, useChunkedWriteHandler, useCompositeByteBuf);
}
@Override