Norman Maurer 697952e3e6
Add WebSocketClientHandshaker / WebSocketServerHandshaker close methods that take ChannelHandlerContext as parameter (#11171)
Motivation:

At the moment we only expose close(...) methods that take a Channel as paramater. This can be problematic as the write will start at the end of the pipeline which may contain ChannelOutboundHandler implementations that not expect WebSocketFrame objects. We should better also support to pass in a ChannelHandlerContext as starting point for the write which ensures that the WebSocketFrame objects will be handled correctly from this position of the pipeline.

Modifications:

- Add new close(...) methods that take a ChannelHandlerContext
- Add javadoc sentence to point users to the new methods.

Result:

Be able to "start" the close at the right position in the pipeline.
2021-04-21 15:16:10 +02:00

146 lines
5.9 KiB
Java

/*
* Copyright 2012 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.testsuite.autobahn;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory;
import io.netty.util.CharsetUtil;
import io.netty.util.internal.StringUtil;
import java.util.logging.Level;
import java.util.logging.Logger;
import static io.netty.handler.codec.http.HttpUtil.*;
import static io.netty.handler.codec.http.HttpMethod.*;
import static io.netty.handler.codec.http.HttpResponseStatus.*;
import static io.netty.handler.codec.http.HttpVersion.*;
/**
* Handles handshakes and messages
*/
public class AutobahnServerHandler extends ChannelInboundHandlerAdapter {
private static final Logger logger = Logger.getLogger(AutobahnServerHandler.class.getName());
private WebSocketServerHandshaker handshaker;
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof HttpRequest) {
handleHttpRequest(ctx, (HttpRequest) msg);
} else if (msg instanceof WebSocketFrame) {
handleWebSocketFrame(ctx, (WebSocketFrame) msg);
} else {
throw new IllegalStateException("unknown message: " + msg);
}
}
@Override
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
ctx.flush();
}
private void handleHttpRequest(ChannelHandlerContext ctx, HttpRequest req)
throws Exception {
// Handle a bad request.
if (!req.decoderResult().isSuccess()) {
sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, BAD_REQUEST, ctx.alloc().buffer(0)));
return;
}
// Allow only GET methods.
if (!GET.equals(req.method())) {
sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN, ctx.alloc().buffer(0)));
return;
}
// Handshake
WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
getWebSocketLocation(req), null, false, Integer.MAX_VALUE);
handshaker = wsFactory.newHandshaker(req);
if (handshaker == null) {
WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
} else {
handshaker.handshake(ctx.channel(), req);
}
}
private void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame) {
if (logger.isLoggable(Level.FINE)) {
logger.fine(String.format(
"Channel %s received %s", ctx.channel().hashCode(), StringUtil.simpleClassName(frame)));
}
if (frame instanceof CloseWebSocketFrame) {
handshaker.close(ctx, (CloseWebSocketFrame) frame);
} else if (frame instanceof PingWebSocketFrame) {
ctx.write(new PongWebSocketFrame(frame.isFinalFragment(), frame.rsv(), frame.content()));
} else if (frame instanceof TextWebSocketFrame ||
frame instanceof BinaryWebSocketFrame ||
frame instanceof ContinuationWebSocketFrame) {
ctx.write(frame);
} else if (frame instanceof PongWebSocketFrame) {
frame.release();
// Ignore
} else {
throw new UnsupportedOperationException(String.format("%s frame types not supported", frame.getClass()
.getName()));
}
}
private static void sendHttpResponse(
ChannelHandlerContext ctx, HttpRequest req, FullHttpResponse res) {
// Generate an error page if response status code is not OK (200).
if (res.status().code() != 200) {
ByteBuf buf = Unpooled.copiedBuffer(res.status().toString(), CharsetUtil.UTF_8);
res.content().writeBytes(buf);
buf.release();
setContentLength(res, res.content().readableBytes());
}
// Send the response and close the connection if necessary.
ChannelFuture f = ctx.channel().writeAndFlush(res);
if (!isKeepAlive(req) || res.status().code() != 200) {
f.addListener(ChannelFutureListener.CLOSE);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
ctx.close();
}
private static String getWebSocketLocation(HttpRequest req) {
return "ws://" + req.headers().get(HttpHeaderNames.HOST);
}
}