diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker.java index 59e3ed597a..d67b85f379 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker.java @@ -21,21 +21,32 @@ import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPromise; +import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpClientCodec; import io.netty.handler.codec.http.HttpContentDecompressor; import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpRequestEncoder; +import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponseDecoder; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.StringUtil; import java.net.URI; +import java.nio.channels.ClosedChannelException; /** * Base class for web socket client handshake implementations */ public abstract class WebSocketClientHandshaker { + private static final ClosedChannelException CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException(); + + static { + CLOSED_CHANNEL_EXCEPTION.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE); + } private final URI uri; @@ -238,6 +249,12 @@ public abstract class WebSocketClientHandshaker { p.remove(decompressor); } + // Remove aggregator if present before + HttpObjectAggregator aggregator = p.get(HttpObjectAggregator.class); + if (aggregator != null) { + p.remove(aggregator); + } + ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class); if (ctx == null) { ctx = p.context(HttpClientCodec.class); @@ -255,6 +272,93 @@ public abstract class WebSocketClientHandshaker { } } + /** + * Process the opening handshake initiated by {@link #handshake}}. + * + * @param channel + * Channel + * @param response + * HTTP response containing the closing handshake details + * @return future + * the {@link ChannelFuture} which is notified once the handshake completes. + */ + public final ChannelFuture processHandshake(final Channel channel, HttpResponse response) { + return processHandshake(channel, response, channel.newPromise()); + } + + /** + * Process the opening handshake initiated by {@link #handshake}}. + * + * @param channel + * Channel + * @param response + * HTTP response containing the closing handshake details + * @param promise + * the {@link ChannelPromise} to notify once the handshake completes. + * @return future + * the {@link ChannelFuture} which is notified once the handshake completes. + */ + public final ChannelFuture processHandshake(final Channel channel, HttpResponse response, + final ChannelPromise promise) { + if (response instanceof FullHttpResponse) { + try { + finishHandshake(channel, (FullHttpResponse) response); + promise.setSuccess(); + } catch (Throwable cause) { + promise.setFailure(cause); + } + } else { + ChannelPipeline p = channel.pipeline(); + ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class); + if (ctx == null) { + ctx = p.context(HttpClientCodec.class); + if (ctx == null) { + return promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " + + "a HttpResponseDecoder or HttpClientCodec")); + } + } + // Add aggregator and ensure we feed the HttpResponse so it is aggregated. A limit of 8192 should be more + // then enough for the websockets handshake payload. + // + // TODO: Make handshake work without HttpObjectAggregator at all. + String aggregatorName = "httpAggregator"; + p.addAfter(ctx.name(), aggregatorName, new HttpObjectAggregator(8192)); + p.addAfter(aggregatorName, "handshaker", new SimpleChannelInboundHandler() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, FullHttpResponse msg) throws Exception { + // Remove ourself and do the actual handshake + ctx.pipeline().remove(this); + try { + finishHandshake(channel, msg); + promise.setSuccess(); + } catch (Throwable cause) { + promise.setFailure(cause); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + // Remove ourself and fail the handshake promise. + ctx.pipeline().remove(this); + promise.setFailure(cause); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + // Fail promise if Channel was closed + promise.tryFailure(CLOSED_CHANNEL_EXCEPTION); + ctx.fireChannelInactive(); + } + }); + try { + ctx.fireChannelRead(ReferenceCountUtil.retain(response)); + } catch (Throwable cause) { + promise.setFailure(cause); + } + } + return promise; + } + /** * Verfiy the {@link FullHttpResponse} and throws a {@link WebSocketHandshakeException} if something is wrong. */ diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker.java index c2ca118d35..09dc546c89 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker.java @@ -21,19 +21,23 @@ import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPromise; +import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpContentCompressor; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpRequestDecoder; import io.netty.handler.codec.http.HttpResponseEncoder; import io.netty.handler.codec.http.HttpServerCodec; +import io.netty.util.ReferenceCountUtil; import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.StringUtil; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; +import java.nio.channels.ClosedChannelException; import java.util.Collections; import java.util.LinkedHashSet; import java.util.Set; @@ -43,6 +47,11 @@ import java.util.Set; */ public abstract class WebSocketServerHandshaker { protected static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketServerHandshaker.class); + private static final ClosedChannelException CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException(); + + static { + CLOSED_CHANNEL_EXCEPTION.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE); + } private final String uri; @@ -200,6 +209,94 @@ public abstract class WebSocketServerHandshaker { return promise; } + /** + * Performs the opening handshake. When call this method you MUST NOT retain the + * {@link FullHttpRequest} which is passed in. + * + * @param channel + * Channel + * @param req + * HTTP Request + * @return future + * The {@link ChannelFuture} which is notified once the opening handshake completes + */ + public ChannelFuture handshake(Channel channel, HttpRequest req) { + return handshake(channel, req, null, channel.newPromise()); + } + + /** + * Performs the opening handshake + * + * When call this method you MUST NOT retain the {@link HttpRequest} which is passed in. + * + * @param channel + * Channel + * @param req + * HTTP Request + * @param responseHeaders + * Extra headers to add to the handshake response or {@code null} if no extra headers should be added + * @param promise + * the {@link ChannelPromise} to be notified when the opening handshake is done + * @return future + * the {@link ChannelFuture} which is notified when the opening handshake is done + */ + public final ChannelFuture handshake(final Channel channel, HttpRequest req, + final HttpHeaders responseHeaders, final ChannelPromise promise) { + + if (req instanceof FullHttpRequest) { + return handshake(channel, (FullHttpRequest) req, responseHeaders, promise); + } + if (logger.isDebugEnabled()) { + logger.debug("{} WebSocket version {} server handshake", channel, version()); + } + ChannelPipeline p = channel.pipeline(); + ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class); + if (ctx == null) { + // this means the user use a HttpServerCodec + ctx = p.context(HttpServerCodec.class); + if (ctx == null) { + promise.setFailure( + new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline")); + return promise; + } + } + // Add aggregator and ensure we feed the HttpRequest so it is aggregated. A limit o 8192 should be more then + // enough for the websockets handshake payload. + // + // TODO: Make handshake work without HttpObjectAggregator at all. + String aggregatorName = "httpAggregator"; + p.addAfter(ctx.name(), aggregatorName, new HttpObjectAggregator(8192)); + p.addAfter(aggregatorName, "handshaker", new SimpleChannelInboundHandler() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest msg) throws Exception { + // Remove ourself and do the actual handshake + ctx.pipeline().remove(this); + handshake(channel, msg, responseHeaders, promise); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + // Remove ourself and fail the handshake promise. + ctx.pipeline().remove(this); + promise.tryFailure(cause); + ctx.fireExceptionCaught(cause); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + // Fail promise if Channel was closed + promise.tryFailure(CLOSED_CHANNEL_EXCEPTION); + ctx.fireChannelInactive(); + } + }); + try { + ctx.fireChannelRead(ReferenceCountUtil.retain(req)); + } catch (Throwable cause) { + promise.setFailure(cause); + } + return promise; + } + /** * Returns a new {@link FullHttpResponse) which will be used for as response to the handshake request. */ diff --git a/testsuite/src/main/java/io/netty/testsuite/websockets/autobahn/AutobahnServerHandler.java b/testsuite/src/main/java/io/netty/testsuite/websockets/autobahn/AutobahnServerHandler.java index 64093ae5d8..8c3045a147 100644 --- a/testsuite/src/main/java/io/netty/testsuite/websockets/autobahn/AutobahnServerHandler.java +++ b/testsuite/src/main/java/io/netty/testsuite/websockets/autobahn/AutobahnServerHandler.java @@ -25,6 +25,7 @@ import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame; @@ -56,8 +57,8 @@ public class AutobahnServerHandler extends ChannelInboundHandlerAdapter { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - if (msg instanceof FullHttpRequest) { - handleHttpRequest(ctx, (FullHttpRequest) msg); + if (msg instanceof HttpRequest) { + handleHttpRequest(ctx, (HttpRequest) msg); } else if (msg instanceof WebSocketFrame) { handleWebSocketFrame(ctx, (WebSocketFrame) msg); } else { @@ -70,19 +71,17 @@ public class AutobahnServerHandler extends ChannelInboundHandlerAdapter { ctx.flush(); } - private void handleHttpRequest(ChannelHandlerContext ctx, FullHttpRequest req) + private void handleHttpRequest(ChannelHandlerContext ctx, HttpRequest req) throws Exception { // Handle a bad request. if (!req.getDecoderResult().isSuccess()) { sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, BAD_REQUEST)); - req.release(); return; } // Allow only GET methods. if (req.getMethod() != GET) { sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN)); - req.release(); return; } @@ -95,7 +94,6 @@ public class AutobahnServerHandler extends ChannelInboundHandlerAdapter { } else { handshaker.handshake(ctx.channel(), req); } - req.release(); } private void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame) { @@ -124,7 +122,7 @@ public class AutobahnServerHandler extends ChannelInboundHandlerAdapter { } private static void sendHttpResponse( - ChannelHandlerContext ctx, FullHttpRequest req, FullHttpResponse res) { + ChannelHandlerContext ctx, HttpRequest req, FullHttpResponse res) { // Generate an error page if response status code is not OK (200). if (res.getStatus().code() != 200) { ByteBuf buf = Unpooled.copiedBuffer(res.getStatus().toString(), CharsetUtil.UTF_8); @@ -145,7 +143,7 @@ public class AutobahnServerHandler extends ChannelInboundHandlerAdapter { ctx.close(); } - private static String getWebSocketLocation(FullHttpRequest req) { - return "ws://" + req.headers().get(HttpHeaders.Names.HOST); + private static String getWebSocketLocation(HttpRequest req) { + return "ws://" + req.headers().get(Names.HOST); } } diff --git a/testsuite/src/main/java/io/netty/testsuite/websockets/autobahn/AutobahnServerInitializer.java b/testsuite/src/main/java/io/netty/testsuite/websockets/autobahn/AutobahnServerInitializer.java index fef9c1c0d4..e6b7596662 100644 --- a/testsuite/src/main/java/io/netty/testsuite/websockets/autobahn/AutobahnServerInitializer.java +++ b/testsuite/src/main/java/io/netty/testsuite/websockets/autobahn/AutobahnServerInitializer.java @@ -28,7 +28,6 @@ public class AutobahnServerInitializer extends ChannelInitializer ChannelPipeline pipeline = ch.pipeline(); pipeline.addLast("encoder", new HttpResponseEncoder()); pipeline.addLast("decoder", new HttpRequestDecoder()); - pipeline.addLast("aggregator", new HttpObjectAggregator(65536)); pipeline.addLast("handler", new AutobahnServerHandler()); } }