Allow to use WebSocketClientHandshaker and WebSocketServerHandshaker with HttpResponse / HttpRequest

Motivation:

To use WebSocketClientHandshaker / WebSocketServerHandshaker it's currently a requirement of having a HttpObjectAggregator in the ChannelPipeline. This is not a big deal when a user only wants to server WebSockets but is a limitation if the server serves WebSockets and normal HTTP traffic.

Modifications:

Allow to use WebSocketClientHandshaker and WebSocketServerHandshaker without HttpObjectAggregator in the ChannelPipeline.

Result:

More flexibility
This commit is contained in:
Norman Maurer 2015-01-06 06:55:01 +01:00
parent d1a60e682a
commit 6e942a3a20
4 changed files with 207 additions and 10 deletions

View File

@ -21,22 +21,33 @@ import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpClientCodec; import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpContentDecompressor; import io.netty.handler.codec.http.HttpContentDecompressor;
import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequestEncoder; import io.netty.handler.codec.http.HttpRequestEncoder;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseDecoder; import io.netty.handler.codec.http.HttpResponseDecoder;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.EmptyArrays;
import io.netty.util.internal.StringUtil; import io.netty.util.internal.StringUtil;
import java.net.URI; import java.net.URI;
import java.nio.channels.ClosedChannelException;
/** /**
* Base class for web socket client handshake implementations * Base class for web socket client handshake implementations
*/ */
public abstract class WebSocketClientHandshaker { public abstract class WebSocketClientHandshaker {
private static final ClosedChannelException CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException();
static {
CLOSED_CHANNEL_EXCEPTION.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE);
}
private final URI uri; private final URI uri;
@ -239,6 +250,12 @@ public abstract class WebSocketClientHandshaker {
p.remove(decompressor); p.remove(decompressor);
} }
// Remove aggregator if present before
HttpObjectAggregator aggregator = p.get(HttpObjectAggregator.class);
if (aggregator != null) {
p.remove(aggregator);
}
ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class); ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class);
if (ctx == null) { if (ctx == null) {
ctx = p.context(HttpClientCodec.class); ctx = p.context(HttpClientCodec.class);
@ -256,6 +273,93 @@ public abstract class WebSocketClientHandshaker {
} }
} }
/**
* Process the opening handshake initiated by {@link #handshake}}.
*
* @param channel
* Channel
* @param response
* HTTP response containing the closing handshake details
* @return future
* the {@link ChannelFuture} which is notified once the handshake completes.
*/
public final ChannelFuture processHandshake(final Channel channel, HttpResponse response) {
return processHandshake(channel, response, channel.newPromise());
}
/**
* Process the opening handshake initiated by {@link #handshake}}.
*
* @param channel
* Channel
* @param response
* HTTP response containing the closing handshake details
* @param promise
* the {@link ChannelPromise} to notify once the handshake completes.
* @return future
* the {@link ChannelFuture} which is notified once the handshake completes.
*/
public final ChannelFuture processHandshake(final Channel channel, HttpResponse response,
final ChannelPromise promise) {
if (response instanceof FullHttpResponse) {
try {
finishHandshake(channel, (FullHttpResponse) response);
promise.setSuccess();
} catch (Throwable cause) {
promise.setFailure(cause);
}
} else {
ChannelPipeline p = channel.pipeline();
ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class);
if (ctx == null) {
ctx = p.context(HttpClientCodec.class);
if (ctx == null) {
return promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
"a HttpResponseDecoder or HttpClientCodec"));
}
}
// Add aggregator and ensure we feed the HttpResponse so it is aggregated. A limit of 8192 should be more
// then enough for the websockets handshake payload.
//
// TODO: Make handshake work without HttpObjectAggregator at all.
String aggregatorName = "httpAggregator";
p.addAfter(ctx.name(), aggregatorName, new HttpObjectAggregator(8192));
p.addAfter(aggregatorName, "handshaker", new SimpleChannelInboundHandler<FullHttpResponse>() {
@Override
protected void channelRead0(ChannelHandlerContext ctx, FullHttpResponse msg) throws Exception {
// Remove ourself and do the actual handshake
ctx.pipeline().remove(this);
try {
finishHandshake(channel, msg);
promise.setSuccess();
} catch (Throwable cause) {
promise.setFailure(cause);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
// Remove ourself and fail the handshake promise.
ctx.pipeline().remove(this);
promise.setFailure(cause);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
// Fail promise if Channel was closed
promise.tryFailure(CLOSED_CHANNEL_EXCEPTION);
ctx.fireChannelInactive();
}
});
try {
ctx.fireChannelRead(ReferenceCountUtil.retain(response));
} catch (Throwable cause) {
promise.setFailure(cause);
}
}
return promise;
}
/** /**
* Verfiy the {@link FullHttpResponse} and throws a {@link WebSocketHandshakeException} if something is wrong. * Verfiy the {@link FullHttpResponse} and throws a {@link WebSocketHandshakeException} if something is wrong.
*/ */

View File

@ -21,19 +21,23 @@ import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpContentCompressor; import io.netty.handler.codec.http.HttpContentCompressor;
import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpRequestDecoder; import io.netty.handler.codec.http.HttpRequestDecoder;
import io.netty.handler.codec.http.HttpResponseEncoder; import io.netty.handler.codec.http.HttpResponseEncoder;
import io.netty.handler.codec.http.HttpServerCodec; import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.EmptyArrays;
import io.netty.util.internal.StringUtil; import io.netty.util.internal.StringUtil;
import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory; import io.netty.util.internal.logging.InternalLoggerFactory;
import java.nio.channels.ClosedChannelException;
import java.util.Collections; import java.util.Collections;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.Set; import java.util.Set;
@ -43,6 +47,11 @@ import java.util.Set;
*/ */
public abstract class WebSocketServerHandshaker { public abstract class WebSocketServerHandshaker {
protected static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketServerHandshaker.class); protected static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketServerHandshaker.class);
private static final ClosedChannelException CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException();
static {
CLOSED_CHANNEL_EXCEPTION.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE);
}
private final String uri; private final String uri;
@ -200,6 +209,94 @@ public abstract class WebSocketServerHandshaker {
return promise; return promise;
} }
/**
* Performs the opening handshake. When call this method you <strong>MUST NOT</strong> retain the
* {@link FullHttpRequest} which is passed in.
*
* @param channel
* Channel
* @param req
* HTTP Request
* @return future
* The {@link ChannelFuture} which is notified once the opening handshake completes
*/
public ChannelFuture handshake(Channel channel, HttpRequest req) {
return handshake(channel, req, null, channel.newPromise());
}
/**
* Performs the opening handshake
*
* When call this method you <strong>MUST NOT</strong> retain the {@link HttpRequest} which is passed in.
*
* @param channel
* Channel
* @param req
* HTTP Request
* @param responseHeaders
* Extra headers to add to the handshake response or {@code null} if no extra headers should be added
* @param promise
* the {@link ChannelPromise} to be notified when the opening handshake is done
* @return future
* the {@link ChannelFuture} which is notified when the opening handshake is done
*/
public final ChannelFuture handshake(final Channel channel, HttpRequest req,
final HttpHeaders responseHeaders, final ChannelPromise promise) {
if (req instanceof FullHttpRequest) {
return handshake(channel, (FullHttpRequest) req, responseHeaders, promise);
}
if (logger.isDebugEnabled()) {
logger.debug("{} WebSocket version {} server handshake", channel, version());
}
ChannelPipeline p = channel.pipeline();
ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
if (ctx == null) {
// this means the user use a HttpServerCodec
ctx = p.context(HttpServerCodec.class);
if (ctx == null) {
promise.setFailure(
new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
return promise;
}
}
// Add aggregator and ensure we feed the HttpRequest so it is aggregated. A limit o 8192 should be more then
// enough for the websockets handshake payload.
//
// TODO: Make handshake work without HttpObjectAggregator at all.
String aggregatorName = "httpAggregator";
p.addAfter(ctx.name(), aggregatorName, new HttpObjectAggregator(8192));
p.addAfter(aggregatorName, "handshaker", new SimpleChannelInboundHandler<FullHttpRequest>() {
@Override
protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest msg) throws Exception {
// Remove ourself and do the actual handshake
ctx.pipeline().remove(this);
handshake(channel, msg, responseHeaders, promise);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
// Remove ourself and fail the handshake promise.
ctx.pipeline().remove(this);
promise.tryFailure(cause);
ctx.fireExceptionCaught(cause);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
// Fail promise if Channel was closed
promise.tryFailure(CLOSED_CHANNEL_EXCEPTION);
ctx.fireChannelInactive();
}
});
try {
ctx.fireChannelRead(ReferenceCountUtil.retain(req));
} catch (Throwable cause) {
promise.setFailure(cause);
}
return promise;
}
/** /**
* Returns a new {@link FullHttpResponse) which will be used for as response to the handshake request. * Returns a new {@link FullHttpResponse) which will be used for as response to the handshake request.
*/ */

View File

@ -22,10 +22,10 @@ import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderUtil; import io.netty.handler.codec.http.HttpHeaderUtil;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame; import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame;
@ -55,8 +55,8 @@ public class AutobahnServerHandler extends ChannelHandlerAdapter {
@Override @Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof FullHttpRequest) { if (msg instanceof HttpRequest) {
handleHttpRequest(ctx, (FullHttpRequest) msg); handleHttpRequest(ctx, (HttpRequest) msg);
} else if (msg instanceof WebSocketFrame) { } else if (msg instanceof WebSocketFrame) {
handleWebSocketFrame(ctx, (WebSocketFrame) msg); handleWebSocketFrame(ctx, (WebSocketFrame) msg);
} else { } else {
@ -69,19 +69,17 @@ public class AutobahnServerHandler extends ChannelHandlerAdapter {
ctx.flush(); ctx.flush();
} }
private void handleHttpRequest(ChannelHandlerContext ctx, FullHttpRequest req) private void handleHttpRequest(ChannelHandlerContext ctx, HttpRequest req)
throws Exception { throws Exception {
// Handle a bad request. // Handle a bad request.
if (!req.decoderResult().isSuccess()) { if (!req.decoderResult().isSuccess()) {
sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, BAD_REQUEST)); sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, BAD_REQUEST));
req.release();
return; return;
} }
// Allow only GET methods. // Allow only GET methods.
if (req.method() != GET) { if (req.method() != GET) {
sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN)); sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN));
req.release();
return; return;
} }
@ -94,7 +92,6 @@ public class AutobahnServerHandler extends ChannelHandlerAdapter {
} else { } else {
handshaker.handshake(ctx.channel(), req); handshaker.handshake(ctx.channel(), req);
} }
req.release();
} }
private void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame) { private void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame) {
@ -121,7 +118,7 @@ public class AutobahnServerHandler extends ChannelHandlerAdapter {
} }
private static void sendHttpResponse( private static void sendHttpResponse(
ChannelHandlerContext ctx, FullHttpRequest req, FullHttpResponse res) { ChannelHandlerContext ctx, HttpRequest req, FullHttpResponse res) {
// Generate an error page if response status code is not OK (200). // Generate an error page if response status code is not OK (200).
if (res.status().code() != 200) { if (res.status().code() != 200) {
ByteBuf buf = Unpooled.copiedBuffer(res.status().toString(), CharsetUtil.UTF_8); ByteBuf buf = Unpooled.copiedBuffer(res.status().toString(), CharsetUtil.UTF_8);
@ -142,7 +139,7 @@ public class AutobahnServerHandler extends ChannelHandlerAdapter {
ctx.close(); ctx.close();
} }
private static String getWebSocketLocation(FullHttpRequest req) { private static String getWebSocketLocation(HttpRequest req) {
return "ws://" + req.headers().get(HttpHeaderNames.HOST); return "ws://" + req.headers().get(HttpHeaderNames.HOST);
} }
} }

View File

@ -28,7 +28,6 @@ public class AutobahnServerInitializer extends ChannelInitializer<SocketChannel>
ChannelPipeline pipeline = ch.pipeline(); ChannelPipeline pipeline = ch.pipeline();
pipeline.addLast("encoder", new HttpResponseEncoder()); pipeline.addLast("encoder", new HttpResponseEncoder());
pipeline.addLast("decoder", new HttpRequestDecoder()); pipeline.addLast("decoder", new HttpRequestDecoder());
pipeline.addLast("aggregator", new HttpObjectAggregator(65536));
pipeline.addLast("handler", new AutobahnServerHandler()); pipeline.addLast("handler", new AutobahnServerHandler());
} }
} }