From 21b9dd00cd8c263f5b312996ecaa74a1e608547c Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 11 Sep 2012 08:40:33 +0200 Subject: [PATCH] WebSocket enhancements for 3.x --- .../WebSocketServerProtocolHandler.java | 114 ++++++++++++ ...bSocketServerProtocolHandshakeHandler.java | 107 +++++++++++ .../websocketx/WebSocketRequestBuilder.java | 132 +++++++++++++ .../WebSocketServerProtocolHandlerTest.java | 176 ++++++++++++++++++ 4 files changed, 529 insertions(+) create mode 100644 src/main/java/org/jboss/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandler.java create mode 100644 src/main/java/org/jboss/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java create mode 100644 src/test/java/org/jboss/netty/handler/codec/http/websocketx/WebSocketRequestBuilder.java create mode 100644 src/test/java/org/jboss/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandlerTest.java diff --git a/src/main/java/org/jboss/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandler.java b/src/main/java/org/jboss/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandler.java new file mode 100644 index 0000000000..e8245984ad --- /dev/null +++ b/src/main/java/org/jboss/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandler.java @@ -0,0 +1,114 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package org.jboss.netty.handler.codec.http.websocketx; + +import static org.jboss.netty.handler.codec.http.HttpVersion.HTTP_1_1; +import org.jboss.netty.buffer.ChannelBuffers; +import org.jboss.netty.channel.ChannelFutureListener; +import org.jboss.netty.channel.ChannelHandler; +import org.jboss.netty.channel.ChannelHandlerContext; +import org.jboss.netty.channel.ChannelPipeline; +import org.jboss.netty.channel.ChannelStateEvent; +import org.jboss.netty.channel.ExceptionEvent; +import org.jboss.netty.channel.MessageEvent; +import org.jboss.netty.channel.SimpleChannelHandler; +import org.jboss.netty.handler.codec.http.DefaultHttpResponse; +import org.jboss.netty.handler.codec.http.HttpResponseStatus; + +/** + * Handles WebSocket control frames (Close, Ping, Pong) and data frames (Text and Binary) are passed + * to the next handler in the pipeline. + */ +public class WebSocketServerProtocolHandler extends SimpleChannelHandler { + + private final String websocketPath; + private final String subprotocols; + private final boolean allowExtensions; + + public WebSocketServerProtocolHandler(String websocketPath) { + this(websocketPath, null, false); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols) { + this(websocketPath, subprotocols, false); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions) { + this.websocketPath = websocketPath; + this.subprotocols = subprotocols; + this.allowExtensions = allowExtensions; + } + + @Override + public void channelBound(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception { + ChannelPipeline cp = ctx.getPipeline(); + if (cp.get(WebSocketServerProtocolHandshakeHandler.class) == null) { + // Add the WebSocketHandshakeHandler before this one. + ctx.getPipeline().addBefore(ctx.getName(), WebSocketServerProtocolHandshakeHandler.class.getName(), + new WebSocketServerProtocolHandshakeHandler(websocketPath, subprotocols, allowExtensions)); + } + } + + @Override + public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception { + if (e.getMessage() instanceof WebSocketFrame) { + WebSocketFrame frame = (WebSocketFrame) e.getMessage(); + if (frame instanceof CloseWebSocketFrame) { + WebSocketServerHandshaker handshaker = WebSocketServerProtocolHandler.getHandshaker(ctx); + handshaker.close(ctx.getChannel(), (CloseWebSocketFrame) frame); + return; + } else if (frame instanceof PingWebSocketFrame) { + ctx.getChannel().write(new PongWebSocketFrame(frame.getBinaryData())); + return; + } + } + ctx.sendUpstream(e); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) throws Exception { + if (e.getCause() instanceof WebSocketHandshakeException) { + DefaultHttpResponse response = new DefaultHttpResponse(HTTP_1_1, HttpResponseStatus.BAD_REQUEST); + response.setContent(ChannelBuffers.wrappedBuffer(e.getCause().getMessage().getBytes())); + ctx.getChannel().write(response).addListener(ChannelFutureListener.CLOSE); + } else { + ctx.getChannel().close(); + } + } + + static WebSocketServerHandshaker getHandshaker(ChannelHandlerContext ctx) { + return (WebSocketServerHandshaker) ctx.getAttachment(); + } + + static void setHandshaker(ChannelHandlerContext ctx, WebSocketServerHandshaker handshaker) { + ctx.setAttachment(handshaker); + } + + static ChannelHandler forbiddenHttpRequestResponder() { + return new SimpleChannelHandler() { + @Override + public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception { + if (!(e.getMessage() instanceof WebSocketFrame)) { + DefaultHttpResponse response = new DefaultHttpResponse(HTTP_1_1, HttpResponseStatus.FORBIDDEN); + ctx.getChannel().write(response); + } else { + ctx.sendUpstream(e); + } + } + }; + } + +} diff --git a/src/main/java/org/jboss/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java b/src/main/java/org/jboss/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java new file mode 100644 index 0000000000..0ace6042b3 --- /dev/null +++ b/src/main/java/org/jboss/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java @@ -0,0 +1,107 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package org.jboss.netty.handler.codec.http.websocketx; + +import static org.jboss.netty.handler.codec.http.HttpHeaders.isKeepAlive; +import static org.jboss.netty.handler.codec.http.HttpMethod.GET; +import static org.jboss.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN; +import static org.jboss.netty.handler.codec.http.HttpVersion.HTTP_1_1; + +import org.jboss.netty.channel.ChannelFuture; +import org.jboss.netty.channel.ChannelFutureListener; +import org.jboss.netty.channel.ChannelHandlerContext; +import org.jboss.netty.channel.ChannelPipeline; +import org.jboss.netty.channel.MessageEvent; +import org.jboss.netty.channel.SimpleChannelHandler; +import org.jboss.netty.handler.codec.http.DefaultHttpResponse; +import org.jboss.netty.handler.codec.http.HttpHeaders; +import org.jboss.netty.handler.codec.http.HttpRequest; +import org.jboss.netty.handler.codec.http.HttpResponse; +import org.jboss.netty.handler.ssl.SslHandler; +import org.jboss.netty.logging.InternalLogger; +import org.jboss.netty.logging.InternalLoggerFactory; + +/** + * Handles the HTTP handshake (the HTTP Upgrade request) + */ +public class WebSocketServerProtocolHandshakeHandler extends SimpleChannelHandler { + + private static final InternalLogger logger = + InternalLoggerFactory.getInstance(WebSocketServerProtocolHandshakeHandler.class); + private final String websocketPath; + private final String subprotocols; + private final boolean allowExtensions; + + public WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols, + boolean allowExtensions) { + this.websocketPath = websocketPath; + this.subprotocols = subprotocols; + this.allowExtensions = allowExtensions; + } + + @Override + public void messageReceived(final ChannelHandlerContext ctx, MessageEvent e) throws Exception { + if (e.getMessage() instanceof HttpRequest) { + HttpRequest req = (HttpRequest) e.getMessage(); + if (req.getMethod() != GET) { + sendHttpResponse(ctx, req, new DefaultHttpResponse(HTTP_1_1, FORBIDDEN)); + return; + } + + final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory( + getWebSocketLocation(ctx.getPipeline(), req, websocketPath), subprotocols, allowExtensions); + final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req); + if (handshaker == null) { + wsFactory.sendUnsupportedWebSocketVersionResponse(ctx.getChannel()); + } else { + final ChannelFuture handshakeFuture = handshaker.handshake(ctx.getChannel(), req); + handshakeFuture.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (!future.isSuccess()) { + //ctx.fireExceptionCaught(future.getCause()); + } + } + }); + WebSocketServerProtocolHandler.setHandshaker(ctx, handshaker); + ctx.getPipeline().replace(this, "WS403Responder", + WebSocketServerProtocolHandler.forbiddenHttpRequestResponder()); + } + } + } + + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + logger.error("Exception Caught", cause); + //ctx.close(); + } + + private static void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) { + ChannelFuture f = ctx.getChannel().write(res); + if (!isKeepAlive(req) || res.getStatus().getCode() != 200) { + f.addListener(ChannelFutureListener.CLOSE); + } + } + + private static String getWebSocketLocation(ChannelPipeline cp, HttpRequest req, String path) { + String protocol = "ws"; + if (cp.get(SslHandler.class) != null) { + // SSL in use so use Secure WebSockets + protocol = "wss"; + } + return protocol + "://" + req.getHeader(HttpHeaders.Names.HOST) + path; + } + +} diff --git a/src/test/java/org/jboss/netty/handler/codec/http/websocketx/WebSocketRequestBuilder.java b/src/test/java/org/jboss/netty/handler/codec/http/websocketx/WebSocketRequestBuilder.java new file mode 100644 index 0000000000..43cfff16db --- /dev/null +++ b/src/test/java/org/jboss/netty/handler/codec/http/websocketx/WebSocketRequestBuilder.java @@ -0,0 +1,132 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package org.jboss.netty.handler.codec.http.websocketx; + + +import static org.jboss.netty.handler.codec.http.HttpHeaders.Values.WEBSOCKET; +import static org.jboss.netty.handler.codec.http.HttpVersion.HTTP_1_1; +import org.jboss.netty.handler.codec.http.DefaultHttpRequest; +import org.jboss.netty.handler.codec.http.HttpMethod; +import org.jboss.netty.handler.codec.http.HttpRequest; +import org.jboss.netty.handler.codec.http.HttpVersion; +import org.jboss.netty.handler.codec.http.HttpHeaders.Names; + +public class WebSocketRequestBuilder { + + private HttpVersion httpVersion; + private HttpMethod method; + private String uri; + private String host; + private String upgrade; + private String connection; + private String key; + private String origin; + private WebSocketVersion version; + + public WebSocketRequestBuilder httpVersion(HttpVersion httpVersion) { + this.httpVersion = httpVersion; + return this; + } + + public WebSocketRequestBuilder method(HttpMethod method) { + this.method = method; + return this; + } + + public WebSocketRequestBuilder uri(String uri) { + this.uri = uri; + return this; + } + + public WebSocketRequestBuilder host(String host) { + this.host = host; + return this; + } + + public WebSocketRequestBuilder upgrade(String upgrade) { + this.upgrade = upgrade; + return this; + } + + public WebSocketRequestBuilder connection(String connection) { + this.connection = connection; + return this; + } + + public WebSocketRequestBuilder key(String key) { + this.key = key; + return this; + } + + public WebSocketRequestBuilder origin(String origin) { + this.origin = origin; + return this; + } + + public WebSocketRequestBuilder version13() { + this.version = WebSocketVersion.V13; + return this; + } + + public WebSocketRequestBuilder version8() { + this.version = WebSocketVersion.V08; + return this; + } + + public WebSocketRequestBuilder version00() { + this.version = null; + return this; + } + + public WebSocketRequestBuilder noVersion() { + return this; + } + + public HttpRequest build() { + HttpRequest req = new DefaultHttpRequest(httpVersion, method, uri); + if (host != null) { + req.setHeader(Names.HOST, host); + } + if (upgrade != null) { + req.setHeader(Names.UPGRADE, upgrade); + } + if (connection != null) { + req.setHeader(Names.CONNECTION, connection); + } + if (key != null) { + req.setHeader(Names.SEC_WEBSOCKET_KEY, key); + } + if (origin != null) { + req.setHeader(Names.SEC_WEBSOCKET_ORIGIN, origin); + } + if (version != null) { + req.setHeader(Names.SEC_WEBSOCKET_VERSION, version.toHttpHeaderValue()); + } + return req; + } + + public static HttpRequest sucessful() { + return new WebSocketRequestBuilder().httpVersion(HTTP_1_1) + .method(HttpMethod.GET) + .uri("/test") + .host("server.example.com") + .upgrade(WEBSOCKET.toLowerCase()) + .key("dGhlIHNhbXBsZSBub25jZQ==") + .origin("http://example.com") + .version13() + .build(); + } +} diff --git a/src/test/java/org/jboss/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandlerTest.java b/src/test/java/org/jboss/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandlerTest.java new file mode 100644 index 0000000000..d8b8f411ec --- /dev/null +++ b/src/test/java/org/jboss/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandlerTest.java @@ -0,0 +1,176 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package org.jboss.netty.handler.codec.http.websocketx; + +import static org.jboss.netty.handler.codec.http.HttpHeaders.Values.WEBSOCKET; +import static org.jboss.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN; +import static org.jboss.netty.handler.codec.http.HttpResponseStatus.SWITCHING_PROTOCOLS; +import static org.jboss.netty.handler.codec.http.HttpVersion.HTTP_1_1; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import java.util.LinkedList; +import java.util.Queue; + +import org.jboss.netty.channel.Channel; +import org.jboss.netty.channel.ChannelHandlerContext; +import org.jboss.netty.channel.MessageEvent; +import org.jboss.netty.channel.SimpleChannelHandler; +import org.jboss.netty.handler.codec.embedder.CodecEmbedderException; +import org.jboss.netty.handler.codec.embedder.DecoderEmbedder; +import org.jboss.netty.handler.codec.http.DefaultHttpRequest; +import org.jboss.netty.handler.codec.http.HttpMethod; +import org.jboss.netty.handler.codec.http.HttpRequest; +import org.jboss.netty.handler.codec.http.HttpRequestDecoder; +import org.jboss.netty.handler.codec.http.HttpResponse; +import org.jboss.netty.handler.codec.http.HttpResponseEncoder; +import org.jboss.netty.handler.codec.http.HttpVersion; +import org.junit.Test; + +public class WebSocketServerProtocolHandlerTest { + + @Test + public void testHttpUpgradeRequest() { + DecoderEmbedder embedder = decoderEmbedder(); + ChannelHandlerContext ctx = embedder.getPipeline().getContext(WebSocketServerProtocolHandshakeHandler.class); + HttpResponseInterceptor responseInterceptor = addHttpResponseInterceptor(embedder); + + embedder.offer(WebSocketRequestBuilder.sucessful()); + + HttpResponse response = responseInterceptor.getHttpResponse(); + assertEquals(SWITCHING_PROTOCOLS, response.getStatus()); + assertNotNull(WebSocketServerProtocolHandler.getHandshaker(ctx)); + } + + private HttpResponseInterceptor addHttpResponseInterceptor(DecoderEmbedder embedder) { + HttpResponseInterceptor interceptor = new HttpResponseInterceptor(); + embedder.getPipeline().addLast("httpEncoder", interceptor); + return interceptor; + } + + @Test + public void testSubsequentHttpRequestsAfterUpgradeShouldReturn403() throws Exception { + DecoderEmbedder embedder = decoderEmbedder(); + HttpResponseInterceptor responseInterceptor = addHttpResponseInterceptor(embedder); + + embedder.offer(WebSocketRequestBuilder.sucessful()); + embedder.offer(new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "path")); + + assertEquals(SWITCHING_PROTOCOLS, responseInterceptor.getHttpResponse().getStatus()); + assertEquals(FORBIDDEN, responseInterceptor.getHttpResponse().getStatus()); + } + + @Test + public void testHttpUpgradeRequestInvalidUpgradeHeader() { + DecoderEmbedder embedder = decoderEmbedder(); + + HttpRequest invalidUpgradeRequest = new WebSocketRequestBuilder().httpVersion(HTTP_1_1) + .method(HttpMethod.GET) + .uri("/test") + .connection("Upgrade") + .version00() + .upgrade("BogusSocket") + .build(); + try { + embedder.offer(invalidUpgradeRequest); + } catch (Exception e) { + assertWebSocketHandshakeException(e); + } + } + + @Test + public void testHttpUpgradeRequestMissingWSKeyHeader() { + DecoderEmbedder embedder = decoderEmbedder(); + HttpRequest missingWSKeyRequest = new WebSocketRequestBuilder().httpVersion(HTTP_1_1) + .method(HttpMethod.GET) + .uri("/test") + .key(null) + .connection("Upgrade") + .upgrade(WEBSOCKET.toLowerCase()) + .version13() + .build(); + + try { + embedder.offer(missingWSKeyRequest); + } catch (Exception e) { + assertWebSocketHandshakeException(e); + } + } + + private void assertWebSocketHandshakeException(Exception e) { + assertTrue(e instanceof CodecEmbedderException); + assertTrue(e.getCause() instanceof WebSocketHandshakeException); + } + + @Test + public void testHandleTextFrame() { + CustomTextFrameHandler customTextFrameHandler = new CustomTextFrameHandler(); + DecoderEmbedder embedder = decoderEmbedder(customTextFrameHandler); + + embedder.offer(WebSocketRequestBuilder.sucessful()); + embedder.offer(new TextWebSocketFrame("payload")); + + assertEquals("processed: payload", customTextFrameHandler.getContent()); + } + + private DecoderEmbedder decoderEmbedder(SimpleChannelHandler handler) { + DecoderEmbedder decoder = decoderEmbedder(); + decoder.getPipeline().addFirst("someHandler", handler); + return decoder; + } + + private DecoderEmbedder decoderEmbedder() { + DecoderEmbedder decoder = new DecoderEmbedder( + new HttpRequestDecoder(), + new WebSocketServerProtocolHandler("path")); + return decoder; + } + + private static class CustomTextFrameHandler extends SimpleChannelHandler { + private String content; + + @Override + public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception { + if (e.getMessage() instanceof TextWebSocketFrame) { + TextWebSocketFrame frame = (TextWebSocketFrame) e.getMessage(); + content = "processed: " + frame.getText(); + } + } + + public String getContent() { + return content; + } + + } + + private static class HttpResponseInterceptor extends HttpResponseEncoder { + + private Queue responses = new LinkedList(); + + @Override + protected Object encode(ChannelHandlerContext ctx, Channel channel, Object msg) throws Exception { + responses.add((HttpResponse) msg); + return super.encode(ctx, channel, msg); + } + + public HttpResponse getHttpResponse() { + return responses.poll(); + } + + } + +}