diff --git a/buffer/src/main/java/io/netty/buffer/ByteBufUtil.java b/buffer/src/main/java/io/netty/buffer/ByteBufUtil.java index fcf9f0fe71..871db9bcdd 100644 --- a/buffer/src/main/java/io/netty/buffer/ByteBufUtil.java +++ b/buffer/src/main/java/io/netty/buffer/ByteBufUtil.java @@ -178,7 +178,55 @@ public final class ByteBufUtil { /** * Returns {@code true} if and only if the two specified buffers are - * identical to each other as described in {@code ChannelBuffer#equals(Object)}. + * identical to each other for {@code length} bytes starting at {@code aStartIndex} + * index for the {@code a} buffer and {@code bStartIndex} index for the {@code b} buffer. + * A more compact way to express this is: + *
+ * {@code a[aStartIndex : aStartIndex + length] == b[bStartIndex : bStartIndex + length]} + */ + public static boolean equals(ByteBuf a, int aStartIndex, ByteBuf b, int bStartIndex, int length) { + if (aStartIndex < 0 || bStartIndex < 0 || length < 0) { + throw new IllegalArgumentException("All indexes and lengths must be non-negative"); + } + if (a.writerIndex() - length < aStartIndex || b.writerIndex() - length < bStartIndex) { + return false; + } + + final int longCount = length >>> 3; + final int byteCount = length & 7; + + if (a.order() == b.order()) { + for (int i = longCount; i > 0; i --) { + if (a.getLong(aStartIndex) != b.getLong(bStartIndex)) { + return false; + } + aStartIndex += 8; + bStartIndex += 8; + } + } else { + for (int i = longCount; i > 0; i --) { + if (a.getLong(aStartIndex) != swapLong(b.getLong(bStartIndex))) { + return false; + } + aStartIndex += 8; + bStartIndex += 8; + } + } + + for (int i = byteCount; i > 0; i --) { + if (a.getByte(aStartIndex) != b.getByte(bStartIndex)) { + return false; + } + aStartIndex ++; + bStartIndex ++; + } + + return true; + } + + /** + * Returns {@code true} if and only if the two specified buffers are + * identical to each other as described in {@link ByteBuf#equals(Object)}. * This method is useful when implementing a new buffer type. */ public static boolean equals(ByteBuf bufferA, ByteBuf bufferB) { @@ -186,40 +234,7 @@ public final class ByteBufUtil { if (aLen != bufferB.readableBytes()) { return false; } - - final int longCount = aLen >>> 3; - final int byteCount = aLen & 7; - - int aIndex = bufferA.readerIndex(); - int bIndex = bufferB.readerIndex(); - - if (bufferA.order() == bufferB.order()) { - for (int i = longCount; i > 0; i --) { - if (bufferA.getLong(aIndex) != bufferB.getLong(bIndex)) { - return false; - } - aIndex += 8; - bIndex += 8; - } - } else { - for (int i = longCount; i > 0; i --) { - if (bufferA.getLong(aIndex) != swapLong(bufferB.getLong(bIndex))) { - return false; - } - aIndex += 8; - bIndex += 8; - } - } - - for (int i = byteCount; i > 0; i --) { - if (bufferA.getByte(aIndex) != bufferB.getByte(bIndex)) { - return false; - } - aIndex ++; - bIndex ++; - } - - return true; + return equals(bufferA, bufferA.readerIndex(), bufferB, bufferB.readerIndex(), aLen); } /** diff --git a/buffer/src/test/java/io/netty/buffer/ByteBufUtilTest.java b/buffer/src/test/java/io/netty/buffer/ByteBufUtilTest.java index fb04c78ce7..fcdf987317 100644 --- a/buffer/src/test/java/io/netty/buffer/ByteBufUtilTest.java +++ b/buffer/src/test/java/io/netty/buffer/ByteBufUtilTest.java @@ -15,12 +15,74 @@ */ package io.netty.buffer; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; import io.netty.util.CharsetUtil; import io.netty.util.ReferenceCountUtil; + +import java.util.Random; + import org.junit.Assert; import org.junit.Test; public class ByteBufUtilTest { + @Test + public void equalsBufferSubsections() { + byte[] b1 = new byte[128]; + byte[] b2 = new byte[256]; + Random rand = new Random(); + rand.nextBytes(b1); + rand.nextBytes(b2); + final int iB1 = b1.length / 2; + final int iB2 = iB1 + b1.length; + final int length = b1.length - iB1; + System.arraycopy(b1, iB1, b2, iB2, length); + assertTrue(ByteBufUtil.equals(Unpooled.wrappedBuffer(b1), iB1, Unpooled.wrappedBuffer(b2), iB2, length)); + } + + @Test + public void notEqualsBufferSubsections() { + byte[] b1 = new byte[50]; + byte[] b2 = new byte[256]; + Random rand = new Random(); + rand.nextBytes(b1); + rand.nextBytes(b2); + final int iB1 = b1.length / 2; + final int iB2 = iB1 + b1.length; + final int length = b1.length - iB1; + System.arraycopy(b1, iB1, b2, iB2, length - 1); + assertFalse(ByteBufUtil.equals(Unpooled.wrappedBuffer(b1), iB1, Unpooled.wrappedBuffer(b2), iB2, length)); + } + + @Test + public void notEqualsBufferOverflow() { + byte[] b1 = new byte[8]; + byte[] b2 = new byte[16]; + Random rand = new Random(); + rand.nextBytes(b1); + rand.nextBytes(b2); + final int iB1 = b1.length / 2; + final int iB2 = iB1 + b1.length; + final int length = b1.length - iB1; + System.arraycopy(b1, iB1, b2, iB2, length - 1); + assertFalse(ByteBufUtil.equals(Unpooled.wrappedBuffer(b1), iB1, Unpooled.wrappedBuffer(b2), iB2, + Math.max(b1.length, b2.length) * 2)); + } + + @Test (expected = IllegalArgumentException.class) + public void notEqualsBufferUnderflow() { + byte[] b1 = new byte[8]; + byte[] b2 = new byte[16]; + Random rand = new Random(); + rand.nextBytes(b1); + rand.nextBytes(b2); + final int iB1 = b1.length / 2; + final int iB2 = iB1 + b1.length; + final int length = b1.length - iB1; + System.arraycopy(b1, iB1, b2, iB2, length - 1); + assertFalse(ByteBufUtil.equals(Unpooled.wrappedBuffer(b1), iB1, Unpooled.wrappedBuffer(b2), iB2, + -1)); + } @Test public void testWriteUsAscii() { diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoder.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoder.java index feff3c14ba..5c21461704 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoder.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoder.java @@ -37,14 +37,13 @@ import java.util.List; * {@link Http2LocalFlowController} */ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder { - private final Http2FrameListener internalFrameListener = new FrameReadListener(); + private Http2FrameListener internalFrameListener = new PrefaceFrameListener(); private final Http2Connection connection; private final Http2LifecycleManager lifecycleManager; private final Http2ConnectionEncoder encoder; private final Http2FrameReader frameReader; private final Http2FrameListener listener; private final Http2PromisedRequestVerifier requestVerifier; - private boolean prefaceReceived; /** * Builder for instances of {@link DefaultHttp2ConnectionDecoder}. @@ -138,7 +137,7 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder { @Override public boolean prefaceReceived() { - return prefaceReceived; + return FrameReadListener.class == internalFrameListener.getClass(); } @Override @@ -213,16 +212,26 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder { return flowController().unconsumedBytes(stream); } + void onGoAwayRead0(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData) + throws Http2Exception { + // Don't allow any more connections to be created. + connection.goAwayReceived(lastStreamId); + + listener.onGoAwayRead(ctx, lastStreamId, errorCode, debugData); + } + + void onUnknownFrame0(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, + ByteBuf payload) throws Http2Exception { + listener.onUnknownFrame(ctx, frameType, streamId, flags, payload); + } + /** * Handles all inbound frames from the network. */ private final class FrameReadListener implements Http2FrameListener { - @Override public int onDataRead(final ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endOfStream) throws Http2Exception { - verifyPrefaceReceived(); - // Check if we received a data frame for a stream which is half-closed Http2Stream stream = connection.requireStream(streamId); @@ -304,15 +313,6 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder { } } - /** - * Verifies that the HTTP/2 connection preface has been received from the remote endpoint. - */ - private void verifyPrefaceReceived() throws Http2Exception { - if (!prefaceReceived) { - throw connectionError(PROTOCOL_ERROR, "Received non-SETTINGS as first frame."); - } - } - @Override public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding, boolean endOfStream) throws Http2Exception { @@ -322,8 +322,6 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder { @Override public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int streamDependency, short weight, boolean exclusive, int padding, boolean endOfStream) throws Http2Exception { - verifyPrefaceReceived(); - Http2Stream stream = connection.stream(streamId); verifyGoAwayNotReceived(); if (shouldIgnoreFrame(stream, false)) { @@ -369,8 +367,6 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder { @Override public void onPriorityRead(ChannelHandlerContext ctx, int streamId, int streamDependency, short weight, boolean exclusive) throws Http2Exception { - verifyPrefaceReceived(); - Http2Stream stream = connection.stream(streamId); verifyGoAwayNotReceived(); if (shouldIgnoreFrame(stream, true)) { @@ -393,8 +389,6 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder { @Override public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode) throws Http2Exception { - verifyPrefaceReceived(); - Http2Stream stream = connection.requireStream(streamId); if (stream.state() == CLOSED) { // RstStream frames must be ignored for closed streams. @@ -408,9 +402,7 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder { @Override public void onSettingsAckRead(ChannelHandlerContext ctx) throws Http2Exception { - verifyPrefaceReceived(); - // Apply oldest outstanding local settings here. This is a synchronization point - // between endpoints. + // Apply oldest outstanding local settings here. This is a synchronization point between endpoints. Http2Settings settings = encoder.pollSentSettings(); if (settings != null) { @@ -469,16 +461,11 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder { // Acknowledge receipt of the settings. encoder.writeSettingsAck(ctx, ctx.newPromise()); - // We've received at least one non-ack settings frame from the remote endpoint. - prefaceReceived = true; - listener.onSettingsRead(ctx, settings); } @Override public void onPingRead(ChannelHandlerContext ctx, ByteBuf data) throws Http2Exception { - verifyPrefaceReceived(); - // Send an ack back to the remote client. // Need to retain the buffer here since it will be released after the write completes. encoder.writePing(ctx, true, data.retain(), ctx.newPromise()); @@ -489,16 +476,12 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder { @Override public void onPingAckRead(ChannelHandlerContext ctx, ByteBuf data) throws Http2Exception { - verifyPrefaceReceived(); - listener.onPingAckRead(ctx, data); } @Override public void onPushPromiseRead(ChannelHandlerContext ctx, int streamId, int promisedStreamId, Http2Headers headers, int padding) throws Http2Exception { - verifyPrefaceReceived(); - Http2Stream parentStream = connection.requireStream(streamId); verifyGoAwayNotReceived(); if (shouldIgnoreFrame(parentStream, false)) { @@ -543,17 +526,12 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder { @Override public void onGoAwayRead(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData) throws Http2Exception { - // Don't allow any more connections to be created. - connection.goAwayReceived(lastStreamId); - - listener.onGoAwayRead(ctx, lastStreamId, errorCode, debugData); + onGoAwayRead0(ctx, lastStreamId, errorCode, debugData); } @Override public void onWindowUpdateRead(ChannelHandlerContext ctx, int streamId, int windowSizeIncrement) throws Http2Exception { - verifyPrefaceReceived(); - Http2Stream stream = connection.requireStream(streamId); verifyGoAwayNotReceived(); if (stream.state() == CLOSED || shouldIgnoreFrame(stream, false)) { @@ -569,8 +547,8 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder { @Override public void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, - ByteBuf payload) { - listener.onUnknownFrame(ctx, frameType, streamId, flags, payload); + ByteBuf payload) throws Http2Exception { + onUnknownFrame0(ctx, frameType, streamId, flags, payload); } /** @@ -600,4 +578,107 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder { } } } + + private final class PrefaceFrameListener implements Http2FrameListener { + /** + * Verifies that the HTTP/2 connection preface has been received from the remote endpoint. + * It is possible that the current call to + * {@link Http2FrameReader#readFrame(ChannelHandlerContext, ByteBuf, Http2FrameListener)} will have multiple + * frames to dispatch. So it may be OK for this class to get legitimate frames for the first readFrame. + */ + private void verifyPrefaceReceived() throws Http2Exception { + if (!prefaceReceived()) { + throw connectionError(PROTOCOL_ERROR, "Received non-SETTINGS as first frame."); + } + } + + @Override + public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endOfStream) + throws Http2Exception { + verifyPrefaceReceived(); + return internalFrameListener.onDataRead(ctx, streamId, data, padding, endOfStream); + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding, + boolean endOfStream) throws Http2Exception { + verifyPrefaceReceived(); + internalFrameListener.onHeadersRead(ctx, streamId, headers, padding, endOfStream); + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int streamDependency, + short weight, boolean exclusive, int padding, boolean endOfStream) throws Http2Exception { + verifyPrefaceReceived(); + internalFrameListener.onHeadersRead(ctx, streamId, headers, streamDependency, weight, + exclusive, padding, endOfStream); + } + + @Override + public void onPriorityRead(ChannelHandlerContext ctx, int streamId, int streamDependency, short weight, + boolean exclusive) throws Http2Exception { + verifyPrefaceReceived(); + internalFrameListener.onPriorityRead(ctx, streamId, streamDependency, weight, exclusive); + } + + @Override + public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode) throws Http2Exception { + verifyPrefaceReceived(); + internalFrameListener.onRstStreamRead(ctx, streamId, errorCode); + } + + @Override + public void onSettingsAckRead(ChannelHandlerContext ctx) throws Http2Exception { + verifyPrefaceReceived(); + internalFrameListener.onSettingsAckRead(ctx); + } + + @Override + public void onSettingsRead(ChannelHandlerContext ctx, Http2Settings settings) throws Http2Exception { + // The first settings should change the internalFrameListener to the "real" listener + // that expects the preface to be verified. + if (!prefaceReceived()) { + internalFrameListener = new FrameReadListener(); + } + internalFrameListener.onSettingsRead(ctx, settings); + } + + @Override + public void onPingRead(ChannelHandlerContext ctx, ByteBuf data) throws Http2Exception { + verifyPrefaceReceived(); + internalFrameListener.onPingRead(ctx, data); + } + + @Override + public void onPingAckRead(ChannelHandlerContext ctx, ByteBuf data) throws Http2Exception { + verifyPrefaceReceived(); + internalFrameListener.onPingAckRead(ctx, data); + } + + @Override + public void onPushPromiseRead(ChannelHandlerContext ctx, int streamId, int promisedStreamId, + Http2Headers headers, int padding) throws Http2Exception { + verifyPrefaceReceived(); + internalFrameListener.onPushPromiseRead(ctx, streamId, promisedStreamId, headers, padding); + } + + @Override + public void onGoAwayRead(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData) + throws Http2Exception { + onGoAwayRead0(ctx, lastStreamId, errorCode, debugData); + } + + @Override + public void onWindowUpdateRead(ChannelHandlerContext ctx, int streamId, int windowSizeIncrement) + throws Http2Exception { + verifyPrefaceReceived(); + internalFrameListener.onWindowUpdateRead(ctx, streamId, windowSizeIncrement); + } + + @Override + public void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, + ByteBuf payload) throws Http2Exception { + onUnknownFrame0(ctx, frameType, streamId, flags, payload); + } + } } diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameReader.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameReader.java index d3ee623ed1..94e1bce298 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameReader.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameReader.java @@ -568,7 +568,8 @@ public class DefaultHttp2FrameReader implements Http2FrameReader, Http2FrameSize listener); } - private void readUnknownFrame(ChannelHandlerContext ctx, ByteBuf payload, Http2FrameListener listener) { + private void readUnknownFrame(ChannelHandlerContext ctx, ByteBuf payload, Http2FrameListener listener) + throws Http2Exception { payload = payload.readSlice(payload.readableBytes()); listener.onUnknownFrame(ctx, frameType, streamId, flags, payload); } diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java index 102807d079..47f1d1f240 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java @@ -24,6 +24,7 @@ import static io.netty.handler.codec.http2.Http2Exception.connectionError; import static io.netty.handler.codec.http2.Http2Exception.isStreamError; import static io.netty.util.internal.ObjectUtil.checkNotNull; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; @@ -50,9 +51,8 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http ChannelOutboundHandler { private final Http2ConnectionDecoder decoder; private final Http2ConnectionEncoder encoder; - private ByteBuf clientPrefaceString; - private boolean prefaceSent; private ChannelFutureListener closeListener; + private BaseDecoder byteDecoder; public Http2ConnectionHandler(boolean server, Http2FrameListener listener) { this(new DefaultHttp2Connection(server), listener); @@ -112,6 +112,10 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http return encoder; } + private boolean prefaceSent() { + return byteDecoder != null && byteDecoder.prefaceSent(); + } + /** * Handles the client-side (cleartext) upgrade from HTTP to HTTP/2. * Reserves local stream 1 for the HTTP/2 response. @@ -120,7 +124,7 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http if (connection().isServer()) { throw connectionError(PROTOCOL_ERROR, "Client-side HTTP upgrade requested for a server"); } - if (prefaceSent || decoder.prefaceReceived()) { + if (prefaceSent() || decoder.prefaceReceived()) { throw connectionError(PROTOCOL_ERROR, "HTTP upgrade must occur before HTTP/2 preface is sent or received"); } @@ -136,7 +140,7 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http if (!connection().isServer()) { throw connectionError(PROTOCOL_ERROR, "Server-side HTTP upgrade requested for a client"); } - if (prefaceSent || decoder.prefaceReceived()) { + if (prefaceSent() || decoder.prefaceReceived()) { throw connectionError(PROTOCOL_ERROR, "HTTP upgrade must occur before HTTP/2 preface is sent or received"); } @@ -147,32 +151,191 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http connection().remote().createStream(HTTP_UPGRADE_STREAM_ID).open(true); } - @Override - public void channelActive(ChannelHandlerContext ctx) throws Exception { - // The channel just became active - send the connection preface to the remote endpoint. - sendPreface(ctx); - super.channelActive(ctx); + private abstract class BaseDecoder { + public abstract void decode(ChannelHandlerContext ctx, ByteBuf in, List