Change the WebSocket API to use HttpHeaders instead of Map<String, String> for custom headers / Cleanup

This commit is contained in:
Trustin Lee 2013-01-17 00:33:40 +09:00
parent 540bc99549
commit 3b79008eda
19 changed files with 206 additions and 206 deletions

View File

@ -65,7 +65,7 @@ public abstract class DefaultHttpMessage extends DefaultHttpObject implements Ht
}
void appendHeaders(StringBuilder buf) {
for (Map.Entry<String, String> e: headers().entries()) {
for (Map.Entry<String, String> e: headers()) {
buf.append(e.getKey());
buf.append(": ");
buf.append(e.getValue());

View File

@ -44,7 +44,7 @@ import io.netty.handler.codec.MessageToMessageDecoder;
public abstract class HttpContentDecoder extends MessageToMessageDecoder<Object> {
private EmbeddedByteChannel decoder;
private HttpMessage header;
private HttpMessage message;
private boolean decodeStarted;
/**
@ -61,8 +61,8 @@ public abstract class HttpContentDecoder extends MessageToMessageDecoder<Object>
return msg;
}
if (msg instanceof HttpMessage) {
assert header == null;
header = (HttpMessage) msg;
assert message == null;
message = (HttpMessage) msg;
cleanup();
}
@ -72,11 +72,12 @@ public abstract class HttpContentDecoder extends MessageToMessageDecoder<Object>
if (!decodeStarted) {
decodeStarted = true;
HttpMessage header = this.header;
this.header = null;
HttpMessage message = this.message;
HttpHeaders headers = message.headers();
this.message = null;
// Determine the content encoding.
String contentEncoding = header.headers().get(HttpHeaders.Names.CONTENT_ENCODING);
String contentEncoding = headers.get(HttpHeaders.Names.CONTENT_ENCODING);
if (contentEncoding != null) {
contentEncoding = contentEncoding.trim();
} else {
@ -90,21 +91,21 @@ public abstract class HttpContentDecoder extends MessageToMessageDecoder<Object>
if (HttpHeaders.Values.IDENTITY.equals(targetContentEncoding)) {
// Do NOT set the 'Content-Encoding' header if the target encoding is 'identity'
// as per: http://tools.ietf.org/html/rfc2616#section-14.11
header.headers().remove(HttpHeaders.Names.CONTENT_ENCODING);
headers.remove(HttpHeaders.Names.CONTENT_ENCODING);
} else {
header.headers().set(HttpHeaders.Names.CONTENT_ENCODING, targetContentEncoding);
headers.set(HttpHeaders.Names.CONTENT_ENCODING, targetContentEncoding);
}
Object[] decoded = decodeContent(header, c);
Object[] decoded = decodeContent(message, c);
// Replace the content.
if (header.headers().contains(HttpHeaders.Names.CONTENT_LENGTH)) {
header.headers().set(
if (headers.contains(HttpHeaders.Names.CONTENT_LENGTH)) {
headers.set(
HttpHeaders.Names.CONTENT_LENGTH,
Integer.toString(((ByteBufHolder) decoded[1]).data().readableBytes()));
}
return decoded;
}
return new Object[] { header, c };
return new Object[] { message, c };
}
return decodeContent(null, c);
}

View File

@ -51,7 +51,7 @@ public abstract class HttpContentEncoder extends MessageToMessageCodec<HttpMessa
private final Queue<String> acceptEncodingQueue = new ArrayDeque<String>();
private EmbeddedByteChannel encoder;
private HttpMessage header;
private HttpMessage message;
private boolean encodeStarted;
/**
@ -84,26 +84,24 @@ public abstract class HttpContentEncoder extends MessageToMessageCodec<HttpMessa
return msg;
}
if (msg instanceof HttpMessage) {
assert header == null;
assert message == null;
// check if this message is also of type HttpContent is such case just make a safe copy of the headers
// as the content will get handled later and this simplify the handling
if (msg instanceof HttpContent) {
if (msg instanceof HttpRequest) {
HttpRequest reqHeader = (HttpRequest) msg;
header = new DefaultHttpRequest(reqHeader.protocolVersion(), reqHeader.method(),
reqHeader.uri());
HttpHeaders.setHeaders(reqHeader, header);
HttpRequest req = (HttpRequest) msg;
message = new DefaultHttpRequest(req.protocolVersion(), req.method(), req.uri());
message.headers().set(req.headers());
} else if (msg instanceof HttpResponse) {
HttpResponse responseHeader = (HttpResponse) msg;
header = new DefaultHttpResponse(responseHeader.protocolVersion(),
responseHeader.status());
HttpHeaders.setHeaders(responseHeader, header);
HttpResponse res = (HttpResponse) msg;
message = new DefaultHttpResponse(res.protocolVersion(), res.status());
message.headers().set(res.headers());
} else {
return msg;
}
} else {
header = (HttpMessage) msg;
message = (HttpMessage) msg;
}
cleanup();
@ -114,21 +112,22 @@ public abstract class HttpContentEncoder extends MessageToMessageCodec<HttpMessa
if (!encodeStarted) {
encodeStarted = true;
HttpMessage header = this.header;
this.header = null;
HttpMessage message = this.message;
HttpHeaders headers = message.headers();
this.message = null;
// Determine the content encoding.
String acceptEncoding = acceptEncodingQueue.poll();
if (acceptEncoding == null) {
throw new IllegalStateException("cannot send more responses than requests");
}
Result result = beginEncode(header, c, acceptEncoding);
Result result = beginEncode(message, c, acceptEncoding);
if (result == null) {
if (c instanceof LastHttpContent) {
return new Object[] { header, new DefaultLastHttpContent(c.data()) };
return new Object[] { message, new DefaultLastHttpContent(c.data()) };
} else {
return new Object[] { header, new DefaultHttpContent(c.data()) };
return new Object[] { message, new DefaultHttpContent(c.data()) };
}
}
@ -136,18 +135,18 @@ public abstract class HttpContentEncoder extends MessageToMessageCodec<HttpMessa
// Encode the content and remove or replace the existing headers
// so that the message looks like a decoded message.
header.headers().set(
headers.set(
HttpHeaders.Names.CONTENT_ENCODING,
result.getTargetContentEncoding());
Object[] encoded = encodeContent(header, c);
Object[] encoded = encodeContent(message, c);
if (!HttpHeaders.isTransferEncodingChunked(header) && encoded.length == 3) {
if (header.headers().contains(HttpHeaders.Names.CONTENT_LENGTH)) {
if (!HttpHeaders.isTransferEncodingChunked(message) && encoded.length == 3) {
if (headers.contains(HttpHeaders.Names.CONTENT_LENGTH)) {
long length = ((ByteBufHolder) encoded[1]).data().readableBytes() +
((ByteBufHolder) encoded[2]).data().readableBytes();
header.headers().set(
headers.set(
HttpHeaders.Names.CONTENT_LENGTH,
Long.toString(length));
}

View File

@ -584,17 +584,18 @@ public abstract class HttpHeaders implements Iterable<Map.Entry<String, String>>
* </ul>
*/
public static void setKeepAlive(HttpMessage message, boolean keepAlive) {
HttpHeaders h = message.headers();
if (message.protocolVersion().isKeepAliveDefault()) {
if (keepAlive) {
message.headers().remove(Names.CONNECTION);
h.remove(Names.CONNECTION);
} else {
message.headers().set(Names.CONNECTION, Values.CLOSE);
h.set(Names.CONNECTION, Values.CLOSE);
}
} else {
if (keepAlive) {
message.headers().set(Names.CONNECTION, Values.KEEP_ALIVE);
h.set(Names.CONNECTION, Values.KEEP_ALIVE);
} else {
message.headers().remove(Names.CONNECTION);
h.remove(Names.CONNECTION);
}
}
}
@ -879,18 +880,19 @@ public abstract class HttpHeaders implements Iterable<Map.Entry<String, String>>
*/
private static int getWebSocketContentLength(HttpMessage message) {
// WebSockset messages have constant content-lengths.
HttpHeaders h = message.headers();
if (message instanceof HttpRequest) {
HttpRequest req = (HttpRequest) message;
if (HttpMethod.GET.equals(req.method()) &&
req.headers().contains(Names.SEC_WEBSOCKET_KEY1) &&
req.headers().contains(Names.SEC_WEBSOCKET_KEY2)) {
h.contains(Names.SEC_WEBSOCKET_KEY1) &&
h.contains(Names.SEC_WEBSOCKET_KEY2)) {
return 8;
}
} else if (message instanceof HttpResponse) {
HttpResponse res = (HttpResponse) message;
if (res.status().code() == 101 &&
res.headers().contains(Names.SEC_WEBSOCKET_ORIGIN) &&
res.headers().contains(Names.SEC_WEBSOCKET_LOCATION)) {
h.contains(Names.SEC_WEBSOCKET_ORIGIN) &&
h.contains(Names.SEC_WEBSOCKET_LOCATION)) {
return 16;
}
}
@ -1015,14 +1017,6 @@ public abstract class HttpHeaders implements Iterable<Map.Entry<String, String>>
}
}
/**
* Set the headers on the dst like they are set on the src
*/
public static void setHeaders(HttpMessage src, HttpMessage dst) {
for (String name: src.headers().names()) {
dst.headers().set(name, src.headers().getAll(name));
}
}
/**
* Validates the name of a header
*

View File

@ -26,8 +26,6 @@ import io.netty.handler.codec.MessageToMessageDecoder;
import io.netty.handler.codec.TooLongFrameException;
import io.netty.util.CharsetUtil;
import java.util.Map.Entry;
import static io.netty.handler.codec.http.HttpHeaders.*;
/**
@ -146,10 +144,8 @@ public class HttpObjectAggregator extends MessageToMessageDecoder<HttpObject> {
throw new Error();
}
HttpHeaders headers = currentMessage.headers();
for (String name: m.headers().names()) {
headers.set(name, m.headers().get(name));
}
currentMessage.headers().set(m.headers());
// A streamed message - initialize the cumulative buffer, and wait for incoming chunks.
removeTransferEncodingChunked(currentMessage);
return null;
@ -194,9 +190,7 @@ public class HttpObjectAggregator extends MessageToMessageDecoder<HttpObject> {
// Merge trailing headers into the message.
if (chunk instanceof LastHttpContent) {
LastHttpContent trailer = (LastHttpContent) chunk;
for (Entry<String, String> header: trailer.trailingHeaders()) {
currentMessage.headers().add(header.getKey(), header.getValue());
}
currentMessage.headers().add(trailer.trailingHeaders());
}
// Set the 'Content-Length' header.

View File

@ -493,18 +493,20 @@ public abstract class HttpObjectDecoder extends ReplayingDecoder<HttpObjectDecod
private State readHeaders(ByteBuf buffer) {
headerSize = 0;
final HttpMessage message = this.message;
final HttpHeaders headers = message.headers();
String line = readHeader(buffer);
String name = null;
String value = null;
if (!line.isEmpty()) {
message.headers().clear();
headers.clear();
do {
char firstChar = line.charAt(0);
if (name != null && (firstChar == ' ' || firstChar == '\t')) {
value = value + ' ' + line.trim();
} else {
if (name != null) {
message.headers().add(name, value);
headers.add(name, value);
}
String[] header = splitHeader(line);
name = header[0];
@ -516,7 +518,7 @@ public abstract class HttpObjectDecoder extends ReplayingDecoder<HttpObjectDecod
// Add the last header.
if (name != null) {
message.headers().add(name, value);
headers.add(name, value);
}
}

View File

@ -195,8 +195,10 @@ public class HttpPostRequestDecoder {
this.charset = charset;
this.factory = factory;
// Fill default values
if (this.request.headers().contains(HttpHeaders.Names.CONTENT_TYPE)) {
checkMultipart(this.request.headers().get(HttpHeaders.Names.CONTENT_TYPE));
String contentType = this.request.headers().get(HttpHeaders.Names.CONTENT_TYPE);
if (contentType != null) {
checkMultipart(contentType);
} else {
isMultipart = false;
}

View File

@ -618,10 +618,12 @@ public class HttpPostRequestEncoder implements ChunkedMessageInput<HttpContent>
} else {
throw new ErrorDataEncoderException("Header already encoded");
}
List<String> contentTypes = request.headers().getAll(HttpHeaders.Names.CONTENT_TYPE);
List<String> transferEncoding = request.headers().getAll(HttpHeaders.Names.TRANSFER_ENCODING);
HttpHeaders headers = request.headers();
List<String> contentTypes = headers.getAll(HttpHeaders.Names.CONTENT_TYPE);
List<String> transferEncoding = headers.getAll(HttpHeaders.Names.TRANSFER_ENCODING);
if (contentTypes != null) {
request.headers().remove(HttpHeaders.Names.CONTENT_TYPE);
headers.remove(HttpHeaders.Names.CONTENT_TYPE);
for (String contentType : contentTypes) {
// "multipart/form-data; boundary=--89421926422648"
if (contentType.toLowerCase().startsWith(HttpHeaders.Values.MULTIPART_FORM_DATA)) {
@ -629,17 +631,17 @@ public class HttpPostRequestEncoder implements ChunkedMessageInput<HttpContent>
} else if (contentType.toLowerCase().startsWith(HttpHeaders.Values.APPLICATION_X_WWW_FORM_URLENCODED)) {
// ignore
} else {
request.headers().add(HttpHeaders.Names.CONTENT_TYPE, contentType);
headers.add(HttpHeaders.Names.CONTENT_TYPE, contentType);
}
}
}
if (isMultipart) {
String value = HttpHeaders.Values.MULTIPART_FORM_DATA + "; " + HttpHeaders.Values.BOUNDARY + '='
+ multipartDataBoundary;
request.headers().add(HttpHeaders.Names.CONTENT_TYPE, value);
headers.add(HttpHeaders.Names.CONTENT_TYPE, value);
} else {
// Not multipart
request.headers().add(HttpHeaders.Names.CONTENT_TYPE, HttpHeaders.Values.APPLICATION_X_WWW_FORM_URLENCODED);
headers.add(HttpHeaders.Names.CONTENT_TYPE, HttpHeaders.Values.APPLICATION_X_WWW_FORM_URLENCODED);
}
// Now consider size for chunk or not
long realSize = globalBodySize;
@ -649,16 +651,16 @@ public class HttpPostRequestEncoder implements ChunkedMessageInput<HttpContent>
realSize -= 1; // last '&' removed
iterator = multipartHttpDatas.listIterator();
}
request.headers().set(HttpHeaders.Names.CONTENT_LENGTH, String.valueOf(realSize));
headers.set(HttpHeaders.Names.CONTENT_LENGTH, String.valueOf(realSize));
if (realSize > HttpPostBodyUtil.chunkSize || isMultipart) {
isChunked = true;
if (transferEncoding != null) {
request.headers().remove(HttpHeaders.Names.TRANSFER_ENCODING);
headers.remove(HttpHeaders.Names.TRANSFER_ENCODING);
for (String v : transferEncoding) {
if (v.equalsIgnoreCase(HttpHeaders.Values.CHUNKED)) {
// ignore
} else {
request.headers().add(HttpHeaders.Names.TRANSFER_ENCODING, v);
headers.add(HttpHeaders.Names.TRANSFER_ENCODING, v);
}
}
}

View File

@ -19,9 +19,9 @@ import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaders;
import java.net.URI;
import java.util.Map;
/**
* Base class for web socket client handshake implementations
@ -38,7 +38,7 @@ public abstract class WebSocketClientHandshaker {
private String actualSubprotocol;
protected final Map<String, String> customHeaders;
protected final HttpHeaders customHeaders;
private final int maxFramePayloadLength;
@ -58,7 +58,7 @@ public abstract class WebSocketClientHandshaker {
* Maximum length of a frame's payload
*/
protected WebSocketClientHandshaker(URI webSocketUrl, WebSocketVersion version, String subprotocol,
Map<String, String> customHeaders, int maxFramePayloadLength) {
HttpHeaders customHeaders, int maxFramePayloadLength) {
this.webSocketUrl = webSocketUrl;
this.version = version;
expectedSubprotocol = subprotocol;

View File

@ -23,6 +23,7 @@ import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpHeaders.Names;
import io.netty.handler.codec.http.HttpHeaders.Values;
import io.netty.handler.codec.http.HttpMethod;
@ -34,7 +35,6 @@ import io.netty.handler.codec.http.HttpVersion;
import java.net.URI;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Map;
/**
* <p>
@ -66,7 +66,7 @@ public class WebSocketClientHandshaker00 extends WebSocketClientHandshaker {
* Maximum length of a frame's payload
*/
public WebSocketClientHandshaker00(URI webSocketURL, WebSocketVersion version, String subprotocol,
Map<String, String> customHeaders, int maxFramePayloadLength) {
HttpHeaders customHeaders, int maxFramePayloadLength) {
super(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength);
}
@ -142,9 +142,10 @@ public class WebSocketClientHandshaker00 extends WebSocketClientHandshaker {
// Format request
FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path);
request.headers().add(Names.UPGRADE, Values.WEBSOCKET);
request.headers().add(Names.CONNECTION, Values.UPGRADE);
request.headers().add(Names.HOST, wsURL.getHost());
HttpHeaders headers = request.headers();
headers.add(Names.UPGRADE, Values.WEBSOCKET)
.add(Names.CONNECTION, Values.UPGRADE)
.add(Names.HOST, wsURL.getHost());
int wsPort = wsURL.getPort();
String originValue = "http://" + wsURL.getHost();
@ -154,23 +155,22 @@ public class WebSocketClientHandshaker00 extends WebSocketClientHandshaker {
originValue = originValue + ':' + wsPort;
}
request.headers().add(Names.ORIGIN, originValue);
request.headers().add(Names.SEC_WEBSOCKET_KEY1, key1);
request.headers().add(Names.SEC_WEBSOCKET_KEY2, key2);
headers.add(Names.ORIGIN, originValue)
.add(Names.SEC_WEBSOCKET_KEY1, key1)
.add(Names.SEC_WEBSOCKET_KEY2, key2);
String expectedSubprotocol = getExpectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
request.headers().add(Names.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol);
headers.add(Names.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol);
}
if (customHeaders != null) {
for (Map.Entry<String, String> e : customHeaders.entrySet()) {
request.headers().add(e.getKey(), e.getValue());
}
headers.add(customHeaders);
}
// Set Content-Length to workaround some known defect.
// See also: http://www.ietf.org/mail-archive/web/hybi/current/msg02149.html
request.headers().set(Names.CONTENT_LENGTH, key3.length);
headers.set(Names.CONTENT_LENGTH, key3.length);
request.data().writeBytes(key3);
ChannelFuture future = channel.write(request);
@ -223,13 +223,15 @@ public class WebSocketClientHandshaker00 extends WebSocketClientHandshaker {
throw new WebSocketHandshakeException("Invalid handshake response status: " + response.status());
}
String upgrade = response.headers().get(Names.UPGRADE);
HttpHeaders headers = response.headers();
String upgrade = headers.get(Names.UPGRADE);
if (!Values.WEBSOCKET.equalsIgnoreCase(upgrade)) {
throw new WebSocketHandshakeException("Invalid handshake response upgrade: "
+ upgrade);
}
String connection = response.headers().get(Names.CONNECTION);
String connection = headers.get(Names.CONNECTION);
if (!Values.UPGRADE.equalsIgnoreCase(connection)) {
throw new WebSocketHandshakeException("Invalid handshake response connection: "
+ connection);
@ -240,7 +242,7 @@ public class WebSocketClientHandshaker00 extends WebSocketClientHandshaker {
throw new WebSocketHandshakeException("Invalid challenge");
}
String subprotocol = response.headers().get(Names.SEC_WEBSOCKET_PROTOCOL);
String subprotocol = headers.get(Names.SEC_WEBSOCKET_PROTOCOL);
setActualSubprotocol(subprotocol);
setHandshakeComplete();

View File

@ -23,6 +23,7 @@ import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpHeaders.Names;
import io.netty.handler.codec.http.HttpHeaders.Values;
import io.netty.handler.codec.http.HttpMethod;
@ -35,7 +36,6 @@ import io.netty.logging.InternalLoggerFactory;
import io.netty.util.CharsetUtil;
import java.net.URI;
import java.util.Map;
/**
* <p>
@ -72,7 +72,7 @@ public class WebSocketClientHandshaker07 extends WebSocketClientHandshaker {
* Maximum length of a frame's payload
*/
public WebSocketClientHandshaker07(URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, Map<String, String> customHeaders, int maxFramePayloadLength) {
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength) {
super(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength);
this.allowExtensions = allowExtensions;
}
@ -125,10 +125,12 @@ public class WebSocketClientHandshaker07 extends WebSocketClientHandshaker {
// Format request
FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path);
request.headers().add(Names.UPGRADE, Values.WEBSOCKET.toLowerCase());
request.headers().add(Names.CONNECTION, Values.UPGRADE);
request.headers().add(Names.SEC_WEBSOCKET_KEY, key);
request.headers().add(Names.HOST, wsURL.getHost());
HttpHeaders headers = request.headers();
headers.add(Names.UPGRADE, Values.WEBSOCKET.toLowerCase())
.add(Names.CONNECTION, Values.UPGRADE)
.add(Names.SEC_WEBSOCKET_KEY, key)
.add(Names.HOST, wsURL.getHost());
int wsPort = wsURL.getPort();
String originValue = "http://" + wsURL.getHost();
@ -137,19 +139,17 @@ public class WebSocketClientHandshaker07 extends WebSocketClientHandshaker {
// See http://tools.ietf.org/html/rfc6454#section-6.2
originValue = originValue + ':' + wsPort;
}
request.headers().add(Names.SEC_WEBSOCKET_ORIGIN, originValue);
headers.add(Names.SEC_WEBSOCKET_ORIGIN, originValue);
String expectedSubprotocol = getExpectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
request.headers().add(Names.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol);
headers.add(Names.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol);
}
request.headers().add(Names.SEC_WEBSOCKET_VERSION, "7");
headers.add(Names.SEC_WEBSOCKET_VERSION, "7");
if (customHeaders != null) {
for (Map.Entry<String, String> e : customHeaders.entrySet()) {
request.headers().add(e.getKey(), e.getValue());
}
headers.add(customHeaders);
}
ChannelFuture future = channel.write(request);
@ -194,30 +194,29 @@ public class WebSocketClientHandshaker07 extends WebSocketClientHandshaker {
@Override
public void finishHandshake(Channel channel, FullHttpResponse response) {
final HttpResponseStatus status = HttpResponseStatus.SWITCHING_PROTOCOLS;
final HttpHeaders headers = response.headers();
if (!response.status().equals(status)) {
throw new WebSocketHandshakeException("Invalid handshake response status: " + response.status());
}
String upgrade = response.headers().get(Names.UPGRADE);
String upgrade = headers.get(Names.UPGRADE);
if (!Values.WEBSOCKET.equalsIgnoreCase(upgrade)) {
throw new WebSocketHandshakeException("Invalid handshake response upgrade: "
+ response.headers().get(Names.UPGRADE));
throw new WebSocketHandshakeException("Invalid handshake response upgrade: " + upgrade);
}
String connection = response.headers().get(Names.CONNECTION);
String connection = headers.get(Names.CONNECTION);
if (!Values.UPGRADE.equalsIgnoreCase(connection)) {
throw new WebSocketHandshakeException("Invalid handshake response connection: "
+ response.headers().get(Names.CONNECTION));
throw new WebSocketHandshakeException("Invalid handshake response connection: " + connection);
}
String accept = response.headers().get(Names.SEC_WEBSOCKET_ACCEPT);
String accept = headers.get(Names.SEC_WEBSOCKET_ACCEPT);
if (accept == null || !accept.equals(expectedChallengeResponseString)) {
throw new WebSocketHandshakeException(String.format("Invalid challenge. Actual: %s. Expected: %s", accept,
expectedChallengeResponseString));
throw new WebSocketHandshakeException(String.format(
"Invalid challenge. Actual: %s. Expected: %s", accept, expectedChallengeResponseString));
}
String subprotocol = response.headers().get(Names.SEC_WEBSOCKET_PROTOCOL);
String subprotocol = headers.get(Names.SEC_WEBSOCKET_PROTOCOL);
setActualSubprotocol(subprotocol);
setHandshakeComplete();

View File

@ -23,6 +23,7 @@ import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpHeaders.Names;
import io.netty.handler.codec.http.HttpHeaders.Values;
import io.netty.handler.codec.http.HttpMethod;
@ -35,7 +36,6 @@ import io.netty.logging.InternalLoggerFactory;
import io.netty.util.CharsetUtil;
import java.net.URI;
import java.util.Map;
/**
* <p>
@ -72,7 +72,7 @@ public class WebSocketClientHandshaker08 extends WebSocketClientHandshaker {
* Maximum length of a frame's payload
*/
public WebSocketClientHandshaker08(URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, Map<String, String> customHeaders, int maxFramePayloadLength) {
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength) {
super(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength);
this.allowExtensions = allowExtensions;
}
@ -125,10 +125,12 @@ public class WebSocketClientHandshaker08 extends WebSocketClientHandshaker {
// Format request
FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path);
request.headers().add(Names.UPGRADE, Values.WEBSOCKET.toLowerCase());
request.headers().add(Names.CONNECTION, Values.UPGRADE);
request.headers().add(Names.SEC_WEBSOCKET_KEY, key);
request.headers().add(Names.HOST, wsURL.getHost());
HttpHeaders headers = request.headers();
headers.add(Names.UPGRADE, Values.WEBSOCKET.toLowerCase())
.add(Names.CONNECTION, Values.UPGRADE)
.add(Names.SEC_WEBSOCKET_KEY, key)
.add(Names.HOST, wsURL.getHost());
int wsPort = wsURL.getPort();
String originValue = "http://" + wsURL.getHost();
@ -137,19 +139,17 @@ public class WebSocketClientHandshaker08 extends WebSocketClientHandshaker {
// See http://tools.ietf.org/html/rfc6454#section-6.2
originValue = originValue + ':' + wsPort;
}
request.headers().add(Names.SEC_WEBSOCKET_ORIGIN, originValue);
headers.add(Names.SEC_WEBSOCKET_ORIGIN, originValue);
String expectedSubprotocol = getExpectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
request.headers().add(Names.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol);
headers.add(Names.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol);
}
request.headers().add(Names.SEC_WEBSOCKET_VERSION, "8");
headers.add(Names.SEC_WEBSOCKET_VERSION, "8");
if (customHeaders != null) {
for (Map.Entry<String, String> e : customHeaders.entrySet()) {
request.headers().add(e.getKey(), e.getValue());
}
headers.add(customHeaders);
}
ChannelFuture future = channel.write(request);
@ -194,30 +194,29 @@ public class WebSocketClientHandshaker08 extends WebSocketClientHandshaker {
@Override
public void finishHandshake(Channel channel, FullHttpResponse response) {
final HttpResponseStatus status = HttpResponseStatus.SWITCHING_PROTOCOLS;
final HttpHeaders headers = response.headers();
if (!response.status().equals(status)) {
throw new WebSocketHandshakeException("Invalid handshake response status: " + response.status());
}
String upgrade = response.headers().get(Names.UPGRADE);
String upgrade = headers.get(Names.UPGRADE);
if (!Values.WEBSOCKET.equalsIgnoreCase(upgrade)) {
throw new WebSocketHandshakeException("Invalid handshake response upgrade: "
+ response.headers().get(Names.UPGRADE));
throw new WebSocketHandshakeException("Invalid handshake response upgrade: " + upgrade);
}
String connection = response.headers().get(Names.CONNECTION);
String connection = headers.get(Names.CONNECTION);
if (!Values.UPGRADE.equalsIgnoreCase(connection)) {
throw new WebSocketHandshakeException("Invalid handshake response connection: "
+ response.headers().get(Names.CONNECTION));
throw new WebSocketHandshakeException("Invalid handshake response connection: " + connection);
}
String accept = response.headers().get(Names.SEC_WEBSOCKET_ACCEPT);
String accept = headers.get(Names.SEC_WEBSOCKET_ACCEPT);
if (accept == null || !accept.equals(expectedChallengeResponseString)) {
throw new WebSocketHandshakeException(String.format("Invalid challenge. Actual: %s. Expected: %s", accept,
expectedChallengeResponseString));
throw new WebSocketHandshakeException(String.format(
"Invalid challenge. Actual: %s. Expected: %s", accept, expectedChallengeResponseString));
}
String subprotocol = response.headers().get(Names.SEC_WEBSOCKET_PROTOCOL);
String subprotocol = headers.get(Names.SEC_WEBSOCKET_PROTOCOL);
setActualSubprotocol(subprotocol);
setHandshakeComplete();

View File

@ -22,6 +22,7 @@ import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpHeaders.Names;
import io.netty.handler.codec.http.HttpHeaders.Values;
import io.netty.handler.codec.http.HttpMethod;
@ -35,7 +36,6 @@ import io.netty.logging.InternalLoggerFactory;
import io.netty.util.CharsetUtil;
import java.net.URI;
import java.util.Map;
/**
* <p>
@ -72,7 +72,7 @@ public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker {
* Maximum length of a frame's payload
*/
public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, Map<String, String> customHeaders, int maxFramePayloadLength) {
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength) {
super(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength);
this.allowExtensions = allowExtensions;
}
@ -125,10 +125,12 @@ public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker {
// Format request
HttpRequest request = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path);
request.headers().add(Names.UPGRADE, Values.WEBSOCKET.toLowerCase());
request.headers().add(Names.CONNECTION, Values.UPGRADE);
request.headers().add(Names.SEC_WEBSOCKET_KEY, key);
request.headers().add(Names.HOST, wsURL.getHost());
HttpHeaders headers = request.headers();
headers.add(Names.UPGRADE, Values.WEBSOCKET.toLowerCase())
.add(Names.CONNECTION, Values.UPGRADE)
.add(Names.SEC_WEBSOCKET_KEY, key)
.add(Names.HOST, wsURL.getHost());
int wsPort = wsURL.getPort();
String originValue = "http://" + wsURL.getHost();
@ -137,19 +139,17 @@ public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker {
// See http://tools.ietf.org/html/rfc6454#section-6.2
originValue = originValue + ':' + wsPort;
}
request.headers().add(Names.SEC_WEBSOCKET_ORIGIN, originValue);
headers.add(Names.SEC_WEBSOCKET_ORIGIN, originValue);
String expectedSubprotocol = getExpectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
request.headers().add(Names.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol);
headers.add(Names.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol);
}
request.headers().add(Names.SEC_WEBSOCKET_VERSION, "13");
headers.add(Names.SEC_WEBSOCKET_VERSION, "13");
if (customHeaders != null) {
for (Map.Entry<String, String> e: customHeaders.entrySet()) {
request.headers().add(e.getKey(), e.getValue());
}
headers.add(customHeaders);
}
ChannelFuture future = channel.write(request);
@ -193,30 +193,29 @@ public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker {
@Override
public void finishHandshake(Channel channel, FullHttpResponse response) {
final HttpResponseStatus status = HttpResponseStatus.SWITCHING_PROTOCOLS;
final HttpHeaders headers = response.headers();
if (!response.status().equals(status)) {
throw new WebSocketHandshakeException("Invalid handshake response status: " + response.status());
}
String upgrade = response.headers().get(Names.UPGRADE);
String upgrade = headers.get(Names.UPGRADE);
if (!Values.WEBSOCKET.equalsIgnoreCase(upgrade)) {
throw new WebSocketHandshakeException("Invalid handshake response upgrade: "
+ response.headers().get(Names.UPGRADE));
throw new WebSocketHandshakeException("Invalid handshake response upgrade: " + upgrade);
}
String connection = response.headers().get(Names.CONNECTION);
String connection = headers.get(Names.CONNECTION);
if (!Values.UPGRADE.equalsIgnoreCase(connection)) {
throw new WebSocketHandshakeException("Invalid handshake response connection: "
+ response.headers().get(Names.CONNECTION));
throw new WebSocketHandshakeException("Invalid handshake response connection: " + connection);
}
String accept = response.headers().get(Names.SEC_WEBSOCKET_ACCEPT);
String accept = headers.get(Names.SEC_WEBSOCKET_ACCEPT);
if (accept == null || !accept.equals(expectedChallengeResponseString)) {
throw new WebSocketHandshakeException(String.format("Invalid challenge. Actual: %s. Expected: %s", accept,
expectedChallengeResponseString));
throw new WebSocketHandshakeException(String.format(
"Invalid challenge. Actual: %s. Expected: %s", accept, expectedChallengeResponseString));
}
String subprotocol = response.headers().get(Names.SEC_WEBSOCKET_PROTOCOL);
String subprotocol = headers.get(Names.SEC_WEBSOCKET_PROTOCOL);
setActualSubprotocol(subprotocol);
setHandshakeComplete();

View File

@ -15,8 +15,9 @@
*/
package io.netty.handler.codec.http.websocketx;
import io.netty.handler.codec.http.HttpHeaders;
import java.net.URI;
import java.util.Map;
import static io.netty.handler.codec.http.websocketx.WebSocketVersion.*;
@ -48,7 +49,7 @@ public final class WebSocketClientHandshakerFactory {
*/
public static WebSocketClientHandshaker newHandshaker(
URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, Map<String, String> customHeaders) {
boolean allowExtensions, HttpHeaders customHeaders) {
return newHandshaker(webSocketURL, version, subprotocol, allowExtensions, customHeaders, 65536);
}
@ -72,7 +73,7 @@ public final class WebSocketClientHandshakerFactory {
*/
public static WebSocketClientHandshaker newHandshaker(
URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, Map<String, String> customHeaders, int maxFramePayloadLength) {
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength) {
if (version == V13) {
return new WebSocketClientHandshaker13(
webSocketURL, V13, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength);

View File

@ -235,7 +235,7 @@ public class SpdyHttpDecoder extends MessageToMessageDecoder<Object> {
SpdyHeaders.removeUrl(spdyVersion, requestFrame);
SpdyHeaders.removeVersion(spdyVersion, requestFrame);
FullHttpRequest httpRequestWithEntity = new DefaultFullHttpRequest(httpVersion, method, url);
FullHttpRequest req = new DefaultFullHttpRequest(httpVersion, method, url);
// Remove the scheme header
SpdyHeaders.removeScheme(spdyVersion, requestFrame);
@ -244,20 +244,20 @@ public class SpdyHttpDecoder extends MessageToMessageDecoder<Object> {
// Replace the SPDY host header with the HTTP host header
String host = SpdyHeaders.getHost(requestFrame);
SpdyHeaders.removeHost(requestFrame);
HttpHeaders.setHost(httpRequestWithEntity, host);
HttpHeaders.setHost(req, host);
}
for (Map.Entry<String, String> e: requestFrame.getHeaders()) {
httpRequestWithEntity.headers().add(e.getKey(), e.getValue());
req.headers().add(e.getKey(), e.getValue());
}
// The Connection and Keep-Alive headers are no longer valid
HttpHeaders.setKeepAlive(httpRequestWithEntity, true);
HttpHeaders.setKeepAlive(req, true);
// Transfer-Encoding header is not valid
httpRequestWithEntity.headers().remove(HttpHeaders.Names.TRANSFER_ENCODING);
req.headers().remove(HttpHeaders.Names.TRANSFER_ENCODING);
return httpRequestWithEntity;
return req;
}
private static FullHttpResponse createHttpResponse(int spdyVersion, SpdyHeaderBlock responseFrame)
@ -268,18 +268,18 @@ public class SpdyHttpDecoder extends MessageToMessageDecoder<Object> {
SpdyHeaders.removeStatus(spdyVersion, responseFrame);
SpdyHeaders.removeVersion(spdyVersion, responseFrame);
FullHttpResponse httpResponseWithEntity = new DefaultFullHttpResponse(version, status);
FullHttpResponse res = new DefaultFullHttpResponse(version, status);
for (Map.Entry<String, String> e: responseFrame.getHeaders()) {
httpResponseWithEntity.headers().add(e.getKey(), e.getValue());
res.headers().add(e.getKey(), e.getValue());
}
// The Connection and Keep-Alive headers are no longer valid
HttpHeaders.setKeepAlive(httpResponseWithEntity, true);
HttpHeaders.setKeepAlive(res, true);
// Transfer-Encoding header is not valid
httpResponseWithEntity.headers().remove(HttpHeaders.Names.TRANSFER_ENCODING);
httpResponseWithEntity.headers().remove(HttpHeaders.Names.TRAILER);
res.headers().remove(HttpHeaders.Names.TRANSFER_ENCODING);
res.headers().remove(HttpHeaders.Names.TRAILER);
return httpResponseWithEntity;
return res;
}
}

View File

@ -15,32 +15,33 @@
*/
package io.netty.handler.codec.http;
import org.junit.Assert;
import org.junit.Test;
import static org.junit.Assert.*;
public class DefaultHttpRequestTest {
@Test
public void testHeaderRemoval() {
HttpMessage m = new DefaultHttpRequest(
HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
HttpMessage m = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
HttpHeaders h = m.headers();
// Insert sample keys.
for (int i = 0; i < 1000; i ++) {
m.headers().set(String.valueOf(i), "");
h.set(String.valueOf(i), "");
}
// Remove in reversed order.
for (int i = 999; i >= 0; i --) {
m.headers().remove(String.valueOf(i));
h.remove(String.valueOf(i));
}
// Check if random access returns nothing.
for (int i = 0; i < 1000; i ++) {
Assert.assertNull(m.headers().get(String.valueOf(i)));
assertNull(h.get(String.valueOf(i)));
}
// Check if sequential access returns nothing.
Assert.assertTrue(m.headers().isEmpty());
assertTrue(h.isEmpty());
}
}

View File

@ -20,11 +20,12 @@ import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedByteChannel;
import io.netty.handler.codec.DecoderResult;
import io.netty.util.CharsetUtil;
import org.junit.Assert;
import org.junit.Test;
import java.util.Random;
import static org.junit.Assert.*;
public class HttpInvalidMessageTest {
private final Random rnd = new Random();
@ -35,8 +36,8 @@ public class HttpInvalidMessageTest {
ch.writeInbound(Unpooled.copiedBuffer("GET / HTTP/1.0 with extra\r\n", CharsetUtil.UTF_8));
HttpRequest req = (HttpRequest) ch.readInbound();
DecoderResult dr = req.decoderResult();
Assert.assertFalse(dr.isSuccess());
Assert.assertFalse(dr.isPartialFailure());
assertFalse(dr.isSuccess());
assertFalse(dr.isPartialFailure());
ensureInboundTrafficDiscarded(ch);
}
@ -49,10 +50,10 @@ public class HttpInvalidMessageTest {
ch.writeInbound(Unpooled.copiedBuffer("\r\n", CharsetUtil.UTF_8));
HttpRequest req = (HttpRequest) ch.readInbound();
DecoderResult dr = req.decoderResult();
Assert.assertFalse(dr.isSuccess());
Assert.assertTrue(dr.isPartialFailure());
Assert.assertEquals("Good Value", req.headers().get("Good_Name"));
Assert.assertEquals("/maybe-something", req.uri());
assertFalse(dr.isSuccess());
assertTrue(dr.isPartialFailure());
assertEquals("Good Value", req.headers().get("Good_Name"));
assertEquals("/maybe-something", req.uri());
ensureInboundTrafficDiscarded(ch);
}
@ -62,8 +63,8 @@ public class HttpInvalidMessageTest {
ch.writeInbound(Unpooled.copiedBuffer("HTTP/1.0 BAD_CODE Bad Server\r\n", CharsetUtil.UTF_8));
HttpResponse res = (HttpResponse) ch.readInbound();
DecoderResult dr = res.decoderResult();
Assert.assertFalse(dr.isSuccess());
Assert.assertFalse(dr.isPartialFailure());
assertFalse(dr.isSuccess());
assertFalse(dr.isPartialFailure());
ensureInboundTrafficDiscarded(ch);
}
@ -76,10 +77,10 @@ public class HttpInvalidMessageTest {
ch.writeInbound(Unpooled.copiedBuffer("\r\n", CharsetUtil.UTF_8));
HttpResponse res = (HttpResponse) ch.readInbound();
DecoderResult dr = res.decoderResult();
Assert.assertFalse(dr.isSuccess());
Assert.assertTrue(dr.isPartialFailure());
Assert.assertEquals("Maybe OK", res.status().reasonPhrase());
Assert.assertEquals("Good Value", res.headers().get("Good_Name"));
assertFalse(dr.isSuccess());
assertTrue(dr.isPartialFailure());
assertEquals("Maybe OK", res.status().reasonPhrase());
assertEquals("Good Value", res.headers().get("Good_Name"));
ensureInboundTrafficDiscarded(ch);
}
@ -91,12 +92,12 @@ public class HttpInvalidMessageTest {
ch.writeInbound(Unpooled.copiedBuffer("BAD_LENGTH\r\n", CharsetUtil.UTF_8));
HttpRequest req = (HttpRequest) ch.readInbound();
Assert.assertTrue(req.decoderResult().isSuccess());
assertTrue(req.decoderResult().isSuccess());
HttpContent chunk = (HttpContent) ch.readInbound();
DecoderResult dr = chunk.decoderResult();
Assert.assertFalse(dr.isSuccess());
Assert.assertFalse(dr.isPartialFailure());
assertFalse(dr.isSuccess());
assertFalse(dr.isPartialFailure());
ensureInboundTrafficDiscarded(ch);
}
@ -110,7 +111,7 @@ public class HttpInvalidMessageTest {
buf.setIndex(0, data.length);
ch.writeInbound(buf);
ch.checkException();
Assert.assertNull(ch.readInbound());
assertNull(ch.readInbound());
}
}
}

View File

@ -17,6 +17,7 @@ package io.netty.handler.codec.http.websocketx;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpHeaders.Names;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequest;
@ -98,23 +99,25 @@ public class WebSocketRequestBuilder {
public FullHttpRequest build() {
FullHttpRequest req = new DefaultFullHttpRequest(httpVersion, method, uri);
HttpHeaders headers = req.headers();
if (host != null) {
req.headers().set(Names.HOST, host);
headers.set(Names.HOST, host);
}
if (upgrade != null) {
req.headers().set(Names.UPGRADE, upgrade);
headers.set(Names.UPGRADE, upgrade);
}
if (connection != null) {
req.headers().set(Names.CONNECTION, connection);
headers.set(Names.CONNECTION, connection);
}
if (key != null) {
req.headers().set(Names.SEC_WEBSOCKET_KEY, key);
headers.set(Names.SEC_WEBSOCKET_KEY, key);
}
if (origin != null) {
req.headers().set(Names.SEC_WEBSOCKET_ORIGIN, origin);
headers.set(Names.SEC_WEBSOCKET_ORIGIN, origin);
}
if (version != null) {
req.headers().set(Names.SEC_WEBSOCKET_VERSION, version.toHttpHeaderValue());
headers.set(Names.SEC_WEBSOCKET_VERSION, version.toHttpHeaderValue());
}
return req;
}

View File

@ -44,6 +44,8 @@ import io.netty.channel.ChannelPipeline;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequestEncoder;
import io.netty.handler.codec.http.HttpResponseDecoder;
@ -54,7 +56,6 @@ import io.netty.handler.codec.http.websocketx.WebSocketClientHandshakerFactory;
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
import java.net.URI;
import java.util.HashMap;
public class WebSocketClient {
@ -72,8 +73,8 @@ public class WebSocketClient {
throw new IllegalArgumentException("Unsupported protocol: " + protocol);
}
HashMap<String, String> customHeaders = new HashMap<String, String>();
customHeaders.put("MyHeader", "MyValue");
HttpHeaders customHeaders = new DefaultHttpHeaders();
customHeaders.add("MyHeader", "MyValue");
// Connect with V13 (RFC 6455 aka HyBi-17). You can change it to V08 or V00.
// If you change it to V00, ping is not supported and remember to change