From 5ffac03f1efd84074e349228b6c345155caffe65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A7=A6=E4=B8=96=E6=88=90?= Date: Wed, 22 May 2019 18:37:28 +0800 Subject: [PATCH] Support handshake timeout in websocket handlers (#8856) Motivation: Support handshake timeout option in websocket handlers. It makes sense to limit the time we need to move from `HANDSHAKE_ISSUED` to `HANDSHAKE_COMPLETE` states when upgrading to WebSockets Modification: - Add `handshakeTimeoutMillis` option in `WebSocketClientProtocolHandshakeHandler` and `WebSocketServerProtocolHandshakeHandler`. - Schedule a timeout task, the task will trigger user event `HANDSHAKE_TIMEOUT` if the handshake timed out. Result: Fixes issue https://github.com/netty/netty/issues/8841 --- .../WebSocketClientProtocolHandler.java | 160 +++++++++++++++++- ...bSocketClientProtocolHandshakeHandler.java | 72 ++++++++ .../WebSocketServerProtocolHandler.java | 75 +++++++- ...bSocketServerProtocolHandshakeHandler.java | 75 +++++++- .../WebSocketHandshakeHandOverTest.java | 71 ++++++++ 5 files changed, 436 insertions(+), 17 deletions(-) diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandler.java index a40e2d74e5..fcc7d16ab9 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandler.java @@ -23,6 +23,8 @@ import io.netty.handler.codec.http.HttpHeaders; import java.net.URI; import java.util.List; +import static io.netty.util.internal.ObjectUtil.*; + /** * This handler does all the heavy lifting for you to run a websocket client. * @@ -38,9 +40,11 @@ import java.util.List; * {@link ClientHandshakeStateEvent#HANDSHAKE_ISSUED} or {@link ClientHandshakeStateEvent#HANDSHAKE_COMPLETE}. */ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler { + private static final long DEFAULT_HANDSHAKE_TIMEOUT_MS = 10000L; private final WebSocketClientHandshaker handshaker; private final boolean handleCloseFrames; + private final long handshakeTimeoutMillis; /** * Returns the used handshaker @@ -53,6 +57,11 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler { * Events that are fired to notify about handshake status */ public enum ClientHandshakeStateEvent { + /** + * The Handshake was timed out + */ + HANDSHAKE_TIMEOUT, + /** * The Handshake was started but the server did not response yet to the request */ @@ -92,9 +101,45 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler { boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, boolean handleCloseFrames, boolean performMasking, boolean allowMaskMismatch) { + this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, + maxFramePayloadLength, handleCloseFrames, performMasking, allowMaskMismatch, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + /** + * Base constructor + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param handleCloseFrames + * {@code true} if close frames should not be forwarded and just close the channel + * @param performMasking + * Whether to mask all written websocket frames. This must be set to true in order to be fully compatible + * with the websocket specifications. Client applications that communicate with a non-standard server + * which doesn't require masking might set this to false to achieve a higher performance. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted. + * @param handshakeTimeoutMillis + * Handshake timeout in mills, when handshake timeout, will trigger user + * event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT} + */ + public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, + int maxFramePayloadLength, boolean handleCloseFrames, boolean performMasking, + boolean allowMaskMismatch, long handshakeTimeoutMillis) { this(WebSocketClientHandshakerFactory.newHandshaker(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, - performMasking, allowMaskMismatch), handleCloseFrames); + performMasking, allowMaskMismatch), + handleCloseFrames, handshakeTimeoutMillis); } /** @@ -118,7 +163,34 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler { boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, boolean handleCloseFrames) { this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, - handleCloseFrames, true, false); + handleCloseFrames, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + /** + * Base constructor + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param handleCloseFrames + * {@code true} if close frames should not be forwarded and just close the channel + * @param handshakeTimeoutMillis + * Handshake timeout in mills, when handshake timeout, will trigger user + * event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT} + */ + public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, + boolean handleCloseFrames, long handshakeTimeoutMillis) { + this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, + handleCloseFrames, true, false, handshakeTimeoutMillis); } /** @@ -140,7 +212,32 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler { boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength) { this(webSocketURL, version, subprotocol, - allowExtensions, customHeaders, maxFramePayloadLength, true); + allowExtensions, customHeaders, maxFramePayloadLength, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + /** + * Base constructor + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param handshakeTimeoutMillis + * Handshake timeout in mills, when handshake timeout, will trigger user + * event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT} + */ + public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, + int maxFramePayloadLength, long handshakeTimeoutMillis) { + this(webSocketURL, version, subprotocol, + allowExtensions, customHeaders, maxFramePayloadLength, true, handshakeTimeoutMillis); } /** @@ -153,7 +250,24 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler { * {@code true} if close frames should not be forwarded and just close the channel */ public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames) { - this(handshaker, handleCloseFrames, true); + this(handshaker, handleCloseFrames, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + /** + * Base constructor + * + * @param handshaker + * The {@link WebSocketClientHandshaker} which will be used to issue the handshake once the connection + * was established to the remote peer. + * @param handleCloseFrames + * {@code true} if close frames should not be forwarded and just close the channel + * @param handshakeTimeoutMillis + * Handshake timeout in mills, when handshake timeout, will trigger user + * event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT} + */ + public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames, + long handshakeTimeoutMillis) { + this(handshaker, handleCloseFrames, true, handshakeTimeoutMillis); } /** @@ -169,9 +283,29 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler { */ public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames, boolean dropPongFrames) { + this(handshaker, handleCloseFrames, dropPongFrames, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + /** + * Base constructor + * + * @param handshaker + * The {@link WebSocketClientHandshaker} which will be used to issue the handshake once the connection + * was established to the remote peer. + * @param handleCloseFrames + * {@code true} if close frames should not be forwarded and just close the channel + * @param dropPongFrames + * {@code true} if pong frames should not be forwarded + * @param handshakeTimeoutMillis + * Handshake timeout in mills, when handshake timeout, will trigger user + * event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT} + */ + public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames, + boolean dropPongFrames, long handshakeTimeoutMillis) { super(dropPongFrames); this.handshaker = handshaker; this.handleCloseFrames = handleCloseFrames; + this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis"); } /** @@ -182,7 +316,21 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler { * was established to the remote peer. */ public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker) { - this(handshaker, true); + this(handshaker, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + /** + * Base constructor + * + * @param handshaker + * The {@link WebSocketClientHandshaker} which will be used to issue the handshake once the connection + * was established to the remote peer. + * @param handshakeTimeoutMillis + * Handshake timeout in mills, when handshake timeout, will trigger user + * event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT} + */ + public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, long handshakeTimeoutMillis) { + this(handshaker, true, handshakeTimeoutMillis); } @Override @@ -200,7 +348,7 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler { if (cp.get(WebSocketClientProtocolHandshakeHandler.class) == null) { // Add the WebSocketClientProtocolHandshakeHandler before this one. ctx.pipeline().addBefore(ctx.name(), WebSocketClientProtocolHandshakeHandler.class.getName(), - new WebSocketClientProtocolHandshakeHandler(handshaker)); + new WebSocketClientProtocolHandshakeHandler(handshaker, handshakeTimeoutMillis)); } if (cp.get(Utf8FrameValidator.class) == null) { // Add the UFT8 checking before this one. diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandshakeHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandshakeHandler.java index 3d130c63d6..1cbc3fb7d2 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandshakeHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandshakeHandler.java @@ -19,13 +19,43 @@ import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelPromise; import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler.ClientHandshakeStateEvent; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; +import io.netty.util.concurrent.ScheduledFuture; +import io.netty.util.internal.ThrowableUtil; + +import java.util.concurrent.TimeUnit; + +import static io.netty.util.internal.ObjectUtil.*; class WebSocketClientProtocolHandshakeHandler extends ChannelInboundHandlerAdapter { + private static final WebSocketHandshakeException HANDSHAKE_TIMED_OUT_EXCEPTION = ThrowableUtil.unknownStackTrace( + new WebSocketHandshakeException("handshake timed out"), + WebSocketClientProtocolHandshakeHandler.class, + "channelActive(...)"); + private static final long DEFAULT_HANDSHAKE_TIMEOUT_MS = 10000L; + private final WebSocketClientHandshaker handshaker; + private final long handshakeTimeoutMillis; + private ChannelHandlerContext ctx; + private ChannelPromise handshakePromise; WebSocketClientProtocolHandshakeHandler(WebSocketClientHandshaker handshaker) { + this(handshaker, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + WebSocketClientProtocolHandshakeHandler(WebSocketClientHandshaker handshaker, long handshakeTimeoutMillis) { this.handshaker = handshaker; + this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis"); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + this.ctx = ctx; + handshakePromise = ctx.newPromise(); } @Override @@ -35,6 +65,7 @@ class WebSocketClientProtocolHandshakeHandler extends ChannelInboundHandlerAdapt @Override public void operationComplete(ChannelFuture future) throws Exception { if (!future.isSuccess()) { + handshakePromise.tryFailure(future.cause()); ctx.fireExceptionCaught(future.cause()); } else { ctx.fireUserEventTriggered( @@ -42,6 +73,7 @@ class WebSocketClientProtocolHandshakeHandler extends ChannelInboundHandlerAdapt } } }); + applyHandshakeTimeout(); } @Override @@ -55,6 +87,7 @@ class WebSocketClientProtocolHandshakeHandler extends ChannelInboundHandlerAdapt try { if (!handshaker.isHandshakeComplete()) { handshaker.finishHandshake(ctx.channel(), response); + handshakePromise.trySuccess(); ctx.fireUserEventTriggered( WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE); ctx.pipeline().remove(this); @@ -65,4 +98,43 @@ class WebSocketClientProtocolHandshakeHandler extends ChannelInboundHandlerAdapt response.release(); } } + + private void applyHandshakeTimeout() { + final ChannelPromise localHandshakePromise = handshakePromise; + if (handshakeTimeoutMillis <= 0 || localHandshakePromise.isDone()) { + return; + } + + final Future timeoutFuture = ctx.executor().schedule(new Runnable() { + @Override + public void run() { + if (localHandshakePromise.isDone()) { + return; + } + + if (localHandshakePromise.tryFailure(HANDSHAKE_TIMED_OUT_EXCEPTION)) { + ctx.flush() + .fireUserEventTriggered(ClientHandshakeStateEvent.HANDSHAKE_TIMEOUT) + .close(); + } + } + }, handshakeTimeoutMillis, TimeUnit.MILLISECONDS); + + // Cancel the handshake timeout when handshake is finished. + localHandshakePromise.addListener(new FutureListener() { + @Override + public void operationComplete(Future f) throws Exception { + timeoutFuture.cancel(false); + } + }); + } + + /** + * This method is visible for testing. + * + * @return current handshake future + */ + ChannelFuture getHandshakeFuture() { + return handshakePromise; + } } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandler.java index 3a19b14542..cffcffd74c 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandler.java @@ -33,6 +33,7 @@ import io.netty.util.AttributeKey; import java.util.List; import static io.netty.handler.codec.http.HttpVersion.*; +import static io.netty.util.internal.ObjectUtil.*; /** * This handler does all the heavy lifting for you to run a websocket server. @@ -64,7 +65,12 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler { * it provides extra information about the handshake */ @Deprecated - HANDSHAKE_COMPLETE + HANDSHAKE_COMPLETE, + + /** + * The Handshake was timed out + */ + HANDSHAKE_TIMEOUT } /** @@ -97,47 +103,94 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler { private static final AttributeKey HANDSHAKER_ATTR_KEY = AttributeKey.valueOf(WebSocketServerHandshaker.class, "HANDSHAKER"); + private static final long DEFAULT_HANDSHAKE_TIMEOUT_MS = 10000L; + private final String websocketPath; private final String subprotocols; private final boolean allowExtensions; private final int maxFramePayloadLength; private final boolean allowMaskMismatch; private final boolean checkStartsWith; + private final long handshakeTimeoutMillis; public WebSocketServerProtocolHandler(String websocketPath) { + this(websocketPath, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + public WebSocketServerProtocolHandler(String websocketPath, long handshakeTimeoutMillis) { this(websocketPath, null, false); } public WebSocketServerProtocolHandler(String websocketPath, boolean checkStartsWith) { - this(websocketPath, null, false, 65536, false, checkStartsWith); + this(websocketPath, checkStartsWith, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + public WebSocketServerProtocolHandler(String websocketPath, boolean checkStartsWith, long handshakeTimeoutMillis) { + this(websocketPath, null, false, 65536, false, checkStartsWith, handshakeTimeoutMillis); } public WebSocketServerProtocolHandler(String websocketPath, String subprotocols) { - this(websocketPath, subprotocols, false); + this(websocketPath, subprotocols, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, long handshakeTimeoutMillis) { + this(websocketPath, subprotocols, false, handshakeTimeoutMillis); } public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions) { - this(websocketPath, subprotocols, allowExtensions, 65536); + this(websocketPath, subprotocols, allowExtensions, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, + long handshakeTimeoutMillis) { + this(websocketPath, subprotocols, allowExtensions, 65536, handshakeTimeoutMillis); } public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, int maxFrameSize) { - this(websocketPath, subprotocols, allowExtensions, maxFrameSize, false); + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, + boolean allowExtensions, int maxFrameSize, long handshakeTimeoutMillis) { + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, false, handshakeTimeoutMillis); } public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch) { - this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false); + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, + DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, + int maxFrameSize, boolean allowMaskMismatch, long handshakeTimeoutMillis) { + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false, + handshakeTimeoutMillis); } public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith) { - this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith, true); + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith, + DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, + boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, + boolean checkStartsWith, long handshakeTimeoutMillis) { + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith, true, + handshakeTimeoutMillis); } public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith, boolean dropPongFrames) { + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith, + dropPongFrames, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, + int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith, + boolean dropPongFrames, long handshakeTimeoutMillis) { super(dropPongFrames); this.websocketPath = websocketPath; this.subprotocols = subprotocols; @@ -145,6 +198,7 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler { maxFramePayloadLength = maxFrameSize; this.allowMaskMismatch = allowMaskMismatch; this.checkStartsWith = checkStartsWith; + this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis"); } @Override @@ -153,8 +207,11 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler { if (cp.get(WebSocketServerProtocolHandshakeHandler.class) == null) { // Add the WebSocketHandshakeHandler before this one. ctx.pipeline().addBefore(ctx.name(), WebSocketServerProtocolHandshakeHandler.class.getName(), - new WebSocketServerProtocolHandshakeHandler(websocketPath, subprotocols, - allowExtensions, maxFramePayloadLength, allowMaskMismatch, checkStartsWith)); + new WebSocketServerProtocolHandshakeHandler(websocketPath, subprotocols, + allowExtensions, maxFramePayloadLength, + allowMaskMismatch, + checkStartsWith, + handshakeTimeoutMillis)); } if (cp.get(Utf8FrameValidator.class) == null) { // Add the UFT8 checking before this one. diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java index 365a5859a6..ade3991f61 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java @@ -20,22 +20,37 @@ import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler.ServerHandshakeStateEvent; import io.netty.handler.ssl.SslHandler; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; +import io.netty.util.concurrent.ScheduledFuture; +import io.netty.util.internal.ThrowableUtil; + +import java.util.concurrent.TimeUnit; -import static io.netty.handler.codec.http.HttpUtil.*; import static io.netty.handler.codec.http.HttpMethod.*; import static io.netty.handler.codec.http.HttpResponseStatus.*; +import static io.netty.handler.codec.http.HttpUtil.*; import static io.netty.handler.codec.http.HttpVersion.*; +import static io.netty.util.internal.ObjectUtil.*; /** * Handles the HTTP handshake (the HTTP Upgrade request) for {@link WebSocketServerProtocolHandler}. */ class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapter { + private static final WebSocketHandshakeException HANDSHAKE_TIMED_OUT_EXCEPTION = ThrowableUtil.unknownStackTrace( + new WebSocketHandshakeException("handshake timed out"), + WebSocketServerProtocolHandshakeHandler.class, + "channelRead(...)"); + + private static final long DEFAULT_HANDSHAKE_TIMEOUT_MS = 10000L; private final String websocketPath; private final String subprotocols; @@ -43,20 +58,45 @@ class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapt private final int maxFramePayloadSize; private final boolean allowMaskMismatch; private final boolean checkStartsWith; + private final long handshakeTimeoutMillis; + private ChannelHandlerContext ctx; + private ChannelPromise handshakePromise; WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols, boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch) { - this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false); + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, + DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols, + boolean allowExtensions, int maxFrameSize, + boolean allowMaskMismatch, long handshakeTimeoutMillis) { + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, + false, handshakeTimeoutMillis); } WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols, boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith) { + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, + checkStartsWith, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols, + boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, + boolean checkStartsWith, long handshakeTimeoutMillis) { this.websocketPath = websocketPath; this.subprotocols = subprotocols; this.allowExtensions = allowExtensions; maxFramePayloadSize = maxFrameSize; this.allowMaskMismatch = allowMaskMismatch; this.checkStartsWith = checkStartsWith; + this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis"); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + this.ctx = ctx; + handshakePromise = ctx.newPromise(); } @Override @@ -77,6 +117,7 @@ class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapt getWebSocketLocation(ctx.pipeline(), req, websocketPath), subprotocols, allowExtensions, maxFramePayloadSize, allowMaskMismatch); final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req); + final ChannelPromise localHandshakePromise = handshakePromise; if (handshaker == null) { WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel()); } else { @@ -85,8 +126,10 @@ class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapt @Override public void operationComplete(ChannelFuture future) throws Exception { if (!future.isSuccess()) { + localHandshakePromise.tryFailure(future.cause()); ctx.fireExceptionCaught(future.cause()); } else { + localHandshakePromise.trySuccess(); // Kept for compatibility ctx.fireUserEventTriggered( WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE); @@ -96,6 +139,7 @@ class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapt } } }); + applyHandshakeTimeout(); WebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker); ctx.pipeline().replace(this, "WS403Responder", WebSocketServerProtocolHandler.forbiddenHttpRequestResponder()); @@ -125,4 +169,31 @@ class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapt String host = req.headers().get(HttpHeaderNames.HOST); return protocol + "://" + host + path; } + + private void applyHandshakeTimeout() { + final ChannelPromise localHandshakePromise = handshakePromise; + final long handshakeTimeoutMillis = this.handshakeTimeoutMillis; + if (handshakeTimeoutMillis <= 0 || localHandshakePromise.isDone()) { + return; + } + + final Future timeoutFuture = ctx.executor().schedule(new Runnable() { + @Override + public void run() { + if (localHandshakePromise.tryFailure(HANDSHAKE_TIMED_OUT_EXCEPTION)) { + ctx.flush() + .fireUserEventTriggered(ServerHandshakeStateEvent.HANDSHAKE_TIMEOUT) + .close(); + } + } + }, handshakeTimeoutMillis, TimeUnit.MILLISECONDS); + + // Cancel the handshake timeout when handshake is finished. + localHandshakePromise.addListener(new FutureListener() { + @Override + public void operationComplete(Future f) throws Exception { + timeoutFuture.cancel(false); + } + }); + } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketHandshakeHandOverTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketHandshakeHandOverTest.java index 5583a27339..e99b6d79ef 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketHandshakeHandOverTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketHandshakeHandOverTest.java @@ -45,6 +45,7 @@ public class WebSocketHandshakeHandOverTest { private boolean clientReceivedMessage; private boolean serverReceivedCloseHandshake; private boolean clientForceClosed; + private boolean clientHandshakeTimeout; private final class CloseNoOpServerProtocolHandler extends WebSocketServerProtocolHandler { CloseNoOpServerProtocolHandler(String websocketPath) { @@ -69,6 +70,7 @@ public class WebSocketHandshakeHandOverTest { clientReceivedMessage = false; serverReceivedCloseHandshake = false; clientForceClosed = false; + clientHandshakeTimeout = false; } @Test @@ -118,6 +120,66 @@ public class WebSocketHandshakeHandOverTest { assertTrue(clientReceivedMessage); } + @Test(expected = WebSocketHandshakeException.class) + public void testClientHandshakeTimeout() throws Exception { + EmbeddedChannel serverChannel = createServerChannel(new SimpleChannelInboundHandler() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt == ServerHandshakeStateEvent.HANDSHAKE_COMPLETE) { + serverReceivedHandshake = true; + // immediately send a message to the client on connect + ctx.writeAndFlush(new TextWebSocketFrame("abc")); + } else if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) { + serverHandshakeComplete = (WebSocketServerProtocolHandler.HandshakeComplete) evt; + } + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { + } + }); + + EmbeddedChannel clientChannel = createClientChannel(new SimpleChannelInboundHandler() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt == ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) { + clientReceivedHandshake = true; + } else if (evt == ClientHandshakeStateEvent.HANDSHAKE_TIMEOUT) { + clientHandshakeTimeout = true; + } + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof TextWebSocketFrame) { + clientReceivedMessage = true; + } + } + }, 100); + // Client send the handshake request to server + transferAllDataWithMerge(clientChannel, serverChannel); + // Server do not send the response back + // transferAllDataWithMerge(serverChannel, clientChannel); + WebSocketClientProtocolHandshakeHandler handshakeHandler = + (WebSocketClientProtocolHandshakeHandler) clientChannel + .pipeline().get(WebSocketClientProtocolHandshakeHandler.class.getName()); + + while (!handshakeHandler.getHandshakeFuture().isDone()) { + Thread.sleep(10); + // We need to run all pending tasks as the handshake timeout is scheduled on the EventLoop. + clientChannel.runScheduledPendingTasks(); + } + assertTrue(clientHandshakeTimeout); + assertFalse(clientReceivedHandshake); + assertFalse(clientReceivedMessage); + // Should throw WebSocketHandshakeException + try { + handshakeHandler.getHandshakeFuture().syncUninterruptibly(); + } finally { + serverChannel.finishAndReleaseAll(); + } + } + @Test(timeout = 10000) public void testClientHandshakerForceClose() throws Exception { final WebSocketClientHandshaker handshaker = WebSocketClientHandshakerFactory.newHandshaker( @@ -245,4 +307,13 @@ public class WebSocketHandshakeHandOverTest { handler); } + private static EmbeddedChannel createClientChannel(ChannelHandler handler, long timeoutMillis) throws Exception { + return new EmbeddedChannel( + new HttpClientCodec(), + new HttpObjectAggregator(8192), + new WebSocketClientProtocolHandler(new URI("ws://localhost:1234/test"), + WebSocketVersion.V13, "test-proto-2", + false, null, 65536, timeoutMillis), + handler); + } }