diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketServerExtensionHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketServerExtensionHandler.java index d7c3161ba1..6402799de2 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketServerExtensionHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketServerExtensionHandler.java @@ -22,6 +22,7 @@ import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpResponse; @@ -62,8 +63,7 @@ public class WebSocketServerExtensionHandler implements ChannelHandler { } @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) - throws Exception { + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { if (msg instanceof HttpRequest) { HttpRequest request = (HttpRequest) msg; @@ -103,32 +103,42 @@ public class WebSocketServerExtensionHandler implements ChannelHandler { @Override public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { - if (msg instanceof HttpResponse && - WebSocketExtensionUtil.isWebsocketUpgrade(((HttpResponse) msg).headers()) && validExtensions != null) { - HttpResponse response = (HttpResponse) msg; - String headerValue = response.headers().getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS); + if (msg instanceof HttpResponse) { + HttpHeaders headers = ((HttpResponse) msg).headers(); - for (WebSocketServerExtension extension : validExtensions) { - WebSocketExtensionData extensionData = extension.newResponseData(); - headerValue = WebSocketExtensionUtil.appendExtension(headerValue, - extensionData.name(), extensionData.parameters()); - } + if (WebSocketExtensionUtil.isWebsocketUpgrade(headers)) { + + if (validExtensions != null) { + String headerValue = headers.getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS); - promise.addListener((ChannelFutureListener) future -> { - if (future.isSuccess()) { for (WebSocketServerExtension extension : validExtensions) { - WebSocketExtensionDecoder decoder = extension.newExtensionDecoder(); - WebSocketExtensionEncoder encoder = extension.newExtensionEncoder(); - ctx.pipeline().addAfter(ctx.name(), decoder.getClass().getName(), decoder); - ctx.pipeline().addAfter(ctx.name(), encoder.getClass().getName(), encoder); + WebSocketExtensionData extensionData = extension.newResponseData(); + headerValue = WebSocketExtensionUtil.appendExtension(headerValue, + extensionData.name(), + extensionData.parameters()); + } + promise.addListener((ChannelFutureListener) future -> { + if (future.isSuccess()) { + for (WebSocketServerExtension extension : validExtensions) { + WebSocketExtensionDecoder decoder = extension.newExtensionDecoder(); + WebSocketExtensionEncoder encoder = extension.newExtensionEncoder(); + ctx.pipeline() + .addAfter(ctx.name(), decoder.getClass().getName(), decoder) + .addAfter(ctx.name(), encoder.getClass().getName(), encoder); + } + } + }); + + if (headerValue != null) { + headers.set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, headerValue); } } - ctx.pipeline().remove(ctx.name()); - }); - - if (headerValue != null) { - response.headers().set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, headerValue); + promise.addListener((ChannelFutureListener) future -> { + if (future.isSuccess()) { + ctx.pipeline().remove(WebSocketServerExtensionHandler.this); + } + }); } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketServerExtensionHandlerTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketServerExtensionHandlerTest.java index fd2821e907..36b3a60dc2 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketServerExtensionHandlerTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketServerExtensionHandlerTest.java @@ -15,11 +15,13 @@ */ package io.netty.handler.codec.http.websocketx.extensions; +import io.netty.channel.ChannelPromise; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpResponse; +import java.io.IOException; import java.util.Collections; import java.util.List; @@ -62,8 +64,9 @@ public class WebSocketServerExtensionHandlerTest { when(fallbackExtensionMock.rsv()).thenReturn(WebSocketExtension.RSV1); // execute - EmbeddedChannel ch = new EmbeddedChannel(new WebSocketServerExtensionHandler( - mainHandshakerMock, fallbackHandshakerMock)); + WebSocketServerExtensionHandler extensionHandler = + new WebSocketServerExtensionHandler(mainHandshakerMock, fallbackHandshakerMock); + EmbeddedChannel ch = new EmbeddedChannel(extensionHandler); HttpRequest req = newUpgradeRequest("main, fallback"); ch.writeInbound(req); @@ -76,6 +79,7 @@ public class WebSocketServerExtensionHandlerTest { res2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS)); // test + assertNull(ch.pipeline().context(extensionHandler)); assertEquals(1, resExts.size()); assertEquals("main", resExts.get(0).name()); assertTrue(resExts.get(0).parameters().isEmpty()); @@ -119,8 +123,9 @@ public class WebSocketServerExtensionHandlerTest { when(fallbackExtensionMock.newExtensionDecoder()).thenReturn(new Dummy2Decoder()); // execute - EmbeddedChannel ch = new EmbeddedChannel(new WebSocketServerExtensionHandler( - mainHandshakerMock, fallbackHandshakerMock)); + WebSocketServerExtensionHandler extensionHandler = + new WebSocketServerExtensionHandler(mainHandshakerMock, fallbackHandshakerMock); + EmbeddedChannel ch = new EmbeddedChannel(extensionHandler); HttpRequest req = newUpgradeRequest("main, fallback"); ch.writeInbound(req); @@ -133,6 +138,7 @@ public class WebSocketServerExtensionHandlerTest { res2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS)); // test + assertNull(ch.pipeline().context(extensionHandler)); assertEquals(2, resExts.size()); assertEquals("main", resExts.get(0).name()); assertEquals("fallback", resExts.get(1).name()); @@ -170,8 +176,9 @@ public class WebSocketServerExtensionHandlerTest { thenReturn(null); // execute - EmbeddedChannel ch = new EmbeddedChannel(new WebSocketServerExtensionHandler( - mainHandshakerMock, fallbackHandshakerMock)); + WebSocketServerExtensionHandler extensionHandler = + new WebSocketServerExtensionHandler(mainHandshakerMock, fallbackHandshakerMock); + EmbeddedChannel ch = new EmbeddedChannel(extensionHandler); HttpRequest req = newUpgradeRequest("unknown, unknown2"); ch.writeInbound(req); @@ -182,6 +189,7 @@ public class WebSocketServerExtensionHandlerTest { HttpResponse res2 = ch.readOutbound(); // test + assertNull(ch.pipeline().context(extensionHandler)); assertFalse(res2.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS)); verify(mainHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("unknown")); @@ -190,4 +198,31 @@ public class WebSocketServerExtensionHandlerTest { verify(fallbackHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("unknown")); verify(fallbackHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("unknown2")); } + + @Test + public void testExtensionHandlerNotRemovedByFailureWritePromise() { + // initialize + when(mainHandshakerMock.handshakeExtension(webSocketExtensionDataMatcher("main"))) + .thenReturn(mainExtensionMock); + when(mainExtensionMock.newResponseData()).thenReturn( + new WebSocketExtensionData("main", Collections.emptyMap())); + + // execute + WebSocketServerExtensionHandler extensionHandler = + new WebSocketServerExtensionHandler(mainHandshakerMock); + EmbeddedChannel ch = new EmbeddedChannel(extensionHandler); + + HttpRequest req = newUpgradeRequest("main"); + ch.writeInbound(req); + + HttpResponse res = newUpgradeResponse(null); + ChannelPromise failurePromise = ch.newPromise(); + ch.writeOneOutbound(res, failurePromise); + failurePromise.setFailure(new IOException("Cannot write response")); + + // test + assertNull(ch.readOutbound()); + assertNotNull(ch.pipeline().context(extensionHandler)); + assertTrue(ch.finish()); + } }