Correct filling an origin header for WS client

Motivation:
An `origin`/`sec-websocket-origin` header value in websocket client is filling incorrect in some cases:
- Hostname is not converting to lower-case as prescribed by RFC 6354 (see [1]).
- Selecting a `http` scheme when source URI has `wss`/`https` scheme and non-standard port.

Modifications:
- Convert uri-host to lower-case.
- Use a `https` scheme if source URI scheme is `wss`/`https`, or if source scheme is null and port == 443.

Result:
Correct filling an `origin` header for WS client.

[1] https://tools.ietf.org/html/rfc6454#section-4
This commit is contained in:
Nikolay Fedorovskikh 2017-10-22 21:39:36 +05:00 committed by Norman Maurer
parent d93c607f93
commit dc98eae5a5
8 changed files with 159 additions and 79 deletions

View File

@ -39,6 +39,7 @@ import io.netty.util.internal.ThrowableUtil;
import java.net.URI;
import java.nio.channels.ClosedChannelException;
import java.util.Locale;
/**
* Base class for web socket client handshake implementations
@ -47,6 +48,9 @@ public abstract class WebSocketClientHandshaker {
private static final ClosedChannelException CLOSED_CHANNEL_EXCEPTION = ThrowableUtil.unknownStackTrace(
new ClosedChannelException(), WebSocketClientHandshaker.class, "processHandshake(...)");
private static final String HTTP_SCHEME_PREFIX = HttpScheme.HTTP + "://";
private static final String HTTPS_SCHEME_PREFIX = HttpScheme.HTTPS + "://";
private final URI uri;
private final WebSocketVersion version;
@ -443,18 +447,6 @@ public abstract class WebSocketClientHandshaker {
return path == null || path.isEmpty() ? "/" : path;
}
static int websocketPort(URI wsURL) {
// Format request
int wsPort = wsURL.getPort();
// check if the URI contained a port if not set the correct one depending on the schema.
// See https://github.com/netty/netty/pull/1558
if (wsPort == -1) {
return WebSocketScheme.WSS.name().contentEquals(wsURL.getScheme())
? WebSocketScheme.WSS.port() : WebSocketScheme.WS.port();
}
return wsPort;
}
static CharSequence websocketHostValue(URI wsURL) {
int port = wsURL.getPort();
if (port == -1) {
@ -477,14 +469,30 @@ public abstract class WebSocketClientHandshaker {
return NetUtil.toSocketAddressString(host, port);
}
static CharSequence websocketOriginValue(String host, int wsPort) {
String originValue = (wsPort == HttpScheme.HTTPS.port() ?
HttpScheme.HTTPS.name() : HttpScheme.HTTP.name()) + "://" + host;
if (wsPort != HttpScheme.HTTP.port() && wsPort != HttpScheme.HTTPS.port()) {
static CharSequence websocketOriginValue(URI wsURL) {
String scheme = wsURL.getScheme();
final String schemePrefix;
int port = wsURL.getPort();
final int defaultPort;
if (WebSocketScheme.WSS.name().contentEquals(scheme)
|| HttpScheme.HTTPS.name().contentEquals(scheme)
|| (scheme == null && port == WebSocketScheme.WSS.port())) {
schemePrefix = HTTPS_SCHEME_PREFIX;
defaultPort = WebSocketScheme.WSS.port();
} else {
schemePrefix = HTTP_SCHEME_PREFIX;
defaultPort = WebSocketScheme.WS.port();
}
// Convert uri-host to lower case (by RFC 6454, chapter 4 "Origin of a URI")
String host = wsURL.getHost().toLowerCase(Locale.US);
if (port != defaultPort && port != -1) {
// if the port is not standard (80/443) its needed to add the port to the header.
// See http://tools.ietf.org/html/rfc6454#section-6.2
return NetUtil.toSocketAddressString(originValue, wsPort);
return schemePrefix + NetUtil.toSocketAddressString(host, port);
}
return originValue;
return schemePrefix + host;
}
}

View File

@ -127,8 +127,6 @@ public class WebSocketClientHandshaker00 extends WebSocketClientHandshaker {
// Get path
URI wsURL = uri();
String path = rawPath(wsURL);
int wsPort = websocketPort(wsURL);
String host = wsURL.getHost();
// Format request
FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path);
@ -136,7 +134,7 @@ public class WebSocketClientHandshaker00 extends WebSocketClientHandshaker {
headers.add(HttpHeaderNames.UPGRADE, WEBSOCKET)
.add(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.add(HttpHeaderNames.HOST, websocketHostValue(wsURL))
.add(HttpHeaderNames.ORIGIN, websocketOriginValue(host, wsPort))
.add(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL))
.add(HttpHeaderNames.SEC_WEBSOCKET_KEY1, key1)
.add(HttpHeaderNames.SEC_WEBSOCKET_KEY2, key2);

View File

@ -141,9 +141,6 @@ public class WebSocketClientHandshaker07 extends WebSocketClientHandshaker {
key, expectedChallengeResponseString);
}
int wsPort = websocketPort(wsURL);
String host = wsURL.getHost();
// Format request
FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path);
HttpHeaders headers = request.headers();
@ -152,7 +149,7 @@ public class WebSocketClientHandshaker07 extends WebSocketClientHandshaker {
.add(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.add(HttpHeaderNames.SEC_WEBSOCKET_KEY, key)
.add(HttpHeaderNames.HOST, websocketHostValue(wsURL))
.add(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(host, wsPort));
.add(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL));
String expectedSubprotocol = expectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {

View File

@ -142,9 +142,6 @@ public class WebSocketClientHandshaker08 extends WebSocketClientHandshaker {
key, expectedChallengeResponseString);
}
int wsPort = websocketPort(wsURL);
String host = wsURL.getHost();
// Format request
FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path);
HttpHeaders headers = request.headers();
@ -153,7 +150,7 @@ public class WebSocketClientHandshaker08 extends WebSocketClientHandshaker {
.add(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.add(HttpHeaderNames.SEC_WEBSOCKET_KEY, key)
.add(HttpHeaderNames.HOST, websocketHostValue(wsURL))
.add(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(host, wsPort));
.add(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL));
String expectedSubprotocol = expectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {

View File

@ -143,8 +143,6 @@ public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker {
}
// Format request
int wsPort = websocketPort(wsURL);
String host = wsURL.getHost();
FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path);
HttpHeaders headers = request.headers();
@ -152,7 +150,7 @@ public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker {
.add(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.add(HttpHeaderNames.SEC_WEBSOCKET_KEY, key)
.add(HttpHeaderNames.HOST, websocketHostValue(wsURL))
.add(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(host, wsPort));
.add(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL));
String expectedSubprotocol = expectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {

View File

@ -15,6 +15,8 @@
*/
package io.netty.handler.codec.http.websocketx;
import io.netty.handler.codec.http.HttpHeaderNames;
import java.net.URI;
public class WebSocketClientHandshaker00Test extends WebSocketClientHandshakerTest {
@ -22,4 +24,9 @@ public class WebSocketClientHandshaker00Test extends WebSocketClientHandshakerTe
protected WebSocketClientHandshaker newHandshaker(URI uri) {
return new WebSocketClientHandshaker00(uri, WebSocketVersion.V00, null, null, 1024);
}
@Override
protected CharSequence getOriginHeaderName() {
return HttpHeaderNames.ORIGIN;
}
}

View File

@ -15,41 +15,18 @@
*/
package io.netty.handler.codec.http.websocketx;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaderNames;
import org.junit.Test;
import java.net.URI;
import static org.junit.Assert.assertEquals;
public class WebSocketClientHandshaker07Test extends WebSocketClientHandshakerTest {
@Override
protected WebSocketClientHandshaker newHandshaker(URI uri) {
return new WebSocketClientHandshaker07(uri, WebSocketVersion.V07, null, false, null, 1024);
}
@Test
public void testSecOriginWss() {
URI uri = URI.create("wss://localhost/path%20with%20ws");
WebSocketClientHandshaker handshaker = newHandshaker(uri);
FullHttpRequest request = handshaker.newHandshakeRequest();
try {
assertEquals("https://localhost", request.headers().get(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN));
} finally {
request.release();
}
}
@Test
public void testSecOriginWs() {
URI uri = URI.create("ws://localhost/path%20with%20ws");
WebSocketClientHandshaker handshaker = newHandshaker(uri);
FullHttpRequest request = handshaker.newHandshakeRequest();
try {
assertEquals("http://localhost", request.headers().get(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN));
} finally {
request.release();
}
@Override
protected CharSequence getOriginHeaderName() {
return HttpHeaderNames.SEC_WEBSOCKET_ORIGIN;
}
}

View File

@ -41,38 +41,136 @@ import static org.junit.Assert.assertTrue;
public abstract class WebSocketClientHandshakerTest {
protected abstract WebSocketClientHandshaker newHandshaker(URI uri);
@Test
public void testHostHeader() {
testHostHeaderDefaultHttp(URI.create("ws://localhost:80/"), "localhost");
testHostHeaderDefaultHttp(URI.create("http://localhost:80/"), "localhost");
testHostHeaderDefaultHttp(URI.create("ws://[::1]:80/"), "[::1]");
testHostHeaderDefaultHttp(URI.create("http://[::1]:80/"), "[::1]");
testHostHeaderDefaultHttp(URI.create("ws://localhost:9999/"), "localhost:9999");
testHostHeaderDefaultHttp(URI.create("http://localhost:9999/"), "localhost:9999");
testHostHeaderDefaultHttp(URI.create("ws://[::1]:9999/"), "[::1]:9999");
testHostHeaderDefaultHttp(URI.create("http://[::1]:9999/"), "[::1]:9999");
protected abstract CharSequence getOriginHeaderName();
testHostHeaderDefaultHttp(URI.create("wss://localhost:443/"), "localhost");
testHostHeaderDefaultHttp(URI.create("https://localhost:443/"), "localhost");
testHostHeaderDefaultHttp(URI.create("wss://[::1]:443/"), "[::1]");
testHostHeaderDefaultHttp(URI.create("https://[::1]:443/"), "[::1]");
testHostHeaderDefaultHttp(URI.create("wss://localhost:9999/"), "localhost:9999");
testHostHeaderDefaultHttp(URI.create("https://localhost:9999/"), "localhost:9999");
testHostHeaderDefaultHttp(URI.create("wss://[::1]:9999/"), "[::1]:9999");
testHostHeaderDefaultHttp(URI.create("https://[::1]:9999/"), "[::1]:9999");
@Test
public void hostHeaderWs() {
for (String scheme : new String[]{"ws://", "http://"}) {
for (String host : new String[]{"localhost", "127.0.0.1", "[::1]", "Netty.io"}) {
String enter = scheme + host;
testHostHeader(enter, host);
testHostHeader(enter + '/', host);
testHostHeader(enter + ":80", host);
testHostHeader(enter + ":443", host + ":443");
testHostHeader(enter + ":9999", host + ":9999");
testHostHeader(enter + "/path", host);
testHostHeader(enter + ":80/path", host);
testHostHeader(enter + ":443/path", host + ":443");
testHostHeader(enter + ":9999/path", host + ":9999");
}
}
}
private void testHostHeaderDefaultHttp(URI uri, String expected) {
WebSocketClientHandshaker handshaker = newHandshaker(uri);
@Test
public void hostHeaderWss() {
for (String scheme : new String[]{"wss://", "https://"}) {
for (String host : new String[]{"localhost", "127.0.0.1", "[::1]", "Netty.io"}) {
String enter = scheme + host;
testHostHeader(enter, host);
testHostHeader(enter + '/', host);
testHostHeader(enter + ":80", host + ":80");
testHostHeader(enter + ":443", host);
testHostHeader(enter + ":9999", host + ":9999");
testHostHeader(enter + "/path", host);
testHostHeader(enter + ":80/path", host + ":80");
testHostHeader(enter + ":443/path", host);
testHostHeader(enter + ":9999/path", host + ":9999");
}
}
}
@Test
public void hostHeaderWithoutScheme() {
testHostHeader("//localhost/", "localhost");
testHostHeader("//localhost/path", "localhost");
testHostHeader("//localhost:80/", "localhost:80");
testHostHeader("//localhost:443/", "localhost:443");
testHostHeader("//localhost:9999/", "localhost:9999");
}
@Test
public void originHeaderWs() {
for (String scheme : new String[]{"ws://", "http://"}) {
for (String host : new String[]{"localhost", "127.0.0.1", "[::1]", "NETTY.IO"}) {
String enter = scheme + host;
String expect = "http://" + host.toLowerCase();
testOriginHeader(enter, expect);
testOriginHeader(enter + '/', expect);
testOriginHeader(enter + ":80", expect);
testOriginHeader(enter + ":443", expect + ":443");
testOriginHeader(enter + ":9999", expect + ":9999");
testOriginHeader(enter + "/path%20with%20ws", expect);
testOriginHeader(enter + ":80/path%20with%20ws", expect);
testOriginHeader(enter + ":443/path%20with%20ws", expect + ":443");
testOriginHeader(enter + ":9999/path%20with%20ws", expect + ":9999");
}
}
}
@Test
public void originHeaderWss() {
for (String scheme : new String[]{"wss://", "https://"}) {
for (String host : new String[]{"localhost", "127.0.0.1", "[::1]", "NETTY.IO"}) {
String enter = scheme + host;
String expect = "https://" + host.toLowerCase();
testOriginHeader(enter, expect);
testOriginHeader(enter + '/', expect);
testOriginHeader(enter + ":80", expect + ":80");
testOriginHeader(enter + ":443", expect);
testOriginHeader(enter + ":9999", expect + ":9999");
testOriginHeader(enter + "/path%20with%20ws", expect);
testOriginHeader(enter + ":80/path%20with%20ws", expect + ":80");
testOriginHeader(enter + ":443/path%20with%20ws", expect);
testOriginHeader(enter + ":9999/path%20with%20ws", expect + ":9999");
}
}
}
@Test
public void originHeaderWithoutScheme() {
testOriginHeader("//localhost/", "http://localhost");
testOriginHeader("//localhost/path", "http://localhost");
// http scheme by port
testOriginHeader("//localhost:80/", "http://localhost");
testOriginHeader("//localhost:80/path", "http://localhost");
// https scheme by port
testOriginHeader("//localhost:443/", "https://localhost");
testOriginHeader("//localhost:443/path", "https://localhost");
// http scheme for non standard port
testOriginHeader("//localhost:9999/", "http://localhost:9999");
testOriginHeader("//localhost:9999/path", "http://localhost:9999");
// convert host to lower case
testOriginHeader("//LOCALHOST/", "http://localhost");
}
private void testHostHeader(String uri, String expected) {
testHeaderDefaultHttp(uri, HttpHeaderNames.HOST, expected);
}
private void testOriginHeader(String uri, String expected) {
testHeaderDefaultHttp(uri, getOriginHeaderName(), expected);
}
protected void testHeaderDefaultHttp(String uri, CharSequence header, String expectedValue) {
WebSocketClientHandshaker handshaker = newHandshaker(URI.create(uri));
FullHttpRequest request = handshaker.newHandshakeRequest();
try {
assertEquals(expected, request.headers().get(HttpHeaderNames.HOST));
assertEquals(expectedValue, request.headers().get(header));
} finally {
request.release();
}
}
@Test
@SuppressWarnings("deprecation")
public void testRawPath() {
URI uri = URI.create("ws://localhost:9999/path%20with%20ws");
WebSocketClientHandshaker handshaker = newHandshaker(uri);