Add SslHandler.renegotiate()

Related: #3125

Motivation:

We did not expose a way to initiate TLS renegotiation and to get
notified when the renegotiation is done.

Modifications:

- Add SslHandler.renegotiate() so that a user can initiate TLS
  renegotiation and get the future that's notified on completion
- Make SslHandler.handshakeFuture() return the future for the most
  recent handshake so that a user can get the future of the last
  renegotiation
- Add the test for renegotiation to SocketSslEchoTest

Result:

Both client-initiated and server-initiated renegotiations are now
supported properly.
This commit is contained in:
Trustin Lee 2014-12-10 18:47:53 +09:00
parent 85ec4d9cc4
commit d1612f67ad
3 changed files with 197 additions and 75 deletions

View File

@ -32,8 +32,10 @@ import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.util.concurrent.DefaultPromise;
import io.netty.util.concurrent.EventExecutor;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.concurrent.FutureListener;
import io.netty.util.concurrent.Promise;
import io.netty.util.internal.EmptyArrays;
import io.netty.util.internal.OneTimeTask;
import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
@ -201,7 +203,7 @@ public class SslHandler extends ByteToMessageDecoder {
private boolean flushedBeforeHandshakeDone;
private PendingWriteQueue pendingUnencryptedWrites;
private final LazyChannelPromise handshakePromise = new LazyChannelPromise();
private Promise<Channel> handshakePromise = new LazyChannelPromise();
private final LazyChannelPromise sslCloseFuture = new LazyChannelPromise();
/**
@ -291,7 +293,10 @@ public class SslHandler extends ByteToMessageDecoder {
}
/**
* Returns a {@link Future} that will get notified once the handshake completes.
* Returns a {@link Future} that will get notified once the current TLS handshake completes.
*
* @return the {@link Future} for the iniital TLS handshake if {@link #renegotiate()} was not invoked.
* The {@link Future} for the most recent {@linkplain #renegotiate() TLS renegotiation} otherwise.
*/
public Future<Channel> handshakeFuture() {
return handshakePromise;
@ -329,12 +334,11 @@ public class SslHandler extends ByteToMessageDecoder {
}
/**
* Return the {@link ChannelFuture} that will get notified if the inbound of the {@link SSLEngine} will get closed.
* Return the {@link Future} that will get notified if the inbound of the {@link SSLEngine} is closed.
*
* This method will return the same {@link ChannelFuture} all the time.
*
* For more informations see the apidocs of {@link SSLEngine}
* This method will return the same {@link Future} all the time.
*
* @see SSLEngine
*/
public Future<Channel> sslCloseFuture() {
return sslCloseFuture;
@ -1002,12 +1006,12 @@ public class SslHandler extends ByteToMessageDecoder {
wantsInboundHeapBuffer = true;
}
if (handshakePromise.trySuccess(ctx.channel())) {
if (logger.isDebugEnabled()) {
logger.debug(ctx.channel() + " HANDSHAKEN: " + engine.getSession().getCipherSuite());
}
ctx.fireUserEventTriggered(SslHandshakeCompletionEvent.SUCCESS);
handshakePromise.trySuccess(ctx.channel());
if (logger.isDebugEnabled()) {
logger.debug(ctx.channel() + " HANDSHAKEN: " + engine.getSession().getCipherSuite());
}
ctx.fireUserEventTriggered(SslHandshakeCompletionEvent.SUCCESS);
}
/**
@ -1066,39 +1070,92 @@ public class SslHandler extends ByteToMessageDecoder {
pendingUnencryptedWrites = new PendingWriteQueue(ctx);
if (ctx.channel().isActive() && engine.getUseClientMode()) {
// Begin the initial handshake.
// channelActive() event has been fired already, which means this.channelActive() will
// not be invoked. We have to initialize here instead.
handshake();
handshake(null);
} else {
// channelActive() event has not been fired yet. this.channelOpen() will be invoked
// and initialization will occur there.
}
}
private Future<Channel> handshake() {
final ScheduledFuture<?> timeoutFuture;
if (handshakeTimeoutMillis > 0) {
timeoutFuture = ctx.executor().schedule(new Runnable() {
@Override
public void run() {
if (handshakePromise.isDone()) {
return;
}
notifyHandshakeFailure(HANDSHAKE_TIMED_OUT);
}
}, handshakeTimeoutMillis, TimeUnit.MILLISECONDS);
} else {
timeoutFuture = null;
/**
* Performs TLS renegotiation.
*/
public Future<Channel> renegotiate() {
ChannelHandlerContext ctx = this.ctx;
if (ctx == null) {
throw new IllegalStateException();
}
handshakePromise.addListener(new GenericFutureListener<Future<Channel>>() {
@Override
public void operationComplete(Future<Channel> f) throws Exception {
if (timeoutFuture != null) {
timeoutFuture.cancel(false);
return renegotiate(ctx.executor().<Channel>newPromise());
}
/**
* Performs TLS renegotiation.
*/
public Future<Channel> renegotiate(final Promise<Channel> promise) {
if (promise == null) {
throw new NullPointerException("promise");
}
ChannelHandlerContext ctx = this.ctx;
if (ctx == null) {
throw new IllegalStateException();
}
EventExecutor executor = ctx.executor();
if (!executor.inEventLoop()) {
executor.execute(new OneTimeTask() {
@Override
public void run() {
handshake(promise);
}
});
return promise;
}
handshake(promise);
return promise;
}
/**
* Performs TLS (re)negotiation.
*
* @param newHandshakePromise if {@code null}, use the existing {@link #handshakePromise},
* assuming that the current negotiation has not been finished.
* Currently, {@code null} is expected only for the initial handshake.
*/
private void handshake(final Promise<Channel> newHandshakePromise) {
final Promise<Channel> p;
if (newHandshakePromise != null) {
final Promise<Channel> oldHandshakePromise = handshakePromise;
if (!oldHandshakePromise.isDone()) {
// There's no need to handshake because handshake is in progress already.
// Merge the new promise into the old one.
oldHandshakePromise.addListener(new FutureListener<Channel>() {
@Override
public void operationComplete(Future<Channel> future) throws Exception {
if (future.isSuccess()) {
newHandshakePromise.setSuccess(future.getNow());
} else {
newHandshakePromise.setFailure(future.cause());
}
}
});
return;
}
});
handshakePromise = p = newHandshakePromise;
} else {
// Forced to reuse the old handshake.
p = handshakePromise;
assert !p.isDone();
}
// Begin handshake.
final ChannelHandlerContext ctx = this.ctx;
try {
engine.beginHandshake();
wrapNonAppData(ctx, false);
@ -1106,26 +1163,40 @@ public class SslHandler extends ByteToMessageDecoder {
} catch (Exception e) {
notifyHandshakeFailure(e);
}
return handshakePromise;
// Set timeout if necessary.
final long handshakeTimeoutMillis = this.handshakeTimeoutMillis;
if (handshakeTimeoutMillis <= 0 || p.isDone()) {
return;
}
final ScheduledFuture<?> timeoutFuture = ctx.executor().schedule(new Runnable() {
@Override
public void run() {
if (p.isDone()) {
return;
}
notifyHandshakeFailure(HANDSHAKE_TIMED_OUT);
}
}, handshakeTimeoutMillis, TimeUnit.MILLISECONDS);
// Cancel the handshake timeout when handshake is finished.
p.addListener(new FutureListener<Channel>() {
@Override
public void operationComplete(Future<Channel> f) throws Exception {
timeoutFuture.cancel(false);
}
});
}
/**
* Issues a SSL handshake once connected when used in client-mode
* Issues an initial TLS handshake once connected when used in client-mode
*/
@Override
public void channelActive(final ChannelHandlerContext ctx) throws Exception {
if (!startTls && engine.getUseClientMode()) {
// issue and handshake and add a listener to it which will fire an exception event if
// an exception was thrown while doing the handshake
handshake().addListener(new GenericFutureListener<Future<Channel>>() {
@Override
public void operationComplete(Future<Channel> future) throws Exception {
if (!future.isSuccess()) {
logger.debug("Failed to complete handshake", future.cause());
ctx.close();
}
}
});
// Begin the initial handshake
handshake(null);
}
ctx.fireChannelActive();
}

View File

@ -31,6 +31,7 @@ import io.netty.handler.ssl.OpenSsl;
import io.netty.handler.ssl.OpenSslServerContext;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.handler.stream.ChunkedWriteHandler;
import io.netty.util.concurrent.Future;
@ -51,6 +52,7 @@ import java.util.Random;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;
@RunWith(Parameterized.class)
@ -77,8 +79,15 @@ public class SocketSslEchoTest extends AbstractSocketTest {
KEY_FILE = ssc.privateKey();
}
public enum RenegotiationType {
NONE, // no renegotiation
CLIENT_INITIATED, // renegotiation from client
SERVER_INITIATED, // renegotiation from server
}
@Parameters(name =
"{index}: serverEngine = {0}, clientEngine = {1}, useChunkedWriteHandler = {2}, useCompositeByteBuf = {3}")
"{index}: serverEngine = {0}, clientEngine = {1}, " +
"renegotiation = {2}, autoRead = {3}, useChunkedWriteHandler = {4}, useCompositeByteBuf = {5}")
public static Collection<Object[]> data() throws Exception {
List<SslContext> serverContexts = new ArrayList<SslContext>();
serverContexts.add(new JdkSslServerContext(CERT_FILE, KEY_FILE));
@ -99,8 +108,16 @@ public class SocketSslEchoTest extends AbstractSocketTest {
List<Object[]> params = new ArrayList<Object[]>();
for (SslContext sc: serverContexts) {
for (SslContext cc: clientContexts) {
for (int i = 0; i < 4; i ++) {
params.add(new Object[] { sc, cc, (i & 2) != 0, (i & 1) != 0 });
for (RenegotiationType rt: RenegotiationType.values()) {
if (rt != RenegotiationType.NONE &&
(sc instanceof OpenSslServerContext || cc instanceof OpenSslServerContext)) {
// TODO: OpenSslEngine does not support renegotiation yet.
continue;
}
for (int i = 0; i < 8; i++) {
params.add(new Object[] { sc, cc, rt, (i & 4) != 0, (i & 2) != 0, (i & 1) != 0 });
}
}
}
}
@ -110,13 +127,18 @@ public class SocketSslEchoTest extends AbstractSocketTest {
private final SslContext serverCtx;
private final SslContext clientCtx;
private final RenegotiationType renegotiationType;
private final boolean autoRead;
private final boolean useChunkedWriteHandler;
private final boolean useCompositeByteBuf;
public SocketSslEchoTest(
SslContext serverCtx, SslContext clientCtx, boolean useChunkedWriteHandler, boolean useCompositeByteBuf) {
SslContext serverCtx, SslContext clientCtx, RenegotiationType renegotiationType,
boolean autoRead, boolean useChunkedWriteHandler, boolean useCompositeByteBuf) {
this.serverCtx = serverCtx;
this.clientCtx = clientCtx;
this.renegotiationType = renegotiationType;
this.autoRead = autoRead;
this.useChunkedWriteHandler = useChunkedWriteHandler;
this.useCompositeByteBuf = useCompositeByteBuf;
}
@ -127,21 +149,8 @@ public class SocketSslEchoTest extends AbstractSocketTest {
}
public void testSslEcho(ServerBootstrap sb, Bootstrap cb) throws Throwable {
testSslEcho(sb, cb, true);
}
@Test(timeout = 30000)
public void testSslEchoNotAutoRead() throws Throwable {
run();
}
public void testSslEchoNotAutoRead(ServerBootstrap sb, Bootstrap cb) throws Throwable {
testSslEcho(sb, cb, false);
}
private void testSslEcho(ServerBootstrap sb, Bootstrap cb, boolean autoRead) throws Throwable {
final EchoHandler sh = new EchoHandler(true, useCompositeByteBuf, autoRead);
final EchoHandler ch = new EchoHandler(false, useCompositeByteBuf, autoRead);
final EchoHandler sh = new EchoHandler(true);
final EchoHandler ch = new EchoHandler(false);
sb.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
@ -175,6 +184,8 @@ public class SocketSslEchoTest extends AbstractSocketTest {
assertFalse(firstByteWriteFutureDone.get());
boolean needsRenegotiation = renegotiationType == RenegotiationType.CLIENT_INITIATED;
Future<Channel> renegoFuture = null;
for (int i = FIRST_MESSAGE_SIZE; i < data.length;) {
int length = Math.min(random.nextInt(1024 * 64), data.length - i);
ByteBuf buf = Unpooled.wrappedBuffer(data, i, length);
@ -184,8 +195,23 @@ public class SocketSslEchoTest extends AbstractSocketTest {
ChannelFuture future = cc.writeAndFlush(buf);
future.sync();
i += length;
if (needsRenegotiation && i >= data.length / 2) {
needsRenegotiation = false;
renegoFuture = cc.pipeline().get(SslHandler.class).renegotiate();
assertThat(renegoFuture, is(not(sameInstance(hf))));
}
}
// Wait until renegotiation is done.
if (renegoFuture != null) {
renegoFuture.sync();
}
if (sh.renegoFuture != null) {
sh.renegoFuture.sync();
}
// Ensure all data has been exchanged.
while (ch.counter < data.length) {
if (sh.exception.get() != null) {
break;
@ -232,20 +258,27 @@ public class SocketSslEchoTest extends AbstractSocketTest {
if (ch.exception.get() != null) {
throw ch.exception.get();
}
// When renegotiation is done, both the client and server side should be notified.
if (renegotiationType != RenegotiationType.NONE) {
assertThat(sh.negoCounter, is(2));
assertThat(ch.negoCounter, is(2));
} else {
assertThat(sh.negoCounter, is(1));
assertThat(ch.negoCounter, is(1));
}
}
private static class EchoHandler extends SimpleChannelInboundHandler<ByteBuf> {
private class EchoHandler extends SimpleChannelInboundHandler<ByteBuf> {
volatile Channel channel;
final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
volatile int counter;
private final boolean server;
private final boolean composite;
private final boolean autoRead;
volatile Future<Channel> renegoFuture;
volatile int negoCounter;
EchoHandler(boolean server, boolean composite, boolean autoRead) {
EchoHandler(boolean server) {
this.server = server;
this.composite = composite;
this.autoRead = autoRead;
}
@Override
@ -266,13 +299,24 @@ public class SocketSslEchoTest extends AbstractSocketTest {
if (channel.parent() != null) {
ByteBuf buf = Unpooled.wrappedBuffer(actual);
if (composite) {
if (useCompositeByteBuf) {
buf = Unpooled.compositeBuffer().addComponent(buf).writerIndex(buf.writerIndex());
}
channel.write(buf);
}
counter += actual.length;
// Perform server-initiated renegotiation if necessary.
if (server && renegotiationType == RenegotiationType.SERVER_INITIATED &&
counter > data.length / 2 && renegoFuture == null) {
SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
Future<Channel> hf = sslHandler.handshakeFuture();
renegoFuture = sslHandler.renegotiate();
assertThat(renegoFuture, is(not(sameInstance(hf))));
assertThat(renegoFuture, is(sameInstance(sslHandler.handshakeFuture())));
}
}
@Override
@ -286,6 +330,13 @@ public class SocketSslEchoTest extends AbstractSocketTest {
}
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof SslHandshakeCompletionEvent) {
negoCounter ++;
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx,
Throwable cause) throws Exception {

View File

@ -24,10 +24,10 @@ import io.netty.testsuite.transport.socket.SocketSslEchoTest;
import java.util.List;
public class EpollSocketSslEchoTest extends SocketSslEchoTest {
public EpollSocketSslEchoTest(
SslContext serverCtx, SslContext clientCtx, boolean useChunkedWriteHandler, boolean useCompositeByteBuf) {
super(serverCtx, clientCtx, useChunkedWriteHandler, useCompositeByteBuf);
SslContext serverCtx, SslContext clientCtx, RenegotiationType renegotiationType,
boolean autoRead, boolean useChunkedWriteHandler, boolean useCompositeByteBuf) {
super(serverCtx, clientCtx, renegotiationType, autoRead, useChunkedWriteHandler, useCompositeByteBuf);
}
@Override