From f4c2c1926f47e38a9d6156af1e651daa79ea0e0a Mon Sep 17 00:00:00 2001 From: chhsiao90 Date: Sat, 14 Jan 2017 18:51:30 +0800 Subject: [PATCH] Fixed rfc violation about sending extension frame in the middle of headers Motivation: At rfc7540 5.5, it said that it's not permitted to send extension frame in the middle of header block and need be treated as protocol error Modifications: When received a extension frame, in netty it's called unknown frame, will verify that is there an headersContinuation exists Result: When received a extension frame in the middle of header block, will throw connection error and closed the connection --- .../codec/http2/DefaultHttp2FrameReader.java | 25 +++ .../http2/DefaultHttp2FrameReaderTest.java | 156 ++++++++++++++++++ 2 files changed, 181 insertions(+) create mode 100644 codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2FrameReaderTest.java diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameReader.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameReader.java index 527fb071d0..bc745e105c 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameReader.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameReader.java @@ -125,8 +125,13 @@ public class DefaultHttp2FrameReader implements Http2FrameReader, Http2FrameSize @Override public void close() { + closeHeadersContinuation(); + } + + private void closeHeadersContinuation() { if (headersContinuation != null) { headersContinuation.close(); + headersContinuation = null; } } @@ -398,6 +403,15 @@ public class DefaultHttp2FrameReader implements Http2FrameReader, Http2FrameSize } } + private void verifyUnknownFrame() throws Http2Exception { + if (headersContinuation != null) { + int streamId = headersContinuation.getStreamId(); + closeHeadersContinuation(); + throw connectionError(PROTOCOL_ERROR, "Extension frames must not be in the middle of headers " + + "on stream %d", streamId); + } + } + private void readDataFrame(ChannelHandlerContext ctx, ByteBuf payload, Http2FrameListener listener) throws Http2Exception { int padding = readPadding(payload); @@ -452,6 +466,7 @@ public class DefaultHttp2FrameReader implements Http2FrameReader, Http2FrameSize // Process the initial fragment, invoking the listener's callback if end of headers. headersContinuation.processFragment(flags.endOfHeaders(), fragment, listener); + resetHeadersContinuationIfEnd(flags.endOfHeaders()); return; } @@ -478,6 +493,13 @@ public class DefaultHttp2FrameReader implements Http2FrameReader, Http2FrameSize // Process the initial fragment, invoking the listener's callback if end of headers. final ByteBuf fragment = payload.readSlice(lengthWithoutTrailingPadding(payload.readableBytes(), padding)); headersContinuation.processFragment(flags.endOfHeaders(), fragment, listener); + resetHeadersContinuationIfEnd(flags.endOfHeaders()); + } + + private void resetHeadersContinuationIfEnd(boolean endOfHeaders) { + if (endOfHeaders) { + closeHeadersContinuation(); + } } private void readPriorityFrame(ChannelHandlerContext ctx, ByteBuf payload, @@ -553,6 +575,7 @@ public class DefaultHttp2FrameReader implements Http2FrameReader, Http2FrameSize // Process the initial fragment, invoking the listener's callback if end of headers. final ByteBuf fragment = payload.readSlice(lengthWithoutTrailingPadding(payload.readableBytes(), padding)); headersContinuation.processFragment(flags.endOfHeaders(), fragment, listener); + resetHeadersContinuationIfEnd(flags.endOfHeaders()); } private void readPingFrame(ChannelHandlerContext ctx, ByteBuf payload, @@ -589,10 +612,12 @@ public class DefaultHttp2FrameReader implements Http2FrameReader, Http2FrameSize final ByteBuf continuationFragment = payload.readSlice(payload.readableBytes()); headersContinuation.processFragment(flags.endOfHeaders(), continuationFragment, listener); + resetHeadersContinuationIfEnd(flags.endOfHeaders()); } private void readUnknownFrame(ChannelHandlerContext ctx, ByteBuf payload, Http2FrameListener listener) throws Http2Exception { + verifyUnknownFrame(); payload = payload.readSlice(payload.readableBytes()); listener.onUnknownFrame(ctx, frameType, streamId, flags, payload); } diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2FrameReaderTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2FrameReaderTest.java new file mode 100644 index 0000000000..8392ce9d3a --- /dev/null +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2FrameReaderTest.java @@ -0,0 +1,156 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http2.internal.hpack.Encoder; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import static io.netty.handler.codec.http2.Http2FrameTypes.CONTINUATION; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static io.netty.handler.codec.http2.Http2CodecUtil.writeFrameHeader; +import static io.netty.handler.codec.http2.Http2FrameTypes.HEADERS; + +/** + * Tests for {@link DefaultHttp2FrameReader}. + */ +public class DefaultHttp2FrameReaderTest { + @Mock + private Http2FrameListener listener; + + @Mock + private ChannelHandlerContext ctx; + + private DefaultHttp2FrameReader frameReader; + + // Used to generate frame + private Encoder encoder; + + @Before + public void setUp() throws Exception { + MockitoAnnotations.initMocks(this); + + // Currently, frameReader only used alloc method from ChannelHandlerContext to allocate buffer, + // and used it as a parameter when calling listener + when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); + + frameReader = new DefaultHttp2FrameReader(); + encoder = new Encoder(); + } + + @After + public void tearDown() { + frameReader.close(); + } + + @Test + public void readHeaderFrame() throws Http2Exception { + final int streamId = 1; + + ByteBuf input = Unpooled.buffer(); + try { + Http2Headers headers = new DefaultHttp2Headers() + .authority("foo") + .method("get") + .path("/") + .scheme("https"); + Http2Flags flags = new Http2Flags().endOfHeaders(true).endOfStream(true); + writeHeaderFrame(input, streamId, headers, flags); + frameReader.readFrame(ctx, input, listener); + + verify(listener).onHeadersRead(ctx, 1, headers, 0, true); + } finally { + input.release(); + } + } + + @Test + public void readHeaderFrameAndContinuationFrame() throws Http2Exception { + final int streamId = 1; + + ByteBuf input = Unpooled.buffer(); + try { + Http2Headers headers = new DefaultHttp2Headers() + .authority("foo") + .method("get") + .path("/") + .scheme("https"); + writeHeaderFrame(input, streamId, headers, + new Http2Flags().endOfHeaders(false).endOfStream(true)); + writeContinuationFrame(input, streamId, new DefaultHttp2Headers().add("foo", "bar"), + new Http2Flags().endOfHeaders(true)); + + frameReader.readFrame(ctx, input, listener); + + verify(listener).onHeadersRead(ctx, 1, headers.add("foo", "bar"), 0, true); + } finally { + input.release(); + } + } + + @Test(expected = Http2Exception.class) + public void failedWhenUnknownFrameInMiddleOfHeaderBlock() throws Http2Exception { + final int streamId = 1; + + ByteBuf input = Unpooled.buffer(); + try { + Http2Headers headers = new DefaultHttp2Headers() + .authority("foo") + .method("get") + .path("/") + .scheme("https"); + Http2Flags flags = new Http2Flags().endOfHeaders(false).endOfStream(true); + writeHeaderFrame(input, streamId, headers, flags); + writeFrameHeader(input, 0, (byte) 0xff, new Http2Flags(), streamId); + frameReader.readFrame(ctx, input, listener); + } finally { + input.release(); + } + } + + private void writeHeaderFrame( + ByteBuf output, int streamId, Http2Headers headers, + Http2Flags flags) throws Http2Exception { + ByteBuf headerBlock = Unpooled.buffer(); + try { + encoder.encodeHeaders(streamId, headerBlock, headers, Http2HeadersEncoder.NEVER_SENSITIVE); + writeFrameHeader(output, headerBlock.readableBytes(), HEADERS, flags, streamId); + output.writeBytes(headerBlock, headerBlock.readableBytes()); + } finally { + headerBlock.release(); + } + } + + private void writeContinuationFrame( + ByteBuf output, int streamId, Http2Headers headers, + Http2Flags flags) throws Http2Exception { + ByteBuf headerBlock = Unpooled.buffer(); + try { + encoder.encodeHeaders(streamId, headerBlock, headers, Http2HeadersEncoder.NEVER_SENSITIVE); + writeFrameHeader(output, headerBlock.readableBytes(), CONTINUATION, flags, streamId); + output.writeBytes(headerBlock, headerBlock.readableBytes()); + } finally { + headerBlock.release(); + } + } +}