Additional fix for lost first WebSocket frame after handshake

Related issue: #2179

Motivation:

Previous fix e71cbb9308bf8788e1e0fb8db99766d89156386d was not enough.

Modifications:

- Add more test cases for WebSocket handshake
- Fix a bug in HttpMessageDecoder where it does not always enter
  UPGRADED state
- Fix incorrect decoder replacement logic in WebSocketClientHandshaker
  implementations
  - Add WebSocketClientHandshaker.replaceDecoder() as a helper

Result:

We never lose the first WebSocket frame for all WebSocket protocol
versions.
This commit is contained in:
Trustin Lee 2014-08-06 11:10:12 -07:00
parent c5e1ab6403
commit 788d3dea42
7 changed files with 62 additions and 21 deletions

View File

@ -205,6 +205,7 @@ public abstract class HttpMessageDecoder extends ReplayingDecoder<State> {
// Remove the headers which are not supposed to be present not // Remove the headers which are not supposed to be present not
// to confuse subsequent handlers. // to confuse subsequent handlers.
message.headers().remove(HttpHeaders.Names.TRANSFER_ENCODING); message.headers().remove(HttpHeaders.Names.TRANSFER_ENCODING);
resetState();
return message; return message;
} }
long contentLength = HttpHeaders.getContentLength(message, -1); long contentLength = HttpHeaders.getContentLength(message, -1);
@ -451,18 +452,23 @@ public abstract class HttpMessageDecoder extends ReplayingDecoder<State> {
message.setContent(content); message.setContent(content);
this.content = null; this.content = null;
} }
resetState();
this.message = null; this.message = null;
return message;
}
private void resetState() {
if (!isDecodingRequest()) { if (!isDecodingRequest()) {
HttpResponse res = (HttpResponse) message; HttpResponse res = (HttpResponse) message;
if (res != null && res.getStatus().getCode() == 101) { if (res != null && res.getStatus().getCode() == 101) {
checkpoint(State.UPGRADED); checkpoint(State.UPGRADED);
return message; return;
} }
} }
checkpoint(State.SKIP_CONTROL_CHARS); checkpoint(State.SKIP_CONTROL_CHARS);
return message;
} }
private static void skipControlCharacters(ChannelBuffer buffer) { private static void skipControlCharacters(ChannelBuffer buffer) {

View File

@ -17,7 +17,11 @@ package org.jboss.netty.handler.codec.http.websocketx;
import org.jboss.netty.channel.Channel; import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelFuture; import org.jboss.netty.channel.ChannelFuture;
import org.jboss.netty.channel.ChannelHandler;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.ChannelPipeline;
import org.jboss.netty.handler.codec.http.HttpResponse; import org.jboss.netty.handler.codec.http.HttpResponse;
import org.jboss.netty.handler.codec.http.HttpResponseDecoder;
import java.net.URI; import java.net.URI;
import java.util.Map; import java.util.Map;
@ -151,4 +155,22 @@ public abstract class WebSocketClientHandshaker {
* HTTP response containing the closing handshake details * HTTP response containing the closing handshake details
*/ */
public abstract void finishHandshake(Channel channel, HttpResponse response); public abstract void finishHandshake(Channel channel, HttpResponse response);
/**
* Replace the HTTP decoder with a new Web Socket decoder.
* Note that we do not use {@link ChannelPipeline#replace(String, String, ChannelHandler)}, because the server
* might have sent the first frame immediately after the upgrade response. In such a case, the HTTP decoder might
* have the first frame in its cumulation buffer and the HTTP decoder will forward it to the next handler.
* The Web Socket decoder will not receive it if we simply replaced it. For more information, refer to
* {@link HttpResponseDecoder} and its unit tests.
*/
static void replaceDecoder(Channel channel, ChannelHandler wsDecoder) {
ChannelPipeline p = channel.getPipeline();
ChannelHandlerContext httpDecoderCtx = p.getContext(HttpResponseDecoder.class);
if (httpDecoderCtx == null) {
throw new IllegalStateException("can't find an HTTP decoder from the pipeline");
}
p.addAfter(httpDecoderCtx.getName(), "ws-decoder", wsDecoder);
p.remove(httpDecoderCtx.getName());
}
} }

View File

@ -29,7 +29,6 @@ import org.jboss.netty.handler.codec.http.HttpMethod;
import org.jboss.netty.handler.codec.http.HttpRequest; import org.jboss.netty.handler.codec.http.HttpRequest;
import org.jboss.netty.handler.codec.http.HttpRequestEncoder; import org.jboss.netty.handler.codec.http.HttpRequestEncoder;
import org.jboss.netty.handler.codec.http.HttpResponse; import org.jboss.netty.handler.codec.http.HttpResponse;
import org.jboss.netty.handler.codec.http.HttpResponseDecoder;
import org.jboss.netty.handler.codec.http.HttpResponseStatus; import org.jboss.netty.handler.codec.http.HttpResponseStatus;
import org.jboss.netty.handler.codec.http.HttpVersion; import org.jboss.netty.handler.codec.http.HttpVersion;
@ -262,9 +261,7 @@ public class WebSocketClientHandshaker00 extends WebSocketClientHandshaker {
setActualSubprotocol(subprotocol); setActualSubprotocol(subprotocol);
setHandshakeComplete(); setHandshakeComplete();
replaceDecoder(channel, new WebSocket00FrameDecoder(getMaxFramePayloadLength()));
channel.getPipeline().get(HttpResponseDecoder.class).replace("ws-decoder",
new WebSocket00FrameDecoder(getMaxFramePayloadLength()));
} }
private static String insertRandomCharacters(String key) { private static String insertRandomCharacters(String key) {

View File

@ -29,7 +29,6 @@ import org.jboss.netty.handler.codec.http.HttpMethod;
import org.jboss.netty.handler.codec.http.HttpRequest; import org.jboss.netty.handler.codec.http.HttpRequest;
import org.jboss.netty.handler.codec.http.HttpRequestEncoder; import org.jboss.netty.handler.codec.http.HttpRequestEncoder;
import org.jboss.netty.handler.codec.http.HttpResponse; import org.jboss.netty.handler.codec.http.HttpResponse;
import org.jboss.netty.handler.codec.http.HttpResponseDecoder;
import org.jboss.netty.handler.codec.http.HttpResponseStatus; import org.jboss.netty.handler.codec.http.HttpResponseStatus;
import org.jboss.netty.handler.codec.http.HttpVersion; import org.jboss.netty.handler.codec.http.HttpVersion;
import org.jboss.netty.logging.InternalLogger; import org.jboss.netty.logging.InternalLogger;
@ -225,10 +224,8 @@ public class WebSocketClientHandshaker07 extends WebSocketClientHandshaker {
setActualSubprotocol(subprotocol); setActualSubprotocol(subprotocol);
setHandshakeComplete(); setHandshakeComplete();
replaceDecoder(
ChannelPipeline p = channel.getPipeline(); channel,
p.get(HttpResponseDecoder.class).replace(
"ws-decoder",
new WebSocket07FrameDecoder(false, allowExtensions, getMaxFramePayloadLength())); new WebSocket07FrameDecoder(false, allowExtensions, getMaxFramePayloadLength()));
} }
} }

View File

@ -29,7 +29,6 @@ import org.jboss.netty.handler.codec.http.HttpMethod;
import org.jboss.netty.handler.codec.http.HttpRequest; import org.jboss.netty.handler.codec.http.HttpRequest;
import org.jboss.netty.handler.codec.http.HttpRequestEncoder; import org.jboss.netty.handler.codec.http.HttpRequestEncoder;
import org.jboss.netty.handler.codec.http.HttpResponse; import org.jboss.netty.handler.codec.http.HttpResponse;
import org.jboss.netty.handler.codec.http.HttpResponseDecoder;
import org.jboss.netty.handler.codec.http.HttpResponseStatus; import org.jboss.netty.handler.codec.http.HttpResponseStatus;
import org.jboss.netty.handler.codec.http.HttpVersion; import org.jboss.netty.handler.codec.http.HttpVersion;
import org.jboss.netty.logging.InternalLogger; import org.jboss.netty.logging.InternalLogger;
@ -249,8 +248,8 @@ public class WebSocketClientHandshaker08 extends WebSocketClientHandshaker {
setActualSubprotocol(subprotocol); setActualSubprotocol(subprotocol);
setHandshakeComplete(); setHandshakeComplete();
replaceDecoder(
channel.getPipeline().get(HttpResponseDecoder.class).replace("ws-decoder", channel,
new WebSocket08FrameDecoder(false, allowExtensions, getMaxFramePayloadLength())); new WebSocket08FrameDecoder(false, allowExtensions, getMaxFramePayloadLength()));
} }
} }

View File

@ -29,7 +29,6 @@ import org.jboss.netty.handler.codec.http.HttpMethod;
import org.jboss.netty.handler.codec.http.HttpRequest; import org.jboss.netty.handler.codec.http.HttpRequest;
import org.jboss.netty.handler.codec.http.HttpRequestEncoder; import org.jboss.netty.handler.codec.http.HttpRequestEncoder;
import org.jboss.netty.handler.codec.http.HttpResponse; import org.jboss.netty.handler.codec.http.HttpResponse;
import org.jboss.netty.handler.codec.http.HttpResponseDecoder;
import org.jboss.netty.handler.codec.http.HttpResponseStatus; import org.jboss.netty.handler.codec.http.HttpResponseStatus;
import org.jboss.netty.handler.codec.http.HttpVersion; import org.jboss.netty.handler.codec.http.HttpVersion;
import org.jboss.netty.logging.InternalLogger; import org.jboss.netty.logging.InternalLogger;
@ -246,8 +245,8 @@ public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker {
setActualSubprotocol(subprotocol); setActualSubprotocol(subprotocol);
setHandshakeComplete(); setHandshakeComplete();
replaceDecoder(
channel.getPipeline().get(HttpResponseDecoder.class).replace("ws-decoder", channel,
new WebSocket13FrameDecoder(false, allowExtensions, getMaxFramePayloadLength())); new WebSocket13FrameDecoder(false, allowExtensions, getMaxFramePayloadLength()));
} }
} }

View File

@ -18,6 +18,7 @@ package org.jboss.netty.handler.codec.http;
import org.jboss.netty.buffer.ChannelBuffers; import org.jboss.netty.buffer.ChannelBuffers;
import org.jboss.netty.handler.codec.embedder.DecoderEmbedder; import org.jboss.netty.handler.codec.embedder.DecoderEmbedder;
import org.jboss.netty.util.CharsetUtil;
import org.junit.Test; import org.junit.Test;
import static org.hamcrest.CoreMatchers.*; import static org.hamcrest.CoreMatchers.*;
@ -56,11 +57,10 @@ public class HttpResponseDecoderTest {
"Sec-WebSocket-Origin: http://localhost:8080\r\n" + "Sec-WebSocket-Origin: http://localhost:8080\r\n" +
"Sec-WebSocket-Location: ws://localhost/some/path\r\n" + "Sec-WebSocket-Location: ws://localhost/some/path\r\n" +
"\r\n" + "\r\n" +
"1234567812345678").getBytes(); "1234567812345678EXTRA").getBytes(CharsetUtil.US_ASCII);
byte[] otherData = {1, 2, 3, 4};
DecoderEmbedder<Object> ch = new DecoderEmbedder<Object>(new HttpResponseDecoder()); DecoderEmbedder<Object> ch = new DecoderEmbedder<Object>(new HttpResponseDecoder());
ch.offer(ChannelBuffers.wrappedBuffer(data, otherData)); ch.offer(ChannelBuffers.wrappedBuffer(data));
HttpResponse res = (HttpResponse) ch.poll(); HttpResponse res = (HttpResponse) ch.poll();
assertThat(res.getProtocolVersion(), sameInstance(HttpVersion.HTTP_1_1)); assertThat(res.getProtocolVersion(), sameInstance(HttpVersion.HTTP_1_1));
@ -69,6 +69,27 @@ public class HttpResponseDecoderTest {
assertThat(ch.finish(), is(true)); assertThat(ch.finish(), is(true));
assertEquals(ch.poll(), ChannelBuffers.wrappedBuffer(otherData)); assertEquals(ch.poll(), ChannelBuffers.wrappedBuffer("EXTRA".getBytes(CharsetUtil.US_ASCII)));
}
@Test
public void testWebSocketResponseWithDataFollowing2() {
byte[] data = ("HTTP/1.1 101 Switching Protocols\n" +
"Upgrade: websocket\n" +
"Connection: Upgrade\n" +
"Sec-WebSocket-Accept: fd6T8bTOMVN65WHXymeKp6WTWfA=\n\n" +
"EXTRA").getBytes(CharsetUtil.US_ASCII);
DecoderEmbedder<Object> ch = new DecoderEmbedder<Object>(new HttpResponseDecoder());
ch.offer(ChannelBuffers.wrappedBuffer(data));
HttpResponse res = (HttpResponse) ch.poll();
assertThat(res.getProtocolVersion(), sameInstance(HttpVersion.HTTP_1_1));
assertThat(res.getStatus(), is(HttpResponseStatus.SWITCHING_PROTOCOLS));
assertThat(res.getContent().readableBytes(), is(0));
assertThat(ch.finish(), is(true));
assertEquals(ch.poll(), ChannelBuffers.wrappedBuffer("EXTRA".getBytes(CharsetUtil.US_ASCII)));
} }
} }