Support handshake timeout in websocket handlers (#8856)
Motivation: Support handshake timeout option in websocket handlers. It makes sense to limit the time we need to move from `HANDSHAKE_ISSUED` to `HANDSHAKE_COMPLETE` states when upgrading to WebSockets Modification: - Add `handshakeTimeoutMillis` option in `WebSocketClientProtocolHandshakeHandler` and `WebSocketServerProtocolHandshakeHandler`. - Schedule a timeout task, the task will trigger user event `HANDSHAKE_TIMEOUT` if the handshake timed out. Result: Fixes issue https://github.com/netty/netty/issues/8841
This commit is contained in:
parent
23554e6997
commit
1465e3ce06
@ -23,6 +23,8 @@ import io.netty.handler.codec.http.HttpHeaders;
|
||||
import java.net.URI;
|
||||
import java.util.List;
|
||||
|
||||
import static io.netty.util.internal.ObjectUtil.*;
|
||||
|
||||
/**
|
||||
* This handler does all the heavy lifting for you to run a websocket client.
|
||||
*
|
||||
@ -38,9 +40,11 @@ import java.util.List;
|
||||
* {@link ClientHandshakeStateEvent#HANDSHAKE_ISSUED} or {@link ClientHandshakeStateEvent#HANDSHAKE_COMPLETE}.
|
||||
*/
|
||||
public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
|
||||
private static final long DEFAULT_HANDSHAKE_TIMEOUT_MS = 10000L;
|
||||
|
||||
private final WebSocketClientHandshaker handshaker;
|
||||
private final boolean handleCloseFrames;
|
||||
private final long handshakeTimeoutMillis;
|
||||
|
||||
/**
|
||||
* Returns the used handshaker
|
||||
@ -53,6 +57,11 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
|
||||
* Events that are fired to notify about handshake status
|
||||
*/
|
||||
public enum ClientHandshakeStateEvent {
|
||||
/**
|
||||
* The Handshake was timed out
|
||||
*/
|
||||
HANDSHAKE_TIMEOUT,
|
||||
|
||||
/**
|
||||
* The Handshake was started but the server did not response yet to the request
|
||||
*/
|
||||
@ -92,9 +101,45 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
|
||||
boolean allowExtensions, HttpHeaders customHeaders,
|
||||
int maxFramePayloadLength, boolean handleCloseFrames,
|
||||
boolean performMasking, boolean allowMaskMismatch) {
|
||||
this(webSocketURL, version, subprotocol, allowExtensions, customHeaders,
|
||||
maxFramePayloadLength, handleCloseFrames, performMasking, allowMaskMismatch, DEFAULT_HANDSHAKE_TIMEOUT_MS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Base constructor
|
||||
*
|
||||
* @param webSocketURL
|
||||
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
|
||||
* sent to this URL.
|
||||
* @param version
|
||||
* Version of web socket specification to use to connect to the server
|
||||
* @param subprotocol
|
||||
* Sub protocol request sent to the server.
|
||||
* @param customHeaders
|
||||
* Map of custom headers to add to the client request
|
||||
* @param maxFramePayloadLength
|
||||
* Maximum length of a frame's payload
|
||||
* @param handleCloseFrames
|
||||
* {@code true} if close frames should not be forwarded and just close the channel
|
||||
* @param performMasking
|
||||
* Whether to mask all written websocket frames. This must be set to true in order to be fully compatible
|
||||
* with the websocket specifications. Client applications that communicate with a non-standard server
|
||||
* which doesn't require masking might set this to false to achieve a higher performance.
|
||||
* @param allowMaskMismatch
|
||||
* When set to true, frames which are not masked properly according to the standard will still be
|
||||
* accepted.
|
||||
* @param handshakeTimeoutMillis
|
||||
* Handshake timeout in mills, when handshake timeout, will trigger user
|
||||
* event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT}
|
||||
*/
|
||||
public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version, String subprotocol,
|
||||
boolean allowExtensions, HttpHeaders customHeaders,
|
||||
int maxFramePayloadLength, boolean handleCloseFrames, boolean performMasking,
|
||||
boolean allowMaskMismatch, long handshakeTimeoutMillis) {
|
||||
this(WebSocketClientHandshakerFactory.newHandshaker(webSocketURL, version, subprotocol,
|
||||
allowExtensions, customHeaders, maxFramePayloadLength,
|
||||
performMasking, allowMaskMismatch), handleCloseFrames);
|
||||
performMasking, allowMaskMismatch),
|
||||
handleCloseFrames, handshakeTimeoutMillis);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -118,7 +163,34 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
|
||||
boolean allowExtensions, HttpHeaders customHeaders,
|
||||
int maxFramePayloadLength, boolean handleCloseFrames) {
|
||||
this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength,
|
||||
handleCloseFrames, true, false);
|
||||
handleCloseFrames, DEFAULT_HANDSHAKE_TIMEOUT_MS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Base constructor
|
||||
*
|
||||
* @param webSocketURL
|
||||
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
|
||||
* sent to this URL.
|
||||
* @param version
|
||||
* Version of web socket specification to use to connect to the server
|
||||
* @param subprotocol
|
||||
* Sub protocol request sent to the server.
|
||||
* @param customHeaders
|
||||
* Map of custom headers to add to the client request
|
||||
* @param maxFramePayloadLength
|
||||
* Maximum length of a frame's payload
|
||||
* @param handleCloseFrames
|
||||
* {@code true} if close frames should not be forwarded and just close the channel
|
||||
* @param handshakeTimeoutMillis
|
||||
* Handshake timeout in mills, when handshake timeout, will trigger user
|
||||
* event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT}
|
||||
*/
|
||||
public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version, String subprotocol,
|
||||
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,
|
||||
boolean handleCloseFrames, long handshakeTimeoutMillis) {
|
||||
this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength,
|
||||
handleCloseFrames, true, false, handshakeTimeoutMillis);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -140,7 +212,32 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
|
||||
boolean allowExtensions, HttpHeaders customHeaders,
|
||||
int maxFramePayloadLength) {
|
||||
this(webSocketURL, version, subprotocol,
|
||||
allowExtensions, customHeaders, maxFramePayloadLength, true);
|
||||
allowExtensions, customHeaders, maxFramePayloadLength, DEFAULT_HANDSHAKE_TIMEOUT_MS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Base constructor
|
||||
*
|
||||
* @param webSocketURL
|
||||
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
|
||||
* sent to this URL.
|
||||
* @param version
|
||||
* Version of web socket specification to use to connect to the server
|
||||
* @param subprotocol
|
||||
* Sub protocol request sent to the server.
|
||||
* @param customHeaders
|
||||
* Map of custom headers to add to the client request
|
||||
* @param maxFramePayloadLength
|
||||
* Maximum length of a frame's payload
|
||||
* @param handshakeTimeoutMillis
|
||||
* Handshake timeout in mills, when handshake timeout, will trigger user
|
||||
* event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT}
|
||||
*/
|
||||
public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version, String subprotocol,
|
||||
boolean allowExtensions, HttpHeaders customHeaders,
|
||||
int maxFramePayloadLength, long handshakeTimeoutMillis) {
|
||||
this(webSocketURL, version, subprotocol,
|
||||
allowExtensions, customHeaders, maxFramePayloadLength, true, handshakeTimeoutMillis);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -153,7 +250,24 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
|
||||
* {@code true} if close frames should not be forwarded and just close the channel
|
||||
*/
|
||||
public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames) {
|
||||
this(handshaker, handleCloseFrames, true);
|
||||
this(handshaker, handleCloseFrames, DEFAULT_HANDSHAKE_TIMEOUT_MS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Base constructor
|
||||
*
|
||||
* @param handshaker
|
||||
* The {@link WebSocketClientHandshaker} which will be used to issue the handshake once the connection
|
||||
* was established to the remote peer.
|
||||
* @param handleCloseFrames
|
||||
* {@code true} if close frames should not be forwarded and just close the channel
|
||||
* @param handshakeTimeoutMillis
|
||||
* Handshake timeout in mills, when handshake timeout, will trigger user
|
||||
* event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT}
|
||||
*/
|
||||
public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames,
|
||||
long handshakeTimeoutMillis) {
|
||||
this(handshaker, handleCloseFrames, true, handshakeTimeoutMillis);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -169,9 +283,29 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
|
||||
*/
|
||||
public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames,
|
||||
boolean dropPongFrames) {
|
||||
this(handshaker, handleCloseFrames, dropPongFrames, DEFAULT_HANDSHAKE_TIMEOUT_MS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Base constructor
|
||||
*
|
||||
* @param handshaker
|
||||
* The {@link WebSocketClientHandshaker} which will be used to issue the handshake once the connection
|
||||
* was established to the remote peer.
|
||||
* @param handleCloseFrames
|
||||
* {@code true} if close frames should not be forwarded and just close the channel
|
||||
* @param dropPongFrames
|
||||
* {@code true} if pong frames should not be forwarded
|
||||
* @param handshakeTimeoutMillis
|
||||
* Handshake timeout in mills, when handshake timeout, will trigger user
|
||||
* event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT}
|
||||
*/
|
||||
public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames,
|
||||
boolean dropPongFrames, long handshakeTimeoutMillis) {
|
||||
super(dropPongFrames);
|
||||
this.handshaker = handshaker;
|
||||
this.handleCloseFrames = handleCloseFrames;
|
||||
this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis");
|
||||
}
|
||||
|
||||
/**
|
||||
@ -182,7 +316,21 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
|
||||
* was established to the remote peer.
|
||||
*/
|
||||
public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker) {
|
||||
this(handshaker, true);
|
||||
this(handshaker, DEFAULT_HANDSHAKE_TIMEOUT_MS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Base constructor
|
||||
*
|
||||
* @param handshaker
|
||||
* The {@link WebSocketClientHandshaker} which will be used to issue the handshake once the connection
|
||||
* was established to the remote peer.
|
||||
* @param handshakeTimeoutMillis
|
||||
* Handshake timeout in mills, when handshake timeout, will trigger user
|
||||
* event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT}
|
||||
*/
|
||||
public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, long handshakeTimeoutMillis) {
|
||||
this(handshaker, true, handshakeTimeoutMillis);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -200,7 +348,7 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
|
||||
if (cp.get(WebSocketClientProtocolHandshakeHandler.class) == null) {
|
||||
// Add the WebSocketClientProtocolHandshakeHandler before this one.
|
||||
ctx.pipeline().addBefore(ctx.name(), WebSocketClientProtocolHandshakeHandler.class.getName(),
|
||||
new WebSocketClientProtocolHandshakeHandler(handshaker));
|
||||
new WebSocketClientProtocolHandshakeHandler(handshaker, handshakeTimeoutMillis));
|
||||
}
|
||||
if (cp.get(Utf8FrameValidator.class) == null) {
|
||||
// Add the UFT8 checking before this one.
|
||||
|
@ -15,16 +15,47 @@
|
||||
*/
|
||||
package io.netty.handler.codec.http.websocketx;
|
||||
|
||||
import io.netty.channel.ChannelFuture;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandler;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.handler.codec.http.FullHttpResponse;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler.ClientHandshakeStateEvent;
|
||||
import io.netty.util.concurrent.Future;
|
||||
import io.netty.util.concurrent.FutureListener;
|
||||
import io.netty.util.internal.ThrowableUtil;
|
||||
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static io.netty.util.internal.ObjectUtil.*;
|
||||
|
||||
class WebSocketClientProtocolHandshakeHandler implements ChannelInboundHandler {
|
||||
|
||||
private static final WebSocketHandshakeException HANDSHAKE_TIMED_OUT_EXCEPTION = ThrowableUtil.unknownStackTrace(
|
||||
new WebSocketHandshakeException("handshake timed out"),
|
||||
WebSocketClientProtocolHandshakeHandler.class,
|
||||
"channelActive(...)");
|
||||
private static final long DEFAULT_HANDSHAKE_TIMEOUT_MS = 10000L;
|
||||
|
||||
private final WebSocketClientHandshaker handshaker;
|
||||
private final long handshakeTimeoutMillis;
|
||||
private ChannelHandlerContext ctx;
|
||||
private ChannelPromise handshakePromise;
|
||||
|
||||
WebSocketClientProtocolHandshakeHandler(WebSocketClientHandshaker handshaker) {
|
||||
this(handshaker, DEFAULT_HANDSHAKE_TIMEOUT_MS);
|
||||
}
|
||||
|
||||
WebSocketClientProtocolHandshakeHandler(WebSocketClientHandshaker handshaker, long handshakeTimeoutMillis) {
|
||||
this.handshaker = handshaker;
|
||||
this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
|
||||
this.ctx = ctx;
|
||||
handshakePromise = ctx.newPromise();
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -32,12 +63,14 @@ class WebSocketClientProtocolHandshakeHandler implements ChannelInboundHandler {
|
||||
ctx.fireChannelActive();
|
||||
handshaker.handshake(ctx.channel()).addListener((ChannelFutureListener) future -> {
|
||||
if (!future.isSuccess()) {
|
||||
handshakePromise.tryFailure(future.cause());
|
||||
ctx.fireExceptionCaught(future.cause());
|
||||
} else {
|
||||
ctx.fireUserEventTriggered(
|
||||
WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_ISSUED);
|
||||
}
|
||||
});
|
||||
applyHandshakeTimeout();
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -51,6 +84,7 @@ class WebSocketClientProtocolHandshakeHandler implements ChannelInboundHandler {
|
||||
try {
|
||||
if (!handshaker.isHandshakeComplete()) {
|
||||
handshaker.finishHandshake(ctx.channel(), response);
|
||||
handshakePromise.trySuccess();
|
||||
ctx.fireUserEventTriggered(
|
||||
WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE);
|
||||
ctx.pipeline().remove(this);
|
||||
@ -61,4 +95,43 @@ class WebSocketClientProtocolHandshakeHandler implements ChannelInboundHandler {
|
||||
response.release();
|
||||
}
|
||||
}
|
||||
|
||||
private void applyHandshakeTimeout() {
|
||||
final ChannelPromise localHandshakePromise = handshakePromise;
|
||||
if (handshakeTimeoutMillis <= 0 || localHandshakePromise.isDone()) {
|
||||
return;
|
||||
}
|
||||
|
||||
final Future<?> timeoutFuture = ctx.executor().schedule(new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
if (localHandshakePromise.isDone()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (localHandshakePromise.tryFailure(HANDSHAKE_TIMED_OUT_EXCEPTION)) {
|
||||
ctx.flush()
|
||||
.fireUserEventTriggered(ClientHandshakeStateEvent.HANDSHAKE_TIMEOUT)
|
||||
.close();
|
||||
}
|
||||
}
|
||||
}, handshakeTimeoutMillis, TimeUnit.MILLISECONDS);
|
||||
|
||||
// Cancel the handshake timeout when handshake is finished.
|
||||
localHandshakePromise.addListener(new FutureListener<Void>() {
|
||||
@Override
|
||||
public void operationComplete(Future<Void> f) throws Exception {
|
||||
timeoutFuture.cancel(false);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* This method is visible for testing.
|
||||
*
|
||||
* @return current handshake future
|
||||
*/
|
||||
ChannelFuture getHandshakeFuture() {
|
||||
return handshakePromise;
|
||||
}
|
||||
}
|
||||
|
@ -32,6 +32,7 @@ import io.netty.util.AttributeKey;
|
||||
import java.util.List;
|
||||
|
||||
import static io.netty.handler.codec.http.HttpVersion.*;
|
||||
import static io.netty.util.internal.ObjectUtil.*;
|
||||
|
||||
/**
|
||||
* This handler does all the heavy lifting for you to run a websocket server.
|
||||
@ -63,7 +64,12 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
|
||||
* it provides extra information about the handshake
|
||||
*/
|
||||
@Deprecated
|
||||
HANDSHAKE_COMPLETE
|
||||
HANDSHAKE_COMPLETE,
|
||||
|
||||
/**
|
||||
* The Handshake was timed out
|
||||
*/
|
||||
HANDSHAKE_TIMEOUT
|
||||
}
|
||||
|
||||
/**
|
||||
@ -96,47 +102,94 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
|
||||
private static final AttributeKey<WebSocketServerHandshaker> HANDSHAKER_ATTR_KEY =
|
||||
AttributeKey.valueOf(WebSocketServerHandshaker.class, "HANDSHAKER");
|
||||
|
||||
private static final long DEFAULT_HANDSHAKE_TIMEOUT_MS = 10000L;
|
||||
|
||||
private final String websocketPath;
|
||||
private final String subprotocols;
|
||||
private final boolean allowExtensions;
|
||||
private final int maxFramePayloadLength;
|
||||
private final boolean allowMaskMismatch;
|
||||
private final boolean checkStartsWith;
|
||||
private final long handshakeTimeoutMillis;
|
||||
|
||||
public WebSocketServerProtocolHandler(String websocketPath) {
|
||||
this(websocketPath, DEFAULT_HANDSHAKE_TIMEOUT_MS);
|
||||
}
|
||||
|
||||
public WebSocketServerProtocolHandler(String websocketPath, long handshakeTimeoutMillis) {
|
||||
this(websocketPath, null, false);
|
||||
}
|
||||
|
||||
public WebSocketServerProtocolHandler(String websocketPath, boolean checkStartsWith) {
|
||||
this(websocketPath, null, false, 65536, false, checkStartsWith);
|
||||
this(websocketPath, checkStartsWith, DEFAULT_HANDSHAKE_TIMEOUT_MS);
|
||||
}
|
||||
|
||||
public WebSocketServerProtocolHandler(String websocketPath, boolean checkStartsWith, long handshakeTimeoutMillis) {
|
||||
this(websocketPath, null, false, 65536, false, checkStartsWith, handshakeTimeoutMillis);
|
||||
}
|
||||
|
||||
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols) {
|
||||
this(websocketPath, subprotocols, false);
|
||||
this(websocketPath, subprotocols, DEFAULT_HANDSHAKE_TIMEOUT_MS);
|
||||
}
|
||||
|
||||
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, long handshakeTimeoutMillis) {
|
||||
this(websocketPath, subprotocols, false, handshakeTimeoutMillis);
|
||||
}
|
||||
|
||||
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions) {
|
||||
this(websocketPath, subprotocols, allowExtensions, 65536);
|
||||
this(websocketPath, subprotocols, allowExtensions, DEFAULT_HANDSHAKE_TIMEOUT_MS);
|
||||
}
|
||||
|
||||
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions,
|
||||
long handshakeTimeoutMillis) {
|
||||
this(websocketPath, subprotocols, allowExtensions, 65536, handshakeTimeoutMillis);
|
||||
}
|
||||
|
||||
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
|
||||
boolean allowExtensions, int maxFrameSize) {
|
||||
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, false);
|
||||
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, DEFAULT_HANDSHAKE_TIMEOUT_MS);
|
||||
}
|
||||
|
||||
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
|
||||
boolean allowExtensions, int maxFrameSize, long handshakeTimeoutMillis) {
|
||||
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, false, handshakeTimeoutMillis);
|
||||
}
|
||||
|
||||
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
|
||||
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch) {
|
||||
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false);
|
||||
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch,
|
||||
DEFAULT_HANDSHAKE_TIMEOUT_MS);
|
||||
}
|
||||
|
||||
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions,
|
||||
int maxFrameSize, boolean allowMaskMismatch, long handshakeTimeoutMillis) {
|
||||
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false,
|
||||
handshakeTimeoutMillis);
|
||||
}
|
||||
|
||||
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
|
||||
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith) {
|
||||
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith, true);
|
||||
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith,
|
||||
DEFAULT_HANDSHAKE_TIMEOUT_MS);
|
||||
}
|
||||
|
||||
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
|
||||
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch,
|
||||
boolean checkStartsWith, long handshakeTimeoutMillis) {
|
||||
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith, true,
|
||||
handshakeTimeoutMillis);
|
||||
}
|
||||
|
||||
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
|
||||
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch,
|
||||
boolean checkStartsWith, boolean dropPongFrames) {
|
||||
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith,
|
||||
dropPongFrames, DEFAULT_HANDSHAKE_TIMEOUT_MS);
|
||||
}
|
||||
|
||||
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions,
|
||||
int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith,
|
||||
boolean dropPongFrames, long handshakeTimeoutMillis) {
|
||||
super(dropPongFrames);
|
||||
this.websocketPath = websocketPath;
|
||||
this.subprotocols = subprotocols;
|
||||
@ -144,6 +197,7 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
|
||||
maxFramePayloadLength = maxFrameSize;
|
||||
this.allowMaskMismatch = allowMaskMismatch;
|
||||
this.checkStartsWith = checkStartsWith;
|
||||
this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis");
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -152,8 +206,11 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
|
||||
if (cp.get(WebSocketServerProtocolHandshakeHandler.class) == null) {
|
||||
// Add the WebSocketHandshakeHandler before this one.
|
||||
ctx.pipeline().addBefore(ctx.name(), WebSocketServerProtocolHandshakeHandler.class.getName(),
|
||||
new WebSocketServerProtocolHandshakeHandler(websocketPath, subprotocols,
|
||||
allowExtensions, maxFramePayloadLength, allowMaskMismatch, checkStartsWith));
|
||||
new WebSocketServerProtocolHandshakeHandler(websocketPath, subprotocols,
|
||||
allowExtensions, maxFramePayloadLength,
|
||||
allowMaskMismatch,
|
||||
checkStartsWith,
|
||||
handshakeTimeoutMillis));
|
||||
}
|
||||
if (cp.get(Utf8FrameValidator.class) == null) {
|
||||
// Add the UFT8 checking before this one.
|
||||
|
@ -20,22 +20,36 @@ import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandler;
|
||||
import io.netty.channel.ChannelPipeline;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.handler.codec.http.DefaultFullHttpResponse;
|
||||
import io.netty.handler.codec.http.FullHttpRequest;
|
||||
import io.netty.handler.codec.http.HttpHeaderNames;
|
||||
import io.netty.handler.codec.http.HttpRequest;
|
||||
import io.netty.handler.codec.http.HttpResponse;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler.ServerHandshakeStateEvent;
|
||||
import io.netty.handler.ssl.SslHandler;
|
||||
import io.netty.util.concurrent.Future;
|
||||
import io.netty.util.concurrent.FutureListener;
|
||||
import io.netty.util.internal.ThrowableUtil;
|
||||
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static io.netty.handler.codec.http.HttpUtil.*;
|
||||
import static io.netty.handler.codec.http.HttpMethod.*;
|
||||
import static io.netty.handler.codec.http.HttpResponseStatus.*;
|
||||
import static io.netty.handler.codec.http.HttpUtil.*;
|
||||
import static io.netty.handler.codec.http.HttpVersion.*;
|
||||
import static io.netty.util.internal.ObjectUtil.*;
|
||||
|
||||
/**
|
||||
* Handles the HTTP handshake (the HTTP Upgrade request) for {@link WebSocketServerProtocolHandler}.
|
||||
*/
|
||||
class WebSocketServerProtocolHandshakeHandler implements ChannelInboundHandler {
|
||||
private static final WebSocketHandshakeException HANDSHAKE_TIMED_OUT_EXCEPTION = ThrowableUtil.unknownStackTrace(
|
||||
new WebSocketHandshakeException("handshake timed out"),
|
||||
WebSocketServerProtocolHandshakeHandler.class,
|
||||
"channelRead(...)");
|
||||
|
||||
private static final long DEFAULT_HANDSHAKE_TIMEOUT_MS = 10000L;
|
||||
|
||||
private final String websocketPath;
|
||||
private final String subprotocols;
|
||||
@ -43,20 +57,45 @@ class WebSocketServerProtocolHandshakeHandler implements ChannelInboundHandler {
|
||||
private final int maxFramePayloadSize;
|
||||
private final boolean allowMaskMismatch;
|
||||
private final boolean checkStartsWith;
|
||||
private final long handshakeTimeoutMillis;
|
||||
private ChannelHandlerContext ctx;
|
||||
private ChannelPromise handshakePromise;
|
||||
|
||||
WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols,
|
||||
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch) {
|
||||
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false);
|
||||
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch,
|
||||
DEFAULT_HANDSHAKE_TIMEOUT_MS);
|
||||
}
|
||||
|
||||
WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols,
|
||||
boolean allowExtensions, int maxFrameSize,
|
||||
boolean allowMaskMismatch, long handshakeTimeoutMillis) {
|
||||
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch,
|
||||
false, handshakeTimeoutMillis);
|
||||
}
|
||||
|
||||
WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols,
|
||||
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith) {
|
||||
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch,
|
||||
checkStartsWith, DEFAULT_HANDSHAKE_TIMEOUT_MS);
|
||||
}
|
||||
|
||||
WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols,
|
||||
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch,
|
||||
boolean checkStartsWith, long handshakeTimeoutMillis) {
|
||||
this.websocketPath = websocketPath;
|
||||
this.subprotocols = subprotocols;
|
||||
this.allowExtensions = allowExtensions;
|
||||
maxFramePayloadSize = maxFrameSize;
|
||||
this.allowMaskMismatch = allowMaskMismatch;
|
||||
this.checkStartsWith = checkStartsWith;
|
||||
this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
|
||||
this.ctx = ctx;
|
||||
handshakePromise = ctx.newPromise();
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -77,14 +116,17 @@ class WebSocketServerProtocolHandshakeHandler implements ChannelInboundHandler {
|
||||
getWebSocketLocation(ctx.pipeline(), req, websocketPath), subprotocols,
|
||||
allowExtensions, maxFramePayloadSize, allowMaskMismatch);
|
||||
final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);
|
||||
final ChannelPromise localHandshakePromise = handshakePromise;
|
||||
if (handshaker == null) {
|
||||
WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
|
||||
} else {
|
||||
final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req);
|
||||
handshakeFuture.addListener((ChannelFutureListener) future -> {
|
||||
if (!future.isSuccess()) {
|
||||
localHandshakePromise.tryFailure(future.cause());
|
||||
ctx.fireExceptionCaught(future.cause());
|
||||
} else {
|
||||
localHandshakePromise.trySuccess();
|
||||
// Kept for compatibility
|
||||
ctx.fireUserEventTriggered(
|
||||
WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE);
|
||||
@ -93,6 +135,7 @@ class WebSocketServerProtocolHandshakeHandler implements ChannelInboundHandler {
|
||||
req.uri(), req.headers(), handshaker.selectedSubprotocol()));
|
||||
}
|
||||
});
|
||||
applyHandshakeTimeout();
|
||||
WebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker);
|
||||
ctx.pipeline().replace(this, "WS403Responder",
|
||||
WebSocketServerProtocolHandler.forbiddenHttpRequestResponder());
|
||||
@ -122,4 +165,31 @@ class WebSocketServerProtocolHandshakeHandler implements ChannelInboundHandler {
|
||||
String host = req.headers().get(HttpHeaderNames.HOST);
|
||||
return protocol + "://" + host + path;
|
||||
}
|
||||
|
||||
private void applyHandshakeTimeout() {
|
||||
final ChannelPromise localHandshakePromise = handshakePromise;
|
||||
final long handshakeTimeoutMillis = this.handshakeTimeoutMillis;
|
||||
if (handshakeTimeoutMillis <= 0 || localHandshakePromise.isDone()) {
|
||||
return;
|
||||
}
|
||||
|
||||
final Future<?> timeoutFuture = ctx.executor().schedule(new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
if (localHandshakePromise.tryFailure(HANDSHAKE_TIMED_OUT_EXCEPTION)) {
|
||||
ctx.flush()
|
||||
.fireUserEventTriggered(ServerHandshakeStateEvent.HANDSHAKE_TIMEOUT)
|
||||
.close();
|
||||
}
|
||||
}
|
||||
}, handshakeTimeoutMillis, TimeUnit.MILLISECONDS);
|
||||
|
||||
// Cancel the handshake timeout when handshake is finished.
|
||||
localHandshakePromise.addListener(new FutureListener<Void>() {
|
||||
@Override
|
||||
public void operationComplete(Future<Void> f) throws Exception {
|
||||
timeoutFuture.cancel(false);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -34,6 +34,7 @@ import org.junit.Test;
|
||||
|
||||
import java.net.URI;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CompletionException;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
@ -45,6 +46,7 @@ public class WebSocketHandshakeHandOverTest {
|
||||
private boolean clientReceivedMessage;
|
||||
private boolean serverReceivedCloseHandshake;
|
||||
private boolean clientForceClosed;
|
||||
private boolean clientHandshakeTimeout;
|
||||
|
||||
private final class CloseNoOpServerProtocolHandler extends WebSocketServerProtocolHandler {
|
||||
CloseNoOpServerProtocolHandler(String websocketPath) {
|
||||
@ -69,6 +71,7 @@ public class WebSocketHandshakeHandOverTest {
|
||||
clientReceivedMessage = false;
|
||||
serverReceivedCloseHandshake = false;
|
||||
clientForceClosed = false;
|
||||
clientHandshakeTimeout = false;
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -118,6 +121,68 @@ public class WebSocketHandshakeHandOverTest {
|
||||
assertTrue(clientReceivedMessage);
|
||||
}
|
||||
|
||||
@Test(expected = WebSocketHandshakeException.class)
|
||||
public void testClientHandshakeTimeout() throws Throwable {
|
||||
EmbeddedChannel serverChannel = createServerChannel(new SimpleChannelInboundHandler<Object>() {
|
||||
@Override
|
||||
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
|
||||
if (evt == ServerHandshakeStateEvent.HANDSHAKE_COMPLETE) {
|
||||
serverReceivedHandshake = true;
|
||||
// immediately send a message to the client on connect
|
||||
ctx.writeAndFlush(new TextWebSocketFrame("abc"));
|
||||
} else if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
|
||||
serverHandshakeComplete = (WebSocketServerProtocolHandler.HandshakeComplete) evt;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||
}
|
||||
});
|
||||
|
||||
EmbeddedChannel clientChannel = createClientChannel(new SimpleChannelInboundHandler<Object>() {
|
||||
@Override
|
||||
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
|
||||
if (evt == ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
|
||||
clientReceivedHandshake = true;
|
||||
} else if (evt == ClientHandshakeStateEvent.HANDSHAKE_TIMEOUT) {
|
||||
clientHandshakeTimeout = true;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||
if (msg instanceof TextWebSocketFrame) {
|
||||
clientReceivedMessage = true;
|
||||
}
|
||||
}
|
||||
}, 100);
|
||||
// Client send the handshake request to server
|
||||
transferAllDataWithMerge(clientChannel, serverChannel);
|
||||
// Server do not send the response back
|
||||
// transferAllDataWithMerge(serverChannel, clientChannel);
|
||||
WebSocketClientProtocolHandshakeHandler handshakeHandler =
|
||||
(WebSocketClientProtocolHandshakeHandler) clientChannel
|
||||
.pipeline().get(WebSocketClientProtocolHandshakeHandler.class.getName());
|
||||
|
||||
while (!handshakeHandler.getHandshakeFuture().isDone()) {
|
||||
Thread.sleep(10);
|
||||
// We need to run all pending tasks as the handshake timeout is scheduled on the EventLoop.
|
||||
clientChannel.runScheduledPendingTasks();
|
||||
}
|
||||
assertTrue(clientHandshakeTimeout);
|
||||
assertFalse(clientReceivedHandshake);
|
||||
assertFalse(clientReceivedMessage);
|
||||
// Should throw WebSocketHandshakeException
|
||||
try {
|
||||
handshakeHandler.getHandshakeFuture().syncUninterruptibly();
|
||||
} catch (CompletionException e) {
|
||||
throw e.getCause();
|
||||
} finally {
|
||||
serverChannel.finishAndReleaseAll();
|
||||
}
|
||||
}
|
||||
|
||||
@Test(timeout = 10000)
|
||||
public void testClientHandshakerForceClose() throws Exception {
|
||||
final WebSocketClientHandshaker handshaker = WebSocketClientHandshakerFactory.newHandshaker(
|
||||
@ -245,4 +310,13 @@ public class WebSocketHandshakeHandOverTest {
|
||||
handler);
|
||||
}
|
||||
|
||||
private static EmbeddedChannel createClientChannel(ChannelHandler handler, long timeoutMillis) throws Exception {
|
||||
return new EmbeddedChannel(
|
||||
new HttpClientCodec(),
|
||||
new HttpObjectAggregator(8192),
|
||||
new WebSocketClientProtocolHandler(new URI("ws://localhost:1234/test"),
|
||||
WebSocketVersion.V13, "test-proto-2",
|
||||
false, null, 65536, timeoutMillis),
|
||||
handler);
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user