Ensure WebSocket*Handshaker can not corrupt pipeline when HttpProxyHa… (#10103)

Motivation:

HttpProxyHandler itself will add a HttpClientCodec into the ChannelPipeline and so confuse the WebSocket*Handshaker when it tries to modify the pipeline as it will replace the wrong HttpClientCodec.

Modifications:

Wrap the internal HttpClientCodec that is added by HttpProxyHandler so it will not be replaced by HttpProxyHandler.

Result:

Fixes  https://github.com/netty/netty/issues/5201 and https://github.com/netty/netty/issues/5070
This commit is contained in:
Norman Maurer 2020-03-16 11:38:41 +01:00
parent b2ee13d2fc
commit 598b04f1d0
2 changed files with 130 additions and 5 deletions

View File

@ -19,8 +19,10 @@ package io.netty.handler.proxy;
import static java.util.Objects.requireNonNull;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpClientCodec;
@ -44,10 +46,15 @@ public final class HttpProxyHandler extends ProxyHandler {
private static final String PROTOCOL = "http";
private static final String AUTH_BASIC = "basic";
private static final byte[] BASIC_BYTES = "Basic ".getBytes(StandardCharsets.UTF_8);
private final HttpClientCodec codec = new HttpClientCodec();
// Wrapper for the HttpClientCodec to prevent it to be removed by other handlers by mistake (for example the
// WebSocket*Handshaker.
//
// See:
// - https://github.com/netty/netty/issues/5201
// - https://github.com/netty/netty/issues/5070
private final HttpClientCodecWrapper codecWrapper = new HttpClientCodecWrapper();
private final String username;
private final String password;
private final CharSequence authorization;
@ -128,17 +135,17 @@ public final class HttpProxyHandler extends ProxyHandler {
protected void addCodec(ChannelHandlerContext ctx) throws Exception {
ChannelPipeline p = ctx.pipeline();
String name = ctx.name();
p.addBefore(name, null, codec);
p.addBefore(name, null, codecWrapper);
}
@Override
protected void removeEncoder(ChannelHandlerContext ctx) throws Exception {
codec.removeOutboundHandler();
codecWrapper.codec.removeOutboundHandler();
}
@Override
protected void removeDecoder(ChannelHandlerContext ctx) throws Exception {
codec.removeInboundHandler();
codecWrapper.codec.removeInboundHandler();
}
@Override
@ -218,4 +225,105 @@ public final class HttpProxyHandler extends ProxyHandler {
return headers;
}
}
private static final class HttpClientCodecWrapper implements ChannelHandler {
final HttpClientCodec codec = new HttpClientCodec();
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
codec.handlerAdded(ctx);
}
@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
codec.handlerRemoved(ctx);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
codec.exceptionCaught(ctx, cause);
}
@Override
public void channelRegistered(ChannelHandlerContext ctx) throws Exception {
codec.channelRegistered(ctx);
}
@Override
public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
codec.channelUnregistered(ctx);
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
codec.channelActive(ctx);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
codec.channelInactive(ctx);
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
codec.channelRead(ctx, msg);
}
@Override
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
codec.channelReadComplete(ctx);
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
codec.userEventTriggered(ctx, evt);
}
@Override
public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
codec.channelWritabilityChanged(ctx);
}
@Override
public void bind(ChannelHandlerContext ctx, SocketAddress localAddress,
ChannelPromise promise) throws Exception {
codec.bind(ctx, localAddress, promise);
}
@Override
public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress,
ChannelPromise promise) throws Exception {
codec.connect(ctx, remoteAddress, localAddress, promise);
}
@Override
public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
codec.disconnect(ctx, promise);
}
@Override
public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
codec.close(ctx, promise);
}
@Override
public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
codec.deregister(ctx, promise);
}
@Override
public void read(ChannelHandlerContext ctx) throws Exception {
codec.read(ctx);
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
codec.write(ctx, msg, promise);
}
@Override
public void flush(ChannelHandlerContext ctx) throws Exception {
codec.flush(ctx);
}
}
}

View File

@ -25,6 +25,7 @@ import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPromise;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.MultithreadEventLoopGroup;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalHandler;
@ -32,6 +33,7 @@ import io.netty.channel.local.LocalServerChannel;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpResponseEncoder;
@ -39,6 +41,7 @@ import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.proxy.HttpProxyHandler.HttpProxyConnectException;
import io.netty.util.NetUtil;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.Test;
@ -271,4 +274,18 @@ public class HttpProxyHandlerTest {
}
verify(ctx).connect(proxyAddress, null, promise);
}
@Test
public void testHttpClientCodecIsInvisible() {
EmbeddedChannel channel = new EmbeddedChannel(new HttpProxyHandler(
new InetSocketAddress(NetUtil.LOCALHOST, 8080))) {
@Override
public boolean isActive() {
// We want to simulate that the Channel did not become active yet.
return false;
}
};
assertNotNull(channel.pipeline().get(HttpProxyHandler.class));
assertNull(channel.pipeline().get(HttpClientCodec.class));
}
}