Provide new client and server websocket handshake exceptions (#10646)

Motivation:

At the moment we have only one base `WebSocketHandshakeException` for handling WebSocket upgrade issues.
Unfortunately, this message contains only a string message about the cause of the failure, which is inconvenient in handling.

Modification:

Provide new `WebSocketClientHandshakeException` with `HttpResponse` field  and `WebSocketServerHandshakeException` with `HttpRequest` field both of them without content for avoid reference counting 
problems. 

Result:

More information for more flexible handling.

Fixes #10277 #4528 #10639.
This commit is contained in:
Andrey Mizurov 2020-10-24 15:41:11 +03:00 committed by GitHub
parent 6ae7f9e1a8
commit 33de96f448
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 296 additions and 56 deletions

View File

@ -0,0 +1,55 @@
/*
* Copyright 2020 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.http.websocketx;
import io.netty.handler.codec.http.DefaultHttpResponse;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.util.ReferenceCounted;
/**
* Client exception during handshaking process.
*
* <p><b>IMPORTANT</b>: This exception does not contain any {@link ReferenceCounted} fields
* e.g. {@link FullHttpResponse}, so no special treatment is needed.
*/
public final class WebSocketClientHandshakeException extends WebSocketHandshakeException {
private static final long serialVersionUID = 1L;
private final HttpResponse response;
public WebSocketClientHandshakeException(String message) {
this(message, null);
}
public WebSocketClientHandshakeException(String message, HttpResponse httpResponse) {
super(message);
if (httpResponse != null) {
response = new DefaultHttpResponse(httpResponse.protocolVersion(),
httpResponse.status(), httpResponse.headers());
} else {
response = null;
}
}
/**
* Returns a {@link HttpResponse response} if exception occurs during response validation otherwise {@code null}.
*/
public HttpResponse response() {
return response;
}
}

View File

@ -324,9 +324,9 @@ public abstract class WebSocketClientHandshaker {
} // else mixed cases - which are all errors } // else mixed cases - which are all errors
if (!protocolValid) { if (!protocolValid) {
throw new WebSocketHandshakeException(String.format( throw new WebSocketClientHandshakeException(String.format(
"Invalid subprotocol. Actual: %s. Expected one of: %s", "Invalid subprotocol. Actual: %s. Expected one of: %s",
receivedProtocol, expectedSubprotocol)); receivedProtocol, expectedSubprotocol), response);
} }
setHandshakeComplete(); setHandshakeComplete();

View File

@ -26,7 +26,6 @@ import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion; import io.netty.handler.codec.http.HttpVersion;
import io.netty.util.AsciiString;
import java.net.URI; import java.net.URI;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -43,8 +42,6 @@ import java.nio.ByteBuffer;
*/ */
public class WebSocketClientHandshaker00 extends WebSocketClientHandshaker { public class WebSocketClientHandshaker00 extends WebSocketClientHandshaker {
private static final AsciiString WEBSOCKET = AsciiString.cached("WebSocket");
private ByteBuf expectedChallengeResponseBytes; private ByteBuf expectedChallengeResponseBytes;
/** /**
@ -186,7 +183,7 @@ public class WebSocketClientHandshaker00 extends WebSocketClientHandshaker {
headers.add(customHeaders); headers.add(customHeaders);
} }
headers.set(HttpHeaderNames.UPGRADE, WEBSOCKET) headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL)) .set(HttpHeaderNames.HOST, websocketHostValue(wsURL))
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY1, key1) .set(HttpHeaderNames.SEC_WEBSOCKET_KEY1, key1)
@ -229,26 +226,25 @@ public class WebSocketClientHandshaker00 extends WebSocketClientHandshaker {
*/ */
@Override @Override
protected void verify(FullHttpResponse response) { protected void verify(FullHttpResponse response) {
if (!response.status().equals(HttpResponseStatus.SWITCHING_PROTOCOLS)) { HttpResponseStatus status = response.status();
throw new WebSocketHandshakeException("Invalid handshake response getStatus: " + response.status()); if (!HttpResponseStatus.SWITCHING_PROTOCOLS.equals(status)) {
throw new WebSocketClientHandshakeException("Invalid handshake response getStatus: " + status, response);
} }
HttpHeaders headers = response.headers(); HttpHeaders headers = response.headers();
CharSequence upgrade = headers.get(HttpHeaderNames.UPGRADE); CharSequence upgrade = headers.get(HttpHeaderNames.UPGRADE);
if (!WEBSOCKET.contentEqualsIgnoreCase(upgrade)) { if (!HttpHeaderValues.WEBSOCKET.contentEqualsIgnoreCase(upgrade)) {
throw new WebSocketHandshakeException("Invalid handshake response upgrade: " throw new WebSocketClientHandshakeException("Invalid handshake response upgrade: " + upgrade, response);
+ upgrade);
} }
if (!headers.containsValue(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE, true)) { if (!headers.containsValue(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE, true)) {
throw new WebSocketHandshakeException("Invalid handshake response connection: " throw new WebSocketClientHandshakeException("Invalid handshake response connection: "
+ headers.get(HttpHeaderNames.CONNECTION)); + headers.get(HttpHeaderNames.CONNECTION), response);
} }
ByteBuf challenge = response.content(); ByteBuf challenge = response.content();
if (!challenge.equals(expectedChallengeResponseBytes)) { if (!challenge.equals(expectedChallengeResponseBytes)) {
throw new WebSocketHandshakeException("Invalid challenge"); throw new WebSocketClientHandshakeException("Invalid challenge", response);
} }
} }

View File

@ -264,27 +264,26 @@ public class WebSocketClientHandshaker07 extends WebSocketClientHandshaker {
*/ */
@Override @Override
protected void verify(FullHttpResponse response) { protected void verify(FullHttpResponse response) {
final HttpResponseStatus status = HttpResponseStatus.SWITCHING_PROTOCOLS; HttpResponseStatus status = response.status();
final HttpHeaders headers = response.headers(); if (!HttpResponseStatus.SWITCHING_PROTOCOLS.equals(status)) {
throw new WebSocketClientHandshakeException("Invalid handshake response getStatus: " + status, response);
if (!response.status().equals(status)) {
throw new WebSocketHandshakeException("Invalid handshake response getStatus: " + response.status());
} }
HttpHeaders headers = response.headers();
CharSequence upgrade = headers.get(HttpHeaderNames.UPGRADE); CharSequence upgrade = headers.get(HttpHeaderNames.UPGRADE);
if (!HttpHeaderValues.WEBSOCKET.contentEqualsIgnoreCase(upgrade)) { if (!HttpHeaderValues.WEBSOCKET.contentEqualsIgnoreCase(upgrade)) {
throw new WebSocketHandshakeException("Invalid handshake response upgrade: " + upgrade); throw new WebSocketClientHandshakeException("Invalid handshake response upgrade: " + upgrade, response);
} }
if (!headers.containsValue(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE, true)) { if (!headers.containsValue(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE, true)) {
throw new WebSocketHandshakeException("Invalid handshake response connection: " throw new WebSocketClientHandshakeException("Invalid handshake response connection: "
+ headers.get(HttpHeaderNames.CONNECTION)); + headers.get(HttpHeaderNames.CONNECTION), response);
} }
CharSequence accept = headers.get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT); CharSequence accept = headers.get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT);
if (accept == null || !accept.equals(expectedChallengeResponseString)) { if (accept == null || !accept.equals(expectedChallengeResponseString)) {
throw new WebSocketHandshakeException(String.format( throw new WebSocketClientHandshakeException(String.format(
"Invalid challenge. Actual: %s. Expected: %s", accept, expectedChallengeResponseString)); "Invalid challenge. Actual: %s. Expected: %s", accept, expectedChallengeResponseString), response);
} }
} }

View File

@ -266,27 +266,26 @@ public class WebSocketClientHandshaker08 extends WebSocketClientHandshaker {
*/ */
@Override @Override
protected void verify(FullHttpResponse response) { protected void verify(FullHttpResponse response) {
final HttpResponseStatus status = HttpResponseStatus.SWITCHING_PROTOCOLS; HttpResponseStatus status = response.status();
final HttpHeaders headers = response.headers(); if (!HttpResponseStatus.SWITCHING_PROTOCOLS.equals(status)) {
throw new WebSocketClientHandshakeException("Invalid handshake response getStatus: " + status, response);
if (!response.status().equals(status)) {
throw new WebSocketHandshakeException("Invalid handshake response getStatus: " + response.status());
} }
HttpHeaders headers = response.headers();
CharSequence upgrade = headers.get(HttpHeaderNames.UPGRADE); CharSequence upgrade = headers.get(HttpHeaderNames.UPGRADE);
if (!HttpHeaderValues.WEBSOCKET.contentEqualsIgnoreCase(upgrade)) { if (!HttpHeaderValues.WEBSOCKET.contentEqualsIgnoreCase(upgrade)) {
throw new WebSocketHandshakeException("Invalid handshake response upgrade: " + upgrade); throw new WebSocketClientHandshakeException("Invalid handshake response upgrade: " + upgrade, response);
} }
if (!headers.containsValue(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE, true)) { if (!headers.containsValue(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE, true)) {
throw new WebSocketHandshakeException("Invalid handshake response connection: " throw new WebSocketClientHandshakeException("Invalid handshake response connection: "
+ headers.get(HttpHeaderNames.CONNECTION)); + headers.get(HttpHeaderNames.CONNECTION), response);
} }
CharSequence accept = headers.get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT); CharSequence accept = headers.get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT);
if (accept == null || !accept.equals(expectedChallengeResponseString)) { if (accept == null || !accept.equals(expectedChallengeResponseString)) {
throw new WebSocketHandshakeException(String.format( throw new WebSocketClientHandshakeException(String.format(
"Invalid challenge. Actual: %s. Expected: %s", accept, expectedChallengeResponseString)); "Invalid challenge. Actual: %s. Expected: %s", accept, expectedChallengeResponseString), response);
} }
} }

View File

@ -267,27 +267,26 @@ public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker {
*/ */
@Override @Override
protected void verify(FullHttpResponse response) { protected void verify(FullHttpResponse response) {
final HttpResponseStatus status = HttpResponseStatus.SWITCHING_PROTOCOLS; HttpResponseStatus status = response.status();
final HttpHeaders headers = response.headers(); if (!HttpResponseStatus.SWITCHING_PROTOCOLS.equals(status)) {
throw new WebSocketClientHandshakeException("Invalid handshake response getStatus: " + status, response);
if (!response.status().equals(status)) {
throw new WebSocketHandshakeException("Invalid handshake response getStatus: " + response.status());
} }
HttpHeaders headers = response.headers();
CharSequence upgrade = headers.get(HttpHeaderNames.UPGRADE); CharSequence upgrade = headers.get(HttpHeaderNames.UPGRADE);
if (!HttpHeaderValues.WEBSOCKET.contentEqualsIgnoreCase(upgrade)) { if (!HttpHeaderValues.WEBSOCKET.contentEqualsIgnoreCase(upgrade)) {
throw new WebSocketHandshakeException("Invalid handshake response upgrade: " + upgrade); throw new WebSocketClientHandshakeException("Invalid handshake response upgrade: " + upgrade, response);
} }
if (!headers.containsValue(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE, true)) { if (!headers.containsValue(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE, true)) {
throw new WebSocketHandshakeException("Invalid handshake response connection: " throw new WebSocketClientHandshakeException("Invalid handshake response connection: "
+ headers.get(HttpHeaderNames.CONNECTION)); + headers.get(HttpHeaderNames.CONNECTION), response);
} }
CharSequence accept = headers.get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT); CharSequence accept = headers.get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT);
if (accept == null || !accept.equals(expectedChallengeResponseString)) { if (accept == null || !accept.equals(expectedChallengeResponseString)) {
throw new WebSocketHandshakeException(String.format( throw new WebSocketClientHandshakeException(String.format(
"Invalid challenge. Actual: %s. Expected: %s", accept, expectedChallengeResponseString)); "Invalid challenge. Actual: %s. Expected: %s", accept, expectedChallengeResponseString), response);
} }
} }

View File

@ -162,7 +162,7 @@ public final class WebSocketClientHandshakerFactory {
webSocketURL, V00, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis); webSocketURL, V00, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis);
} }
throw new WebSocketHandshakeException("Protocol version " + version + " not supported."); throw new WebSocketClientHandshakeException("Protocol version " + version + " not supported.");
} }
/** /**
@ -220,6 +220,6 @@ public final class WebSocketClientHandshakerFactory {
maxFramePayloadLength, forceCloseTimeoutMillis, absoluteUpgradeUrl); maxFramePayloadLength, forceCloseTimeoutMillis, absoluteUpgradeUrl);
} }
throw new WebSocketHandshakeException("Protocol version " + version + " not supported."); throw new WebSocketClientHandshakeException("Protocol version " + version + " not supported.");
} }
} }

View File

@ -370,6 +370,11 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
super.decode(ctx, frame, out); super.decode(ctx, frame, out);
} }
@Override
protected WebSocketClientHandshakeException buildHandshakeException(String message) {
return new WebSocketClientHandshakeException(message);
}
@Override @Override
public void handlerAdded(ChannelHandlerContext ctx) { public void handlerAdded(ChannelHandlerContext ctx) {
ChannelPipeline cp = ctx.pipeline(); ChannelPipeline cp = ctx.pipeline();

View File

@ -73,7 +73,8 @@ class WebSocketClientProtocolHandshakeHandler extends ChannelInboundHandlerAdapt
@Override @Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception { public void channelInactive(ChannelHandlerContext ctx) throws Exception {
if (!handshakePromise.isDone()) { if (!handshakePromise.isDone()) {
handshakePromise.tryFailure(new WebSocketHandshakeException("channel closed with handshake in progress")); handshakePromise.tryFailure(new WebSocketClientHandshakeException("channel closed with handshake " +
"in progress"));
} }
super.channelInactive(ctx); super.channelInactive(ctx);
@ -115,7 +116,7 @@ class WebSocketClientProtocolHandshakeHandler extends ChannelInboundHandlerAdapt
return; return;
} }
if (localHandshakePromise.tryFailure(new WebSocketHandshakeException("handshake timed out"))) { if (localHandshakePromise.tryFailure(new WebSocketClientHandshakeException("handshake timed out"))) {
ctx.flush() ctx.flush()
.fireUserEventTriggered(ClientHandshakeStateEvent.HANDSHAKE_TIMEOUT) .fireUserEventTriggered(ClientHandshakeStateEvent.HANDSHAKE_TIMEOUT)
.close(); .close();

View File

@ -128,7 +128,7 @@ abstract class WebSocketProtocolHandler extends MessageToMessageDecoder<WebSocke
@Override @Override
public void run() { public void run() {
if (!closeSent.isDone()) { if (!closeSent.isDone()) {
closeSent.tryFailure(new WebSocketHandshakeException("send close frame timed out")); closeSent.tryFailure(buildHandshakeException("send close frame timed out"));
} }
} }
}, forceCloseTimeoutMillis, TimeUnit.MILLISECONDS); }, forceCloseTimeoutMillis, TimeUnit.MILLISECONDS);
@ -141,6 +141,14 @@ abstract class WebSocketProtocolHandler extends MessageToMessageDecoder<WebSocke
}); });
} }
/**
* Returns a {@link WebSocketHandshakeException} that depends on which client or server pipeline
* this handler belongs. Should be overridden in implementation otherwise a default exception is used.
*/
protected WebSocketHandshakeException buildHandshakeException(String message) {
return new WebSocketHandshakeException(message);
}
@Override @Override
public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, public void bind(ChannelHandlerContext ctx, SocketAddress localAddress,
ChannelPromise promise) throws Exception { ChannelPromise promise) throws Exception {

View File

@ -0,0 +1,55 @@
/*
* Copyright 2020 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.http.websocketx;
import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.util.ReferenceCounted;
/**
* Server exception during handshaking process.
*
* <p><b>IMPORTANT</b>: This exception does not contain any {@link ReferenceCounted} fields
* e.g. {@link FullHttpRequest}, so no special treatment is needed.
*/
public final class WebSocketServerHandshakeException extends WebSocketHandshakeException {
private static final long serialVersionUID = 1L;
private final HttpRequest request;
public WebSocketServerHandshakeException(String message) {
this(message, null);
}
public WebSocketServerHandshakeException(String message, HttpRequest httpRequest) {
super(message);
if (httpRequest != null) {
request = new DefaultHttpRequest(httpRequest.protocolVersion(), httpRequest.method(),
httpRequest.uri(), httpRequest.headers());
} else {
request = null;
}
}
/**
* Returns a {@link HttpRequest request} if exception occurs during request validation otherwise {@code null}.
*/
public HttpRequest request() {
return request;
}
}

View File

@ -126,7 +126,7 @@ public class WebSocketServerHandshaker00 extends WebSocketServerHandshaker {
// Serve the WebSocket handshake request. // Serve the WebSocket handshake request.
if (!req.headers().containsValue(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE, true) if (!req.headers().containsValue(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE, true)
|| !HttpHeaderValues.WEBSOCKET.contentEqualsIgnoreCase(req.headers().get(HttpHeaderNames.UPGRADE))) { || !HttpHeaderValues.WEBSOCKET.contentEqualsIgnoreCase(req.headers().get(HttpHeaderNames.UPGRADE))) {
throw new WebSocketHandshakeException("not a WebSocket handshake request: missing upgrade"); throw new WebSocketServerHandshakeException("not a WebSocket handshake request: missing upgrade", req);
} }
// Hixie 75 does not contain these headers while Hixie 76 does // Hixie 75 does not contain these headers while Hixie 76 does
@ -136,7 +136,8 @@ public class WebSocketServerHandshaker00 extends WebSocketServerHandshaker {
String origin = req.headers().get(HttpHeaderNames.ORIGIN); String origin = req.headers().get(HttpHeaderNames.ORIGIN);
//throw before allocating FullHttpResponse //throw before allocating FullHttpResponse
if (origin == null && !isHixie76) { if (origin == null && !isHixie76) {
throw new WebSocketHandshakeException("Missing origin header, got only " + req.headers().names()); throw new WebSocketServerHandshakeException("Missing origin header, got only " + req.headers().names(),
req);
} }
// Create the WebSocket handshake response. // Create the WebSocket handshake response.

View File

@ -130,7 +130,7 @@ public class WebSocketServerHandshaker07 extends WebSocketServerHandshaker {
protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) { protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) {
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY); CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
if (key == null) { if (key == null) {
throw new WebSocketHandshakeException("not a WebSocket request: missing key"); throw new WebSocketServerHandshakeException("not a WebSocket request: missing key", req);
} }
FullHttpResponse res = FullHttpResponse res =

View File

@ -137,7 +137,7 @@ public class WebSocketServerHandshaker08 extends WebSocketServerHandshaker {
protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) { protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) {
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY); CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
if (key == null) { if (key == null) {
throw new WebSocketHandshakeException("not a WebSocket request: missing key"); throw new WebSocketServerHandshakeException("not a WebSocket request: missing key", req);
} }
FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS, FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS,

View File

@ -136,7 +136,7 @@ public class WebSocketServerHandshaker13 extends WebSocketServerHandshaker {
protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) { protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) {
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY); CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
if (key == null) { if (key == null) {
throw new WebSocketHandshakeException("not a WebSocket request: missing key"); throw new WebSocketServerHandshakeException("not a WebSocket request: missing key", req);
} }
FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS, FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS,

View File

@ -246,6 +246,11 @@ public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
super.decode(ctx, frame, out); super.decode(ctx, frame, out);
} }
@Override
protected WebSocketServerHandshakeException buildHandshakeException(String message) {
return new WebSocketServerHandshakeException(message);
}
@Override @Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
if (cause instanceof WebSocketHandshakeException) { if (cause instanceof WebSocketHandshakeException) {

View File

@ -158,7 +158,7 @@ class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapt
@Override @Override
public void run() { public void run() {
if (!localHandshakePromise.isDone() && if (!localHandshakePromise.isDone() &&
localHandshakePromise.tryFailure(new WebSocketHandshakeException("handshake timed out"))) { localHandshakePromise.tryFailure(new WebSocketServerHandshakeException("handshake timed out"))) {
ctx.flush() ctx.flush()
.fireUserEventTriggered(ServerHandshakeStateEvent.HANDSHAKE_TIMEOUT) .fireUserEventTriggered(ServerHandshakeStateEvent.HANDSHAKE_TIMEOUT)
.close(); .close();

View File

@ -21,6 +21,7 @@ import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.EmptyHttpHeaders; import io.netty.handler.codec.http.EmptyHttpHeaders;
import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpRequest;
@ -31,6 +32,8 @@ import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequestEncoder; import io.netty.handler.codec.http.HttpRequestEncoder;
import io.netty.handler.codec.http.HttpResponseDecoder; import io.netty.handler.codec.http.HttpResponseDecoder;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.util.CharsetUtil; import io.netty.util.CharsetUtil;
import org.junit.Test; import org.junit.Test;
@ -387,4 +390,24 @@ public abstract class WebSocketClientHandshakerTest {
request.release(); request.release();
} }
@Test
public void testWebSocketClientHandshakeException() {
URI uri = URI.create("ws://localhost:9999/exception");
WebSocketClientHandshaker handshaker = newHandshaker(uri, null, null, false);
FullHttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.UNAUTHORIZED);
response.headers().set(HttpHeaderNames.WWW_AUTHENTICATE, "realm = access token required");
try {
handshaker.finishHandshake(null, response);
} catch (WebSocketClientHandshakeException exception) {
assertEquals("Invalid handshake response getStatus: 401 Unauthorized", exception.getMessage());
assertEquals(HttpResponseStatus.UNAUTHORIZED, exception.response().status());
assertTrue(exception.response().headers().contains(HttpHeaderNames.WWW_AUTHENTICATE,
"realm = access token required", false));
} finally {
response.release();
} }
}
}

View File

@ -0,0 +1,74 @@
/*
* Copyright 2020 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.http.websocketx;
import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.DefaultHttpResponse;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import org.junit.Test;
import static org.junit.Assert.*;
public class WebSocketHandshakeExceptionTest {
@Test
public void testClientExceptionWithoutResponse() {
WebSocketClientHandshakeException clientException = new WebSocketClientHandshakeException("client message");
assertNull(clientException.response());
assertEquals("client message", clientException.getMessage());
}
@Test
public void testClientExceptionWithResponse() {
HttpResponse httpResponse = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_REQUEST);
httpResponse.headers().set("x-header", "x-value");
WebSocketClientHandshakeException clientException = new WebSocketClientHandshakeException("client message",
httpResponse);
assertNotNull(clientException.response());
assertEquals("client message", clientException.getMessage());
assertEquals(HttpResponseStatus.BAD_REQUEST, clientException.response().status());
assertEquals(httpResponse.headers(), clientException.response().headers());
}
@Test
public void testServerExceptionWithoutRequest() {
WebSocketServerHandshakeException serverException = new WebSocketServerHandshakeException("server message");
assertNull(serverException.request());
assertEquals("server message", serverException.getMessage());
}
@Test
public void testClientExceptionWithRequest() {
HttpRequest httpRequest = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET,
"ws://localhost:9999/ws");
httpRequest.headers().set("x-header", "x-value");
WebSocketServerHandshakeException serverException = new WebSocketServerHandshakeException("server message",
httpRequest);
assertNotNull(serverException.request());
assertEquals("server message", serverException.getMessage());
assertEquals(HttpMethod.GET, serverException.request().method());
assertEquals(httpRequest.headers(), serverException.request().headers());
assertEquals(httpRequest.uri(), serverException.request().uri());
}
}

View File

@ -78,4 +78,24 @@ public abstract class WebSocketServerHandshakerTest {
} }
} }
} }
@Test
public void testWebSocketServerHandshakeException() {
WebSocketServerHandshaker serverHandshaker = newHandshaker("ws://example.com/chat",
"chat", WebSocketDecoderConfig.DEFAULT);
FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET,
"ws://example.com/chat");
request.headers().set("x-client-header", "value");
try {
serverHandshaker.handshake(null, request, null, null);
} catch (WebSocketServerHandshakeException exception) {
assertNotNull(exception.getMessage());
assertEquals(request.headers(), exception.request().headers());
assertEquals(HttpMethod.GET, exception.request().method());
} finally {
request.release();
} }
}
}