Send close frame on channel close, when this frame was not send manually (#9745)

Motivation:
By default CloseWebSocketFrames are handled automatically.
However I need manually manage their sending both on client- and on server-sides.

Modification:
Send close frame on channel close automatically, when it was not send before explicitly.

Result:
No more messages like "Connection closed by remote peer" for normal close flows.
This commit is contained in:
ursa 2019-11-18 19:32:21 +00:00 committed by Norman Maurer
parent 7632de7084
commit b5230c7b9c
8 changed files with 357 additions and 36 deletions

View File

@ -30,8 +30,8 @@ import static io.netty.util.internal.ObjectUtil.checkPositive;
public final class WebSocketClientProtocolConfig { public final class WebSocketClientProtocolConfig {
static final WebSocketClientProtocolConfig DEFAULT = new WebSocketClientProtocolConfig( static final WebSocketClientProtocolConfig DEFAULT = new WebSocketClientProtocolConfig(
URI.create("https://localhost/"), null, WebSocketVersion.V13, false, URI.create("https://localhost/"), null, WebSocketVersion.V13, false, EmptyHttpHeaders.INSTANCE,
EmptyHttpHeaders.INSTANCE, 65536, true, false, true, true, 10000L, -1, false); 65536, true, false, true, WebSocketCloseStatus.NORMAL_CLOSURE, true, 10000L, -1, false);
private final URI webSocketUri; private final URI webSocketUri;
private final String subprotocol; private final String subprotocol;
@ -42,6 +42,7 @@ public final class WebSocketClientProtocolConfig {
private final boolean performMasking; private final boolean performMasking;
private final boolean allowMaskMismatch; private final boolean allowMaskMismatch;
private final boolean handleCloseFrames; private final boolean handleCloseFrames;
private final WebSocketCloseStatus sendCloseFrame;
private final boolean dropPongFrames; private final boolean dropPongFrames;
private final long handshakeTimeoutMillis; private final long handshakeTimeoutMillis;
private final long forceCloseTimeoutMillis; private final long forceCloseTimeoutMillis;
@ -57,6 +58,7 @@ public final class WebSocketClientProtocolConfig {
boolean performMasking, boolean performMasking,
boolean allowMaskMismatch, boolean allowMaskMismatch,
boolean handleCloseFrames, boolean handleCloseFrames,
WebSocketCloseStatus sendCloseFrame,
boolean dropPongFrames, boolean dropPongFrames,
long handshakeTimeoutMillis, long handshakeTimeoutMillis,
long forceCloseTimeoutMillis, long forceCloseTimeoutMillis,
@ -72,6 +74,7 @@ public final class WebSocketClientProtocolConfig {
this.allowMaskMismatch = allowMaskMismatch; this.allowMaskMismatch = allowMaskMismatch;
this.forceCloseTimeoutMillis = forceCloseTimeoutMillis; this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
this.handleCloseFrames = handleCloseFrames; this.handleCloseFrames = handleCloseFrames;
this.sendCloseFrame = sendCloseFrame;
this.dropPongFrames = dropPongFrames; this.dropPongFrames = dropPongFrames;
this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis"); this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis");
this.absoluteUpgradeUrl = absoluteUpgradeUrl; this.absoluteUpgradeUrl = absoluteUpgradeUrl;
@ -113,6 +116,10 @@ public final class WebSocketClientProtocolConfig {
return handleCloseFrames; return handleCloseFrames;
} }
public WebSocketCloseStatus sendCloseFrame() {
return sendCloseFrame;
}
public boolean dropPongFrames() { public boolean dropPongFrames() {
return dropPongFrames; return dropPongFrames;
} }
@ -141,6 +148,7 @@ public final class WebSocketClientProtocolConfig {
", performMasking=" + performMasking + ", performMasking=" + performMasking +
", allowMaskMismatch=" + allowMaskMismatch + ", allowMaskMismatch=" + allowMaskMismatch +
", handleCloseFrames=" + handleCloseFrames + ", handleCloseFrames=" + handleCloseFrames +
", sendCloseFrame=" + sendCloseFrame +
", dropPongFrames=" + dropPongFrames + ", dropPongFrames=" + dropPongFrames +
", handshakeTimeoutMillis=" + handshakeTimeoutMillis + ", handshakeTimeoutMillis=" + handshakeTimeoutMillis +
", forceCloseTimeoutMillis=" + forceCloseTimeoutMillis + ", forceCloseTimeoutMillis=" + forceCloseTimeoutMillis +
@ -166,6 +174,7 @@ public final class WebSocketClientProtocolConfig {
private boolean performMasking; private boolean performMasking;
private boolean allowMaskMismatch; private boolean allowMaskMismatch;
private boolean handleCloseFrames; private boolean handleCloseFrames;
private WebSocketCloseStatus sendCloseFrame;
private boolean dropPongFrames; private boolean dropPongFrames;
private long handshakeTimeoutMillis; private long handshakeTimeoutMillis;
private long forceCloseTimeoutMillis; private long forceCloseTimeoutMillis;
@ -174,19 +183,20 @@ public final class WebSocketClientProtocolConfig {
private Builder(WebSocketClientProtocolConfig clientConfig) { private Builder(WebSocketClientProtocolConfig clientConfig) {
ObjectUtil.checkNotNull(clientConfig, "clientConfig"); ObjectUtil.checkNotNull(clientConfig, "clientConfig");
this.webSocketUri = clientConfig.webSocketUri(); webSocketUri = clientConfig.webSocketUri();
this.subprotocol = clientConfig.subprotocol(); subprotocol = clientConfig.subprotocol();
this.version = clientConfig.version(); version = clientConfig.version();
this.allowExtensions = clientConfig.allowExtensions(); allowExtensions = clientConfig.allowExtensions();
this.customHeaders = clientConfig.customHeaders(); customHeaders = clientConfig.customHeaders();
this.maxFramePayloadLength = clientConfig.maxFramePayloadLength(); maxFramePayloadLength = clientConfig.maxFramePayloadLength();
this.performMasking = clientConfig.performMasking(); performMasking = clientConfig.performMasking();
this.allowMaskMismatch = clientConfig.allowMaskMismatch(); allowMaskMismatch = clientConfig.allowMaskMismatch();
this.handleCloseFrames = clientConfig.handleCloseFrames(); handleCloseFrames = clientConfig.handleCloseFrames();
this.dropPongFrames = clientConfig.dropPongFrames(); sendCloseFrame = clientConfig.sendCloseFrame();
this.handshakeTimeoutMillis = clientConfig.handshakeTimeoutMillis(); dropPongFrames = clientConfig.dropPongFrames();
this.forceCloseTimeoutMillis = clientConfig.forceCloseTimeoutMillis(); handshakeTimeoutMillis = clientConfig.handshakeTimeoutMillis();
this.absoluteUpgradeUrl = clientConfig.absoluteUpgradeUrl(); forceCloseTimeoutMillis = clientConfig.forceCloseTimeoutMillis();
absoluteUpgradeUrl = clientConfig.absoluteUpgradeUrl();
} }
/** /**
@ -272,6 +282,14 @@ public final class WebSocketClientProtocolConfig {
return this; return this;
} }
/**
* Close frame to send, when close frame was not send manually. Or {@code null} to disable proper close.
*/
public Builder sendCloseFrame(WebSocketCloseStatus sendCloseFrame) {
this.sendCloseFrame = sendCloseFrame;
return this;
}
/** /**
* {@code true} if pong frames should not be forwarded * {@code true} if pong frames should not be forwarded
*/ */
@ -319,6 +337,7 @@ public final class WebSocketClientProtocolConfig {
performMasking, performMasking,
allowMaskMismatch, allowMaskMismatch,
handleCloseFrames, handleCloseFrames,
sendCloseFrame,
dropPongFrames, dropPongFrames,
handshakeTimeoutMillis, handshakeTimeoutMillis,
forceCloseTimeoutMillis, forceCloseTimeoutMillis,

View File

@ -42,8 +42,7 @@ import static io.netty.util.internal.ObjectUtil.*;
*/ */
public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler { public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
private final WebSocketClientHandshaker handshaker; private final WebSocketClientHandshaker handshaker;
private final boolean handleCloseFrames; private final WebSocketClientProtocolConfig clientConfig;
private final long handshakeTimeoutMillis;
/** /**
* Returns the used handshaker * Returns the used handshaker
@ -92,8 +91,7 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
clientConfig.forceCloseTimeoutMillis(), clientConfig.forceCloseTimeoutMillis(),
clientConfig.absoluteUpgradeUrl() clientConfig.absoluteUpgradeUrl()
); );
this.handleCloseFrames = clientConfig.handleCloseFrames(); this.clientConfig = clientConfig;
this.handshakeTimeoutMillis = clientConfig.handshakeTimeoutMillis();
} }
/** /**
@ -327,8 +325,10 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
boolean dropPongFrames, long handshakeTimeoutMillis) { boolean dropPongFrames, long handshakeTimeoutMillis) {
super(dropPongFrames); super(dropPongFrames);
this.handshaker = handshaker; this.handshaker = handshaker;
this.handleCloseFrames = handleCloseFrames; this.clientConfig = WebSocketClientProtocolConfig.newBuilder()
this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis"); .handleCloseFrames(handleCloseFrames)
.handshakeTimeoutMillis(handshakeTimeoutMillis)
.build();
} }
/** /**
@ -358,7 +358,7 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
@Override @Override
protected void decode(ChannelHandlerContext ctx, WebSocketFrame frame, List<Object> out) throws Exception { protected void decode(ChannelHandlerContext ctx, WebSocketFrame frame, List<Object> out) throws Exception {
if (handleCloseFrames && frame instanceof CloseWebSocketFrame) { if (clientConfig.handleCloseFrames() && frame instanceof CloseWebSocketFrame) {
ctx.close(); ctx.close();
return; return;
} }
@ -371,12 +371,16 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
if (cp.get(WebSocketClientProtocolHandshakeHandler.class) == null) { if (cp.get(WebSocketClientProtocolHandshakeHandler.class) == null) {
// Add the WebSocketClientProtocolHandshakeHandler before this one. // Add the WebSocketClientProtocolHandshakeHandler before this one.
ctx.pipeline().addBefore(ctx.name(), WebSocketClientProtocolHandshakeHandler.class.getName(), ctx.pipeline().addBefore(ctx.name(), WebSocketClientProtocolHandshakeHandler.class.getName(),
new WebSocketClientProtocolHandshakeHandler(handshaker, handshakeTimeoutMillis)); new WebSocketClientProtocolHandshakeHandler(handshaker, clientConfig.handshakeTimeoutMillis()));
} }
if (cp.get(Utf8FrameValidator.class) == null) { if (cp.get(Utf8FrameValidator.class) == null) {
// Add the UFT8 checking before this one. // Add the UFT8 checking before this one.
ctx.pipeline().addBefore(ctx.name(), Utf8FrameValidator.class.getName(), ctx.pipeline().addBefore(ctx.name(), Utf8FrameValidator.class.getName(),
new Utf8FrameValidator()); new Utf8FrameValidator());
} }
if (clientConfig.sendCloseFrame() != null) {
cp.addBefore(ctx.name(), WebSocketCloseFrameHandler.class.getName(),
new WebSocketCloseFrameHandler(clientConfig.sendCloseFrame(), clientConfig.forceCloseTimeoutMillis()));
}
} }
} }

View File

@ -0,0 +1,97 @@
/*
* Copyright 2019 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.http.websocketx;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.ScheduledFuture;
import io.netty.util.internal.ObjectUtil;
import java.nio.channels.ClosedChannelException;
import java.util.concurrent.TimeUnit;
/**
* Send {@link CloseWebSocketFrame} message on channel close, if close frame was not sent before.
*/
final class WebSocketCloseFrameHandler extends ChannelOutboundHandlerAdapter {
private final WebSocketCloseStatus closeStatus;
private final long forceCloseTimeoutMillis;
private ChannelPromise closeSent;
WebSocketCloseFrameHandler(WebSocketCloseStatus closeStatus, long forceCloseTimeoutMillis) {
this.closeStatus = ObjectUtil.checkNotNull(closeStatus, "closeStatus");
this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
}
@Override
public void close(final ChannelHandlerContext ctx, final ChannelPromise promise) throws Exception {
if (!ctx.channel().isActive()) {
ctx.close(promise);
return;
}
if (closeSent == null) {
write(ctx, new CloseWebSocketFrame(closeStatus), ctx.newPromise());
}
flush(ctx);
applyCloseSentTimeout(ctx);
closeSent.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
ctx.close(promise);
}
});
}
@Override
public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
if (closeSent != null) {
ReferenceCountUtil.release(msg);
promise.setFailure(new ClosedChannelException());
return;
}
if (msg instanceof CloseWebSocketFrame) {
promise = promise.unvoid();
closeSent = promise;
}
super.write(ctx, msg, promise);
}
private void applyCloseSentTimeout(ChannelHandlerContext ctx) {
if (closeSent.isDone() || forceCloseTimeoutMillis < 0) {
return;
}
final ScheduledFuture<?> timeoutTask = ctx.executor().schedule(new Runnable() {
@Override
public void run() {
if (!closeSent.isDone()) {
closeSent.tryFailure(new WebSocketHandshakeException("send close frame timed out"));
}
}
}, forceCloseTimeoutMillis, TimeUnit.MILLISECONDS);
closeSent.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
timeoutTask.cancel(false);
}
});
}
}

View File

@ -25,14 +25,16 @@ import static io.netty.util.internal.ObjectUtil.checkPositive;
*/ */
public final class WebSocketServerProtocolConfig { public final class WebSocketServerProtocolConfig {
static final WebSocketServerProtocolConfig DEFAULT = static final WebSocketServerProtocolConfig DEFAULT = new WebSocketServerProtocolConfig(
new WebSocketServerProtocolConfig("/", null, false, 10000L, true, true, WebSocketDecoderConfig.DEFAULT); "/", null, false, 10000L, 0, true, WebSocketCloseStatus.NORMAL_CLOSURE, true, WebSocketDecoderConfig.DEFAULT);
private final String websocketPath; private final String websocketPath;
private final String subprotocols; private final String subprotocols;
private final boolean checkStartsWith; private final boolean checkStartsWith;
private final long handshakeTimeoutMillis; private final long handshakeTimeoutMillis;
private final long forceCloseTimeoutMillis;
private final boolean handleCloseFrames; private final boolean handleCloseFrames;
private final WebSocketCloseStatus sendCloseFrame;
private final boolean dropPongFrames; private final boolean dropPongFrames;
private final WebSocketDecoderConfig decoderConfig; private final WebSocketDecoderConfig decoderConfig;
@ -41,7 +43,9 @@ public final class WebSocketServerProtocolConfig {
String subprotocols, String subprotocols,
boolean checkStartsWith, boolean checkStartsWith,
long handshakeTimeoutMillis, long handshakeTimeoutMillis,
long forceCloseTimeoutMillis,
boolean handleCloseFrames, boolean handleCloseFrames,
WebSocketCloseStatus sendCloseFrame,
boolean dropPongFrames, boolean dropPongFrames,
WebSocketDecoderConfig decoderConfig WebSocketDecoderConfig decoderConfig
) { ) {
@ -49,7 +53,9 @@ public final class WebSocketServerProtocolConfig {
this.subprotocols = subprotocols; this.subprotocols = subprotocols;
this.checkStartsWith = checkStartsWith; this.checkStartsWith = checkStartsWith;
this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis"); this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis");
this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
this.handleCloseFrames = handleCloseFrames; this.handleCloseFrames = handleCloseFrames;
this.sendCloseFrame = sendCloseFrame;
this.dropPongFrames = dropPongFrames; this.dropPongFrames = dropPongFrames;
this.decoderConfig = decoderConfig == null ? WebSocketDecoderConfig.DEFAULT : decoderConfig; this.decoderConfig = decoderConfig == null ? WebSocketDecoderConfig.DEFAULT : decoderConfig;
} }
@ -70,10 +76,18 @@ public final class WebSocketServerProtocolConfig {
return handshakeTimeoutMillis; return handshakeTimeoutMillis;
} }
public long forceCloseTimeoutMillis() {
return forceCloseTimeoutMillis;
}
public boolean handleCloseFrames() { public boolean handleCloseFrames() {
return handleCloseFrames; return handleCloseFrames;
} }
public WebSocketCloseStatus sendCloseFrame() {
return sendCloseFrame;
}
public boolean dropPongFrames() { public boolean dropPongFrames() {
return dropPongFrames; return dropPongFrames;
} }
@ -89,7 +103,9 @@ public final class WebSocketServerProtocolConfig {
", subprotocols=" + subprotocols + ", subprotocols=" + subprotocols +
", checkStartsWith=" + checkStartsWith + ", checkStartsWith=" + checkStartsWith +
", handshakeTimeoutMillis=" + handshakeTimeoutMillis + ", handshakeTimeoutMillis=" + handshakeTimeoutMillis +
", forceCloseTimeoutMillis=" + forceCloseTimeoutMillis +
", handleCloseFrames=" + handleCloseFrames + ", handleCloseFrames=" + handleCloseFrames +
", sendCloseFrame=" + sendCloseFrame +
", dropPongFrames=" + dropPongFrames + ", dropPongFrames=" + dropPongFrames +
", decoderConfig=" + decoderConfig + ", decoderConfig=" + decoderConfig +
"}"; "}";
@ -108,7 +124,9 @@ public final class WebSocketServerProtocolConfig {
private String subprotocols; private String subprotocols;
private boolean checkStartsWith; private boolean checkStartsWith;
private long handshakeTimeoutMillis; private long handshakeTimeoutMillis;
private long forceCloseTimeoutMillis;
private boolean handleCloseFrames; private boolean handleCloseFrames;
private WebSocketCloseStatus sendCloseFrame;
private boolean dropPongFrames; private boolean dropPongFrames;
private WebSocketDecoderConfig decoderConfig; private WebSocketDecoderConfig decoderConfig;
private WebSocketDecoderConfig.Builder decoderConfigBuilder; private WebSocketDecoderConfig.Builder decoderConfigBuilder;
@ -119,7 +137,9 @@ public final class WebSocketServerProtocolConfig {
subprotocols = serverConfig.subprotocols(); subprotocols = serverConfig.subprotocols();
checkStartsWith = serverConfig.checkStartsWith(); checkStartsWith = serverConfig.checkStartsWith();
handshakeTimeoutMillis = serverConfig.handshakeTimeoutMillis(); handshakeTimeoutMillis = serverConfig.handshakeTimeoutMillis();
forceCloseTimeoutMillis = serverConfig.forceCloseTimeoutMillis();
handleCloseFrames = serverConfig.handleCloseFrames(); handleCloseFrames = serverConfig.handleCloseFrames();
sendCloseFrame = serverConfig.sendCloseFrame();
dropPongFrames = serverConfig.dropPongFrames(); dropPongFrames = serverConfig.dropPongFrames();
decoderConfig = serverConfig.decoderConfig(); decoderConfig = serverConfig.decoderConfig();
} }
@ -158,6 +178,14 @@ public final class WebSocketServerProtocolConfig {
return this; return this;
} }
/**
* Close the connection if it was not closed by the client after timeout specified
*/
public Builder forceCloseTimeoutMillis(long forceCloseTimeoutMillis) {
this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
return this;
}
/** /**
* {@code true} if close frames should not be forwarded and just close the channel * {@code true} if close frames should not be forwarded and just close the channel
*/ */
@ -166,6 +194,14 @@ public final class WebSocketServerProtocolConfig {
return this; return this;
} }
/**
* Close frame to send, when close frame was not send manually. Or {@code null} to disable proper close.
*/
public Builder sendCloseFrame(WebSocketCloseStatus sendCloseFrame) {
this.sendCloseFrame = sendCloseFrame;
return this;
}
/** /**
* {@code true} if pong frames should not be forwarded * {@code true} if pong frames should not be forwarded
*/ */
@ -229,7 +265,9 @@ public final class WebSocketServerProtocolConfig {
subprotocols, subprotocols,
checkStartsWith, checkStartsWith,
handshakeTimeoutMillis, handshakeTimeoutMillis,
forceCloseTimeoutMillis,
handleCloseFrames, handleCloseFrames,
sendCloseFrame,
dropPongFrames, dropPongFrames,
decoderConfigBuilder == null ? decoderConfig : decoderConfigBuilder.build() decoderConfigBuilder == null ? decoderConfig : decoderConfigBuilder.build()
); );

View File

@ -229,6 +229,10 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
cp.addBefore(ctx.name(), Utf8FrameValidator.class.getName(), cp.addBefore(ctx.name(), Utf8FrameValidator.class.getName(),
new Utf8FrameValidator()); new Utf8FrameValidator());
} }
if (serverConfig.sendCloseFrame() != null) {
cp.addBefore(ctx.name(), WebSocketCloseFrameHandler.class.getName(),
new WebSocketCloseFrameHandler(serverConfig.sendCloseFrame(), serverConfig.forceCloseTimeoutMillis()));
}
} }
@Override @Override

View File

@ -78,6 +78,9 @@ public class WebSocket08EncoderDecoderTest {
Assert.assertEquals(errorMessage, response.reasonText()); Assert.assertEquals(errorMessage, response.reasonText());
response.release(); response.release();
Assert.assertFalse(inChannel.finish());
Assert.assertFalse(outChannel.finish());
// Without auto-close // Without auto-close
config = WebSocketDecoderConfig.newBuilder() config = WebSocketDecoderConfig.newBuilder()
.maxFramePayloadLength(maxPayloadLength) .maxFramePayloadLength(maxPayloadLength)
@ -91,10 +94,11 @@ public class WebSocket08EncoderDecoderTest {
response = inChannel.readOutbound(); response = inChannel.readOutbound();
Assert.assertNull(response); Assert.assertNull(response);
// Release test data
binTestData.release();
Assert.assertFalse(inChannel.finish()); Assert.assertFalse(inChannel.finish());
Assert.assertFalse(outChannel.finish()); Assert.assertFalse(outChannel.finish());
// Release test data
binTestData.release();
} }
private void executeProtocolViolationTest(EmbeddedChannel outChannel, EmbeddedChannel inChannel, private void executeProtocolViolationTest(EmbeddedChannel outChannel, EmbeddedChannel inChannel,

View File

@ -49,7 +49,11 @@ public class WebSocketHandshakeHandOverTest {
private final class CloseNoOpServerProtocolHandler extends WebSocketServerProtocolHandler { private final class CloseNoOpServerProtocolHandler extends WebSocketServerProtocolHandler {
CloseNoOpServerProtocolHandler(String websocketPath) { CloseNoOpServerProtocolHandler(String websocketPath) {
super(websocketPath, null, false); super(WebSocketServerProtocolConfig.newBuilder()
.websocketPath(websocketPath)
.allowExtensions(false)
.sendCloseFrame(null)
.build());
} }
@Override @Override

View File

@ -15,6 +15,7 @@
*/ */
package io.netty.handler.codec.http.websocketx; package io.netty.handler.codec.http.websocketx;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInboundHandlerAdapter;
@ -24,11 +25,14 @@ import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest; import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpHeaderValues; import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpRequestDecoder; import io.netty.handler.codec.http.HttpRequestDecoder;
import io.netty.handler.codec.http.HttpResponseEncoder; import io.netty.handler.codec.http.HttpResponseEncoder;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.util.CharsetUtil; import io.netty.util.CharsetUtil;
import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCountUtil;
import org.junit.Before; import org.junit.Before;
@ -51,7 +55,7 @@ public class WebSocketServerProtocolHandlerTest {
} }
@Test @Test
public void testHttpUpgradeRequest() throws Exception { public void testHttpUpgradeRequest() {
EmbeddedChannel ch = createChannel(new MockOutboundHandler()); EmbeddedChannel ch = createChannel(new MockOutboundHandler());
ChannelHandlerContext handshakerCtx = ch.pipeline().context(WebSocketServerProtocolHandshakeHandler.class); ChannelHandlerContext handshakerCtx = ch.pipeline().context(WebSocketServerProtocolHandshakeHandler.class);
writeUpgradeRequest(ch); writeUpgradeRequest(ch);
@ -64,12 +68,12 @@ public class WebSocketServerProtocolHandlerTest {
} }
@Test @Test
public void testWebSocketServerProtocolHandshakeHandlerReplacedBeforeHandshake() throws Exception { public void testWebSocketServerProtocolHandshakeHandlerReplacedBeforeHandshake() {
EmbeddedChannel ch = createChannel(new MockOutboundHandler()); EmbeddedChannel ch = createChannel(new MockOutboundHandler());
ChannelHandlerContext handshakerCtx = ch.pipeline().context(WebSocketServerProtocolHandshakeHandler.class); ChannelHandlerContext handshakerCtx = ch.pipeline().context(WebSocketServerProtocolHandshakeHandler.class);
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override @Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) { if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
// We should have removed the handler already. // We should have removed the handler already.
assertNull(ctx.pipeline().context(WebSocketServerProtocolHandshakeHandler.class)); assertNull(ctx.pipeline().context(WebSocketServerProtocolHandshakeHandler.class));
@ -86,7 +90,7 @@ public class WebSocketServerProtocolHandlerTest {
} }
@Test @Test
public void testSubsequentHttpRequestsAfterUpgradeShouldReturn403() throws Exception { public void testSubsequentHttpRequestsAfterUpgradeShouldReturn403() {
EmbeddedChannel ch = createChannel(); EmbeddedChannel ch = createChannel();
writeUpgradeRequest(ch); writeUpgradeRequest(ch);
@ -207,13 +211,160 @@ public class WebSocketServerProtocolHandlerTest {
assertFalse(ch.finish()); assertFalse(ch.finish());
} }
@Test
public void testExplicitCloseFrameSentWhenServerChannelClosed() throws Exception {
WebSocketCloseStatus closeStatus = WebSocketCloseStatus.ENDPOINT_UNAVAILABLE;
EmbeddedChannel client = createClient();
EmbeddedChannel server = createServer();
assertFalse(server.writeInbound(client.readOutbound()));
assertFalse(client.writeInbound(server.readOutbound()));
// When server channel closed with explicit close-frame
assertTrue(server.writeOutbound(new CloseWebSocketFrame(closeStatus)));
server.close();
// Then client receives provided close-frame
assertTrue(client.writeInbound(server.readOutbound()));
assertFalse(server.isOpen());
CloseWebSocketFrame closeMessage = client.readInbound();
assertEquals(closeMessage.statusCode(), closeStatus.code());
closeMessage.release();
client.close();
assertTrue(ReferenceCountUtil.release(client.readOutbound()));
assertFalse(client.finishAndReleaseAll());
assertFalse(server.finishAndReleaseAll());
}
@Test
public void testCloseFrameSentWhenServerChannelClosedSilently() throws Exception {
EmbeddedChannel client = createClient();
EmbeddedChannel server = createServer();
assertFalse(server.writeInbound(client.readOutbound()));
assertFalse(client.writeInbound(server.readOutbound()));
// When server channel closed without explicit close-frame
server.close();
// Then client receives NORMAL_CLOSURE close-frame
assertTrue(client.writeInbound(server.readOutbound()));
assertFalse(server.isOpen());
CloseWebSocketFrame closeMessage = client.readInbound();
assertEquals(closeMessage.statusCode(), WebSocketCloseStatus.NORMAL_CLOSURE.code());
closeMessage.release();
client.close();
assertTrue(ReferenceCountUtil.release(client.readOutbound()));
assertFalse(client.finishAndReleaseAll());
assertFalse(server.finishAndReleaseAll());
}
@Test
public void testExplicitCloseFrameSentWhenClientChannelClosed() throws Exception {
WebSocketCloseStatus closeStatus = WebSocketCloseStatus.INVALID_PAYLOAD_DATA;
EmbeddedChannel client = createClient();
EmbeddedChannel server = createServer();
assertFalse(server.writeInbound(client.readOutbound()));
assertFalse(client.writeInbound(server.readOutbound()));
// When client channel closed with explicit close-frame
assertTrue(client.writeOutbound(new CloseWebSocketFrame(closeStatus)));
client.close();
// Then client receives provided close-frame
assertFalse(server.writeInbound(client.readOutbound()));
assertFalse(client.isOpen());
assertFalse(server.isOpen());
CloseWebSocketFrame closeMessage = decode(server.<ByteBuf>readOutbound(), CloseWebSocketFrame.class);
assertEquals(closeMessage.statusCode(), closeStatus.code());
closeMessage.release();
assertFalse(client.finishAndReleaseAll());
assertFalse(server.finishAndReleaseAll());
}
@Test
public void testCloseFrameSentWhenClientChannelClosedSilently() throws Exception {
EmbeddedChannel client = createClient();
EmbeddedChannel server = createServer();
assertFalse(server.writeInbound(client.readOutbound()));
assertFalse(client.writeInbound(server.readOutbound()));
// When client channel closed without explicit close-frame
client.close();
// Then server receives NORMAL_CLOSURE close-frame
assertFalse(server.writeInbound(client.readOutbound()));
assertFalse(client.isOpen());
assertFalse(server.isOpen());
CloseWebSocketFrame closeMessage = decode(server.<ByteBuf>readOutbound(), CloseWebSocketFrame.class);
assertEquals(closeMessage, new CloseWebSocketFrame(WebSocketCloseStatus.NORMAL_CLOSURE));
closeMessage.release();
assertFalse(client.finishAndReleaseAll());
assertFalse(server.finishAndReleaseAll());
}
private EmbeddedChannel createClient(ChannelHandler... handlers) throws Exception {
WebSocketClientProtocolConfig clientConfig = WebSocketClientProtocolConfig.newBuilder()
.webSocketUri("http://test/test")
.dropPongFrames(false)
.handleCloseFrames(false)
.build();
EmbeddedChannel ch = new EmbeddedChannel(false, false,
new HttpClientCodec(),
new HttpObjectAggregator(8192),
new WebSocketClientProtocolHandler(clientConfig)
);
ch.pipeline().addLast(handlers);
ch.register();
return ch;
}
private EmbeddedChannel createServer(ChannelHandler... handlers) throws Exception {
WebSocketServerProtocolConfig serverConfig = WebSocketServerProtocolConfig.newBuilder()
.websocketPath("/test")
.dropPongFrames(false)
.build();
EmbeddedChannel ch = new EmbeddedChannel(false, false,
new HttpServerCodec(),
new HttpObjectAggregator(8192),
new WebSocketServerProtocolHandler(serverConfig)
);
ch.pipeline().addLast(handlers);
ch.register();
return ch;
}
@SuppressWarnings("SameParameterValue")
private <T> T decode(ByteBuf input, Class<T> clazz) {
EmbeddedChannel ch = new EmbeddedChannel(new WebSocket13FrameDecoder(true, false, 65536, true));
assertTrue(ch.writeInbound(input));
Object decoded = ch.readInbound();
assertNotNull(decoded);
assertFalse(ch.finish());
return clazz.cast(decoded);
}
private EmbeddedChannel createChannel() { private EmbeddedChannel createChannel() {
return createChannel(null); return createChannel(null);
} }
private EmbeddedChannel createChannel(ChannelHandler handler) { private EmbeddedChannel createChannel(ChannelHandler handler) {
WebSocketServerProtocolConfig serverConfig = WebSocketServerProtocolConfig.newBuilder()
.websocketPath("/test")
.sendCloseFrame(null)
.build();
return new EmbeddedChannel( return new EmbeddedChannel(
new WebSocketServerProtocolHandler("/test"), new WebSocketServerProtocolHandler(serverConfig),
new HttpRequestDecoder(), new HttpRequestDecoder(),
new HttpResponseEncoder(), new HttpResponseEncoder(),
new MockOutboundHandler(), new MockOutboundHandler(),
@ -231,13 +382,13 @@ public class WebSocketServerProtocolHandlerTest {
private class MockOutboundHandler extends ChannelOutboundHandlerAdapter { private class MockOutboundHandler extends ChannelOutboundHandlerAdapter {
@Override @Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
responses.add((FullHttpResponse) msg); responses.add((FullHttpResponse) msg);
promise.setSuccess(); promise.setSuccess();
} }
@Override @Override
public void flush(ChannelHandlerContext ctx) throws Exception { public void flush(ChannelHandlerContext ctx) {
} }
} }
@ -245,7 +396,7 @@ public class WebSocketServerProtocolHandlerTest {
private String content; private String content;
@Override @Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { public void channelRead(ChannelHandlerContext ctx, Object msg) {
assertNull(content); assertNull(content);
content = "processed: " + ((TextWebSocketFrame) msg).text(); content = "processed: " + ((TextWebSocketFrame) msg).text();
ReferenceCountUtil.release(msg); ReferenceCountUtil.release(msg);