Set (and override) websocket handshake headers after custom headers (#7975)

Motivation:

Currently, when passing custom headers to a WebSocketClientHandshaker,
if values are added for headers that are reserved for use in the
websocket handshake performed with the server, these custom values can
be used by the server to compute the websocket handshake challenge. If
the server computes the response to the challenge with the custom header
values, rather than the values computed by the client handshaker, the
handshake may fail.

Modifications:

Update the client handshaker implementations to add the custom header
values first, and then set the reserved websocket header values.

Result:

Reserved websocket handshake headers, if present in the custom headers
passed to the client handshaker, will not be propagated to the server.
Instead the client handshaker will propagate the values it generates.

Fixes #7973.
This commit is contained in:
Nick Travers 2018-05-30 10:52:40 -07:00 committed by Norman Maurer
parent b53cf045a7
commit 48911e0b63
9 changed files with 143 additions and 61 deletions

View File

@ -131,22 +131,23 @@ public class WebSocketClientHandshaker00 extends WebSocketClientHandshaker {
// Format request // Format request
FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path); FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path);
HttpHeaders headers = request.headers(); HttpHeaders headers = request.headers();
headers.add(HttpHeaderNames.UPGRADE, WEBSOCKET)
.add(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.add(HttpHeaderNames.HOST, websocketHostValue(wsURL))
.add(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL))
.add(HttpHeaderNames.SEC_WEBSOCKET_KEY1, key1)
.add(HttpHeaderNames.SEC_WEBSOCKET_KEY2, key2);
String expectedSubprotocol = expectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
headers.add(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol);
}
if (customHeaders != null) { if (customHeaders != null) {
headers.add(customHeaders); headers.add(customHeaders);
} }
headers.set(HttpHeaderNames.UPGRADE, WEBSOCKET)
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL))
.set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL))
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY1, key1)
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY2, key2);
String expectedSubprotocol = expectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
headers.set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol);
}
// Set Content-Length to workaround some known defect. // Set Content-Length to workaround some known defect.
// See also: http://www.ietf.org/mail-archive/web/hybi/current/msg02149.html // See also: http://www.ietf.org/mail-archive/web/hybi/current/msg02149.html
headers.set(HttpHeaderNames.CONTENT_LENGTH, key3.length); headers.set(HttpHeaderNames.CONTENT_LENGTH, key3.length);

View File

@ -145,22 +145,22 @@ public class WebSocketClientHandshaker07 extends WebSocketClientHandshaker {
FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path); FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path);
HttpHeaders headers = request.headers(); HttpHeaders headers = request.headers();
headers.add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
.add(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.add(HttpHeaderNames.SEC_WEBSOCKET_KEY, key)
.add(HttpHeaderNames.HOST, websocketHostValue(wsURL))
.add(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL));
String expectedSubprotocol = expectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
headers.add(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol);
}
headers.add(HttpHeaderNames.SEC_WEBSOCKET_VERSION, "7");
if (customHeaders != null) { if (customHeaders != null) {
headers.add(customHeaders); headers.add(customHeaders);
} }
headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key)
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL))
.set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL));
String expectedSubprotocol = expectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
headers.set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol);
}
headers.set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, "7");
return request; return request;
} }

View File

@ -146,22 +146,22 @@ public class WebSocketClientHandshaker08 extends WebSocketClientHandshaker {
FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path); FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path);
HttpHeaders headers = request.headers(); HttpHeaders headers = request.headers();
headers.add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
.add(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.add(HttpHeaderNames.SEC_WEBSOCKET_KEY, key)
.add(HttpHeaderNames.HOST, websocketHostValue(wsURL))
.add(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL));
String expectedSubprotocol = expectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
headers.add(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol);
}
headers.add(HttpHeaderNames.SEC_WEBSOCKET_VERSION, "8");
if (customHeaders != null) { if (customHeaders != null) {
headers.add(customHeaders); headers.add(customHeaders);
} }
headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key)
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL))
.set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL));
String expectedSubprotocol = expectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
headers.set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol);
}
headers.set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, "8");
return request; return request;
} }

View File

@ -146,22 +146,22 @@ public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker {
FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path); FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path);
HttpHeaders headers = request.headers(); HttpHeaders headers = request.headers();
headers.add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
.add(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.add(HttpHeaderNames.SEC_WEBSOCKET_KEY, key)
.add(HttpHeaderNames.HOST, websocketHostValue(wsURL))
.add(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL));
String expectedSubprotocol = expectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
headers.add(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol);
}
headers.add(HttpHeaderNames.SEC_WEBSOCKET_VERSION, "13");
if (customHeaders != null) { if (customHeaders != null) {
headers.add(customHeaders); headers.add(customHeaders);
} }
headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key)
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL))
.set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL));
String expectedSubprotocol = expectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
headers.set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol);
}
headers.set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, "13");
return request; return request;
} }

View File

@ -16,17 +16,35 @@
package io.netty.handler.codec.http.websocketx; package io.netty.handler.codec.http.websocketx;
import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import java.net.URI; import java.net.URI;
public class WebSocketClientHandshaker00Test extends WebSocketClientHandshakerTest { public class WebSocketClientHandshaker00Test extends WebSocketClientHandshakerTest {
@Override @Override
protected WebSocketClientHandshaker newHandshaker(URI uri) { protected WebSocketClientHandshaker newHandshaker(URI uri, String subprotocol, HttpHeaders headers) {
return new WebSocketClientHandshaker00(uri, WebSocketVersion.V00, null, null, 1024); return new WebSocketClientHandshaker00(uri, WebSocketVersion.V00, subprotocol, headers, 1024);
} }
@Override @Override
protected CharSequence getOriginHeaderName() { protected CharSequence getOriginHeaderName() {
return HttpHeaderNames.ORIGIN; return HttpHeaderNames.ORIGIN;
} }
@Override
protected CharSequence getProtocolHeaderName() {
return HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL;
}
@Override
protected CharSequence[] getHandshakeHeaderNames() {
return new CharSequence[] {
HttpHeaderNames.CONNECTION,
HttpHeaderNames.UPGRADE,
HttpHeaderNames.HOST,
HttpHeaderNames.ORIGIN,
HttpHeaderNames.SEC_WEBSOCKET_KEY1,
HttpHeaderNames.SEC_WEBSOCKET_KEY2,
};
}
} }

View File

@ -16,17 +16,35 @@
package io.netty.handler.codec.http.websocketx; package io.netty.handler.codec.http.websocketx;
import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import java.net.URI; import java.net.URI;
public class WebSocketClientHandshaker07Test extends WebSocketClientHandshakerTest { public class WebSocketClientHandshaker07Test extends WebSocketClientHandshakerTest {
@Override @Override
protected WebSocketClientHandshaker newHandshaker(URI uri) { protected WebSocketClientHandshaker newHandshaker(URI uri, String subprotocol, HttpHeaders headers) {
return new WebSocketClientHandshaker07(uri, WebSocketVersion.V07, null, false, null, 1024); return new WebSocketClientHandshaker07(uri, WebSocketVersion.V07, subprotocol, false, headers, 1024);
} }
@Override @Override
protected CharSequence getOriginHeaderName() { protected CharSequence getOriginHeaderName() {
return HttpHeaderNames.SEC_WEBSOCKET_ORIGIN; return HttpHeaderNames.SEC_WEBSOCKET_ORIGIN;
} }
@Override
protected CharSequence getProtocolHeaderName() {
return HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL;
}
@Override
protected CharSequence[] getHandshakeHeaderNames() {
return new CharSequence[] {
HttpHeaderNames.UPGRADE,
HttpHeaderNames.CONNECTION,
HttpHeaderNames.SEC_WEBSOCKET_KEY,
HttpHeaderNames.HOST,
HttpHeaderNames.SEC_WEBSOCKET_ORIGIN,
HttpHeaderNames.SEC_WEBSOCKET_VERSION,
};
}
} }

View File

@ -15,11 +15,13 @@
*/ */
package io.netty.handler.codec.http.websocketx; package io.netty.handler.codec.http.websocketx;
import io.netty.handler.codec.http.HttpHeaders;
import java.net.URI; import java.net.URI;
public class WebSocketClientHandshaker08Test extends WebSocketClientHandshaker07Test { public class WebSocketClientHandshaker08Test extends WebSocketClientHandshaker07Test {
@Override @Override
protected WebSocketClientHandshaker newHandshaker(URI uri) { protected WebSocketClientHandshaker newHandshaker(URI uri, String subprotocol, HttpHeaders headers) {
return new WebSocketClientHandshaker08(uri, WebSocketVersion.V08, null, false, null, 1024); return new WebSocketClientHandshaker08(uri, WebSocketVersion.V08, subprotocol, false, headers, 1024);
} }
} }

View File

@ -15,11 +15,13 @@
*/ */
package io.netty.handler.codec.http.websocketx; package io.netty.handler.codec.http.websocketx;
import io.netty.handler.codec.http.HttpHeaders;
import java.net.URI; import java.net.URI;
public class WebSocketClientHandshaker13Test extends WebSocketClientHandshaker07Test { public class WebSocketClientHandshaker13Test extends WebSocketClientHandshaker07Test {
@Override @Override
protected WebSocketClientHandshaker newHandshaker(URI uri) { protected WebSocketClientHandshaker newHandshaker(URI uri, String subprotocol, HttpHeaders headers) {
return new WebSocketClientHandshaker13(uri, WebSocketVersion.V13, null, false, null, 1024); return new WebSocketClientHandshaker13(uri, WebSocketVersion.V13, subprotocol, false, headers, 1024);
} }
} }

View File

@ -21,11 +21,13 @@ import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.EmptyHttpHeaders; import io.netty.handler.codec.http.EmptyHttpHeaders;
import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpClientCodec; import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequestEncoder; import io.netty.handler.codec.http.HttpRequestEncoder;
import io.netty.handler.codec.http.HttpResponseDecoder; import io.netty.handler.codec.http.HttpResponseDecoder;
@ -35,14 +37,21 @@ import org.junit.Test;
import java.net.URI; import java.net.URI;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.*;
import static org.junit.Assert.assertTrue;
public abstract class WebSocketClientHandshakerTest { public abstract class WebSocketClientHandshakerTest {
protected abstract WebSocketClientHandshaker newHandshaker(URI uri); protected abstract WebSocketClientHandshaker newHandshaker(URI uri, String subprotocol, HttpHeaders headers);
protected WebSocketClientHandshaker newHandshaker(URI uri) {
return newHandshaker(uri, null, null);
}
protected abstract CharSequence getOriginHeaderName(); protected abstract CharSequence getOriginHeaderName();
protected abstract CharSequence getProtocolHeaderName();
protected abstract CharSequence[] getHandshakeHeaderNames();
@Test @Test
public void hostHeaderWs() { public void hostHeaderWs() {
for (String scheme : new String[]{"ws://", "http://"}) { for (String scheme : new String[]{"ws://", "http://"}) {
@ -292,4 +301,36 @@ public abstract class WebSocketClientHandshakerTest {
frame.release(); frame.release();
} }
} }
@Test
public void testDuplicateWebsocketHandshakeHeaders() {
URI uri = URI.create("ws://localhost:9999/foo");
HttpHeaders inputHeaders = new DefaultHttpHeaders();
String bogusSubProtocol = "bogusSubProtocol";
String bogusHeaderValue = "bogusHeaderValue";
// add values for the headers that are reserved for use in the websockets handshake
for (CharSequence header : getHandshakeHeaderNames()) {
inputHeaders.add(header, bogusHeaderValue);
}
inputHeaders.add(getProtocolHeaderName(), bogusSubProtocol);
String realSubProtocol = "realSubProtocol";
WebSocketClientHandshaker handshaker = newHandshaker(uri, realSubProtocol, inputHeaders);
FullHttpRequest request = handshaker.newHandshakeRequest();
HttpHeaders outputHeaders = request.headers();
// the header values passed in originally have been replaced with values generated by the Handshaker
for (CharSequence header : getHandshakeHeaderNames()) {
assertEquals(1, outputHeaders.getAll(header).size());
assertNotEquals(bogusHeaderValue, outputHeaders.get(header));
}
// the subprotocol header value is that of the subprotocol string passed into the Handshaker
assertEquals(1, outputHeaders.getAll(getProtocolHeaderName()).size());
assertEquals(realSubProtocol, outputHeaders.get(getProtocolHeaderName()));
request.release();
}
} }