diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index 60e7a3786c..a5ebc363f3 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -1176,6 +1176,14 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH private void handleUnwrapThrowable(ChannelHandlerContext ctx, Throwable cause) { try { + // We should attempt to notify the handshake failure before writing any pending data. If we are in unwrap + // and failed during the handshake process, and we attempt to wrap, then promises will fail, and if + // listeners immediately close the Channel then we may end up firing the handshake event after the Channel + // has been closed. + if (handshakePromise.tryFailure(cause)) { + ctx.fireUserEventTriggered(new SslHandshakeCompletionEvent(cause)); + } + // We need to flush one time as there may be an alert that we should send to the remote peer because // of the SSLException reported here. wrapAndFlush(ctx); @@ -1183,7 +1191,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH logger.debug("SSLException during trying to call SSLEngine.wrap(...)" + " because of an previous SSLException, ignoring...", ex); } finally { - setHandshakeFailure(ctx, cause); + setHandshakeFailure(ctx, cause, true, false); } PlatformDependent.throwException(cause); } diff --git a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java index 65eaffd435..5366de289f 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java @@ -22,6 +22,8 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; @@ -29,9 +31,13 @@ import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOutboundHandlerAdapter; import io.netty.channel.ChannelPromise; import io.netty.channel.DefaultChannelId; +import io.netty.channel.DefaultEventLoopGroup; import io.netty.channel.EventLoopGroup; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; @@ -431,6 +437,99 @@ public class SslHandlerTest { assertTrue(events.isEmpty()); } + @Test(timeout = 5000) + public void testHandshakeFailBeforeWritePromise() throws Exception { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + final SslContext sslServerCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()).build(); + final CountDownLatch latch = new CountDownLatch(2); + final CountDownLatch latch2 = new CountDownLatch(2); + final BlockingQueue events = new LinkedBlockingQueue(); + Channel serverChannel = null; + Channel clientChannel = null; + EventLoopGroup group = new DefaultEventLoopGroup(); + try { + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(sslServerCtx.newHandler(ch.alloc())); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) { + ByteBuf buf = ctx.alloc().buffer(10); + buf.writeZero(buf.capacity()); + ctx.writeAndFlush(buf).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + events.add(future); + latch.countDown(); + } + }); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof SslCompletionEvent) { + events.add(evt); + latch.countDown(); + latch2.countDown(); + } + } + }); + } + }); + + Bootstrap cb = new Bootstrap(); + cb.group(group) + .channel(LocalChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addFirst(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) { + ByteBuf buf = ctx.alloc().buffer(1000); + buf.writeZero(buf.capacity()); + ctx.writeAndFlush(buf); + } + }); + } + }); + + serverChannel = sb.bind(new LocalAddress("SslHandlerTest")).sync().channel(); + clientChannel = cb.connect(serverChannel.localAddress()).sync().channel(); + latch.await(); + + SslCompletionEvent evt = (SslCompletionEvent) events.take(); + assertTrue(evt instanceof SslHandshakeCompletionEvent); + assertThat(evt.cause(), is(instanceOf(SSLException.class))); + + ChannelFuture future = (ChannelFuture) events.take(); + assertThat(future.cause(), is(instanceOf(SSLException.class))); + + serverChannel.close().sync(); + serverChannel = null; + clientChannel.close().sync(); + clientChannel = null; + + latch2.await(); + evt = (SslCompletionEvent) events.take(); + assertTrue(evt instanceof SslCloseCompletionEvent); + assertThat(evt.cause(), is(instanceOf(ClosedChannelException.class))); + assertTrue(events.isEmpty()); + } finally { + if (serverChannel != null) { + serverChannel.close(); + } + if (clientChannel != null) { + clientChannel.close(); + } + group.shutdownGracefully(); + } + } + @Test public void writingReadOnlyBufferDoesNotBreakAggregation() throws Exception { SelfSignedCertificate ssc = new SelfSignedCertificate(); @@ -483,7 +582,7 @@ public class SslHandlerTest { firstBuffer.writeByte(0); firstBuffer = firstBuffer.asReadOnly(); ByteBuf secondBuffer = Unpooled.buffer(10); - secondBuffer.writerIndex(secondBuffer.capacity()); + secondBuffer.writeZero(secondBuffer.capacity()); cc.write(firstBuffer); cc.writeAndFlush(secondBuffer).syncUninterruptibly(); serverReceiveLatch.countDown();