netty5/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java
秦世成 5ffac03f1e Support handshake timeout in websocket handlers (#8856)
Motivation:

Support handshake timeout option in websocket handlers. It makes sense to limit the time we need to move from `HANDSHAKE_ISSUED` to `HANDSHAKE_COMPLETE` states when upgrading to WebSockets

Modification:

- Add `handshakeTimeoutMillis` option in `WebSocketClientProtocolHandshakeHandler`  and `WebSocketServerProtocolHandshakeHandler`.
- Schedule a timeout task, the task will trigger user event `HANDSHAKE_TIMEOUT` if the handshake timed out.

Result:

Fixes issue https://github.com/netty/netty/issues/8841
2019-05-22 12:37:28 +02:00

200 lines
9.0 KiB
Java

/*
* Copyright 2012 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.http.websocketx;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler.ServerHandshakeStateEvent;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.FutureListener;
import io.netty.util.concurrent.ScheduledFuture;
import io.netty.util.internal.ThrowableUtil;
import java.util.concurrent.TimeUnit;
import static io.netty.handler.codec.http.HttpMethod.*;
import static io.netty.handler.codec.http.HttpResponseStatus.*;
import static io.netty.handler.codec.http.HttpUtil.*;
import static io.netty.handler.codec.http.HttpVersion.*;
import static io.netty.util.internal.ObjectUtil.*;
/**
* Handles the HTTP handshake (the HTTP Upgrade request) for {@link WebSocketServerProtocolHandler}.
*/
class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapter {
private static final WebSocketHandshakeException HANDSHAKE_TIMED_OUT_EXCEPTION = ThrowableUtil.unknownStackTrace(
new WebSocketHandshakeException("handshake timed out"),
WebSocketServerProtocolHandshakeHandler.class,
"channelRead(...)");
private static final long DEFAULT_HANDSHAKE_TIMEOUT_MS = 10000L;
private final String websocketPath;
private final String subprotocols;
private final boolean allowExtensions;
private final int maxFramePayloadSize;
private final boolean allowMaskMismatch;
private final boolean checkStartsWith;
private final long handshakeTimeoutMillis;
private ChannelHandlerContext ctx;
private ChannelPromise handshakePromise;
WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch) {
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch,
DEFAULT_HANDSHAKE_TIMEOUT_MS);
}
WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize,
boolean allowMaskMismatch, long handshakeTimeoutMillis) {
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch,
false, handshakeTimeoutMillis);
}
WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith) {
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch,
checkStartsWith, DEFAULT_HANDSHAKE_TIMEOUT_MS);
}
WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch,
boolean checkStartsWith, long handshakeTimeoutMillis) {
this.websocketPath = websocketPath;
this.subprotocols = subprotocols;
this.allowExtensions = allowExtensions;
maxFramePayloadSize = maxFrameSize;
this.allowMaskMismatch = allowMaskMismatch;
this.checkStartsWith = checkStartsWith;
this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis");
}
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
this.ctx = ctx;
handshakePromise = ctx.newPromise();
}
@Override
public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception {
final FullHttpRequest req = (FullHttpRequest) msg;
if (isNotWebSocketPath(req)) {
ctx.fireChannelRead(msg);
return;
}
try {
if (!GET.equals(req.method())) {
sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN));
return;
}
final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
getWebSocketLocation(ctx.pipeline(), req, websocketPath), subprotocols,
allowExtensions, maxFramePayloadSize, allowMaskMismatch);
final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);
final ChannelPromise localHandshakePromise = handshakePromise;
if (handshaker == null) {
WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
} else {
final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req);
handshakeFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
localHandshakePromise.tryFailure(future.cause());
ctx.fireExceptionCaught(future.cause());
} else {
localHandshakePromise.trySuccess();
// Kept for compatibility
ctx.fireUserEventTriggered(
WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE);
ctx.fireUserEventTriggered(
new WebSocketServerProtocolHandler.HandshakeComplete(
req.uri(), req.headers(), handshaker.selectedSubprotocol()));
}
}
});
applyHandshakeTimeout();
WebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker);
ctx.pipeline().replace(this, "WS403Responder",
WebSocketServerProtocolHandler.forbiddenHttpRequestResponder());
}
} finally {
req.release();
}
}
private boolean isNotWebSocketPath(FullHttpRequest req) {
return checkStartsWith ? !req.uri().startsWith(websocketPath) : !req.uri().equals(websocketPath);
}
private static void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) {
ChannelFuture f = ctx.channel().writeAndFlush(res);
if (!isKeepAlive(req) || res.status().code() != 200) {
f.addListener(ChannelFutureListener.CLOSE);
}
}
private static String getWebSocketLocation(ChannelPipeline cp, HttpRequest req, String path) {
String protocol = "ws";
if (cp.get(SslHandler.class) != null) {
// SSL in use so use Secure WebSockets
protocol = "wss";
}
String host = req.headers().get(HttpHeaderNames.HOST);
return protocol + "://" + host + path;
}
private void applyHandshakeTimeout() {
final ChannelPromise localHandshakePromise = handshakePromise;
final long handshakeTimeoutMillis = this.handshakeTimeoutMillis;
if (handshakeTimeoutMillis <= 0 || localHandshakePromise.isDone()) {
return;
}
final Future<?> timeoutFuture = ctx.executor().schedule(new Runnable() {
@Override
public void run() {
if (localHandshakePromise.tryFailure(HANDSHAKE_TIMED_OUT_EXCEPTION)) {
ctx.flush()
.fireUserEventTriggered(ServerHandshakeStateEvent.HANDSHAKE_TIMEOUT)
.close();
}
}
}, handshakeTimeoutMillis, TimeUnit.MILLISECONDS);
// Cancel the handshake timeout when handshake is finished.
localHandshakePromise.addListener(new FutureListener<Void>() {
@Override
public void operationComplete(Future<Void> f) throws Exception {
timeoutFuture.cancel(false);
}
});
}
}