Fix GOAWAY logic in Http2Encoder and Http2Decoder.

Motivation:

1) The current implementation doesn't allow for HEADERS, DATA, PING, PRIORITY and SETTINGS
   frames to be sent after GOAWAY.

2) When receiving or sending a GOAWAY frame, all streams with ids greater than the lastStreamId
   of the GOAWAY frame should be closed. That's not happening.

Modifications:

1) Allow sending of HEADERS and DATA frames after GOAWAY for streams with ids < lastStreamId.
2) Always allow sending PING, PRIORITY AND SETTINGS frames.
3) Allow sending multiple GOAWAY frames with decreasing lastStreamIds.
4) After receiving or sending a GOAWAY frame, close all streams with ids > lastStreamId.

Result:

The GOAWAY handling is more correct.
This commit is contained in:
Jakob Buchgraber 2015-03-19 18:36:24 -07:00 committed by Scott Mitchell
parent ac8b05e0fb
commit 19c864505d
8 changed files with 268 additions and 176 deletions

View File

@ -163,11 +163,15 @@ public class DefaultHttp2Connection implements Http2Connection {
@Override @Override
public void goAwayReceived(int lastKnownStream, long errorCode, ByteBuf debugData) { public void goAwayReceived(int lastKnownStream, long errorCode, ByteBuf debugData) {
boolean alreadyNotified = goAwayReceived();
localEndpoint.lastKnownStream(lastKnownStream); localEndpoint.lastKnownStream(lastKnownStream);
if (!alreadyNotified) { for (Listener listener : listeners) {
for (int i = 0; i < listeners.size(); i++) { listener.onGoAwayReceived(lastKnownStream, errorCode, debugData);
listeners.get(i).onGoAwayReceived(lastKnownStream, errorCode, debugData); }
Http2Stream[] streams = new Http2Stream[numActiveStreams()];
for (Http2Stream stream : activeStreams().toArray(streams)) {
if (stream.id() > lastKnownStream && localEndpoint.createdStreamId(stream.id())) {
stream.close();
} }
} }
} }
@ -179,11 +183,15 @@ public class DefaultHttp2Connection implements Http2Connection {
@Override @Override
public void goAwaySent(int lastKnownStream, long errorCode, ByteBuf debugData) { public void goAwaySent(int lastKnownStream, long errorCode, ByteBuf debugData) {
boolean alreadyNotified = goAwaySent();
remoteEndpoint.lastKnownStream(lastKnownStream); remoteEndpoint.lastKnownStream(lastKnownStream);
if (!alreadyNotified) { for (Listener listener : listeners) {
for (int i = 0; i < listeners.size(); i++) { listener.onGoAwaySent(lastKnownStream, errorCode, debugData);
listeners.get(i).onGoAwaySent(lastKnownStream, errorCode, debugData); }
Http2Stream[] streams = new Http2Stream[numActiveStreams()];
for (Http2Stream stream : activeStreams().toArray(streams)) {
if (stream.id() > lastKnownStream && remoteEndpoint.createdStreamId(stream.id())) {
stream.close();
} }
} }
} }
@ -911,7 +919,7 @@ public class DefaultHttp2Connection implements Http2Connection {
@Override @Override
public int lastKnownStream() { public int lastKnownStream() {
return lastKnownStream >= 0 ? lastKnownStream : lastStreamCreated; return lastKnownStream;
} }
private void lastKnownStream(int lastKnownStream) { private void lastKnownStream(int lastKnownStream) {
@ -934,8 +942,10 @@ public class DefaultHttp2Connection implements Http2Connection {
} }
private void checkNewStreamAllowed(int streamId) throws Http2Exception { private void checkNewStreamAllowed(int streamId) throws Http2Exception {
if (goAwaySent() || goAwayReceived()) { if (goAwayReceived() && streamId > localEndpoint.lastKnownStream()) {
throw connectionError(PROTOCOL_ERROR, "Cannot create a stream since the connection is going away"); throw connectionError(PROTOCOL_ERROR, "Cannot create stream %d since this endpoint has received a " +
"GOAWAY frame with last stream id %d.", streamId,
localEndpoint.lastKnownStream());
} }
if (streamId < 0) { if (streamId < 0) {
throw new Http2NoMoreStreamIdsException(); throw new Http2NoMoreStreamIdsException();

View File

@ -168,10 +168,8 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder {
void onGoAwayRead0(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData) void onGoAwayRead0(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData)
throws Http2Exception { throws Http2Exception {
// Don't allow any more connections to be created.
connection.goAwayReceived(lastStreamId, errorCode, debugData);
listener.onGoAwayRead(ctx, lastStreamId, errorCode, debugData); listener.onGoAwayRead(ctx, lastStreamId, errorCode, debugData);
connection.goAwayReceived(lastStreamId, errorCode, debugData);
} }
void onUnknownFrame0(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, void onUnknownFrame0(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags,
@ -186,53 +184,42 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder {
@Override @Override
public int onDataRead(final ChannelHandlerContext ctx, int streamId, ByteBuf data, public int onDataRead(final ChannelHandlerContext ctx, int streamId, ByteBuf data,
int padding, boolean endOfStream) throws Http2Exception { int padding, boolean endOfStream) throws Http2Exception {
// Check if we received a data frame for a stream which is half-closed
Http2Stream stream = connection.requireStream(streamId); Http2Stream stream = connection.requireStream(streamId);
Http2LocalFlowController flowController = flowController();
int bytesToReturn = data.readableBytes() + padding;
verifyGoAwayNotReceived(); if (stream.isResetSent() || streamCreatedAfterGoAwaySent(stream)) {
// Count the frame towards the connection flow control window and don't process it further.
flowController.receiveFlowControlledFrame(ctx, stream, data, padding, endOfStream);
flowController.consumeBytes(ctx, stream, bytesToReturn);
// Since no bytes are consumed, return them all.
return bytesToReturn;
}
// We should ignore this frame if RST_STREAM was sent or if GO_AWAY was sent with a
// lower stream ID.
boolean shouldIgnore = shouldIgnoreFrame(stream, false);
Http2Exception error = null; Http2Exception error = null;
switch (stream.state()) { switch (stream.state()) {
case OPEN: case OPEN:
case HALF_CLOSED_LOCAL: case HALF_CLOSED_LOCAL:
break; break;
case HALF_CLOSED_REMOTE: case HALF_CLOSED_REMOTE:
// Always fail the stream if we've more data after the remote endpoint half-closed. case CLOSED:
error = streamError(stream.id(), STREAM_CLOSED, "Stream %d in unexpected state: %s", error = streamError(stream.id(), STREAM_CLOSED, "Stream %d in unexpected state: %s",
stream.id(), stream.state()); stream.id(), stream.state());
break; break;
case CLOSED:
if (!shouldIgnore) {
error = streamError(stream.id(), STREAM_CLOSED, "Stream %d in unexpected state: %s",
stream.id(), stream.state());
}
break;
default: default:
if (!shouldIgnore) { error = streamError(stream.id(), PROTOCOL_ERROR,
error = streamError(stream.id(), PROTOCOL_ERROR, "Stream %d in unexpected state: %s", stream.id(), stream.state());
"Stream %d in unexpected state: %s", stream.id(), stream.state());
}
break; break;
} }
int bytesToReturn = data.readableBytes() + padding;
int unconsumedBytes = unconsumedBytes(stream); int unconsumedBytes = unconsumedBytes(stream);
Http2LocalFlowController flowController = flowController();
try { try {
// If we should apply flow control, do so now.
flowController.receiveFlowControlledFrame(ctx, stream, data, padding, endOfStream); flowController.receiveFlowControlledFrame(ctx, stream, data, padding, endOfStream);
// Update the unconsumed bytes after flow control is applied. // Update the unconsumed bytes after flow control is applied.
unconsumedBytes = unconsumedBytes(stream); unconsumedBytes = unconsumedBytes(stream);
// If we should ignore this frame, do so now. // If the stream is in an invalid state to receive the frame, throw the error.
if (shouldIgnore) {
return bytesToReturn;
}
// If the stream was in an invalid state to receive the frame, throw the error.
if (error != null) { if (error != null) {
throw error; throw error;
} }
@ -256,7 +243,7 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder {
bytesToReturn -= delta; bytesToReturn -= delta;
throw e; throw e;
} finally { } finally {
// If appropriate, returned the processed bytes to the flow controller. // If appropriate, return the processed bytes to the flow controller.
if (bytesToReturn > 0) { if (bytesToReturn > 0) {
flowController.consumeBytes(ctx, stream, bytesToReturn); flowController.consumeBytes(ctx, stream, bytesToReturn);
} }
@ -277,14 +264,12 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder {
public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int streamDependency, public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int streamDependency,
short weight, boolean exclusive, int padding, boolean endOfStream) throws Http2Exception { short weight, boolean exclusive, int padding, boolean endOfStream) throws Http2Exception {
Http2Stream stream = connection.stream(streamId); Http2Stream stream = connection.stream(streamId);
verifyGoAwayNotReceived();
if (shouldIgnoreFrame(stream, false)) {
// Ignore this frame.
return;
}
if (stream == null) { if (stream == null) {
stream = connection.remote().createStream(streamId).open(endOfStream); stream = connection.remote().createStream(streamId).open(endOfStream);
} else if (stream.isResetSent() || streamCreatedAfterGoAwaySent(stream)) {
// Ignore this frame.
return;
} else { } else {
switch (stream.state()) { switch (stream.state()) {
case RESERVED_REMOTE: case RESERVED_REMOTE:
@ -297,9 +282,8 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder {
break; break;
case HALF_CLOSED_REMOTE: case HALF_CLOSED_REMOTE:
case CLOSED: case CLOSED:
// Stream error.
throw streamError(stream.id(), STREAM_CLOSED, "Stream %d in unexpected state: %s", throw streamError(stream.id(), STREAM_CLOSED, "Stream %d in unexpected state: %s",
stream.id(), stream.state()); stream.id(), stream.state());
default: default:
// Connection error. // Connection error.
throw connectionError(PROTOCOL_ERROR, "Stream %d in unexpected state: %s", stream.id(), throw connectionError(PROTOCOL_ERROR, "Stream %d in unexpected state: %s", stream.id(),
@ -316,8 +300,7 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder {
// In this case we should ignore the exception and allow the frame to be sent. // In this case we should ignore the exception and allow the frame to be sent.
} }
listener.onHeadersRead(ctx, streamId, headers, listener.onHeadersRead(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endOfStream);
streamDependency, weight, exclusive, padding, endOfStream);
// If the headers completes this stream, close it. // If the headers completes this stream, close it.
if (endOfStream) { if (endOfStream) {
@ -329,17 +312,15 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder {
public void onPriorityRead(ChannelHandlerContext ctx, int streamId, int streamDependency, short weight, public void onPriorityRead(ChannelHandlerContext ctx, int streamId, int streamDependency, short weight,
boolean exclusive) throws Http2Exception { boolean exclusive) throws Http2Exception {
Http2Stream stream = connection.stream(streamId); Http2Stream stream = connection.stream(streamId);
verifyGoAwayNotReceived();
if (shouldIgnoreFrame(stream, true)) {
// Ignore this frame.
return;
}
try { try {
if (stream == null) { if (stream == null) {
// PRIORITY frames always identify a stream. This means that if a PRIORITY frame is the // PRIORITY frames always identify a stream. This means that if a PRIORITY frame is the
// first frame to be received for a stream that we must create the stream. // first frame to be received for a stream that we must create the stream.
stream = connection.remote().createStream(streamId); stream = connection.remote().createStream(streamId);
} else if (streamCreatedAfterGoAwaySent(stream)) {
// Ignore this frame.
return;
} }
// This call will create a stream for streamDependency if necessary. // This call will create a stream for streamDependency if necessary.
@ -356,6 +337,7 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder {
@Override @Override
public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode) throws Http2Exception { public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode) throws Http2Exception {
Http2Stream stream = connection.requireStream(streamId); Http2Stream stream = connection.requireStream(streamId);
switch(stream.state()) { switch(stream.state()) {
case IDLE: case IDLE:
throw connectionError(PROTOCOL_ERROR, "RST_STREAM received for IDLE stream %d", streamId); throw connectionError(PROTOCOL_ERROR, "RST_STREAM received for IDLE stream %d", streamId);
@ -453,9 +435,8 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder {
public void onPushPromiseRead(ChannelHandlerContext ctx, int streamId, int promisedStreamId, public void onPushPromiseRead(ChannelHandlerContext ctx, int streamId, int promisedStreamId,
Http2Headers headers, int padding) throws Http2Exception { Http2Headers headers, int padding) throws Http2Exception {
Http2Stream parentStream = connection.requireStream(streamId); Http2Stream parentStream = connection.requireStream(streamId);
verifyGoAwayNotReceived();
if (shouldIgnoreFrame(parentStream, false)) { if (streamCreatedAfterGoAwaySent(parentStream)) {
// Ignore frames for any stream created after we sent a go-away.
return; return;
} }
@ -503,13 +484,12 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder {
public void onWindowUpdateRead(ChannelHandlerContext ctx, int streamId, int windowSizeIncrement) public void onWindowUpdateRead(ChannelHandlerContext ctx, int streamId, int windowSizeIncrement)
throws Http2Exception { throws Http2Exception {
Http2Stream stream = connection.requireStream(streamId); Http2Stream stream = connection.requireStream(streamId);
verifyGoAwayNotReceived();
if (stream.state() == CLOSED || shouldIgnoreFrame(stream, false)) { if (stream.state() == CLOSED || streamCreatedAfterGoAwaySent(stream)) {
// Ignore frames for any stream created after we sent a go-away.
return; return;
} }
// Update the outbound flow controller. // Update the outbound flow control window.
encoder.flowController().incrementWindowSize(ctx, stream, windowSizeIncrement); encoder.flowController().incrementWindowSize(ctx, stream, windowSizeIncrement);
listener.onWindowUpdateRead(ctx, streamId, windowSizeIncrement); listener.onWindowUpdateRead(ctx, streamId, windowSizeIncrement);
@ -521,31 +501,10 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder {
onUnknownFrame0(ctx, frameType, streamId, flags, payload); onUnknownFrame0(ctx, frameType, streamId, flags, payload);
} }
/** private boolean streamCreatedAfterGoAwaySent(Http2Stream stream) {
* Indicates whether or not frames for the given stream should be ignored based on the state of the // Ignore inbound frames after a GOAWAY was sent and the stream id is greater than
* stream/connection. // the last stream id set in the GOAWAY frame.
*/ return connection().goAwaySent() && stream.id() > connection().remote().lastKnownStream();
private boolean shouldIgnoreFrame(Http2Stream stream, boolean allowResetSent) {
if (connection.goAwaySent() &&
(stream == null || connection.remote().lastStreamCreated() <= stream.id())) {
// Frames from streams created after we sent a go-away should be ignored.
// Frames for the connection stream ID (i.e. 0) will always be allowed.
return true;
}
// Also ignore inbound frames after we sent a RST_STREAM frame.
return stream != null && !allowResetSent && stream.isResetSent();
}
/**
* Verifies that a GO_AWAY frame was not previously received from the remote endpoint. If it was, throws a
* connection error.
*/
private void verifyGoAwayNotReceived() throws Http2Exception {
if (connection.goAwayReceived()) {
// Connection error.
throw connectionError(PROTOCOL_ERROR, "Received frames after receiving GO_AWAY");
}
} }
} }

View File

@ -112,12 +112,7 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
final boolean endOfStream, ChannelPromise promise) { final boolean endOfStream, ChannelPromise promise) {
final Http2Stream stream; final Http2Stream stream;
try { try {
if (connection.goAwayReceived() || connection.goAwaySent()) {
throw new IllegalStateException("Sending data after connection going away.");
}
stream = connection.requireStream(streamId); stream = connection.requireStream(streamId);
// Verify that the stream is in the appropriate state for sending DATA frames. // Verify that the stream is in the appropriate state for sending DATA frames.
switch (stream.state()) { switch (stream.state()) {
case OPEN: case OPEN:
@ -151,9 +146,6 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
final boolean exclusive, final int padding, final boolean endOfStream, final boolean exclusive, final int padding, final boolean endOfStream,
final ChannelPromise promise) { final ChannelPromise promise) {
try { try {
if (connection.goAwayReceived() || connection.goAwaySent()) {
throw connectionError(PROTOCOL_ERROR, "Sending headers after connection going away.");
}
Http2Stream stream = connection.stream(streamId); Http2Stream stream = connection.stream(streamId);
if (stream == null) { if (stream == null) {
stream = connection.local().createStream(streamId); stream = connection.local().createStream(streamId);
@ -190,10 +182,6 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
public ChannelFuture writePriority(ChannelHandlerContext ctx, int streamId, int streamDependency, short weight, public ChannelFuture writePriority(ChannelHandlerContext ctx, int streamId, int streamDependency, short weight,
boolean exclusive, ChannelPromise promise) { boolean exclusive, ChannelPromise promise) {
try { try {
if (connection.goAwayReceived() || connection.goAwaySent()) {
throw connectionError(PROTOCOL_ERROR, "Sending priority after connection going away.");
}
// Update the priority on this stream. // Update the priority on this stream.
Http2Stream stream = connection.stream(streamId); Http2Stream stream = connection.stream(streamId);
if (stream == null) { if (stream == null) {
@ -227,10 +215,6 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
ChannelPromise promise) { ChannelPromise promise) {
outstandingLocalSettingsQueue.add(settings); outstandingLocalSettingsQueue.add(settings);
try { try {
if (connection.goAwayReceived() || connection.goAwaySent()) {
throw connectionError(PROTOCOL_ERROR, "Sending settings after connection going away.");
}
Boolean pushEnabled = settings.pushEnabled(); Boolean pushEnabled = settings.pushEnabled();
if (pushEnabled != null && connection.isServer()) { if (pushEnabled != null && connection.isServer()) {
throw connectionError(PROTOCOL_ERROR, "Server sending SETTINGS frame with ENABLE_PUSH specified"); throw connectionError(PROTOCOL_ERROR, "Server sending SETTINGS frame with ENABLE_PUSH specified");
@ -252,13 +236,7 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
} }
@Override @Override
public ChannelFuture writePing(ChannelHandlerContext ctx, boolean ack, ByteBuf data, public ChannelFuture writePing(ChannelHandlerContext ctx, boolean ack, ByteBuf data, ChannelPromise promise) {
ChannelPromise promise) {
if (connection.goAwayReceived() || connection.goAwaySent()) {
data.release();
return promise.setFailure(connectionError(PROTOCOL_ERROR, "Sending ping after connection going away."));
}
ChannelFuture future = frameWriter.writePing(ctx, ack, data, promise); ChannelFuture future = frameWriter.writePing(ctx, ack, data, promise);
ctx.flush(); ctx.flush();
return future; return future;
@ -268,12 +246,12 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
public ChannelFuture writePushPromise(ChannelHandlerContext ctx, int streamId, int promisedStreamId, public ChannelFuture writePushPromise(ChannelHandlerContext ctx, int streamId, int promisedStreamId,
Http2Headers headers, int padding, ChannelPromise promise) { Http2Headers headers, int padding, ChannelPromise promise) {
try { try {
if (connection.goAwayReceived() || connection.goAwaySent()) { if (connection.goAwayReceived()) {
throw connectionError(PROTOCOL_ERROR, "Sending push promise after connection going away."); throw connectionError(PROTOCOL_ERROR, "Sending PUSH_PROMISE after GO_AWAY received.");
} }
// Reserve the promised stream.
Http2Stream stream = connection.requireStream(streamId); Http2Stream stream = connection.requireStream(streamId);
// Reserve the promised stream.
connection.local().reservePushStream(promisedStreamId, stream); connection.local().reservePushStream(promisedStreamId, stream);
} catch (Throwable e) { } catch (Throwable e) {
return promise.setFailure(e); return promise.setFailure(e);

View File

@ -97,7 +97,7 @@ public interface Http2Connection {
* Called when a {@code GOAWAY} was received from the remote endpoint. This event handler duplicates {@link * Called when a {@code GOAWAY} was received from the remote endpoint. This event handler duplicates {@link
* Http2FrameListener#onGoAwayRead(io.netty.channel.ChannelHandlerContext, int, long, io.netty.buffer.ByteBuf)} * Http2FrameListener#onGoAwayRead(io.netty.channel.ChannelHandlerContext, int, long, io.netty.buffer.ByteBuf)}
* but is added here in order to simplify application logic for handling {@code GOAWAY} in a uniform way. An * but is added here in order to simplify application logic for handling {@code GOAWAY} in a uniform way. An
* application should generally not handle both events, but if it does this method is called first, before * application should generally not handle both events, but if it does this method is called second, after
* notifying the {@link Http2FrameListener}. * notifying the {@link Http2FrameListener}.
* *
* @param lastStreamId the last known stream of the remote endpoint. * @param lastStreamId the last known stream of the remote endpoint.
@ -206,9 +206,8 @@ public interface Http2Connection {
int lastStreamCreated(); int lastStreamCreated();
/** /**
* Gets the last stream created by this endpoint that is "known" by the opposite endpoint.
* If a GOAWAY was received for this endpoint, this will be the last stream ID from the * If a GOAWAY was received for this endpoint, this will be the last stream ID from the
* GOAWAY frame. Otherwise, this will be same as {@link #lastStreamCreated()}. * GOAWAY frame. Otherwise, this will be {@code -1}.
*/ */
int lastKnownStream(); int lastKnownStream();

View File

@ -23,6 +23,8 @@ import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2Exception.connectionError; import static io.netty.handler.codec.http2.Http2Exception.connectionError;
import static io.netty.handler.codec.http2.Http2Exception.isStreamError; import static io.netty.handler.codec.http2.Http2Exception.isStreamError;
import static io.netty.util.internal.ObjectUtil.checkNotNull; import static io.netty.util.internal.ObjectUtil.checkNotNull;
import static java.lang.String.format;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil; import io.netty.buffer.ByteBufUtil;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
@ -32,6 +34,9 @@ import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.http2.Http2Exception.CompositeStreamException; import io.netty.handler.codec.http2.Http2Exception.CompositeStreamException;
import io.netty.handler.codec.http2.Http2Exception.StreamException; import io.netty.handler.codec.http2.Http2Exception.StreamException;
import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import java.util.Collection; import java.util.Collection;
import java.util.List; import java.util.List;
@ -50,6 +55,7 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
private final Http2ConnectionEncoder encoder; private final Http2ConnectionEncoder encoder;
private ChannelFutureListener closeListener; private ChannelFutureListener closeListener;
private BaseDecoder byteDecoder; private BaseDecoder byteDecoder;
private static final InternalLogger logger = InternalLoggerFactory.getInstance(Http2ConnectionHandler.class);
public Http2ConnectionHandler(boolean server, Http2FrameListener listener) { public Http2ConnectionHandler(boolean server, Http2FrameListener listener) {
this(new DefaultHttp2Connection(server), listener); this(new DefaultHttp2Connection(server), listener);
@ -510,36 +516,47 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
} }
@Override @Override
public ChannelFuture goAway(ChannelHandlerContext ctx, int lastStreamId, long errorCode, public ChannelFuture goAway(final ChannelHandlerContext ctx, final int lastStreamId, final long errorCode,
ByteBuf debugData, ChannelPromise promise) { final ByteBuf debugData, ChannelPromise promise) {
Http2Connection connection = connection(); try {
if (connection.goAwayReceived() || connection.goAwaySent()) { final Http2Connection connection = connection();
if (connection.goAwaySent() && connection.remote().lastKnownStream() < lastStreamId) {
throw connectionError(PROTOCOL_ERROR, "Last stream identifier must not increase between " +
"sending multiple GOAWAY frames (was '%d', is '%d').",
connection.remote().lastKnownStream(),
lastStreamId);
}
connection.goAwaySent(lastStreamId, errorCode, debugData);
ChannelFuture future = frameWriter().writeGoAway(ctx, lastStreamId, errorCode, debugData, promise);
ctx.flush();
future.addListener(new GenericFutureListener<ChannelFuture>() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
String msg = format("Sending GOAWAY failed: lastStreamId '%d', errorCode '%d', " +
"debugData '%s'.", lastStreamId, errorCode, debugData);
logger.error(msg, future.cause());
ctx.channel().close();
}
}
});
return future;
} catch (Http2Exception e) {
debugData.release(); debugData.release();
return ctx.newSucceededFuture(); return promise.setFailure(e);
} }
connection.goAwaySent(lastStreamId, errorCode, debugData);
ChannelFuture future = frameWriter().writeGoAway(ctx, lastStreamId, errorCode, debugData, promise);
ctx.flush();
return future;
} }
/** /**
* Close the remote endpoint with with a {@code GO_AWAY} frame. * Close the remote endpoint with with a {@code GO_AWAY} frame.
*/ */
private ChannelFuture goAway(ChannelHandlerContext ctx, Http2Exception cause) { private ChannelFuture goAway(ChannelHandlerContext ctx, Http2Exception cause) {
Http2Connection connection = connection();
if (connection.goAwayReceived() || connection.goAwaySent()) {
return ctx.newSucceededFuture();
}
// The connection isn't alredy going away, send the GO_AWAY frame now to start
// the process.
long errorCode = cause != null ? cause.error().code() : NO_ERROR.code(); long errorCode = cause != null ? cause.error().code() : NO_ERROR.code();
ByteBuf debugData = Http2CodecUtil.toByteBuf(ctx, cause); ByteBuf debugData = Http2CodecUtil.toByteBuf(ctx, cause);
int lastKnownStream = connection.remote().lastStreamCreated(); int lastKnownStream = connection().remote().lastStreamCreated();
return goAway(ctx, lastKnownStream, errorCode, debugData, ctx.newPromise()); return goAway(ctx, lastKnownStream, errorCode, debugData, ctx.newPromise());
} }

View File

@ -18,13 +18,17 @@ import static io.netty.buffer.Unpooled.wrappedBuffer;
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT;
import static io.netty.handler.codec.http2.Http2CodecUtil.emptyPingBuf; import static io.netty.handler.codec.http2.Http2CodecUtil.emptyPingBuf;
import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2Stream.State.CLOSED;
import static io.netty.handler.codec.http2.Http2Stream.State.HALF_CLOSED_LOCAL; import static io.netty.handler.codec.http2.Http2Stream.State.HALF_CLOSED_LOCAL;
import static io.netty.handler.codec.http2.Http2Stream.State.IDLE; import static io.netty.handler.codec.http2.Http2Stream.State.IDLE;
import static io.netty.handler.codec.http2.Http2Stream.State.OPEN; import static io.netty.handler.codec.http2.Http2Stream.State.OPEN;
import static io.netty.handler.codec.http2.Http2Stream.State.RESERVED_LOCAL; import static io.netty.handler.codec.http2.Http2Stream.State.RESERVED_LOCAL;
import static io.netty.util.CharsetUtil.UTF_8; import static io.netty.util.CharsetUtil.UTF_8;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
@ -35,8 +39,10 @@ import static org.mockito.Matchers.anyShort;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
@ -199,20 +205,6 @@ public class DefaultHttp2ConnectionEncoderTest {
encoder.lifecycleManager(lifecycleManager); encoder.lifecycleManager(lifecycleManager);
} }
@Test
public void dataWriteAfterGoAwayShouldFail() throws Exception {
when(connection.goAwayReceived()).thenReturn(true);
final ByteBuf data = dummyData();
try {
ChannelFuture future = encoder.writeData(ctx, STREAM_ID, data, 0, true, promise);
assertTrue(future.awaitUninterruptibly().cause() instanceof IllegalStateException);
} finally {
while (data.refCnt() > 0) {
data.release();
}
}
}
@Test @Test
public void dataWriteShouldSucceed() throws Exception { public void dataWriteShouldSucceed() throws Exception {
final ByteBuf data = dummyData(); final ByteBuf data = dummyData();
@ -302,18 +294,6 @@ public class DefaultHttp2ConnectionEncoderTest {
assertEquals(5, (int) writtenPadding.get(1)); assertEquals(5, (int) writtenPadding.get(1));
} }
@Test
public void headersWriteAfterGoAwayShouldFail() throws Exception {
when(connection.goAwayReceived()).thenReturn(true);
ChannelFuture future = encoder.writeHeaders(
ctx, 5, EmptyHttp2Headers.INSTANCE, 0, (short) 255, false, 0, false, promise);
verify(local, never()).createStream(anyInt());
verify(stream, never()).open(anyBoolean());
verify(writer, never()).writeHeaders(eq(ctx), anyInt(), any(Http2Headers.class), anyInt(), anyBoolean(),
eq(promise));
assertTrue(future.awaitUninterruptibly().cause() instanceof Http2Exception);
}
@Test @Test
public void headersWriteForUnknownStreamShouldCreateStream() throws Exception { public void headersWriteForUnknownStreamShouldCreateStream() throws Exception {
int streamId = 5; int streamId = 5;
@ -344,11 +324,10 @@ public class DefaultHttp2ConnectionEncoderTest {
} }
@Test @Test
public void pushPromiseWriteAfterGoAwayShouldFail() throws Exception { public void pushPromiseWriteAfterGoAwayReceivedShouldFail() throws Exception {
when(connection.goAwayReceived()).thenReturn(true); when(connection.goAwayReceived()).thenReturn(true);
ChannelFuture future = ChannelFuture future = encoder.writePushPromise(ctx, STREAM_ID, PUSH_STREAM_ID, EmptyHttp2Headers.INSTANCE, 0,
encoder.writePushPromise(ctx, STREAM_ID, PUSH_STREAM_ID, promise);
EmptyHttp2Headers.INSTANCE, 0, promise);
assertTrue(future.awaitUninterruptibly().cause() instanceof Http2Exception); assertTrue(future.awaitUninterruptibly().cause() instanceof Http2Exception);
} }
@ -361,10 +340,10 @@ public class DefaultHttp2ConnectionEncoderTest {
} }
@Test @Test
public void priorityWriteAfterGoAwayShouldFail() throws Exception { public void priorityWriteAfterGoAwayShouldSucceed() throws Exception {
when(connection.goAwayReceived()).thenReturn(true); when(connection.goAwayReceived()).thenReturn(true);
ChannelFuture future = encoder.writePriority(ctx, STREAM_ID, 0, (short) 255, true, promise); encoder.writePriority(ctx, STREAM_ID, 0, (short) 255, true, promise);
assertTrue(future.awaitUninterruptibly().cause() instanceof Http2Exception); verify(writer).writePriority(eq(ctx), eq(STREAM_ID), eq(0), eq((short) 255), eq(true), eq(promise));
} }
@Test @Test
@ -425,10 +404,10 @@ public class DefaultHttp2ConnectionEncoderTest {
} }
@Test @Test
public void pingWriteAfterGoAwayShouldFail() throws Exception { public void pingWriteAfterGoAwayShouldSucceed() throws Exception {
when(connection.goAwayReceived()).thenReturn(true); when(connection.goAwayReceived()).thenReturn(true);
ChannelFuture future = encoder.writePing(ctx, false, emptyPingBuf(), promise); encoder.writePing(ctx, false, emptyPingBuf(), promise);
assertTrue(future.awaitUninterruptibly().cause() instanceof Http2Exception); verify(writer).writePing(eq(ctx), eq(false), eq(emptyPingBuf()), eq(promise));
} }
@Test @Test
@ -438,10 +417,10 @@ public class DefaultHttp2ConnectionEncoderTest {
} }
@Test @Test
public void settingsWriteAfterGoAwayShouldFail() throws Exception { public void settingsWriteAfterGoAwayShouldSucceed() throws Exception {
when(connection.goAwayReceived()).thenReturn(true); when(connection.goAwayReceived()).thenReturn(true);
ChannelFuture future = encoder.writeSettings(ctx, new Http2Settings(), promise); encoder.writeSettings(ctx, new Http2Settings(), promise);
assertTrue(future.awaitUninterruptibly().cause() instanceof Http2Exception); verify(writer).writeSettings(eq(ctx), any(Http2Settings.class), eq(promise));
} }
@Test @Test
@ -501,6 +480,70 @@ public class DefaultHttp2ConnectionEncoderTest {
verify(lifecycleManager).closeStreamLocal(eq(stream), eq(promise)); verify(lifecycleManager).closeStreamLocal(eq(stream), eq(promise));
} }
@Test
public void encoderDelegatesGoAwayToLifeCycleManager() {
encoder.writeGoAway(ctx, STREAM_ID, Http2Error.INTERNAL_ERROR.code(), null, promise);
verify(lifecycleManager).goAway(eq(ctx), eq(STREAM_ID), eq(Http2Error.INTERNAL_ERROR.code()),
eq((ByteBuf) null), eq(promise));
verifyNoMoreInteractions(writer);
}
@Test
public void dataWriteToClosedStreamShouldFail() {
when(stream.state()).thenReturn(CLOSED);
ByteBuf data = mock(ByteBuf.class);
encoder.writeData(ctx, STREAM_ID, data, 0, false, promise);
assertTrue(promise.isDone());
assertFalse(promise.isSuccess());
assertThat(promise.cause(), instanceOf(IllegalStateException.class));
verify(data).release();
}
@Test
public void dataWriteToHalfClosedLocalStreamShouldFail() {
when(stream.state()).thenReturn(HALF_CLOSED_LOCAL);
ByteBuf data = mock(ByteBuf.class);
encoder.writeData(ctx, STREAM_ID, data, 0, false, promise);
assertTrue(promise.isDone());
assertFalse(promise.isSuccess());
assertThat(promise.cause(), instanceOf(IllegalStateException.class));
verify(data).release();
}
@Test
public void canWriteDataFrameAfterGoAwaySent() {
when(connection.goAwaySent()).thenReturn(true);
when(remote.lastKnownStream()).thenReturn(0);
ByteBuf data = mock(ByteBuf.class);
encoder.writeData(ctx, STREAM_ID, data, 0, false, promise);
verify(remoteFlow).sendFlowControlled(eq(ctx), eq(stream), any(FlowControlled.class));
}
@Test
public void canWriteHeaderFrameAfterGoAwaySent() {
when(connection.goAwaySent()).thenReturn(true);
when(remote.lastKnownStream()).thenReturn(0);
encoder.writeHeaders(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, false, promise);
verify(remoteFlow).sendFlowControlled(eq(ctx), eq(stream), any(FlowControlled.class));
}
@Test
public void canWriteDataFrameAfterGoAwayReceived() {
when(connection.goAwayReceived()).thenReturn(true);
when(local.lastKnownStream()).thenReturn(STREAM_ID);
ByteBuf data = mock(ByteBuf.class);
encoder.writeData(ctx, STREAM_ID, data, 0, false, promise);
verify(remoteFlow).sendFlowControlled(eq(ctx), eq(stream), any(FlowControlled.class));
}
@Test
public void canWriteHeaderFrameAfterGoAwayReceived() {
when(connection.goAwayReceived()).thenReturn(true);
when(local.lastKnownStream()).thenReturn(STREAM_ID);
encoder.writeHeaders(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, false, promise);
verify(remoteFlow).sendFlowControlled(eq(ctx), eq(stream), any(FlowControlled.class));
}
private void mockSendFlowControlledWriteEverything() { private void mockSendFlowControlledWriteEverything() {
doAnswer(new Answer<Void>() { doAnswer(new Answer<Void>() {
@Override @Override

View File

@ -74,6 +74,45 @@ public class DefaultHttp2ConnectionTest {
assertNull(server.stream(100)); assertNull(server.stream(100));
} }
@Test
public void goAwayReceivedShouldCloseStreamsGreaterThanLastStream() throws Exception {
Http2Stream stream1 = client.local().createStream(3).open(false);
Http2Stream stream2 = client.local().createStream(5).open(false);
Http2Stream remoteStream = client.remote().createStream(4).open(false);
assertEquals(State.OPEN, stream1.state());
assertEquals(State.OPEN, stream2.state());
client.goAwayReceived(3, 8, null);
assertEquals(State.OPEN, stream1.state());
assertEquals(State.CLOSED, stream2.state());
assertEquals(State.OPEN, remoteStream.state());
assertEquals(3, client.local().lastKnownStream());
assertEquals(5, client.local().lastStreamCreated());
// The remote endpoint must not be affected by a received GOAWAY frame.
assertEquals(-1, client.remote().lastKnownStream());
assertEquals(State.OPEN, remoteStream.state());
}
@Test
public void goAwaySentShouldCloseStreamsGreaterThanLastStream() throws Exception {
Http2Stream stream1 = server.remote().createStream(3).open(false);
Http2Stream stream2 = server.remote().createStream(5).open(false);
Http2Stream localStream = server.local().createStream(4).open(false);
server.goAwaySent(3, 8, null);
assertEquals(State.OPEN, stream1.state());
assertEquals(State.CLOSED, stream2.state());
assertEquals(3, server.remote().lastKnownStream());
assertEquals(5, server.remote().lastStreamCreated());
// The local endpoint must not be affected by a sent GOAWAY frame.
assertEquals(-1, server.local().lastKnownStream());
assertEquals(State.OPEN, localStream.state());
}
@Test @Test
public void serverCreateStreamShouldSucceed() throws Http2Exception { public void serverCreateStreamShouldSucceed() throws Http2Exception {
Http2Stream stream = server.local().createStream(2).open(false); Http2Stream stream = server.local().createStream(2).open(false);

View File

@ -22,6 +22,7 @@ import static io.netty.handler.codec.http2.Http2Error.STREAM_CLOSED;
import static io.netty.handler.codec.http2.Http2Stream.State.CLOSED; import static io.netty.handler.codec.http2.Http2Stream.State.CLOSED;
import static io.netty.util.CharsetUtil.UTF_8; import static io.netty.util.CharsetUtil.UTF_8;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyBoolean; import static org.mockito.Matchers.anyBoolean;
@ -29,16 +30,17 @@ import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.anyLong; import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise; import io.netty.channel.DefaultChannelPromise;
@ -272,7 +274,7 @@ public class Http2ConnectionHandlerTest {
@Override @Override
public ChannelFuture answer(InvocationOnMock invocation) throws Throwable { public ChannelFuture answer(InvocationOnMock invocation) throws Throwable {
Object[] args = invocation.getArguments(); Object[] args = invocation.getArguments();
ChannelFutureListener listener = (ChannelFutureListener) args[0]; GenericFutureListener<ChannelFuture> listener = (GenericFutureListener<ChannelFuture>) args[0];
// Simulate that all streams have become inactive by the time the future completes. // Simulate that all streams have become inactive by the time the future completes.
when(connection.activeStreams()).thenReturn(Collections.<Http2Stream>emptyList()); when(connection.activeStreams()).thenReturn(Collections.<Http2Stream>emptyList());
when(connection.numActiveStreams()).thenReturn(0); when(connection.numActiveStreams()).thenReturn(0);
@ -287,4 +289,49 @@ public class Http2ConnectionHandlerTest {
handler.closeStream(stream, future); handler.closeStream(stream, future);
verify(ctx, times(1)).close(any(ChannelPromise.class)); verify(ctx, times(1)).close(any(ChannelPromise.class));
} }
public void canSendGoAwayFrame() throws Exception {
handler = newHandler();
ByteBuf data = mock(ByteBuf.class);
long errorCode = Http2Error.INTERNAL_ERROR.code();
handler.goAway(ctx, STREAM_ID, errorCode, data, promise);
verify(connection).goAwaySent(eq(STREAM_ID), eq(errorCode), eq(data));
verify(frameWriter).writeGoAway(eq(ctx), eq(STREAM_ID), eq(errorCode), eq(data), eq(promise));
}
@Test
public void canSendGoAwayFramesWithDecreasingLastStreamIds() throws Exception {
handler = newHandler();
ByteBuf data = mock(ByteBuf.class);
long errorCode = Http2Error.INTERNAL_ERROR.code();
handler.goAway(ctx, STREAM_ID + 2, errorCode, data, promise);
verify(frameWriter).writeGoAway(eq(ctx), eq(STREAM_ID + 2), eq(errorCode), eq(data), eq(promise));
verify(connection).goAwaySent(eq(STREAM_ID + 2), eq(errorCode), eq(data));
handler.goAway(ctx, STREAM_ID, errorCode, data, promise);
verify(frameWriter).writeGoAway(eq(ctx), eq(STREAM_ID), eq(errorCode), eq(data), eq(promise));
verify(connection).goAwaySent(eq(STREAM_ID), eq(errorCode), eq(data));
}
@Test
public void cannotSendGoAwayFrameWithIncreasingLastStreamIds() throws Exception {
handler = newHandler();
ByteBuf data = mock(ByteBuf.class);
long errorCode = Http2Error.INTERNAL_ERROR.code();
handler.goAway(ctx, STREAM_ID, errorCode, data, promise);
verify(connection).goAwaySent(eq(STREAM_ID), eq(errorCode), eq(data));
verify(frameWriter).writeGoAway(eq(ctx), eq(STREAM_ID), eq(errorCode), eq(data), eq(promise));
// The frameWriter is only mocked, so it should not have interacted with the promise.
assertFalse(promise.isDone());
when(connection.goAwaySent()).thenReturn(true);
when(remote.lastKnownStream()).thenReturn(STREAM_ID);
handler.goAway(ctx, STREAM_ID + 2, errorCode, data, promise);
assertTrue(promise.isDone());
assertFalse(promise.isSuccess());
verify(data).release();
verifyNoMoreInteractions(frameWriter);
}
} }