Allow to use WebSocketClientHandshaker and WebSocketServerHandshaker with HttpResponse / HttpRequest

Motivation:

To use WebSocketClientHandshaker / WebSocketServerHandshaker it's currently a requirement of having a HttpObjectAggregator in the ChannelPipeline. This is not a big deal when a user only wants to server WebSockets but is a limitation if the server serves WebSockets and normal HTTP traffic.

Modifications:

Allow to use WebSocketClientHandshaker and WebSocketServerHandshaker without HttpObjectAggregator in the ChannelPipeline.

Result:

More flexibility
This commit is contained in:
Norman Maurer 2015-01-06 06:55:01 +01:00
parent dd32313fad
commit 261a30d8af
4 changed files with 208 additions and 10 deletions

View File

@ -21,21 +21,32 @@ import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpContentDecompressor;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequestEncoder;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseDecoder;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.EmptyArrays;
import io.netty.util.internal.StringUtil;
import java.net.URI;
import java.nio.channels.ClosedChannelException;
/**
* Base class for web socket client handshake implementations
*/
public abstract class WebSocketClientHandshaker {
private static final ClosedChannelException CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException();
static {
CLOSED_CHANNEL_EXCEPTION.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE);
}
private final URI uri;
@ -238,6 +249,12 @@ public abstract class WebSocketClientHandshaker {
p.remove(decompressor);
}
// Remove aggregator if present before
HttpObjectAggregator aggregator = p.get(HttpObjectAggregator.class);
if (aggregator != null) {
p.remove(aggregator);
}
ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class);
if (ctx == null) {
ctx = p.context(HttpClientCodec.class);
@ -255,6 +272,93 @@ public abstract class WebSocketClientHandshaker {
}
}
/**
* Process the opening handshake initiated by {@link #handshake}}.
*
* @param channel
* Channel
* @param response
* HTTP response containing the closing handshake details
* @return future
* the {@link ChannelFuture} which is notified once the handshake completes.
*/
public final ChannelFuture processHandshake(final Channel channel, HttpResponse response) {
return processHandshake(channel, response, channel.newPromise());
}
/**
* Process the opening handshake initiated by {@link #handshake}}.
*
* @param channel
* Channel
* @param response
* HTTP response containing the closing handshake details
* @param promise
* the {@link ChannelPromise} to notify once the handshake completes.
* @return future
* the {@link ChannelFuture} which is notified once the handshake completes.
*/
public final ChannelFuture processHandshake(final Channel channel, HttpResponse response,
final ChannelPromise promise) {
if (response instanceof FullHttpResponse) {
try {
finishHandshake(channel, (FullHttpResponse) response);
promise.setSuccess();
} catch (Throwable cause) {
promise.setFailure(cause);
}
} else {
ChannelPipeline p = channel.pipeline();
ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class);
if (ctx == null) {
ctx = p.context(HttpClientCodec.class);
if (ctx == null) {
return promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
"a HttpResponseDecoder or HttpClientCodec"));
}
}
// Add aggregator and ensure we feed the HttpResponse so it is aggregated. A limit of 8192 should be more
// then enough for the websockets handshake payload.
//
// TODO: Make handshake work without HttpObjectAggregator at all.
String aggregatorName = "httpAggregator";
p.addAfter(ctx.name(), aggregatorName, new HttpObjectAggregator(8192));
p.addAfter(aggregatorName, "handshaker", new SimpleChannelInboundHandler<FullHttpResponse>() {
@Override
protected void channelRead0(ChannelHandlerContext ctx, FullHttpResponse msg) throws Exception {
// Remove ourself and do the actual handshake
ctx.pipeline().remove(this);
try {
finishHandshake(channel, msg);
promise.setSuccess();
} catch (Throwable cause) {
promise.setFailure(cause);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
// Remove ourself and fail the handshake promise.
ctx.pipeline().remove(this);
promise.setFailure(cause);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
// Fail promise if Channel was closed
promise.tryFailure(CLOSED_CHANNEL_EXCEPTION);
ctx.fireChannelInactive();
}
});
try {
ctx.fireChannelRead(ReferenceCountUtil.retain(response));
} catch (Throwable cause) {
promise.setFailure(cause);
}
}
return promise;
}
/**
* Verfiy the {@link FullHttpResponse} and throws a {@link WebSocketHandshakeException} if something is wrong.
*/

View File

@ -21,19 +21,23 @@ import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpContentCompressor;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpRequestDecoder;
import io.netty.handler.codec.http.HttpResponseEncoder;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.EmptyArrays;
import io.netty.util.internal.StringUtil;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import java.nio.channels.ClosedChannelException;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.Set;
@ -43,6 +47,11 @@ import java.util.Set;
*/
public abstract class WebSocketServerHandshaker {
protected static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketServerHandshaker.class);
private static final ClosedChannelException CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException();
static {
CLOSED_CHANNEL_EXCEPTION.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE);
}
private final String uri;
@ -200,6 +209,94 @@ public abstract class WebSocketServerHandshaker {
return promise;
}
/**
* Performs the opening handshake. When call this method you <strong>MUST NOT</strong> retain the
* {@link FullHttpRequest} which is passed in.
*
* @param channel
* Channel
* @param req
* HTTP Request
* @return future
* The {@link ChannelFuture} which is notified once the opening handshake completes
*/
public ChannelFuture handshake(Channel channel, HttpRequest req) {
return handshake(channel, req, null, channel.newPromise());
}
/**
* Performs the opening handshake
*
* When call this method you <strong>MUST NOT</strong> retain the {@link HttpRequest} which is passed in.
*
* @param channel
* Channel
* @param req
* HTTP Request
* @param responseHeaders
* Extra headers to add to the handshake response or {@code null} if no extra headers should be added
* @param promise
* the {@link ChannelPromise} to be notified when the opening handshake is done
* @return future
* the {@link ChannelFuture} which is notified when the opening handshake is done
*/
public final ChannelFuture handshake(final Channel channel, HttpRequest req,
final HttpHeaders responseHeaders, final ChannelPromise promise) {
if (req instanceof FullHttpRequest) {
return handshake(channel, (FullHttpRequest) req, responseHeaders, promise);
}
if (logger.isDebugEnabled()) {
logger.debug("{} WebSocket version {} server handshake", channel, version());
}
ChannelPipeline p = channel.pipeline();
ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
if (ctx == null) {
// this means the user use a HttpServerCodec
ctx = p.context(HttpServerCodec.class);
if (ctx == null) {
promise.setFailure(
new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
return promise;
}
}
// Add aggregator and ensure we feed the HttpRequest so it is aggregated. A limit o 8192 should be more then
// enough for the websockets handshake payload.
//
// TODO: Make handshake work without HttpObjectAggregator at all.
String aggregatorName = "httpAggregator";
p.addAfter(ctx.name(), aggregatorName, new HttpObjectAggregator(8192));
p.addAfter(aggregatorName, "handshaker", new SimpleChannelInboundHandler<FullHttpRequest>() {
@Override
protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest msg) throws Exception {
// Remove ourself and do the actual handshake
ctx.pipeline().remove(this);
handshake(channel, msg, responseHeaders, promise);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
// Remove ourself and fail the handshake promise.
ctx.pipeline().remove(this);
promise.tryFailure(cause);
ctx.fireExceptionCaught(cause);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
// Fail promise if Channel was closed
promise.tryFailure(CLOSED_CHANNEL_EXCEPTION);
ctx.fireChannelInactive();
}
});
try {
ctx.fireChannelRead(ReferenceCountUtil.retain(req));
} catch (Throwable cause) {
promise.setFailure(cause);
}
return promise;
}
/**
* Returns a new {@link FullHttpResponse) which will be used for as response to the handshake request.
*/

View File

@ -25,6 +25,7 @@ import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame;
@ -56,8 +57,8 @@ public class AutobahnServerHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof FullHttpRequest) {
handleHttpRequest(ctx, (FullHttpRequest) msg);
if (msg instanceof HttpRequest) {
handleHttpRequest(ctx, (HttpRequest) msg);
} else if (msg instanceof WebSocketFrame) {
handleWebSocketFrame(ctx, (WebSocketFrame) msg);
} else {
@ -70,19 +71,17 @@ public class AutobahnServerHandler extends ChannelInboundHandlerAdapter {
ctx.flush();
}
private void handleHttpRequest(ChannelHandlerContext ctx, FullHttpRequest req)
private void handleHttpRequest(ChannelHandlerContext ctx, HttpRequest req)
throws Exception {
// Handle a bad request.
if (!req.getDecoderResult().isSuccess()) {
sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, BAD_REQUEST));
req.release();
return;
}
// Allow only GET methods.
if (req.getMethod() != GET) {
sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN));
req.release();
return;
}
@ -95,7 +94,6 @@ public class AutobahnServerHandler extends ChannelInboundHandlerAdapter {
} else {
handshaker.handshake(ctx.channel(), req);
}
req.release();
}
private void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame) {
@ -124,7 +122,7 @@ public class AutobahnServerHandler extends ChannelInboundHandlerAdapter {
}
private static void sendHttpResponse(
ChannelHandlerContext ctx, FullHttpRequest req, FullHttpResponse res) {
ChannelHandlerContext ctx, HttpRequest req, FullHttpResponse res) {
// Generate an error page if response status code is not OK (200).
if (res.getStatus().code() != 200) {
ByteBuf buf = Unpooled.copiedBuffer(res.getStatus().toString(), CharsetUtil.UTF_8);
@ -145,7 +143,7 @@ public class AutobahnServerHandler extends ChannelInboundHandlerAdapter {
ctx.close();
}
private static String getWebSocketLocation(FullHttpRequest req) {
return "ws://" + req.headers().get(HttpHeaders.Names.HOST);
private static String getWebSocketLocation(HttpRequest req) {
return "ws://" + req.headers().get(Names.HOST);
}
}

View File

@ -28,7 +28,6 @@ public class AutobahnServerInitializer extends ChannelInitializer<SocketChannel>
ChannelPipeline pipeline = ch.pipeline();
pipeline.addLast("encoder", new HttpResponseEncoder());
pipeline.addLast("decoder", new HttpRequestDecoder());
pipeline.addLast("aggregator", new HttpObjectAggregator(65536));
pipeline.addLast("handler", new AutobahnServerHandler());
}
}