HTTP/2 Prevent modification of activeStreams while iterating

Motivation:
The Http2Connection interface exposes an activeStreams() method which allows direct iteration over the underlying collection. There are a few places that make copies of this collection to avoid modification while iterating, and a few places that do not make copies. The copy operation can be expensive on hot code paths and also we are not consistently iterating over the activeStreams collection.

Modifications:
- The Http2Connection interface should reduce the exposure of the underlying collection and just expose what is necessary for the interface to function.  This is just a means to iterate over the collection.
- The DefaultHttp2Connection should use this new interface and protect it's internal state while iteration is occurring.

Result:
Reduction in surface area of the Http2Connection interface.  Consistent iteration of the set of active streams.  Concurrent modification exceptions are handled in 1 encapsulated spot.
This commit is contained in:
Scott Mitchell 2015-03-28 13:32:19 -07:00
parent 19c864505d
commit 28bd5a55c8
9 changed files with 300 additions and 118 deletions

View File

@ -33,39 +33,39 @@ import static io.netty.handler.codec.http2.Http2Stream.State.OPEN;
import static io.netty.handler.codec.http2.Http2Stream.State.RESERVED_LOCAL; import static io.netty.handler.codec.http2.Http2Stream.State.RESERVED_LOCAL;
import static io.netty.handler.codec.http2.Http2Stream.State.RESERVED_REMOTE; import static io.netty.handler.codec.http2.Http2Stream.State.RESERVED_REMOTE;
import static io.netty.util.internal.ObjectUtil.checkNotNull; import static io.netty.util.internal.ObjectUtil.checkNotNull;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.handler.codec.http2.Http2StreamRemovalPolicy.Action; import io.netty.handler.codec.http2.Http2StreamRemovalPolicy.Action;
import io.netty.util.collection.IntObjectHashMap; import io.netty.util.collection.IntObjectHashMap;
import io.netty.util.collection.IntObjectMap; import io.netty.util.collection.IntObjectMap;
import io.netty.util.collection.PrimitiveCollections; import io.netty.util.collection.PrimitiveCollections;
import io.netty.util.internal.PlatformDependent;
import java.util.ArrayDeque;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Iterator; import java.util.Iterator;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Queue;
import java.util.Set; import java.util.Set;
/** /**
* Simple implementation of {@link Http2Connection}. * Simple implementation of {@link Http2Connection}.
*/ */
public class DefaultHttp2Connection implements Http2Connection { public class DefaultHttp2Connection implements Http2Connection {
// Fields accessed by inner classes
final IntObjectMap<Http2Stream> streamMap = new IntObjectHashMap<Http2Stream>();
final ConnectionStream connectionStream = new ConnectionStream();
final DefaultEndpoint<Http2LocalFlowController> localEndpoint;
final DefaultEndpoint<Http2RemoteFlowController> remoteEndpoint;
/** /**
* We chose a {@link List} over a {@link Set} to avoid allocating an {@link Iterator} objects when iterating over * We chose a {@link List} over a {@link Set} to avoid allocating an {@link Iterator} objects when iterating over
* the listeners. * the listeners.
*/ */
private final List<Listener> listeners = new ArrayList<Listener>(4); final List<Listener> listeners = new ArrayList<Listener>(4);
private final IntObjectMap<Http2Stream> streamMap = new IntObjectHashMap<Http2Stream>(); final ActiveStreams activeStreams;
private final ConnectionStream connectionStream = new ConnectionStream();
private final Set<Http2Stream> activeStreams = new LinkedHashSet<Http2Stream>();
private final DefaultEndpoint<Http2LocalFlowController> localEndpoint;
private final DefaultEndpoint<Http2RemoteFlowController> remoteEndpoint;
private final Http2StreamRemovalPolicy removalPolicy;
/** /**
* Creates a connection with an immediate stream removal policy. * Creates a connection with an immediate stream removal policy.
@ -86,7 +86,7 @@ public class DefaultHttp2Connection implements Http2Connection {
* the policy to be used for removal of closed stream. * the policy to be used for removal of closed stream.
*/ */
public DefaultHttp2Connection(boolean server, Http2StreamRemovalPolicy removalPolicy) { public DefaultHttp2Connection(boolean server, Http2StreamRemovalPolicy removalPolicy) {
this.removalPolicy = checkNotNull(removalPolicy, "removalPolicy"); activeStreams = new ActiveStreams(listeners, checkNotNull(removalPolicy, "removalPolicy"));
localEndpoint = new DefaultEndpoint<Http2LocalFlowController>(server); localEndpoint = new DefaultEndpoint<Http2LocalFlowController>(server);
remoteEndpoint = new DefaultEndpoint<Http2RemoteFlowController>(!server); remoteEndpoint = new DefaultEndpoint<Http2RemoteFlowController>(!server);
@ -142,8 +142,8 @@ public class DefaultHttp2Connection implements Http2Connection {
} }
@Override @Override
public Set<Http2Stream> activeStreams() { public Http2Stream forEachActiveStream(StreamVisitor visitor) throws Http2Exception {
return Collections.unmodifiableSet(activeStreams); return activeStreams.forEachActiveStream(visitor);
} }
@Override @Override
@ -162,17 +162,24 @@ public class DefaultHttp2Connection implements Http2Connection {
} }
@Override @Override
public void goAwayReceived(int lastKnownStream, long errorCode, ByteBuf debugData) { public void goAwayReceived(final int lastKnownStream, long errorCode, ByteBuf debugData) {
localEndpoint.lastKnownStream(lastKnownStream); localEndpoint.lastKnownStream(lastKnownStream);
for (Listener listener : listeners) { for (Listener listener : listeners) {
listener.onGoAwayReceived(lastKnownStream, errorCode, debugData); listener.onGoAwayReceived(lastKnownStream, errorCode, debugData);
} }
Http2Stream[] streams = new Http2Stream[numActiveStreams()]; try {
for (Http2Stream stream : activeStreams().toArray(streams)) { forEachActiveStream(new StreamVisitor() {
@Override
public boolean visit(Http2Stream stream) {
if (stream.id() > lastKnownStream && localEndpoint.createdStreamId(stream.id())) { if (stream.id() > lastKnownStream && localEndpoint.createdStreamId(stream.id())) {
stream.close(); stream.close();
} }
return true;
}
});
} catch (Http2Exception e) {
PlatformDependent.throwException(e);
} }
} }
@ -182,17 +189,24 @@ public class DefaultHttp2Connection implements Http2Connection {
} }
@Override @Override
public void goAwaySent(int lastKnownStream, long errorCode, ByteBuf debugData) { public void goAwaySent(final int lastKnownStream, long errorCode, ByteBuf debugData) {
remoteEndpoint.lastKnownStream(lastKnownStream); remoteEndpoint.lastKnownStream(lastKnownStream);
for (Listener listener : listeners) { for (Listener listener : listeners) {
listener.onGoAwaySent(lastKnownStream, errorCode, debugData); listener.onGoAwaySent(lastKnownStream, errorCode, debugData);
} }
Http2Stream[] streams = new Http2Stream[numActiveStreams()]; try {
for (Http2Stream stream : activeStreams().toArray(streams)) { forEachActiveStream(new StreamVisitor() {
@Override
public boolean visit(Http2Stream stream) {
if (stream.id() > lastKnownStream && remoteEndpoint.createdStreamId(stream.id())) { if (stream.id() > lastKnownStream && remoteEndpoint.createdStreamId(stream.id())) {
stream.close(); stream.close();
} }
return true;
}
});
} catch (Http2Exception e) {
PlatformDependent.throwException(e);
} }
} }
@ -384,15 +398,7 @@ public class DefaultHttp2Connection implements Http2Connection {
throw streamError(id, PROTOCOL_ERROR, "Attempting to open a stream in an invalid state: " + state); throw streamError(id, PROTOCOL_ERROR, "Attempting to open a stream in an invalid state: " + state);
} }
if (activeStreams.add(this)) { activeStreams.activate(this);
// Update the number of active streams initiated by the endpoint.
createdBy().numActiveStreams++;
// Notify the listeners.
for (int i = 0; i < listeners.size(); i++) {
listeners.get(i).onStreamActive(this);
}
}
return this; return this;
} }
@ -404,20 +410,8 @@ public class DefaultHttp2Connection implements Http2Connection {
state = CLOSED; state = CLOSED;
decrementPrioritizableForTree(1); decrementPrioritizableForTree(1);
if (activeStreams.remove(this)) {
try {
// Update the number of active streams initiated by the endpoint.
createdBy().numActiveStreams--;
// Notify the listeners. activeStreams.deactivate(this);
for (int i = 0; i < listeners.size(); i++) {
listeners.get(i).onStreamClosed(this);
}
} finally {
// Mark this stream for removal.
removalPolicy.markForRemoval(this);
}
}
return this; return this;
} }
@ -790,8 +784,9 @@ public class DefaultHttp2Connection implements Http2Connection {
private int lastKnownStream = -1; private int lastKnownStream = -1;
private boolean pushToAllowed = true; private boolean pushToAllowed = true;
private F flowController; private F flowController;
private int numActiveStreams;
private int maxActiveStreams; private int maxActiveStreams;
// Fields accessed by inner classes
int numActiveStreams;
DefaultEndpoint(boolean server) { DefaultEndpoint(boolean server) {
this.server = server; this.server = server;
@ -951,8 +946,8 @@ public class DefaultHttp2Connection implements Http2Connection {
throw new Http2NoMoreStreamIdsException(); throw new Http2NoMoreStreamIdsException();
} }
if (!createdStreamId(streamId)) { if (!createdStreamId(streamId)) {
throw connectionError(PROTOCOL_ERROR, "Request stream %d is not correct for %s connection", throw connectionError(PROTOCOL_ERROR, "Request stream %d is not correct for %s connection", streamId,
streamId, server ? "server" : "client"); server ? "server" : "client");
} }
// This check must be after all id validated checks, but before the max streams check because it may be // This check must be after all id validated checks, but before the max streams check because it may be
// recoverable to some degree for handling frames which can be sent on closed streams. // recoverable to some degree for handling frames which can be sent on closed streams.
@ -969,4 +964,116 @@ public class DefaultHttp2Connection implements Http2Connection {
return this == localEndpoint; return this == localEndpoint;
} }
} }
/**
* Default implementation of the {@link ActiveStreams} class.
*/
private static final class ActiveStreams {
/**
* Allows events which would modify {@link #streams} to be queued while iterating over {@link #streams}.
*/
interface Event {
/**
* Trigger the original intention of this event. Expect to modify {@link #streams}.
*/
void process();
}
private final List<Listener> listeners;
private final Http2StreamRemovalPolicy removalPolicy;
private final Queue<Event> pendingEvents = new ArrayDeque<Event>(4);
private final Set<Http2Stream> streams = new LinkedHashSet<Http2Stream>();
private int pendingIterations;
public ActiveStreams(List<Listener> listeners, Http2StreamRemovalPolicy removalPolicy) {
this.listeners = listeners;
this.removalPolicy = removalPolicy;
}
public int size() {
return streams.size();
}
public void activate(final DefaultStream stream) {
if (allowModifications()) {
addToActiveStreams(stream);
} else {
pendingEvents.add(new Event() {
@Override
public void process() {
addToActiveStreams(stream);
}
});
}
}
public void deactivate(final DefaultStream stream) {
if (allowModifications()) {
removeFromActiveStreams(stream);
} else {
pendingEvents.add(new Event() {
@Override
public void process() {
removeFromActiveStreams(stream);
}
});
}
}
public Http2Stream forEachActiveStream(StreamVisitor visitor) throws Http2Exception {
++pendingIterations;
Http2Stream resultStream = null;
try {
for (Http2Stream stream : streams) {
if (!visitor.visit(stream)) {
resultStream = stream;
break;
}
}
return resultStream;
} finally {
--pendingIterations;
if (allowModifications()) {
for (;;) {
Event event = pendingEvents.poll();
if (event == null) {
break;
}
event.process();
}
}
}
}
void addToActiveStreams(DefaultStream stream) {
if (streams.add(stream)) {
// Update the number of active streams initiated by the endpoint.
stream.createdBy().numActiveStreams++;
for (int i = 0; i < listeners.size(); i++) {
listeners.get(i).onStreamActive(stream);
}
}
}
void removeFromActiveStreams(DefaultStream stream) {
if (streams.remove(stream)) {
try {
// Update the number of active streams initiated by the endpoint.
stream.createdBy().numActiveStreams--;
for (int i = 0; i < listeners.size(); i++) {
listeners.get(i).onStreamClosed(stream);
}
} finally {
// Mark this stream for removal.
removalPolicy.markForRemoval(stream);
}
}
}
private boolean allowModifications() {
return pendingIterations == 0;
}
}
} }

View File

@ -28,6 +28,7 @@ import static java.lang.Math.max;
import static java.lang.Math.min; import static java.lang.Math.min;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http2.Http2Connection.StreamVisitor;
import io.netty.handler.codec.http2.Http2Exception.CompositeStreamException; import io.netty.handler.codec.http2.Http2Exception.CompositeStreamException;
import io.netty.handler.codec.http2.Http2Exception.StreamException; import io.netty.handler.codec.http2.Http2Exception.StreamException;
@ -35,7 +36,6 @@ import io.netty.handler.codec.http2.Http2Exception.StreamException;
* Basic implementation of {@link Http2LocalFlowController}. * Basic implementation of {@link Http2LocalFlowController}.
*/ */
public class DefaultHttp2LocalFlowController implements Http2LocalFlowController { public class DefaultHttp2LocalFlowController implements Http2LocalFlowController {
private static final int DEFAULT_COMPOSITE_EXCEPTION_SIZE = 4;
/** /**
* The default ratio of window size to initial window size below which a {@code WINDOW_UPDATE} * The default ratio of window size to initial window size below which a {@code WINDOW_UPDATE}
* is sent to expand the window. * is sent to expand the window.
@ -82,23 +82,9 @@ public class DefaultHttp2LocalFlowController implements Http2LocalFlowController
int delta = newWindowSize - initialWindowSize; int delta = newWindowSize - initialWindowSize;
initialWindowSize = newWindowSize; initialWindowSize = newWindowSize;
CompositeStreamException compositeException = null; WindowUpdateVisitor visitor = new WindowUpdateVisitor(delta);
for (Http2Stream stream : connection.activeStreams()) { connection.forEachActiveStream(visitor);
try { visitor.throwIfError();
// Increment flow control window first so state will be consistent if overflow is detected
FlowState state = state(stream);
state.incrementFlowControlWindows(delta);
state.incrementInitialStreamWindow(delta);
} catch (StreamException e) {
if (compositeException == null) {
compositeException = new CompositeStreamException(e.error(), DEFAULT_COMPOSITE_EXCEPTION_SIZE);
}
compositeException.add(e);
}
}
if (compositeException != null) {
throw compositeException;
}
} }
@Override @Override
@ -208,7 +194,7 @@ public class DefaultHttp2LocalFlowController implements Http2LocalFlowController
return state(connection.connectionStream()); return state(connection.connectionStream());
} }
private FlowState state(Http2Stream stream) { private static FlowState state(Http2Stream stream) {
checkNotNull(stream, "stream"); checkNotNull(stream, "stream");
return stream.getProperty(FlowState.class); return stream.getProperty(FlowState.class);
} }
@ -391,4 +377,38 @@ public class DefaultHttp2LocalFlowController implements Http2LocalFlowController
ctx.flush(); ctx.flush();
} }
} }
/**
* Provides a means to iterate over all active streams and increment the flow control windows.
*/
private static final class WindowUpdateVisitor implements StreamVisitor {
private CompositeStreamException compositeException;
private final int delta;
public WindowUpdateVisitor(int delta) {
this.delta = delta;
}
@Override
public boolean visit(Http2Stream stream) throws Http2Exception {
try {
// Increment flow control window first so state will be consistent if overflow is detected.
FlowState state = state(stream);
state.incrementFlowControlWindows(delta);
state.incrementInitialStreamWindow(delta);
} catch (StreamException e) {
if (compositeException == null) {
compositeException = new CompositeStreamException(e.error(), 4);
}
compositeException.add(e);
}
return true;
}
public void throwIfError() throws CompositeStreamException {
if (compositeException != null) {
throw compositeException;
}
}
}
} }

View File

@ -23,16 +23,23 @@ import static io.netty.util.internal.ObjectUtil.checkNotNull;
import static java.lang.Math.max; import static java.lang.Math.max;
import static java.lang.Math.min; import static java.lang.Math.min;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http2.Http2Connection.StreamVisitor;
import io.netty.handler.codec.http2.Http2Stream.State; import io.netty.handler.codec.http2.Http2Stream.State;
import java.util.ArrayDeque; import java.util.ArrayDeque;
import java.util.Collection;
import java.util.Deque; import java.util.Deque;
/** /**
* Basic implementation of {@link Http2RemoteFlowController}. * Basic implementation of {@link Http2RemoteFlowController}.
*/ */
public class DefaultHttp2RemoteFlowController implements Http2RemoteFlowController { public class DefaultHttp2RemoteFlowController implements Http2RemoteFlowController {
private static final StreamVisitor WRITE_ALLOCATED_BYTES = new StreamVisitor() {
@Override
public boolean visit(Http2Stream stream) {
state(stream).writeAllocatedBytes();
return true;
}
};
private final Http2Connection connection; private final Http2Connection connection;
private int initialWindowSize = DEFAULT_WINDOW_SIZE; private int initialWindowSize = DEFAULT_WINDOW_SIZE;
private ChannelHandlerContext ctx; private ChannelHandlerContext ctx;
@ -115,12 +122,16 @@ public class DefaultHttp2RemoteFlowController implements Http2RemoteFlowControll
throw new IllegalArgumentException("Invalid initial window size: " + newWindowSize); throw new IllegalArgumentException("Invalid initial window size: " + newWindowSize);
} }
int delta = newWindowSize - initialWindowSize; final int delta = newWindowSize - initialWindowSize;
initialWindowSize = newWindowSize; initialWindowSize = newWindowSize;
for (Http2Stream stream : connection.activeStreams()) { connection.forEachActiveStream(new StreamVisitor() {
@Override
public boolean visit(Http2Stream stream) throws Http2Exception {
// Verify that the maximum value is not exceeded by this change. // Verify that the maximum value is not exceeded by this change.
state(stream).incrementStreamWindow(delta); state(stream).incrementStreamWindow(delta);
return true;
} }
});
if (delta > 0) { if (delta > 0) {
// The window size increased, send any pending frames for all streams. // The window size increased, send any pending frames for all streams.
@ -215,7 +226,7 @@ public class DefaultHttp2RemoteFlowController implements Http2RemoteFlowControll
/** /**
* Writes as many pending bytes as possible, according to stream priority. * Writes as many pending bytes as possible, according to stream priority.
*/ */
private void writePendingBytes() { private void writePendingBytes() throws Http2Exception {
Http2Stream connectionStream = connection.connectionStream(); Http2Stream connectionStream = connection.connectionStream();
int connectionWindow = state(connectionStream).window(); int connectionWindow = state(connectionStream).window();
@ -223,13 +234,8 @@ public class DefaultHttp2RemoteFlowController implements Http2RemoteFlowControll
// Allocate the bytes for the connection window to the streams, but do not write. // Allocate the bytes for the connection window to the streams, but do not write.
allocateBytesForTree(connectionStream, connectionWindow); allocateBytesForTree(connectionStream, connectionWindow);
// Now write all of the allocated bytes. Copying the activeStreams array to avoid // Now write all of the allocated bytes.
// side effects due to stream removal/addition which might occur as a result connection.forEachActiveStream(WRITE_ALLOCATED_BYTES);
// of end-of-stream or errors.
Collection<Http2Stream> streams = connection.activeStreams();
for (Http2Stream stream : streams.toArray(new Http2Stream[streams.size()])) {
state(stream).writeAllocatedBytes();
}
flush(); flush();
} }
} }

View File

@ -17,8 +17,6 @@ package io.netty.handler.codec.http2;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import java.util.Collection;
/** /**
* Manager for the state of an HTTP/2 connection with the remote end-point. * Manager for the state of an HTTP/2 connection with the remote end-point.
*/ */
@ -260,10 +258,25 @@ public interface Http2Connection {
int numActiveStreams(); int numActiveStreams();
/** /**
* Gets all streams that are actively in use (i.e. {@code OPEN} or {@code HALF CLOSED}). The returned collection is * Provide a means of iterating over the collection of active streams.
* sorted by priority. *
* @param visitor The visitor which will visit each active stream.
* @return The stream before iteration stopped or {@code null} if iteration went past the end.
*/ */
Collection<Http2Stream> activeStreams(); Http2Stream forEachActiveStream(StreamVisitor visitor) throws Http2Exception;
/**
* A visitor that allows iteration over a collection of streams.
*/
interface StreamVisitor {
/**
* @return <ul>
* <li>{@code true} if the processor wants to continue the loop and handle the entry.</li>
* <li>{@code false} if the processor wants to stop handling headers and abort the loop.</li>
* </ul>
*/
boolean visit(Http2Stream stream) throws Http2Exception;
}
/** /**
* Indicates whether or not the local endpoint for this connection is the server. * Indicates whether or not the local endpoint for this connection is the server.

View File

@ -32,13 +32,13 @@ import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.http2.Http2Connection.StreamVisitor;
import io.netty.handler.codec.http2.Http2Exception.CompositeStreamException; import io.netty.handler.codec.http2.Http2Exception.CompositeStreamException;
import io.netty.handler.codec.http2.Http2Exception.StreamException; import io.netty.handler.codec.http2.Http2Exception.StreamException;
import io.netty.util.concurrent.GenericFutureListener; import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory; import io.netty.util.internal.logging.InternalLoggerFactory;
import java.util.Collection;
import java.util.List; import java.util.List;
/** /**
@ -142,10 +142,17 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
public void channelInactive(ChannelHandlerContext ctx) throws Exception { public void channelInactive(ChannelHandlerContext ctx) throws Exception {
try { try {
ChannelFuture future = ctx.newSucceededFuture(); final Http2Connection connection = connection();
final Collection<Http2Stream> streams = connection().activeStreams(); // Check if there are streams to avoid the overhead of creating the ChannelFuture.
for (Http2Stream s : streams.toArray(new Http2Stream[streams.size()])) { if (connection.numActiveStreams() > 0) {
closeStream(s, future); final ChannelFuture future = ctx.newSucceededFuture();
connection.forEachActiveStream(new StreamVisitor() {
@Override
public boolean visit(Http2Stream stream) throws Http2Exception {
closeStream(stream, future);
return true;
}
});
} }
} finally { } finally {
try { try {

View File

@ -46,6 +46,7 @@ import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise; import io.netty.channel.DefaultChannelPromise;
import io.netty.handler.codec.http2.Http2Connection.StreamVisitor;
import io.netty.handler.codec.http2.Http2Exception.ClosedStreamCreationException; import io.netty.handler.codec.http2.Http2Exception.ClosedStreamCreationException;
import java.util.Collections; import java.util.Collections;
@ -124,7 +125,16 @@ public class DefaultHttp2ConnectionDecoderTest {
when(stream.state()).thenReturn(OPEN); when(stream.state()).thenReturn(OPEN);
when(stream.open(anyBoolean())).thenReturn(stream); when(stream.open(anyBoolean())).thenReturn(stream);
when(pushStream.id()).thenReturn(PUSH_STREAM_ID); when(pushStream.id()).thenReturn(PUSH_STREAM_ID);
when(connection.activeStreams()).thenReturn(Collections.singletonList(stream)); doAnswer(new Answer<Http2Stream>() {
@Override
public Http2Stream answer(InvocationOnMock in) throws Throwable {
StreamVisitor visitor = in.getArgumentAt(0, StreamVisitor.class);
if (!visitor.visit(stream)) {
return stream;
}
return null;
}
}).when(connection).forEachActiveStream(any(StreamVisitor.class));
when(connection.stream(STREAM_ID)).thenReturn(stream); when(connection.stream(STREAM_ID)).thenReturn(stream);
when(connection.requireStream(STREAM_ID)).thenReturn(stream); when(connection.requireStream(STREAM_ID)).thenReturn(stream);
when(connection.local()).thenReturn(local); when(connection.local()).thenReturn(local);

View File

@ -53,12 +53,12 @@ import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise; import io.netty.channel.DefaultChannelPromise;
import io.netty.handler.codec.http2.Http2Connection.StreamVisitor;
import io.netty.handler.codec.http2.Http2Exception.ClosedStreamCreationException; import io.netty.handler.codec.http2.Http2Exception.ClosedStreamCreationException;
import io.netty.handler.codec.http2.Http2RemoteFlowController.FlowControlled; import io.netty.handler.codec.http2.Http2RemoteFlowController.FlowControlled;
import io.netty.util.concurrent.ImmediateEventExecutor; import io.netty.util.concurrent.ImmediateEventExecutor;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
import org.junit.Before; import org.junit.Before;
@ -138,7 +138,16 @@ public class DefaultHttp2ConnectionEncoderTest {
when(stream.state()).thenReturn(OPEN); when(stream.state()).thenReturn(OPEN);
when(stream.open(anyBoolean())).thenReturn(stream); when(stream.open(anyBoolean())).thenReturn(stream);
when(pushStream.id()).thenReturn(PUSH_STREAM_ID); when(pushStream.id()).thenReturn(PUSH_STREAM_ID);
when(connection.activeStreams()).thenReturn(Collections.singletonList(stream)); doAnswer(new Answer<Http2Stream>() {
@Override
public Http2Stream answer(InvocationOnMock in) throws Throwable {
StreamVisitor visitor = in.getArgumentAt(0, StreamVisitor.class);
if (!visitor.visit(stream)) {
return stream;
}
return null;
}
}).when(connection).forEachActiveStream(any(StreamVisitor.class));
when(connection.stream(STREAM_ID)).thenReturn(stream); when(connection.stream(STREAM_ID)).thenReturn(stream);
when(connection.requireStream(STREAM_ID)).thenReturn(stream); when(connection.requireStream(STREAM_ID)).thenReturn(stream);
when(connection.local()).thenReturn(local); when(connection.local()).thenReturn(local);

View File

@ -21,7 +21,6 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame; import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyShort; import static org.mockito.Matchers.anyShort;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
@ -30,7 +29,6 @@ import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset; import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.handler.codec.http2.Http2Connection.Endpoint; import io.netty.handler.codec.http2.Http2Connection.Endpoint;
import io.netty.handler.codec.http2.Http2Stream.State; import io.netty.handler.codec.http2.Http2Stream.State;
@ -118,25 +116,25 @@ public class DefaultHttp2ConnectionTest {
Http2Stream stream = server.local().createStream(2).open(false); Http2Stream stream = server.local().createStream(2).open(false);
assertEquals(2, stream.id()); assertEquals(2, stream.id());
assertEquals(State.OPEN, stream.state()); assertEquals(State.OPEN, stream.state());
assertEquals(1, server.activeStreams().size()); assertEquals(1, server.numActiveStreams());
assertEquals(2, server.local().lastStreamCreated()); assertEquals(2, server.local().lastStreamCreated());
stream = server.local().createStream(4).open(true); stream = server.local().createStream(4).open(true);
assertEquals(4, stream.id()); assertEquals(4, stream.id());
assertEquals(State.HALF_CLOSED_LOCAL, stream.state()); assertEquals(State.HALF_CLOSED_LOCAL, stream.state());
assertEquals(2, server.activeStreams().size()); assertEquals(2, server.numActiveStreams());
assertEquals(4, server.local().lastStreamCreated()); assertEquals(4, server.local().lastStreamCreated());
stream = server.remote().createStream(3).open(true); stream = server.remote().createStream(3).open(true);
assertEquals(3, stream.id()); assertEquals(3, stream.id());
assertEquals(State.HALF_CLOSED_REMOTE, stream.state()); assertEquals(State.HALF_CLOSED_REMOTE, stream.state());
assertEquals(3, server.activeStreams().size()); assertEquals(3, server.numActiveStreams());
assertEquals(3, server.remote().lastStreamCreated()); assertEquals(3, server.remote().lastStreamCreated());
stream = server.remote().createStream(5).open(false); stream = server.remote().createStream(5).open(false);
assertEquals(5, stream.id()); assertEquals(5, stream.id());
assertEquals(State.OPEN, stream.state()); assertEquals(State.OPEN, stream.state());
assertEquals(4, server.activeStreams().size()); assertEquals(4, server.numActiveStreams());
assertEquals(5, server.remote().lastStreamCreated()); assertEquals(5, server.remote().lastStreamCreated());
} }
@ -145,25 +143,25 @@ public class DefaultHttp2ConnectionTest {
Http2Stream stream = client.remote().createStream(2).open(false); Http2Stream stream = client.remote().createStream(2).open(false);
assertEquals(2, stream.id()); assertEquals(2, stream.id());
assertEquals(State.OPEN, stream.state()); assertEquals(State.OPEN, stream.state());
assertEquals(1, client.activeStreams().size()); assertEquals(1, client.numActiveStreams());
assertEquals(2, client.remote().lastStreamCreated()); assertEquals(2, client.remote().lastStreamCreated());
stream = client.remote().createStream(4).open(true); stream = client.remote().createStream(4).open(true);
assertEquals(4, stream.id()); assertEquals(4, stream.id());
assertEquals(State.HALF_CLOSED_REMOTE, stream.state()); assertEquals(State.HALF_CLOSED_REMOTE, stream.state());
assertEquals(2, client.activeStreams().size()); assertEquals(2, client.numActiveStreams());
assertEquals(4, client.remote().lastStreamCreated()); assertEquals(4, client.remote().lastStreamCreated());
stream = client.local().createStream(3).open(true); stream = client.local().createStream(3).open(true);
assertEquals(3, stream.id()); assertEquals(3, stream.id());
assertEquals(State.HALF_CLOSED_LOCAL, stream.state()); assertEquals(State.HALF_CLOSED_LOCAL, stream.state());
assertEquals(3, client.activeStreams().size()); assertEquals(3, client.numActiveStreams());
assertEquals(3, client.local().lastStreamCreated()); assertEquals(3, client.local().lastStreamCreated());
stream = client.local().createStream(5).open(false); stream = client.local().createStream(5).open(false);
assertEquals(5, stream.id()); assertEquals(5, stream.id());
assertEquals(State.OPEN, stream.state()); assertEquals(State.OPEN, stream.state());
assertEquals(4, client.activeStreams().size()); assertEquals(4, client.numActiveStreams());
assertEquals(5, client.local().lastStreamCreated()); assertEquals(5, client.local().lastStreamCreated());
} }
@ -173,7 +171,7 @@ public class DefaultHttp2ConnectionTest {
Http2Stream pushStream = server.local().reservePushStream(2, stream); Http2Stream pushStream = server.local().reservePushStream(2, stream);
assertEquals(2, pushStream.id()); assertEquals(2, pushStream.id());
assertEquals(State.RESERVED_LOCAL, pushStream.state()); assertEquals(State.RESERVED_LOCAL, pushStream.state());
assertEquals(1, server.activeStreams().size()); assertEquals(1, server.numActiveStreams());
assertEquals(2, server.local().lastStreamCreated()); assertEquals(2, server.local().lastStreamCreated());
} }
@ -183,7 +181,7 @@ public class DefaultHttp2ConnectionTest {
Http2Stream pushStream = server.local().reservePushStream(4, stream); Http2Stream pushStream = server.local().reservePushStream(4, stream);
assertEquals(4, pushStream.id()); assertEquals(4, pushStream.id());
assertEquals(State.RESERVED_LOCAL, pushStream.state()); assertEquals(State.RESERVED_LOCAL, pushStream.state());
assertEquals(1, server.activeStreams().size()); assertEquals(1, server.numActiveStreams());
assertEquals(4, server.local().lastStreamCreated()); assertEquals(4, server.local().lastStreamCreated());
} }
@ -226,7 +224,7 @@ public class DefaultHttp2ConnectionTest {
Http2Stream stream = server.remote().createStream(3).open(true); Http2Stream stream = server.remote().createStream(3).open(true);
stream.close(); stream.close();
assertEquals(State.CLOSED, stream.state()); assertEquals(State.CLOSED, stream.state());
assertTrue(server.activeStreams().isEmpty()); assertEquals(0, server.numActiveStreams());
} }
@Test @Test
@ -234,7 +232,7 @@ public class DefaultHttp2ConnectionTest {
Http2Stream stream = server.remote().createStream(3).open(false); Http2Stream stream = server.remote().createStream(3).open(false);
stream.closeLocalSide(); stream.closeLocalSide();
assertEquals(State.HALF_CLOSED_LOCAL, stream.state()); assertEquals(State.HALF_CLOSED_LOCAL, stream.state());
assertEquals(1, server.activeStreams().size()); assertEquals(1, server.numActiveStreams());
} }
@Test @Test
@ -242,7 +240,7 @@ public class DefaultHttp2ConnectionTest {
Http2Stream stream = server.remote().createStream(3).open(false); Http2Stream stream = server.remote().createStream(3).open(false);
stream.closeRemoteSide(); stream.closeRemoteSide();
assertEquals(State.HALF_CLOSED_REMOTE, stream.state()); assertEquals(State.HALF_CLOSED_REMOTE, stream.state());
assertEquals(1, server.activeStreams().size()); assertEquals(1, server.numActiveStreams());
} }
@Test @Test
@ -250,7 +248,7 @@ public class DefaultHttp2ConnectionTest {
Http2Stream stream = server.remote().createStream(3).open(true); Http2Stream stream = server.remote().createStream(3).open(true);
stream.closeLocalSide(); stream.closeLocalSide();
assertEquals(State.CLOSED, stream.state()); assertEquals(State.CLOSED, stream.state());
assertTrue(server.activeStreams().isEmpty()); assertEquals(0, server.numActiveStreams());
} }
@Test(expected = Http2Exception.class) @Test(expected = Http2Exception.class)

View File

@ -44,10 +44,9 @@ import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise; import io.netty.channel.DefaultChannelPromise;
import io.netty.handler.codec.http2.Http2Connection.StreamVisitor;
import io.netty.util.concurrent.GenericFutureListener; import io.netty.util.concurrent.GenericFutureListener;
import java.util.Arrays;
import java.util.Collections;
import java.util.List; import java.util.List;
import org.junit.After; import org.junit.After;
@ -113,8 +112,18 @@ public class Http2ConnectionHandlerTest {
when(channel.isActive()).thenReturn(true); when(channel.isActive()).thenReturn(true);
when(connection.remote()).thenReturn(remote); when(connection.remote()).thenReturn(remote);
when(connection.local()).thenReturn(local); when(connection.local()).thenReturn(local);
when(connection.activeStreams()).thenReturn(Collections.singletonList(stream)); doAnswer(new Answer<Http2Stream>() {
@Override
public Http2Stream answer(InvocationOnMock in) throws Throwable {
StreamVisitor visitor = in.getArgumentAt(0, StreamVisitor.class);
if (!visitor.visit(stream)) {
return stream;
}
return null;
}
}).when(connection).forEachActiveStream(any(StreamVisitor.class));
when(connection.stream(NON_EXISTANT_STREAM_ID)).thenReturn(null); when(connection.stream(NON_EXISTANT_STREAM_ID)).thenReturn(null);
when(connection.numActiveStreams()).thenReturn(1);
when(connection.stream(STREAM_ID)).thenReturn(stream); when(connection.stream(STREAM_ID)).thenReturn(stream);
when(stream.open(anyBoolean())).thenReturn(stream); when(stream.open(anyBoolean())).thenReturn(stream);
when(encoder.writeSettings(eq(ctx), any(Http2Settings.class), eq(promise))).thenReturn(future); when(encoder.writeSettings(eq(ctx), any(Http2Settings.class), eq(promise))).thenReturn(future);
@ -266,8 +275,6 @@ public class Http2ConnectionHandlerTest {
@Test @Test
public void closeListenerShouldBeNotifiedOnlyOneTime() throws Exception { public void closeListenerShouldBeNotifiedOnlyOneTime() throws Exception {
handler = newHandler(); handler = newHandler();
when(connection.activeStreams()).thenReturn(Arrays.asList(stream));
when(connection.numActiveStreams()).thenReturn(1);
when(future.isDone()).thenReturn(true); when(future.isDone()).thenReturn(true);
when(future.isSuccess()).thenReturn(true); when(future.isSuccess()).thenReturn(true);
doAnswer(new Answer<ChannelFuture>() { doAnswer(new Answer<ChannelFuture>() {
@ -276,7 +283,12 @@ public class Http2ConnectionHandlerTest {
Object[] args = invocation.getArguments(); Object[] args = invocation.getArguments();
GenericFutureListener<ChannelFuture> listener = (GenericFutureListener<ChannelFuture>) args[0]; GenericFutureListener<ChannelFuture> listener = (GenericFutureListener<ChannelFuture>) args[0];
// Simulate that all streams have become inactive by the time the future completes. // Simulate that all streams have become inactive by the time the future completes.
when(connection.activeStreams()).thenReturn(Collections.<Http2Stream>emptyList()); doAnswer(new Answer<Http2Stream>() {
@Override
public Http2Stream answer(InvocationOnMock in) throws Throwable {
return null;
}
}).when(connection).forEachActiveStream(any(StreamVisitor.class));
when(connection.numActiveStreams()).thenReturn(0); when(connection.numActiveStreams()).thenReturn(0);
// Simulate the future being completed. // Simulate the future being completed.
listener.operationComplete(future); listener.operationComplete(future);