SslHandler unwrap out of order promise/event notificaiton

Motivation:
SslHandler#decode methods catch any exceptions and attempt to wrap
before shutting down the engine. The intention is to write any alerts
which the engine may have pending. However the wrap process may also
attempt to write user data, and may also complete the associated
promises. If this is the case, and a promise listener closes the channel
then SslHandler may later propagate a SslHandshakeCompletionEvent user
event through the pipeline. Since the channel has already been closed
the user may no longer be paying attention to user events.

Modifications:
- Sslhandler#decode should first fail the associated handshake promise
and propagate the SslHandshakeCompletionEvent before attempting to wrap

Result:
Fixes https://github.com/netty/netty/issues/7639
This commit is contained in:
Scott Mitchell 2018-01-31 13:16:51 -08:00
parent 77156660ff
commit 94e55692d4
3 changed files with 114 additions and 7 deletions

View File

@ -1125,6 +1125,14 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
private void handleUnwrapThrowable(ChannelHandlerContext ctx, Throwable cause) {
try {
// We should attempt to notify the handshake failure before writing any pending data. If we are in unwrap
// and failed during the handshake process, and we attempt to wrap, then promises will fail, and if
// listeners immediately close the Channel then we may end up firing the handshake event after the Channel
// has been closed.
if (handshakePromise.tryFailure(cause)) {
ctx.fireUserEventTriggered(new SslHandshakeCompletionEvent(cause));
}
// We need to flush one time as there may be an alert that we should send to the remote peer because
// of the SSLException reported here.
wrapAndFlush(ctx);
@ -1132,7 +1140,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
logger.debug("SSLException during trying to call SSLEngine.wrap(...)" +
" because of an previous SSLException, ignoring...", ex);
} finally {
setHandshakeFailure(ctx, cause);
setHandshakeFailure(ctx, cause, true, false);
}
PlatformDependent.throwException(cause);
}

View File

@ -34,7 +34,7 @@ import org.junit.Assert;
import org.junit.Assume;
import org.junit.Test;
import java.nio.channels.ClosedChannelException;
import javax.net.ssl.SSLException;
public class SniClientTest {
@ -67,7 +67,7 @@ public class SniClientTest {
SniClientJava8TestUtil.testSniClient(SslProvider.JDK, SslProvider.JDK, true);
}
@Test(timeout = 30000, expected = ClosedChannelException.class)
@Test(timeout = 30000, expected = SSLException.class)
public void testSniSNIMatcherDoesNotMatchClientJdkSslServerJdkSsl() throws Exception {
Assume.assumeTrue(PlatformDependent.javaVersion() >= 8);
SniClientJava8TestUtil.testSniClient(SslProvider.JDK, SslProvider.JDK, false);
@ -80,7 +80,7 @@ public class SniClientTest {
SniClientJava8TestUtil.testSniClient(SslProvider.OPENSSL, SslProvider.OPENSSL, true);
}
@Test(timeout = 30000, expected = ClosedChannelException.class)
@Test(timeout = 30000, expected = SSLException.class)
public void testSniSNIMatcherDoesNotMatchClientOpenSslServerOpenSsl() throws Exception {
Assume.assumeTrue(PlatformDependent.javaVersion() >= 8);
Assume.assumeTrue(OpenSsl.isAvailable());
@ -94,7 +94,7 @@ public class SniClientTest {
SniClientJava8TestUtil.testSniClient(SslProvider.JDK, SslProvider.OPENSSL, true);
}
@Test(timeout = 30000, expected = ClosedChannelException.class)
@Test(timeout = 30000, expected = SSLException.class)
public void testSniSNIMatcherDoesNotMatchClientJdkSslServerOpenSsl() throws Exception {
Assume.assumeTrue(PlatformDependent.javaVersion() >= 8);
Assume.assumeTrue(OpenSsl.isAvailable());
@ -108,7 +108,7 @@ public class SniClientTest {
SniClientJava8TestUtil.testSniClient(SslProvider.OPENSSL, SslProvider.JDK, true);
}
@Test(timeout = 30000, expected = ClosedChannelException.class)
@Test(timeout = 30000, expected = SSLException.class)
public void testSniSNIMatcherDoesNotMatchClientOpenSslServerJdkSsl() throws Exception {
Assume.assumeTrue(PlatformDependent.javaVersion() >= 8);
Assume.assumeTrue(OpenSsl.isAvailable());

View File

@ -22,6 +22,8 @@ import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
@ -30,6 +32,10 @@ import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalEventLoopGroup;
import io.netty.channel.local.LocalServerChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
@ -382,6 +388,99 @@ public class SslHandlerTest {
assertTrue(events.isEmpty());
}
@Test(timeout = 5000)
public void testHandshakeFailBeforeWritePromise() throws Exception {
SelfSignedCertificate ssc = new SelfSignedCertificate();
final SslContext sslServerCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()).build();
final CountDownLatch latch = new CountDownLatch(2);
final CountDownLatch latch2 = new CountDownLatch(2);
final BlockingQueue<Object> events = new LinkedBlockingQueue<Object>();
Channel serverChannel = null;
Channel clientChannel = null;
EventLoopGroup group = new LocalEventLoopGroup();
try {
ServerBootstrap sb = new ServerBootstrap();
sb.group(group)
.channel(LocalServerChannel.class)
.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) {
ch.pipeline().addLast(sslServerCtx.newHandler(ch.alloc()));
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void channelActive(ChannelHandlerContext ctx) {
ByteBuf buf = ctx.alloc().buffer(10);
buf.writeZero(buf.capacity());
ctx.writeAndFlush(buf).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
events.add(future);
latch.countDown();
}
});
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof SslCompletionEvent) {
events.add(evt);
latch.countDown();
latch2.countDown();
}
}
});
}
});
Bootstrap cb = new Bootstrap();
cb.group(group)
.channel(LocalChannel.class)
.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) {
ch.pipeline().addFirst(new ChannelInboundHandlerAdapter() {
@Override
public void channelActive(ChannelHandlerContext ctx) {
ByteBuf buf = ctx.alloc().buffer(1000);
buf.writeZero(buf.capacity());
ctx.writeAndFlush(buf);
}
});
}
});
serverChannel = sb.bind(new LocalAddress("SslHandlerTest")).sync().channel();
clientChannel = cb.connect(serverChannel.localAddress()).sync().channel();
latch.await();
SslCompletionEvent evt = (SslCompletionEvent) events.take();
assertTrue(evt instanceof SslHandshakeCompletionEvent);
assertThat(evt.cause(), is(instanceOf(SSLException.class)));
ChannelFuture future = (ChannelFuture) events.take();
assertThat(future.cause(), is(instanceOf(SSLException.class)));
serverChannel.close().sync();
serverChannel = null;
clientChannel.close().sync();
clientChannel = null;
latch2.await();
evt = (SslCompletionEvent) events.take();
assertTrue(evt instanceof SslCloseCompletionEvent);
assertThat(evt.cause(), is(instanceOf(ClosedChannelException.class)));
assertTrue(events.isEmpty());
} finally {
if (serverChannel != null) {
serverChannel.close();
}
if (clientChannel != null) {
clientChannel.close();
}
group.shutdownGracefully();
}
}
@Test
public void writingReadOnlyBufferDoesNotBreakAggregation() throws Exception {
SelfSignedCertificate ssc = new SelfSignedCertificate();
@ -434,7 +533,7 @@ public class SslHandlerTest {
firstBuffer.writeByte(0);
firstBuffer = Unpooled.unmodifiableBuffer(firstBuffer);
ByteBuf secondBuffer = Unpooled.buffer(10);
secondBuffer.writerIndex(secondBuffer.capacity());
secondBuffer.writeZero(secondBuffer.capacity());
cc.write(firstBuffer);
cc.writeAndFlush(secondBuffer).syncUninterruptibly();
serverReceiveLatch.countDown();