Support handshake timeout in websocket handlers (#8856)

Motivation:

Support handshake timeout option in websocket handlers. It makes sense to limit the time we need to move from `HANDSHAKE_ISSUED` to `HANDSHAKE_COMPLETE` states when upgrading to WebSockets

Modification:

- Add `handshakeTimeoutMillis` option in `WebSocketClientProtocolHandshakeHandler`  and `WebSocketServerProtocolHandshakeHandler`.
- Schedule a timeout task, the task will trigger user event `HANDSHAKE_TIMEOUT` if the handshake timed out.

Result:

Fixes issue https://github.com/netty/netty/issues/8841
This commit is contained in:
秦世成 2019-05-22 18:37:28 +08:00 committed by Norman Maurer
parent 2ca526fac6
commit 5ffac03f1e
5 changed files with 436 additions and 17 deletions

View File

@ -23,6 +23,8 @@ import io.netty.handler.codec.http.HttpHeaders;
import java.net.URI;
import java.util.List;
import static io.netty.util.internal.ObjectUtil.*;
/**
* This handler does all the heavy lifting for you to run a websocket client.
*
@ -38,9 +40,11 @@ import java.util.List;
* {@link ClientHandshakeStateEvent#HANDSHAKE_ISSUED} or {@link ClientHandshakeStateEvent#HANDSHAKE_COMPLETE}.
*/
public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
private static final long DEFAULT_HANDSHAKE_TIMEOUT_MS = 10000L;
private final WebSocketClientHandshaker handshaker;
private final boolean handleCloseFrames;
private final long handshakeTimeoutMillis;
/**
* Returns the used handshaker
@ -53,6 +57,11 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
* Events that are fired to notify about handshake status
*/
public enum ClientHandshakeStateEvent {
/**
* The Handshake was timed out
*/
HANDSHAKE_TIMEOUT,
/**
* The Handshake was started but the server did not response yet to the request
*/
@ -92,9 +101,45 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
boolean allowExtensions, HttpHeaders customHeaders,
int maxFramePayloadLength, boolean handleCloseFrames,
boolean performMasking, boolean allowMaskMismatch) {
this(webSocketURL, version, subprotocol, allowExtensions, customHeaders,
maxFramePayloadLength, handleCloseFrames, performMasking, allowMaskMismatch, DEFAULT_HANDSHAKE_TIMEOUT_MS);
}
/**
* Base constructor
*
* @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
* @param version
* Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server.
* @param customHeaders
* Map of custom headers to add to the client request
* @param maxFramePayloadLength
* Maximum length of a frame's payload
* @param handleCloseFrames
* {@code true} if close frames should not be forwarded and just close the channel
* @param performMasking
* Whether to mask all written websocket frames. This must be set to true in order to be fully compatible
* with the websocket specifications. Client applications that communicate with a non-standard server
* which doesn't require masking might set this to false to achieve a higher performance.
* @param allowMaskMismatch
* When set to true, frames which are not masked properly according to the standard will still be
* accepted.
* @param handshakeTimeoutMillis
* Handshake timeout in mills, when handshake timeout, will trigger user
* event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT}
*/
public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders,
int maxFramePayloadLength, boolean handleCloseFrames, boolean performMasking,
boolean allowMaskMismatch, long handshakeTimeoutMillis) {
this(WebSocketClientHandshakerFactory.newHandshaker(webSocketURL, version, subprotocol,
allowExtensions, customHeaders, maxFramePayloadLength,
performMasking, allowMaskMismatch), handleCloseFrames);
performMasking, allowMaskMismatch),
handleCloseFrames, handshakeTimeoutMillis);
}
/**
@ -118,7 +163,34 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
boolean allowExtensions, HttpHeaders customHeaders,
int maxFramePayloadLength, boolean handleCloseFrames) {
this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength,
handleCloseFrames, true, false);
handleCloseFrames, DEFAULT_HANDSHAKE_TIMEOUT_MS);
}
/**
* Base constructor
*
* @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
* @param version
* Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server.
* @param customHeaders
* Map of custom headers to add to the client request
* @param maxFramePayloadLength
* Maximum length of a frame's payload
* @param handleCloseFrames
* {@code true} if close frames should not be forwarded and just close the channel
* @param handshakeTimeoutMillis
* Handshake timeout in mills, when handshake timeout, will trigger user
* event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT}
*/
public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,
boolean handleCloseFrames, long handshakeTimeoutMillis) {
this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength,
handleCloseFrames, true, false, handshakeTimeoutMillis);
}
/**
@ -140,7 +212,32 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
boolean allowExtensions, HttpHeaders customHeaders,
int maxFramePayloadLength) {
this(webSocketURL, version, subprotocol,
allowExtensions, customHeaders, maxFramePayloadLength, true);
allowExtensions, customHeaders, maxFramePayloadLength, DEFAULT_HANDSHAKE_TIMEOUT_MS);
}
/**
* Base constructor
*
* @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
* @param version
* Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server.
* @param customHeaders
* Map of custom headers to add to the client request
* @param maxFramePayloadLength
* Maximum length of a frame's payload
* @param handshakeTimeoutMillis
* Handshake timeout in mills, when handshake timeout, will trigger user
* event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT}
*/
public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders,
int maxFramePayloadLength, long handshakeTimeoutMillis) {
this(webSocketURL, version, subprotocol,
allowExtensions, customHeaders, maxFramePayloadLength, true, handshakeTimeoutMillis);
}
/**
@ -153,7 +250,24 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
* {@code true} if close frames should not be forwarded and just close the channel
*/
public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames) {
this(handshaker, handleCloseFrames, true);
this(handshaker, handleCloseFrames, DEFAULT_HANDSHAKE_TIMEOUT_MS);
}
/**
* Base constructor
*
* @param handshaker
* The {@link WebSocketClientHandshaker} which will be used to issue the handshake once the connection
* was established to the remote peer.
* @param handleCloseFrames
* {@code true} if close frames should not be forwarded and just close the channel
* @param handshakeTimeoutMillis
* Handshake timeout in mills, when handshake timeout, will trigger user
* event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT}
*/
public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames,
long handshakeTimeoutMillis) {
this(handshaker, handleCloseFrames, true, handshakeTimeoutMillis);
}
/**
@ -169,9 +283,29 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
*/
public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames,
boolean dropPongFrames) {
this(handshaker, handleCloseFrames, dropPongFrames, DEFAULT_HANDSHAKE_TIMEOUT_MS);
}
/**
* Base constructor
*
* @param handshaker
* The {@link WebSocketClientHandshaker} which will be used to issue the handshake once the connection
* was established to the remote peer.
* @param handleCloseFrames
* {@code true} if close frames should not be forwarded and just close the channel
* @param dropPongFrames
* {@code true} if pong frames should not be forwarded
* @param handshakeTimeoutMillis
* Handshake timeout in mills, when handshake timeout, will trigger user
* event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT}
*/
public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames,
boolean dropPongFrames, long handshakeTimeoutMillis) {
super(dropPongFrames);
this.handshaker = handshaker;
this.handleCloseFrames = handleCloseFrames;
this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis");
}
/**
@ -182,7 +316,21 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
* was established to the remote peer.
*/
public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker) {
this(handshaker, true);
this(handshaker, DEFAULT_HANDSHAKE_TIMEOUT_MS);
}
/**
* Base constructor
*
* @param handshaker
* The {@link WebSocketClientHandshaker} which will be used to issue the handshake once the connection
* was established to the remote peer.
* @param handshakeTimeoutMillis
* Handshake timeout in mills, when handshake timeout, will trigger user
* event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT}
*/
public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, long handshakeTimeoutMillis) {
this(handshaker, true, handshakeTimeoutMillis);
}
@Override
@ -200,7 +348,7 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
if (cp.get(WebSocketClientProtocolHandshakeHandler.class) == null) {
// Add the WebSocketClientProtocolHandshakeHandler before this one.
ctx.pipeline().addBefore(ctx.name(), WebSocketClientProtocolHandshakeHandler.class.getName(),
new WebSocketClientProtocolHandshakeHandler(handshaker));
new WebSocketClientProtocolHandshakeHandler(handshaker, handshakeTimeoutMillis));
}
if (cp.get(Utf8FrameValidator.class) == null) {
// Add the UFT8 checking before this one.

View File

@ -19,13 +19,43 @@ import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler.ClientHandshakeStateEvent;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.FutureListener;
import io.netty.util.concurrent.ScheduledFuture;
import io.netty.util.internal.ThrowableUtil;
import java.util.concurrent.TimeUnit;
import static io.netty.util.internal.ObjectUtil.*;
class WebSocketClientProtocolHandshakeHandler extends ChannelInboundHandlerAdapter {
private static final WebSocketHandshakeException HANDSHAKE_TIMED_OUT_EXCEPTION = ThrowableUtil.unknownStackTrace(
new WebSocketHandshakeException("handshake timed out"),
WebSocketClientProtocolHandshakeHandler.class,
"channelActive(...)");
private static final long DEFAULT_HANDSHAKE_TIMEOUT_MS = 10000L;
private final WebSocketClientHandshaker handshaker;
private final long handshakeTimeoutMillis;
private ChannelHandlerContext ctx;
private ChannelPromise handshakePromise;
WebSocketClientProtocolHandshakeHandler(WebSocketClientHandshaker handshaker) {
this(handshaker, DEFAULT_HANDSHAKE_TIMEOUT_MS);
}
WebSocketClientProtocolHandshakeHandler(WebSocketClientHandshaker handshaker, long handshakeTimeoutMillis) {
this.handshaker = handshaker;
this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis");
}
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
this.ctx = ctx;
handshakePromise = ctx.newPromise();
}
@Override
@ -35,6 +65,7 @@ class WebSocketClientProtocolHandshakeHandler extends ChannelInboundHandlerAdapt
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
handshakePromise.tryFailure(future.cause());
ctx.fireExceptionCaught(future.cause());
} else {
ctx.fireUserEventTriggered(
@ -42,6 +73,7 @@ class WebSocketClientProtocolHandshakeHandler extends ChannelInboundHandlerAdapt
}
}
});
applyHandshakeTimeout();
}
@Override
@ -55,6 +87,7 @@ class WebSocketClientProtocolHandshakeHandler extends ChannelInboundHandlerAdapt
try {
if (!handshaker.isHandshakeComplete()) {
handshaker.finishHandshake(ctx.channel(), response);
handshakePromise.trySuccess();
ctx.fireUserEventTriggered(
WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE);
ctx.pipeline().remove(this);
@ -65,4 +98,43 @@ class WebSocketClientProtocolHandshakeHandler extends ChannelInboundHandlerAdapt
response.release();
}
}
private void applyHandshakeTimeout() {
final ChannelPromise localHandshakePromise = handshakePromise;
if (handshakeTimeoutMillis <= 0 || localHandshakePromise.isDone()) {
return;
}
final Future<?> timeoutFuture = ctx.executor().schedule(new Runnable() {
@Override
public void run() {
if (localHandshakePromise.isDone()) {
return;
}
if (localHandshakePromise.tryFailure(HANDSHAKE_TIMED_OUT_EXCEPTION)) {
ctx.flush()
.fireUserEventTriggered(ClientHandshakeStateEvent.HANDSHAKE_TIMEOUT)
.close();
}
}
}, handshakeTimeoutMillis, TimeUnit.MILLISECONDS);
// Cancel the handshake timeout when handshake is finished.
localHandshakePromise.addListener(new FutureListener<Void>() {
@Override
public void operationComplete(Future<Void> f) throws Exception {
timeoutFuture.cancel(false);
}
});
}
/**
* This method is visible for testing.
*
* @return current handshake future
*/
ChannelFuture getHandshakeFuture() {
return handshakePromise;
}
}

View File

@ -33,6 +33,7 @@ import io.netty.util.AttributeKey;
import java.util.List;
import static io.netty.handler.codec.http.HttpVersion.*;
import static io.netty.util.internal.ObjectUtil.*;
/**
* This handler does all the heavy lifting for you to run a websocket server.
@ -64,7 +65,12 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
* it provides extra information about the handshake
*/
@Deprecated
HANDSHAKE_COMPLETE
HANDSHAKE_COMPLETE,
/**
* The Handshake was timed out
*/
HANDSHAKE_TIMEOUT
}
/**
@ -97,47 +103,94 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
private static final AttributeKey<WebSocketServerHandshaker> HANDSHAKER_ATTR_KEY =
AttributeKey.valueOf(WebSocketServerHandshaker.class, "HANDSHAKER");
private static final long DEFAULT_HANDSHAKE_TIMEOUT_MS = 10000L;
private final String websocketPath;
private final String subprotocols;
private final boolean allowExtensions;
private final int maxFramePayloadLength;
private final boolean allowMaskMismatch;
private final boolean checkStartsWith;
private final long handshakeTimeoutMillis;
public WebSocketServerProtocolHandler(String websocketPath) {
this(websocketPath, DEFAULT_HANDSHAKE_TIMEOUT_MS);
}
public WebSocketServerProtocolHandler(String websocketPath, long handshakeTimeoutMillis) {
this(websocketPath, null, false);
}
public WebSocketServerProtocolHandler(String websocketPath, boolean checkStartsWith) {
this(websocketPath, null, false, 65536, false, checkStartsWith);
this(websocketPath, checkStartsWith, DEFAULT_HANDSHAKE_TIMEOUT_MS);
}
public WebSocketServerProtocolHandler(String websocketPath, boolean checkStartsWith, long handshakeTimeoutMillis) {
this(websocketPath, null, false, 65536, false, checkStartsWith, handshakeTimeoutMillis);
}
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols) {
this(websocketPath, subprotocols, false);
this(websocketPath, subprotocols, DEFAULT_HANDSHAKE_TIMEOUT_MS);
}
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, long handshakeTimeoutMillis) {
this(websocketPath, subprotocols, false, handshakeTimeoutMillis);
}
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions) {
this(websocketPath, subprotocols, allowExtensions, 65536);
this(websocketPath, subprotocols, allowExtensions, DEFAULT_HANDSHAKE_TIMEOUT_MS);
}
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions,
long handshakeTimeoutMillis) {
this(websocketPath, subprotocols, allowExtensions, 65536, handshakeTimeoutMillis);
}
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize) {
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, false);
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, DEFAULT_HANDSHAKE_TIMEOUT_MS);
}
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize, long handshakeTimeoutMillis) {
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, false, handshakeTimeoutMillis);
}
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch) {
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false);
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch,
DEFAULT_HANDSHAKE_TIMEOUT_MS);
}
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions,
int maxFrameSize, boolean allowMaskMismatch, long handshakeTimeoutMillis) {
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false,
handshakeTimeoutMillis);
}
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith) {
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith, true);
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith,
DEFAULT_HANDSHAKE_TIMEOUT_MS);
}
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch,
boolean checkStartsWith, long handshakeTimeoutMillis) {
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith, true,
handshakeTimeoutMillis);
}
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch,
boolean checkStartsWith, boolean dropPongFrames) {
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith,
dropPongFrames, DEFAULT_HANDSHAKE_TIMEOUT_MS);
}
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions,
int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith,
boolean dropPongFrames, long handshakeTimeoutMillis) {
super(dropPongFrames);
this.websocketPath = websocketPath;
this.subprotocols = subprotocols;
@ -145,6 +198,7 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
maxFramePayloadLength = maxFrameSize;
this.allowMaskMismatch = allowMaskMismatch;
this.checkStartsWith = checkStartsWith;
this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis");
}
@Override
@ -153,8 +207,11 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
if (cp.get(WebSocketServerProtocolHandshakeHandler.class) == null) {
// Add the WebSocketHandshakeHandler before this one.
ctx.pipeline().addBefore(ctx.name(), WebSocketServerProtocolHandshakeHandler.class.getName(),
new WebSocketServerProtocolHandshakeHandler(websocketPath, subprotocols,
allowExtensions, maxFramePayloadLength, allowMaskMismatch, checkStartsWith));
new WebSocketServerProtocolHandshakeHandler(websocketPath, subprotocols,
allowExtensions, maxFramePayloadLength,
allowMaskMismatch,
checkStartsWith,
handshakeTimeoutMillis));
}
if (cp.get(Utf8FrameValidator.class) == null) {
// Add the UFT8 checking before this one.

View File

@ -20,22 +20,37 @@ import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler.ServerHandshakeStateEvent;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.FutureListener;
import io.netty.util.concurrent.ScheduledFuture;
import io.netty.util.internal.ThrowableUtil;
import java.util.concurrent.TimeUnit;
import static io.netty.handler.codec.http.HttpUtil.*;
import static io.netty.handler.codec.http.HttpMethod.*;
import static io.netty.handler.codec.http.HttpResponseStatus.*;
import static io.netty.handler.codec.http.HttpUtil.*;
import static io.netty.handler.codec.http.HttpVersion.*;
import static io.netty.util.internal.ObjectUtil.*;
/**
* Handles the HTTP handshake (the HTTP Upgrade request) for {@link WebSocketServerProtocolHandler}.
*/
class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapter {
private static final WebSocketHandshakeException HANDSHAKE_TIMED_OUT_EXCEPTION = ThrowableUtil.unknownStackTrace(
new WebSocketHandshakeException("handshake timed out"),
WebSocketServerProtocolHandshakeHandler.class,
"channelRead(...)");
private static final long DEFAULT_HANDSHAKE_TIMEOUT_MS = 10000L;
private final String websocketPath;
private final String subprotocols;
@ -43,20 +58,45 @@ class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapt
private final int maxFramePayloadSize;
private final boolean allowMaskMismatch;
private final boolean checkStartsWith;
private final long handshakeTimeoutMillis;
private ChannelHandlerContext ctx;
private ChannelPromise handshakePromise;
WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch) {
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false);
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch,
DEFAULT_HANDSHAKE_TIMEOUT_MS);
}
WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize,
boolean allowMaskMismatch, long handshakeTimeoutMillis) {
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch,
false, handshakeTimeoutMillis);
}
WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith) {
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch,
checkStartsWith, DEFAULT_HANDSHAKE_TIMEOUT_MS);
}
WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch,
boolean checkStartsWith, long handshakeTimeoutMillis) {
this.websocketPath = websocketPath;
this.subprotocols = subprotocols;
this.allowExtensions = allowExtensions;
maxFramePayloadSize = maxFrameSize;
this.allowMaskMismatch = allowMaskMismatch;
this.checkStartsWith = checkStartsWith;
this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis");
}
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
this.ctx = ctx;
handshakePromise = ctx.newPromise();
}
@Override
@ -77,6 +117,7 @@ class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapt
getWebSocketLocation(ctx.pipeline(), req, websocketPath), subprotocols,
allowExtensions, maxFramePayloadSize, allowMaskMismatch);
final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);
final ChannelPromise localHandshakePromise = handshakePromise;
if (handshaker == null) {
WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
} else {
@ -85,8 +126,10 @@ class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapt
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
localHandshakePromise.tryFailure(future.cause());
ctx.fireExceptionCaught(future.cause());
} else {
localHandshakePromise.trySuccess();
// Kept for compatibility
ctx.fireUserEventTriggered(
WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE);
@ -96,6 +139,7 @@ class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapt
}
}
});
applyHandshakeTimeout();
WebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker);
ctx.pipeline().replace(this, "WS403Responder",
WebSocketServerProtocolHandler.forbiddenHttpRequestResponder());
@ -125,4 +169,31 @@ class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapt
String host = req.headers().get(HttpHeaderNames.HOST);
return protocol + "://" + host + path;
}
private void applyHandshakeTimeout() {
final ChannelPromise localHandshakePromise = handshakePromise;
final long handshakeTimeoutMillis = this.handshakeTimeoutMillis;
if (handshakeTimeoutMillis <= 0 || localHandshakePromise.isDone()) {
return;
}
final Future<?> timeoutFuture = ctx.executor().schedule(new Runnable() {
@Override
public void run() {
if (localHandshakePromise.tryFailure(HANDSHAKE_TIMED_OUT_EXCEPTION)) {
ctx.flush()
.fireUserEventTriggered(ServerHandshakeStateEvent.HANDSHAKE_TIMEOUT)
.close();
}
}
}, handshakeTimeoutMillis, TimeUnit.MILLISECONDS);
// Cancel the handshake timeout when handshake is finished.
localHandshakePromise.addListener(new FutureListener<Void>() {
@Override
public void operationComplete(Future<Void> f) throws Exception {
timeoutFuture.cancel(false);
}
});
}
}

View File

@ -45,6 +45,7 @@ public class WebSocketHandshakeHandOverTest {
private boolean clientReceivedMessage;
private boolean serverReceivedCloseHandshake;
private boolean clientForceClosed;
private boolean clientHandshakeTimeout;
private final class CloseNoOpServerProtocolHandler extends WebSocketServerProtocolHandler {
CloseNoOpServerProtocolHandler(String websocketPath) {
@ -69,6 +70,7 @@ public class WebSocketHandshakeHandOverTest {
clientReceivedMessage = false;
serverReceivedCloseHandshake = false;
clientForceClosed = false;
clientHandshakeTimeout = false;
}
@Test
@ -118,6 +120,66 @@ public class WebSocketHandshakeHandOverTest {
assertTrue(clientReceivedMessage);
}
@Test(expected = WebSocketHandshakeException.class)
public void testClientHandshakeTimeout() throws Exception {
EmbeddedChannel serverChannel = createServerChannel(new SimpleChannelInboundHandler<Object>() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt == ServerHandshakeStateEvent.HANDSHAKE_COMPLETE) {
serverReceivedHandshake = true;
// immediately send a message to the client on connect
ctx.writeAndFlush(new TextWebSocketFrame("abc"));
} else if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
serverHandshakeComplete = (WebSocketServerProtocolHandler.HandshakeComplete) evt;
}
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception {
}
});
EmbeddedChannel clientChannel = createClientChannel(new SimpleChannelInboundHandler<Object>() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt == ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
clientReceivedHandshake = true;
} else if (evt == ClientHandshakeStateEvent.HANDSHAKE_TIMEOUT) {
clientHandshakeTimeout = true;
}
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof TextWebSocketFrame) {
clientReceivedMessage = true;
}
}
}, 100);
// Client send the handshake request to server
transferAllDataWithMerge(clientChannel, serverChannel);
// Server do not send the response back
// transferAllDataWithMerge(serverChannel, clientChannel);
WebSocketClientProtocolHandshakeHandler handshakeHandler =
(WebSocketClientProtocolHandshakeHandler) clientChannel
.pipeline().get(WebSocketClientProtocolHandshakeHandler.class.getName());
while (!handshakeHandler.getHandshakeFuture().isDone()) {
Thread.sleep(10);
// We need to run all pending tasks as the handshake timeout is scheduled on the EventLoop.
clientChannel.runScheduledPendingTasks();
}
assertTrue(clientHandshakeTimeout);
assertFalse(clientReceivedHandshake);
assertFalse(clientReceivedMessage);
// Should throw WebSocketHandshakeException
try {
handshakeHandler.getHandshakeFuture().syncUninterruptibly();
} finally {
serverChannel.finishAndReleaseAll();
}
}
@Test(timeout = 10000)
public void testClientHandshakerForceClose() throws Exception {
final WebSocketClientHandshaker handshaker = WebSocketClientHandshakerFactory.newHandshaker(
@ -245,4 +307,13 @@ public class WebSocketHandshakeHandOverTest {
handler);
}
private static EmbeddedChannel createClientChannel(ChannelHandler handler, long timeoutMillis) throws Exception {
return new EmbeddedChannel(
new HttpClientCodec(),
new HttpObjectAggregator(8192),
new WebSocketClientProtocolHandler(new URI("ws://localhost:1234/test"),
WebSocketVersion.V13, "test-proto-2",
false, null, 65536, timeoutMillis),
handler);
}
}