HTTP/2 Decoder reduce preface conditional checks

Motivation:
The DefaultHttp2ConnectionDecoder class is calling verifyPrefaceReceived() for almost every frame event at all times.
The Http2ConnectionHandler class is calling readClientPrefaceString() on every decode event.

Modifications:
- DefaultHttp2ConnectionDecoder should not have to continuously call verifyPrefaceReceived() because it transitions boolean state 1 time for each connection.
- Http2ConnectionHandler should not have to continuously call readClientPrefaceString() because it transitions boolean state 1 time for each connection.

Result:
- Less conditional checks for the mainstream usage of the connection.
This commit is contained in:
Scott Mitchell 2015-03-23 10:52:11 -07:00
parent 2bf592c50f
commit 0d3a6e0511
11 changed files with 462 additions and 193 deletions

View File

@ -178,7 +178,55 @@ public final class ByteBufUtil {
/** /**
* Returns {@code true} if and only if the two specified buffers are * Returns {@code true} if and only if the two specified buffers are
* identical to each other as described in {@code ChannelBuffer#equals(Object)}. * identical to each other for {@code length} bytes starting at {@code aStartIndex}
* index for the {@code a} buffer and {@code bStartIndex} index for the {@code b} buffer.
* A more compact way to express this is:
* <p>
* {@code a[aStartIndex : aStartIndex + length] == b[bStartIndex : bStartIndex + length]}
*/
public static boolean equals(ByteBuf a, int aStartIndex, ByteBuf b, int bStartIndex, int length) {
if (aStartIndex < 0 || bStartIndex < 0 || length < 0) {
throw new IllegalArgumentException("All indexes and lengths must be non-negative");
}
if (a.writerIndex() - length < aStartIndex || b.writerIndex() - length < bStartIndex) {
return false;
}
final int longCount = length >>> 3;
final int byteCount = length & 7;
if (a.order() == b.order()) {
for (int i = longCount; i > 0; i --) {
if (a.getLong(aStartIndex) != b.getLong(bStartIndex)) {
return false;
}
aStartIndex += 8;
bStartIndex += 8;
}
} else {
for (int i = longCount; i > 0; i --) {
if (a.getLong(aStartIndex) != swapLong(b.getLong(bStartIndex))) {
return false;
}
aStartIndex += 8;
bStartIndex += 8;
}
}
for (int i = byteCount; i > 0; i --) {
if (a.getByte(aStartIndex) != b.getByte(bStartIndex)) {
return false;
}
aStartIndex ++;
bStartIndex ++;
}
return true;
}
/**
* Returns {@code true} if and only if the two specified buffers are
* identical to each other as described in {@link ByteBuf#equals(Object)}.
* This method is useful when implementing a new buffer type. * This method is useful when implementing a new buffer type.
*/ */
public static boolean equals(ByteBuf bufferA, ByteBuf bufferB) { public static boolean equals(ByteBuf bufferA, ByteBuf bufferB) {
@ -186,40 +234,7 @@ public final class ByteBufUtil {
if (aLen != bufferB.readableBytes()) { if (aLen != bufferB.readableBytes()) {
return false; return false;
} }
return equals(bufferA, bufferA.readerIndex(), bufferB, bufferB.readerIndex(), aLen);
final int longCount = aLen >>> 3;
final int byteCount = aLen & 7;
int aIndex = bufferA.readerIndex();
int bIndex = bufferB.readerIndex();
if (bufferA.order() == bufferB.order()) {
for (int i = longCount; i > 0; i --) {
if (bufferA.getLong(aIndex) != bufferB.getLong(bIndex)) {
return false;
}
aIndex += 8;
bIndex += 8;
}
} else {
for (int i = longCount; i > 0; i --) {
if (bufferA.getLong(aIndex) != swapLong(bufferB.getLong(bIndex))) {
return false;
}
aIndex += 8;
bIndex += 8;
}
}
for (int i = byteCount; i > 0; i --) {
if (bufferA.getByte(aIndex) != bufferB.getByte(bIndex)) {
return false;
}
aIndex ++;
bIndex ++;
}
return true;
} }
/** /**

View File

@ -15,12 +15,74 @@
*/ */
package io.netty.buffer; package io.netty.buffer;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import io.netty.util.CharsetUtil; import io.netty.util.CharsetUtil;
import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCountUtil;
import java.util.Random;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
public class ByteBufUtilTest { public class ByteBufUtilTest {
@Test
public void equalsBufferSubsections() {
byte[] b1 = new byte[128];
byte[] b2 = new byte[256];
Random rand = new Random();
rand.nextBytes(b1);
rand.nextBytes(b2);
final int iB1 = b1.length / 2;
final int iB2 = iB1 + b1.length;
final int length = b1.length - iB1;
System.arraycopy(b1, iB1, b2, iB2, length);
assertTrue(ByteBufUtil.equals(Unpooled.wrappedBuffer(b1), iB1, Unpooled.wrappedBuffer(b2), iB2, length));
}
@Test
public void notEqualsBufferSubsections() {
byte[] b1 = new byte[50];
byte[] b2 = new byte[256];
Random rand = new Random();
rand.nextBytes(b1);
rand.nextBytes(b2);
final int iB1 = b1.length / 2;
final int iB2 = iB1 + b1.length;
final int length = b1.length - iB1;
System.arraycopy(b1, iB1, b2, iB2, length - 1);
assertFalse(ByteBufUtil.equals(Unpooled.wrappedBuffer(b1), iB1, Unpooled.wrappedBuffer(b2), iB2, length));
}
@Test
public void notEqualsBufferOverflow() {
byte[] b1 = new byte[8];
byte[] b2 = new byte[16];
Random rand = new Random();
rand.nextBytes(b1);
rand.nextBytes(b2);
final int iB1 = b1.length / 2;
final int iB2 = iB1 + b1.length;
final int length = b1.length - iB1;
System.arraycopy(b1, iB1, b2, iB2, length - 1);
assertFalse(ByteBufUtil.equals(Unpooled.wrappedBuffer(b1), iB1, Unpooled.wrappedBuffer(b2), iB2,
Math.max(b1.length, b2.length) * 2));
}
@Test (expected = IllegalArgumentException.class)
public void notEqualsBufferUnderflow() {
byte[] b1 = new byte[8];
byte[] b2 = new byte[16];
Random rand = new Random();
rand.nextBytes(b1);
rand.nextBytes(b2);
final int iB1 = b1.length / 2;
final int iB2 = iB1 + b1.length;
final int length = b1.length - iB1;
System.arraycopy(b1, iB1, b2, iB2, length - 1);
assertFalse(ByteBufUtil.equals(Unpooled.wrappedBuffer(b1), iB1, Unpooled.wrappedBuffer(b2), iB2,
-1));
}
@Test @Test
public void testWriteUsAscii() { public void testWriteUsAscii() {

View File

@ -37,14 +37,13 @@ import java.util.List;
* {@link Http2LocalFlowController} * {@link Http2LocalFlowController}
*/ */
public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder { public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder {
private final Http2FrameListener internalFrameListener = new FrameReadListener(); private Http2FrameListener internalFrameListener = new PrefaceFrameListener();
private final Http2Connection connection; private final Http2Connection connection;
private final Http2LifecycleManager lifecycleManager; private final Http2LifecycleManager lifecycleManager;
private final Http2ConnectionEncoder encoder; private final Http2ConnectionEncoder encoder;
private final Http2FrameReader frameReader; private final Http2FrameReader frameReader;
private final Http2FrameListener listener; private final Http2FrameListener listener;
private final Http2PromisedRequestVerifier requestVerifier; private final Http2PromisedRequestVerifier requestVerifier;
private boolean prefaceReceived;
/** /**
* Builder for instances of {@link DefaultHttp2ConnectionDecoder}. * Builder for instances of {@link DefaultHttp2ConnectionDecoder}.
@ -138,7 +137,7 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder {
@Override @Override
public boolean prefaceReceived() { public boolean prefaceReceived() {
return prefaceReceived; return FrameReadListener.class == internalFrameListener.getClass();
} }
@Override @Override
@ -213,16 +212,26 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder {
return flowController().unconsumedBytes(stream); return flowController().unconsumedBytes(stream);
} }
void onGoAwayRead0(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData)
throws Http2Exception {
// Don't allow any more connections to be created.
connection.goAwayReceived(lastStreamId);
listener.onGoAwayRead(ctx, lastStreamId, errorCode, debugData);
}
void onUnknownFrame0(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags,
ByteBuf payload) throws Http2Exception {
listener.onUnknownFrame(ctx, frameType, streamId, flags, payload);
}
/** /**
* Handles all inbound frames from the network. * Handles all inbound frames from the network.
*/ */
private final class FrameReadListener implements Http2FrameListener { private final class FrameReadListener implements Http2FrameListener {
@Override @Override
public int onDataRead(final ChannelHandlerContext ctx, int streamId, ByteBuf data, public int onDataRead(final ChannelHandlerContext ctx, int streamId, ByteBuf data,
int padding, boolean endOfStream) throws Http2Exception { int padding, boolean endOfStream) throws Http2Exception {
verifyPrefaceReceived();
// Check if we received a data frame for a stream which is half-closed // Check if we received a data frame for a stream which is half-closed
Http2Stream stream = connection.requireStream(streamId); Http2Stream stream = connection.requireStream(streamId);
@ -304,15 +313,6 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder {
} }
} }
/**
* Verifies that the HTTP/2 connection preface has been received from the remote endpoint.
*/
private void verifyPrefaceReceived() throws Http2Exception {
if (!prefaceReceived) {
throw connectionError(PROTOCOL_ERROR, "Received non-SETTINGS as first frame.");
}
}
@Override @Override
public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding, public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding,
boolean endOfStream) throws Http2Exception { boolean endOfStream) throws Http2Exception {
@ -322,8 +322,6 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder {
@Override @Override
public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int streamDependency, public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int streamDependency,
short weight, boolean exclusive, int padding, boolean endOfStream) throws Http2Exception { short weight, boolean exclusive, int padding, boolean endOfStream) throws Http2Exception {
verifyPrefaceReceived();
Http2Stream stream = connection.stream(streamId); Http2Stream stream = connection.stream(streamId);
verifyGoAwayNotReceived(); verifyGoAwayNotReceived();
if (shouldIgnoreFrame(stream, false)) { if (shouldIgnoreFrame(stream, false)) {
@ -369,8 +367,6 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder {
@Override @Override
public void onPriorityRead(ChannelHandlerContext ctx, int streamId, int streamDependency, short weight, public void onPriorityRead(ChannelHandlerContext ctx, int streamId, int streamDependency, short weight,
boolean exclusive) throws Http2Exception { boolean exclusive) throws Http2Exception {
verifyPrefaceReceived();
Http2Stream stream = connection.stream(streamId); Http2Stream stream = connection.stream(streamId);
verifyGoAwayNotReceived(); verifyGoAwayNotReceived();
if (shouldIgnoreFrame(stream, true)) { if (shouldIgnoreFrame(stream, true)) {
@ -393,8 +389,6 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder {
@Override @Override
public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode) throws Http2Exception { public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode) throws Http2Exception {
verifyPrefaceReceived();
Http2Stream stream = connection.requireStream(streamId); Http2Stream stream = connection.requireStream(streamId);
if (stream.state() == CLOSED) { if (stream.state() == CLOSED) {
// RstStream frames must be ignored for closed streams. // RstStream frames must be ignored for closed streams.
@ -408,9 +402,7 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder {
@Override @Override
public void onSettingsAckRead(ChannelHandlerContext ctx) throws Http2Exception { public void onSettingsAckRead(ChannelHandlerContext ctx) throws Http2Exception {
verifyPrefaceReceived(); // Apply oldest outstanding local settings here. This is a synchronization point between endpoints.
// Apply oldest outstanding local settings here. This is a synchronization point
// between endpoints.
Http2Settings settings = encoder.pollSentSettings(); Http2Settings settings = encoder.pollSentSettings();
if (settings != null) { if (settings != null) {
@ -469,16 +461,11 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder {
// Acknowledge receipt of the settings. // Acknowledge receipt of the settings.
encoder.writeSettingsAck(ctx, ctx.newPromise()); encoder.writeSettingsAck(ctx, ctx.newPromise());
// We've received at least one non-ack settings frame from the remote endpoint.
prefaceReceived = true;
listener.onSettingsRead(ctx, settings); listener.onSettingsRead(ctx, settings);
} }
@Override @Override
public void onPingRead(ChannelHandlerContext ctx, ByteBuf data) throws Http2Exception { public void onPingRead(ChannelHandlerContext ctx, ByteBuf data) throws Http2Exception {
verifyPrefaceReceived();
// Send an ack back to the remote client. // Send an ack back to the remote client.
// Need to retain the buffer here since it will be released after the write completes. // Need to retain the buffer here since it will be released after the write completes.
encoder.writePing(ctx, true, data.retain(), ctx.newPromise()); encoder.writePing(ctx, true, data.retain(), ctx.newPromise());
@ -489,16 +476,12 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder {
@Override @Override
public void onPingAckRead(ChannelHandlerContext ctx, ByteBuf data) throws Http2Exception { public void onPingAckRead(ChannelHandlerContext ctx, ByteBuf data) throws Http2Exception {
verifyPrefaceReceived();
listener.onPingAckRead(ctx, data); listener.onPingAckRead(ctx, data);
} }
@Override @Override
public void onPushPromiseRead(ChannelHandlerContext ctx, int streamId, int promisedStreamId, public void onPushPromiseRead(ChannelHandlerContext ctx, int streamId, int promisedStreamId,
Http2Headers headers, int padding) throws Http2Exception { Http2Headers headers, int padding) throws Http2Exception {
verifyPrefaceReceived();
Http2Stream parentStream = connection.requireStream(streamId); Http2Stream parentStream = connection.requireStream(streamId);
verifyGoAwayNotReceived(); verifyGoAwayNotReceived();
if (shouldIgnoreFrame(parentStream, false)) { if (shouldIgnoreFrame(parentStream, false)) {
@ -543,17 +526,12 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder {
@Override @Override
public void onGoAwayRead(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData) public void onGoAwayRead(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData)
throws Http2Exception { throws Http2Exception {
// Don't allow any more connections to be created. onGoAwayRead0(ctx, lastStreamId, errorCode, debugData);
connection.goAwayReceived(lastStreamId);
listener.onGoAwayRead(ctx, lastStreamId, errorCode, debugData);
} }
@Override @Override
public void onWindowUpdateRead(ChannelHandlerContext ctx, int streamId, int windowSizeIncrement) public void onWindowUpdateRead(ChannelHandlerContext ctx, int streamId, int windowSizeIncrement)
throws Http2Exception { throws Http2Exception {
verifyPrefaceReceived();
Http2Stream stream = connection.requireStream(streamId); Http2Stream stream = connection.requireStream(streamId);
verifyGoAwayNotReceived(); verifyGoAwayNotReceived();
if (stream.state() == CLOSED || shouldIgnoreFrame(stream, false)) { if (stream.state() == CLOSED || shouldIgnoreFrame(stream, false)) {
@ -569,8 +547,8 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder {
@Override @Override
public void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, public void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags,
ByteBuf payload) { ByteBuf payload) throws Http2Exception {
listener.onUnknownFrame(ctx, frameType, streamId, flags, payload); onUnknownFrame0(ctx, frameType, streamId, flags, payload);
} }
/** /**
@ -600,4 +578,107 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder {
} }
} }
} }
private final class PrefaceFrameListener implements Http2FrameListener {
/**
* Verifies that the HTTP/2 connection preface has been received from the remote endpoint.
* It is possible that the current call to
* {@link Http2FrameReader#readFrame(ChannelHandlerContext, ByteBuf, Http2FrameListener)} will have multiple
* frames to dispatch. So it may be OK for this class to get legitimate frames for the first readFrame.
*/
private void verifyPrefaceReceived() throws Http2Exception {
if (!prefaceReceived()) {
throw connectionError(PROTOCOL_ERROR, "Received non-SETTINGS as first frame.");
}
}
@Override
public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endOfStream)
throws Http2Exception {
verifyPrefaceReceived();
return internalFrameListener.onDataRead(ctx, streamId, data, padding, endOfStream);
}
@Override
public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding,
boolean endOfStream) throws Http2Exception {
verifyPrefaceReceived();
internalFrameListener.onHeadersRead(ctx, streamId, headers, padding, endOfStream);
}
@Override
public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int streamDependency,
short weight, boolean exclusive, int padding, boolean endOfStream) throws Http2Exception {
verifyPrefaceReceived();
internalFrameListener.onHeadersRead(ctx, streamId, headers, streamDependency, weight,
exclusive, padding, endOfStream);
}
@Override
public void onPriorityRead(ChannelHandlerContext ctx, int streamId, int streamDependency, short weight,
boolean exclusive) throws Http2Exception {
verifyPrefaceReceived();
internalFrameListener.onPriorityRead(ctx, streamId, streamDependency, weight, exclusive);
}
@Override
public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode) throws Http2Exception {
verifyPrefaceReceived();
internalFrameListener.onRstStreamRead(ctx, streamId, errorCode);
}
@Override
public void onSettingsAckRead(ChannelHandlerContext ctx) throws Http2Exception {
verifyPrefaceReceived();
internalFrameListener.onSettingsAckRead(ctx);
}
@Override
public void onSettingsRead(ChannelHandlerContext ctx, Http2Settings settings) throws Http2Exception {
// The first settings should change the internalFrameListener to the "real" listener
// that expects the preface to be verified.
if (!prefaceReceived()) {
internalFrameListener = new FrameReadListener();
}
internalFrameListener.onSettingsRead(ctx, settings);
}
@Override
public void onPingRead(ChannelHandlerContext ctx, ByteBuf data) throws Http2Exception {
verifyPrefaceReceived();
internalFrameListener.onPingRead(ctx, data);
}
@Override
public void onPingAckRead(ChannelHandlerContext ctx, ByteBuf data) throws Http2Exception {
verifyPrefaceReceived();
internalFrameListener.onPingAckRead(ctx, data);
}
@Override
public void onPushPromiseRead(ChannelHandlerContext ctx, int streamId, int promisedStreamId,
Http2Headers headers, int padding) throws Http2Exception {
verifyPrefaceReceived();
internalFrameListener.onPushPromiseRead(ctx, streamId, promisedStreamId, headers, padding);
}
@Override
public void onGoAwayRead(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData)
throws Http2Exception {
onGoAwayRead0(ctx, lastStreamId, errorCode, debugData);
}
@Override
public void onWindowUpdateRead(ChannelHandlerContext ctx, int streamId, int windowSizeIncrement)
throws Http2Exception {
verifyPrefaceReceived();
internalFrameListener.onWindowUpdateRead(ctx, streamId, windowSizeIncrement);
}
@Override
public void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags,
ByteBuf payload) throws Http2Exception {
onUnknownFrame0(ctx, frameType, streamId, flags, payload);
}
}
} }

View File

@ -568,7 +568,8 @@ public class DefaultHttp2FrameReader implements Http2FrameReader, Http2FrameSize
listener); listener);
} }
private void readUnknownFrame(ChannelHandlerContext ctx, ByteBuf payload, Http2FrameListener listener) { private void readUnknownFrame(ChannelHandlerContext ctx, ByteBuf payload, Http2FrameListener listener)
throws Http2Exception {
payload = payload.readSlice(payload.readableBytes()); payload = payload.readSlice(payload.readableBytes());
listener.onUnknownFrame(ctx, frameType, streamId, flags, payload); listener.onUnknownFrame(ctx, frameType, streamId, flags, payload);
} }

View File

@ -24,6 +24,7 @@ import static io.netty.handler.codec.http2.Http2Exception.connectionError;
import static io.netty.handler.codec.http2.Http2Exception.isStreamError; import static io.netty.handler.codec.http2.Http2Exception.isStreamError;
import static io.netty.util.internal.ObjectUtil.checkNotNull; import static io.netty.util.internal.ObjectUtil.checkNotNull;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
@ -50,9 +51,8 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
ChannelOutboundHandler { ChannelOutboundHandler {
private final Http2ConnectionDecoder decoder; private final Http2ConnectionDecoder decoder;
private final Http2ConnectionEncoder encoder; private final Http2ConnectionEncoder encoder;
private ByteBuf clientPrefaceString;
private boolean prefaceSent;
private ChannelFutureListener closeListener; private ChannelFutureListener closeListener;
private BaseDecoder byteDecoder;
public Http2ConnectionHandler(boolean server, Http2FrameListener listener) { public Http2ConnectionHandler(boolean server, Http2FrameListener listener) {
this(new DefaultHttp2Connection(server), listener); this(new DefaultHttp2Connection(server), listener);
@ -112,6 +112,10 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
return encoder; return encoder;
} }
private boolean prefaceSent() {
return byteDecoder != null && byteDecoder.prefaceSent();
}
/** /**
* Handles the client-side (cleartext) upgrade from HTTP to HTTP/2. * Handles the client-side (cleartext) upgrade from HTTP to HTTP/2.
* Reserves local stream 1 for the HTTP/2 response. * Reserves local stream 1 for the HTTP/2 response.
@ -120,7 +124,7 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
if (connection().isServer()) { if (connection().isServer()) {
throw connectionError(PROTOCOL_ERROR, "Client-side HTTP upgrade requested for a server"); throw connectionError(PROTOCOL_ERROR, "Client-side HTTP upgrade requested for a server");
} }
if (prefaceSent || decoder.prefaceReceived()) { if (prefaceSent() || decoder.prefaceReceived()) {
throw connectionError(PROTOCOL_ERROR, "HTTP upgrade must occur before HTTP/2 preface is sent or received"); throw connectionError(PROTOCOL_ERROR, "HTTP upgrade must occur before HTTP/2 preface is sent or received");
} }
@ -136,7 +140,7 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
if (!connection().isServer()) { if (!connection().isServer()) {
throw connectionError(PROTOCOL_ERROR, "Server-side HTTP upgrade requested for a client"); throw connectionError(PROTOCOL_ERROR, "Server-side HTTP upgrade requested for a client");
} }
if (prefaceSent || decoder.prefaceReceived()) { if (prefaceSent() || decoder.prefaceReceived()) {
throw connectionError(PROTOCOL_ERROR, "HTTP upgrade must occur before HTTP/2 preface is sent or received"); throw connectionError(PROTOCOL_ERROR, "HTTP upgrade must occur before HTTP/2 preface is sent or received");
} }
@ -147,32 +151,191 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
connection().remote().createStream(HTTP_UPGRADE_STREAM_ID).open(true); connection().remote().createStream(HTTP_UPGRADE_STREAM_ID).open(true);
} }
@Override private abstract class BaseDecoder {
public void channelActive(ChannelHandlerContext ctx) throws Exception { public abstract void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception;
// The channel just became active - send the connection preface to the remote endpoint. public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { }
sendPreface(ctx); public void channelActive(ChannelHandlerContext ctx) throws Exception { }
super.channelActive(ctx);
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
try {
ChannelFuture future = ctx.newSucceededFuture();
final Collection<Http2Stream> streams = connection().activeStreams();
for (Http2Stream s : streams.toArray(new Http2Stream[streams.size()])) {
closeStream(s, future);
}
} finally {
try {
encoder().close();
} finally {
decoder().close();
}
}
}
/**
* Determine if the HTTP/2 connection preface been sent.
*/
public boolean prefaceSent() {
return true;
}
}
private final class PrefaceDecoder extends BaseDecoder {
private ByteBuf clientPrefaceString;
private boolean prefaceSent;
public PrefaceDecoder(ChannelHandlerContext ctx) {
clientPrefaceString = clientPrefaceString(encoder.connection());
// This handler was just added to the context. In case it was handled after
// the connection became active, send the connection preface now.
sendPreface(ctx);
}
@Override
public boolean prefaceSent() {
return prefaceSent;
}
@Override
public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
try {
if (readClientPrefaceString(in)) {
// After the preface is read, it is time to hand over control to the post initialized decoder.
Http2ConnectionHandler.this.byteDecoder = new FrameDecoder();
Http2ConnectionHandler.this.byteDecoder.decode(ctx, in, out);
}
} catch (Throwable e) {
onException(ctx, e);
}
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
// The channel just became active - send the connection preface to the remote endpoint.
sendPreface(ctx);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
cleanup();
super.channelInactive(ctx);
}
/**
* Releases the {@code clientPrefaceString}. Any active streams will be left in the open.
*/
@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
cleanup();
}
/**
* Releases the {@code clientPrefaceString}. Any active streams will be left in the open.
*/
private void cleanup() {
if (clientPrefaceString != null) {
clientPrefaceString.release();
clientPrefaceString = null;
}
}
/**
* Decodes the client connection preface string from the input buffer.
*
* @return {@code true} if processing of the client preface string is complete. Since client preface strings can
* only be received by servers, returns true immediately for client endpoints.
*/
private boolean readClientPrefaceString(ByteBuf in) throws Http2Exception {
if (clientPrefaceString == null) {
return true;
}
int prefaceRemaining = clientPrefaceString.readableBytes();
int bytesRead = Math.min(in.readableBytes(), prefaceRemaining);
// If the input so far doesn't match the preface, break the connection.
if (bytesRead == 0 || !ByteBufUtil.equals(in, in.readerIndex(),
clientPrefaceString, clientPrefaceString.readerIndex(), bytesRead)) {
throw connectionError(PROTOCOL_ERROR, "HTTP/2 client preface string missing or corrupt.");
}
in.skipBytes(bytesRead);
clientPrefaceString.skipBytes(bytesRead);
if (!clientPrefaceString.isReadable()) {
// Entire preface has been read.
clientPrefaceString.release();
clientPrefaceString = null;
return true;
}
return false;
}
/**
* Sends the HTTP/2 connection preface upon establishment of the connection, if not already sent.
*/
private void sendPreface(ChannelHandlerContext ctx) {
if (prefaceSent || !ctx.channel().isActive()) {
return;
}
prefaceSent = true;
if (!connection().isServer()) {
// Clients must send the preface string as the first bytes on the connection.
ctx.write(connectionPrefaceBuf()).addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
}
// Both client and server must send their initial settings.
encoder.writeSettings(ctx, decoder.localSettings(), ctx.newPromise()).addListener(
ChannelFutureListener.CLOSE_ON_FAILURE);
}
}
private final class FrameDecoder extends BaseDecoder {
@Override
public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
try {
decoder.decodeFrame(ctx, in, out);
} catch (Throwable e) {
onException(ctx, e);
}
}
} }
@Override @Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception { public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
clientPrefaceString = clientPrefaceString(encoder.connection()); byteDecoder = new PrefaceDecoder(ctx);
// This handler was just added to the context. In case it was handled after
// the connection became active, send the connection preface now.
sendPreface(ctx);
} }
/**
* Releases the {@code clientPrefaceString}. Any active streams will be left in the open.
*/
@Override @Override
protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
if (clientPrefaceString != null) { if (byteDecoder != null) {
clientPrefaceString.release(); byteDecoder.handlerRemoved(ctx);
clientPrefaceString = null; byteDecoder = null;
} }
} }
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
if (byteDecoder == null) {
byteDecoder = new PrefaceDecoder(ctx);
}
byteDecoder.channelActive(ctx);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
if (byteDecoder != null) {
byteDecoder.channelInactive(ctx);
byteDecoder = null;
}
}
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
byteDecoder.decode(ctx, in, out);
}
@Override @Override
public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception { public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception {
ctx.bind(localAddress, promise); ctx.bind(localAddress, promise);
@ -228,24 +391,6 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
ctx.flush(); ctx.flush();
} }
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
try {
ChannelFuture future = ctx.newSucceededFuture();
final Collection<Http2Stream> streams = connection().activeStreams();
for (Http2Stream s : streams.toArray(new Http2Stream[streams.size()])) {
closeStream(s, future);
}
} finally {
try {
encoder().close();
} finally {
decoder().close();
}
}
super.channelInactive(ctx);
}
/** /**
* Handles {@link Http2Exception} objects that were thrown from other handlers. Ignores all other exceptions. * Handles {@link Http2Exception} objects that were thrown from other handlers. Ignores all other exceptions.
*/ */
@ -318,7 +463,7 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
if (closeListener != null && connection().numActiveStreams() == 0) { if (closeListener != null && connection().numActiveStreams() == 0) {
ChannelFutureListener closeListener = Http2ConnectionHandler.this.closeListener; ChannelFutureListener closeListener = Http2ConnectionHandler.this.closeListener;
// This method could be called multiple times // This method could be called multiple times
// and we don't want to notify the closeListener multiple times // and we don't want to notify the closeListener multiple times.
Http2ConnectionHandler.this.closeListener = null; Http2ConnectionHandler.this.closeListener = null;
closeListener.operationComplete(future); closeListener.operationComplete(future);
} }
@ -446,76 +591,6 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
return writeGoAway(ctx, lastKnownStream, errorCode, debugData, ctx.newPromise()); return writeGoAway(ctx, lastKnownStream, errorCode, debugData, ctx.newPromise());
} }
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
try {
// Read the remaining of the client preface string if we haven't already.
// If this is a client endpoint, always returns true.
if (!readClientPrefaceString(in)) {
// Still processing the client preface.
return;
}
decoder.decodeFrame(ctx, in, out);
} catch (Throwable e) {
onException(ctx, e);
}
}
/**
* Sends the HTTP/2 connection preface upon establishment of the connection, if not already sent.
*/
private void sendPreface(final ChannelHandlerContext ctx) {
if (prefaceSent || !ctx.channel().isActive()) {
return;
}
prefaceSent = true;
if (!connection().isServer()) {
// Clients must send the preface string as the first bytes on the connection.
ctx.write(connectionPrefaceBuf()).addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
}
// Both client and server must send their initial settings.
encoder.writeSettings(ctx, decoder.localSettings(), ctx.newPromise()).addListener(
ChannelFutureListener.CLOSE_ON_FAILURE);
}
/**
* Decodes the client connection preface string from the input buffer.
*
* @return {@code true} if processing of the client preface string is complete. Since client preface strings can
* only be received by servers, returns true immediately for client endpoints.
*/
private boolean readClientPrefaceString(ByteBuf in) throws Http2Exception {
if (clientPrefaceString == null) {
return true;
}
int prefaceRemaining = clientPrefaceString.readableBytes();
int bytesRead = Math.min(in.readableBytes(), prefaceRemaining);
// Read the portion of the input up to the length of the preface, if reached.
ByteBuf sourceSlice = in.readSlice(bytesRead);
// Read the same number of bytes from the preface buffer.
ByteBuf prefaceSlice = clientPrefaceString.readSlice(bytesRead);
// If the input so far doesn't match the preface, break the connection.
if (bytesRead == 0 || !prefaceSlice.equals(sourceSlice)) {
throw connectionError(PROTOCOL_ERROR, "HTTP/2 client preface string missing or corrupt.");
}
if (!clientPrefaceString.isReadable()) {
// Entire preface has been read.
clientPrefaceString.release();
clientPrefaceString = null;
return true;
}
return false;
}
/** /**
* Returns the client preface string if this is a client connection, otherwise returns {@code null}. * Returns the client preface string if this is a client connection, otherwise returns {@code null}.
*/ */

View File

@ -211,5 +211,6 @@ public interface Http2FrameListener {
* @param flags the flags in the frame header. * @param flags the flags in the frame header.
* @param payload the payload of the frame. * @param payload the payload of the frame.
*/ */
void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, ByteBuf payload); void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, ByteBuf payload)
throws Http2Exception;
} }

View File

@ -97,7 +97,7 @@ public class Http2FrameListenerDecorator implements Http2FrameListener {
@Override @Override
public void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, public void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags,
ByteBuf payload) { ByteBuf payload) throws Http2Exception {
listener.onUnknownFrame(ctx, frameType, streamId, flags, payload); listener.onUnknownFrame(ctx, frameType, streamId, flags, payload);
} }
} }

View File

@ -126,7 +126,7 @@ public class Http2InboundFrameLogger implements Http2FrameReader {
@Override @Override
public void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, public void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId,
Http2Flags flags, ByteBuf payload) { Http2Flags flags, ByteBuf payload) throws Http2Exception {
logger.logUnknownFrame(INBOUND, frameType, streamId, flags, payload); logger.logUnknownFrame(INBOUND, frameType, streamId, flags, payload);
listener.onUnknownFrame(ctx, frameType, streamId, flags, payload); listener.onUnknownFrame(ctx, frameType, streamId, flags, payload);
} }

View File

@ -34,6 +34,7 @@ import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
@ -176,10 +177,43 @@ public class Http2ConnectionHandlerTest {
@Test @Test
public void serverReceivingValidClientPrefaceStringShouldContinueReadingFrames() throws Exception { public void serverReceivingValidClientPrefaceStringShouldContinueReadingFrames() throws Exception {
when(connection.isServer()).thenReturn(true);
handler = newHandler();
ByteBuf preface = connectionPrefaceBuf();
ByteBuf prefacePlusSome = Unpooled.wrappedBuffer(new byte[preface.readableBytes() + 1]);
prefacePlusSome.resetWriterIndex().writeBytes(preface).writeByte(0);
handler.channelRead(ctx, prefacePlusSome);
verify(decoder, times(2)).decodeFrame(eq(ctx), any(ByteBuf.class), Matchers.<List<Object>>any());
}
@Test
public void serverReceivingValidClientPrefaceStringShouldOnlyReadWholeFrame() throws Exception {
when(connection.isServer()).thenReturn(true); when(connection.isServer()).thenReturn(true);
handler = newHandler(); handler = newHandler();
handler.channelRead(ctx, connectionPrefaceBuf()); handler.channelRead(ctx, connectionPrefaceBuf());
verify(decoder).decodeFrame(eq(ctx), any(ByteBuf.class), Matchers.<List<Object>>any()); verify(decoder).decodeFrame(any(ChannelHandlerContext.class),
any(ByteBuf.class), Matchers.<List<Object>>any());
}
@Test
public void verifyChannelHandlerCanBeReusedInPipeline() throws Exception {
when(connection.isServer()).thenReturn(true);
handler = newHandler();
// Only read the connection preface...after preface is read internal state of Http2ConnectionHandler
// is expected to change relative to the pipeline.
ByteBuf preface = connectionPrefaceBuf();
verify(decoder, never()).decodeFrame(any(ChannelHandlerContext.class),
any(ByteBuf.class), Matchers.<List<Object>>any());
// Now remove and add the handler...this is setting up the test condition.
handler.handlerRemoved(ctx);
handler.handlerAdded(ctx);
// Now verify we can continue as normal, reading connection preface plus more.
ByteBuf prefacePlusSome = Unpooled.wrappedBuffer(new byte[preface.readableBytes() + 1]);
prefacePlusSome.resetWriterIndex().writeBytes(preface).writeByte(0);
handler.channelRead(ctx, prefacePlusSome);
verify(decoder, times(2)).decodeFrame(eq(ctx), any(ByteBuf.class), Matchers.<List<Object>>any());
} }
@Test @Test

View File

@ -247,7 +247,7 @@ final class Http2TestUtil {
@Override @Override
public void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, public void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags,
ByteBuf payload) { ByteBuf payload) throws Http2Exception {
listener.onUnknownFrame(ctx, frameType, streamId, flags, payload); listener.onUnknownFrame(ctx, frameType, streamId, flags, payload);
latch.countDown(); latch.countDown();
} }
@ -373,7 +373,7 @@ final class Http2TestUtil {
@Override @Override
public void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, public void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags,
ByteBuf payload) { ByteBuf payload) throws Http2Exception {
listener.onUnknownFrame(ctx, frameType, streamId, flags, payload); listener.onUnknownFrame(ctx, frameType, streamId, flags, payload);
messageLatch.countDown(); messageLatch.countDown();
} }

View File

@ -20,7 +20,6 @@ import static io.netty.buffer.Unpooled.unreleasableBuffer;
import static io.netty.example.http2.Http2ExampleUtil.UPGRADE_RESPONSE_HEADER; import static io.netty.example.http2.Http2ExampleUtil.UPGRADE_RESPONSE_HEADER;
import static io.netty.handler.codec.http.HttpResponseStatus.OK; import static io.netty.handler.codec.http.HttpResponseStatus.OK;
import static io.netty.handler.logging.LogLevel.INFO; import static io.netty.handler.logging.LogLevel.INFO;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.AsciiString; import io.netty.handler.codec.AsciiString;
@ -79,7 +78,8 @@ public class HelloWorldHttp2Handler extends Http2ConnectionHandler {
} }
@Override @Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
super.exceptionCaught(ctx, cause);
cause.printStackTrace(); cause.printStackTrace();
ctx.close(); ctx.close();
} }