diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandler.java index a40e2d74e5..fcc7d16ab9 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandler.java @@ -23,6 +23,8 @@ import io.netty.handler.codec.http.HttpHeaders; import java.net.URI; import java.util.List; +import static io.netty.util.internal.ObjectUtil.*; + /** * This handler does all the heavy lifting for you to run a websocket client. * @@ -38,9 +40,11 @@ import java.util.List; * {@link ClientHandshakeStateEvent#HANDSHAKE_ISSUED} or {@link ClientHandshakeStateEvent#HANDSHAKE_COMPLETE}. */ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler { + private static final long DEFAULT_HANDSHAKE_TIMEOUT_MS = 10000L; private final WebSocketClientHandshaker handshaker; private final boolean handleCloseFrames; + private final long handshakeTimeoutMillis; /** * Returns the used handshaker @@ -53,6 +57,11 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler { * Events that are fired to notify about handshake status */ public enum ClientHandshakeStateEvent { + /** + * The Handshake was timed out + */ + HANDSHAKE_TIMEOUT, + /** * The Handshake was started but the server did not response yet to the request */ @@ -92,9 +101,45 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler { boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, boolean handleCloseFrames, boolean performMasking, boolean allowMaskMismatch) { + this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, + maxFramePayloadLength, handleCloseFrames, performMasking, allowMaskMismatch, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + /** + * Base constructor + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param handleCloseFrames + * {@code true} if close frames should not be forwarded and just close the channel + * @param performMasking + * Whether to mask all written websocket frames. This must be set to true in order to be fully compatible + * with the websocket specifications. Client applications that communicate with a non-standard server + * which doesn't require masking might set this to false to achieve a higher performance. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted. + * @param handshakeTimeoutMillis + * Handshake timeout in mills, when handshake timeout, will trigger user + * event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT} + */ + public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, + int maxFramePayloadLength, boolean handleCloseFrames, boolean performMasking, + boolean allowMaskMismatch, long handshakeTimeoutMillis) { this(WebSocketClientHandshakerFactory.newHandshaker(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, - performMasking, allowMaskMismatch), handleCloseFrames); + performMasking, allowMaskMismatch), + handleCloseFrames, handshakeTimeoutMillis); } /** @@ -118,7 +163,34 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler { boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, boolean handleCloseFrames) { this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, - handleCloseFrames, true, false); + handleCloseFrames, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + /** + * Base constructor + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param handleCloseFrames + * {@code true} if close frames should not be forwarded and just close the channel + * @param handshakeTimeoutMillis + * Handshake timeout in mills, when handshake timeout, will trigger user + * event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT} + */ + public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, + boolean handleCloseFrames, long handshakeTimeoutMillis) { + this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, + handleCloseFrames, true, false, handshakeTimeoutMillis); } /** @@ -140,7 +212,32 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler { boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength) { this(webSocketURL, version, subprotocol, - allowExtensions, customHeaders, maxFramePayloadLength, true); + allowExtensions, customHeaders, maxFramePayloadLength, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + /** + * Base constructor + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param handshakeTimeoutMillis + * Handshake timeout in mills, when handshake timeout, will trigger user + * event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT} + */ + public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, + int maxFramePayloadLength, long handshakeTimeoutMillis) { + this(webSocketURL, version, subprotocol, + allowExtensions, customHeaders, maxFramePayloadLength, true, handshakeTimeoutMillis); } /** @@ -153,7 +250,24 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler { * {@code true} if close frames should not be forwarded and just close the channel */ public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames) { - this(handshaker, handleCloseFrames, true); + this(handshaker, handleCloseFrames, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + /** + * Base constructor + * + * @param handshaker + * The {@link WebSocketClientHandshaker} which will be used to issue the handshake once the connection + * was established to the remote peer. + * @param handleCloseFrames + * {@code true} if close frames should not be forwarded and just close the channel + * @param handshakeTimeoutMillis + * Handshake timeout in mills, when handshake timeout, will trigger user + * event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT} + */ + public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames, + long handshakeTimeoutMillis) { + this(handshaker, handleCloseFrames, true, handshakeTimeoutMillis); } /** @@ -169,9 +283,29 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler { */ public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames, boolean dropPongFrames) { + this(handshaker, handleCloseFrames, dropPongFrames, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + /** + * Base constructor + * + * @param handshaker + * The {@link WebSocketClientHandshaker} which will be used to issue the handshake once the connection + * was established to the remote peer. + * @param handleCloseFrames + * {@code true} if close frames should not be forwarded and just close the channel + * @param dropPongFrames + * {@code true} if pong frames should not be forwarded + * @param handshakeTimeoutMillis + * Handshake timeout in mills, when handshake timeout, will trigger user + * event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT} + */ + public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames, + boolean dropPongFrames, long handshakeTimeoutMillis) { super(dropPongFrames); this.handshaker = handshaker; this.handleCloseFrames = handleCloseFrames; + this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis"); } /** @@ -182,7 +316,21 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler { * was established to the remote peer. */ public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker) { - this(handshaker, true); + this(handshaker, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + /** + * Base constructor + * + * @param handshaker + * The {@link WebSocketClientHandshaker} which will be used to issue the handshake once the connection + * was established to the remote peer. + * @param handshakeTimeoutMillis + * Handshake timeout in mills, when handshake timeout, will trigger user + * event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT} + */ + public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, long handshakeTimeoutMillis) { + this(handshaker, true, handshakeTimeoutMillis); } @Override @@ -200,7 +348,7 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler { if (cp.get(WebSocketClientProtocolHandshakeHandler.class) == null) { // Add the WebSocketClientProtocolHandshakeHandler before this one. ctx.pipeline().addBefore(ctx.name(), WebSocketClientProtocolHandshakeHandler.class.getName(), - new WebSocketClientProtocolHandshakeHandler(handshaker)); + new WebSocketClientProtocolHandshakeHandler(handshaker, handshakeTimeoutMillis)); } if (cp.get(Utf8FrameValidator.class) == null) { // Add the UFT8 checking before this one. diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandshakeHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandshakeHandler.java index 7077c017bd..4eaffec763 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandshakeHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandshakeHandler.java @@ -15,16 +15,47 @@ */ package io.netty.handler.codec.http.websocketx; +import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandler; +import io.netty.channel.ChannelPromise; import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler.ClientHandshakeStateEvent; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; +import io.netty.util.internal.ThrowableUtil; + +import java.util.concurrent.TimeUnit; + +import static io.netty.util.internal.ObjectUtil.*; class WebSocketClientProtocolHandshakeHandler implements ChannelInboundHandler { + + private static final WebSocketHandshakeException HANDSHAKE_TIMED_OUT_EXCEPTION = ThrowableUtil.unknownStackTrace( + new WebSocketHandshakeException("handshake timed out"), + WebSocketClientProtocolHandshakeHandler.class, + "channelActive(...)"); + private static final long DEFAULT_HANDSHAKE_TIMEOUT_MS = 10000L; + private final WebSocketClientHandshaker handshaker; + private final long handshakeTimeoutMillis; + private ChannelHandlerContext ctx; + private ChannelPromise handshakePromise; WebSocketClientProtocolHandshakeHandler(WebSocketClientHandshaker handshaker) { + this(handshaker, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + WebSocketClientProtocolHandshakeHandler(WebSocketClientHandshaker handshaker, long handshakeTimeoutMillis) { this.handshaker = handshaker; + this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis"); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + this.ctx = ctx; + handshakePromise = ctx.newPromise(); } @Override @@ -32,12 +63,14 @@ class WebSocketClientProtocolHandshakeHandler implements ChannelInboundHandler { ctx.fireChannelActive(); handshaker.handshake(ctx.channel()).addListener((ChannelFutureListener) future -> { if (!future.isSuccess()) { + handshakePromise.tryFailure(future.cause()); ctx.fireExceptionCaught(future.cause()); } else { ctx.fireUserEventTriggered( WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_ISSUED); } }); + applyHandshakeTimeout(); } @Override @@ -51,6 +84,7 @@ class WebSocketClientProtocolHandshakeHandler implements ChannelInboundHandler { try { if (!handshaker.isHandshakeComplete()) { handshaker.finishHandshake(ctx.channel(), response); + handshakePromise.trySuccess(); ctx.fireUserEventTriggered( WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE); ctx.pipeline().remove(this); @@ -61,4 +95,43 @@ class WebSocketClientProtocolHandshakeHandler implements ChannelInboundHandler { response.release(); } } + + private void applyHandshakeTimeout() { + final ChannelPromise localHandshakePromise = handshakePromise; + if (handshakeTimeoutMillis <= 0 || localHandshakePromise.isDone()) { + return; + } + + final Future timeoutFuture = ctx.executor().schedule(new Runnable() { + @Override + public void run() { + if (localHandshakePromise.isDone()) { + return; + } + + if (localHandshakePromise.tryFailure(HANDSHAKE_TIMED_OUT_EXCEPTION)) { + ctx.flush() + .fireUserEventTriggered(ClientHandshakeStateEvent.HANDSHAKE_TIMEOUT) + .close(); + } + } + }, handshakeTimeoutMillis, TimeUnit.MILLISECONDS); + + // Cancel the handshake timeout when handshake is finished. + localHandshakePromise.addListener(new FutureListener() { + @Override + public void operationComplete(Future f) throws Exception { + timeoutFuture.cancel(false); + } + }); + } + + /** + * This method is visible for testing. + * + * @return current handshake future + */ + ChannelFuture getHandshakeFuture() { + return handshakePromise; + } } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandler.java index 82c631c4b0..ae1f8d17f5 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandler.java @@ -32,6 +32,7 @@ import io.netty.util.AttributeKey; import java.util.List; import static io.netty.handler.codec.http.HttpVersion.*; +import static io.netty.util.internal.ObjectUtil.*; /** * This handler does all the heavy lifting for you to run a websocket server. @@ -63,7 +64,12 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler { * it provides extra information about the handshake */ @Deprecated - HANDSHAKE_COMPLETE + HANDSHAKE_COMPLETE, + + /** + * The Handshake was timed out + */ + HANDSHAKE_TIMEOUT } /** @@ -96,47 +102,94 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler { private static final AttributeKey HANDSHAKER_ATTR_KEY = AttributeKey.valueOf(WebSocketServerHandshaker.class, "HANDSHAKER"); + private static final long DEFAULT_HANDSHAKE_TIMEOUT_MS = 10000L; + private final String websocketPath; private final String subprotocols; private final boolean allowExtensions; private final int maxFramePayloadLength; private final boolean allowMaskMismatch; private final boolean checkStartsWith; + private final long handshakeTimeoutMillis; public WebSocketServerProtocolHandler(String websocketPath) { + this(websocketPath, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + public WebSocketServerProtocolHandler(String websocketPath, long handshakeTimeoutMillis) { this(websocketPath, null, false); } public WebSocketServerProtocolHandler(String websocketPath, boolean checkStartsWith) { - this(websocketPath, null, false, 65536, false, checkStartsWith); + this(websocketPath, checkStartsWith, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + public WebSocketServerProtocolHandler(String websocketPath, boolean checkStartsWith, long handshakeTimeoutMillis) { + this(websocketPath, null, false, 65536, false, checkStartsWith, handshakeTimeoutMillis); } public WebSocketServerProtocolHandler(String websocketPath, String subprotocols) { - this(websocketPath, subprotocols, false); + this(websocketPath, subprotocols, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, long handshakeTimeoutMillis) { + this(websocketPath, subprotocols, false, handshakeTimeoutMillis); } public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions) { - this(websocketPath, subprotocols, allowExtensions, 65536); + this(websocketPath, subprotocols, allowExtensions, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, + long handshakeTimeoutMillis) { + this(websocketPath, subprotocols, allowExtensions, 65536, handshakeTimeoutMillis); } public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, int maxFrameSize) { - this(websocketPath, subprotocols, allowExtensions, maxFrameSize, false); + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, + boolean allowExtensions, int maxFrameSize, long handshakeTimeoutMillis) { + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, false, handshakeTimeoutMillis); } public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch) { - this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false); + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, + DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, + int maxFrameSize, boolean allowMaskMismatch, long handshakeTimeoutMillis) { + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false, + handshakeTimeoutMillis); } public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith) { - this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith, true); + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith, + DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, + boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, + boolean checkStartsWith, long handshakeTimeoutMillis) { + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith, true, + handshakeTimeoutMillis); } public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith, boolean dropPongFrames) { + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith, + dropPongFrames, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, + int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith, + boolean dropPongFrames, long handshakeTimeoutMillis) { super(dropPongFrames); this.websocketPath = websocketPath; this.subprotocols = subprotocols; @@ -144,6 +197,7 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler { maxFramePayloadLength = maxFrameSize; this.allowMaskMismatch = allowMaskMismatch; this.checkStartsWith = checkStartsWith; + this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis"); } @Override @@ -152,8 +206,11 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler { if (cp.get(WebSocketServerProtocolHandshakeHandler.class) == null) { // Add the WebSocketHandshakeHandler before this one. ctx.pipeline().addBefore(ctx.name(), WebSocketServerProtocolHandshakeHandler.class.getName(), - new WebSocketServerProtocolHandshakeHandler(websocketPath, subprotocols, - allowExtensions, maxFramePayloadLength, allowMaskMismatch, checkStartsWith)); + new WebSocketServerProtocolHandshakeHandler(websocketPath, subprotocols, + allowExtensions, maxFramePayloadLength, + allowMaskMismatch, + checkStartsWith, + handshakeTimeoutMillis)); } if (cp.get(Utf8FrameValidator.class) == null) { // Add the UFT8 checking before this one. diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java index db1161a240..3fcd933e53 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java @@ -20,22 +20,36 @@ import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandler; import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler.ServerHandshakeStateEvent; import io.netty.handler.ssl.SslHandler; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; +import io.netty.util.internal.ThrowableUtil; + +import java.util.concurrent.TimeUnit; -import static io.netty.handler.codec.http.HttpUtil.*; import static io.netty.handler.codec.http.HttpMethod.*; import static io.netty.handler.codec.http.HttpResponseStatus.*; +import static io.netty.handler.codec.http.HttpUtil.*; import static io.netty.handler.codec.http.HttpVersion.*; +import static io.netty.util.internal.ObjectUtil.*; /** * Handles the HTTP handshake (the HTTP Upgrade request) for {@link WebSocketServerProtocolHandler}. */ class WebSocketServerProtocolHandshakeHandler implements ChannelInboundHandler { + private static final WebSocketHandshakeException HANDSHAKE_TIMED_OUT_EXCEPTION = ThrowableUtil.unknownStackTrace( + new WebSocketHandshakeException("handshake timed out"), + WebSocketServerProtocolHandshakeHandler.class, + "channelRead(...)"); + + private static final long DEFAULT_HANDSHAKE_TIMEOUT_MS = 10000L; private final String websocketPath; private final String subprotocols; @@ -43,20 +57,45 @@ class WebSocketServerProtocolHandshakeHandler implements ChannelInboundHandler { private final int maxFramePayloadSize; private final boolean allowMaskMismatch; private final boolean checkStartsWith; + private final long handshakeTimeoutMillis; + private ChannelHandlerContext ctx; + private ChannelPromise handshakePromise; WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols, boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch) { - this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false); + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, + DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols, + boolean allowExtensions, int maxFrameSize, + boolean allowMaskMismatch, long handshakeTimeoutMillis) { + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, + false, handshakeTimeoutMillis); } WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols, boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith) { + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, + checkStartsWith, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols, + boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, + boolean checkStartsWith, long handshakeTimeoutMillis) { this.websocketPath = websocketPath; this.subprotocols = subprotocols; this.allowExtensions = allowExtensions; maxFramePayloadSize = maxFrameSize; this.allowMaskMismatch = allowMaskMismatch; this.checkStartsWith = checkStartsWith; + this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis"); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + this.ctx = ctx; + handshakePromise = ctx.newPromise(); } @Override @@ -77,14 +116,17 @@ class WebSocketServerProtocolHandshakeHandler implements ChannelInboundHandler { getWebSocketLocation(ctx.pipeline(), req, websocketPath), subprotocols, allowExtensions, maxFramePayloadSize, allowMaskMismatch); final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req); + final ChannelPromise localHandshakePromise = handshakePromise; if (handshaker == null) { WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel()); } else { final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req); handshakeFuture.addListener((ChannelFutureListener) future -> { if (!future.isSuccess()) { + localHandshakePromise.tryFailure(future.cause()); ctx.fireExceptionCaught(future.cause()); } else { + localHandshakePromise.trySuccess(); // Kept for compatibility ctx.fireUserEventTriggered( WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE); @@ -93,6 +135,7 @@ class WebSocketServerProtocolHandshakeHandler implements ChannelInboundHandler { req.uri(), req.headers(), handshaker.selectedSubprotocol())); } }); + applyHandshakeTimeout(); WebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker); ctx.pipeline().replace(this, "WS403Responder", WebSocketServerProtocolHandler.forbiddenHttpRequestResponder()); @@ -122,4 +165,31 @@ class WebSocketServerProtocolHandshakeHandler implements ChannelInboundHandler { String host = req.headers().get(HttpHeaderNames.HOST); return protocol + "://" + host + path; } + + private void applyHandshakeTimeout() { + final ChannelPromise localHandshakePromise = handshakePromise; + final long handshakeTimeoutMillis = this.handshakeTimeoutMillis; + if (handshakeTimeoutMillis <= 0 || localHandshakePromise.isDone()) { + return; + } + + final Future timeoutFuture = ctx.executor().schedule(new Runnable() { + @Override + public void run() { + if (localHandshakePromise.tryFailure(HANDSHAKE_TIMED_OUT_EXCEPTION)) { + ctx.flush() + .fireUserEventTriggered(ServerHandshakeStateEvent.HANDSHAKE_TIMEOUT) + .close(); + } + } + }, handshakeTimeoutMillis, TimeUnit.MILLISECONDS); + + // Cancel the handshake timeout when handshake is finished. + localHandshakePromise.addListener(new FutureListener() { + @Override + public void operationComplete(Future f) throws Exception { + timeoutFuture.cancel(false); + } + }); + } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketHandshakeHandOverTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketHandshakeHandOverTest.java index 5583a27339..87e4e22ff0 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketHandshakeHandOverTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketHandshakeHandOverTest.java @@ -34,6 +34,7 @@ import org.junit.Test; import java.net.URI; import java.util.List; +import java.util.concurrent.CompletionException; import static org.junit.Assert.*; @@ -45,6 +46,7 @@ public class WebSocketHandshakeHandOverTest { private boolean clientReceivedMessage; private boolean serverReceivedCloseHandshake; private boolean clientForceClosed; + private boolean clientHandshakeTimeout; private final class CloseNoOpServerProtocolHandler extends WebSocketServerProtocolHandler { CloseNoOpServerProtocolHandler(String websocketPath) { @@ -69,6 +71,7 @@ public class WebSocketHandshakeHandOverTest { clientReceivedMessage = false; serverReceivedCloseHandshake = false; clientForceClosed = false; + clientHandshakeTimeout = false; } @Test @@ -118,6 +121,68 @@ public class WebSocketHandshakeHandOverTest { assertTrue(clientReceivedMessage); } + @Test(expected = WebSocketHandshakeException.class) + public void testClientHandshakeTimeout() throws Throwable { + EmbeddedChannel serverChannel = createServerChannel(new SimpleChannelInboundHandler() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt == ServerHandshakeStateEvent.HANDSHAKE_COMPLETE) { + serverReceivedHandshake = true; + // immediately send a message to the client on connect + ctx.writeAndFlush(new TextWebSocketFrame("abc")); + } else if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) { + serverHandshakeComplete = (WebSocketServerProtocolHandler.HandshakeComplete) evt; + } + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { + } + }); + + EmbeddedChannel clientChannel = createClientChannel(new SimpleChannelInboundHandler() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt == ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) { + clientReceivedHandshake = true; + } else if (evt == ClientHandshakeStateEvent.HANDSHAKE_TIMEOUT) { + clientHandshakeTimeout = true; + } + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof TextWebSocketFrame) { + clientReceivedMessage = true; + } + } + }, 100); + // Client send the handshake request to server + transferAllDataWithMerge(clientChannel, serverChannel); + // Server do not send the response back + // transferAllDataWithMerge(serverChannel, clientChannel); + WebSocketClientProtocolHandshakeHandler handshakeHandler = + (WebSocketClientProtocolHandshakeHandler) clientChannel + .pipeline().get(WebSocketClientProtocolHandshakeHandler.class.getName()); + + while (!handshakeHandler.getHandshakeFuture().isDone()) { + Thread.sleep(10); + // We need to run all pending tasks as the handshake timeout is scheduled on the EventLoop. + clientChannel.runScheduledPendingTasks(); + } + assertTrue(clientHandshakeTimeout); + assertFalse(clientReceivedHandshake); + assertFalse(clientReceivedMessage); + // Should throw WebSocketHandshakeException + try { + handshakeHandler.getHandshakeFuture().syncUninterruptibly(); + } catch (CompletionException e) { + throw e.getCause(); + } finally { + serverChannel.finishAndReleaseAll(); + } + } + @Test(timeout = 10000) public void testClientHandshakerForceClose() throws Exception { final WebSocketClientHandshaker handshaker = WebSocketClientHandshakerFactory.newHandshaker( @@ -245,4 +310,13 @@ public class WebSocketHandshakeHandOverTest { handler); } + private static EmbeddedChannel createClientChannel(ChannelHandler handler, long timeoutMillis) throws Exception { + return new EmbeddedChannel( + new HttpClientCodec(), + new HttpObjectAggregator(8192), + new WebSocketClientProtocolHandler(new URI("ws://localhost:1234/test"), + WebSocketVersion.V13, "test-proto-2", + false, null, 65536, timeoutMillis), + handler); + } }