Simplify WebSocket handlers constructor arguments hell #9698 (#9699)

### Motivation:

Introduction of `WebSocketDecoderConfig` made our server-side code more elegant and simpler for support.

However there is still some problem with maintenance and new features development for WebSocket codecs (`WebSocketServerProtocolHandler`, `WebSocketServerProtocolHandler`).

Particularly, it makes me ~~crying with blood~~ extremely sad to add new parameter and yet another one constructor into these handlers, when I want to contribute new feature.

### Modification:

I've extracted all parameters for client and server WebSocket handlers into config/builder structures, like it was made for decoders in PR #9116.

### Result:

* Fixes #9698: Simplify WebSocket handlers constructor arguments hell
* Unblock further development in this module (configurable close frame handling on server-side; automatic close-frame sending, when missed; memory leaks on protocol violations; etc...)

Bonuses:

* All defaults are gathered in one place and could be easily found/reused.
* New API greatly simplifies usage, but does NOT allow inheritance or modification.
* New API would simplify long-term maintenance of WebSockets module.

### Example

    WebSocketClientProtocolConfig config = WebSocketClientProtocolConfig.newBuilder()
        .webSocketUri("wss://localhost:8443/fx-spot")
        .subprotocol("trading")
        .handshakeTimeoutMillis(15000L)
        .build();
    ctx.pipeline().addLast(new WebSocketClientProtocolHandler(config));
This commit is contained in:
ursa 2019-10-29 19:48:18 +00:00 committed by Norman Maurer
parent 63729a310f
commit db84735975
8 changed files with 675 additions and 81 deletions

View File

@ -0,0 +1,329 @@
/*
* Copyright 2019 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.http.websocketx;
import io.netty.handler.codec.http.EmptyHttpHeaders;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler.ClientHandshakeStateEvent;
import io.netty.util.internal.ObjectUtil;
import java.net.URI;
import static io.netty.util.internal.ObjectUtil.checkPositive;
/**
* WebSocket server configuration.
*/
public final class WebSocketClientProtocolConfig {
static final WebSocketClientProtocolConfig DEFAULT = new WebSocketClientProtocolConfig(
URI.create("https://localhost/"), null, WebSocketVersion.V13, false,
EmptyHttpHeaders.INSTANCE, 65536, true, false, true, true, 10000L, -1, false);
private final URI webSocketUri;
private final String subprotocol;
private final WebSocketVersion version;
private final boolean allowExtensions;
private final HttpHeaders customHeaders;
private final int maxFramePayloadLength;
private final boolean performMasking;
private final boolean allowMaskMismatch;
private final boolean handleCloseFrames;
private final boolean dropPongFrames;
private final long handshakeTimeoutMillis;
private final long forceCloseTimeoutMillis;
private final boolean absoluteUpgradeUrl;
private WebSocketClientProtocolConfig(
URI webSocketUri,
String subprotocol,
WebSocketVersion version,
boolean allowExtensions,
HttpHeaders customHeaders,
int maxFramePayloadLength,
boolean performMasking,
boolean allowMaskMismatch,
boolean handleCloseFrames,
boolean dropPongFrames,
long handshakeTimeoutMillis,
long forceCloseTimeoutMillis,
boolean absoluteUpgradeUrl
) {
this.webSocketUri = webSocketUri;
this.subprotocol = subprotocol;
this.version = version;
this.allowExtensions = allowExtensions;
this.customHeaders = customHeaders;
this.maxFramePayloadLength = maxFramePayloadLength;
this.performMasking = performMasking;
this.allowMaskMismatch = allowMaskMismatch;
this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
this.handleCloseFrames = handleCloseFrames;
this.dropPongFrames = dropPongFrames;
this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis");
this.absoluteUpgradeUrl = absoluteUpgradeUrl;
}
public URI webSocketUri() {
return webSocketUri;
}
public String subprotocol() {
return subprotocol;
}
public WebSocketVersion version() {
return version;
}
public boolean allowExtensions() {
return allowExtensions;
}
public HttpHeaders customHeaders() {
return customHeaders;
}
public int maxFramePayloadLength() {
return maxFramePayloadLength;
}
public boolean performMasking() {
return performMasking;
}
public boolean allowMaskMismatch() {
return allowMaskMismatch;
}
public boolean handleCloseFrames() {
return handleCloseFrames;
}
public boolean dropPongFrames() {
return dropPongFrames;
}
public long handshakeTimeoutMillis() {
return handshakeTimeoutMillis;
}
public long forceCloseTimeoutMillis() {
return forceCloseTimeoutMillis;
}
public boolean absoluteUpgradeUrl() {
return absoluteUpgradeUrl;
}
@Override
public String toString() {
return "WebSocketClientProtocolConfig" +
" {webSocketUri=" + webSocketUri +
", subprotocol=" + subprotocol +
", version=" + version +
", allowExtensions=" + allowExtensions +
", customHeaders=" + customHeaders +
", maxFramePayloadLength=" + maxFramePayloadLength +
", performMasking=" + performMasking +
", allowMaskMismatch=" + allowMaskMismatch +
", handleCloseFrames=" + handleCloseFrames +
", dropPongFrames=" + dropPongFrames +
", handshakeTimeoutMillis=" + handshakeTimeoutMillis +
", forceCloseTimeoutMillis=" + forceCloseTimeoutMillis +
", absoluteUpgradeUrl=" + absoluteUpgradeUrl +
"}";
}
public Builder toBuilder() {
return new Builder(this);
}
public static Builder newBuilder() {
return new Builder(DEFAULT);
}
public static final class Builder {
private URI webSocketUri;
private String subprotocol;
private WebSocketVersion version;
private boolean allowExtensions;
private HttpHeaders customHeaders;
private int maxFramePayloadLength;
private boolean performMasking;
private boolean allowMaskMismatch;
private boolean handleCloseFrames;
private boolean dropPongFrames;
private long handshakeTimeoutMillis;
private long forceCloseTimeoutMillis;
private boolean absoluteUpgradeUrl;
private Builder(WebSocketClientProtocolConfig clientConfig) {
ObjectUtil.checkNotNull(clientConfig, "clientConfig");
this.webSocketUri = clientConfig.webSocketUri();
this.subprotocol = clientConfig.subprotocol();
this.version = clientConfig.version();
this.allowExtensions = clientConfig.allowExtensions();
this.customHeaders = clientConfig.customHeaders();
this.maxFramePayloadLength = clientConfig.maxFramePayloadLength();
this.performMasking = clientConfig.performMasking();
this.allowMaskMismatch = clientConfig.allowMaskMismatch();
this.handleCloseFrames = clientConfig.handleCloseFrames();
this.dropPongFrames = clientConfig.dropPongFrames();
this.handshakeTimeoutMillis = clientConfig.handshakeTimeoutMillis();
this.forceCloseTimeoutMillis = clientConfig.forceCloseTimeoutMillis();
this.absoluteUpgradeUrl = clientConfig.absoluteUpgradeUrl();
}
/**
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
*/
public Builder webSocketUri(String webSocketUri) {
return webSocketUri(URI.create(webSocketUri));
}
/**
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
*/
public Builder webSocketUri(URI webSocketUri) {
this.webSocketUri = webSocketUri;
return this;
}
/**
* Sub protocol request sent to the server.
*/
public Builder subprotocol(String subprotocol) {
this.subprotocol = subprotocol;
return this;
}
/**
* Version of web socket specification to use to connect to the server
*/
public Builder version(WebSocketVersion version) {
this.version = version;
return this;
}
/**
* Allow extensions to be used in the reserved bits of the web socket frame
*/
public Builder allowExtensions(boolean allowExtensions) {
this.allowExtensions = allowExtensions;
return this;
}
/**
* Map of custom headers to add to the client request
*/
public Builder customHeaders(HttpHeaders customHeaders) {
this.customHeaders = customHeaders;
return this;
}
/**
* Maximum length of a frame's payload
*/
public Builder maxFramePayloadLength(int maxFramePayloadLength) {
this.maxFramePayloadLength = maxFramePayloadLength;
return this;
}
/**
* Whether to mask all written websocket frames. This must be set to true in order to be fully compatible
* with the websocket specifications. Client applications that communicate with a non-standard server
* which doesn't require masking might set this to false to achieve a higher performance.
*/
public Builder performMasking(boolean performMasking) {
this.performMasking = performMasking;
return this;
}
/**
* When set to true, frames which are not masked properly according to the standard will still be accepted.
*/
public Builder allowMaskMismatch(boolean allowMaskMismatch) {
this.allowMaskMismatch = allowMaskMismatch;
return this;
}
/**
* {@code true} if close frames should not be forwarded and just close the channel
*/
public Builder handleCloseFrames(boolean handleCloseFrames) {
this.handleCloseFrames = handleCloseFrames;
return this;
}
/**
* {@code true} if pong frames should not be forwarded
*/
public Builder dropPongFrames(boolean dropPongFrames) {
this.dropPongFrames = dropPongFrames;
return this;
}
/**
* Handshake timeout in mills, when handshake timeout, will trigger user
* event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT}
*/
public Builder handshakeTimeoutMillis(long handshakeTimeoutMillis) {
this.handshakeTimeoutMillis = handshakeTimeoutMillis;
return this;
}
/**
* Close the connection if it was not closed by the server after timeout specified
*/
public Builder forceCloseTimeoutMillis(long forceCloseTimeoutMillis) {
this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
return this;
}
/**
* Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over clear HTTP
*/
public Builder absoluteUpgradeUrl(boolean absoluteUpgradeUrl) {
this.absoluteUpgradeUrl = absoluteUpgradeUrl;
return this;
}
/**
* Build unmodifiable client protocol configuration.
*/
public WebSocketClientProtocolConfig build() {
return new WebSocketClientProtocolConfig(
webSocketUri,
subprotocol,
version,
allowExtensions,
customHeaders,
maxFramePayloadLength,
performMasking,
allowMaskMismatch,
handleCloseFrames,
dropPongFrames,
handshakeTimeoutMillis,
forceCloseTimeoutMillis,
absoluteUpgradeUrl
);
}
}
}

View File

@ -23,6 +23,7 @@ import io.netty.handler.codec.http.HttpHeaders;
import java.net.URI; import java.net.URI;
import java.util.List; import java.util.List;
import static io.netty.handler.codec.http.websocketx.WebSocketClientProtocolConfig.DEFAULT;
import static io.netty.util.internal.ObjectUtil.*; import static io.netty.util.internal.ObjectUtil.*;
/** /**
@ -40,8 +41,6 @@ import static io.netty.util.internal.ObjectUtil.*;
* {@link ClientHandshakeStateEvent#HANDSHAKE_ISSUED} or {@link ClientHandshakeStateEvent#HANDSHAKE_COMPLETE}. * {@link ClientHandshakeStateEvent#HANDSHAKE_ISSUED} or {@link ClientHandshakeStateEvent#HANDSHAKE_COMPLETE}.
*/ */
public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler { public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
private static final long DEFAULT_HANDSHAKE_TIMEOUT_MS = 10000L;
private final WebSocketClientHandshaker handshaker; private final WebSocketClientHandshaker handshaker;
private final boolean handleCloseFrames; private final boolean handleCloseFrames;
private final long handshakeTimeoutMillis; private final long handshakeTimeoutMillis;
@ -73,6 +72,30 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
HANDSHAKE_COMPLETE HANDSHAKE_COMPLETE
} }
/**
* Base constructor
*
* @param clientConfig
* Client protocol configuration.
*/
public WebSocketClientProtocolHandler(WebSocketClientProtocolConfig clientConfig) {
super(checkNotNull(clientConfig, "clientConfig").dropPongFrames());
this.handshaker = WebSocketClientHandshakerFactory.newHandshaker(
clientConfig.webSocketUri(),
clientConfig.version(),
clientConfig.subprotocol(),
clientConfig.allowExtensions(),
clientConfig.customHeaders(),
clientConfig.maxFramePayloadLength(),
clientConfig.performMasking(),
clientConfig.allowMaskMismatch(),
clientConfig.forceCloseTimeoutMillis(),
clientConfig.absoluteUpgradeUrl()
);
this.handleCloseFrames = clientConfig.handleCloseFrames();
this.handshakeTimeoutMillis = clientConfig.handshakeTimeoutMillis();
}
/** /**
* Base constructor * Base constructor
* *
@ -101,8 +124,8 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
boolean allowExtensions, HttpHeaders customHeaders, boolean allowExtensions, HttpHeaders customHeaders,
int maxFramePayloadLength, boolean handleCloseFrames, int maxFramePayloadLength, boolean handleCloseFrames,
boolean performMasking, boolean allowMaskMismatch) { boolean performMasking, boolean allowMaskMismatch) {
this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength,
maxFramePayloadLength, handleCloseFrames, performMasking, allowMaskMismatch, DEFAULT_HANDSHAKE_TIMEOUT_MS); handleCloseFrames, performMasking, allowMaskMismatch, DEFAULT.handshakeTimeoutMillis());
} }
/** /**
@ -163,7 +186,7 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
boolean allowExtensions, HttpHeaders customHeaders, boolean allowExtensions, HttpHeaders customHeaders,
int maxFramePayloadLength, boolean handleCloseFrames) { int maxFramePayloadLength, boolean handleCloseFrames) {
this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength,
handleCloseFrames, DEFAULT_HANDSHAKE_TIMEOUT_MS); handleCloseFrames, DEFAULT.handshakeTimeoutMillis());
} }
/** /**
@ -190,7 +213,7 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,
boolean handleCloseFrames, long handshakeTimeoutMillis) { boolean handleCloseFrames, long handshakeTimeoutMillis) {
this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength,
handleCloseFrames, true, false, handshakeTimeoutMillis); handleCloseFrames, DEFAULT.performMasking(), DEFAULT.allowMaskMismatch(), handshakeTimeoutMillis);
} }
/** /**
@ -211,8 +234,8 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version, String subprotocol, public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, boolean allowExtensions, HttpHeaders customHeaders,
int maxFramePayloadLength) { int maxFramePayloadLength) {
this(webSocketURL, version, subprotocol, this(webSocketURL, version, subprotocol, allowExtensions,
allowExtensions, customHeaders, maxFramePayloadLength, DEFAULT_HANDSHAKE_TIMEOUT_MS); customHeaders, maxFramePayloadLength, DEFAULT.handshakeTimeoutMillis());
} }
/** /**
@ -236,8 +259,8 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version, String subprotocol, public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, boolean allowExtensions, HttpHeaders customHeaders,
int maxFramePayloadLength, long handshakeTimeoutMillis) { int maxFramePayloadLength, long handshakeTimeoutMillis) {
this(webSocketURL, version, subprotocol, this(webSocketURL, version, subprotocol, allowExtensions, customHeaders,
allowExtensions, customHeaders, maxFramePayloadLength, true, handshakeTimeoutMillis); maxFramePayloadLength, DEFAULT.handleCloseFrames(), handshakeTimeoutMillis);
} }
/** /**
@ -250,7 +273,7 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
* {@code true} if close frames should not be forwarded and just close the channel * {@code true} if close frames should not be forwarded and just close the channel
*/ */
public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames) { public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames) {
this(handshaker, handleCloseFrames, DEFAULT_HANDSHAKE_TIMEOUT_MS); this(handshaker, handleCloseFrames, DEFAULT.handshakeTimeoutMillis());
} }
/** /**
@ -267,7 +290,7 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
*/ */
public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames, public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames,
long handshakeTimeoutMillis) { long handshakeTimeoutMillis) {
this(handshaker, handleCloseFrames, true, handshakeTimeoutMillis); this(handshaker, handleCloseFrames, DEFAULT.dropPongFrames(), handshakeTimeoutMillis);
} }
/** /**
@ -283,7 +306,7 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
*/ */
public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames, public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames,
boolean dropPongFrames) { boolean dropPongFrames) {
this(handshaker, handleCloseFrames, dropPongFrames, DEFAULT_HANDSHAKE_TIMEOUT_MS); this(handshaker, handleCloseFrames, dropPongFrames, DEFAULT.handshakeTimeoutMillis());
} }
/** /**
@ -316,7 +339,7 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
* was established to the remote peer. * was established to the remote peer.
*/ */
public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker) { public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker) {
this(handshaker, DEFAULT_HANDSHAKE_TIMEOUT_MS); this(handshaker, DEFAULT.handshakeTimeoutMillis());
} }
/** /**
@ -330,7 +353,7 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
* event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT} * event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT}
*/ */
public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, long handshakeTimeoutMillis) { public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, long handshakeTimeoutMillis) {
this(handshaker, true, handshakeTimeoutMillis); this(handshaker, DEFAULT.handleCloseFrames(), handshakeTimeoutMillis);
} }
@Override @Override

View File

@ -22,6 +22,9 @@ import io.netty.util.internal.ObjectUtil;
*/ */
public final class WebSocketDecoderConfig { public final class WebSocketDecoderConfig {
static final WebSocketDecoderConfig DEFAULT =
new WebSocketDecoderConfig(65536, true, false, false, true, true);
private final int maxFramePayloadLength; private final int maxFramePayloadLength;
private final boolean expectMaskedFrames; private final boolean expectMaskedFrames;
private final boolean allowMaskMismatch; private final boolean allowMaskMismatch;
@ -102,20 +105,16 @@ public final class WebSocketDecoderConfig {
} }
public static Builder newBuilder() { public static Builder newBuilder() {
return new Builder(); return new Builder(DEFAULT);
} }
public static final class Builder { public static final class Builder {
private int maxFramePayloadLength = 65536; private int maxFramePayloadLength;
private boolean expectMaskedFrames = true; private boolean expectMaskedFrames;
private boolean allowMaskMismatch; private boolean allowMaskMismatch;
private boolean allowExtensions; private boolean allowExtensions;
private boolean closeOnProtocolViolation = true; private boolean closeOnProtocolViolation;
private boolean withUTF8Validator = true; private boolean withUTF8Validator;
private Builder() {
/* No-op */
}
private Builder(WebSocketDecoderConfig decoderConfig) { private Builder(WebSocketDecoderConfig decoderConfig) {
ObjectUtil.checkNotNull(decoderConfig, "decoderConfig"); ObjectUtil.checkNotNull(decoderConfig, "decoderConfig");

View File

@ -0,0 +1,238 @@
/*
* Copyright 2019 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.http.websocketx;
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler.ClientHandshakeStateEvent;
import io.netty.util.internal.ObjectUtil;
import static io.netty.util.internal.ObjectUtil.checkPositive;
/**
* WebSocket server configuration.
*/
public final class WebSocketServerProtocolConfig {
static final WebSocketServerProtocolConfig DEFAULT =
new WebSocketServerProtocolConfig("/", null, false, 10000L, true, true, WebSocketDecoderConfig.DEFAULT);
private final String websocketPath;
private final String subprotocols;
private final boolean checkStartsWith;
private final long handshakeTimeoutMillis;
private final boolean handleCloseFrames;
private final boolean dropPongFrames;
private final WebSocketDecoderConfig decoderConfig;
private WebSocketServerProtocolConfig(
String websocketPath,
String subprotocols,
boolean checkStartsWith,
long handshakeTimeoutMillis,
boolean handleCloseFrames,
boolean dropPongFrames,
WebSocketDecoderConfig decoderConfig
) {
this.websocketPath = websocketPath;
this.subprotocols = subprotocols;
this.checkStartsWith = checkStartsWith;
this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis");
this.handleCloseFrames = handleCloseFrames;
this.dropPongFrames = dropPongFrames;
this.decoderConfig = decoderConfig == null ? WebSocketDecoderConfig.DEFAULT : decoderConfig;
}
public String websocketPath() {
return websocketPath;
}
public String subprotocols() {
return subprotocols;
}
public boolean checkStartsWith() {
return checkStartsWith;
}
public long handshakeTimeoutMillis() {
return handshakeTimeoutMillis;
}
public boolean handleCloseFrames() {
return handleCloseFrames;
}
public boolean dropPongFrames() {
return dropPongFrames;
}
public WebSocketDecoderConfig decoderConfig() {
return decoderConfig;
}
@Override
public String toString() {
return "WebSocketServerProtocolConfig" +
" {websocketPath=" + websocketPath +
", subprotocols=" + subprotocols +
", checkStartsWith=" + checkStartsWith +
", handshakeTimeoutMillis=" + handshakeTimeoutMillis +
", handleCloseFrames=" + handleCloseFrames +
", dropPongFrames=" + dropPongFrames +
", decoderConfig=" + decoderConfig +
"}";
}
public Builder toBuilder() {
return new Builder(this);
}
public static Builder newBuilder() {
return new Builder(DEFAULT);
}
public static final class Builder {
private String websocketPath;
private String subprotocols;
private boolean checkStartsWith;
private long handshakeTimeoutMillis;
private boolean handleCloseFrames;
private boolean dropPongFrames;
private WebSocketDecoderConfig decoderConfig;
private WebSocketDecoderConfig.Builder decoderConfigBuilder;
private Builder(WebSocketServerProtocolConfig serverConfig) {
ObjectUtil.checkNotNull(serverConfig, "serverConfig");
websocketPath = serverConfig.websocketPath();
subprotocols = serverConfig.subprotocols();
checkStartsWith = serverConfig.checkStartsWith();
handshakeTimeoutMillis = serverConfig.handshakeTimeoutMillis();
handleCloseFrames = serverConfig.handleCloseFrames();
dropPongFrames = serverConfig.dropPongFrames();
decoderConfig = serverConfig.decoderConfig();
}
/**
* URI path component to handle websocket upgrade requests on.
*/
public Builder websocketPath(String websocketPath) {
this.websocketPath = websocketPath;
return this;
}
/**
* CSV of supported protocols
*/
public Builder subprotocols(String subprotocols) {
this.subprotocols = subprotocols;
return this;
}
/**
* {@code true} to handle all requests, where URI path component starts from
* {@link WebSocketServerProtocolConfig#websocketPath()}, {@code false} for exact match (default).
*/
public Builder checkStartsWith(boolean checkStartsWith) {
this.checkStartsWith = checkStartsWith;
return this;
}
/**
* Handshake timeout in mills, when handshake timeout, will trigger user
* event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT}
*/
public Builder handshakeTimeoutMillis(long handshakeTimeoutMillis) {
this.handshakeTimeoutMillis = handshakeTimeoutMillis;
return this;
}
/**
* {@code true} if close frames should not be forwarded and just close the channel
*/
public Builder handleCloseFrames(boolean handleCloseFrames) {
this.handleCloseFrames = handleCloseFrames;
return this;
}
/**
* {@code true} if pong frames should not be forwarded
*/
public Builder dropPongFrames(boolean dropPongFrames) {
this.dropPongFrames = dropPongFrames;
return this;
}
/**
* Frames decoder configuration.
*/
public Builder decoderConfig(WebSocketDecoderConfig decoderConfig) {
this.decoderConfig = decoderConfig == null ? WebSocketDecoderConfig.DEFAULT : decoderConfig;
this.decoderConfigBuilder = null;
return this;
}
private WebSocketDecoderConfig.Builder decoderConfigBuilder() {
if (decoderConfigBuilder == null) {
decoderConfigBuilder = decoderConfig.toBuilder();
}
return decoderConfigBuilder;
}
public Builder maxFramePayloadLength(int maxFramePayloadLength) {
decoderConfigBuilder().maxFramePayloadLength(maxFramePayloadLength);
return this;
}
public Builder expectMaskedFrames(boolean expectMaskedFrames) {
decoderConfigBuilder().expectMaskedFrames(expectMaskedFrames);
return this;
}
public Builder allowMaskMismatch(boolean allowMaskMismatch) {
decoderConfigBuilder().allowMaskMismatch(allowMaskMismatch);
return this;
}
public Builder allowExtensions(boolean allowExtensions) {
decoderConfigBuilder().allowExtensions(allowExtensions);
return this;
}
public Builder closeOnProtocolViolation(boolean closeOnProtocolViolation) {
decoderConfigBuilder().closeOnProtocolViolation(closeOnProtocolViolation);
return this;
}
public Builder withUTF8Validator(boolean withUTF8Validator) {
decoderConfigBuilder().withUTF8Validator(withUTF8Validator);
return this;
}
/**
* Build unmodifiable server protocol configuration.
*/
public WebSocketServerProtocolConfig build() {
return new WebSocketServerProtocolConfig(
websocketPath,
subprotocols,
checkStartsWith,
handshakeTimeoutMillis,
handleCloseFrames,
dropPongFrames,
decoderConfigBuilder == null ? decoderConfig : decoderConfigBuilder.build()
);
}
}
}

View File

@ -32,6 +32,7 @@ import io.netty.util.AttributeKey;
import java.util.List; import java.util.List;
import static io.netty.handler.codec.http.HttpVersion.*; import static io.netty.handler.codec.http.HttpVersion.*;
import static io.netty.handler.codec.http.websocketx.WebSocketServerProtocolConfig.DEFAULT;
import static io.netty.util.internal.ObjectUtil.*; import static io.netty.util.internal.ObjectUtil.*;
/** /**
@ -102,16 +103,21 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
private static final AttributeKey<WebSocketServerHandshaker> HANDSHAKER_ATTR_KEY = private static final AttributeKey<WebSocketServerHandshaker> HANDSHAKER_ATTR_KEY =
AttributeKey.valueOf(WebSocketServerHandshaker.class, "HANDSHAKER"); AttributeKey.valueOf(WebSocketServerHandshaker.class, "HANDSHAKER");
private static final long DEFAULT_HANDSHAKE_TIMEOUT_MS = 10000L; private final WebSocketServerProtocolConfig serverConfig;
private final String websocketPath; /**
private final String subprotocols; * Base constructor
private final boolean checkStartsWith; *
private final long handshakeTimeoutMillis; * @param serverConfig
private final WebSocketDecoderConfig decoderConfig; * Server protocol configuration.
*/
public WebSocketServerProtocolHandler(WebSocketServerProtocolConfig serverConfig) {
super(checkNotNull(serverConfig, "serverConfig").dropPongFrames());
this.serverConfig = serverConfig;
}
public WebSocketServerProtocolHandler(String websocketPath) { public WebSocketServerProtocolHandler(String websocketPath) {
this(websocketPath, DEFAULT_HANDSHAKE_TIMEOUT_MS); this(websocketPath, DEFAULT.handshakeTimeoutMillis());
} }
public WebSocketServerProtocolHandler(String websocketPath, long handshakeTimeoutMillis) { public WebSocketServerProtocolHandler(String websocketPath, long handshakeTimeoutMillis) {
@ -119,7 +125,7 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
} }
public WebSocketServerProtocolHandler(String websocketPath, boolean checkStartsWith) { public WebSocketServerProtocolHandler(String websocketPath, boolean checkStartsWith) {
this(websocketPath, checkStartsWith, DEFAULT_HANDSHAKE_TIMEOUT_MS); this(websocketPath, checkStartsWith, DEFAULT.handshakeTimeoutMillis());
} }
public WebSocketServerProtocolHandler(String websocketPath, boolean checkStartsWith, long handshakeTimeoutMillis) { public WebSocketServerProtocolHandler(String websocketPath, boolean checkStartsWith, long handshakeTimeoutMillis) {
@ -127,7 +133,7 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
} }
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols) { public WebSocketServerProtocolHandler(String websocketPath, String subprotocols) {
this(websocketPath, subprotocols, DEFAULT_HANDSHAKE_TIMEOUT_MS); this(websocketPath, subprotocols, DEFAULT.handshakeTimeoutMillis());
} }
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, long handshakeTimeoutMillis) { public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, long handshakeTimeoutMillis) {
@ -135,7 +141,7 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
} }
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions) { public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions) {
this(websocketPath, subprotocols, allowExtensions, DEFAULT_HANDSHAKE_TIMEOUT_MS); this(websocketPath, subprotocols, allowExtensions, DEFAULT.handshakeTimeoutMillis());
} }
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions,
@ -145,7 +151,7 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize) { boolean allowExtensions, int maxFrameSize) {
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, DEFAULT_HANDSHAKE_TIMEOUT_MS); this(websocketPath, subprotocols, allowExtensions, maxFrameSize, DEFAULT.handshakeTimeoutMillis());
} }
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
@ -156,7 +162,7 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch) { boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch) {
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch,
DEFAULT_HANDSHAKE_TIMEOUT_MS); DEFAULT.handshakeTimeoutMillis());
} }
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions,
@ -168,7 +174,7 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith) { boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith) {
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith, this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith,
DEFAULT_HANDSHAKE_TIMEOUT_MS); DEFAULT.handshakeTimeoutMillis());
} }
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
@ -182,7 +188,7 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch,
boolean checkStartsWith, boolean dropPongFrames) { boolean checkStartsWith, boolean dropPongFrames) {
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith, this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith,
dropPongFrames, DEFAULT_HANDSHAKE_TIMEOUT_MS); dropPongFrames, DEFAULT.handshakeTimeoutMillis());
} }
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions,
@ -199,12 +205,14 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean checkStartsWith, public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean checkStartsWith,
boolean dropPongFrames, long handshakeTimeoutMillis, boolean dropPongFrames, long handshakeTimeoutMillis,
WebSocketDecoderConfig decoderConfig) { WebSocketDecoderConfig decoderConfig) {
super(dropPongFrames); this(WebSocketServerProtocolConfig.newBuilder()
this.websocketPath = websocketPath; .websocketPath(websocketPath)
this.subprotocols = subprotocols; .subprotocols(subprotocols)
this.checkStartsWith = checkStartsWith; .checkStartsWith(checkStartsWith)
this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis"); .handshakeTimeoutMillis(handshakeTimeoutMillis)
this.decoderConfig = checkNotNull(decoderConfig, "decoderConfig"); .dropPongFrames(dropPongFrames)
.decoderConfig(decoderConfig)
.build());
} }
@Override @Override
@ -213,10 +221,9 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
if (cp.get(WebSocketServerProtocolHandshakeHandler.class) == null) { if (cp.get(WebSocketServerProtocolHandshakeHandler.class) == null) {
// Add the WebSocketHandshakeHandler before this one. // Add the WebSocketHandshakeHandler before this one.
cp.addBefore(ctx.name(), WebSocketServerProtocolHandshakeHandler.class.getName(), cp.addBefore(ctx.name(), WebSocketServerProtocolHandshakeHandler.class.getName(),
new WebSocketServerProtocolHandshakeHandler( new WebSocketServerProtocolHandshakeHandler(serverConfig));
websocketPath, subprotocols, checkStartsWith, handshakeTimeoutMillis, decoderConfig));
} }
if (decoderConfig.withUTF8Validator() && cp.get(Utf8FrameValidator.class) == null) { if (serverConfig.decoderConfig().withUTF8Validator() && cp.get(Utf8FrameValidator.class) == null) {
// Add the UFT8 checking before this one. // Add the UFT8 checking before this one.
cp.addBefore(ctx.name(), Utf8FrameValidator.class.getName(), cp.addBefore(ctx.name(), Utf8FrameValidator.class.getName(),
new Utf8FrameValidator()); new Utf8FrameValidator());
@ -225,7 +232,7 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
@Override @Override
protected void decode(ChannelHandlerContext ctx, WebSocketFrame frame, List<Object> out) throws Exception { protected void decode(ChannelHandlerContext ctx, WebSocketFrame frame, List<Object> out) throws Exception {
if (frame instanceof CloseWebSocketFrame) { if (serverConfig.handleCloseFrames() && frame instanceof CloseWebSocketFrame) {
WebSocketServerHandshaker handshaker = getHandshaker(ctx.channel()); WebSocketServerHandshaker handshaker = getHandshaker(ctx.channel());
if (handshaker != null) { if (handshaker != null) {
frame.retain(); frame.retain();

View File

@ -44,21 +44,12 @@ import static io.netty.util.internal.ObjectUtil.*;
*/ */
class WebSocketServerProtocolHandshakeHandler implements ChannelInboundHandler { class WebSocketServerProtocolHandshakeHandler implements ChannelInboundHandler {
private final String websocketPath; private final WebSocketServerProtocolConfig serverConfig;
private final String subprotocols;
private final boolean checkStartsWith;
private final long handshakeTimeoutMillis;
private final WebSocketDecoderConfig decoderConfig;
private ChannelHandlerContext ctx; private ChannelHandlerContext ctx;
private ChannelPromise handshakePromise; private ChannelPromise handshakePromise;
WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols, WebSocketServerProtocolHandshakeHandler(WebSocketServerProtocolConfig serverConfig) {
boolean checkStartsWith, long handshakeTimeoutMillis, WebSocketDecoderConfig decoderConfig) { this.serverConfig = checkNotNull(serverConfig, "serverConfig");
this.websocketPath = websocketPath;
this.subprotocols = subprotocols;
this.checkStartsWith = checkStartsWith;
this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis");
this.decoderConfig = checkNotNull(decoderConfig, "decoderConfig");
} }
@Override @Override
@ -82,7 +73,8 @@ class WebSocketServerProtocolHandshakeHandler implements ChannelInboundHandler {
} }
final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory( final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
getWebSocketLocation(ctx.pipeline(), req, websocketPath), subprotocols, decoderConfig); getWebSocketLocation(ctx.pipeline(), req, serverConfig.websocketPath()),
serverConfig.subprotocols(), serverConfig.decoderConfig());
final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req); final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);
final ChannelPromise localHandshakePromise = handshakePromise; final ChannelPromise localHandshakePromise = handshakePromise;
if (handshaker == null) { if (handshaker == null) {
@ -120,7 +112,8 @@ class WebSocketServerProtocolHandshakeHandler implements ChannelInboundHandler {
} }
private boolean isNotWebSocketPath(FullHttpRequest req) { private boolean isNotWebSocketPath(FullHttpRequest req) {
return checkStartsWith ? !req.uri().startsWith(websocketPath) : !req.uri().equals(websocketPath); String websocketPath = serverConfig.websocketPath();
return serverConfig.checkStartsWith() ? !req.uri().startsWith(websocketPath) : !req.uri().equals(websocketPath);
} }
private static void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) { private static void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) {
@ -142,7 +135,7 @@ class WebSocketServerProtocolHandshakeHandler implements ChannelInboundHandler {
private void applyHandshakeTimeout() { private void applyHandshakeTimeout() {
final ChannelPromise localHandshakePromise = handshakePromise; final ChannelPromise localHandshakePromise = handshakePromise;
final long handshakeTimeoutMillis = this.handshakeTimeoutMillis; final long handshakeTimeoutMillis = serverConfig.handshakeTimeoutMillis();
if (handshakeTimeoutMillis <= 0 || localHandshakePromise.isDone()) { if (handshakeTimeoutMillis <= 0 || localHandshakePromise.isDone()) {
return; return;
} }

View File

@ -274,12 +274,25 @@ public class WebSocketHandshakeHandOverTest {
} }
private static EmbeddedChannel createClientChannel(ChannelHandler handler) throws Exception { private static EmbeddedChannel createClientChannel(ChannelHandler handler) throws Exception {
return createClientChannel(handler, WebSocketClientProtocolConfig.newBuilder()
.webSocketUri("ws://localhost:1234/test")
.subprotocol("test-proto-2")
.build());
}
private static EmbeddedChannel createClientChannel(ChannelHandler handler, long timeoutMillis) throws Exception {
return createClientChannel(handler, WebSocketClientProtocolConfig.newBuilder()
.webSocketUri("ws://localhost:1234/test")
.subprotocol("test-proto-2")
.handshakeTimeoutMillis(timeoutMillis)
.build());
}
private static EmbeddedChannel createClientChannel(ChannelHandler handler, WebSocketClientProtocolConfig config) {
return new EmbeddedChannel( return new EmbeddedChannel(
new HttpClientCodec(), new HttpClientCodec(),
new HttpObjectAggregator(8192), new HttpObjectAggregator(8192),
new WebSocketClientProtocolHandler(new URI("ws://localhost:1234/test"), new WebSocketClientProtocolHandler(config),
WebSocketVersion.V13, "test-proto-2",
false, null, 65536),
handler); handler);
} }
@ -309,14 +322,4 @@ public class WebSocketHandshakeHandOverTest {
webSocketHandler, webSocketHandler,
handler); handler);
} }
private static EmbeddedChannel createClientChannel(ChannelHandler handler, long timeoutMillis) throws Exception {
return new EmbeddedChannel(
new HttpClientCodec(),
new HttpObjectAggregator(8192),
new WebSocketClientProtocolHandler(new URI("ws://localhost:1234/test"),
WebSocketVersion.V13, "test-proto-2",
false, null, 65536, timeoutMillis),
handler);
}
} }

View File

@ -144,12 +144,13 @@ public class WebSocketServerProtocolHandlerTest {
@Test @Test
public void testCreateUTF8Validator() { public void testCreateUTF8Validator() {
WebSocketDecoderConfig config = WebSocketDecoderConfig.newBuilder() WebSocketServerProtocolConfig config = WebSocketServerProtocolConfig.newBuilder()
.websocketPath("/test")
.withUTF8Validator(true) .withUTF8Validator(true)
.build(); .build();
EmbeddedChannel ch = new EmbeddedChannel( EmbeddedChannel ch = new EmbeddedChannel(
new WebSocketServerProtocolHandler("/test", null, false, false, 1000L, config), new WebSocketServerProtocolHandler(config),
new HttpRequestDecoder(), new HttpRequestDecoder(),
new HttpResponseEncoder(), new HttpResponseEncoder(),
new MockOutboundHandler()); new MockOutboundHandler());
@ -164,12 +165,13 @@ public class WebSocketServerProtocolHandlerTest {
@Test @Test
public void testDoNotCreateUTF8Validator() { public void testDoNotCreateUTF8Validator() {
WebSocketDecoderConfig config = WebSocketDecoderConfig.newBuilder() WebSocketServerProtocolConfig config = WebSocketServerProtocolConfig.newBuilder()
.websocketPath("/test")
.withUTF8Validator(false) .withUTF8Validator(false)
.build(); .build();
EmbeddedChannel ch = new EmbeddedChannel( EmbeddedChannel ch = new EmbeddedChannel(
new WebSocketServerProtocolHandler("/test", null, false, false, 1000L, config), new WebSocketServerProtocolHandler(config),
new HttpRequestDecoder(), new HttpRequestDecoder(),
new HttpResponseEncoder(), new HttpResponseEncoder(),
new MockOutboundHandler()); new MockOutboundHandler());
@ -210,7 +212,7 @@ public class WebSocketServerProtocolHandlerTest {
private EmbeddedChannel createChannel(ChannelHandler handler) { private EmbeddedChannel createChannel(ChannelHandler handler) {
return new EmbeddedChannel( return new EmbeddedChannel(
new WebSocketServerProtocolHandler("/test", null, false), new WebSocketServerProtocolHandler("/test"),
new HttpRequestDecoder(), new HttpRequestDecoder(),
new HttpResponseEncoder(), new HttpResponseEncoder(),
new MockOutboundHandler(), new MockOutboundHandler(),