Add websocket encoder / decoder in correct order to the pipeline when HttpServerCodec is used (#9386)

Motivation:

We need to ensure we place the encoder before the decoder when doing the websockets upgrade as the decoder may produce a close frame when protocol violations are detected.

Modifications:

- Correctly place encoder before decoder
- Add unit test

Result:

Fixes https://github.com/netty/netty/issues/9300
This commit is contained in:
Norman Maurer 2019-07-18 10:19:09 +02:00
parent 8e370fbbcd
commit 2ef4b16138
2 changed files with 56 additions and 11 deletions

View File

@ -211,8 +211,8 @@ public abstract class WebSocketServerHandshaker {
new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
return promise;
}
p.addBefore(ctx.name(), "wsdecoder", newWebsocketDecoder());
p.addBefore(ctx.name(), "wsencoder", newWebSocketEncoder());
p.addBefore(ctx.name(), "wsdecoder", newWebsocketDecoder());
encoderName = ctx.name();
} else {
p.replace(ctx.name(), "wsdecoder", newWebsocketDecoder());

View File

@ -16,6 +16,8 @@
package io.netty.handler.codec.http.websocketx;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandler;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
@ -27,10 +29,15 @@ import io.netty.handler.codec.http.HttpRequestDecoder;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseDecoder;
import io.netty.handler.codec.http.HttpResponseEncoder;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import org.hamcrest.CoreMatchers;
import org.junit.Assert;
import org.junit.Test;
import java.util.Iterator;
import static io.netty.handler.codec.http.HttpVersion.*;
public class WebSocketServerHandshaker13Test {
@ -47,8 +54,50 @@ public class WebSocketServerHandshaker13Test {
private static void testPerformOpeningHandshake0(boolean subProtocol) {
EmbeddedChannel ch = new EmbeddedChannel(
new HttpObjectAggregator(42), new HttpRequestDecoder(), new HttpResponseEncoder());
new HttpObjectAggregator(42), new HttpResponseEncoder(), new HttpRequestDecoder());
if (subProtocol) {
testUpgrade0(ch, new WebSocketServerHandshaker13(
"ws://example.com/chat", "chat", false, Integer.MAX_VALUE, false));
} else {
testUpgrade0(ch, new WebSocketServerHandshaker13(
"ws://example.com/chat", null, false, Integer.MAX_VALUE, false));
}
Assert.assertFalse(ch.finish());
}
@Test
public void testCloseReasonWithEncoderAndDecoder() {
testCloseReason0(new HttpResponseEncoder(), new HttpRequestDecoder());
}
@Test
public void testCloseReasonWithCodec() {
testCloseReason0(new HttpServerCodec());
}
private static void testCloseReason0(ChannelHandler... handlers) {
EmbeddedChannel ch = new EmbeddedChannel(
new HttpObjectAggregator(42));
ch.pipeline().addLast(handlers);
testUpgrade0(ch, new WebSocketServerHandshaker13("ws://example.com/chat", "chat",
WebSocketDecoderConfig.newBuilder().maxFramePayloadLength(4).closeOnProtocolViolation(true).build()));
ch.writeOutbound(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(new byte[8])));
ByteBuf buffer = ch.readOutbound();
try {
ch.writeInbound(buffer);
Assert.fail();
} catch (CorruptedWebSocketFrameException expected) {
// expected
}
ReferenceCounted closeMessage = ch.readOutbound();
Assert.assertThat(closeMessage, CoreMatchers.instanceOf(ByteBuf.class));
closeMessage.release();
Assert.assertFalse(ch.finish());
}
private static void testUpgrade0(EmbeddedChannel ch, WebSocketServerHandshaker13 handshaker) {
FullHttpRequest req = new DefaultFullHttpRequest(HTTP_1_1, HttpMethod.GET, "/chat");
req.headers().set(HttpHeaderNames.HOST, "server.example.com");
req.headers().set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET);
@ -58,13 +107,7 @@ public class WebSocketServerHandshaker13Test {
req.headers().set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, "chat, superchat");
req.headers().set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, "13");
if (subProtocol) {
new WebSocketServerHandshaker13(
"ws://example.com/chat", "chat", false, Integer.MAX_VALUE, false).handshake(ch, req);
} else {
new WebSocketServerHandshaker13(
"ws://example.com/chat", null, false, Integer.MAX_VALUE, false).handshake(ch, req);
}
handshaker.handshake(ch, req);
ByteBuf resBuf = ch.readOutbound();
@ -74,8 +117,10 @@ public class WebSocketServerHandshaker13Test {
Assert.assertEquals(
"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", res.headers().get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT));
if (subProtocol) {
Assert.assertEquals("chat", res.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL));
Iterator<String> subProtocols = handshaker.subprotocols().iterator();
if (subProtocols.hasNext()) {
Assert.assertEquals(subProtocols.next(),
res.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL));
} else {
Assert.assertNull(res.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL));
}