adjust validation logic when websocket server check starts with '/' (#11191)

Motivation:

When create a WebSocketServerProtocolConfig to check URI path starts from '/',
only '/' or '//subPath' can be passed by the checker,but '/subPath' should be
passed as well

Modifications:

in `WebSocketServerProtocolHandshakeHandler.isWebSocketPath()` treat '/' a special case

Result:
'/subPath' can be passed
This commit is contained in:
roy 2021-04-26 16:32:10 +08:00 committed by Norman Maurer
parent f221e4d706
commit 636244c287
2 changed files with 37 additions and 1 deletions

View File

@ -112,7 +112,7 @@ class WebSocketServerProtocolHandshakeHandler implements ChannelHandler {
String websocketPath = serverConfig.websocketPath(); String websocketPath = serverConfig.websocketPath();
String uri = req.uri(); String uri = req.uri();
boolean checkStartUri = uri.startsWith(websocketPath); boolean checkStartUri = uri.startsWith(websocketPath);
boolean checkNextUri = checkNextUri(uri, websocketPath); boolean checkNextUri = "/".equals(websocketPath) || checkNextUri(uri, websocketPath);
return serverConfig.checkStartsWith() ? (checkStartUri && checkNextUri) : uri.equals(websocketPath); return serverConfig.checkStartsWith() ? (checkStartUri && checkNextUri) : uri.equals(websocketPath);
} }

View File

@ -193,6 +193,38 @@ public class WebSocketServerProtocolHandlerTest {
assertFalse(ch.finish()); assertFalse(ch.finish());
} }
@Test
public void testCheckWebSocketPathStartWithSlash() {
WebSocketRequestBuilder builder = new WebSocketRequestBuilder().httpVersion(HTTP_1_1)
.method(HttpMethod.GET)
.key(HttpHeaderNames.SEC_WEBSOCKET_KEY)
.connection("Upgrade")
.upgrade(HttpHeaderValues.WEBSOCKET)
.version13();
WebSocketServerProtocolConfig config = WebSocketServerProtocolConfig.newBuilder()
.websocketPath("/")
.checkStartsWith(true)
.build();
FullHttpResponse response;
createChannel(config, null).writeInbound(builder.uri("/test").build());
response = responses.remove();
assertEquals(SWITCHING_PROTOCOLS, response.status());
response.release();
createChannel(config, null).writeInbound(builder.uri("/?q=v").build());
response = responses.remove();
assertEquals(SWITCHING_PROTOCOLS, response.status());
response.release();
createChannel(config, null).writeInbound(builder.uri("/").build());
response = responses.remove();
assertEquals(SWITCHING_PROTOCOLS, response.status());
response.release();
}
@Test @Test
public void testCheckValidWebSocketPath() { public void testCheckValidWebSocketPath() {
HttpRequest httpRequest = new WebSocketRequestBuilder().httpVersion(HTTP_1_1) HttpRequest httpRequest = new WebSocketRequestBuilder().httpVersion(HTTP_1_1)
@ -400,6 +432,10 @@ public class WebSocketServerProtocolHandlerTest {
.websocketPath("/test") .websocketPath("/test")
.sendCloseFrame(null) .sendCloseFrame(null)
.build(); .build();
return createChannel(serverConfig, handler);
}
private EmbeddedChannel createChannel(WebSocketServerProtocolConfig serverConfig, ChannelHandler handler) {
return new EmbeddedChannel( return new EmbeddedChannel(
new WebSocketServerProtocolHandler(serverConfig), new WebSocketServerProtocolHandler(serverConfig),
new HttpRequestDecoder(), new HttpRequestDecoder(),