Add websocket encoder / decoder in correct order to the pipeline when HttpServerCodec is used (#9386)

Motivation:

We need to ensure we place the encoder before the decoder when doing the websockets upgrade as the decoder may produce a close frame when protocol violations are detected.

Modifications:

- Correctly place encoder before decoder
- Add unit test

Result:

Fixes https://github.com/netty/netty/issues/9300
This commit is contained in:
Norman Maurer 2019-07-18 10:19:09 +02:00 committed by GitHub
parent dd1785ba66
commit 26c3abc63c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 56 additions and 11 deletions

View File

@ -209,8 +209,8 @@ public abstract class WebSocketServerHandshaker {
new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline")); new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
return promise; return promise;
} }
p.addBefore(ctx.name(), "wsdecoder", newWebsocketDecoder());
p.addBefore(ctx.name(), "wsencoder", newWebSocketEncoder()); p.addBefore(ctx.name(), "wsencoder", newWebSocketEncoder());
p.addBefore(ctx.name(), "wsdecoder", newWebsocketDecoder());
encoderName = ctx.name(); encoderName = ctx.name();
} else { } else {
p.replace(ctx.name(), "wsdecoder", newWebsocketDecoder()); p.replace(ctx.name(), "wsdecoder", newWebsocketDecoder());

View File

@ -16,6 +16,8 @@
package io.netty.handler.codec.http.websocketx; package io.netty.handler.codec.http.websocketx;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandler;
import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest; import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpRequest;
@ -27,10 +29,15 @@ import io.netty.handler.codec.http.HttpRequestDecoder;
import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseDecoder; import io.netty.handler.codec.http.HttpResponseDecoder;
import io.netty.handler.codec.http.HttpResponseEncoder; import io.netty.handler.codec.http.HttpResponseEncoder;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import org.hamcrest.CoreMatchers;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import java.util.Iterator;
import static io.netty.handler.codec.http.HttpVersion.*; import static io.netty.handler.codec.http.HttpVersion.*;
public class WebSocketServerHandshaker13Test { public class WebSocketServerHandshaker13Test {
@ -47,8 +54,50 @@ public class WebSocketServerHandshaker13Test {
private static void testPerformOpeningHandshake0(boolean subProtocol) { private static void testPerformOpeningHandshake0(boolean subProtocol) {
EmbeddedChannel ch = new EmbeddedChannel( EmbeddedChannel ch = new EmbeddedChannel(
new HttpObjectAggregator(42), new HttpRequestDecoder(), new HttpResponseEncoder()); new HttpObjectAggregator(42), new HttpResponseEncoder(), new HttpRequestDecoder());
if (subProtocol) {
testUpgrade0(ch, new WebSocketServerHandshaker13(
"ws://example.com/chat", "chat", false, Integer.MAX_VALUE, false));
} else {
testUpgrade0(ch, new WebSocketServerHandshaker13(
"ws://example.com/chat", null, false, Integer.MAX_VALUE, false));
}
Assert.assertFalse(ch.finish());
}
@Test
public void testCloseReasonWithEncoderAndDecoder() {
testCloseReason0(new HttpResponseEncoder(), new HttpRequestDecoder());
}
@Test
public void testCloseReasonWithCodec() {
testCloseReason0(new HttpServerCodec());
}
private static void testCloseReason0(ChannelHandler... handlers) {
EmbeddedChannel ch = new EmbeddedChannel(
new HttpObjectAggregator(42));
ch.pipeline().addLast(handlers);
testUpgrade0(ch, new WebSocketServerHandshaker13("ws://example.com/chat", "chat",
WebSocketDecoderConfig.newBuilder().maxFramePayloadLength(4).closeOnProtocolViolation(true).build()));
ch.writeOutbound(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(new byte[8])));
ByteBuf buffer = ch.readOutbound();
try {
ch.writeInbound(buffer);
Assert.fail();
} catch (CorruptedWebSocketFrameException expected) {
// expected
}
ReferenceCounted closeMessage = ch.readOutbound();
Assert.assertThat(closeMessage, CoreMatchers.instanceOf(ByteBuf.class));
closeMessage.release();
Assert.assertFalse(ch.finish());
}
private static void testUpgrade0(EmbeddedChannel ch, WebSocketServerHandshaker13 handshaker) {
FullHttpRequest req = new DefaultFullHttpRequest(HTTP_1_1, HttpMethod.GET, "/chat"); FullHttpRequest req = new DefaultFullHttpRequest(HTTP_1_1, HttpMethod.GET, "/chat");
req.headers().set(HttpHeaderNames.HOST, "server.example.com"); req.headers().set(HttpHeaderNames.HOST, "server.example.com");
req.headers().set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET); req.headers().set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET);
@ -58,13 +107,7 @@ public class WebSocketServerHandshaker13Test {
req.headers().set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, "chat, superchat"); req.headers().set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, "chat, superchat");
req.headers().set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, "13"); req.headers().set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, "13");
if (subProtocol) { handshaker.handshake(ch, req);
new WebSocketServerHandshaker13(
"ws://example.com/chat", "chat", false, Integer.MAX_VALUE, false).handshake(ch, req);
} else {
new WebSocketServerHandshaker13(
"ws://example.com/chat", null, false, Integer.MAX_VALUE, false).handshake(ch, req);
}
ByteBuf resBuf = ch.readOutbound(); ByteBuf resBuf = ch.readOutbound();
@ -74,8 +117,10 @@ public class WebSocketServerHandshaker13Test {
Assert.assertEquals( Assert.assertEquals(
"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", res.headers().get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT)); "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", res.headers().get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT));
if (subProtocol) { Iterator<String> subProtocols = handshaker.subprotocols().iterator();
Assert.assertEquals("chat", res.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL)); if (subProtocols.hasNext()) {
Assert.assertEquals(subProtocols.next(),
res.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL));
} else { } else {
Assert.assertNull(res.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL)); Assert.assertNull(res.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL));
} }