Fix remove 'WebSocketServerExtensionHandler' from pipeline after upgrade (#9940)

Motivation:

We should remove WebSocketServerExtensionHandler from pipeline after successful WebSocket upgrade even if the client has not selected any extensions.

Modification:

Remove handler once upgrade is complete and no extensions are used.

Result:

Fixes #9939.
This commit is contained in:
Andrey Mizurov 2020-01-15 12:49:51 +03:00 committed by Norman Maurer
parent 776e38af88
commit 91404e1828
2 changed files with 73 additions and 28 deletions

View File

@ -22,6 +22,7 @@ import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponse;
@ -62,8 +63,7 @@ public class WebSocketServerExtensionHandler implements ChannelHandler {
} }
@Override @Override
public void channelRead(ChannelHandlerContext ctx, Object msg) public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
throws Exception {
if (msg instanceof HttpRequest) { if (msg instanceof HttpRequest) {
HttpRequest request = (HttpRequest) msg; HttpRequest request = (HttpRequest) msg;
@ -103,32 +103,42 @@ public class WebSocketServerExtensionHandler implements ChannelHandler {
@Override @Override
public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
if (msg instanceof HttpResponse && if (msg instanceof HttpResponse) {
WebSocketExtensionUtil.isWebsocketUpgrade(((HttpResponse) msg).headers()) && validExtensions != null) { HttpHeaders headers = ((HttpResponse) msg).headers();
HttpResponse response = (HttpResponse) msg;
String headerValue = response.headers().getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS); if (WebSocketExtensionUtil.isWebsocketUpgrade(headers)) {
if (validExtensions != null) {
String headerValue = headers.getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
for (WebSocketServerExtension extension : validExtensions) { for (WebSocketServerExtension extension : validExtensions) {
WebSocketExtensionData extensionData = extension.newResponseData(); WebSocketExtensionData extensionData = extension.newResponseData();
headerValue = WebSocketExtensionUtil.appendExtension(headerValue, headerValue = WebSocketExtensionUtil.appendExtension(headerValue,
extensionData.name(), extensionData.parameters()); extensionData.name(),
extensionData.parameters());
} }
promise.addListener((ChannelFutureListener) future -> { promise.addListener((ChannelFutureListener) future -> {
if (future.isSuccess()) { if (future.isSuccess()) {
for (WebSocketServerExtension extension : validExtensions) { for (WebSocketServerExtension extension : validExtensions) {
WebSocketExtensionDecoder decoder = extension.newExtensionDecoder(); WebSocketExtensionDecoder decoder = extension.newExtensionDecoder();
WebSocketExtensionEncoder encoder = extension.newExtensionEncoder(); WebSocketExtensionEncoder encoder = extension.newExtensionEncoder();
ctx.pipeline().addAfter(ctx.name(), decoder.getClass().getName(), decoder); ctx.pipeline()
ctx.pipeline().addAfter(ctx.name(), encoder.getClass().getName(), encoder); .addAfter(ctx.name(), decoder.getClass().getName(), decoder)
.addAfter(ctx.name(), encoder.getClass().getName(), encoder);
} }
} }
ctx.pipeline().remove(ctx.name());
}); });
if (headerValue != null) { if (headerValue != null) {
response.headers().set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, headerValue); headers.set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, headerValue);
}
}
promise.addListener((ChannelFutureListener) future -> {
if (future.isSuccess()) {
ctx.pipeline().remove(WebSocketServerExtensionHandler.this);
}
});
} }
} }

View File

@ -15,11 +15,13 @@
*/ */
package io.netty.handler.codec.http.websocketx.extensions; package io.netty.handler.codec.http.websocketx.extensions;
import io.netty.channel.ChannelPromise;
import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponse;
import java.io.IOException;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -62,8 +64,9 @@ public class WebSocketServerExtensionHandlerTest {
when(fallbackExtensionMock.rsv()).thenReturn(WebSocketExtension.RSV1); when(fallbackExtensionMock.rsv()).thenReturn(WebSocketExtension.RSV1);
// execute // execute
EmbeddedChannel ch = new EmbeddedChannel(new WebSocketServerExtensionHandler( WebSocketServerExtensionHandler extensionHandler =
mainHandshakerMock, fallbackHandshakerMock)); new WebSocketServerExtensionHandler(mainHandshakerMock, fallbackHandshakerMock);
EmbeddedChannel ch = new EmbeddedChannel(extensionHandler);
HttpRequest req = newUpgradeRequest("main, fallback"); HttpRequest req = newUpgradeRequest("main, fallback");
ch.writeInbound(req); ch.writeInbound(req);
@ -76,6 +79,7 @@ public class WebSocketServerExtensionHandlerTest {
res2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS)); res2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS));
// test // test
assertNull(ch.pipeline().context(extensionHandler));
assertEquals(1, resExts.size()); assertEquals(1, resExts.size());
assertEquals("main", resExts.get(0).name()); assertEquals("main", resExts.get(0).name());
assertTrue(resExts.get(0).parameters().isEmpty()); assertTrue(resExts.get(0).parameters().isEmpty());
@ -119,8 +123,9 @@ public class WebSocketServerExtensionHandlerTest {
when(fallbackExtensionMock.newExtensionDecoder()).thenReturn(new Dummy2Decoder()); when(fallbackExtensionMock.newExtensionDecoder()).thenReturn(new Dummy2Decoder());
// execute // execute
EmbeddedChannel ch = new EmbeddedChannel(new WebSocketServerExtensionHandler( WebSocketServerExtensionHandler extensionHandler =
mainHandshakerMock, fallbackHandshakerMock)); new WebSocketServerExtensionHandler(mainHandshakerMock, fallbackHandshakerMock);
EmbeddedChannel ch = new EmbeddedChannel(extensionHandler);
HttpRequest req = newUpgradeRequest("main, fallback"); HttpRequest req = newUpgradeRequest("main, fallback");
ch.writeInbound(req); ch.writeInbound(req);
@ -133,6 +138,7 @@ public class WebSocketServerExtensionHandlerTest {
res2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS)); res2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS));
// test // test
assertNull(ch.pipeline().context(extensionHandler));
assertEquals(2, resExts.size()); assertEquals(2, resExts.size());
assertEquals("main", resExts.get(0).name()); assertEquals("main", resExts.get(0).name());
assertEquals("fallback", resExts.get(1).name()); assertEquals("fallback", resExts.get(1).name());
@ -170,8 +176,9 @@ public class WebSocketServerExtensionHandlerTest {
thenReturn(null); thenReturn(null);
// execute // execute
EmbeddedChannel ch = new EmbeddedChannel(new WebSocketServerExtensionHandler( WebSocketServerExtensionHandler extensionHandler =
mainHandshakerMock, fallbackHandshakerMock)); new WebSocketServerExtensionHandler(mainHandshakerMock, fallbackHandshakerMock);
EmbeddedChannel ch = new EmbeddedChannel(extensionHandler);
HttpRequest req = newUpgradeRequest("unknown, unknown2"); HttpRequest req = newUpgradeRequest("unknown, unknown2");
ch.writeInbound(req); ch.writeInbound(req);
@ -182,6 +189,7 @@ public class WebSocketServerExtensionHandlerTest {
HttpResponse res2 = ch.readOutbound(); HttpResponse res2 = ch.readOutbound();
// test // test
assertNull(ch.pipeline().context(extensionHandler));
assertFalse(res2.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS)); assertFalse(res2.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS));
verify(mainHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("unknown")); verify(mainHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("unknown"));
@ -190,4 +198,31 @@ public class WebSocketServerExtensionHandlerTest {
verify(fallbackHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("unknown")); verify(fallbackHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("unknown"));
verify(fallbackHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("unknown2")); verify(fallbackHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("unknown2"));
} }
@Test
public void testExtensionHandlerNotRemovedByFailureWritePromise() {
// initialize
when(mainHandshakerMock.handshakeExtension(webSocketExtensionDataMatcher("main")))
.thenReturn(mainExtensionMock);
when(mainExtensionMock.newResponseData()).thenReturn(
new WebSocketExtensionData("main", Collections.<String, String>emptyMap()));
// execute
WebSocketServerExtensionHandler extensionHandler =
new WebSocketServerExtensionHandler(mainHandshakerMock);
EmbeddedChannel ch = new EmbeddedChannel(extensionHandler);
HttpRequest req = newUpgradeRequest("main");
ch.writeInbound(req);
HttpResponse res = newUpgradeResponse(null);
ChannelPromise failurePromise = ch.newPromise();
ch.writeOneOutbound(res, failurePromise);
failurePromise.setFailure(new IOException("Cannot write response"));
// test
assertNull(ch.readOutbound());
assertNotNull(ch.pipeline().context(extensionHandler));
assertTrue(ch.finish());
}
} }