From d01471917b94d15fbe8b3b2e0f0ed8f24ee2954a Mon Sep 17 00:00:00 2001 From: Doyun Geum <46062199+Doyuni@users.noreply.github.com> Date: Thu, 8 Oct 2020 19:06:43 +0900 Subject: [PATCH] Add validation check about websocket path (#10583) Add validation check about websocket path Motivation: I add websocket handler in custom server with netty. I first add WebSocketServerProtocolHandler in my channel pipeline. It does work! but I found that it can pass "/websocketabc". (websocketPath is "/websocket") Modification: `isWebSocketPath()` method of `WebSocketServerProtocolHandshakeHandler` now checks that "startsWith" applies to the first URL path component, rather than the URL as a string. Result: Requests to "/websocketabc" are no longer passed to handlers for requests that starts-with "/websocket". --- ...bSocketServerProtocolHandshakeHandler.java | 18 ++++- .../WebSocketServerProtocolHandlerTest.java | 68 +++++++++++++++++-- 2 files changed, 77 insertions(+), 9 deletions(-) diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java index ec9b4ff104..f3f5332756 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java @@ -61,7 +61,7 @@ class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapt @Override public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception { final FullHttpRequest req = (FullHttpRequest) msg; - if (isNotWebSocketPath(req)) { + if (!isWebSocketPath(req)) { ctx.fireChannelRead(msg); return; } @@ -113,9 +113,21 @@ class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapt } } - private boolean isNotWebSocketPath(FullHttpRequest req) { + private boolean isWebSocketPath(FullHttpRequest req) { String websocketPath = serverConfig.websocketPath(); - return serverConfig.checkStartsWith() ? !req.uri().startsWith(websocketPath) : !req.uri().equals(websocketPath); + String uri = req.uri(); + boolean checkStartUri = uri.startsWith(websocketPath); + boolean checkNextUri = checkNextUri(uri, websocketPath); + return serverConfig.checkStartsWith() ? (checkStartUri && checkNextUri) : uri.equals(websocketPath); + } + + private boolean checkNextUri(String uri, String websocketPath) { + int len = websocketPath.length(); + if (uri.length() > len) { + char nextUri = uri.charAt(len); + return nextUri == '/' || nextUri == '?'; + } + return true; } private static void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) { diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandlerTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandlerTest.java index 69cda51f0b..1ca29b0a5e 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandlerTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandlerTest.java @@ -22,17 +22,18 @@ import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelOutboundHandlerAdapter; import io.netty.channel.ChannelPromise; import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.handler.codec.http.DefaultFullHttpRequest; -import io.netty.handler.codec.http.FullHttpRequest; -import io.netty.handler.codec.http.FullHttpResponse; + import io.netty.handler.codec.http.HttpClientCodec; import io.netty.handler.codec.http.HttpHeaderValues; -import io.netty.handler.codec.http.HttpMethod; -import io.netty.handler.codec.http.HttpObjectAggregator; -import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpRequestDecoder; import io.netty.handler.codec.http.HttpResponseEncoder; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpServerCodec; +import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.util.CharsetUtil; import io.netty.util.ReferenceCountUtil; import org.junit.Before; @@ -194,6 +195,61 @@ public class WebSocketServerProtocolHandlerTest { assertFalse(ch.finish()); } + @Test + public void testCheckValidWebSocketPath() { + HttpRequest httpRequest = new WebSocketRequestBuilder().httpVersion(HTTP_1_1) + .method(HttpMethod.GET) + .uri("/test") + .key(HttpHeaderNames.SEC_WEBSOCKET_KEY) + .connection("Upgrade") + .upgrade(HttpHeaderValues.WEBSOCKET) + .version13() + .build(); + + WebSocketServerProtocolConfig config = WebSocketServerProtocolConfig.newBuilder() + .websocketPath("/test") + .checkStartsWith(true) + .build(); + + EmbeddedChannel ch = new EmbeddedChannel( + new WebSocketServerProtocolHandler(config), + new HttpRequestDecoder(), + new HttpResponseEncoder(), + new MockOutboundHandler()); + ch.writeInbound(httpRequest); + + FullHttpResponse response = responses.remove(); + assertEquals(SWITCHING_PROTOCOLS, response.status()); + response.release(); + } + + @Test + public void testCheckInvalidWebSocketPath() { + HttpRequest httpRequest = new WebSocketRequestBuilder().httpVersion(HTTP_1_1) + .method(HttpMethod.GET) + .uri("/testabc") + .key(HttpHeaderNames.SEC_WEBSOCKET_KEY) + .connection("Upgrade") + .upgrade(HttpHeaderValues.WEBSOCKET) + .version13() + .build(); + + WebSocketServerProtocolConfig config = WebSocketServerProtocolConfig.newBuilder() + .websocketPath("/test") + .checkStartsWith(true) + .build(); + + EmbeddedChannel ch = new EmbeddedChannel( + new WebSocketServerProtocolHandler(config), + new HttpRequestDecoder(), + new HttpResponseEncoder(), + new MockOutboundHandler()); + ch.writeInbound(httpRequest); + + ChannelHandlerContext handshakerCtx = ch.pipeline().context(WebSocketServerProtocolHandshakeHandler.class); + assertNull(WebSocketServerProtocolHandler.getHandshaker(handshakerCtx.channel())); + } + @Test public void testExplicitCloseFrameSentWhenServerChannelClosed() throws Exception { WebSocketCloseStatus closeStatus = WebSocketCloseStatus.ENDPOINT_UNAVAILABLE;