From 978a46cc0a5242702f248842f6feb97b8fbe0123 Mon Sep 17 00:00:00 2001 From: Scott Mitchell Date: Wed, 31 Jan 2018 13:16:51 -0800 Subject: [PATCH] SslHandler unwrap out of order promise/event notificaiton Motivation: SslHandler#decode methods catch any exceptions and attempt to wrap before shutting down the engine. The intention is to write any alerts which the engine may have pending. However the wrap process may also attempt to write user data, and may also complete the associated promises. If this is the case, and a promise listener closes the channel then SslHandler may later propagate a SslHandshakeCompletionEvent user event through the pipeline. Since the channel has already been closed the user may no longer be paying attention to user events. Modifications: - Sslhandler#decode should first fail the associated handshake promise and propagate the SslHandshakeCompletionEvent before attempting to wrap Result: Fixes https://github.com/netty/netty/issues/7639 --- .../java/io/netty/handler/ssl/SslHandler.java | 10 +- .../io/netty/handler/ssl/SslHandlerTest.java | 101 +++++++++++++++++- 2 files changed, 109 insertions(+), 2 deletions(-) diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index 60e7a3786c..a5ebc363f3 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -1176,6 +1176,14 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH private void handleUnwrapThrowable(ChannelHandlerContext ctx, Throwable cause) { try { + // We should attempt to notify the handshake failure before writing any pending data. If we are in unwrap + // and failed during the handshake process, and we attempt to wrap, then promises will fail, and if + // listeners immediately close the Channel then we may end up firing the handshake event after the Channel + // has been closed. + if (handshakePromise.tryFailure(cause)) { + ctx.fireUserEventTriggered(new SslHandshakeCompletionEvent(cause)); + } + // We need to flush one time as there may be an alert that we should send to the remote peer because // of the SSLException reported here. wrapAndFlush(ctx); @@ -1183,7 +1191,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH logger.debug("SSLException during trying to call SSLEngine.wrap(...)" + " because of an previous SSLException, ignoring...", ex); } finally { - setHandshakeFailure(ctx, cause); + setHandshakeFailure(ctx, cause, true, false); } PlatformDependent.throwException(cause); } diff --git a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java index 65eaffd435..5366de289f 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java @@ -22,6 +22,8 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; @@ -29,9 +31,13 @@ import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOutboundHandlerAdapter; import io.netty.channel.ChannelPromise; import io.netty.channel.DefaultChannelId; +import io.netty.channel.DefaultEventLoopGroup; import io.netty.channel.EventLoopGroup; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; @@ -431,6 +437,99 @@ public class SslHandlerTest { assertTrue(events.isEmpty()); } + @Test(timeout = 5000) + public void testHandshakeFailBeforeWritePromise() throws Exception { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + final SslContext sslServerCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()).build(); + final CountDownLatch latch = new CountDownLatch(2); + final CountDownLatch latch2 = new CountDownLatch(2); + final BlockingQueue events = new LinkedBlockingQueue(); + Channel serverChannel = null; + Channel clientChannel = null; + EventLoopGroup group = new DefaultEventLoopGroup(); + try { + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(sslServerCtx.newHandler(ch.alloc())); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) { + ByteBuf buf = ctx.alloc().buffer(10); + buf.writeZero(buf.capacity()); + ctx.writeAndFlush(buf).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + events.add(future); + latch.countDown(); + } + }); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof SslCompletionEvent) { + events.add(evt); + latch.countDown(); + latch2.countDown(); + } + } + }); + } + }); + + Bootstrap cb = new Bootstrap(); + cb.group(group) + .channel(LocalChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addFirst(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) { + ByteBuf buf = ctx.alloc().buffer(1000); + buf.writeZero(buf.capacity()); + ctx.writeAndFlush(buf); + } + }); + } + }); + + serverChannel = sb.bind(new LocalAddress("SslHandlerTest")).sync().channel(); + clientChannel = cb.connect(serverChannel.localAddress()).sync().channel(); + latch.await(); + + SslCompletionEvent evt = (SslCompletionEvent) events.take(); + assertTrue(evt instanceof SslHandshakeCompletionEvent); + assertThat(evt.cause(), is(instanceOf(SSLException.class))); + + ChannelFuture future = (ChannelFuture) events.take(); + assertThat(future.cause(), is(instanceOf(SSLException.class))); + + serverChannel.close().sync(); + serverChannel = null; + clientChannel.close().sync(); + clientChannel = null; + + latch2.await(); + evt = (SslCompletionEvent) events.take(); + assertTrue(evt instanceof SslCloseCompletionEvent); + assertThat(evt.cause(), is(instanceOf(ClosedChannelException.class))); + assertTrue(events.isEmpty()); + } finally { + if (serverChannel != null) { + serverChannel.close(); + } + if (clientChannel != null) { + clientChannel.close(); + } + group.shutdownGracefully(); + } + } + @Test public void writingReadOnlyBufferDoesNotBreakAggregation() throws Exception { SelfSignedCertificate ssc = new SelfSignedCertificate(); @@ -483,7 +582,7 @@ public class SslHandlerTest { firstBuffer.writeByte(0); firstBuffer = firstBuffer.asReadOnly(); ByteBuf secondBuffer = Unpooled.buffer(10); - secondBuffer.writerIndex(secondBuffer.capacity()); + secondBuffer.writeZero(secondBuffer.capacity()); cc.write(firstBuffer); cc.writeAndFlush(secondBuffer).syncUninterruptibly(); serverReceiveLatch.countDown();