Add validation check about websocket path (#10583)

Add validation check about websocket path

Motivation:

I add websocket handler in custom server with netty.
I first add WebSocketServerProtocolHandler in my channel pipeline.
It does work! but I found that it can pass "/websocketabc". (websocketPath is "/websocket")

Modification:
`isWebSocketPath()` method of `WebSocketServerProtocolHandshakeHandler` now checks that "startsWith" applies to the first URL path component, rather than the URL as a string.

Result:
Requests to "/websocketabc" are no longer passed to handlers for requests that starts-with "/websocket".
This commit is contained in:
Doyun Geum 2020-10-08 19:06:43 +09:00 committed by GitHub
parent 461b6e0f7b
commit d01471917b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 77 additions and 9 deletions

View File

@ -61,7 +61,7 @@ class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapt
@Override
public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception {
final FullHttpRequest req = (FullHttpRequest) msg;
if (isNotWebSocketPath(req)) {
if (!isWebSocketPath(req)) {
ctx.fireChannelRead(msg);
return;
}
@ -113,9 +113,21 @@ class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapt
}
}
private boolean isNotWebSocketPath(FullHttpRequest req) {
private boolean isWebSocketPath(FullHttpRequest req) {
String websocketPath = serverConfig.websocketPath();
return serverConfig.checkStartsWith() ? !req.uri().startsWith(websocketPath) : !req.uri().equals(websocketPath);
String uri = req.uri();
boolean checkStartUri = uri.startsWith(websocketPath);
boolean checkNextUri = checkNextUri(uri, websocketPath);
return serverConfig.checkStartsWith() ? (checkStartUri && checkNextUri) : uri.equals(websocketPath);
}
private boolean checkNextUri(String uri, String websocketPath) {
int len = websocketPath.length();
if (uri.length() > len) {
char nextUri = uri.charAt(len);
return nextUri == '/' || nextUri == '?';
}
return true;
}
private static void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) {

View File

@ -22,17 +22,18 @@ import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpRequestDecoder;
import io.netty.handler.codec.http.HttpResponseEncoder;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.util.CharsetUtil;
import io.netty.util.ReferenceCountUtil;
import org.junit.Before;
@ -194,6 +195,61 @@ public class WebSocketServerProtocolHandlerTest {
assertFalse(ch.finish());
}
@Test
public void testCheckValidWebSocketPath() {
HttpRequest httpRequest = new WebSocketRequestBuilder().httpVersion(HTTP_1_1)
.method(HttpMethod.GET)
.uri("/test")
.key(HttpHeaderNames.SEC_WEBSOCKET_KEY)
.connection("Upgrade")
.upgrade(HttpHeaderValues.WEBSOCKET)
.version13()
.build();
WebSocketServerProtocolConfig config = WebSocketServerProtocolConfig.newBuilder()
.websocketPath("/test")
.checkStartsWith(true)
.build();
EmbeddedChannel ch = new EmbeddedChannel(
new WebSocketServerProtocolHandler(config),
new HttpRequestDecoder(),
new HttpResponseEncoder(),
new MockOutboundHandler());
ch.writeInbound(httpRequest);
FullHttpResponse response = responses.remove();
assertEquals(SWITCHING_PROTOCOLS, response.status());
response.release();
}
@Test
public void testCheckInvalidWebSocketPath() {
HttpRequest httpRequest = new WebSocketRequestBuilder().httpVersion(HTTP_1_1)
.method(HttpMethod.GET)
.uri("/testabc")
.key(HttpHeaderNames.SEC_WEBSOCKET_KEY)
.connection("Upgrade")
.upgrade(HttpHeaderValues.WEBSOCKET)
.version13()
.build();
WebSocketServerProtocolConfig config = WebSocketServerProtocolConfig.newBuilder()
.websocketPath("/test")
.checkStartsWith(true)
.build();
EmbeddedChannel ch = new EmbeddedChannel(
new WebSocketServerProtocolHandler(config),
new HttpRequestDecoder(),
new HttpResponseEncoder(),
new MockOutboundHandler());
ch.writeInbound(httpRequest);
ChannelHandlerContext handshakerCtx = ch.pipeline().context(WebSocketServerProtocolHandshakeHandler.class);
assertNull(WebSocketServerProtocolHandler.getHandshaker(handshakerCtx.channel()));
}
@Test
public void testExplicitCloseFrameSentWhenServerChannelClosed() throws Exception {
WebSocketCloseStatus closeStatus = WebSocketCloseStatus.ENDPOINT_UNAVAILABLE;