Send close frame on channel close, when this frame was not send manually (#9745)

Motivation:
By default CloseWebSocketFrames are handled automatically.
However I need manually manage their sending both on client- and on server-sides.

Modification:
Send close frame on channel close automatically, when it was not send before explicitly.

Result:
No more messages like "Connection closed by remote peer" for normal close flows.
This commit is contained in:
ursa 2019-11-18 19:32:21 +00:00 committed by Norman Maurer
parent 3bb75198cd
commit 7ff8cde66f
8 changed files with 357 additions and 36 deletions

View File

@ -30,8 +30,8 @@ import static io.netty.util.internal.ObjectUtil.checkPositive;
public final class WebSocketClientProtocolConfig {
static final WebSocketClientProtocolConfig DEFAULT = new WebSocketClientProtocolConfig(
URI.create("https://localhost/"), null, WebSocketVersion.V13, false,
EmptyHttpHeaders.INSTANCE, 65536, true, false, true, true, 10000L, -1, false);
URI.create("https://localhost/"), null, WebSocketVersion.V13, false, EmptyHttpHeaders.INSTANCE,
65536, true, false, true, WebSocketCloseStatus.NORMAL_CLOSURE, true, 10000L, -1, false);
private final URI webSocketUri;
private final String subprotocol;
@ -42,6 +42,7 @@ public final class WebSocketClientProtocolConfig {
private final boolean performMasking;
private final boolean allowMaskMismatch;
private final boolean handleCloseFrames;
private final WebSocketCloseStatus sendCloseFrame;
private final boolean dropPongFrames;
private final long handshakeTimeoutMillis;
private final long forceCloseTimeoutMillis;
@ -57,6 +58,7 @@ public final class WebSocketClientProtocolConfig {
boolean performMasking,
boolean allowMaskMismatch,
boolean handleCloseFrames,
WebSocketCloseStatus sendCloseFrame,
boolean dropPongFrames,
long handshakeTimeoutMillis,
long forceCloseTimeoutMillis,
@ -72,6 +74,7 @@ public final class WebSocketClientProtocolConfig {
this.allowMaskMismatch = allowMaskMismatch;
this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
this.handleCloseFrames = handleCloseFrames;
this.sendCloseFrame = sendCloseFrame;
this.dropPongFrames = dropPongFrames;
this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis");
this.absoluteUpgradeUrl = absoluteUpgradeUrl;
@ -113,6 +116,10 @@ public final class WebSocketClientProtocolConfig {
return handleCloseFrames;
}
public WebSocketCloseStatus sendCloseFrame() {
return sendCloseFrame;
}
public boolean dropPongFrames() {
return dropPongFrames;
}
@ -141,6 +148,7 @@ public final class WebSocketClientProtocolConfig {
", performMasking=" + performMasking +
", allowMaskMismatch=" + allowMaskMismatch +
", handleCloseFrames=" + handleCloseFrames +
", sendCloseFrame=" + sendCloseFrame +
", dropPongFrames=" + dropPongFrames +
", handshakeTimeoutMillis=" + handshakeTimeoutMillis +
", forceCloseTimeoutMillis=" + forceCloseTimeoutMillis +
@ -166,6 +174,7 @@ public final class WebSocketClientProtocolConfig {
private boolean performMasking;
private boolean allowMaskMismatch;
private boolean handleCloseFrames;
private WebSocketCloseStatus sendCloseFrame;
private boolean dropPongFrames;
private long handshakeTimeoutMillis;
private long forceCloseTimeoutMillis;
@ -174,19 +183,20 @@ public final class WebSocketClientProtocolConfig {
private Builder(WebSocketClientProtocolConfig clientConfig) {
ObjectUtil.checkNotNull(clientConfig, "clientConfig");
this.webSocketUri = clientConfig.webSocketUri();
this.subprotocol = clientConfig.subprotocol();
this.version = clientConfig.version();
this.allowExtensions = clientConfig.allowExtensions();
this.customHeaders = clientConfig.customHeaders();
this.maxFramePayloadLength = clientConfig.maxFramePayloadLength();
this.performMasking = clientConfig.performMasking();
this.allowMaskMismatch = clientConfig.allowMaskMismatch();
this.handleCloseFrames = clientConfig.handleCloseFrames();
this.dropPongFrames = clientConfig.dropPongFrames();
this.handshakeTimeoutMillis = clientConfig.handshakeTimeoutMillis();
this.forceCloseTimeoutMillis = clientConfig.forceCloseTimeoutMillis();
this.absoluteUpgradeUrl = clientConfig.absoluteUpgradeUrl();
webSocketUri = clientConfig.webSocketUri();
subprotocol = clientConfig.subprotocol();
version = clientConfig.version();
allowExtensions = clientConfig.allowExtensions();
customHeaders = clientConfig.customHeaders();
maxFramePayloadLength = clientConfig.maxFramePayloadLength();
performMasking = clientConfig.performMasking();
allowMaskMismatch = clientConfig.allowMaskMismatch();
handleCloseFrames = clientConfig.handleCloseFrames();
sendCloseFrame = clientConfig.sendCloseFrame();
dropPongFrames = clientConfig.dropPongFrames();
handshakeTimeoutMillis = clientConfig.handshakeTimeoutMillis();
forceCloseTimeoutMillis = clientConfig.forceCloseTimeoutMillis();
absoluteUpgradeUrl = clientConfig.absoluteUpgradeUrl();
}
/**
@ -272,6 +282,14 @@ public final class WebSocketClientProtocolConfig {
return this;
}
/**
* Close frame to send, when close frame was not send manually. Or {@code null} to disable proper close.
*/
public Builder sendCloseFrame(WebSocketCloseStatus sendCloseFrame) {
this.sendCloseFrame = sendCloseFrame;
return this;
}
/**
* {@code true} if pong frames should not be forwarded
*/
@ -319,6 +337,7 @@ public final class WebSocketClientProtocolConfig {
performMasking,
allowMaskMismatch,
handleCloseFrames,
sendCloseFrame,
dropPongFrames,
handshakeTimeoutMillis,
forceCloseTimeoutMillis,

View File

@ -42,8 +42,7 @@ import static io.netty.util.internal.ObjectUtil.*;
*/
public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
private final WebSocketClientHandshaker handshaker;
private final boolean handleCloseFrames;
private final long handshakeTimeoutMillis;
private final WebSocketClientProtocolConfig clientConfig;
/**
* Returns the used handshaker
@ -92,8 +91,7 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
clientConfig.forceCloseTimeoutMillis(),
clientConfig.absoluteUpgradeUrl()
);
this.handleCloseFrames = clientConfig.handleCloseFrames();
this.handshakeTimeoutMillis = clientConfig.handshakeTimeoutMillis();
this.clientConfig = clientConfig;
}
/**
@ -327,8 +325,10 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
boolean dropPongFrames, long handshakeTimeoutMillis) {
super(dropPongFrames);
this.handshaker = handshaker;
this.handleCloseFrames = handleCloseFrames;
this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis");
this.clientConfig = WebSocketClientProtocolConfig.newBuilder()
.handleCloseFrames(handleCloseFrames)
.handshakeTimeoutMillis(handshakeTimeoutMillis)
.build();
}
/**
@ -358,7 +358,7 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
@Override
protected void decode(ChannelHandlerContext ctx, WebSocketFrame frame, List<Object> out) throws Exception {
if (handleCloseFrames && frame instanceof CloseWebSocketFrame) {
if (clientConfig.handleCloseFrames() && frame instanceof CloseWebSocketFrame) {
ctx.close();
return;
}
@ -371,12 +371,16 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
if (cp.get(WebSocketClientProtocolHandshakeHandler.class) == null) {
// Add the WebSocketClientProtocolHandshakeHandler before this one.
ctx.pipeline().addBefore(ctx.name(), WebSocketClientProtocolHandshakeHandler.class.getName(),
new WebSocketClientProtocolHandshakeHandler(handshaker, handshakeTimeoutMillis));
new WebSocketClientProtocolHandshakeHandler(handshaker, clientConfig.handshakeTimeoutMillis()));
}
if (cp.get(Utf8FrameValidator.class) == null) {
// Add the UFT8 checking before this one.
ctx.pipeline().addBefore(ctx.name(), Utf8FrameValidator.class.getName(),
new Utf8FrameValidator());
}
if (clientConfig.sendCloseFrame() != null) {
cp.addBefore(ctx.name(), WebSocketCloseFrameHandler.class.getName(),
new WebSocketCloseFrameHandler(clientConfig.sendCloseFrame(), clientConfig.forceCloseTimeoutMillis()));
}
}
}

View File

@ -0,0 +1,97 @@
/*
* Copyright 2019 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.http.websocketx;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.ScheduledFuture;
import io.netty.util.internal.ObjectUtil;
import java.nio.channels.ClosedChannelException;
import java.util.concurrent.TimeUnit;
/**
* Send {@link CloseWebSocketFrame} message on channel close, if close frame was not sent before.
*/
final class WebSocketCloseFrameHandler extends ChannelOutboundHandlerAdapter {
private final WebSocketCloseStatus closeStatus;
private final long forceCloseTimeoutMillis;
private ChannelPromise closeSent;
WebSocketCloseFrameHandler(WebSocketCloseStatus closeStatus, long forceCloseTimeoutMillis) {
this.closeStatus = ObjectUtil.checkNotNull(closeStatus, "closeStatus");
this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
}
@Override
public void close(final ChannelHandlerContext ctx, final ChannelPromise promise) throws Exception {
if (!ctx.channel().isActive()) {
ctx.close(promise);
return;
}
if (closeSent == null) {
write(ctx, new CloseWebSocketFrame(closeStatus), ctx.newPromise());
}
flush(ctx);
applyCloseSentTimeout(ctx);
closeSent.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
ctx.close(promise);
}
});
}
@Override
public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
if (closeSent != null) {
ReferenceCountUtil.release(msg);
promise.setFailure(new ClosedChannelException());
return;
}
if (msg instanceof CloseWebSocketFrame) {
promise = promise.unvoid();
closeSent = promise;
}
super.write(ctx, msg, promise);
}
private void applyCloseSentTimeout(ChannelHandlerContext ctx) {
if (closeSent.isDone() || forceCloseTimeoutMillis < 0) {
return;
}
final ScheduledFuture<?> timeoutTask = ctx.executor().schedule(new Runnable() {
@Override
public void run() {
if (!closeSent.isDone()) {
closeSent.tryFailure(new WebSocketHandshakeException("send close frame timed out"));
}
}
}, forceCloseTimeoutMillis, TimeUnit.MILLISECONDS);
closeSent.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
timeoutTask.cancel(false);
}
});
}
}

View File

@ -25,14 +25,16 @@ import static io.netty.util.internal.ObjectUtil.checkPositive;
*/
public final class WebSocketServerProtocolConfig {
static final WebSocketServerProtocolConfig DEFAULT =
new WebSocketServerProtocolConfig("/", null, false, 10000L, true, true, WebSocketDecoderConfig.DEFAULT);
static final WebSocketServerProtocolConfig DEFAULT = new WebSocketServerProtocolConfig(
"/", null, false, 10000L, 0, true, WebSocketCloseStatus.NORMAL_CLOSURE, true, WebSocketDecoderConfig.DEFAULT);
private final String websocketPath;
private final String subprotocols;
private final boolean checkStartsWith;
private final long handshakeTimeoutMillis;
private final long forceCloseTimeoutMillis;
private final boolean handleCloseFrames;
private final WebSocketCloseStatus sendCloseFrame;
private final boolean dropPongFrames;
private final WebSocketDecoderConfig decoderConfig;
@ -41,7 +43,9 @@ public final class WebSocketServerProtocolConfig {
String subprotocols,
boolean checkStartsWith,
long handshakeTimeoutMillis,
long forceCloseTimeoutMillis,
boolean handleCloseFrames,
WebSocketCloseStatus sendCloseFrame,
boolean dropPongFrames,
WebSocketDecoderConfig decoderConfig
) {
@ -49,7 +53,9 @@ public final class WebSocketServerProtocolConfig {
this.subprotocols = subprotocols;
this.checkStartsWith = checkStartsWith;
this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis");
this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
this.handleCloseFrames = handleCloseFrames;
this.sendCloseFrame = sendCloseFrame;
this.dropPongFrames = dropPongFrames;
this.decoderConfig = decoderConfig == null ? WebSocketDecoderConfig.DEFAULT : decoderConfig;
}
@ -70,10 +76,18 @@ public final class WebSocketServerProtocolConfig {
return handshakeTimeoutMillis;
}
public long forceCloseTimeoutMillis() {
return forceCloseTimeoutMillis;
}
public boolean handleCloseFrames() {
return handleCloseFrames;
}
public WebSocketCloseStatus sendCloseFrame() {
return sendCloseFrame;
}
public boolean dropPongFrames() {
return dropPongFrames;
}
@ -89,7 +103,9 @@ public final class WebSocketServerProtocolConfig {
", subprotocols=" + subprotocols +
", checkStartsWith=" + checkStartsWith +
", handshakeTimeoutMillis=" + handshakeTimeoutMillis +
", forceCloseTimeoutMillis=" + forceCloseTimeoutMillis +
", handleCloseFrames=" + handleCloseFrames +
", sendCloseFrame=" + sendCloseFrame +
", dropPongFrames=" + dropPongFrames +
", decoderConfig=" + decoderConfig +
"}";
@ -108,7 +124,9 @@ public final class WebSocketServerProtocolConfig {
private String subprotocols;
private boolean checkStartsWith;
private long handshakeTimeoutMillis;
private long forceCloseTimeoutMillis;
private boolean handleCloseFrames;
private WebSocketCloseStatus sendCloseFrame;
private boolean dropPongFrames;
private WebSocketDecoderConfig decoderConfig;
private WebSocketDecoderConfig.Builder decoderConfigBuilder;
@ -119,7 +137,9 @@ public final class WebSocketServerProtocolConfig {
subprotocols = serverConfig.subprotocols();
checkStartsWith = serverConfig.checkStartsWith();
handshakeTimeoutMillis = serverConfig.handshakeTimeoutMillis();
forceCloseTimeoutMillis = serverConfig.forceCloseTimeoutMillis();
handleCloseFrames = serverConfig.handleCloseFrames();
sendCloseFrame = serverConfig.sendCloseFrame();
dropPongFrames = serverConfig.dropPongFrames();
decoderConfig = serverConfig.decoderConfig();
}
@ -158,6 +178,14 @@ public final class WebSocketServerProtocolConfig {
return this;
}
/**
* Close the connection if it was not closed by the client after timeout specified
*/
public Builder forceCloseTimeoutMillis(long forceCloseTimeoutMillis) {
this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
return this;
}
/**
* {@code true} if close frames should not be forwarded and just close the channel
*/
@ -166,6 +194,14 @@ public final class WebSocketServerProtocolConfig {
return this;
}
/**
* Close frame to send, when close frame was not send manually. Or {@code null} to disable proper close.
*/
public Builder sendCloseFrame(WebSocketCloseStatus sendCloseFrame) {
this.sendCloseFrame = sendCloseFrame;
return this;
}
/**
* {@code true} if pong frames should not be forwarded
*/
@ -229,7 +265,9 @@ public final class WebSocketServerProtocolConfig {
subprotocols,
checkStartsWith,
handshakeTimeoutMillis,
forceCloseTimeoutMillis,
handleCloseFrames,
sendCloseFrame,
dropPongFrames,
decoderConfigBuilder == null ? decoderConfig : decoderConfigBuilder.build()
);

View File

@ -228,6 +228,10 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
cp.addBefore(ctx.name(), Utf8FrameValidator.class.getName(),
new Utf8FrameValidator());
}
if (serverConfig.sendCloseFrame() != null) {
cp.addBefore(ctx.name(), WebSocketCloseFrameHandler.class.getName(),
new WebSocketCloseFrameHandler(serverConfig.sendCloseFrame(), serverConfig.forceCloseTimeoutMillis()));
}
}
@Override

View File

@ -78,6 +78,9 @@ public class WebSocket08EncoderDecoderTest {
Assert.assertEquals(errorMessage, response.reasonText());
response.release();
Assert.assertFalse(inChannel.finish());
Assert.assertFalse(outChannel.finish());
// Without auto-close
config = WebSocketDecoderConfig.newBuilder()
.maxFramePayloadLength(maxPayloadLength)
@ -91,10 +94,11 @@ public class WebSocket08EncoderDecoderTest {
response = inChannel.readOutbound();
Assert.assertNull(response);
// Release test data
binTestData.release();
Assert.assertFalse(inChannel.finish());
Assert.assertFalse(outChannel.finish());
// Release test data
binTestData.release();
}
private void executeProtocolViolationTest(EmbeddedChannel outChannel, EmbeddedChannel inChannel,

View File

@ -50,7 +50,11 @@ public class WebSocketHandshakeHandOverTest {
private final class CloseNoOpServerProtocolHandler extends WebSocketServerProtocolHandler {
CloseNoOpServerProtocolHandler(String websocketPath) {
super(websocketPath, null, false);
super(WebSocketServerProtocolConfig.newBuilder()
.websocketPath(websocketPath)
.allowExtensions(false)
.sendCloseFrame(null)
.build());
}
@Override

View File

@ -15,6 +15,7 @@
*/
package io.netty.handler.codec.http.websocketx;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandler;
@ -23,11 +24,14 @@ import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpRequestDecoder;
import io.netty.handler.codec.http.HttpResponseEncoder;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.util.CharsetUtil;
import io.netty.util.ReferenceCountUtil;
import org.junit.Before;
@ -50,7 +54,7 @@ public class WebSocketServerProtocolHandlerTest {
}
@Test
public void testHttpUpgradeRequest() throws Exception {
public void testHttpUpgradeRequest() {
EmbeddedChannel ch = createChannel(new MockOutboundHandler());
ChannelHandlerContext handshakerCtx = ch.pipeline().context(WebSocketServerProtocolHandshakeHandler.class);
writeUpgradeRequest(ch);
@ -63,12 +67,12 @@ public class WebSocketServerProtocolHandlerTest {
}
@Test
public void testWebSocketServerProtocolHandshakeHandlerReplacedBeforeHandshake() throws Exception {
public void testWebSocketServerProtocolHandshakeHandlerReplacedBeforeHandshake() {
EmbeddedChannel ch = createChannel(new MockOutboundHandler());
ChannelHandlerContext handshakerCtx = ch.pipeline().context(WebSocketServerProtocolHandshakeHandler.class);
ch.pipeline().addLast(new ChannelHandler() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
// We should have removed the handler already.
assertNull(ctx.pipeline().context(WebSocketServerProtocolHandshakeHandler.class));
@ -85,7 +89,7 @@ public class WebSocketServerProtocolHandlerTest {
}
@Test
public void testSubsequentHttpRequestsAfterUpgradeShouldReturn403() throws Exception {
public void testSubsequentHttpRequestsAfterUpgradeShouldReturn403() {
EmbeddedChannel ch = createChannel();
writeUpgradeRequest(ch);
@ -206,13 +210,160 @@ public class WebSocketServerProtocolHandlerTest {
assertFalse(ch.finish());
}
@Test
public void testExplicitCloseFrameSentWhenServerChannelClosed() throws Exception {
WebSocketCloseStatus closeStatus = WebSocketCloseStatus.ENDPOINT_UNAVAILABLE;
EmbeddedChannel client = createClient();
EmbeddedChannel server = createServer();
assertFalse(server.writeInbound(client.readOutbound()));
assertFalse(client.writeInbound(server.readOutbound()));
// When server channel closed with explicit close-frame
assertTrue(server.writeOutbound(new CloseWebSocketFrame(closeStatus)));
server.close();
// Then client receives provided close-frame
assertTrue(client.writeInbound(server.readOutbound()));
assertFalse(server.isOpen());
CloseWebSocketFrame closeMessage = client.readInbound();
assertEquals(closeMessage.statusCode(), closeStatus.code());
closeMessage.release();
client.close();
assertTrue(ReferenceCountUtil.release(client.readOutbound()));
assertFalse(client.finishAndReleaseAll());
assertFalse(server.finishAndReleaseAll());
}
@Test
public void testCloseFrameSentWhenServerChannelClosedSilently() throws Exception {
EmbeddedChannel client = createClient();
EmbeddedChannel server = createServer();
assertFalse(server.writeInbound(client.readOutbound()));
assertFalse(client.writeInbound(server.readOutbound()));
// When server channel closed without explicit close-frame
server.close();
// Then client receives NORMAL_CLOSURE close-frame
assertTrue(client.writeInbound(server.readOutbound()));
assertFalse(server.isOpen());
CloseWebSocketFrame closeMessage = client.readInbound();
assertEquals(closeMessage.statusCode(), WebSocketCloseStatus.NORMAL_CLOSURE.code());
closeMessage.release();
client.close();
assertTrue(ReferenceCountUtil.release(client.readOutbound()));
assertFalse(client.finishAndReleaseAll());
assertFalse(server.finishAndReleaseAll());
}
@Test
public void testExplicitCloseFrameSentWhenClientChannelClosed() throws Exception {
WebSocketCloseStatus closeStatus = WebSocketCloseStatus.INVALID_PAYLOAD_DATA;
EmbeddedChannel client = createClient();
EmbeddedChannel server = createServer();
assertFalse(server.writeInbound(client.readOutbound()));
assertFalse(client.writeInbound(server.readOutbound()));
// When client channel closed with explicit close-frame
assertTrue(client.writeOutbound(new CloseWebSocketFrame(closeStatus)));
client.close();
// Then client receives provided close-frame
assertFalse(server.writeInbound(client.readOutbound()));
assertFalse(client.isOpen());
assertFalse(server.isOpen());
CloseWebSocketFrame closeMessage = decode(server.<ByteBuf>readOutbound(), CloseWebSocketFrame.class);
assertEquals(closeMessage.statusCode(), closeStatus.code());
closeMessage.release();
assertFalse(client.finishAndReleaseAll());
assertFalse(server.finishAndReleaseAll());
}
@Test
public void testCloseFrameSentWhenClientChannelClosedSilently() throws Exception {
EmbeddedChannel client = createClient();
EmbeddedChannel server = createServer();
assertFalse(server.writeInbound(client.readOutbound()));
assertFalse(client.writeInbound(server.readOutbound()));
// When client channel closed without explicit close-frame
client.close();
// Then server receives NORMAL_CLOSURE close-frame
assertFalse(server.writeInbound(client.readOutbound()));
assertFalse(client.isOpen());
assertFalse(server.isOpen());
CloseWebSocketFrame closeMessage = decode(server.<ByteBuf>readOutbound(), CloseWebSocketFrame.class);
assertEquals(closeMessage, new CloseWebSocketFrame(WebSocketCloseStatus.NORMAL_CLOSURE));
closeMessage.release();
assertFalse(client.finishAndReleaseAll());
assertFalse(server.finishAndReleaseAll());
}
private EmbeddedChannel createClient(ChannelHandler... handlers) throws Exception {
WebSocketClientProtocolConfig clientConfig = WebSocketClientProtocolConfig.newBuilder()
.webSocketUri("http://test/test")
.dropPongFrames(false)
.handleCloseFrames(false)
.build();
EmbeddedChannel ch = new EmbeddedChannel(false, false,
new HttpClientCodec(),
new HttpObjectAggregator(8192),
new WebSocketClientProtocolHandler(clientConfig)
);
ch.pipeline().addLast(handlers);
ch.register();
return ch;
}
private EmbeddedChannel createServer(ChannelHandler... handlers) throws Exception {
WebSocketServerProtocolConfig serverConfig = WebSocketServerProtocolConfig.newBuilder()
.websocketPath("/test")
.dropPongFrames(false)
.build();
EmbeddedChannel ch = new EmbeddedChannel(false, false,
new HttpServerCodec(),
new HttpObjectAggregator(8192),
new WebSocketServerProtocolHandler(serverConfig)
);
ch.pipeline().addLast(handlers);
ch.register();
return ch;
}
@SuppressWarnings("SameParameterValue")
private <T> T decode(ByteBuf input, Class<T> clazz) {
EmbeddedChannel ch = new EmbeddedChannel(new WebSocket13FrameDecoder(true, false, 65536, true));
assertTrue(ch.writeInbound(input));
Object decoded = ch.readInbound();
assertNotNull(decoded);
assertFalse(ch.finish());
return clazz.cast(decoded);
}
private EmbeddedChannel createChannel() {
return createChannel(null);
}
private EmbeddedChannel createChannel(ChannelHandler handler) {
WebSocketServerProtocolConfig serverConfig = WebSocketServerProtocolConfig.newBuilder()
.websocketPath("/test")
.sendCloseFrame(null)
.build();
return new EmbeddedChannel(
new WebSocketServerProtocolHandler("/test"),
new WebSocketServerProtocolHandler(serverConfig),
new HttpRequestDecoder(),
new HttpResponseEncoder(),
new MockOutboundHandler(),
@ -230,13 +381,13 @@ public class WebSocketServerProtocolHandlerTest {
private class MockOutboundHandler implements ChannelHandler {
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
responses.add((FullHttpResponse) msg);
promise.setSuccess();
}
@Override
public void flush(ChannelHandlerContext ctx) throws Exception {
public void flush(ChannelHandlerContext ctx) {
}
}
@ -244,7 +395,7 @@ public class WebSocketServerProtocolHandlerTest {
private String content;
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
public void channelRead(ChannelHandlerContext ctx, Object msg) {
assertNull(content);
content = "processed: " + ((TextWebSocketFrame) msg).text();
ReferenceCountUtil.release(msg);