Cleanup in websockets, throw exception before allocating response if possible (#9361)
Motivation: While fixing #9359 found few places that could be patched / improved separately. Modification: On handshake response generation - throw exception before allocating response objects if request is invalid. Result: No more leaks when exception is thrown.
This commit is contained in:
parent
b0b02d51d2
commit
ef24732640
@ -138,10 +138,7 @@ public class CloseWebSocketFrame extends WebSocketFrame {
|
|||||||
}
|
}
|
||||||
|
|
||||||
binaryData.readerIndex(0);
|
binaryData.readerIndex(0);
|
||||||
int statusCode = binaryData.readShort();
|
return binaryData.getShort(0);
|
||||||
binaryData.readerIndex(0);
|
|
||||||
|
|
||||||
return statusCode;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -248,11 +248,10 @@ public abstract class WebSocketClientHandshaker {
|
|||||||
* the {@link ChannelPromise} to be notified when the opening handshake is sent
|
* the {@link ChannelPromise} to be notified when the opening handshake is sent
|
||||||
*/
|
*/
|
||||||
public final ChannelFuture handshake(Channel channel, final ChannelPromise promise) {
|
public final ChannelFuture handshake(Channel channel, final ChannelPromise promise) {
|
||||||
FullHttpRequest request = newHandshakeRequest();
|
ChannelPipeline pipeline = channel.pipeline();
|
||||||
|
HttpResponseDecoder decoder = pipeline.get(HttpResponseDecoder.class);
|
||||||
HttpResponseDecoder decoder = channel.pipeline().get(HttpResponseDecoder.class);
|
|
||||||
if (decoder == null) {
|
if (decoder == null) {
|
||||||
HttpClientCodec codec = channel.pipeline().get(HttpClientCodec.class);
|
HttpClientCodec codec = pipeline.get(HttpClientCodec.class);
|
||||||
if (codec == null) {
|
if (codec == null) {
|
||||||
promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
|
promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
|
||||||
"a HttpResponseDecoder or HttpClientCodec"));
|
"a HttpResponseDecoder or HttpClientCodec"));
|
||||||
@ -260,6 +259,8 @@ public abstract class WebSocketClientHandshaker {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FullHttpRequest request = newHandshakeRequest();
|
||||||
|
|
||||||
channel.writeAndFlush(request).addListener((ChannelFutureListener) future -> {
|
channel.writeAndFlush(request).addListener((ChannelFutureListener) future -> {
|
||||||
if (future.isSuccess()) {
|
if (future.isSuccess()) {
|
||||||
ChannelPipeline p = future.channel().pipeline();
|
ChannelPipeline p = future.channel().pipeline();
|
||||||
@ -559,14 +560,15 @@ public abstract class WebSocketClientHandshaker {
|
|||||||
return wsURL.getHost();
|
return wsURL.getHost();
|
||||||
}
|
}
|
||||||
String host = wsURL.getHost();
|
String host = wsURL.getHost();
|
||||||
|
String scheme = wsURL.getScheme();
|
||||||
if (port == HttpScheme.HTTP.port()) {
|
if (port == HttpScheme.HTTP.port()) {
|
||||||
return HttpScheme.HTTP.name().contentEquals(wsURL.getScheme())
|
return HttpScheme.HTTP.name().contentEquals(scheme)
|
||||||
|| WebSocketScheme.WS.name().contentEquals(wsURL.getScheme()) ?
|
|| WebSocketScheme.WS.name().contentEquals(scheme) ?
|
||||||
host : NetUtil.toSocketAddressString(host, port);
|
host : NetUtil.toSocketAddressString(host, port);
|
||||||
}
|
}
|
||||||
if (port == HttpScheme.HTTPS.port()) {
|
if (port == HttpScheme.HTTPS.port()) {
|
||||||
return HttpScheme.HTTPS.name().contentEquals(wsURL.getScheme())
|
return HttpScheme.HTTPS.name().contentEquals(scheme)
|
||||||
|| WebSocketScheme.WSS.name().contentEquals(wsURL.getScheme()) ?
|
|| WebSocketScheme.WSS.name().contentEquals(scheme) ?
|
||||||
host : NetUtil.toSocketAddressString(host, port);
|
host : NetUtil.toSocketAddressString(host, port);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -133,6 +133,12 @@ public class WebSocketServerHandshaker00 extends WebSocketServerHandshaker {
|
|||||||
boolean isHixie76 = req.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_KEY1) &&
|
boolean isHixie76 = req.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_KEY1) &&
|
||||||
req.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_KEY2);
|
req.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_KEY2);
|
||||||
|
|
||||||
|
String origin = req.headers().get(HttpHeaderNames.ORIGIN);
|
||||||
|
//throw before allocating FullHttpResponse
|
||||||
|
if (origin == null && !isHixie76) {
|
||||||
|
throw new WebSocketHandshakeException("Missing origin header, got only " + req.headers().names());
|
||||||
|
}
|
||||||
|
|
||||||
// Create the WebSocket handshake response.
|
// Create the WebSocket handshake response.
|
||||||
FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, new HttpResponseStatus(101,
|
FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, new HttpResponseStatus(101,
|
||||||
isHixie76 ? "WebSocket Protocol Handshake" : "Web Socket Protocol Handshake"));
|
isHixie76 ? "WebSocket Protocol Handshake" : "Web Socket Protocol Handshake"));
|
||||||
@ -146,7 +152,7 @@ public class WebSocketServerHandshaker00 extends WebSocketServerHandshaker {
|
|||||||
// Fill in the headers and contents depending on handshake getMethod.
|
// Fill in the headers and contents depending on handshake getMethod.
|
||||||
if (isHixie76) {
|
if (isHixie76) {
|
||||||
// New handshake getMethod with a challenge:
|
// New handshake getMethod with a challenge:
|
||||||
res.headers().add(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, req.headers().get(HttpHeaderNames.ORIGIN));
|
res.headers().add(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, origin);
|
||||||
res.headers().add(HttpHeaderNames.SEC_WEBSOCKET_LOCATION, uri());
|
res.headers().add(HttpHeaderNames.SEC_WEBSOCKET_LOCATION, uri());
|
||||||
|
|
||||||
String subprotocols = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL);
|
String subprotocols = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL);
|
||||||
@ -176,10 +182,6 @@ public class WebSocketServerHandshaker00 extends WebSocketServerHandshaker {
|
|||||||
res.content().writeBytes(WebSocketUtil.md5(input.array()));
|
res.content().writeBytes(WebSocketUtil.md5(input.array()));
|
||||||
} else {
|
} else {
|
||||||
// Old Hixie 75 handshake getMethod with no challenge:
|
// Old Hixie 75 handshake getMethod with no challenge:
|
||||||
String origin = req.headers().get(HttpHeaderNames.ORIGIN);
|
|
||||||
if (origin == null) {
|
|
||||||
throw new WebSocketHandshakeException("Missing origin header, got only " + req.headers().names());
|
|
||||||
}
|
|
||||||
res.headers().add(HttpHeaderNames.WEBSOCKET_ORIGIN, origin);
|
res.headers().add(HttpHeaderNames.WEBSOCKET_ORIGIN, origin);
|
||||||
res.headers().add(HttpHeaderNames.WEBSOCKET_LOCATION, uri());
|
res.headers().add(HttpHeaderNames.WEBSOCKET_LOCATION, uri());
|
||||||
|
|
||||||
|
@ -128,6 +128,10 @@ public class WebSocketServerHandshaker07 extends WebSocketServerHandshaker {
|
|||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) {
|
protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) {
|
||||||
|
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
|
||||||
|
if (key == null) {
|
||||||
|
throw new WebSocketHandshakeException("not a WebSocket request: missing key");
|
||||||
|
}
|
||||||
|
|
||||||
FullHttpResponse res =
|
FullHttpResponse res =
|
||||||
new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS);
|
new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS);
|
||||||
@ -136,10 +140,6 @@ public class WebSocketServerHandshaker07 extends WebSocketServerHandshaker {
|
|||||||
res.headers().add(headers);
|
res.headers().add(headers);
|
||||||
}
|
}
|
||||||
|
|
||||||
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
|
|
||||||
if (key == null) {
|
|
||||||
throw new WebSocketHandshakeException("not a WebSocket request: missing key");
|
|
||||||
}
|
|
||||||
String acceptSeed = key + WEBSOCKET_07_ACCEPT_GUID;
|
String acceptSeed = key + WEBSOCKET_07_ACCEPT_GUID;
|
||||||
byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII));
|
byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII));
|
||||||
String accept = WebSocketUtil.base64(sha1);
|
String accept = WebSocketUtil.base64(sha1);
|
||||||
|
@ -135,16 +135,17 @@ public class WebSocketServerHandshaker08 extends WebSocketServerHandshaker {
|
|||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) {
|
protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) {
|
||||||
|
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
|
||||||
|
if (key == null) {
|
||||||
|
throw new WebSocketHandshakeException("not a WebSocket request: missing key");
|
||||||
|
}
|
||||||
|
|
||||||
FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS);
|
FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS);
|
||||||
|
|
||||||
if (headers != null) {
|
if (headers != null) {
|
||||||
res.headers().add(headers);
|
res.headers().add(headers);
|
||||||
}
|
}
|
||||||
|
|
||||||
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
|
|
||||||
if (key == null) {
|
|
||||||
throw new WebSocketHandshakeException("not a WebSocket request: missing key");
|
|
||||||
}
|
|
||||||
String acceptSeed = key + WEBSOCKET_08_ACCEPT_GUID;
|
String acceptSeed = key + WEBSOCKET_08_ACCEPT_GUID;
|
||||||
byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII));
|
byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII));
|
||||||
String accept = WebSocketUtil.base64(sha1);
|
String accept = WebSocketUtil.base64(sha1);
|
||||||
|
@ -134,15 +134,16 @@ public class WebSocketServerHandshaker13 extends WebSocketServerHandshaker {
|
|||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) {
|
protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) {
|
||||||
|
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
|
||||||
|
if (key == null) {
|
||||||
|
throw new WebSocketHandshakeException("not a WebSocket request: missing key");
|
||||||
|
}
|
||||||
|
|
||||||
FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS);
|
FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS);
|
||||||
if (headers != null) {
|
if (headers != null) {
|
||||||
res.headers().add(headers);
|
res.headers().add(headers);
|
||||||
}
|
}
|
||||||
|
|
||||||
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
|
|
||||||
if (key == null) {
|
|
||||||
throw new WebSocketHandshakeException("not a WebSocket request: missing key");
|
|
||||||
}
|
|
||||||
String acceptSeed = key + WEBSOCKET_13_ACCEPT_GUID;
|
String acceptSeed = key + WEBSOCKET_13_ACCEPT_GUID;
|
||||||
byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII));
|
byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII));
|
||||||
String accept = WebSocketUtil.base64(sha1);
|
String accept = WebSocketUtil.base64(sha1);
|
||||||
|
@ -212,13 +212,13 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
|
|||||||
ChannelPipeline cp = ctx.pipeline();
|
ChannelPipeline cp = ctx.pipeline();
|
||||||
if (cp.get(WebSocketServerProtocolHandshakeHandler.class) == null) {
|
if (cp.get(WebSocketServerProtocolHandshakeHandler.class) == null) {
|
||||||
// Add the WebSocketHandshakeHandler before this one.
|
// Add the WebSocketHandshakeHandler before this one.
|
||||||
ctx.pipeline().addBefore(ctx.name(), WebSocketServerProtocolHandshakeHandler.class.getName(),
|
cp.addBefore(ctx.name(), WebSocketServerProtocolHandshakeHandler.class.getName(),
|
||||||
new WebSocketServerProtocolHandshakeHandler(
|
new WebSocketServerProtocolHandshakeHandler(
|
||||||
websocketPath, subprotocols, checkStartsWith, handshakeTimeoutMillis, decoderConfig));
|
websocketPath, subprotocols, checkStartsWith, handshakeTimeoutMillis, decoderConfig));
|
||||||
}
|
}
|
||||||
if (cp.get(Utf8FrameValidator.class) == null) {
|
if (cp.get(Utf8FrameValidator.class) == null) {
|
||||||
// Add the UFT8 checking before this one.
|
// Add the UFT8 checking before this one.
|
||||||
ctx.pipeline().addBefore(ctx.name(), Utf8FrameValidator.class.getName(),
|
cp.addBefore(ctx.name(), Utf8FrameValidator.class.getName(),
|
||||||
new Utf8FrameValidator());
|
new Utf8FrameValidator());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -44,8 +44,6 @@ import static io.netty.util.internal.ObjectUtil.*;
|
|||||||
*/
|
*/
|
||||||
class WebSocketServerProtocolHandshakeHandler implements ChannelInboundHandler {
|
class WebSocketServerProtocolHandshakeHandler implements ChannelInboundHandler {
|
||||||
|
|
||||||
private static final long DEFAULT_HANDSHAKE_TIMEOUT_MS = 10000L;
|
|
||||||
|
|
||||||
private final String websocketPath;
|
private final String websocketPath;
|
||||||
private final String subprotocols;
|
private final String subprotocols;
|
||||||
private final boolean checkStartsWith;
|
private final boolean checkStartsWith;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user