netty5/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketHandshakeHandOverT...

324 lines
14 KiB
Java

/*
* Copyright 2014 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License, version
* 2.0 (the "License"); you may not use this file except in compliance with the
* License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package io.netty.handler.codec.http.websocketx;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.EmptyHttpHeaders;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler.ClientHandshakeStateEvent;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler.ServerHandshakeStateEvent;
import org.junit.Before;
import org.junit.Test;
import java.net.URI;
import java.util.concurrent.CompletionException;
import static org.junit.Assert.*;
public class WebSocketHandshakeHandOverTest {
private boolean serverReceivedHandshake;
private WebSocketServerProtocolHandler.HandshakeComplete serverHandshakeComplete;
private boolean clientReceivedHandshake;
private boolean clientReceivedMessage;
private boolean serverReceivedCloseHandshake;
private boolean clientForceClosed;
private boolean clientHandshakeTimeout;
private final class CloseNoOpServerProtocolHandler extends WebSocketServerProtocolHandler {
CloseNoOpServerProtocolHandler(String websocketPath) {
super(WebSocketServerProtocolConfig.newBuilder()
.websocketPath(websocketPath)
.allowExtensions(false)
.sendCloseFrame(null)
.build());
}
@Override
protected void decode(ChannelHandlerContext ctx, WebSocketFrame frame) throws Exception {
if (frame instanceof CloseWebSocketFrame) {
serverReceivedCloseHandshake = true;
return;
}
super.decode(ctx, frame);
}
}
@Before
public void setUp() {
serverReceivedHandshake = false;
serverHandshakeComplete = null;
clientReceivedHandshake = false;
clientReceivedMessage = false;
serverReceivedCloseHandshake = false;
clientForceClosed = false;
clientHandshakeTimeout = false;
}
@Test
public void testHandover() throws Exception {
EmbeddedChannel serverChannel = createServerChannel(new SimpleChannelInboundHandler<Object>() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt == ServerHandshakeStateEvent.HANDSHAKE_COMPLETE) {
serverReceivedHandshake = true;
// immediately send a message to the client on connect
ctx.writeAndFlush(new TextWebSocketFrame("abc"));
} else if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
serverHandshakeComplete = (WebSocketServerProtocolHandler.HandshakeComplete) evt;
}
}
@Override
protected void messageReceived(ChannelHandlerContext ctx, Object msg) throws Exception {
}
});
EmbeddedChannel clientChannel = createClientChannel(new SimpleChannelInboundHandler<Object>() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt == ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
clientReceivedHandshake = true;
}
}
@Override
protected void messageReceived(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof TextWebSocketFrame) {
clientReceivedMessage = true;
}
}
});
// Transfer the handshake from the client to the server
transferAllDataWithMerge(clientChannel, serverChannel);
assertTrue(serverReceivedHandshake);
assertNotNull(serverHandshakeComplete);
assertEquals("/test", serverHandshakeComplete.requestUri());
assertEquals(8, serverHandshakeComplete.requestHeaders().size());
assertEquals("test-proto-2", serverHandshakeComplete.selectedSubprotocol());
// Transfer the handshake response and the websocket message to the client
transferAllDataWithMerge(serverChannel, clientChannel);
assertTrue(clientReceivedHandshake);
assertTrue(clientReceivedMessage);
}
@Test(expected = WebSocketHandshakeException.class)
public void testClientHandshakeTimeout() throws Throwable {
EmbeddedChannel serverChannel = createServerChannel(new SimpleChannelInboundHandler<Object>() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt == ServerHandshakeStateEvent.HANDSHAKE_COMPLETE) {
serverReceivedHandshake = true;
// immediately send a message to the client on connect
ctx.writeAndFlush(new TextWebSocketFrame("abc"));
} else if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
serverHandshakeComplete = (WebSocketServerProtocolHandler.HandshakeComplete) evt;
}
}
@Override
protected void messageReceived(ChannelHandlerContext ctx, Object msg) throws Exception {
}
});
EmbeddedChannel clientChannel = createClientChannel(new SimpleChannelInboundHandler<Object>() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt == ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
clientReceivedHandshake = true;
} else if (evt == ClientHandshakeStateEvent.HANDSHAKE_TIMEOUT) {
clientHandshakeTimeout = true;
}
}
@Override
protected void messageReceived(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof TextWebSocketFrame) {
clientReceivedMessage = true;
}
}
}, 100);
// Client send the handshake request to server
transferAllDataWithMerge(clientChannel, serverChannel);
// Server do not send the response back
// transferAllDataWithMerge(serverChannel, clientChannel);
WebSocketClientProtocolHandshakeHandler handshakeHandler =
(WebSocketClientProtocolHandshakeHandler) clientChannel
.pipeline().get(WebSocketClientProtocolHandshakeHandler.class.getName());
while (!handshakeHandler.getHandshakeFuture().isDone()) {
Thread.sleep(10);
// We need to run all pending tasks as the handshake timeout is scheduled on the EventLoop.
clientChannel.runScheduledPendingTasks();
}
assertTrue(clientHandshakeTimeout);
assertFalse(clientReceivedHandshake);
assertFalse(clientReceivedMessage);
// Should throw WebSocketHandshakeException
try {
handshakeHandler.getHandshakeFuture().syncUninterruptibly();
} catch (CompletionException e) {
throw e.getCause();
} finally {
serverChannel.finishAndReleaseAll();
}
}
@Test(timeout = 10000)
public void testClientHandshakerForceClose() throws Exception {
final WebSocketClientHandshaker handshaker = WebSocketClientHandshakerFactory.newHandshaker(
new URI("ws://localhost:1234/test"), WebSocketVersion.V13, null, true,
EmptyHttpHeaders.INSTANCE, Integer.MAX_VALUE, true, false, 20);
EmbeddedChannel serverChannel = createServerChannel(
new CloseNoOpServerProtocolHandler("/test"),
new SimpleChannelInboundHandler<Object>() {
@Override
protected void messageReceived(ChannelHandlerContext ctx, Object msg) throws Exception {
}
});
EmbeddedChannel clientChannel = createClientChannel(handshaker, new SimpleChannelInboundHandler<Object>() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt == ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
ctx.channel().closeFuture().addListener((ChannelFutureListener) future -> clientForceClosed = true);
handshaker.close(ctx.channel(), new CloseWebSocketFrame());
}
}
@Override
protected void messageReceived(ChannelHandlerContext ctx, Object msg) throws Exception {
}
});
// Transfer the handshake from the client to the server
transferAllDataWithMerge(clientChannel, serverChannel);
// Transfer the handshake from the server to client
transferAllDataWithMerge(serverChannel, clientChannel);
// Transfer closing handshake
transferAllDataWithMerge(clientChannel, serverChannel);
assertTrue(serverReceivedCloseHandshake);
// Should not be closed yet as we disabled closing the connection on the server
assertFalse(clientForceClosed);
while (!clientForceClosed) {
Thread.sleep(10);
// We need to run all pending tasks as the force close timeout is scheduled on the EventLoop.
clientChannel.runPendingTasks();
}
// clientForceClosed would be set to TRUE after any close,
// so check here that force close timeout was actually fired
assertTrue(handshaker.isForceCloseComplete());
// Both should be empty
assertFalse(serverChannel.finishAndReleaseAll());
assertFalse(clientChannel.finishAndReleaseAll());
}
/**
* Transfers all pending data from the source channel into the destination channel.<br>
* Merges all data into a single buffer before transmission into the destination.
* @param srcChannel The source channel
* @param dstChannel The destination channel
*/
private static void transferAllDataWithMerge(EmbeddedChannel srcChannel, EmbeddedChannel dstChannel) {
ByteBuf mergedBuffer = null;
for (;;) {
Object srcData = srcChannel.readOutbound();
if (srcData != null) {
assertTrue(srcData instanceof ByteBuf);
ByteBuf srcBuf = (ByteBuf) srcData;
try {
if (mergedBuffer == null) {
mergedBuffer = Unpooled.buffer();
}
mergedBuffer.writeBytes(srcBuf);
} finally {
srcBuf.release();
}
} else {
break;
}
}
if (mergedBuffer != null) {
dstChannel.writeInbound(mergedBuffer);
}
}
private static EmbeddedChannel createClientChannel(ChannelHandler handler) throws Exception {
return createClientChannel(handler, WebSocketClientProtocolConfig.newBuilder()
.webSocketUri("ws://localhost:1234/test")
.subprotocol("test-proto-2")
.build());
}
private static EmbeddedChannel createClientChannel(ChannelHandler handler, long timeoutMillis) throws Exception {
return createClientChannel(handler, WebSocketClientProtocolConfig.newBuilder()
.webSocketUri("ws://localhost:1234/test")
.subprotocol("test-proto-2")
.handshakeTimeoutMillis(timeoutMillis)
.build());
}
private static EmbeddedChannel createClientChannel(ChannelHandler handler, WebSocketClientProtocolConfig config) {
return new EmbeddedChannel(
new HttpClientCodec(),
new HttpObjectAggregator(8192),
new WebSocketClientProtocolHandler(config),
handler);
}
private static EmbeddedChannel createClientChannel(WebSocketClientHandshaker handshaker,
ChannelHandler handler) throws Exception {
return new EmbeddedChannel(
new HttpClientCodec(),
new HttpObjectAggregator(8192),
// Note that we're switching off close frames handling on purpose to test forced close on timeout.
new WebSocketClientProtocolHandler(handshaker, false, false),
handler);
}
private static EmbeddedChannel createServerChannel(ChannelHandler handler) {
return new EmbeddedChannel(
new HttpServerCodec(),
new HttpObjectAggregator(8192),
new WebSocketServerProtocolHandler("/test", "test-proto-1, test-proto-2", false),
handler);
}
private static EmbeddedChannel createServerChannel(WebSocketServerProtocolHandler webSocketHandler,
ChannelHandler handler) {
return new EmbeddedChannel(
new HttpServerCodec(),
new HttpObjectAggregator(8192),
webSocketHandler,
handler);
}
}