Remove the Handler only after it has initialized the channel (#9132)

Motivation:

Previously, any 'relative' pipeline operations, such as
ctx.pipeline().replace(), .addBefore(), addAfter(), etc
would fail as the handler was not present in the pipeline.

Modification:

Used the pattern from ChannelInitializer when invoking configurePipeline().

Result:

Fixes #9131
This commit is contained in:
RoganDawes 2019-05-13 13:49:17 +02:00 committed by Norman Maurer
parent cb85e03d72
commit 3221bf6854
2 changed files with 55 additions and 31 deletions

View File

@ -79,22 +79,29 @@ public abstract class ApplicationProtocolNegotiationHandler extends ChannelInbou
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof SslHandshakeCompletionEvent) {
ctx.pipeline().remove(this);
SslHandshakeCompletionEvent handshakeEvent = (SslHandshakeCompletionEvent) evt;
if (handshakeEvent.isSuccess()) {
SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
if (sslHandler == null) {
throw new IllegalStateException("cannot find a SslHandler in the pipeline (required for " +
"application-level protocol negotiation)");
try {
SslHandshakeCompletionEvent handshakeEvent = (SslHandshakeCompletionEvent) evt;
if (handshakeEvent.isSuccess()) {
SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
if (sslHandler == null) {
throw new IllegalStateException("cannot find a SslHandler in the pipeline (required for "
+ "application-level protocol negotiation)");
}
String protocol = sslHandler.applicationProtocol();
configurePipeline(ctx, protocol != null ? protocol : fallbackProtocol);
} else {
handshakeFailure(ctx, handshakeEvent.cause());
}
} catch (Throwable cause) {
exceptionCaught(ctx, cause);
} finally {
ChannelPipeline pipeline = ctx.pipeline();
if (pipeline.context(this) != null) {
pipeline.remove(this);
}
String protocol = sslHandler.applicationProtocol();
configurePipeline(ctx, protocol != null? protocol : fallbackProtocol);
} else {
handshakeFailure(ctx, handshakeEvent.cause());
}
}
ctx.fireUserEventTriggered(evt);
}
@ -119,6 +126,7 @@ public abstract class ApplicationProtocolNegotiationHandler extends ChannelInbou
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
logger.warn("{} Failed to select the application-level protocol:", ctx.channel(), cause);
ctx.fireExceptionCaught(cause);
ctx.close();
}
}

View File

@ -16,6 +16,32 @@
package io.netty.handler.ssl;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.junit.Assume.assumeTrue;
import java.io.File;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.SSLEngine;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
@ -44,28 +70,10 @@ import io.netty.util.DomainNameMappingBuilder;
import io.netty.util.Mapping;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import io.netty.util.internal.ResourcesUtil;
import io.netty.util.concurrent.Promise;
import io.netty.util.internal.ObjectUtil;
import io.netty.util.internal.ResourcesUtil;
import io.netty.util.internal.StringUtil;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.io.File;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.SSLEngine;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.junit.Assert.*;
import static org.junit.Assume.assumeTrue;
@RunWith(Parameterized.class)
public class SniHandlerTest {
@ -339,6 +347,8 @@ public class SniHandlerTest {
SslContext sniContext = makeSslContext(provider, true);
final SslContext clientContext = makeSslClientContext(provider, true);
try {
final AtomicBoolean serverApnCtx = new AtomicBoolean(false);
final AtomicBoolean clientApnCtx = new AtomicBoolean(false);
final CountDownLatch serverApnDoneLatch = new CountDownLatch(1);
final CountDownLatch clientApnDoneLatch = new CountDownLatch(1);
@ -363,6 +373,8 @@ public class SniHandlerTest {
p.addLast(new ApplicationProtocolNegotiationHandler("foo") {
@Override
protected void configurePipeline(ChannelHandlerContext ctx, String protocol) {
// addresses issue #9131
serverApnCtx.set(ctx.pipeline().context(this) != null);
serverApnDoneLatch.countDown();
}
});
@ -381,6 +393,8 @@ public class SniHandlerTest {
ch.pipeline().addLast(new ApplicationProtocolNegotiationHandler("foo") {
@Override
protected void configurePipeline(ChannelHandlerContext ctx, String protocol) {
// addresses issue #9131
clientApnCtx.set(ctx.pipeline().context(this) != null);
clientApnDoneLatch.countDown();
}
});
@ -395,6 +409,8 @@ public class SniHandlerTest {
assertTrue(serverApnDoneLatch.await(5, TimeUnit.SECONDS));
assertTrue(clientApnDoneLatch.await(5, TimeUnit.SECONDS));
assertTrue(serverApnCtx.get());
assertTrue(clientApnCtx.get());
assertThat(handler.hostname(), is("sni.fake.site"));
assertThat(handler.sslContext(), is(sniContext));
} finally {