Cleanup in websockets, throw exception before allocating response if possible (#9361)

Motivation:

While fixing #9359 found few places that could be patched / improved separately.

Modification:

On handshake response generation - throw exception before allocating response objects if request is invalid.

Result:

No more leaks when exception is thrown.
This commit is contained in:
Dmitriy Dumanskiy 2019-07-16 14:12:17 +03:00 committed by Norman Maurer
parent b0b02d51d2
commit ef24732640
8 changed files with 34 additions and 33 deletions

View File

@ -138,10 +138,7 @@ public class CloseWebSocketFrame extends WebSocketFrame {
} }
binaryData.readerIndex(0); binaryData.readerIndex(0);
int statusCode = binaryData.readShort(); return binaryData.getShort(0);
binaryData.readerIndex(0);
return statusCode;
} }
/** /**

View File

@ -248,11 +248,10 @@ public abstract class WebSocketClientHandshaker {
* the {@link ChannelPromise} to be notified when the opening handshake is sent * the {@link ChannelPromise} to be notified when the opening handshake is sent
*/ */
public final ChannelFuture handshake(Channel channel, final ChannelPromise promise) { public final ChannelFuture handshake(Channel channel, final ChannelPromise promise) {
FullHttpRequest request = newHandshakeRequest(); ChannelPipeline pipeline = channel.pipeline();
HttpResponseDecoder decoder = pipeline.get(HttpResponseDecoder.class);
HttpResponseDecoder decoder = channel.pipeline().get(HttpResponseDecoder.class);
if (decoder == null) { if (decoder == null) {
HttpClientCodec codec = channel.pipeline().get(HttpClientCodec.class); HttpClientCodec codec = pipeline.get(HttpClientCodec.class);
if (codec == null) { if (codec == null) {
promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " + promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
"a HttpResponseDecoder or HttpClientCodec")); "a HttpResponseDecoder or HttpClientCodec"));
@ -260,6 +259,8 @@ public abstract class WebSocketClientHandshaker {
} }
} }
FullHttpRequest request = newHandshakeRequest();
channel.writeAndFlush(request).addListener((ChannelFutureListener) future -> { channel.writeAndFlush(request).addListener((ChannelFutureListener) future -> {
if (future.isSuccess()) { if (future.isSuccess()) {
ChannelPipeline p = future.channel().pipeline(); ChannelPipeline p = future.channel().pipeline();
@ -559,14 +560,15 @@ public abstract class WebSocketClientHandshaker {
return wsURL.getHost(); return wsURL.getHost();
} }
String host = wsURL.getHost(); String host = wsURL.getHost();
String scheme = wsURL.getScheme();
if (port == HttpScheme.HTTP.port()) { if (port == HttpScheme.HTTP.port()) {
return HttpScheme.HTTP.name().contentEquals(wsURL.getScheme()) return HttpScheme.HTTP.name().contentEquals(scheme)
|| WebSocketScheme.WS.name().contentEquals(wsURL.getScheme()) ? || WebSocketScheme.WS.name().contentEquals(scheme) ?
host : NetUtil.toSocketAddressString(host, port); host : NetUtil.toSocketAddressString(host, port);
} }
if (port == HttpScheme.HTTPS.port()) { if (port == HttpScheme.HTTPS.port()) {
return HttpScheme.HTTPS.name().contentEquals(wsURL.getScheme()) return HttpScheme.HTTPS.name().contentEquals(scheme)
|| WebSocketScheme.WSS.name().contentEquals(wsURL.getScheme()) ? || WebSocketScheme.WSS.name().contentEquals(scheme) ?
host : NetUtil.toSocketAddressString(host, port); host : NetUtil.toSocketAddressString(host, port);
} }

View File

@ -133,6 +133,12 @@ public class WebSocketServerHandshaker00 extends WebSocketServerHandshaker {
boolean isHixie76 = req.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_KEY1) && boolean isHixie76 = req.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_KEY1) &&
req.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_KEY2); req.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_KEY2);
String origin = req.headers().get(HttpHeaderNames.ORIGIN);
//throw before allocating FullHttpResponse
if (origin == null && !isHixie76) {
throw new WebSocketHandshakeException("Missing origin header, got only " + req.headers().names());
}
// Create the WebSocket handshake response. // Create the WebSocket handshake response.
FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, new HttpResponseStatus(101, FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, new HttpResponseStatus(101,
isHixie76 ? "WebSocket Protocol Handshake" : "Web Socket Protocol Handshake")); isHixie76 ? "WebSocket Protocol Handshake" : "Web Socket Protocol Handshake"));
@ -146,7 +152,7 @@ public class WebSocketServerHandshaker00 extends WebSocketServerHandshaker {
// Fill in the headers and contents depending on handshake getMethod. // Fill in the headers and contents depending on handshake getMethod.
if (isHixie76) { if (isHixie76) {
// New handshake getMethod with a challenge: // New handshake getMethod with a challenge:
res.headers().add(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, req.headers().get(HttpHeaderNames.ORIGIN)); res.headers().add(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, origin);
res.headers().add(HttpHeaderNames.SEC_WEBSOCKET_LOCATION, uri()); res.headers().add(HttpHeaderNames.SEC_WEBSOCKET_LOCATION, uri());
String subprotocols = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL); String subprotocols = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL);
@ -176,10 +182,6 @@ public class WebSocketServerHandshaker00 extends WebSocketServerHandshaker {
res.content().writeBytes(WebSocketUtil.md5(input.array())); res.content().writeBytes(WebSocketUtil.md5(input.array()));
} else { } else {
// Old Hixie 75 handshake getMethod with no challenge: // Old Hixie 75 handshake getMethod with no challenge:
String origin = req.headers().get(HttpHeaderNames.ORIGIN);
if (origin == null) {
throw new WebSocketHandshakeException("Missing origin header, got only " + req.headers().names());
}
res.headers().add(HttpHeaderNames.WEBSOCKET_ORIGIN, origin); res.headers().add(HttpHeaderNames.WEBSOCKET_ORIGIN, origin);
res.headers().add(HttpHeaderNames.WEBSOCKET_LOCATION, uri()); res.headers().add(HttpHeaderNames.WEBSOCKET_LOCATION, uri());

View File

@ -128,6 +128,10 @@ public class WebSocketServerHandshaker07 extends WebSocketServerHandshaker {
*/ */
@Override @Override
protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) { protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) {
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
if (key == null) {
throw new WebSocketHandshakeException("not a WebSocket request: missing key");
}
FullHttpResponse res = FullHttpResponse res =
new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS); new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS);
@ -136,10 +140,6 @@ public class WebSocketServerHandshaker07 extends WebSocketServerHandshaker {
res.headers().add(headers); res.headers().add(headers);
} }
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
if (key == null) {
throw new WebSocketHandshakeException("not a WebSocket request: missing key");
}
String acceptSeed = key + WEBSOCKET_07_ACCEPT_GUID; String acceptSeed = key + WEBSOCKET_07_ACCEPT_GUID;
byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII)); byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII));
String accept = WebSocketUtil.base64(sha1); String accept = WebSocketUtil.base64(sha1);

View File

@ -135,16 +135,17 @@ public class WebSocketServerHandshaker08 extends WebSocketServerHandshaker {
*/ */
@Override @Override
protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) { protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) {
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
if (key == null) {
throw new WebSocketHandshakeException("not a WebSocket request: missing key");
}
FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS); FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS);
if (headers != null) { if (headers != null) {
res.headers().add(headers); res.headers().add(headers);
} }
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
if (key == null) {
throw new WebSocketHandshakeException("not a WebSocket request: missing key");
}
String acceptSeed = key + WEBSOCKET_08_ACCEPT_GUID; String acceptSeed = key + WEBSOCKET_08_ACCEPT_GUID;
byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII)); byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII));
String accept = WebSocketUtil.base64(sha1); String accept = WebSocketUtil.base64(sha1);

View File

@ -134,15 +134,16 @@ public class WebSocketServerHandshaker13 extends WebSocketServerHandshaker {
*/ */
@Override @Override
protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) { protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) {
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
if (key == null) {
throw new WebSocketHandshakeException("not a WebSocket request: missing key");
}
FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS); FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS);
if (headers != null) { if (headers != null) {
res.headers().add(headers); res.headers().add(headers);
} }
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
if (key == null) {
throw new WebSocketHandshakeException("not a WebSocket request: missing key");
}
String acceptSeed = key + WEBSOCKET_13_ACCEPT_GUID; String acceptSeed = key + WEBSOCKET_13_ACCEPT_GUID;
byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII)); byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII));
String accept = WebSocketUtil.base64(sha1); String accept = WebSocketUtil.base64(sha1);

View File

@ -212,13 +212,13 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
ChannelPipeline cp = ctx.pipeline(); ChannelPipeline cp = ctx.pipeline();
if (cp.get(WebSocketServerProtocolHandshakeHandler.class) == null) { if (cp.get(WebSocketServerProtocolHandshakeHandler.class) == null) {
// Add the WebSocketHandshakeHandler before this one. // Add the WebSocketHandshakeHandler before this one.
ctx.pipeline().addBefore(ctx.name(), WebSocketServerProtocolHandshakeHandler.class.getName(), cp.addBefore(ctx.name(), WebSocketServerProtocolHandshakeHandler.class.getName(),
new WebSocketServerProtocolHandshakeHandler( new WebSocketServerProtocolHandshakeHandler(
websocketPath, subprotocols, checkStartsWith, handshakeTimeoutMillis, decoderConfig)); websocketPath, subprotocols, checkStartsWith, handshakeTimeoutMillis, decoderConfig));
} }
if (cp.get(Utf8FrameValidator.class) == null) { if (cp.get(Utf8FrameValidator.class) == null) {
// Add the UFT8 checking before this one. // Add the UFT8 checking before this one.
ctx.pipeline().addBefore(ctx.name(), Utf8FrameValidator.class.getName(), cp.addBefore(ctx.name(), Utf8FrameValidator.class.getName(),
new Utf8FrameValidator()); new Utf8FrameValidator());
} }
} }

View File

@ -44,8 +44,6 @@ import static io.netty.util.internal.ObjectUtil.*;
*/ */
class WebSocketServerProtocolHandshakeHandler implements ChannelInboundHandler { class WebSocketServerProtocolHandshakeHandler implements ChannelInboundHandler {
private static final long DEFAULT_HANDSHAKE_TIMEOUT_MS = 10000L;
private final String websocketPath; private final String websocketPath;
private final String subprotocols; private final String subprotocols;
private final boolean checkStartsWith; private final boolean checkStartsWith;