netty5/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java
Norman Maurer e4995be33c Use allocator when constructing ByteBufHolder sub-types or use Unpool… (#9377)
Motivation:

In many places Netty uses Unpooled.buffer(0) while should use EMPTY_BUFFER. We can't change this due to back compatibility in the constructors but can use Unpooled.EMPTY_BUFFER in some cases to ensure we not allocate at all. In others we can directly use the allocator either from the Channel / ChannelHandlerContext or the request / response.

Modification:

- Use Unpooled.EMPTY_BUFFER where possible
- Use allocator where possible

Result:

Fixes #9345 for websockets and http package
2019-07-18 10:36:03 +02:00

165 lines
7.1 KiB
Java

/*
* Copyright 2019 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.http.websocketx;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandler;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler.ServerHandshakeStateEvent;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.FutureListener;
import java.util.concurrent.TimeUnit;
import static io.netty.handler.codec.http.HttpMethod.*;
import static io.netty.handler.codec.http.HttpResponseStatus.*;
import static io.netty.handler.codec.http.HttpUtil.*;
import static io.netty.handler.codec.http.HttpVersion.*;
import static io.netty.util.internal.ObjectUtil.*;
/**
* Handles the HTTP handshake (the HTTP Upgrade request) for {@link WebSocketServerProtocolHandler}.
*/
class WebSocketServerProtocolHandshakeHandler implements ChannelInboundHandler {
private final String websocketPath;
private final String subprotocols;
private final boolean checkStartsWith;
private final long handshakeTimeoutMillis;
private final WebSocketDecoderConfig decoderConfig;
private ChannelHandlerContext ctx;
private ChannelPromise handshakePromise;
WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols,
boolean checkStartsWith, long handshakeTimeoutMillis, WebSocketDecoderConfig decoderConfig) {
this.websocketPath = websocketPath;
this.subprotocols = subprotocols;
this.checkStartsWith = checkStartsWith;
this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis");
this.decoderConfig = checkNotNull(decoderConfig, "decoderConfig");
}
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
this.ctx = ctx;
handshakePromise = ctx.newPromise();
}
@Override
public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception {
final FullHttpRequest req = (FullHttpRequest) msg;
if (isNotWebSocketPath(req)) {
ctx.fireChannelRead(msg);
return;
}
try {
if (!GET.equals(req.method())) {
sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN, ctx.alloc().buffer(0)));
return;
}
final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
getWebSocketLocation(ctx.pipeline(), req, websocketPath), subprotocols, decoderConfig);
final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);
final ChannelPromise localHandshakePromise = handshakePromise;
if (handshaker == null) {
WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
} else {
final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req);
handshakeFuture.addListener((ChannelFutureListener) future -> {
if (!future.isSuccess()) {
localHandshakePromise.tryFailure(future.cause());
ctx.fireExceptionCaught(future.cause());
} else {
localHandshakePromise.trySuccess();
// Kept for compatibility
ctx.fireUserEventTriggered(
WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE);
ctx.fireUserEventTriggered(
new WebSocketServerProtocolHandler.HandshakeComplete(
req.uri(), req.headers(), handshaker.selectedSubprotocol()));
}
});
applyHandshakeTimeout();
WebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker);
ctx.pipeline().replace(this, "WS403Responder",
WebSocketServerProtocolHandler.forbiddenHttpRequestResponder());
}
} finally {
req.release();
}
}
private boolean isNotWebSocketPath(FullHttpRequest req) {
return checkStartsWith ? !req.uri().startsWith(websocketPath) : !req.uri().equals(websocketPath);
}
private static void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) {
ChannelFuture f = ctx.channel().writeAndFlush(res);
if (!isKeepAlive(req) || res.status().code() != 200) {
f.addListener(ChannelFutureListener.CLOSE);
}
}
private static String getWebSocketLocation(ChannelPipeline cp, HttpRequest req, String path) {
String protocol = "ws";
if (cp.get(SslHandler.class) != null) {
// SSL in use so use Secure WebSockets
protocol = "wss";
}
String host = req.headers().get(HttpHeaderNames.HOST);
return protocol + "://" + host + path;
}
private void applyHandshakeTimeout() {
final ChannelPromise localHandshakePromise = handshakePromise;
final long handshakeTimeoutMillis = this.handshakeTimeoutMillis;
if (handshakeTimeoutMillis <= 0 || localHandshakePromise.isDone()) {
return;
}
final Future<?> timeoutFuture = ctx.executor().schedule(new Runnable() {
@Override
public void run() {
if (!localHandshakePromise.isDone() &&
localHandshakePromise.tryFailure(new WebSocketHandshakeException("handshake timed out"))) {
ctx.flush()
.fireUserEventTriggered(ServerHandshakeStateEvent.HANDSHAKE_TIMEOUT)
.close();
}
}
}, handshakeTimeoutMillis, TimeUnit.MILLISECONDS);
// Cancel the handshake timeout when handshake is finished.
localHandshakePromise.addListener(new FutureListener<Void>() {
@Override
public void operationComplete(Future<Void> f) throws Exception {
timeoutFuture.cancel(false);
}
});
}
}