Cleanup in websockets, throw exception before allocating response if possible (#9361)
Motivation: While fixing #9359 found few places that could be patched / improved separately. Modification: On handshake response generation - throw exception before allocating response objects if request is invalid. Result: No more leaks when exception is thrown.
This commit is contained in:
parent
4f172c13bb
commit
cd824e4e31
@ -138,10 +138,7 @@ public class CloseWebSocketFrame extends WebSocketFrame {
|
||||
}
|
||||
|
||||
binaryData.readerIndex(0);
|
||||
int statusCode = binaryData.readShort();
|
||||
binaryData.readerIndex(0);
|
||||
|
||||
return statusCode;
|
||||
return binaryData.getShort(0);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -248,11 +248,10 @@ public abstract class WebSocketClientHandshaker {
|
||||
* the {@link ChannelPromise} to be notified when the opening handshake is sent
|
||||
*/
|
||||
public final ChannelFuture handshake(Channel channel, final ChannelPromise promise) {
|
||||
FullHttpRequest request = newHandshakeRequest();
|
||||
|
||||
HttpResponseDecoder decoder = channel.pipeline().get(HttpResponseDecoder.class);
|
||||
ChannelPipeline pipeline = channel.pipeline();
|
||||
HttpResponseDecoder decoder = pipeline.get(HttpResponseDecoder.class);
|
||||
if (decoder == null) {
|
||||
HttpClientCodec codec = channel.pipeline().get(HttpClientCodec.class);
|
||||
HttpClientCodec codec = pipeline.get(HttpClientCodec.class);
|
||||
if (codec == null) {
|
||||
promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
|
||||
"a HttpResponseDecoder or HttpClientCodec"));
|
||||
@ -260,6 +259,8 @@ public abstract class WebSocketClientHandshaker {
|
||||
}
|
||||
}
|
||||
|
||||
FullHttpRequest request = newHandshakeRequest();
|
||||
|
||||
channel.writeAndFlush(request).addListener(new ChannelFutureListener() {
|
||||
@Override
|
||||
public void operationComplete(ChannelFuture future) {
|
||||
@ -584,14 +585,15 @@ public abstract class WebSocketClientHandshaker {
|
||||
return wsURL.getHost();
|
||||
}
|
||||
String host = wsURL.getHost();
|
||||
String scheme = wsURL.getScheme();
|
||||
if (port == HttpScheme.HTTP.port()) {
|
||||
return HttpScheme.HTTP.name().contentEquals(wsURL.getScheme())
|
||||
|| WebSocketScheme.WS.name().contentEquals(wsURL.getScheme()) ?
|
||||
return HttpScheme.HTTP.name().contentEquals(scheme)
|
||||
|| WebSocketScheme.WS.name().contentEquals(scheme) ?
|
||||
host : NetUtil.toSocketAddressString(host, port);
|
||||
}
|
||||
if (port == HttpScheme.HTTPS.port()) {
|
||||
return HttpScheme.HTTPS.name().contentEquals(wsURL.getScheme())
|
||||
|| WebSocketScheme.WSS.name().contentEquals(wsURL.getScheme()) ?
|
||||
return HttpScheme.HTTPS.name().contentEquals(scheme)
|
||||
|| WebSocketScheme.WSS.name().contentEquals(scheme) ?
|
||||
host : NetUtil.toSocketAddressString(host, port);
|
||||
}
|
||||
|
||||
|
@ -133,6 +133,12 @@ public class WebSocketServerHandshaker00 extends WebSocketServerHandshaker {
|
||||
boolean isHixie76 = req.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_KEY1) &&
|
||||
req.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_KEY2);
|
||||
|
||||
String origin = req.headers().get(HttpHeaderNames.ORIGIN);
|
||||
//throw before allocating FullHttpResponse
|
||||
if (origin == null && !isHixie76) {
|
||||
throw new WebSocketHandshakeException("Missing origin header, got only " + req.headers().names());
|
||||
}
|
||||
|
||||
// Create the WebSocket handshake response.
|
||||
FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, new HttpResponseStatus(101,
|
||||
isHixie76 ? "WebSocket Protocol Handshake" : "Web Socket Protocol Handshake"));
|
||||
@ -146,7 +152,7 @@ public class WebSocketServerHandshaker00 extends WebSocketServerHandshaker {
|
||||
// Fill in the headers and contents depending on handshake getMethod.
|
||||
if (isHixie76) {
|
||||
// New handshake getMethod with a challenge:
|
||||
res.headers().add(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, req.headers().get(HttpHeaderNames.ORIGIN));
|
||||
res.headers().add(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, origin);
|
||||
res.headers().add(HttpHeaderNames.SEC_WEBSOCKET_LOCATION, uri());
|
||||
|
||||
String subprotocols = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL);
|
||||
@ -176,10 +182,6 @@ public class WebSocketServerHandshaker00 extends WebSocketServerHandshaker {
|
||||
res.content().writeBytes(WebSocketUtil.md5(input.array()));
|
||||
} else {
|
||||
// Old Hixie 75 handshake getMethod with no challenge:
|
||||
String origin = req.headers().get(HttpHeaderNames.ORIGIN);
|
||||
if (origin == null) {
|
||||
throw new WebSocketHandshakeException("Missing origin header, got only " + req.headers().names());
|
||||
}
|
||||
res.headers().add(HttpHeaderNames.WEBSOCKET_ORIGIN, origin);
|
||||
res.headers().add(HttpHeaderNames.WEBSOCKET_LOCATION, uri());
|
||||
|
||||
|
@ -128,6 +128,10 @@ public class WebSocketServerHandshaker07 extends WebSocketServerHandshaker {
|
||||
*/
|
||||
@Override
|
||||
protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) {
|
||||
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
|
||||
if (key == null) {
|
||||
throw new WebSocketHandshakeException("not a WebSocket request: missing key");
|
||||
}
|
||||
|
||||
FullHttpResponse res =
|
||||
new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS);
|
||||
@ -136,10 +140,6 @@ public class WebSocketServerHandshaker07 extends WebSocketServerHandshaker {
|
||||
res.headers().add(headers);
|
||||
}
|
||||
|
||||
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
|
||||
if (key == null) {
|
||||
throw new WebSocketHandshakeException("not a WebSocket request: missing key");
|
||||
}
|
||||
String acceptSeed = key + WEBSOCKET_07_ACCEPT_GUID;
|
||||
byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII));
|
||||
String accept = WebSocketUtil.base64(sha1);
|
||||
|
@ -135,16 +135,17 @@ public class WebSocketServerHandshaker08 extends WebSocketServerHandshaker {
|
||||
*/
|
||||
@Override
|
||||
protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) {
|
||||
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
|
||||
if (key == null) {
|
||||
throw new WebSocketHandshakeException("not a WebSocket request: missing key");
|
||||
}
|
||||
|
||||
FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS);
|
||||
|
||||
if (headers != null) {
|
||||
res.headers().add(headers);
|
||||
}
|
||||
|
||||
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
|
||||
if (key == null) {
|
||||
throw new WebSocketHandshakeException("not a WebSocket request: missing key");
|
||||
}
|
||||
String acceptSeed = key + WEBSOCKET_08_ACCEPT_GUID;
|
||||
byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII));
|
||||
String accept = WebSocketUtil.base64(sha1);
|
||||
|
@ -134,15 +134,16 @@ public class WebSocketServerHandshaker13 extends WebSocketServerHandshaker {
|
||||
*/
|
||||
@Override
|
||||
protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) {
|
||||
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
|
||||
if (key == null) {
|
||||
throw new WebSocketHandshakeException("not a WebSocket request: missing key");
|
||||
}
|
||||
|
||||
FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS);
|
||||
if (headers != null) {
|
||||
res.headers().add(headers);
|
||||
}
|
||||
|
||||
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
|
||||
if (key == null) {
|
||||
throw new WebSocketHandshakeException("not a WebSocket request: missing key");
|
||||
}
|
||||
String acceptSeed = key + WEBSOCKET_13_ACCEPT_GUID;
|
||||
byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII));
|
||||
String accept = WebSocketUtil.base64(sha1);
|
||||
|
@ -213,13 +213,13 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
|
||||
ChannelPipeline cp = ctx.pipeline();
|
||||
if (cp.get(WebSocketServerProtocolHandshakeHandler.class) == null) {
|
||||
// Add the WebSocketHandshakeHandler before this one.
|
||||
ctx.pipeline().addBefore(ctx.name(), WebSocketServerProtocolHandshakeHandler.class.getName(),
|
||||
cp.addBefore(ctx.name(), WebSocketServerProtocolHandshakeHandler.class.getName(),
|
||||
new WebSocketServerProtocolHandshakeHandler(
|
||||
websocketPath, subprotocols, checkStartsWith, handshakeTimeoutMillis, decoderConfig));
|
||||
}
|
||||
if (cp.get(Utf8FrameValidator.class) == null) {
|
||||
// Add the UFT8 checking before this one.
|
||||
ctx.pipeline().addBefore(ctx.name(), Utf8FrameValidator.class.getName(),
|
||||
cp.addBefore(ctx.name(), Utf8FrameValidator.class.getName(),
|
||||
new Utf8FrameValidator());
|
||||
}
|
||||
}
|
||||
|
@ -44,8 +44,6 @@ import static io.netty.util.internal.ObjectUtil.*;
|
||||
*/
|
||||
class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private static final long DEFAULT_HANDSHAKE_TIMEOUT_MS = 10000L;
|
||||
|
||||
private final String websocketPath;
|
||||
private final String subprotocols;
|
||||
private final boolean checkStartsWith;
|
||||
|
Loading…
Reference in New Issue
Block a user