Fix remove 'WebSocketServerExtensionHandler' from pipeline after upgrade (#9940)
Motivation: We should remove WebSocketServerExtensionHandler from pipeline after successful WebSocket upgrade even if the client has not selected any extensions. Modification: Remove handler once upgrade is complete and no extensions are used. Result: Fixes #9939.
This commit is contained in:
parent
776e38af88
commit
91404e1828
@ -22,6 +22,7 @@ import io.netty.channel.ChannelHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.handler.codec.http.HttpHeaderNames;
|
||||
import io.netty.handler.codec.http.HttpHeaders;
|
||||
import io.netty.handler.codec.http.HttpRequest;
|
||||
import io.netty.handler.codec.http.HttpResponse;
|
||||
|
||||
@ -62,8 +63,7 @@ public class WebSocketServerExtensionHandler implements ChannelHandler {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelRead(ChannelHandlerContext ctx, Object msg)
|
||||
throws Exception {
|
||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||
if (msg instanceof HttpRequest) {
|
||||
HttpRequest request = (HttpRequest) msg;
|
||||
|
||||
@ -103,32 +103,42 @@ public class WebSocketServerExtensionHandler implements ChannelHandler {
|
||||
|
||||
@Override
|
||||
public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
|
||||
if (msg instanceof HttpResponse &&
|
||||
WebSocketExtensionUtil.isWebsocketUpgrade(((HttpResponse) msg).headers()) && validExtensions != null) {
|
||||
HttpResponse response = (HttpResponse) msg;
|
||||
String headerValue = response.headers().getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
|
||||
if (msg instanceof HttpResponse) {
|
||||
HttpHeaders headers = ((HttpResponse) msg).headers();
|
||||
|
||||
if (WebSocketExtensionUtil.isWebsocketUpgrade(headers)) {
|
||||
|
||||
if (validExtensions != null) {
|
||||
String headerValue = headers.getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
|
||||
|
||||
for (WebSocketServerExtension extension : validExtensions) {
|
||||
WebSocketExtensionData extensionData = extension.newResponseData();
|
||||
headerValue = WebSocketExtensionUtil.appendExtension(headerValue,
|
||||
extensionData.name(), extensionData.parameters());
|
||||
extensionData.name(),
|
||||
extensionData.parameters());
|
||||
}
|
||||
|
||||
promise.addListener((ChannelFutureListener) future -> {
|
||||
if (future.isSuccess()) {
|
||||
for (WebSocketServerExtension extension : validExtensions) {
|
||||
WebSocketExtensionDecoder decoder = extension.newExtensionDecoder();
|
||||
WebSocketExtensionEncoder encoder = extension.newExtensionEncoder();
|
||||
ctx.pipeline().addAfter(ctx.name(), decoder.getClass().getName(), decoder);
|
||||
ctx.pipeline().addAfter(ctx.name(), encoder.getClass().getName(), encoder);
|
||||
ctx.pipeline()
|
||||
.addAfter(ctx.name(), decoder.getClass().getName(), decoder)
|
||||
.addAfter(ctx.name(), encoder.getClass().getName(), encoder);
|
||||
}
|
||||
}
|
||||
|
||||
ctx.pipeline().remove(ctx.name());
|
||||
});
|
||||
|
||||
if (headerValue != null) {
|
||||
response.headers().set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, headerValue);
|
||||
headers.set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, headerValue);
|
||||
}
|
||||
}
|
||||
|
||||
promise.addListener((ChannelFutureListener) future -> {
|
||||
if (future.isSuccess()) {
|
||||
ctx.pipeline().remove(WebSocketServerExtensionHandler.this);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -15,11 +15,13 @@
|
||||
*/
|
||||
package io.netty.handler.codec.http.websocketx.extensions;
|
||||
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.handler.codec.http.HttpHeaderNames;
|
||||
import io.netty.handler.codec.http.HttpRequest;
|
||||
import io.netty.handler.codec.http.HttpResponse;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
@ -62,8 +64,9 @@ public class WebSocketServerExtensionHandlerTest {
|
||||
when(fallbackExtensionMock.rsv()).thenReturn(WebSocketExtension.RSV1);
|
||||
|
||||
// execute
|
||||
EmbeddedChannel ch = new EmbeddedChannel(new WebSocketServerExtensionHandler(
|
||||
mainHandshakerMock, fallbackHandshakerMock));
|
||||
WebSocketServerExtensionHandler extensionHandler =
|
||||
new WebSocketServerExtensionHandler(mainHandshakerMock, fallbackHandshakerMock);
|
||||
EmbeddedChannel ch = new EmbeddedChannel(extensionHandler);
|
||||
|
||||
HttpRequest req = newUpgradeRequest("main, fallback");
|
||||
ch.writeInbound(req);
|
||||
@ -76,6 +79,7 @@ public class WebSocketServerExtensionHandlerTest {
|
||||
res2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS));
|
||||
|
||||
// test
|
||||
assertNull(ch.pipeline().context(extensionHandler));
|
||||
assertEquals(1, resExts.size());
|
||||
assertEquals("main", resExts.get(0).name());
|
||||
assertTrue(resExts.get(0).parameters().isEmpty());
|
||||
@ -119,8 +123,9 @@ public class WebSocketServerExtensionHandlerTest {
|
||||
when(fallbackExtensionMock.newExtensionDecoder()).thenReturn(new Dummy2Decoder());
|
||||
|
||||
// execute
|
||||
EmbeddedChannel ch = new EmbeddedChannel(new WebSocketServerExtensionHandler(
|
||||
mainHandshakerMock, fallbackHandshakerMock));
|
||||
WebSocketServerExtensionHandler extensionHandler =
|
||||
new WebSocketServerExtensionHandler(mainHandshakerMock, fallbackHandshakerMock);
|
||||
EmbeddedChannel ch = new EmbeddedChannel(extensionHandler);
|
||||
|
||||
HttpRequest req = newUpgradeRequest("main, fallback");
|
||||
ch.writeInbound(req);
|
||||
@ -133,6 +138,7 @@ public class WebSocketServerExtensionHandlerTest {
|
||||
res2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS));
|
||||
|
||||
// test
|
||||
assertNull(ch.pipeline().context(extensionHandler));
|
||||
assertEquals(2, resExts.size());
|
||||
assertEquals("main", resExts.get(0).name());
|
||||
assertEquals("fallback", resExts.get(1).name());
|
||||
@ -170,8 +176,9 @@ public class WebSocketServerExtensionHandlerTest {
|
||||
thenReturn(null);
|
||||
|
||||
// execute
|
||||
EmbeddedChannel ch = new EmbeddedChannel(new WebSocketServerExtensionHandler(
|
||||
mainHandshakerMock, fallbackHandshakerMock));
|
||||
WebSocketServerExtensionHandler extensionHandler =
|
||||
new WebSocketServerExtensionHandler(mainHandshakerMock, fallbackHandshakerMock);
|
||||
EmbeddedChannel ch = new EmbeddedChannel(extensionHandler);
|
||||
|
||||
HttpRequest req = newUpgradeRequest("unknown, unknown2");
|
||||
ch.writeInbound(req);
|
||||
@ -182,6 +189,7 @@ public class WebSocketServerExtensionHandlerTest {
|
||||
HttpResponse res2 = ch.readOutbound();
|
||||
|
||||
// test
|
||||
assertNull(ch.pipeline().context(extensionHandler));
|
||||
assertFalse(res2.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS));
|
||||
|
||||
verify(mainHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("unknown"));
|
||||
@ -190,4 +198,31 @@ public class WebSocketServerExtensionHandlerTest {
|
||||
verify(fallbackHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("unknown"));
|
||||
verify(fallbackHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("unknown2"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testExtensionHandlerNotRemovedByFailureWritePromise() {
|
||||
// initialize
|
||||
when(mainHandshakerMock.handshakeExtension(webSocketExtensionDataMatcher("main")))
|
||||
.thenReturn(mainExtensionMock);
|
||||
when(mainExtensionMock.newResponseData()).thenReturn(
|
||||
new WebSocketExtensionData("main", Collections.<String, String>emptyMap()));
|
||||
|
||||
// execute
|
||||
WebSocketServerExtensionHandler extensionHandler =
|
||||
new WebSocketServerExtensionHandler(mainHandshakerMock);
|
||||
EmbeddedChannel ch = new EmbeddedChannel(extensionHandler);
|
||||
|
||||
HttpRequest req = newUpgradeRequest("main");
|
||||
ch.writeInbound(req);
|
||||
|
||||
HttpResponse res = newUpgradeResponse(null);
|
||||
ChannelPromise failurePromise = ch.newPromise();
|
||||
ch.writeOneOutbound(res, failurePromise);
|
||||
failurePromise.setFailure(new IOException("Cannot write response"));
|
||||
|
||||
// test
|
||||
assertNull(ch.readOutbound());
|
||||
assertNotNull(ch.pipeline().context(extensionHandler));
|
||||
assertTrue(ch.finish());
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user