Fix remove 'WebSocketServerExtensionHandler' from pipeline after upgrade (#9940)

Motivation:

We should remove WebSocketServerExtensionHandler from pipeline after successful WebSocket upgrade even if the client has not selected any extensions.

Modification:

Remove handler once upgrade is complete and no extensions are used.

Result:

Fixes #9939.
This commit is contained in:
Andrey Mizurov 2020-01-15 12:49:51 +03:00 committed by Norman Maurer
parent 776e38af88
commit 91404e1828
2 changed files with 73 additions and 28 deletions

View File

@ -22,6 +22,7 @@ import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
@ -62,8 +63,7 @@ public class WebSocketServerExtensionHandler implements ChannelHandler {
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg)
throws Exception {
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof HttpRequest) {
HttpRequest request = (HttpRequest) msg;
@ -103,32 +103,42 @@ public class WebSocketServerExtensionHandler implements ChannelHandler {
@Override
public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
if (msg instanceof HttpResponse &&
WebSocketExtensionUtil.isWebsocketUpgrade(((HttpResponse) msg).headers()) && validExtensions != null) {
HttpResponse response = (HttpResponse) msg;
String headerValue = response.headers().getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
if (msg instanceof HttpResponse) {
HttpHeaders headers = ((HttpResponse) msg).headers();
if (WebSocketExtensionUtil.isWebsocketUpgrade(headers)) {
if (validExtensions != null) {
String headerValue = headers.getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
for (WebSocketServerExtension extension : validExtensions) {
WebSocketExtensionData extensionData = extension.newResponseData();
headerValue = WebSocketExtensionUtil.appendExtension(headerValue,
extensionData.name(), extensionData.parameters());
extensionData.name(),
extensionData.parameters());
}
promise.addListener((ChannelFutureListener) future -> {
if (future.isSuccess()) {
for (WebSocketServerExtension extension : validExtensions) {
WebSocketExtensionDecoder decoder = extension.newExtensionDecoder();
WebSocketExtensionEncoder encoder = extension.newExtensionEncoder();
ctx.pipeline().addAfter(ctx.name(), decoder.getClass().getName(), decoder);
ctx.pipeline().addAfter(ctx.name(), encoder.getClass().getName(), encoder);
ctx.pipeline()
.addAfter(ctx.name(), decoder.getClass().getName(), decoder)
.addAfter(ctx.name(), encoder.getClass().getName(), encoder);
}
}
ctx.pipeline().remove(ctx.name());
});
if (headerValue != null) {
response.headers().set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, headerValue);
headers.set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, headerValue);
}
}
promise.addListener((ChannelFutureListener) future -> {
if (future.isSuccess()) {
ctx.pipeline().remove(WebSocketServerExtensionHandler.this);
}
});
}
}

View File

@ -15,11 +15,13 @@
*/
package io.netty.handler.codec.http.websocketx.extensions;
import io.netty.channel.ChannelPromise;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
@ -62,8 +64,9 @@ public class WebSocketServerExtensionHandlerTest {
when(fallbackExtensionMock.rsv()).thenReturn(WebSocketExtension.RSV1);
// execute
EmbeddedChannel ch = new EmbeddedChannel(new WebSocketServerExtensionHandler(
mainHandshakerMock, fallbackHandshakerMock));
WebSocketServerExtensionHandler extensionHandler =
new WebSocketServerExtensionHandler(mainHandshakerMock, fallbackHandshakerMock);
EmbeddedChannel ch = new EmbeddedChannel(extensionHandler);
HttpRequest req = newUpgradeRequest("main, fallback");
ch.writeInbound(req);
@ -76,6 +79,7 @@ public class WebSocketServerExtensionHandlerTest {
res2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS));
// test
assertNull(ch.pipeline().context(extensionHandler));
assertEquals(1, resExts.size());
assertEquals("main", resExts.get(0).name());
assertTrue(resExts.get(0).parameters().isEmpty());
@ -119,8 +123,9 @@ public class WebSocketServerExtensionHandlerTest {
when(fallbackExtensionMock.newExtensionDecoder()).thenReturn(new Dummy2Decoder());
// execute
EmbeddedChannel ch = new EmbeddedChannel(new WebSocketServerExtensionHandler(
mainHandshakerMock, fallbackHandshakerMock));
WebSocketServerExtensionHandler extensionHandler =
new WebSocketServerExtensionHandler(mainHandshakerMock, fallbackHandshakerMock);
EmbeddedChannel ch = new EmbeddedChannel(extensionHandler);
HttpRequest req = newUpgradeRequest("main, fallback");
ch.writeInbound(req);
@ -133,6 +138,7 @@ public class WebSocketServerExtensionHandlerTest {
res2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS));
// test
assertNull(ch.pipeline().context(extensionHandler));
assertEquals(2, resExts.size());
assertEquals("main", resExts.get(0).name());
assertEquals("fallback", resExts.get(1).name());
@ -170,8 +176,9 @@ public class WebSocketServerExtensionHandlerTest {
thenReturn(null);
// execute
EmbeddedChannel ch = new EmbeddedChannel(new WebSocketServerExtensionHandler(
mainHandshakerMock, fallbackHandshakerMock));
WebSocketServerExtensionHandler extensionHandler =
new WebSocketServerExtensionHandler(mainHandshakerMock, fallbackHandshakerMock);
EmbeddedChannel ch = new EmbeddedChannel(extensionHandler);
HttpRequest req = newUpgradeRequest("unknown, unknown2");
ch.writeInbound(req);
@ -182,6 +189,7 @@ public class WebSocketServerExtensionHandlerTest {
HttpResponse res2 = ch.readOutbound();
// test
assertNull(ch.pipeline().context(extensionHandler));
assertFalse(res2.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS));
verify(mainHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("unknown"));
@ -190,4 +198,31 @@ public class WebSocketServerExtensionHandlerTest {
verify(fallbackHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("unknown"));
verify(fallbackHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("unknown2"));
}
@Test
public void testExtensionHandlerNotRemovedByFailureWritePromise() {
// initialize
when(mainHandshakerMock.handshakeExtension(webSocketExtensionDataMatcher("main")))
.thenReturn(mainExtensionMock);
when(mainExtensionMock.newResponseData()).thenReturn(
new WebSocketExtensionData("main", Collections.<String, String>emptyMap()));
// execute
WebSocketServerExtensionHandler extensionHandler =
new WebSocketServerExtensionHandler(mainHandshakerMock);
EmbeddedChannel ch = new EmbeddedChannel(extensionHandler);
HttpRequest req = newUpgradeRequest("main");
ch.writeInbound(req);
HttpResponse res = newUpgradeResponse(null);
ChannelPromise failurePromise = ch.newPromise();
ch.writeOneOutbound(res, failurePromise);
failurePromise.setFailure(new IOException("Cannot write response"));
// test
assertNull(ch.readOutbound());
assertNotNull(ch.pipeline().context(extensionHandler));
assertTrue(ch.finish());
}
}