Add max frame length for web socket to limit chance of DOS attack (#283)

- Contributed by @veebs
This commit is contained in:
Trustin Lee 2012-05-30 17:13:00 -07:00
parent ec43aa121f
commit 67ee22e23a
17 changed files with 164 additions and 58 deletions

View File

@ -34,7 +34,7 @@ public class WebSocket00FrameDecoder extends ReplayingDecoder<WebSocketFrame, Vo
private static final int DEFAULT_MAX_FRAME_SIZE = 16384; private static final int DEFAULT_MAX_FRAME_SIZE = 16384;
private final int maxFrameSize; private final long maxFrameSize;
private boolean receivedClosingHandshake; private boolean receivedClosingHandshake;
public WebSocket00FrameDecoder() { public WebSocket00FrameDecoder() {
@ -48,7 +48,7 @@ public class WebSocket00FrameDecoder extends ReplayingDecoder<WebSocketFrame, Vo
* @param maxFrameSize * @param maxFrameSize
* the maximum frame size to decode * the maximum frame size to decode
*/ */
public WebSocket00FrameDecoder(int maxFrameSize) { public WebSocket00FrameDecoder(long maxFrameSize) {
this.maxFrameSize = maxFrameSize; this.maxFrameSize = maxFrameSize;
} }

View File

@ -81,6 +81,7 @@ public class WebSocket08FrameDecoder extends ReplayingDecoder<WebSocketFrame, We
private UTF8Output fragmentedFramesText; private UTF8Output fragmentedFramesText;
private int fragmentedFramesCount; private int fragmentedFramesCount;
private final long maxFramePayloadLength;
private boolean frameFinalFlag; private boolean frameFinalFlag;
private int frameRsv; private int frameRsv;
private int frameOpcode; private int frameOpcode;
@ -105,11 +106,15 @@ public class WebSocket08FrameDecoder extends ReplayingDecoder<WebSocketFrame, We
* must set this to false. * must set this to false.
* @param allowExtensions * @param allowExtensions
* Flag to allow reserved extension bits to be used or not * Flag to allow reserved extension bits to be used or not
* @param maxFramePayloadLength
* Maximum length of a frame's payload. Setting this to an appropriate value for you application
* helps check for denial of services attacks.
*/ */
public WebSocket08FrameDecoder(boolean maskedPayload, boolean allowExtensions) { public WebSocket08FrameDecoder(boolean maskedPayload, boolean allowExtensions, long maxFramePayloadLength) {
super(State.FRAME_START); super(State.FRAME_START);
this.maskedPayload = maskedPayload; this.maskedPayload = maskedPayload;
this.allowExtensions = allowExtensions; this.allowExtensions = allowExtensions;
this.maxFramePayloadLength = maxFramePayloadLength;
} }
@Override @Override
@ -219,6 +224,11 @@ public class WebSocket08FrameDecoder extends ReplayingDecoder<WebSocketFrame, We
framePayloadLength = framePayloadLen1; framePayloadLength = framePayloadLen1;
} }
if (framePayloadLength > maxFramePayloadLength) {
protocolViolation(ctx, "Max frame length of " + maxFramePayloadLength + " has been exceeded.");
return null;
}
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("Decoding WebSocket Frame length=" + framePayloadLength); logger.debug("Decoding WebSocket Frame length=" + framePayloadLength);
} }
@ -235,7 +245,7 @@ public class WebSocket08FrameDecoder extends ReplayingDecoder<WebSocketFrame, We
int rbytes = actualReadableBytes(); int rbytes = actualReadableBytes();
ChannelBuffer payloadBuffer = null; ChannelBuffer payloadBuffer = null;
int willHaveReadByteCount = framePayloadBytesRead + rbytes; long willHaveReadByteCount = framePayloadBytesRead + rbytes;
// logger.debug("Frame rbytes=" + rbytes + " willHaveReadByteCount=" // logger.debug("Frame rbytes=" + rbytes + " willHaveReadByteCount="
// + willHaveReadByteCount + " framePayloadLength=" + // + willHaveReadByteCount + " framePayloadLength=" +
// framePayloadLength); // framePayloadLength);

View File

@ -66,8 +66,11 @@ public class WebSocket13FrameDecoder extends WebSocket08FrameDecoder {
* must set this to false. * must set this to false.
* @param allowExtensions * @param allowExtensions
* Flag to allow reserved extension bits to be used or not * Flag to allow reserved extension bits to be used or not
* @param maxFramePayloadLength
* Maximum length of a frame's payload. Setting this to an appropriate value for you application
* helps check for denial of services attacks.
*/ */
public WebSocket13FrameDecoder(boolean maskedPayload, boolean allowExtensions) { public WebSocket13FrameDecoder(boolean maskedPayload, boolean allowExtensions, long maxFramePayloadLength) {
super(maskedPayload, allowExtensions); super(maskedPayload, allowExtensions, maxFramePayloadLength);
} }
} }

View File

@ -15,13 +15,13 @@
*/ */
package io.netty.handler.codec.http.websocketx; package io.netty.handler.codec.http.websocketx;
import java.net.URI;
import java.util.Map;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponse;
import java.net.URI;
import java.util.Map;
/** /**
* Base class for web socket client handshake implementations * Base class for web socket client handshake implementations
*/ */
@ -39,6 +39,8 @@ public abstract class WebSocketClientHandshaker {
protected final Map<String, String> customHeaders; protected final Map<String, String> customHeaders;
private final long maxFramePayloadLength;
/** /**
* Base constructor * Base constructor
* *
@ -51,13 +53,16 @@ public abstract class WebSocketClientHandshaker {
* Sub protocol request sent to the server. * Sub protocol request sent to the server.
* @param customHeaders * @param customHeaders
* Map of custom headers to add to the client request * Map of custom headers to add to the client request
* @param maxFramePayloadLength
* Maximum length of a frame's payload
*/ */
public WebSocketClientHandshaker(URI webSocketUrl, WebSocketVersion version, String subprotocol, public WebSocketClientHandshaker(URI webSocketUrl, WebSocketVersion version, String subprotocol,
Map<String, String> customHeaders) { Map<String, String> customHeaders, long maxFramePayloadLength) {
this.webSocketUrl = webSocketUrl; this.webSocketUrl = webSocketUrl;
this.version = version; this.version = version;
expectedSubprotocol = subprotocol; expectedSubprotocol = subprotocol;
this.customHeaders = customHeaders; this.customHeaders = customHeaders;
this.maxFramePayloadLength = maxFramePayloadLength;
} }
/** /**
@ -74,6 +79,13 @@ public abstract class WebSocketClientHandshaker {
return version; return version;
} }
/**
* Returns the max length for any frame's payload
*/
public long getMaxFramePayloadLength() {
return maxFramePayloadLength;
}
/** /**
* Flag to indicate if the opening handshake is complete * Flag to indicate if the opening handshake is complete
*/ */

View File

@ -60,10 +60,12 @@ public class WebSocketClientHandshaker00 extends WebSocketClientHandshaker {
* Sub protocol request sent to the server. * Sub protocol request sent to the server.
* @param customHeaders * @param customHeaders
* Map of custom headers to add to the client request * Map of custom headers to add to the client request
* @param maxFramePayloadLength
* Maximum length of a frame's payload
*/ */
public WebSocketClientHandshaker00(URI webSocketURL, WebSocketVersion version, String subprotocol, public WebSocketClientHandshaker00(URI webSocketURL, WebSocketVersion version, String subprotocol,
Map<String, String> customHeaders) { Map<String, String> customHeaders, long maxFramePayloadLength) {
super(webSocketURL, version, subprotocol, customHeaders); super(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength);
} }
@ -219,7 +221,9 @@ public class WebSocketClientHandshaker00 extends WebSocketClientHandshaker {
String protocol = response.getHeader(Names.SEC_WEBSOCKET_PROTOCOL); String protocol = response.getHeader(Names.SEC_WEBSOCKET_PROTOCOL);
setActualSubprotocol(protocol); setActualSubprotocol(protocol);
channel.pipeline().replace(HttpResponseDecoder.class, "ws-decoder", new WebSocket00FrameDecoder()); channel.pipeline().replace(
HttpResponseDecoder.class, "ws-decoder",
new WebSocket00FrameDecoder(getMaxFramePayloadLength()));
setHandshakeComplete(); setHandshakeComplete();
} }

View File

@ -54,7 +54,7 @@ public class WebSocketClientHandshaker08 extends WebSocketClientHandshaker {
private final boolean allowExtensions; private final boolean allowExtensions;
/** /**
* Constructor specifying the destination web socket location and version to initiate * Creates a new instance.
* *
* @param webSocketURL * @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
@ -67,10 +67,12 @@ public class WebSocketClientHandshaker08 extends WebSocketClientHandshaker {
* Allow extensions to be used in the reserved bits of the web socket frame * Allow extensions to be used in the reserved bits of the web socket frame
* @param customHeaders * @param customHeaders
* Map of custom headers to add to the client request * Map of custom headers to add to the client request
* @param maxFramePayloadLength
* Maximum length of a frame's payload
*/ */
public WebSocketClientHandshaker08(URI webSocketURL, WebSocketVersion version, String subprotocol, public WebSocketClientHandshaker08(URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, Map<String, String> customHeaders) { boolean allowExtensions, Map<String, String> customHeaders, long maxFramePayloadLength) {
super(webSocketURL, version, subprotocol, customHeaders); super(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength);
this.allowExtensions = allowExtensions; this.allowExtensions = allowExtensions;
} }
@ -196,7 +198,7 @@ public class WebSocketClientHandshaker08 extends WebSocketClientHandshaker {
} }
channel.pipeline().replace(HttpResponseDecoder.class, "ws-decoder", channel.pipeline().replace(HttpResponseDecoder.class, "ws-decoder",
new WebSocket08FrameDecoder(false, allowExtensions)); new WebSocket08FrameDecoder(false, allowExtensions, getMaxFramePayloadLength()));
setHandshakeComplete(); setHandshakeComplete();
} }

View File

@ -54,7 +54,7 @@ public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker {
private final boolean allowExtensions; private final boolean allowExtensions;
/** /**
* Constructor specifying the destination web socket location and version to initiate * Creates a new instance.
* *
* @param webSocketURL * @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
@ -67,10 +67,12 @@ public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker {
* Allow extensions to be used in the reserved bits of the web socket frame * Allow extensions to be used in the reserved bits of the web socket frame
* @param customHeaders * @param customHeaders
* Map of custom headers to add to the client request * Map of custom headers to add to the client request
* @param maxFramePayloadLength
* Maximum length of a frame's payload
*/ */
public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, String subprotocol, public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, Map<String, String> customHeaders) { boolean allowExtensions, Map<String, String> customHeaders, long maxFramePayloadLength) {
super(webSocketURL, version, subprotocol, customHeaders); super(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength);
this.allowExtensions = allowExtensions; this.allowExtensions = allowExtensions;
} }
@ -196,7 +198,7 @@ public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker {
} }
channel.pipeline().replace(HttpResponseDecoder.class, "ws-decoder", channel.pipeline().replace(HttpResponseDecoder.class, "ws-decoder",
new WebSocket13FrameDecoder(false, allowExtensions)); new WebSocket13FrameDecoder(false, allowExtensions, getMaxFramePayloadLength()));
setHandshakeComplete(); setHandshakeComplete();
} }

View File

@ -24,7 +24,7 @@ import java.util.Map;
public class WebSocketClientHandshakerFactory { public class WebSocketClientHandshakerFactory {
/** /**
* Instances a new handshaker * Creates a new handshaker.
* *
* @param webSocketURL * @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
@ -37,21 +37,44 @@ public class WebSocketClientHandshakerFactory {
* Allow extensions to be used in the reserved bits of the web socket frame * Allow extensions to be used in the reserved bits of the web socket frame
* @param customHeaders * @param customHeaders
* Custom HTTP headers to send during the handshake * Custom HTTP headers to send during the handshake
* @throws WebSocketHandshakeException
*/ */
public WebSocketClientHandshaker newHandshaker(URI webSocketURL, WebSocketVersion version, String subprotocol, public WebSocketClientHandshaker newHandshaker(
boolean allowExtensions, Map<String, String> customHeaders) throws WebSocketHandshakeException { URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, Map<String, String> customHeaders) {
return newHandshaker(webSocketURL, version, subprotocol, allowExtensions, customHeaders, 65536);
}
/**
* Creates a new handshaker.
*
* @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
* @param version
* Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server. Null if no sub-protocol support is required.
* @param allowExtensions
* Allow extensions to be used in the reserved bits of the web socket frame
* @param customHeaders
* Custom HTTP headers to send during the handshake
* @param maxFramePayloadLength
* Maximum allowable frame payload length. Setting this value to your application's requirement may
* reduce denial of service attacks using long data frames.
*/
public WebSocketClientHandshaker newHandshaker(
URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, Map<String, String> customHeaders, long maxFramePayloadLength) {
if (version == WebSocketVersion.V13) { if (version == WebSocketVersion.V13) {
return new WebSocketClientHandshaker13(webSocketURL, version, subprotocol, allowExtensions, customHeaders); return new WebSocketClientHandshaker13(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength);
} }
if (version == WebSocketVersion.V08) { if (version == WebSocketVersion.V08) {
return new WebSocketClientHandshaker08(webSocketURL, version, subprotocol, allowExtensions, customHeaders); return new WebSocketClientHandshaker08(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength);
} }
if (version == WebSocketVersion.V00) { if (version == WebSocketVersion.V00) {
return new WebSocketClientHandshaker00(webSocketURL, version, subprotocol, customHeaders); return new WebSocketClientHandshaker00(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength);
} }
throw new WebSocketHandshakeException("Protocol version " + version.toString() + " not supported."); throw new WebSocketHandshakeException("Protocol version " + version.toString() + " not supported.");
} }
} }

View File

@ -15,13 +15,13 @@
*/ */
package io.netty.handler.codec.http.websocketx; package io.netty.handler.codec.http.websocketx;
import java.util.LinkedHashSet;
import java.util.Set;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpRequest;
import java.util.LinkedHashSet;
import java.util.Set;
/** /**
* Base class for server side web socket opening and closing handshakes * Base class for server side web socket opening and closing handshakes
*/ */
@ -33,6 +33,8 @@ public abstract class WebSocketServerHandshaker {
private final WebSocketVersion version; private final WebSocketVersion version;
private final long maxFramePayloadLength;
/** /**
* Constructor specifying the destination web socket location * Constructor specifying the destination web socket location
* *
@ -43,9 +45,12 @@ public abstract class WebSocketServerHandshaker {
* sent to this URL. * sent to this URL.
* @param subprotocols * @param subprotocols
* CSV of supported protocols. Null if sub protocols not supported. * CSV of supported protocols. Null if sub protocols not supported.
* @param maxFramePayloadLength
* Maximum length of a frame's payload
*/ */
protected WebSocketServerHandshaker( protected WebSocketServerHandshaker(
WebSocketVersion version, String webSocketUrl, String subprotocols) { WebSocketVersion version, String webSocketUrl, String subprotocols,
long maxFramePayloadLength) {
this.version = version; this.version = version;
this.webSocketUrl = webSocketUrl; this.webSocketUrl = webSocketUrl;
if (subprotocols != null) { if (subprotocols != null) {
@ -57,6 +62,7 @@ public abstract class WebSocketServerHandshaker {
} else { } else {
this.subprotocols = new String[0]; this.subprotocols = new String[0];
} }
this.maxFramePayloadLength = maxFramePayloadLength;
} }
/** /**
@ -71,7 +77,7 @@ public abstract class WebSocketServerHandshaker {
*/ */
public Set<String> getSubprotocols() { public Set<String> getSubprotocols() {
Set<String> ret = new LinkedHashSet<String>(); Set<String> ret = new LinkedHashSet<String>();
for (String p: this.subprotocols) { for (String p: subprotocols) {
ret.add(p); ret.add(p);
} }
return ret; return ret;
@ -84,6 +90,13 @@ public abstract class WebSocketServerHandshaker {
return version; return version;
} }
/**
* Returns the max length for any frame's payload.
*/
public long getMaxFramePayloadLength() {
return maxFramePayloadLength;
}
/** /**
* Performs the opening handshake * Performs the opening handshake
* *

View File

@ -57,9 +57,12 @@ public class WebSocketServerHandshaker00 extends WebSocketServerHandshaker {
* sent to this URL. * sent to this URL.
* @param subprotocols * @param subprotocols
* CSV of supported protocols * CSV of supported protocols
* @param maxFramePayloadLength
* Maximum allowable frame payload length. Setting this value to your application's requirement may
* reduce denial of service attacks using long data frames.
*/ */
public WebSocketServerHandshaker00(String webSocketURL, String subprotocols) { public WebSocketServerHandshaker00(String webSocketURL, String subprotocols, long maxFramePayloadLength) {
super(WebSocketVersion.V00, webSocketURL, subprotocols); super(WebSocketVersion.V00, webSocketURL, subprotocols, maxFramePayloadLength);
} }
/** /**
@ -166,7 +169,8 @@ public class WebSocketServerHandshaker00 extends WebSocketServerHandshaker {
if (p.get(HttpChunkAggregator.class) != null) { if (p.get(HttpChunkAggregator.class) != null) {
p.remove(HttpChunkAggregator.class); p.remove(HttpChunkAggregator.class);
} }
p.replace(HttpRequestDecoder.class, "wsdecoder", new WebSocket00FrameDecoder()); p.replace(HttpRequestDecoder.class, "wsdecoder",
new WebSocket00FrameDecoder(getMaxFramePayloadLength()));
ChannelFuture future = channel.write(res); ChannelFuture future = channel.write(res);

View File

@ -58,9 +58,12 @@ public class WebSocketServerHandshaker08 extends WebSocketServerHandshaker {
* CSV of supported protocols * CSV of supported protocols
* @param allowExtensions * @param allowExtensions
* Allow extensions to be used in the reserved bits of the web socket frame * Allow extensions to be used in the reserved bits of the web socket frame
* @param maxFramePayloadLength
* Maximum allowable frame payload length. Setting this value to your application's requirement may
* reduce denial of service attacks using long data frames.
*/ */
public WebSocketServerHandshaker08(String webSocketURL, String subprotocols, boolean allowExtensions) { public WebSocketServerHandshaker08(String webSocketURL, String subprotocols, boolean allowExtensions, long maxFramePayloadLength) {
super(WebSocketVersion.V08, webSocketURL, subprotocols); super(WebSocketVersion.V08, webSocketURL, subprotocols, maxFramePayloadLength);
this.allowExtensions = allowExtensions; this.allowExtensions = allowExtensions;
} }
@ -141,7 +144,8 @@ public class WebSocketServerHandshaker08 extends WebSocketServerHandshaker {
p.remove(HttpChunkAggregator.class); p.remove(HttpChunkAggregator.class);
} }
p.replace(HttpRequestDecoder.class, "wsdecoder", new WebSocket08FrameDecoder(true, allowExtensions)); p.replace(HttpRequestDecoder.class, "wsdecoder",
new WebSocket08FrameDecoder(true, allowExtensions, getMaxFramePayloadLength()));
p.replace(HttpResponseEncoder.class, "wsencoder", new WebSocket08FrameEncoder(false)); p.replace(HttpResponseEncoder.class, "wsencoder", new WebSocket08FrameEncoder(false));
return future; return future;

View File

@ -59,9 +59,12 @@ public class WebSocketServerHandshaker13 extends WebSocketServerHandshaker {
* CSV of supported protocols * CSV of supported protocols
* @param allowExtensions * @param allowExtensions
* Allow extensions to be used in the reserved bits of the web socket frame * Allow extensions to be used in the reserved bits of the web socket frame
* @param maxFramePayloadLength
* Maximum allowable frame payload length. Setting this value to your application's requirement may
* reduce denial of service attacks using long data frames.
*/ */
public WebSocketServerHandshaker13(String webSocketURL, String subprotocols, boolean allowExtensions) { public WebSocketServerHandshaker13(String webSocketURL, String subprotocols, boolean allowExtensions, long maxFramePayloadLength) {
super(WebSocketVersion.V13, webSocketURL, subprotocols); super(WebSocketVersion.V13, webSocketURL, subprotocols, maxFramePayloadLength);
this.allowExtensions = allowExtensions; this.allowExtensions = allowExtensions;
} }
@ -142,7 +145,8 @@ public class WebSocketServerHandshaker13 extends WebSocketServerHandshaker {
p.remove(HttpChunkAggregator.class); p.remove(HttpChunkAggregator.class);
} }
p.replace(HttpRequestDecoder.class, "wsdecoder", new WebSocket13FrameDecoder(true, allowExtensions)); p.replace(HttpRequestDecoder.class, "wsdecoder",
new WebSocket13FrameDecoder(true, allowExtensions, getMaxFramePayloadLength()));
p.replace(HttpResponseEncoder.class, "wsencoder", new WebSocket13FrameEncoder(false)); p.replace(HttpResponseEncoder.class, "wsencoder", new WebSocket13FrameEncoder(false));
return future; return future;

View File

@ -17,11 +17,11 @@ package io.netty.handler.codec.http.websocketx;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.handler.codec.http.DefaultHttpResponse; import io.netty.handler.codec.http.DefaultHttpResponse;
import io.netty.handler.codec.http.HttpHeaders.Names;
import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion; import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.HttpHeaders.Names;
/** /**
* Instances the appropriate handshake class to use for servers * Instances the appropriate handshake class to use for servers
@ -34,6 +34,8 @@ public class WebSocketServerHandshakerFactory {
private final boolean allowExtensions; private final boolean allowExtensions;
private final long maxFramePayloadLength;
/** /**
* Constructor specifying the destination web socket location * Constructor specifying the destination web socket location
* *
@ -45,10 +47,32 @@ public class WebSocketServerHandshakerFactory {
* @param allowExtensions * @param allowExtensions
* Allow extensions to be used in the reserved bits of the web socket frame * Allow extensions to be used in the reserved bits of the web socket frame
*/ */
public WebSocketServerHandshakerFactory(String webSocketURL, String subprotocols, boolean allowExtensions) { public WebSocketServerHandshakerFactory(
String webSocketURL, String subprotocols, boolean allowExtensions) {
this(webSocketURL, subprotocols, allowExtensions, 65536);
}
/**
* Constructor specifying the destination web socket location
*
* @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
* @param subprotocols
* CSV of supported protocols. Null if sub protocols not supported.
* @param allowExtensions
* Allow extensions to be used in the reserved bits of the web socket frame
* @param maxFramePayloadLength
* Maximum allowable frame payload length. Setting this value to your application's requirement may
* reduce denial of service attacks using long data frames.
*/
public WebSocketServerHandshakerFactory(
String webSocketURL, String subprotocols, boolean allowExtensions,
long maxFramePayloadLength) {
this.webSocketURL = webSocketURL; this.webSocketURL = webSocketURL;
this.subprotocols = subprotocols; this.subprotocols = subprotocols;
this.allowExtensions = allowExtensions; this.allowExtensions = allowExtensions;
this.maxFramePayloadLength = maxFramePayloadLength;
} }
/** /**
@ -63,16 +87,16 @@ public class WebSocketServerHandshakerFactory {
if (version != null) { if (version != null) {
if (version.equals(WebSocketVersion.V13.toHttpHeaderValue())) { if (version.equals(WebSocketVersion.V13.toHttpHeaderValue())) {
// Version 13 of the wire protocol - RFC 6455 (version 17 of the draft hybi specification). // Version 13 of the wire protocol - RFC 6455 (version 17 of the draft hybi specification).
return new WebSocketServerHandshaker13(webSocketURL, subprotocols, allowExtensions); return new WebSocketServerHandshaker13(webSocketURL, subprotocols, allowExtensions, maxFramePayloadLength);
} else if (version.equals(WebSocketVersion.V08.toHttpHeaderValue())) { } else if (version.equals(WebSocketVersion.V08.toHttpHeaderValue())) {
// Version 8 of the wire protocol - version 10 of the draft hybi specification. // Version 8 of the wire protocol - version 10 of the draft hybi specification.
return new WebSocketServerHandshaker08(webSocketURL, subprotocols, allowExtensions); return new WebSocketServerHandshaker08(webSocketURL, subprotocols, allowExtensions, maxFramePayloadLength);
} else { } else {
return null; return null;
} }
} else { } else {
// Assume version 00 where version header was not specified // Assume version 00 where version header was not specified
return new WebSocketServerHandshaker00(webSocketURL, subprotocols); return new WebSocketServerHandshaker00(webSocketURL, subprotocols, maxFramePayloadLength);
} }
} }

View File

@ -18,6 +18,7 @@ package io.netty.handler.codec.http;
import static org.junit.Assert.*; import static org.junit.Assert.*;
import io.netty.buffer.ChannelBuffer; import io.netty.buffer.ChannelBuffer;
import io.netty.buffer.ChannelBuffers; import io.netty.buffer.ChannelBuffers;
import io.netty.handler.codec.CodecException;
import io.netty.handler.codec.PrematureChannelClosureException; import io.netty.handler.codec.PrematureChannelClosureException;
import io.netty.handler.codec.embedder.DecoderEmbedder; import io.netty.handler.codec.embedder.DecoderEmbedder;
import io.netty.handler.codec.embedder.EncoderEmbedder; import io.netty.handler.codec.embedder.EncoderEmbedder;
@ -72,7 +73,7 @@ public class HttpClientCodecTest {
try { try {
encoder.finish(); encoder.finish();
fail(); fail();
} catch (CodecEmbedderException e) { } catch (CodecException e) {
assertTrue(e.getCause() instanceof PrematureChannelClosureException); assertTrue(e.getCause() instanceof PrematureChannelClosureException);
} }
@ -92,7 +93,7 @@ public class HttpClientCodecTest {
encoder.finish(); encoder.finish();
decoder.finish(); decoder.finish();
fail(); fail();
} catch (CodecEmbedderException e) { } catch (CodecException e) {
assertTrue(e.getCause() instanceof PrematureChannelClosureException); assertTrue(e.getCause() instanceof PrematureChannelClosureException);
} }

View File

@ -74,7 +74,7 @@ public class WebSocketServerHandshaker00Test {
ChannelBuffer buffer = ChannelBuffers.copiedBuffer("^n:ds[4U", Charset.defaultCharset()); ChannelBuffer buffer = ChannelBuffers.copiedBuffer("^n:ds[4U", Charset.defaultCharset());
req.setContent(buffer); req.setContent(buffer);
WebSocketServerHandshaker00 handsaker = new WebSocketServerHandshaker00("ws://example.com/chat", "chat"); WebSocketServerHandshaker00 handsaker = new WebSocketServerHandshaker00("ws://example.com/chat", "chat", Long.MAX_VALUE);
handsaker.handshake(channelMock, req); handsaker.handshake(channelMock, req);
Assert.assertEquals("ws://example.com/chat", res.getValue().getHeader(Names.SEC_WEBSOCKET_LOCATION)); Assert.assertEquals("ws://example.com/chat", res.getValue().getHeader(Names.SEC_WEBSOCKET_LOCATION));

View File

@ -67,7 +67,7 @@ public class WebSocketServerHandshaker08Test {
req.setHeader(Names.SEC_WEBSOCKET_PROTOCOL, "chat, superchat"); req.setHeader(Names.SEC_WEBSOCKET_PROTOCOL, "chat, superchat");
req.setHeader(Names.SEC_WEBSOCKET_VERSION, "8"); req.setHeader(Names.SEC_WEBSOCKET_VERSION, "8");
WebSocketServerHandshaker08 handsaker = new WebSocketServerHandshaker08("ws://example.com/chat", "chat", false); WebSocketServerHandshaker08 handsaker = new WebSocketServerHandshaker08("ws://example.com/chat", "chat", false, Long.MAX_VALUE);
handsaker.handshake(channelMock, req); handsaker.handshake(channelMock, req);
Assert.assertEquals("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", res.getValue().getHeader(Names.SEC_WEBSOCKET_ACCEPT)); Assert.assertEquals("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", res.getValue().getHeader(Names.SEC_WEBSOCKET_ACCEPT));

View File

@ -66,7 +66,7 @@ public class WebSocketServerHandshaker13Test {
req.setHeader(Names.SEC_WEBSOCKET_ORIGIN, "http://example.com"); req.setHeader(Names.SEC_WEBSOCKET_ORIGIN, "http://example.com");
req.setHeader(Names.SEC_WEBSOCKET_PROTOCOL, "chat, superchat"); req.setHeader(Names.SEC_WEBSOCKET_PROTOCOL, "chat, superchat");
req.setHeader(Names.SEC_WEBSOCKET_VERSION, "13"); req.setHeader(Names.SEC_WEBSOCKET_VERSION, "13");
WebSocketServerHandshaker13 handsaker = new WebSocketServerHandshaker13("ws://example.com/chat", "chat", false); WebSocketServerHandshaker13 handsaker = new WebSocketServerHandshaker13("ws://example.com/chat", "chat", false, Long.MAX_VALUE);
handsaker.handshake(channelMock, req); handsaker.handshake(channelMock, req);
Assert.assertEquals("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", res.getValue().getHeader(Names.SEC_WEBSOCKET_ACCEPT)); Assert.assertEquals("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", res.getValue().getHeader(Names.SEC_WEBSOCKET_ACCEPT));