From ef24732640e1f3f5621b65df476fee159d030d79 Mon Sep 17 00:00:00 2001 From: Dmitriy Dumanskiy Date: Tue, 16 Jul 2019 14:12:17 +0300 Subject: [PATCH] Cleanup in websockets, throw exception before allocating response if possible (#9361) Motivation: While fixing #9359 found few places that could be patched / improved separately. Modification: On handshake response generation - throw exception before allocating response objects if request is invalid. Result: No more leaks when exception is thrown. --- .../http/websocketx/CloseWebSocketFrame.java | 5 +---- .../websocketx/WebSocketClientHandshaker.java | 18 ++++++++++-------- .../WebSocketServerHandshaker00.java | 12 +++++++----- .../WebSocketServerHandshaker07.java | 8 ++++---- .../WebSocketServerHandshaker08.java | 9 +++++---- .../WebSocketServerHandshaker13.java | 9 +++++---- .../WebSocketServerProtocolHandler.java | 4 ++-- ...ebSocketServerProtocolHandshakeHandler.java | 2 -- 8 files changed, 34 insertions(+), 33 deletions(-) diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/CloseWebSocketFrame.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/CloseWebSocketFrame.java index d61af92632..0286c6fae3 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/CloseWebSocketFrame.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/CloseWebSocketFrame.java @@ -138,10 +138,7 @@ public class CloseWebSocketFrame extends WebSocketFrame { } binaryData.readerIndex(0); - int statusCode = binaryData.readShort(); - binaryData.readerIndex(0); - - return statusCode; + return binaryData.getShort(0); } /** diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker.java index e34b8909ad..4dd29d80d3 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker.java @@ -248,11 +248,10 @@ public abstract class WebSocketClientHandshaker { * the {@link ChannelPromise} to be notified when the opening handshake is sent */ public final ChannelFuture handshake(Channel channel, final ChannelPromise promise) { - FullHttpRequest request = newHandshakeRequest(); - - HttpResponseDecoder decoder = channel.pipeline().get(HttpResponseDecoder.class); + ChannelPipeline pipeline = channel.pipeline(); + HttpResponseDecoder decoder = pipeline.get(HttpResponseDecoder.class); if (decoder == null) { - HttpClientCodec codec = channel.pipeline().get(HttpClientCodec.class); + HttpClientCodec codec = pipeline.get(HttpClientCodec.class); if (codec == null) { promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " + "a HttpResponseDecoder or HttpClientCodec")); @@ -260,6 +259,8 @@ public abstract class WebSocketClientHandshaker { } } + FullHttpRequest request = newHandshakeRequest(); + channel.writeAndFlush(request).addListener((ChannelFutureListener) future -> { if (future.isSuccess()) { ChannelPipeline p = future.channel().pipeline(); @@ -559,14 +560,15 @@ public abstract class WebSocketClientHandshaker { return wsURL.getHost(); } String host = wsURL.getHost(); + String scheme = wsURL.getScheme(); if (port == HttpScheme.HTTP.port()) { - return HttpScheme.HTTP.name().contentEquals(wsURL.getScheme()) - || WebSocketScheme.WS.name().contentEquals(wsURL.getScheme()) ? + return HttpScheme.HTTP.name().contentEquals(scheme) + || WebSocketScheme.WS.name().contentEquals(scheme) ? host : NetUtil.toSocketAddressString(host, port); } if (port == HttpScheme.HTTPS.port()) { - return HttpScheme.HTTPS.name().contentEquals(wsURL.getScheme()) - || WebSocketScheme.WSS.name().contentEquals(wsURL.getScheme()) ? + return HttpScheme.HTTPS.name().contentEquals(scheme) + || WebSocketScheme.WSS.name().contentEquals(scheme) ? host : NetUtil.toSocketAddressString(host, port); } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker00.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker00.java index c3e50ee152..5c6429be47 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker00.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker00.java @@ -133,6 +133,12 @@ public class WebSocketServerHandshaker00 extends WebSocketServerHandshaker { boolean isHixie76 = req.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_KEY1) && req.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_KEY2); + String origin = req.headers().get(HttpHeaderNames.ORIGIN); + //throw before allocating FullHttpResponse + if (origin == null && !isHixie76) { + throw new WebSocketHandshakeException("Missing origin header, got only " + req.headers().names()); + } + // Create the WebSocket handshake response. FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, new HttpResponseStatus(101, isHixie76 ? "WebSocket Protocol Handshake" : "Web Socket Protocol Handshake")); @@ -146,7 +152,7 @@ public class WebSocketServerHandshaker00 extends WebSocketServerHandshaker { // Fill in the headers and contents depending on handshake getMethod. if (isHixie76) { // New handshake getMethod with a challenge: - res.headers().add(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, req.headers().get(HttpHeaderNames.ORIGIN)); + res.headers().add(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, origin); res.headers().add(HttpHeaderNames.SEC_WEBSOCKET_LOCATION, uri()); String subprotocols = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL); @@ -176,10 +182,6 @@ public class WebSocketServerHandshaker00 extends WebSocketServerHandshaker { res.content().writeBytes(WebSocketUtil.md5(input.array())); } else { // Old Hixie 75 handshake getMethod with no challenge: - String origin = req.headers().get(HttpHeaderNames.ORIGIN); - if (origin == null) { - throw new WebSocketHandshakeException("Missing origin header, got only " + req.headers().names()); - } res.headers().add(HttpHeaderNames.WEBSOCKET_ORIGIN, origin); res.headers().add(HttpHeaderNames.WEBSOCKET_LOCATION, uri()); diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker07.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker07.java index 1ab627d74d..c52d5e1e93 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker07.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker07.java @@ -128,6 +128,10 @@ public class WebSocketServerHandshaker07 extends WebSocketServerHandshaker { */ @Override protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) { + CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY); + if (key == null) { + throw new WebSocketHandshakeException("not a WebSocket request: missing key"); + } FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS); @@ -136,10 +140,6 @@ public class WebSocketServerHandshaker07 extends WebSocketServerHandshaker { res.headers().add(headers); } - CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY); - if (key == null) { - throw new WebSocketHandshakeException("not a WebSocket request: missing key"); - } String acceptSeed = key + WEBSOCKET_07_ACCEPT_GUID; byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII)); String accept = WebSocketUtil.base64(sha1); diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker08.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker08.java index e81dcc9fa4..e69ed7f1bf 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker08.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker08.java @@ -135,16 +135,17 @@ public class WebSocketServerHandshaker08 extends WebSocketServerHandshaker { */ @Override protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) { + CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY); + if (key == null) { + throw new WebSocketHandshakeException("not a WebSocket request: missing key"); + } + FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS); if (headers != null) { res.headers().add(headers); } - CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY); - if (key == null) { - throw new WebSocketHandshakeException("not a WebSocket request: missing key"); - } String acceptSeed = key + WEBSOCKET_08_ACCEPT_GUID; byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII)); String accept = WebSocketUtil.base64(sha1); diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker13.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker13.java index 976e48bc52..c3d92e85d9 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker13.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker13.java @@ -134,15 +134,16 @@ public class WebSocketServerHandshaker13 extends WebSocketServerHandshaker { */ @Override protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) { + CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY); + if (key == null) { + throw new WebSocketHandshakeException("not a WebSocket request: missing key"); + } + FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS); if (headers != null) { res.headers().add(headers); } - CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY); - if (key == null) { - throw new WebSocketHandshakeException("not a WebSocket request: missing key"); - } String acceptSeed = key + WEBSOCKET_13_ACCEPT_GUID; byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII)); String accept = WebSocketUtil.base64(sha1); diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandler.java index b08f489d66..2afcad3b4c 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandler.java @@ -212,13 +212,13 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler { ChannelPipeline cp = ctx.pipeline(); if (cp.get(WebSocketServerProtocolHandshakeHandler.class) == null) { // Add the WebSocketHandshakeHandler before this one. - ctx.pipeline().addBefore(ctx.name(), WebSocketServerProtocolHandshakeHandler.class.getName(), + cp.addBefore(ctx.name(), WebSocketServerProtocolHandshakeHandler.class.getName(), new WebSocketServerProtocolHandshakeHandler( websocketPath, subprotocols, checkStartsWith, handshakeTimeoutMillis, decoderConfig)); } if (cp.get(Utf8FrameValidator.class) == null) { // Add the UFT8 checking before this one. - ctx.pipeline().addBefore(ctx.name(), Utf8FrameValidator.class.getName(), + cp.addBefore(ctx.name(), Utf8FrameValidator.class.getName(), new Utf8FrameValidator()); } } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java index 1b21c3dc98..c937092961 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java @@ -44,8 +44,6 @@ import static io.netty.util.internal.ObjectUtil.*; */ class WebSocketServerProtocolHandshakeHandler implements ChannelInboundHandler { - private static final long DEFAULT_HANDSHAKE_TIMEOUT_MS = 10000L; - private final String websocketPath; private final String subprotocols; private final boolean checkStartsWith;