diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandler.java index dc48516687..e6e7e38b74 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandler.java @@ -102,11 +102,16 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler { private final boolean allowExtensions; private final int maxFramePayloadLength; private final boolean allowMaskMismatch; + private final boolean checkStartsWith; public WebSocketServerProtocolHandler(String websocketPath) { this(websocketPath, null, false); } + public WebSocketServerProtocolHandler(String websocketPath, boolean checkStartsWith) { + this(websocketPath, null, false, 65536, false, checkStartsWith); + } + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols) { this(websocketPath, subprotocols, false); } @@ -122,11 +127,17 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler { public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch) { + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, + boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith) { this.websocketPath = websocketPath; this.subprotocols = subprotocols; this.allowExtensions = allowExtensions; maxFramePayloadLength = maxFrameSize; this.allowMaskMismatch = allowMaskMismatch; + this.checkStartsWith = checkStartsWith; } @Override @@ -136,7 +147,7 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler { // Add the WebSocketHandshakeHandler before this one. ctx.pipeline().addBefore(ctx.name(), WebSocketServerProtocolHandshakeHandler.class.getName(), new WebSocketServerProtocolHandshakeHandler(websocketPath, subprotocols, - allowExtensions, maxFramePayloadLength, allowMaskMismatch)); + allowExtensions, maxFramePayloadLength, allowMaskMismatch, checkStartsWith)); } if (cp.get(Utf8FrameValidator.class) == null) { // Add the UFT8 checking before this one. diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java index e63a715db1..fc341ce106 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java @@ -42,20 +42,32 @@ class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapt private final boolean allowExtensions; private final int maxFramePayloadSize; private final boolean allowMaskMismatch; + private final boolean checkStartsWith; WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols, boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch) { + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false); + } + + WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols, + boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith) { this.websocketPath = websocketPath; this.subprotocols = subprotocols; this.allowExtensions = allowExtensions; maxFramePayloadSize = maxFrameSize; this.allowMaskMismatch = allowMaskMismatch; + this.checkStartsWith = checkStartsWith; } @Override public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception { final FullHttpRequest req = (FullHttpRequest) msg; - if (!websocketPath.equals(req.uri())) { + if (checkStartsWith) { + if (!req.uri().startsWith(websocketPath)) { + ctx.fireChannelRead(msg); + return; + } + } else if (!req.uri().equals(websocketPath)) { ctx.fireChannelRead(msg); return; }