Allow to use WebSocketClientHandshaker and WebSocketServerHandshaker with HttpResponse / HttpRequest
Motivation: To use WebSocketClientHandshaker / WebSocketServerHandshaker it's currently a requirement of having a HttpObjectAggregator in the ChannelPipeline. This is not a big deal when a user only wants to server WebSockets but is a limitation if the server serves WebSockets and normal HTTP traffic. Modifications: Allow to use WebSocketClientHandshaker and WebSocketServerHandshaker without HttpObjectAggregator in the ChannelPipeline. Result: More flexibility
This commit is contained in:
parent
b984ca7979
commit
afa9e71ed3
@ -21,22 +21,33 @@ import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPipeline;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.channel.SimpleChannelInboundHandler;
|
||||
import io.netty.handler.codec.http.FullHttpRequest;
|
||||
import io.netty.handler.codec.http.FullHttpResponse;
|
||||
import io.netty.handler.codec.http.HttpClientCodec;
|
||||
import io.netty.handler.codec.http.HttpContentDecompressor;
|
||||
import io.netty.handler.codec.http.HttpHeaderNames;
|
||||
import io.netty.handler.codec.http.HttpHeaders;
|
||||
import io.netty.handler.codec.http.HttpObjectAggregator;
|
||||
import io.netty.handler.codec.http.HttpRequestEncoder;
|
||||
import io.netty.handler.codec.http.HttpResponse;
|
||||
import io.netty.handler.codec.http.HttpResponseDecoder;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import io.netty.util.internal.EmptyArrays;
|
||||
import io.netty.util.internal.StringUtil;
|
||||
|
||||
import java.net.URI;
|
||||
import java.nio.channels.ClosedChannelException;
|
||||
|
||||
/**
|
||||
* Base class for web socket client handshake implementations
|
||||
*/
|
||||
public abstract class WebSocketClientHandshaker {
|
||||
private static final ClosedChannelException CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException();
|
||||
|
||||
static {
|
||||
CLOSED_CHANNEL_EXCEPTION.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE);
|
||||
}
|
||||
|
||||
private final URI uri;
|
||||
|
||||
@ -239,6 +250,12 @@ public abstract class WebSocketClientHandshaker {
|
||||
p.remove(decompressor);
|
||||
}
|
||||
|
||||
// Remove aggregator if present before
|
||||
HttpObjectAggregator aggregator = p.get(HttpObjectAggregator.class);
|
||||
if (aggregator != null) {
|
||||
p.remove(aggregator);
|
||||
}
|
||||
|
||||
ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class);
|
||||
if (ctx == null) {
|
||||
ctx = p.context(HttpClientCodec.class);
|
||||
@ -256,6 +273,93 @@ public abstract class WebSocketClientHandshaker {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process the opening handshake initiated by {@link #handshake}}.
|
||||
*
|
||||
* @param channel
|
||||
* Channel
|
||||
* @param response
|
||||
* HTTP response containing the closing handshake details
|
||||
* @return future
|
||||
* the {@link ChannelFuture} which is notified once the handshake completes.
|
||||
*/
|
||||
public final ChannelFuture processHandshake(final Channel channel, HttpResponse response) {
|
||||
return processHandshake(channel, response, channel.newPromise());
|
||||
}
|
||||
|
||||
/**
|
||||
* Process the opening handshake initiated by {@link #handshake}}.
|
||||
*
|
||||
* @param channel
|
||||
* Channel
|
||||
* @param response
|
||||
* HTTP response containing the closing handshake details
|
||||
* @param promise
|
||||
* the {@link ChannelPromise} to notify once the handshake completes.
|
||||
* @return future
|
||||
* the {@link ChannelFuture} which is notified once the handshake completes.
|
||||
*/
|
||||
public final ChannelFuture processHandshake(final Channel channel, HttpResponse response,
|
||||
final ChannelPromise promise) {
|
||||
if (response instanceof FullHttpResponse) {
|
||||
try {
|
||||
finishHandshake(channel, (FullHttpResponse) response);
|
||||
promise.setSuccess();
|
||||
} catch (Throwable cause) {
|
||||
promise.setFailure(cause);
|
||||
}
|
||||
} else {
|
||||
ChannelPipeline p = channel.pipeline();
|
||||
ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class);
|
||||
if (ctx == null) {
|
||||
ctx = p.context(HttpClientCodec.class);
|
||||
if (ctx == null) {
|
||||
return promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
|
||||
"a HttpResponseDecoder or HttpClientCodec"));
|
||||
}
|
||||
}
|
||||
// Add aggregator and ensure we feed the HttpResponse so it is aggregated. A limit of 8192 should be more
|
||||
// then enough for the websockets handshake payload.
|
||||
//
|
||||
// TODO: Make handshake work without HttpObjectAggregator at all.
|
||||
String aggregatorName = "httpAggregator";
|
||||
p.addAfter(ctx.name(), aggregatorName, new HttpObjectAggregator(8192));
|
||||
p.addAfter(aggregatorName, "handshaker", new SimpleChannelInboundHandler<FullHttpResponse>() {
|
||||
@Override
|
||||
protected void channelRead0(ChannelHandlerContext ctx, FullHttpResponse msg) throws Exception {
|
||||
// Remove ourself and do the actual handshake
|
||||
ctx.pipeline().remove(this);
|
||||
try {
|
||||
finishHandshake(channel, msg);
|
||||
promise.setSuccess();
|
||||
} catch (Throwable cause) {
|
||||
promise.setFailure(cause);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
|
||||
// Remove ourself and fail the handshake promise.
|
||||
ctx.pipeline().remove(this);
|
||||
promise.setFailure(cause);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
|
||||
// Fail promise if Channel was closed
|
||||
promise.tryFailure(CLOSED_CHANNEL_EXCEPTION);
|
||||
ctx.fireChannelInactive();
|
||||
}
|
||||
});
|
||||
try {
|
||||
ctx.fireChannelRead(ReferenceCountUtil.retain(response));
|
||||
} catch (Throwable cause) {
|
||||
promise.setFailure(cause);
|
||||
}
|
||||
}
|
||||
return promise;
|
||||
}
|
||||
|
||||
/**
|
||||
* Verfiy the {@link FullHttpResponse} and throws a {@link WebSocketHandshakeException} if something is wrong.
|
||||
*/
|
||||
|
@ -21,19 +21,23 @@ import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPipeline;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.channel.SimpleChannelInboundHandler;
|
||||
import io.netty.handler.codec.http.FullHttpRequest;
|
||||
import io.netty.handler.codec.http.FullHttpResponse;
|
||||
import io.netty.handler.codec.http.HttpContentCompressor;
|
||||
import io.netty.handler.codec.http.HttpHeaders;
|
||||
import io.netty.handler.codec.http.HttpObjectAggregator;
|
||||
import io.netty.handler.codec.http.HttpRequest;
|
||||
import io.netty.handler.codec.http.HttpRequestDecoder;
|
||||
import io.netty.handler.codec.http.HttpResponseEncoder;
|
||||
import io.netty.handler.codec.http.HttpServerCodec;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import io.netty.util.internal.EmptyArrays;
|
||||
import io.netty.util.internal.StringUtil;
|
||||
import io.netty.util.internal.logging.InternalLogger;
|
||||
import io.netty.util.internal.logging.InternalLoggerFactory;
|
||||
|
||||
import java.nio.channels.ClosedChannelException;
|
||||
import java.util.Collections;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.Set;
|
||||
@ -43,6 +47,11 @@ import java.util.Set;
|
||||
*/
|
||||
public abstract class WebSocketServerHandshaker {
|
||||
protected static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketServerHandshaker.class);
|
||||
private static final ClosedChannelException CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException();
|
||||
|
||||
static {
|
||||
CLOSED_CHANNEL_EXCEPTION.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE);
|
||||
}
|
||||
|
||||
private final String uri;
|
||||
|
||||
@ -200,6 +209,94 @@ public abstract class WebSocketServerHandshaker {
|
||||
return promise;
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs the opening handshake. When call this method you <strong>MUST NOT</strong> retain the
|
||||
* {@link FullHttpRequest} which is passed in.
|
||||
*
|
||||
* @param channel
|
||||
* Channel
|
||||
* @param req
|
||||
* HTTP Request
|
||||
* @return future
|
||||
* The {@link ChannelFuture} which is notified once the opening handshake completes
|
||||
*/
|
||||
public ChannelFuture handshake(Channel channel, HttpRequest req) {
|
||||
return handshake(channel, req, null, channel.newPromise());
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs the opening handshake
|
||||
*
|
||||
* When call this method you <strong>MUST NOT</strong> retain the {@link HttpRequest} which is passed in.
|
||||
*
|
||||
* @param channel
|
||||
* Channel
|
||||
* @param req
|
||||
* HTTP Request
|
||||
* @param responseHeaders
|
||||
* Extra headers to add to the handshake response or {@code null} if no extra headers should be added
|
||||
* @param promise
|
||||
* the {@link ChannelPromise} to be notified when the opening handshake is done
|
||||
* @return future
|
||||
* the {@link ChannelFuture} which is notified when the opening handshake is done
|
||||
*/
|
||||
public final ChannelFuture handshake(final Channel channel, HttpRequest req,
|
||||
final HttpHeaders responseHeaders, final ChannelPromise promise) {
|
||||
|
||||
if (req instanceof FullHttpRequest) {
|
||||
return handshake(channel, (FullHttpRequest) req, responseHeaders, promise);
|
||||
}
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug("{} WebSocket version {} server handshake", channel, version());
|
||||
}
|
||||
ChannelPipeline p = channel.pipeline();
|
||||
ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
|
||||
if (ctx == null) {
|
||||
// this means the user use a HttpServerCodec
|
||||
ctx = p.context(HttpServerCodec.class);
|
||||
if (ctx == null) {
|
||||
promise.setFailure(
|
||||
new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
|
||||
return promise;
|
||||
}
|
||||
}
|
||||
// Add aggregator and ensure we feed the HttpRequest so it is aggregated. A limit o 8192 should be more then
|
||||
// enough for the websockets handshake payload.
|
||||
//
|
||||
// TODO: Make handshake work without HttpObjectAggregator at all.
|
||||
String aggregatorName = "httpAggregator";
|
||||
p.addAfter(ctx.name(), aggregatorName, new HttpObjectAggregator(8192));
|
||||
p.addAfter(aggregatorName, "handshaker", new SimpleChannelInboundHandler<FullHttpRequest>() {
|
||||
@Override
|
||||
protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest msg) throws Exception {
|
||||
// Remove ourself and do the actual handshake
|
||||
ctx.pipeline().remove(this);
|
||||
handshake(channel, msg, responseHeaders, promise);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
|
||||
// Remove ourself and fail the handshake promise.
|
||||
ctx.pipeline().remove(this);
|
||||
promise.tryFailure(cause);
|
||||
ctx.fireExceptionCaught(cause);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
|
||||
// Fail promise if Channel was closed
|
||||
promise.tryFailure(CLOSED_CHANNEL_EXCEPTION);
|
||||
ctx.fireChannelInactive();
|
||||
}
|
||||
});
|
||||
try {
|
||||
ctx.fireChannelRead(ReferenceCountUtil.retain(req));
|
||||
} catch (Throwable cause) {
|
||||
promise.setFailure(cause);
|
||||
}
|
||||
return promise;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a new {@link FullHttpResponse) which will be used for as response to the handshake request.
|
||||
*/
|
||||
|
@ -22,9 +22,9 @@ import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.handler.codec.http.DefaultFullHttpResponse;
|
||||
import io.netty.handler.codec.http.FullHttpRequest;
|
||||
import io.netty.handler.codec.http.FullHttpResponse;
|
||||
import io.netty.handler.codec.http.HttpHeaderNames;
|
||||
import io.netty.handler.codec.http.HttpRequest;
|
||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame;
|
||||
@ -55,8 +55,8 @@ public class AutobahnServerHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
@Override
|
||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||
if (msg instanceof FullHttpRequest) {
|
||||
handleHttpRequest(ctx, (FullHttpRequest) msg);
|
||||
if (msg instanceof HttpRequest) {
|
||||
handleHttpRequest(ctx, (HttpRequest) msg);
|
||||
} else if (msg instanceof WebSocketFrame) {
|
||||
handleWebSocketFrame(ctx, (WebSocketFrame) msg);
|
||||
} else {
|
||||
@ -69,19 +69,17 @@ public class AutobahnServerHandler extends ChannelInboundHandlerAdapter {
|
||||
ctx.flush();
|
||||
}
|
||||
|
||||
private void handleHttpRequest(ChannelHandlerContext ctx, FullHttpRequest req)
|
||||
private void handleHttpRequest(ChannelHandlerContext ctx, HttpRequest req)
|
||||
throws Exception {
|
||||
// Handle a bad request.
|
||||
if (!req.decoderResult().isSuccess()) {
|
||||
sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, BAD_REQUEST));
|
||||
req.release();
|
||||
return;
|
||||
}
|
||||
|
||||
// Allow only GET methods.
|
||||
if (req.method() != GET) {
|
||||
sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN));
|
||||
req.release();
|
||||
return;
|
||||
}
|
||||
|
||||
@ -94,7 +92,6 @@ public class AutobahnServerHandler extends ChannelInboundHandlerAdapter {
|
||||
} else {
|
||||
handshaker.handshake(ctx.channel(), req);
|
||||
}
|
||||
req.release();
|
||||
}
|
||||
|
||||
private void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame) {
|
||||
@ -121,7 +118,7 @@ public class AutobahnServerHandler extends ChannelInboundHandlerAdapter {
|
||||
}
|
||||
|
||||
private static void sendHttpResponse(
|
||||
ChannelHandlerContext ctx, FullHttpRequest req, FullHttpResponse res) {
|
||||
ChannelHandlerContext ctx, HttpRequest req, FullHttpResponse res) {
|
||||
// Generate an error page if response status code is not OK (200).
|
||||
if (res.status().code() != 200) {
|
||||
ByteBuf buf = Unpooled.copiedBuffer(res.status().toString(), CharsetUtil.UTF_8);
|
||||
@ -142,7 +139,7 @@ public class AutobahnServerHandler extends ChannelInboundHandlerAdapter {
|
||||
ctx.close();
|
||||
}
|
||||
|
||||
private static String getWebSocketLocation(FullHttpRequest req) {
|
||||
private static String getWebSocketLocation(HttpRequest req) {
|
||||
return "ws://" + req.headers().get(HttpHeaderNames.HOST);
|
||||
}
|
||||
}
|
||||
|
@ -28,7 +28,6 @@ public class AutobahnServerInitializer extends ChannelInitializer<SocketChannel>
|
||||
ChannelPipeline pipeline = ch.pipeline();
|
||||
pipeline.addLast("encoder", new HttpResponseEncoder());
|
||||
pipeline.addLast("decoder", new HttpRequestDecoder());
|
||||
pipeline.addLast("aggregator", new HttpObjectAggregator(65536));
|
||||
pipeline.addLast("handler", new AutobahnServerHandler());
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user