409 lines
15 KiB
Java
409 lines
15 KiB
Java
/*
|
|
* Copyright 2012 The Netty Project
|
|
*
|
|
* The Netty Project licenses this file to you under the Apache License, version
|
|
* 2.0 (the "License"); you may not use this file except in compliance with the
|
|
* License. You may obtain a copy of the License at:
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* License for the specific language governing permissions and limitations under
|
|
* the License.
|
|
*/
|
|
package io.netty.handler.codec.http.websocketx;
|
|
|
|
import io.netty.buffer.ByteBuf;
|
|
import io.netty.channel.ChannelHandler;
|
|
import io.netty.channel.ChannelHandlerContext;
|
|
import io.netty.channel.ChannelInboundHandler;
|
|
import io.netty.channel.ChannelPromise;
|
|
import io.netty.channel.embedded.EmbeddedChannel;
|
|
import io.netty.handler.codec.http.DefaultFullHttpRequest;
|
|
import io.netty.handler.codec.http.FullHttpRequest;
|
|
import io.netty.handler.codec.http.FullHttpResponse;
|
|
import io.netty.handler.codec.http.HttpClientCodec;
|
|
import io.netty.handler.codec.http.HttpHeaderValues;
|
|
import io.netty.handler.codec.http.HttpMethod;
|
|
import io.netty.handler.codec.http.HttpObjectAggregator;
|
|
import io.netty.handler.codec.http.HttpRequest;
|
|
import io.netty.handler.codec.http.HttpRequestDecoder;
|
|
import io.netty.handler.codec.http.HttpResponseEncoder;
|
|
import io.netty.handler.codec.http.HttpServerCodec;
|
|
import io.netty.util.CharsetUtil;
|
|
import io.netty.util.ReferenceCountUtil;
|
|
import org.junit.Before;
|
|
import org.junit.Test;
|
|
|
|
import java.util.ArrayDeque;
|
|
import java.util.Queue;
|
|
|
|
import static io.netty.handler.codec.http.HttpResponseStatus.*;
|
|
import static io.netty.handler.codec.http.HttpVersion.*;
|
|
import static org.junit.Assert.*;
|
|
|
|
public class WebSocketServerProtocolHandlerTest {
|
|
|
|
private final Queue<FullHttpResponse> responses = new ArrayDeque<>();
|
|
|
|
@Before
|
|
public void setUp() {
|
|
responses.clear();
|
|
}
|
|
|
|
@Test
|
|
public void testHttpUpgradeRequest() {
|
|
EmbeddedChannel ch = createChannel(new MockOutboundHandler());
|
|
ChannelHandlerContext handshakerCtx = ch.pipeline().context(WebSocketServerProtocolHandshakeHandler.class);
|
|
writeUpgradeRequest(ch);
|
|
|
|
FullHttpResponse response = responses.remove();
|
|
assertEquals(SWITCHING_PROTOCOLS, response.status());
|
|
response.release();
|
|
assertNotNull(WebSocketServerProtocolHandler.getHandshaker(handshakerCtx.channel()));
|
|
assertFalse(ch.finish());
|
|
}
|
|
|
|
@Test
|
|
public void testWebSocketServerProtocolHandshakeHandlerReplacedBeforeHandshake() {
|
|
EmbeddedChannel ch = createChannel(new MockOutboundHandler());
|
|
ChannelHandlerContext handshakerCtx = ch.pipeline().context(WebSocketServerProtocolHandshakeHandler.class);
|
|
ch.pipeline().addLast(new ChannelHandler() {
|
|
@Override
|
|
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
|
|
if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
|
|
// We should have removed the handler already.
|
|
// assertNull(ctx.pipeline().context(WebSocketServerProtocolHandshakeHandler.class));
|
|
}
|
|
}
|
|
});
|
|
writeUpgradeRequest(ch);
|
|
|
|
FullHttpResponse response = responses.remove();
|
|
assertEquals(SWITCHING_PROTOCOLS, response.status());
|
|
response.release();
|
|
assertNotNull(WebSocketServerProtocolHandler.getHandshaker(handshakerCtx.channel()));
|
|
assertFalse(ch.finish());
|
|
}
|
|
|
|
@Test
|
|
public void testSubsequentHttpRequestsAfterUpgradeShouldReturn403() {
|
|
EmbeddedChannel ch = createChannel();
|
|
|
|
writeUpgradeRequest(ch);
|
|
|
|
FullHttpResponse response = responses.remove();
|
|
assertEquals(SWITCHING_PROTOCOLS, response.status());
|
|
response.release();
|
|
|
|
ch.writeInbound(new DefaultFullHttpRequest(HTTP_1_1, HttpMethod.GET, "/test"));
|
|
response = responses.remove();
|
|
assertEquals(FORBIDDEN, response.status());
|
|
response.release();
|
|
assertFalse(ch.finish());
|
|
}
|
|
|
|
@Test
|
|
public void testHttpUpgradeRequestInvalidUpgradeHeader() {
|
|
EmbeddedChannel ch = createChannel();
|
|
FullHttpRequest httpRequestWithEntity = new WebSocketRequestBuilder().httpVersion(HTTP_1_1)
|
|
.method(HttpMethod.GET)
|
|
.uri("/test")
|
|
.connection("Upgrade")
|
|
.version00()
|
|
.upgrade("BogusSocket")
|
|
.build();
|
|
|
|
ch.writeInbound(httpRequestWithEntity);
|
|
|
|
FullHttpResponse response = responses.remove();
|
|
assertEquals(BAD_REQUEST, response.status());
|
|
assertEquals("not a WebSocket handshake request: missing upgrade", getResponseMessage(response));
|
|
response.release();
|
|
assertFalse(ch.finish());
|
|
}
|
|
|
|
@Test
|
|
public void testHttpUpgradeRequestMissingWSKeyHeader() {
|
|
EmbeddedChannel ch = createChannel();
|
|
HttpRequest httpRequest = new WebSocketRequestBuilder().httpVersion(HTTP_1_1)
|
|
.method(HttpMethod.GET)
|
|
.uri("/test")
|
|
.key(null)
|
|
.connection("Upgrade")
|
|
.upgrade(HttpHeaderValues.WEBSOCKET)
|
|
.version13()
|
|
.build();
|
|
|
|
ch.writeInbound(httpRequest);
|
|
|
|
FullHttpResponse response = responses.remove();
|
|
assertEquals(BAD_REQUEST, response.status());
|
|
assertEquals("not a WebSocket request: missing key", getResponseMessage(response));
|
|
response.release();
|
|
assertFalse(ch.finish());
|
|
}
|
|
|
|
@Test
|
|
public void testCreateUTF8Validator() {
|
|
WebSocketServerProtocolConfig config = WebSocketServerProtocolConfig.newBuilder()
|
|
.websocketPath("/test")
|
|
.withUTF8Validator(true)
|
|
.build();
|
|
|
|
EmbeddedChannel ch = new EmbeddedChannel(
|
|
new WebSocketServerProtocolHandler(config),
|
|
new HttpRequestDecoder(),
|
|
new HttpResponseEncoder(),
|
|
new MockOutboundHandler());
|
|
writeUpgradeRequest(ch);
|
|
|
|
FullHttpResponse response = responses.remove();
|
|
assertEquals(SWITCHING_PROTOCOLS, response.status());
|
|
response.release();
|
|
|
|
assertNotNull(ch.pipeline().get(Utf8FrameValidator.class));
|
|
}
|
|
|
|
@Test
|
|
public void testDoNotCreateUTF8Validator() {
|
|
WebSocketServerProtocolConfig config = WebSocketServerProtocolConfig.newBuilder()
|
|
.websocketPath("/test")
|
|
.withUTF8Validator(false)
|
|
.build();
|
|
|
|
EmbeddedChannel ch = new EmbeddedChannel(
|
|
new WebSocketServerProtocolHandler(config),
|
|
new HttpRequestDecoder(),
|
|
new HttpResponseEncoder(),
|
|
new MockOutboundHandler());
|
|
writeUpgradeRequest(ch);
|
|
|
|
FullHttpResponse response = responses.remove();
|
|
assertEquals(SWITCHING_PROTOCOLS, response.status());
|
|
response.release();
|
|
|
|
assertNull(ch.pipeline().get(Utf8FrameValidator.class));
|
|
}
|
|
|
|
@Test
|
|
public void testHandleTextFrame() {
|
|
CustomTextFrameHandler customTextFrameHandler = new CustomTextFrameHandler();
|
|
EmbeddedChannel ch = createChannel(customTextFrameHandler);
|
|
writeUpgradeRequest(ch);
|
|
|
|
FullHttpResponse response = responses.remove();
|
|
assertEquals(SWITCHING_PROTOCOLS, response.status());
|
|
response.release();
|
|
|
|
if (ch.pipeline().context(HttpRequestDecoder.class) != null) {
|
|
// Removing the HttpRequestDecoder because we are writing a TextWebSocketFrame and thus
|
|
// decoding is not necessary.
|
|
ch.pipeline().remove(HttpRequestDecoder.class);
|
|
}
|
|
|
|
ch.writeInbound(new TextWebSocketFrame("payload"));
|
|
|
|
assertEquals("processed: payload", customTextFrameHandler.getContent());
|
|
assertFalse(ch.finish());
|
|
}
|
|
|
|
@Test
|
|
public void testExplicitCloseFrameSentWhenServerChannelClosed() throws Exception {
|
|
WebSocketCloseStatus closeStatus = WebSocketCloseStatus.ENDPOINT_UNAVAILABLE;
|
|
EmbeddedChannel client = createClient();
|
|
EmbeddedChannel server = createServer();
|
|
|
|
assertFalse(server.writeInbound((ByteBuf) client.readOutbound()));
|
|
assertFalse(client.writeInbound((ByteBuf) server.readOutbound()));
|
|
|
|
// When server channel closed with explicit close-frame
|
|
assertTrue(server.writeOutbound(new CloseWebSocketFrame(closeStatus)));
|
|
server.close();
|
|
|
|
// Then client receives provided close-frame
|
|
assertTrue(client.writeInbound((ByteBuf) server.readOutbound()));
|
|
assertFalse(server.isOpen());
|
|
|
|
CloseWebSocketFrame closeMessage = client.readInbound();
|
|
assertEquals(closeMessage.statusCode(), closeStatus.code());
|
|
closeMessage.release();
|
|
|
|
client.close();
|
|
assertTrue(ReferenceCountUtil.release(client.readOutbound()));
|
|
assertFalse(client.finishAndReleaseAll());
|
|
assertFalse(server.finishAndReleaseAll());
|
|
}
|
|
|
|
@Test
|
|
public void testCloseFrameSentWhenServerChannelClosedSilently() throws Exception {
|
|
EmbeddedChannel client = createClient();
|
|
EmbeddedChannel server = createServer();
|
|
|
|
assertFalse(server.writeInbound((ByteBuf) client.readOutbound()));
|
|
assertFalse(client.writeInbound((ByteBuf) server.readOutbound()));
|
|
|
|
// When server channel closed without explicit close-frame
|
|
server.close();
|
|
|
|
// Then client receives NORMAL_CLOSURE close-frame
|
|
assertTrue(client.writeInbound((ByteBuf) server.readOutbound()));
|
|
assertFalse(server.isOpen());
|
|
|
|
CloseWebSocketFrame closeMessage = client.readInbound();
|
|
assertEquals(closeMessage.statusCode(), WebSocketCloseStatus.NORMAL_CLOSURE.code());
|
|
closeMessage.release();
|
|
|
|
client.close();
|
|
assertTrue(ReferenceCountUtil.release(client.readOutbound()));
|
|
assertFalse(client.finishAndReleaseAll());
|
|
assertFalse(server.finishAndReleaseAll());
|
|
}
|
|
|
|
@Test
|
|
public void testExplicitCloseFrameSentWhenClientChannelClosed() throws Exception {
|
|
WebSocketCloseStatus closeStatus = WebSocketCloseStatus.INVALID_PAYLOAD_DATA;
|
|
EmbeddedChannel client = createClient();
|
|
EmbeddedChannel server = createServer();
|
|
|
|
assertFalse(server.writeInbound((ByteBuf) client.readOutbound()));
|
|
assertFalse(client.writeInbound((ByteBuf) server.readOutbound()));
|
|
|
|
// When client channel closed with explicit close-frame
|
|
assertTrue(client.writeOutbound(new CloseWebSocketFrame(closeStatus)));
|
|
client.close();
|
|
|
|
// Then client receives provided close-frame
|
|
assertFalse(server.writeInbound((ByteBuf) client.readOutbound()));
|
|
assertFalse(client.isOpen());
|
|
assertFalse(server.isOpen());
|
|
|
|
CloseWebSocketFrame closeMessage = decode(server.<ByteBuf>readOutbound(), CloseWebSocketFrame.class);
|
|
assertEquals(closeMessage.statusCode(), closeStatus.code());
|
|
closeMessage.release();
|
|
|
|
assertFalse(client.finishAndReleaseAll());
|
|
assertFalse(server.finishAndReleaseAll());
|
|
}
|
|
|
|
@Test
|
|
public void testCloseFrameSentWhenClientChannelClosedSilently() throws Exception {
|
|
EmbeddedChannel client = createClient();
|
|
EmbeddedChannel server = createServer();
|
|
|
|
assertFalse(server.writeInbound((ByteBuf) client.readOutbound()));
|
|
assertFalse(client.writeInbound((ByteBuf) server.readOutbound()));
|
|
|
|
// When client channel closed without explicit close-frame
|
|
client.close();
|
|
|
|
// Then server receives NORMAL_CLOSURE close-frame
|
|
assertFalse(server.writeInbound((ByteBuf) client.readOutbound()));
|
|
assertFalse(client.isOpen());
|
|
assertFalse(server.isOpen());
|
|
|
|
CloseWebSocketFrame closeMessage = decode(server.<ByteBuf>readOutbound(), CloseWebSocketFrame.class);
|
|
assertEquals(closeMessage, new CloseWebSocketFrame(WebSocketCloseStatus.NORMAL_CLOSURE));
|
|
closeMessage.release();
|
|
|
|
assertFalse(client.finishAndReleaseAll());
|
|
assertFalse(server.finishAndReleaseAll());
|
|
}
|
|
|
|
private EmbeddedChannel createClient(ChannelHandler... handlers) throws Exception {
|
|
WebSocketClientProtocolConfig clientConfig = WebSocketClientProtocolConfig.newBuilder()
|
|
.webSocketUri("http://test/test")
|
|
.dropPongFrames(false)
|
|
.handleCloseFrames(false)
|
|
.build();
|
|
EmbeddedChannel ch = new EmbeddedChannel(false, false,
|
|
new HttpClientCodec(),
|
|
new HttpObjectAggregator(8192),
|
|
new WebSocketClientProtocolHandler(clientConfig)
|
|
);
|
|
ch.pipeline().addLast(handlers);
|
|
ch.register();
|
|
return ch;
|
|
}
|
|
|
|
private EmbeddedChannel createServer(ChannelHandler... handlers) throws Exception {
|
|
WebSocketServerProtocolConfig serverConfig = WebSocketServerProtocolConfig.newBuilder()
|
|
.websocketPath("/test")
|
|
.dropPongFrames(false)
|
|
.build();
|
|
EmbeddedChannel ch = new EmbeddedChannel(false, false,
|
|
new HttpServerCodec(),
|
|
new HttpObjectAggregator(8192),
|
|
new WebSocketServerProtocolHandler(serverConfig)
|
|
);
|
|
ch.pipeline().addLast(handlers);
|
|
ch.register();
|
|
return ch;
|
|
}
|
|
|
|
@SuppressWarnings("SameParameterValue")
|
|
private <T> T decode(ByteBuf input, Class<T> clazz) {
|
|
EmbeddedChannel ch = new EmbeddedChannel(new WebSocket13FrameDecoder(true, false, 65536, true));
|
|
assertTrue(ch.writeInbound(input));
|
|
Object decoded = ch.readInbound();
|
|
assertNotNull(decoded);
|
|
assertFalse(ch.finish());
|
|
return clazz.cast(decoded);
|
|
}
|
|
|
|
private EmbeddedChannel createChannel() {
|
|
return createChannel(null);
|
|
}
|
|
|
|
private EmbeddedChannel createChannel(ChannelHandler handler) {
|
|
WebSocketServerProtocolConfig serverConfig = WebSocketServerProtocolConfig.newBuilder()
|
|
.websocketPath("/test")
|
|
.sendCloseFrame(null)
|
|
.build();
|
|
return new EmbeddedChannel(
|
|
new WebSocketServerProtocolHandler(serverConfig),
|
|
new HttpRequestDecoder(),
|
|
new HttpResponseEncoder(),
|
|
new MockOutboundHandler(),
|
|
handler);
|
|
}
|
|
|
|
private static void writeUpgradeRequest(EmbeddedChannel ch) {
|
|
ch.writeInbound(WebSocketRequestBuilder.successful());
|
|
}
|
|
|
|
private static String getResponseMessage(FullHttpResponse response) {
|
|
return response.content().toString(CharsetUtil.UTF_8);
|
|
}
|
|
|
|
private class MockOutboundHandler implements ChannelHandler {
|
|
|
|
@Override
|
|
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
|
|
responses.add((FullHttpResponse) msg);
|
|
promise.setSuccess();
|
|
}
|
|
|
|
@Override
|
|
public void flush(ChannelHandlerContext ctx) {
|
|
}
|
|
}
|
|
|
|
private static class CustomTextFrameHandler implements ChannelInboundHandler {
|
|
private String content;
|
|
|
|
@Override
|
|
public void channelRead(ChannelHandlerContext ctx, Object msg) {
|
|
assertNull(content);
|
|
content = "processed: " + ((TextWebSocketFrame) msg).text();
|
|
ReferenceCountUtil.release(msg);
|
|
}
|
|
|
|
String getContent() {
|
|
return content;
|
|
}
|
|
}
|
|
}
|