Merge WebSocket extensions, close #10792 (#10951)

Motivation:

We currently append extensions to the user defined "sec-websocket-extensions" headers. This can cause duplicated entries.

Modifications:

* Replace existing `WebSocketExtensionUtil#appendExtension` private helper with a new `computeMergeExtensionsHeaderValue`. User defined parameters have higher precedence.
* Add tests (existing method wasn't tested)
* Reuse code for both client and server side (code was duplicated).

Result:

No more duplicated entries when user defined extensions overlap with the ones Netty generated.
This commit is contained in:
Stephane Landelle 2021-01-21 13:58:52 +01:00 committed by GitHub
parent 41e79835f2
commit 4fbbcf8702
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 103 additions and 30 deletions

View File

@ -63,14 +63,15 @@ public class WebSocketClientExtensionHandler extends ChannelDuplexHandler {
if (msg instanceof HttpRequest && WebSocketExtensionUtil.isWebsocketUpgrade(((HttpRequest) msg).headers())) { if (msg instanceof HttpRequest && WebSocketExtensionUtil.isWebsocketUpgrade(((HttpRequest) msg).headers())) {
HttpRequest request = (HttpRequest) msg; HttpRequest request = (HttpRequest) msg;
String headerValue = request.headers().getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS); String headerValue = request.headers().getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
List<WebSocketExtensionData> extraExtensions =
new ArrayList<WebSocketExtensionData>(extensionHandshakers.size());
for (WebSocketClientExtensionHandshaker extensionHandshaker : extensionHandshakers) { for (WebSocketClientExtensionHandshaker extensionHandshaker : extensionHandshakers) {
WebSocketExtensionData extensionData = extensionHandshaker.newRequestData(); extraExtensions.add(extensionHandshaker.newRequestData());
headerValue = WebSocketExtensionUtil.appendExtension(headerValue,
extensionData.name(), extensionData.parameters());
} }
String newHeaderValue = WebSocketExtensionUtil
.computeMergeExtensionsHeaderValue(headerValue, extraExtensions);
request.headers().set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, headerValue); request.headers().set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, newHeaderValue);
} }
super.write(ctx, msg, promise); super.write(ctx, msg, promise);

View File

@ -72,25 +72,53 @@ public final class WebSocketExtensionUtil {
} }
} }
static String appendExtension(String currentHeaderValue, String extensionName, static String computeMergeExtensionsHeaderValue(String userDefinedHeaderValue,
Map<String, String> extensionParameters) { List<WebSocketExtensionData> extraExtensions) {
List<WebSocketExtensionData> userDefinedExtensions =
userDefinedHeaderValue != null ?
extractExtensions(userDefinedHeaderValue) :
Collections.<WebSocketExtensionData>emptyList();
StringBuilder newHeaderValue = new StringBuilder( for (WebSocketExtensionData userDefined: userDefinedExtensions) {
currentHeaderValue != null ? currentHeaderValue.length() : extensionName.length() + 1); WebSocketExtensionData matchingExtra = null;
if (currentHeaderValue != null && !currentHeaderValue.trim().isEmpty()) { int i;
newHeaderValue.append(currentHeaderValue); for (i = 0; i < extraExtensions.size(); i ++) {
newHeaderValue.append(EXTENSION_SEPARATOR); WebSocketExtensionData extra = extraExtensions.get(i);
} if (extra.name().equals(userDefined.name())) {
newHeaderValue.append(extensionName); matchingExtra = extra;
for (Entry<String, String> extensionParameter : extensionParameters.entrySet()) { break;
newHeaderValue.append(PARAMETER_SEPARATOR);
newHeaderValue.append(extensionParameter.getKey());
if (extensionParameter.getValue() != null) {
newHeaderValue.append(PARAMETER_EQUAL);
newHeaderValue.append(extensionParameter.getValue());
} }
} }
return newHeaderValue.toString(); if (matchingExtra == null) {
extraExtensions.add(userDefined);
} else {
// merge with higher precedence to user defined parameters
Map<String, String> mergedParameters = new HashMap<String, String>(matchingExtra.parameters());
mergedParameters.putAll(userDefined.parameters());
extraExtensions.set(i, new WebSocketExtensionData(matchingExtra.name(), mergedParameters));
}
}
StringBuilder sb = new StringBuilder(150);
for (WebSocketExtensionData data: extraExtensions) {
sb.append(data.name());
for (Entry<String, String> parameter : data.parameters().entrySet()) {
sb.append(PARAMETER_SEPARATOR);
sb.append(parameter.getKey());
if (parameter.getValue() != null) {
sb.append(PARAMETER_EQUAL);
sb.append(parameter.getValue());
}
}
sb.append(EXTENSION_SEPARATOR);
}
if (!extraExtensions.isEmpty()) {
sb.setLength(sb.length() - EXTENSION_SEPARATOR.length());
}
return sb.toString();
} }
private WebSocketExtensionUtil() { private WebSocketExtensionUtil() {

View File

@ -124,13 +124,13 @@ public class WebSocketServerExtensionHandler extends ChannelDuplexHandler {
if (validExtensions != null) { if (validExtensions != null) {
String headerValue = headers.getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS); String headerValue = headers.getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
List<WebSocketExtensionData> extraExtensions =
new ArrayList<WebSocketExtensionData>(extensionHandshakers.size());
for (WebSocketServerExtension extension : validExtensions) { for (WebSocketServerExtension extension : validExtensions) {
WebSocketExtensionData extensionData = extension.newReponseData(); extraExtensions.add(extension.newReponseData());
headerValue = WebSocketExtensionUtil.appendExtension(headerValue,
extensionData.name(),
extensionData.parameters());
} }
String newHeaderValue = WebSocketExtensionUtil
.computeMergeExtensionsHeaderValue(headerValue, extraExtensions);
promise.addListener(new ChannelFutureListener() { promise.addListener(new ChannelFutureListener() {
@Override @Override
public void operationComplete(ChannelFuture future) { public void operationComplete(ChannelFuture future) {
@ -148,7 +148,7 @@ public class WebSocketServerExtensionHandler extends ChannelDuplexHandler {
}); });
if (headerValue != null) { if (headerValue != null) {
headers.set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, headerValue); headers.set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, newHeaderValue);
} }
} }

View File

@ -21,19 +21,63 @@ import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpHeaders;
import org.junit.Test; import org.junit.Test;
import java.util.List;
import static org.junit.Assert.*; import static org.junit.Assert.*;
import static io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionUtil.*;
public class WebSocketExtensionUtilTest { public class WebSocketExtensionUtilTest {
@Test @Test
public void testIsWebsocketUpgrade() { public void testIsWebsocketUpgrade() {
HttpHeaders headers = new DefaultHttpHeaders(); HttpHeaders headers = new DefaultHttpHeaders();
assertFalse(WebSocketExtensionUtil.isWebsocketUpgrade(headers)); assertFalse(isWebsocketUpgrade(headers));
headers.add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET); headers.add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET);
assertFalse(WebSocketExtensionUtil.isWebsocketUpgrade(headers)); assertFalse(isWebsocketUpgrade(headers));
headers.add(HttpHeaderNames.CONNECTION, "Keep-Alive, Upgrade"); headers.add(HttpHeaderNames.CONNECTION, "Keep-Alive, Upgrade");
assertTrue(WebSocketExtensionUtil.isWebsocketUpgrade(headers)); assertTrue(isWebsocketUpgrade(headers));
}
@Test
public void computeMergeExtensionsHeaderValueWhenNoUserDefinedHeader() {
List<WebSocketExtensionData> extras = extractExtensions("permessage-deflate; client_max_window_bits," +
"permessage-deflate; client_no_context_takeover; client_max_window_bits," +
"deflate-frame," +
"x-webkit-deflate-frame");
String newHeaderValue = computeMergeExtensionsHeaderValue(null, extras);
assertEquals("permessage-deflate;client_max_window_bits," +
"permessage-deflate;client_no_context_takeover;client_max_window_bits," +
"deflate-frame," +
"x-webkit-deflate-frame", newHeaderValue);
}
@Test
public void computeMergeExtensionsHeaderValueWhenNoConflictingUserDefinedHeader() {
List<WebSocketExtensionData> extras = extractExtensions("permessage-deflate; client_max_window_bits," +
"permessage-deflate; client_no_context_takeover; client_max_window_bits," +
"deflate-frame," +
"x-webkit-deflate-frame");
String newHeaderValue = computeMergeExtensionsHeaderValue("foo, bar", extras);
assertEquals("permessage-deflate;client_max_window_bits," +
"permessage-deflate;client_no_context_takeover;client_max_window_bits," +
"deflate-frame," +
"x-webkit-deflate-frame," +
"foo," +
"bar", newHeaderValue);
}
@Test
public void computeMergeExtensionsHeaderValueWhenConflictingUserDefinedHeader() {
List<WebSocketExtensionData> extras = extractExtensions("permessage-deflate; client_max_window_bits," +
"permessage-deflate; client_no_context_takeover; client_max_window_bits," +
"deflate-frame," +
"x-webkit-deflate-frame");
String newHeaderValue = computeMergeExtensionsHeaderValue("permessage-deflate; client_max_window_bits", extras);
assertEquals("permessage-deflate;client_max_window_bits," +
"permessage-deflate;client_no_context_takeover;client_max_window_bits," +
"deflate-frame," +
"x-webkit-deflate-frame", newHeaderValue);
} }
} }