Cleanup in websockets, throw exception before allocating response if possible (#9361)

Motivation:

While fixing #9359 found few places that could be patched / improved separately.

Modification:

On handshake response generation - throw exception before allocating response objects if request is invalid.

Result:

No more leaks when exception is thrown.
This commit is contained in:
Dmitriy Dumanskiy 2019-07-16 14:12:17 +03:00 committed by Norman Maurer
parent 4f172c13bb
commit cd824e4e31
8 changed files with 34 additions and 33 deletions

View File

@ -138,10 +138,7 @@ public class CloseWebSocketFrame extends WebSocketFrame {
}
binaryData.readerIndex(0);
int statusCode = binaryData.readShort();
binaryData.readerIndex(0);
return statusCode;
return binaryData.getShort(0);
}
/**

View File

@ -248,11 +248,10 @@ public abstract class WebSocketClientHandshaker {
* the {@link ChannelPromise} to be notified when the opening handshake is sent
*/
public final ChannelFuture handshake(Channel channel, final ChannelPromise promise) {
FullHttpRequest request = newHandshakeRequest();
HttpResponseDecoder decoder = channel.pipeline().get(HttpResponseDecoder.class);
ChannelPipeline pipeline = channel.pipeline();
HttpResponseDecoder decoder = pipeline.get(HttpResponseDecoder.class);
if (decoder == null) {
HttpClientCodec codec = channel.pipeline().get(HttpClientCodec.class);
HttpClientCodec codec = pipeline.get(HttpClientCodec.class);
if (codec == null) {
promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
"a HttpResponseDecoder or HttpClientCodec"));
@ -260,6 +259,8 @@ public abstract class WebSocketClientHandshaker {
}
}
FullHttpRequest request = newHandshakeRequest();
channel.writeAndFlush(request).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
@ -584,14 +585,15 @@ public abstract class WebSocketClientHandshaker {
return wsURL.getHost();
}
String host = wsURL.getHost();
String scheme = wsURL.getScheme();
if (port == HttpScheme.HTTP.port()) {
return HttpScheme.HTTP.name().contentEquals(wsURL.getScheme())
|| WebSocketScheme.WS.name().contentEquals(wsURL.getScheme()) ?
return HttpScheme.HTTP.name().contentEquals(scheme)
|| WebSocketScheme.WS.name().contentEquals(scheme) ?
host : NetUtil.toSocketAddressString(host, port);
}
if (port == HttpScheme.HTTPS.port()) {
return HttpScheme.HTTPS.name().contentEquals(wsURL.getScheme())
|| WebSocketScheme.WSS.name().contentEquals(wsURL.getScheme()) ?
return HttpScheme.HTTPS.name().contentEquals(scheme)
|| WebSocketScheme.WSS.name().contentEquals(scheme) ?
host : NetUtil.toSocketAddressString(host, port);
}

View File

@ -133,6 +133,12 @@ public class WebSocketServerHandshaker00 extends WebSocketServerHandshaker {
boolean isHixie76 = req.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_KEY1) &&
req.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_KEY2);
String origin = req.headers().get(HttpHeaderNames.ORIGIN);
//throw before allocating FullHttpResponse
if (origin == null && !isHixie76) {
throw new WebSocketHandshakeException("Missing origin header, got only " + req.headers().names());
}
// Create the WebSocket handshake response.
FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, new HttpResponseStatus(101,
isHixie76 ? "WebSocket Protocol Handshake" : "Web Socket Protocol Handshake"));
@ -146,7 +152,7 @@ public class WebSocketServerHandshaker00 extends WebSocketServerHandshaker {
// Fill in the headers and contents depending on handshake getMethod.
if (isHixie76) {
// New handshake getMethod with a challenge:
res.headers().add(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, req.headers().get(HttpHeaderNames.ORIGIN));
res.headers().add(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, origin);
res.headers().add(HttpHeaderNames.SEC_WEBSOCKET_LOCATION, uri());
String subprotocols = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL);
@ -176,10 +182,6 @@ public class WebSocketServerHandshaker00 extends WebSocketServerHandshaker {
res.content().writeBytes(WebSocketUtil.md5(input.array()));
} else {
// Old Hixie 75 handshake getMethod with no challenge:
String origin = req.headers().get(HttpHeaderNames.ORIGIN);
if (origin == null) {
throw new WebSocketHandshakeException("Missing origin header, got only " + req.headers().names());
}
res.headers().add(HttpHeaderNames.WEBSOCKET_ORIGIN, origin);
res.headers().add(HttpHeaderNames.WEBSOCKET_LOCATION, uri());

View File

@ -128,6 +128,10 @@ public class WebSocketServerHandshaker07 extends WebSocketServerHandshaker {
*/
@Override
protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) {
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
if (key == null) {
throw new WebSocketHandshakeException("not a WebSocket request: missing key");
}
FullHttpResponse res =
new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS);
@ -136,10 +140,6 @@ public class WebSocketServerHandshaker07 extends WebSocketServerHandshaker {
res.headers().add(headers);
}
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
if (key == null) {
throw new WebSocketHandshakeException("not a WebSocket request: missing key");
}
String acceptSeed = key + WEBSOCKET_07_ACCEPT_GUID;
byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII));
String accept = WebSocketUtil.base64(sha1);

View File

@ -135,16 +135,17 @@ public class WebSocketServerHandshaker08 extends WebSocketServerHandshaker {
*/
@Override
protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) {
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
if (key == null) {
throw new WebSocketHandshakeException("not a WebSocket request: missing key");
}
FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS);
if (headers != null) {
res.headers().add(headers);
}
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
if (key == null) {
throw new WebSocketHandshakeException("not a WebSocket request: missing key");
}
String acceptSeed = key + WEBSOCKET_08_ACCEPT_GUID;
byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII));
String accept = WebSocketUtil.base64(sha1);

View File

@ -134,15 +134,16 @@ public class WebSocketServerHandshaker13 extends WebSocketServerHandshaker {
*/
@Override
protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) {
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
if (key == null) {
throw new WebSocketHandshakeException("not a WebSocket request: missing key");
}
FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS);
if (headers != null) {
res.headers().add(headers);
}
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
if (key == null) {
throw new WebSocketHandshakeException("not a WebSocket request: missing key");
}
String acceptSeed = key + WEBSOCKET_13_ACCEPT_GUID;
byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII));
String accept = WebSocketUtil.base64(sha1);

View File

@ -213,13 +213,13 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
ChannelPipeline cp = ctx.pipeline();
if (cp.get(WebSocketServerProtocolHandshakeHandler.class) == null) {
// Add the WebSocketHandshakeHandler before this one.
ctx.pipeline().addBefore(ctx.name(), WebSocketServerProtocolHandshakeHandler.class.getName(),
cp.addBefore(ctx.name(), WebSocketServerProtocolHandshakeHandler.class.getName(),
new WebSocketServerProtocolHandshakeHandler(
websocketPath, subprotocols, checkStartsWith, handshakeTimeoutMillis, decoderConfig));
}
if (cp.get(Utf8FrameValidator.class) == null) {
// Add the UFT8 checking before this one.
ctx.pipeline().addBefore(ctx.name(), Utf8FrameValidator.class.getName(),
cp.addBefore(ctx.name(), Utf8FrameValidator.class.getName(),
new Utf8FrameValidator());
}
}

View File

@ -44,8 +44,6 @@ import static io.netty.util.internal.ObjectUtil.*;
*/
class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapter {
private static final long DEFAULT_HANDSHAKE_TIMEOUT_MS = 10000L;
private final String websocketPath;
private final String subprotocols;
private final boolean checkStartsWith;