[#3687] Correctly store WebSocketServerHandshaker in Channel attributes

Motivation:

As we stored the WebSocketServerHandshaker in the ChannelHandlerContext it was always null and so no close frame was send if WebSocketServerProtocolHandler was used.

Modifications:

Store WebSocketServerHAndshaker in the Channel attributes and so make it visibile between different handlers.

Result:

Correctly send close frame.
This commit is contained in:
Norman Maurer 2015-09-09 14:25:32 +02:00
parent 08b4c7d6b5
commit c73cd35de0
3 changed files with 8 additions and 7 deletions

View File

@ -16,6 +16,7 @@
package io.netty.handler.codec.http.websocketx; package io.netty.handler.codec.http.websocketx;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
@ -102,7 +103,7 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
@Override @Override
protected void decode(ChannelHandlerContext ctx, WebSocketFrame frame, List<Object> out) throws Exception { protected void decode(ChannelHandlerContext ctx, WebSocketFrame frame, List<Object> out) throws Exception {
if (frame instanceof CloseWebSocketFrame) { if (frame instanceof CloseWebSocketFrame) {
WebSocketServerHandshaker handshaker = getHandshaker(ctx); WebSocketServerHandshaker handshaker = getHandshaker(ctx.channel());
if (handshaker != null) { if (handshaker != null) {
frame.retain(); frame.retain();
handshaker.close(ctx.channel(), (CloseWebSocketFrame) frame); handshaker.close(ctx.channel(), (CloseWebSocketFrame) frame);
@ -125,12 +126,12 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
} }
} }
static WebSocketServerHandshaker getHandshaker(ChannelHandlerContext ctx) { static WebSocketServerHandshaker getHandshaker(Channel channel) {
return ctx.attr(HANDSHAKER_ATTR_KEY).get(); return channel.attr(HANDSHAKER_ATTR_KEY).get();
} }
static void setHandshaker(ChannelHandlerContext ctx, WebSocketServerHandshaker handshaker) { static void setHandshaker(Channel channel, WebSocketServerHandshaker handshaker) {
ctx.attr(HANDSHAKER_ATTR_KEY).set(handshaker); channel.attr(HANDSHAKER_ATTR_KEY).set(handshaker);
} }
static ChannelHandler forbiddenHttpRequestResponder() { static ChannelHandler forbiddenHttpRequestResponder() {

View File

@ -78,7 +78,7 @@ class WebSocketServerProtocolHandshakeHandler
} }
} }
}); });
WebSocketServerProtocolHandler.setHandshaker(ctx, handshaker); WebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker);
ctx.pipeline().replace(this, "WS403Responder", ctx.pipeline().replace(this, "WS403Responder",
WebSocketServerProtocolHandler.forbiddenHttpRequestResponder()); WebSocketServerProtocolHandler.forbiddenHttpRequestResponder());
} }

View File

@ -55,7 +55,7 @@ public class WebSocketServerProtocolHandlerTest {
ChannelHandlerContext handshakerCtx = ch.pipeline().context(WebSocketServerProtocolHandshakeHandler.class); ChannelHandlerContext handshakerCtx = ch.pipeline().context(WebSocketServerProtocolHandshakeHandler.class);
writeUpgradeRequest(ch); writeUpgradeRequest(ch);
assertEquals(SWITCHING_PROTOCOLS, ReferenceCountUtil.releaseLater(responses.remove()).getStatus()); assertEquals(SWITCHING_PROTOCOLS, ReferenceCountUtil.releaseLater(responses.remove()).getStatus());
assertNotNull(WebSocketServerProtocolHandler.getHandshaker(handshakerCtx)); assertNotNull(WebSocketServerProtocolHandler.getHandshaker(handshakerCtx.channel()));
} }
@Test @Test