Don't produce multiple calls to exceptionCaught(...) on SSL handshake failure (#10134)

Motivation:

Before release 4.1.23, there was only ONE call to exceptionCaught method when an ssl handshake failure occurs, now we have two when using the JDK provider.

Modifications:

- Ensure we only propagate one exception fi we already failed the handshake and channelInactive(...) produce an exception again
- Add unit test

Result:

Fixes https://github.com/netty/netty/issues/10119
This commit is contained in:
Norman Maurer 2020-03-26 09:12:39 +01:00 committed by GitHub
parent 0c3eae34ec
commit 62e3c463c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 143 additions and 1 deletions

View File

@ -1073,6 +1073,8 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
boolean handshakeFailed = handshakePromise.cause() != null;
ClosedChannelException exception = new ClosedChannelException();
// Make sure to release SSLEngine,
// and notify the handshake future if the connection has been closed during handshake.
@ -1081,7 +1083,18 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
// Ensure we always notify the sslClosePromise as well
notifyClosePromise(exception);
super.channelInactive(ctx);
try {
super.channelInactive(ctx);
} catch (DecoderException e) {
if (!handshakeFailed || !(e.getCause() instanceof SSLException)) {
// We only rethrow the exception if the handshake did not fail before channelInactive(...) was called
// as otherwise this may produce duplicated failures as super.channelInactive(...) will also call
// channelRead(...).
//
// See https://github.com/netty/netty/issues/10119
throw e;
}
}
}
@Override

View File

@ -58,13 +58,19 @@ import io.netty.util.concurrent.FutureListener;
import io.netty.util.concurrent.ImmediateEventExecutor;
import io.netty.util.concurrent.ImmediateExecutor;
import io.netty.util.concurrent.Promise;
import io.netty.util.internal.EmptyArrays;
import io.netty.util.internal.PlatformDependent;
import org.hamcrest.CoreMatchers;
import org.hamcrest.Matchers;
import org.junit.Test;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.nio.channels.ClosedChannelException;
import java.security.KeyStore;
import java.security.NoSuchAlgorithmException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.BlockingQueue;
@ -79,10 +85,13 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.ManagerFactoryParameters;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLProtocolException;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509ExtendedTrustManager;
import static io.netty.buffer.Unpooled.wrappedBuffer;
import static org.hamcrest.CoreMatchers.containsString;
@ -92,6 +101,8 @@ import static org.hamcrest.CoreMatchers.nullValue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@ -1180,4 +1191,122 @@ public class SslHandlerTest {
ReferenceCountUtil.release(sslClientCtx);
}
}
@Test(timeout = 10000L)
public void testHandshakeFailureOnlyFireExceptionOnce() throws Exception {
final SslContext sslClientCtx = SslContextBuilder.forClient()
.trustManager(new X509ExtendedTrustManager() {
@Override
public void checkClientTrusted(X509Certificate[] chain, String authType, Socket socket)
throws CertificateException {
failVerification();
}
@Override
public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket)
throws CertificateException {
failVerification();
}
@Override
public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngine engine)
throws CertificateException {
failVerification();
}
@Override
public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine engine)
throws CertificateException {
failVerification();
}
@Override
public void checkClientTrusted(X509Certificate[] chain, String authType)
throws CertificateException {
failVerification();
}
@Override
public void checkServerTrusted(X509Certificate[] chain, String authType)
throws CertificateException {
failVerification();
}
@Override
public X509Certificate[] getAcceptedIssuers() {
return EmptyArrays.EMPTY_X509_CERTIFICATES;
}
private void failVerification() throws CertificateException {
throw new CertificateException();
}
})
.sslProvider(SslProvider.JDK).build();
final SelfSignedCertificate cert = new SelfSignedCertificate();
final SslContext sslServerCtx = SslContextBuilder.forServer(cert.key(), cert.cert())
.sslProvider(SslProvider.JDK).build();
EventLoopGroup group = new NioEventLoopGroup();
Channel sc = null;
final SslHandler clientSslHandler = sslClientCtx.newHandler(UnpooledByteBufAllocator.DEFAULT);
final SslHandler serverSslHandler = sslServerCtx.newHandler(UnpooledByteBufAllocator.DEFAULT);
try {
final Object terminalEvent = new Object();
final BlockingQueue<Object> errorQueue = new LinkedBlockingQueue<Object>();
sc = new ServerBootstrap()
.group(group)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) {
ch.pipeline().addLast(serverSslHandler);
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void exceptionCaught(final ChannelHandlerContext ctx, Throwable cause) {
errorQueue.add(cause);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) {
errorQueue.add(terminalEvent);
}
});
}
})
.bind(new InetSocketAddress(0)).syncUninterruptibly().channel();
final ChannelFuture future = new Bootstrap()
.group(group)
.channel(NioSocketChannel.class)
.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) {
ch.pipeline().addLast(clientSslHandler);
}
}).connect(sc.localAddress());
future.syncUninterruptibly();
clientSslHandler.handshakeFuture().addListener(new FutureListener<Channel>() {
@Override
public void operationComplete(Future<Channel> f) {
future.channel().close();
}
});
assertFalse(clientSslHandler.handshakeFuture().await().isSuccess());
assertFalse(serverSslHandler.handshakeFuture().await().isSuccess());
Object error = errorQueue.take();
assertThat(error, Matchers.instanceOf(DecoderException.class));
assertThat(((Throwable) error).getCause(), Matchers.<Throwable>instanceOf(SSLException.class));
Object terminal = errorQueue.take();
assertSame(terminalEvent, terminal);
assertNull(errorQueue.poll(1, TimeUnit.MILLISECONDS));
} finally {
if (sc != null) {
sc.close().syncUninterruptibly();
}
group.shutdownGracefully();
}
}
}