diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index 3b683ff1e0..d175f93330 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -33,9 +33,11 @@ import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.util.concurrent.DefaultPromise; import io.netty.util.concurrent.EventExecutor; import io.netty.util.concurrent.Future; -import io.netty.util.concurrent.GenericFutureListener; +import io.netty.util.concurrent.FutureListener; import io.netty.util.concurrent.ImmediateExecutor; +import io.netty.util.concurrent.Promise; import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.OneTimeTask; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -207,7 +209,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH private boolean flushedBeforeHandshakeDone; private PendingWriteQueue pendingUnencryptedWrites; - private final LazyChannelPromise handshakePromise = new LazyChannelPromise(); + private Promise handshakePromise = new LazyChannelPromise(); private final LazyChannelPromise sslCloseFuture = new LazyChannelPromise(); /** @@ -318,7 +320,10 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH } /** - * Returns a {@link Future} that will get notified once the handshake completes. + * Returns a {@link Future} that will get notified once the current TLS handshake completes. + * + * @return the {@link Future} for the iniital TLS handshake if {@link #renegotiate()} was not invoked. + * The {@link Future} for the most recent {@linkplain #renegotiate() TLS renegotiation} otherwise. */ public Future handshakeFuture() { return handshakePromise; @@ -356,12 +361,11 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH } /** - * Return the {@link ChannelFuture} that will get notified if the inbound of the {@link SSLEngine} will get closed. + * Return the {@link Future} that will get notified if the inbound of the {@link SSLEngine} is closed. * - * This method will return the same {@link ChannelFuture} all the time. - * - * For more informations see the apidocs of {@link SSLEngine} + * This method will return the same {@link Future} all the time. * + * @see SSLEngine */ public Future sslCloseFuture() { return sslCloseFuture; @@ -1099,12 +1103,12 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH wantsInboundHeapBuffer = true; } - if (handshakePromise.trySuccess(ctx.channel())) { - if (logger.isDebugEnabled()) { - logger.debug(ctx.channel() + " HANDSHAKEN: " + engine.getSession().getCipherSuite()); - } - ctx.fireUserEventTriggered(SslHandshakeCompletionEvent.SUCCESS); + handshakePromise.trySuccess(ctx.channel()); + + if (logger.isDebugEnabled()) { + logger.debug(ctx.channel() + " HANDSHAKEN: " + engine.getSession().getCipherSuite()); } + ctx.fireUserEventTriggered(SslHandshakeCompletionEvent.SUCCESS); } /** @@ -1163,39 +1167,92 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH pendingUnencryptedWrites = new PendingWriteQueue(ctx); if (ctx.channel().isActive() && engine.getUseClientMode()) { + // Begin the initial handshake. // channelActive() event has been fired already, which means this.channelActive() will // not be invoked. We have to initialize here instead. - handshake(); + handshake(null); } else { // channelActive() event has not been fired yet. this.channelOpen() will be invoked // and initialization will occur there. } } - private Future handshake() { - final ScheduledFuture timeoutFuture; - if (handshakeTimeoutMillis > 0) { - timeoutFuture = ctx.executor().schedule(new Runnable() { - @Override - public void run() { - if (handshakePromise.isDone()) { - return; - } - notifyHandshakeFailure(HANDSHAKE_TIMED_OUT); - } - }, handshakeTimeoutMillis, TimeUnit.MILLISECONDS); - } else { - timeoutFuture = null; + /** + * Performs TLS renegotiation. + */ + public Future renegotiate() { + ChannelHandlerContext ctx = this.ctx; + if (ctx == null) { + throw new IllegalStateException(); } - handshakePromise.addListener(new GenericFutureListener>() { - @Override - public void operationComplete(Future f) throws Exception { - if (timeoutFuture != null) { - timeoutFuture.cancel(false); + return renegotiate(ctx.executor().newPromise()); + } + + /** + * Performs TLS renegotiation. + */ + public Future renegotiate(final Promise promise) { + if (promise == null) { + throw new NullPointerException("promise"); + } + + ChannelHandlerContext ctx = this.ctx; + if (ctx == null) { + throw new IllegalStateException(); + } + + EventExecutor executor = ctx.executor(); + if (!executor.inEventLoop()) { + executor.execute(new OneTimeTask() { + @Override + public void run() { + handshake(promise); } + }); + return promise; + } + + handshake(promise); + return promise; + } + + /** + * Performs TLS (re)negotiation. + * + * @param newHandshakePromise if {@code null}, use the existing {@link #handshakePromise}, + * assuming that the current negotiation has not been finished. + * Currently, {@code null} is expected only for the initial handshake. + */ + private void handshake(final Promise newHandshakePromise) { + final Promise p; + if (newHandshakePromise != null) { + final Promise oldHandshakePromise = handshakePromise; + if (!oldHandshakePromise.isDone()) { + // There's no need to handshake because handshake is in progress already. + // Merge the new promise into the old one. + oldHandshakePromise.addListener(new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + if (future.isSuccess()) { + newHandshakePromise.setSuccess(future.getNow()); + } else { + newHandshakePromise.setFailure(future.cause()); + } + } + }); + return; } - }); + + handshakePromise = p = newHandshakePromise; + } else { + // Forced to reuse the old handshake. + p = handshakePromise; + assert !p.isDone(); + } + + // Begin handshake. + final ChannelHandlerContext ctx = this.ctx; try { engine.beginHandshake(); wrapNonAppData(ctx, false); @@ -1203,26 +1260,40 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH } catch (Exception e) { notifyHandshakeFailure(e); } - return handshakePromise; + + // Set timeout if necessary. + final long handshakeTimeoutMillis = this.handshakeTimeoutMillis; + if (handshakeTimeoutMillis <= 0 || p.isDone()) { + return; + } + + final ScheduledFuture timeoutFuture = ctx.executor().schedule(new Runnable() { + @Override + public void run() { + if (p.isDone()) { + return; + } + notifyHandshakeFailure(HANDSHAKE_TIMED_OUT); + } + }, handshakeTimeoutMillis, TimeUnit.MILLISECONDS); + + // Cancel the handshake timeout when handshake is finished. + p.addListener(new FutureListener() { + @Override + public void operationComplete(Future f) throws Exception { + timeoutFuture.cancel(false); + } + }); } /** - * Issues a SSL handshake once connected when used in client-mode + * Issues an initial TLS handshake once connected when used in client-mode */ @Override public void channelActive(final ChannelHandlerContext ctx) throws Exception { if (!startTls && engine.getUseClientMode()) { - // issue and handshake and add a listener to it which will fire an exception event if - // an exception was thrown while doing the handshake - handshake().addListener(new GenericFutureListener>() { - @Override - public void operationComplete(Future future) throws Exception { - if (!future.isSuccess()) { - logger.debug("Failed to complete handshake", future.cause()); - ctx.close(); - } - } - }); + // Begin the initial handshake + handshake(null); } ctx.fireChannelActive(); } diff --git a/testsuite/src/test/java/io/netty/testsuite/transport/socket/SocketSslEchoTest.java b/testsuite/src/test/java/io/netty/testsuite/transport/socket/SocketSslEchoTest.java index d377e4996e..394fbe8314 100644 --- a/testsuite/src/test/java/io/netty/testsuite/transport/socket/SocketSslEchoTest.java +++ b/testsuite/src/test/java/io/netty/testsuite/transport/socket/SocketSslEchoTest.java @@ -31,6 +31,7 @@ import io.netty.handler.ssl.OpenSsl; import io.netty.handler.ssl.OpenSslServerContext; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslHandler; +import io.netty.handler.ssl.SslHandshakeCompletionEvent; import io.netty.handler.ssl.util.SelfSignedCertificate; import io.netty.handler.stream.ChunkedWriteHandler; import io.netty.util.concurrent.Future; @@ -54,6 +55,7 @@ import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; +import static org.hamcrest.Matchers.*; import static org.junit.Assert.*; @RunWith(Parameterized.class) @@ -80,10 +82,16 @@ public class SocketSslEchoTest extends AbstractSocketTest { KEY_FILE = ssc.privateKey(); } - @Parameters(name = "{index}: " + - "serverEngine = {0}, clientEngine = {1}, " + - "serverUsesDelegatedTaskExecutor = {2}, clientUsesDelegatedTaskExecutor = {3}, " + - "useChunkedWriteHandler = {4}, useCompositeByteBuf = {5}") + public enum RenegotiationType { + NONE, // no renegotiation + CLIENT_INITIATED, // renegotiation from client + SERVER_INITIATED, // renegotiation from server + } + + @Parameters(name = + "{index}: serverEngine = {0}, clientEngine = {1}, renegotiation = {2}, " + + "serverUsesDelegatedTaskExecutor = {3}, clientUsesDelegatedTaskExecutor = {4}, " + + "autoRead = {5}, useChunkedWriteHandler = {6}, useCompositeByteBuf = {7}") public static Collection data() throws Exception { List serverContexts = new ArrayList(); serverContexts.add(new JdkSslServerContext(CERT_FILE, KEY_FILE)); @@ -104,8 +112,18 @@ public class SocketSslEchoTest extends AbstractSocketTest { List params = new ArrayList(); for (SslContext sc: serverContexts) { for (SslContext cc: clientContexts) { - for (int i = 0; i < 16; i ++) { - params.add(new Object[] { sc, cc, (i & 8) != 0, (i & 4) != 0, (i & 2) != 0, (i & 1) != 0 }); + for (RenegotiationType rt: RenegotiationType.values()) { + if (rt != RenegotiationType.NONE && + (sc instanceof OpenSslServerContext || cc instanceof OpenSslServerContext)) { + // TODO: OpenSslEngine does not support renegotiation yet. + continue; + } + + for (int i = 0; i < 32; i++) { + params.add(new Object[] { + sc, cc, rt, + (i & 16) != 0, (i & 8) != 0, (i & 4) != 0, (i & 2) != 0, (i & 1) != 0 }); + } } } } @@ -115,19 +133,23 @@ public class SocketSslEchoTest extends AbstractSocketTest { private final SslContext serverCtx; private final SslContext clientCtx; + private final RenegotiationType renegotiationType; private final boolean serverUsesDelegatedTaskExecutor; private final boolean clientUsesDelegatedTaskExecutor; + private final boolean autoRead; private final boolean useChunkedWriteHandler; private final boolean useCompositeByteBuf; public SocketSslEchoTest( - SslContext serverCtx, SslContext clientCtx, + SslContext serverCtx, SslContext clientCtx, RenegotiationType renegotiationType, boolean serverUsesDelegatedTaskExecutor, boolean clientUsesDelegatedTaskExecutor, - boolean useChunkedWriteHandler, boolean useCompositeByteBuf) { + boolean autoRead, boolean useChunkedWriteHandler, boolean useCompositeByteBuf) { this.serverCtx = serverCtx; this.clientCtx = clientCtx; this.serverUsesDelegatedTaskExecutor = serverUsesDelegatedTaskExecutor; this.clientUsesDelegatedTaskExecutor = clientUsesDelegatedTaskExecutor; + this.renegotiationType = renegotiationType; + this.autoRead = autoRead; this.useChunkedWriteHandler = useChunkedWriteHandler; this.useCompositeByteBuf = useCompositeByteBuf; } @@ -138,22 +160,9 @@ public class SocketSslEchoTest extends AbstractSocketTest { } public void testSslEcho(ServerBootstrap sb, Bootstrap cb) throws Throwable { - testSslEcho(sb, cb, true); - } - - @Test(timeout = 30000) - public void testSslEchoNotAutoRead() throws Throwable { - run(); - } - - public void testSslEchoNotAutoRead(ServerBootstrap sb, Bootstrap cb) throws Throwable { - testSslEcho(sb, cb, false); - } - - private void testSslEcho(ServerBootstrap sb, Bootstrap cb, boolean autoRead) throws Throwable { final ExecutorService delegatedTaskExecutor = Executors.newCachedThreadPool(); - final EchoHandler sh = new EchoHandler(true, useCompositeByteBuf, autoRead); - final EchoHandler ch = new EchoHandler(false, useCompositeByteBuf, autoRead); + final EchoHandler sh = new EchoHandler(true); + final EchoHandler ch = new EchoHandler(false); sb.childHandler(new ChannelInitializer() { @Override @@ -199,6 +208,8 @@ public class SocketSslEchoTest extends AbstractSocketTest { assertFalse(firstByteWriteFutureDone.get()); + boolean needsRenegotiation = renegotiationType == RenegotiationType.CLIENT_INITIATED; + Future renegoFuture = null; for (int i = FIRST_MESSAGE_SIZE; i < data.length;) { int length = Math.min(random.nextInt(1024 * 64), data.length - i); ByteBuf buf = Unpooled.wrappedBuffer(data, i, length); @@ -208,8 +219,23 @@ public class SocketSslEchoTest extends AbstractSocketTest { ChannelFuture future = cc.writeAndFlush(buf); future.sync(); i += length; + + if (needsRenegotiation && i >= data.length / 2) { + needsRenegotiation = false; + renegoFuture = cc.pipeline().get(SslHandler.class).renegotiate(); + assertThat(renegoFuture, is(not(sameInstance(hf)))); + } } + // Wait until renegotiation is done. + if (renegoFuture != null) { + renegoFuture.sync(); + } + if (sh.renegoFuture != null) { + sh.renegoFuture.sync(); + } + + // Ensure all data has been exchanged. while (ch.counter < data.length) { if (sh.exception.get() != null) { break; @@ -257,20 +283,27 @@ public class SocketSslEchoTest extends AbstractSocketTest { if (ch.exception.get() != null) { throw ch.exception.get(); } + + // When renegotiation is done, both the client and server side should be notified. + if (renegotiationType != RenegotiationType.NONE) { + assertThat(sh.negoCounter, is(2)); + assertThat(ch.negoCounter, is(2)); + } else { + assertThat(sh.negoCounter, is(1)); + assertThat(ch.negoCounter, is(1)); + } } - private static class EchoHandler extends SimpleChannelInboundHandler { + private class EchoHandler extends SimpleChannelInboundHandler { volatile Channel channel; final AtomicReference exception = new AtomicReference(); volatile int counter; private final boolean server; - private final boolean composite; - private final boolean autoRead; + volatile Future renegoFuture; + volatile int negoCounter; - EchoHandler(boolean server, boolean composite, boolean autoRead) { + EchoHandler(boolean server) { this.server = server; - this.composite = composite; - this.autoRead = autoRead; } @Override @@ -291,13 +324,24 @@ public class SocketSslEchoTest extends AbstractSocketTest { if (channel.parent() != null) { ByteBuf buf = Unpooled.wrappedBuffer(actual); - if (composite) { + if (useCompositeByteBuf) { buf = Unpooled.compositeBuffer().addComponent(buf).writerIndex(buf.writerIndex()); } channel.write(buf); } counter += actual.length; + + // Perform server-initiated renegotiation if necessary. + if (server && renegotiationType == RenegotiationType.SERVER_INITIATED && + counter > data.length / 2 && renegoFuture == null) { + + SslHandler sslHandler = ctx.pipeline().get(SslHandler.class); + Future hf = sslHandler.handshakeFuture(); + renegoFuture = sslHandler.renegotiate(); + assertThat(renegoFuture, is(not(sameInstance(hf)))); + assertThat(renegoFuture, is(sameInstance(sslHandler.handshakeFuture()))); + } } @Override @@ -311,6 +355,13 @@ public class SocketSslEchoTest extends AbstractSocketTest { } } + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof SslHandshakeCompletionEvent) { + negoCounter ++; + } + } + @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketSslEchoTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketSslEchoTest.java index 06e01f9229..bb5c2985ea 100644 --- a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketSslEchoTest.java +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketSslEchoTest.java @@ -24,15 +24,14 @@ import io.netty.testsuite.transport.socket.SocketSslEchoTest; import java.util.List; public class EpollSocketSslEchoTest extends SocketSslEchoTest { - public EpollSocketSslEchoTest( - SslContext serverCtx, SslContext clientCtx, + SslContext serverCtx, SslContext clientCtx, RenegotiationType renegotiationType, boolean serverUsesDelegatedTaskExecutor, boolean clientUsesDelegatedTaskExecutor, - boolean useChunkedWriteHandler, boolean useCompositeByteBuf) { - super( - serverCtx, clientCtx, - serverUsesDelegatedTaskExecutor, clientUsesDelegatedTaskExecutor, - useChunkedWriteHandler, useCompositeByteBuf); + boolean autoRead, boolean useChunkedWriteHandler, boolean useCompositeByteBuf) { + + super(serverCtx, clientCtx, renegotiationType, + serverUsesDelegatedTaskExecutor, clientUsesDelegatedTaskExecutor, + autoRead, useChunkedWriteHandler, useCompositeByteBuf); } @Override