Fix remove 'WebSocketServerExtensionHandler' from pipeline after upgrade (#9940)
Motivation: We should remove WebSocketServerExtensionHandler from pipeline after successful WebSocket upgrade even if the client has not selected any extensions. Modification: Remove handler once upgrade is complete and no extensions are used. Result: Fixes #9939.
This commit is contained in:
parent
776e38af88
commit
91404e1828
@ -22,6 +22,7 @@ import io.netty.channel.ChannelHandler;
|
|||||||
import io.netty.channel.ChannelHandlerContext;
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
import io.netty.channel.ChannelPromise;
|
import io.netty.channel.ChannelPromise;
|
||||||
import io.netty.handler.codec.http.HttpHeaderNames;
|
import io.netty.handler.codec.http.HttpHeaderNames;
|
||||||
|
import io.netty.handler.codec.http.HttpHeaders;
|
||||||
import io.netty.handler.codec.http.HttpRequest;
|
import io.netty.handler.codec.http.HttpRequest;
|
||||||
import io.netty.handler.codec.http.HttpResponse;
|
import io.netty.handler.codec.http.HttpResponse;
|
||||||
|
|
||||||
@ -62,8 +63,7 @@ public class WebSocketServerExtensionHandler implements ChannelHandler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void channelRead(ChannelHandlerContext ctx, Object msg)
|
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||||
throws Exception {
|
|
||||||
if (msg instanceof HttpRequest) {
|
if (msg instanceof HttpRequest) {
|
||||||
HttpRequest request = (HttpRequest) msg;
|
HttpRequest request = (HttpRequest) msg;
|
||||||
|
|
||||||
@ -103,32 +103,42 @@ public class WebSocketServerExtensionHandler implements ChannelHandler {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
|
public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
|
||||||
if (msg instanceof HttpResponse &&
|
if (msg instanceof HttpResponse) {
|
||||||
WebSocketExtensionUtil.isWebsocketUpgrade(((HttpResponse) msg).headers()) && validExtensions != null) {
|
HttpHeaders headers = ((HttpResponse) msg).headers();
|
||||||
HttpResponse response = (HttpResponse) msg;
|
|
||||||
String headerValue = response.headers().getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
|
if (WebSocketExtensionUtil.isWebsocketUpgrade(headers)) {
|
||||||
|
|
||||||
|
if (validExtensions != null) {
|
||||||
|
String headerValue = headers.getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
|
||||||
|
|
||||||
for (WebSocketServerExtension extension : validExtensions) {
|
for (WebSocketServerExtension extension : validExtensions) {
|
||||||
WebSocketExtensionData extensionData = extension.newResponseData();
|
WebSocketExtensionData extensionData = extension.newResponseData();
|
||||||
headerValue = WebSocketExtensionUtil.appendExtension(headerValue,
|
headerValue = WebSocketExtensionUtil.appendExtension(headerValue,
|
||||||
extensionData.name(), extensionData.parameters());
|
extensionData.name(),
|
||||||
|
extensionData.parameters());
|
||||||
}
|
}
|
||||||
|
|
||||||
promise.addListener((ChannelFutureListener) future -> {
|
promise.addListener((ChannelFutureListener) future -> {
|
||||||
if (future.isSuccess()) {
|
if (future.isSuccess()) {
|
||||||
for (WebSocketServerExtension extension : validExtensions) {
|
for (WebSocketServerExtension extension : validExtensions) {
|
||||||
WebSocketExtensionDecoder decoder = extension.newExtensionDecoder();
|
WebSocketExtensionDecoder decoder = extension.newExtensionDecoder();
|
||||||
WebSocketExtensionEncoder encoder = extension.newExtensionEncoder();
|
WebSocketExtensionEncoder encoder = extension.newExtensionEncoder();
|
||||||
ctx.pipeline().addAfter(ctx.name(), decoder.getClass().getName(), decoder);
|
ctx.pipeline()
|
||||||
ctx.pipeline().addAfter(ctx.name(), encoder.getClass().getName(), encoder);
|
.addAfter(ctx.name(), decoder.getClass().getName(), decoder)
|
||||||
|
.addAfter(ctx.name(), encoder.getClass().getName(), encoder);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.pipeline().remove(ctx.name());
|
|
||||||
});
|
});
|
||||||
|
|
||||||
if (headerValue != null) {
|
if (headerValue != null) {
|
||||||
response.headers().set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, headerValue);
|
headers.set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, headerValue);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
promise.addListener((ChannelFutureListener) future -> {
|
||||||
|
if (future.isSuccess()) {
|
||||||
|
ctx.pipeline().remove(WebSocketServerExtensionHandler.this);
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,11 +15,13 @@
|
|||||||
*/
|
*/
|
||||||
package io.netty.handler.codec.http.websocketx.extensions;
|
package io.netty.handler.codec.http.websocketx.extensions;
|
||||||
|
|
||||||
|
import io.netty.channel.ChannelPromise;
|
||||||
import io.netty.channel.embedded.EmbeddedChannel;
|
import io.netty.channel.embedded.EmbeddedChannel;
|
||||||
import io.netty.handler.codec.http.HttpHeaderNames;
|
import io.netty.handler.codec.http.HttpHeaderNames;
|
||||||
import io.netty.handler.codec.http.HttpRequest;
|
import io.netty.handler.codec.http.HttpRequest;
|
||||||
import io.netty.handler.codec.http.HttpResponse;
|
import io.netty.handler.codec.http.HttpResponse;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@ -62,8 +64,9 @@ public class WebSocketServerExtensionHandlerTest {
|
|||||||
when(fallbackExtensionMock.rsv()).thenReturn(WebSocketExtension.RSV1);
|
when(fallbackExtensionMock.rsv()).thenReturn(WebSocketExtension.RSV1);
|
||||||
|
|
||||||
// execute
|
// execute
|
||||||
EmbeddedChannel ch = new EmbeddedChannel(new WebSocketServerExtensionHandler(
|
WebSocketServerExtensionHandler extensionHandler =
|
||||||
mainHandshakerMock, fallbackHandshakerMock));
|
new WebSocketServerExtensionHandler(mainHandshakerMock, fallbackHandshakerMock);
|
||||||
|
EmbeddedChannel ch = new EmbeddedChannel(extensionHandler);
|
||||||
|
|
||||||
HttpRequest req = newUpgradeRequest("main, fallback");
|
HttpRequest req = newUpgradeRequest("main, fallback");
|
||||||
ch.writeInbound(req);
|
ch.writeInbound(req);
|
||||||
@ -76,6 +79,7 @@ public class WebSocketServerExtensionHandlerTest {
|
|||||||
res2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS));
|
res2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS));
|
||||||
|
|
||||||
// test
|
// test
|
||||||
|
assertNull(ch.pipeline().context(extensionHandler));
|
||||||
assertEquals(1, resExts.size());
|
assertEquals(1, resExts.size());
|
||||||
assertEquals("main", resExts.get(0).name());
|
assertEquals("main", resExts.get(0).name());
|
||||||
assertTrue(resExts.get(0).parameters().isEmpty());
|
assertTrue(resExts.get(0).parameters().isEmpty());
|
||||||
@ -119,8 +123,9 @@ public class WebSocketServerExtensionHandlerTest {
|
|||||||
when(fallbackExtensionMock.newExtensionDecoder()).thenReturn(new Dummy2Decoder());
|
when(fallbackExtensionMock.newExtensionDecoder()).thenReturn(new Dummy2Decoder());
|
||||||
|
|
||||||
// execute
|
// execute
|
||||||
EmbeddedChannel ch = new EmbeddedChannel(new WebSocketServerExtensionHandler(
|
WebSocketServerExtensionHandler extensionHandler =
|
||||||
mainHandshakerMock, fallbackHandshakerMock));
|
new WebSocketServerExtensionHandler(mainHandshakerMock, fallbackHandshakerMock);
|
||||||
|
EmbeddedChannel ch = new EmbeddedChannel(extensionHandler);
|
||||||
|
|
||||||
HttpRequest req = newUpgradeRequest("main, fallback");
|
HttpRequest req = newUpgradeRequest("main, fallback");
|
||||||
ch.writeInbound(req);
|
ch.writeInbound(req);
|
||||||
@ -133,6 +138,7 @@ public class WebSocketServerExtensionHandlerTest {
|
|||||||
res2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS));
|
res2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS));
|
||||||
|
|
||||||
// test
|
// test
|
||||||
|
assertNull(ch.pipeline().context(extensionHandler));
|
||||||
assertEquals(2, resExts.size());
|
assertEquals(2, resExts.size());
|
||||||
assertEquals("main", resExts.get(0).name());
|
assertEquals("main", resExts.get(0).name());
|
||||||
assertEquals("fallback", resExts.get(1).name());
|
assertEquals("fallback", resExts.get(1).name());
|
||||||
@ -170,8 +176,9 @@ public class WebSocketServerExtensionHandlerTest {
|
|||||||
thenReturn(null);
|
thenReturn(null);
|
||||||
|
|
||||||
// execute
|
// execute
|
||||||
EmbeddedChannel ch = new EmbeddedChannel(new WebSocketServerExtensionHandler(
|
WebSocketServerExtensionHandler extensionHandler =
|
||||||
mainHandshakerMock, fallbackHandshakerMock));
|
new WebSocketServerExtensionHandler(mainHandshakerMock, fallbackHandshakerMock);
|
||||||
|
EmbeddedChannel ch = new EmbeddedChannel(extensionHandler);
|
||||||
|
|
||||||
HttpRequest req = newUpgradeRequest("unknown, unknown2");
|
HttpRequest req = newUpgradeRequest("unknown, unknown2");
|
||||||
ch.writeInbound(req);
|
ch.writeInbound(req);
|
||||||
@ -182,6 +189,7 @@ public class WebSocketServerExtensionHandlerTest {
|
|||||||
HttpResponse res2 = ch.readOutbound();
|
HttpResponse res2 = ch.readOutbound();
|
||||||
|
|
||||||
// test
|
// test
|
||||||
|
assertNull(ch.pipeline().context(extensionHandler));
|
||||||
assertFalse(res2.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS));
|
assertFalse(res2.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS));
|
||||||
|
|
||||||
verify(mainHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("unknown"));
|
verify(mainHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("unknown"));
|
||||||
@ -190,4 +198,31 @@ public class WebSocketServerExtensionHandlerTest {
|
|||||||
verify(fallbackHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("unknown"));
|
verify(fallbackHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("unknown"));
|
||||||
verify(fallbackHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("unknown2"));
|
verify(fallbackHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("unknown2"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testExtensionHandlerNotRemovedByFailureWritePromise() {
|
||||||
|
// initialize
|
||||||
|
when(mainHandshakerMock.handshakeExtension(webSocketExtensionDataMatcher("main")))
|
||||||
|
.thenReturn(mainExtensionMock);
|
||||||
|
when(mainExtensionMock.newResponseData()).thenReturn(
|
||||||
|
new WebSocketExtensionData("main", Collections.<String, String>emptyMap()));
|
||||||
|
|
||||||
|
// execute
|
||||||
|
WebSocketServerExtensionHandler extensionHandler =
|
||||||
|
new WebSocketServerExtensionHandler(mainHandshakerMock);
|
||||||
|
EmbeddedChannel ch = new EmbeddedChannel(extensionHandler);
|
||||||
|
|
||||||
|
HttpRequest req = newUpgradeRequest("main");
|
||||||
|
ch.writeInbound(req);
|
||||||
|
|
||||||
|
HttpResponse res = newUpgradeResponse(null);
|
||||||
|
ChannelPromise failurePromise = ch.newPromise();
|
||||||
|
ch.writeOneOutbound(res, failurePromise);
|
||||||
|
failurePromise.setFailure(new IOException("Cannot write response"));
|
||||||
|
|
||||||
|
// test
|
||||||
|
assertNull(ch.readOutbound());
|
||||||
|
assertNotNull(ch.pipeline().context(extensionHandler));
|
||||||
|
assertTrue(ch.finish());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user