From 83ce8a9187a11b7bf38fd91841fc0b03d3e24d80 Mon Sep 17 00:00:00 2001 From: Scott Mitchell Date: Sat, 28 Mar 2015 13:32:19 -0700 Subject: [PATCH] HTTP/2 Prevent modification of activeStreams while iterating Motivation: The Http2Connection interface exposes an activeStreams() method which allows direct iteration over the underlying collection. There are a few places that make copies of this collection to avoid modification while iterating, and a few places that do not make copies. The copy operation can be expensive on hot code paths and also we are not consistently iterating over the activeStreams collection. Modifications: - The Http2Connection interface should reduce the exposure of the underlying collection and just expose what is necessary for the interface to function. This is just a means to iterate over the collection. - The DefaultHttp2Connection should use this new interface and protect it's internal state while iteration is occurring. Result: Reduction in surface area of the Http2Connection interface. Consistent iteration of the set of active streams. Concurrent modification exceptions are handled in 1 encapsulated spot. --- .../codec/http2/DefaultHttp2Connection.java | 207 +++++++++++++----- .../DefaultHttp2LocalFlowController.java | 58 +++-- .../DefaultHttp2RemoteFlowController.java | 34 +-- .../handler/codec/http2/Http2Connection.java | 23 +- .../codec/http2/Http2ConnectionHandler.java | 17 +- .../DefaultHttp2ConnectionDecoderTest.java | 12 +- .../DefaultHttp2ConnectionEncoderTest.java | 13 +- .../http2/DefaultHttp2ConnectionTest.java | 30 ++- .../http2/Http2ConnectionHandlerTest.java | 24 +- 9 files changed, 300 insertions(+), 118 deletions(-) diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2Connection.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2Connection.java index 8551634632..f44b571570 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2Connection.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2Connection.java @@ -33,39 +33,39 @@ import static io.netty.handler.codec.http2.Http2Stream.State.OPEN; import static io.netty.handler.codec.http2.Http2Stream.State.RESERVED_LOCAL; import static io.netty.handler.codec.http2.Http2Stream.State.RESERVED_REMOTE; import static io.netty.util.internal.ObjectUtil.checkNotNull; - import io.netty.buffer.ByteBuf; import io.netty.handler.codec.http2.Http2StreamRemovalPolicy.Action; import io.netty.util.collection.IntObjectHashMap; import io.netty.util.collection.IntObjectMap; import io.netty.util.collection.PrimitiveCollections; +import io.netty.util.internal.PlatformDependent; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.Queue; import java.util.Set; /** * Simple implementation of {@link Http2Connection}. */ public class DefaultHttp2Connection implements Http2Connection { - + // Fields accessed by inner classes + final IntObjectMap streamMap = new IntObjectHashMap(); + final ConnectionStream connectionStream = new ConnectionStream(); + final DefaultEndpoint localEndpoint; + final DefaultEndpoint remoteEndpoint; /** * We chose a {@link List} over a {@link Set} to avoid allocating an {@link Iterator} objects when iterating over * the listeners. */ - private final List listeners = new ArrayList(4); - private final IntObjectMap streamMap = new IntObjectHashMap(); - private final ConnectionStream connectionStream = new ConnectionStream(); - private final Set activeStreams = new LinkedHashSet(); - private final DefaultEndpoint localEndpoint; - private final DefaultEndpoint remoteEndpoint; - private final Http2StreamRemovalPolicy removalPolicy; + final List listeners = new ArrayList(4); + final ActiveStreams activeStreams; /** * Creates a connection with an immediate stream removal policy. @@ -86,7 +86,7 @@ public class DefaultHttp2Connection implements Http2Connection { * the policy to be used for removal of closed stream. */ public DefaultHttp2Connection(boolean server, Http2StreamRemovalPolicy removalPolicy) { - this.removalPolicy = checkNotNull(removalPolicy, "removalPolicy"); + activeStreams = new ActiveStreams(listeners, checkNotNull(removalPolicy, "removalPolicy")); localEndpoint = new DefaultEndpoint(server); remoteEndpoint = new DefaultEndpoint(!server); @@ -142,8 +142,8 @@ public class DefaultHttp2Connection implements Http2Connection { } @Override - public Set activeStreams() { - return Collections.unmodifiableSet(activeStreams); + public Http2Stream forEachActiveStream(StreamVisitor visitor) throws Http2Exception { + return activeStreams.forEachActiveStream(visitor); } @Override @@ -162,17 +162,24 @@ public class DefaultHttp2Connection implements Http2Connection { } @Override - public void goAwayReceived(int lastKnownStream, long errorCode, ByteBuf debugData) { + public void goAwayReceived(final int lastKnownStream, long errorCode, ByteBuf debugData) { localEndpoint.lastKnownStream(lastKnownStream); for (Listener listener : listeners) { listener.onGoAwayReceived(lastKnownStream, errorCode, debugData); } - Http2Stream[] streams = new Http2Stream[numActiveStreams()]; - for (Http2Stream stream : activeStreams().toArray(streams)) { - if (stream.id() > lastKnownStream && localEndpoint.createdStreamId(stream.id())) { - stream.close(); - } + try { + forEachActiveStream(new StreamVisitor() { + @Override + public boolean visit(Http2Stream stream) { + if (stream.id() > lastKnownStream && localEndpoint.createdStreamId(stream.id())) { + stream.close(); + } + return true; + } + }); + } catch (Http2Exception e) { + PlatformDependent.throwException(e); } } @@ -182,17 +189,24 @@ public class DefaultHttp2Connection implements Http2Connection { } @Override - public void goAwaySent(int lastKnownStream, long errorCode, ByteBuf debugData) { + public void goAwaySent(final int lastKnownStream, long errorCode, ByteBuf debugData) { remoteEndpoint.lastKnownStream(lastKnownStream); for (Listener listener : listeners) { listener.onGoAwaySent(lastKnownStream, errorCode, debugData); } - Http2Stream[] streams = new Http2Stream[numActiveStreams()]; - for (Http2Stream stream : activeStreams().toArray(streams)) { - if (stream.id() > lastKnownStream && remoteEndpoint.createdStreamId(stream.id())) { - stream.close(); - } + try { + forEachActiveStream(new StreamVisitor() { + @Override + public boolean visit(Http2Stream stream) { + if (stream.id() > lastKnownStream && remoteEndpoint.createdStreamId(stream.id())) { + stream.close(); + } + return true; + } + }); + } catch (Http2Exception e) { + PlatformDependent.throwException(e); } } @@ -384,15 +398,7 @@ public class DefaultHttp2Connection implements Http2Connection { throw streamError(id, PROTOCOL_ERROR, "Attempting to open a stream in an invalid state: " + state); } - if (activeStreams.add(this)) { - // Update the number of active streams initiated by the endpoint. - createdBy().numActiveStreams++; - - // Notify the listeners. - for (int i = 0; i < listeners.size(); i++) { - listeners.get(i).onStreamActive(this); - } - } + activeStreams.activate(this); return this; } @@ -404,20 +410,8 @@ public class DefaultHttp2Connection implements Http2Connection { state = CLOSED; decrementPrioritizableForTree(1); - if (activeStreams.remove(this)) { - try { - // Update the number of active streams initiated by the endpoint. - createdBy().numActiveStreams--; - // Notify the listeners. - for (int i = 0; i < listeners.size(); i++) { - listeners.get(i).onStreamClosed(this); - } - } finally { - // Mark this stream for removal. - removalPolicy.markForRemoval(this); - } - } + activeStreams.deactivate(this); return this; } @@ -790,8 +784,9 @@ public class DefaultHttp2Connection implements Http2Connection { private int lastKnownStream = -1; private boolean pushToAllowed = true; private F flowController; - private int numActiveStreams; private int maxActiveStreams; + // Fields accessed by inner classes + int numActiveStreams; DefaultEndpoint(boolean server) { this.server = server; @@ -951,8 +946,8 @@ public class DefaultHttp2Connection implements Http2Connection { throw new Http2NoMoreStreamIdsException(); } if (!createdStreamId(streamId)) { - throw connectionError(PROTOCOL_ERROR, "Request stream %d is not correct for %s connection", - streamId, server ? "server" : "client"); + throw connectionError(PROTOCOL_ERROR, "Request stream %d is not correct for %s connection", streamId, + server ? "server" : "client"); } // This check must be after all id validated checks, but before the max streams check because it may be // recoverable to some degree for handling frames which can be sent on closed streams. @@ -969,4 +964,116 @@ public class DefaultHttp2Connection implements Http2Connection { return this == localEndpoint; } } + + /** + * Default implementation of the {@link ActiveStreams} class. + */ + private static final class ActiveStreams { + /** + * Allows events which would modify {@link #streams} to be queued while iterating over {@link #streams}. + */ + interface Event { + /** + * Trigger the original intention of this event. Expect to modify {@link #streams}. + */ + void process(); + } + + private final List listeners; + private final Http2StreamRemovalPolicy removalPolicy; + private final Queue pendingEvents = new ArrayDeque(4); + private final Set streams = new LinkedHashSet(); + private int pendingIterations; + + public ActiveStreams(List listeners, Http2StreamRemovalPolicy removalPolicy) { + this.listeners = listeners; + this.removalPolicy = removalPolicy; + } + + public int size() { + return streams.size(); + } + + public void activate(final DefaultStream stream) { + if (allowModifications()) { + addToActiveStreams(stream); + } else { + pendingEvents.add(new Event() { + @Override + public void process() { + addToActiveStreams(stream); + } + }); + } + } + + public void deactivate(final DefaultStream stream) { + if (allowModifications()) { + removeFromActiveStreams(stream); + } else { + pendingEvents.add(new Event() { + @Override + public void process() { + removeFromActiveStreams(stream); + } + }); + } + } + + public Http2Stream forEachActiveStream(StreamVisitor visitor) throws Http2Exception { + ++pendingIterations; + Http2Stream resultStream = null; + try { + for (Http2Stream stream : streams) { + if (!visitor.visit(stream)) { + resultStream = stream; + break; + } + } + return resultStream; + } finally { + --pendingIterations; + if (allowModifications()) { + for (;;) { + Event event = pendingEvents.poll(); + if (event == null) { + break; + } + event.process(); + } + } + } + } + + void addToActiveStreams(DefaultStream stream) { + if (streams.add(stream)) { + // Update the number of active streams initiated by the endpoint. + stream.createdBy().numActiveStreams++; + + for (int i = 0; i < listeners.size(); i++) { + listeners.get(i).onStreamActive(stream); + } + } + } + + void removeFromActiveStreams(DefaultStream stream) { + if (streams.remove(stream)) { + try { + // Update the number of active streams initiated by the endpoint. + stream.createdBy().numActiveStreams--; + + for (int i = 0; i < listeners.size(); i++) { + listeners.get(i).onStreamClosed(stream); + } + } finally { + // Mark this stream for removal. + removalPolicy.markForRemoval(stream); + } + } + } + + private boolean allowModifications() { + return pendingIterations == 0; + } + } } diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2LocalFlowController.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2LocalFlowController.java index b1cb8e133b..a12ae287c3 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2LocalFlowController.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2LocalFlowController.java @@ -28,6 +28,7 @@ import static java.lang.Math.max; import static java.lang.Math.min; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http2.Http2Connection.StreamVisitor; import io.netty.handler.codec.http2.Http2Exception.CompositeStreamException; import io.netty.handler.codec.http2.Http2Exception.StreamException; @@ -35,7 +36,6 @@ import io.netty.handler.codec.http2.Http2Exception.StreamException; * Basic implementation of {@link Http2LocalFlowController}. */ public class DefaultHttp2LocalFlowController implements Http2LocalFlowController { - private static final int DEFAULT_COMPOSITE_EXCEPTION_SIZE = 4; /** * The default ratio of window size to initial window size below which a {@code WINDOW_UPDATE} * is sent to expand the window. @@ -82,23 +82,9 @@ public class DefaultHttp2LocalFlowController implements Http2LocalFlowController int delta = newWindowSize - initialWindowSize; initialWindowSize = newWindowSize; - CompositeStreamException compositeException = null; - for (Http2Stream stream : connection.activeStreams()) { - try { - // Increment flow control window first so state will be consistent if overflow is detected - FlowState state = state(stream); - state.incrementFlowControlWindows(delta); - state.incrementInitialStreamWindow(delta); - } catch (StreamException e) { - if (compositeException == null) { - compositeException = new CompositeStreamException(e.error(), DEFAULT_COMPOSITE_EXCEPTION_SIZE); - } - compositeException.add(e); - } - } - if (compositeException != null) { - throw compositeException; - } + WindowUpdateVisitor visitor = new WindowUpdateVisitor(delta); + connection.forEachActiveStream(visitor); + visitor.throwIfError(); } @Override @@ -208,7 +194,7 @@ public class DefaultHttp2LocalFlowController implements Http2LocalFlowController return state(connection.connectionStream()); } - private FlowState state(Http2Stream stream) { + private static FlowState state(Http2Stream stream) { checkNotNull(stream, "stream"); return stream.getProperty(FlowState.class); } @@ -391,4 +377,38 @@ public class DefaultHttp2LocalFlowController implements Http2LocalFlowController ctx.flush(); } } + + /** + * Provides a means to iterate over all active streams and increment the flow control windows. + */ + private static final class WindowUpdateVisitor implements StreamVisitor { + private CompositeStreamException compositeException; + private final int delta; + + public WindowUpdateVisitor(int delta) { + this.delta = delta; + } + + @Override + public boolean visit(Http2Stream stream) throws Http2Exception { + try { + // Increment flow control window first so state will be consistent if overflow is detected. + FlowState state = state(stream); + state.incrementFlowControlWindows(delta); + state.incrementInitialStreamWindow(delta); + } catch (StreamException e) { + if (compositeException == null) { + compositeException = new CompositeStreamException(e.error(), 4); + } + compositeException.add(e); + } + return true; + } + + public void throwIfError() throws CompositeStreamException { + if (compositeException != null) { + throw compositeException; + } + } + } } diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowController.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowController.java index 04e4b19241..a6eebdc0e6 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowController.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowController.java @@ -23,16 +23,23 @@ import static io.netty.util.internal.ObjectUtil.checkNotNull; import static java.lang.Math.max; import static java.lang.Math.min; import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http2.Http2Connection.StreamVisitor; import io.netty.handler.codec.http2.Http2Stream.State; import java.util.ArrayDeque; -import java.util.Collection; import java.util.Deque; /** * Basic implementation of {@link Http2RemoteFlowController}. */ public class DefaultHttp2RemoteFlowController implements Http2RemoteFlowController { + private static final StreamVisitor WRITE_ALLOCATED_BYTES = new StreamVisitor() { + @Override + public boolean visit(Http2Stream stream) { + state(stream).writeAllocatedBytes(); + return true; + } + }; private final Http2Connection connection; private int initialWindowSize = DEFAULT_WINDOW_SIZE; private ChannelHandlerContext ctx; @@ -115,12 +122,16 @@ public class DefaultHttp2RemoteFlowController implements Http2RemoteFlowControll throw new IllegalArgumentException("Invalid initial window size: " + newWindowSize); } - int delta = newWindowSize - initialWindowSize; + final int delta = newWindowSize - initialWindowSize; initialWindowSize = newWindowSize; - for (Http2Stream stream : connection.activeStreams()) { - // Verify that the maximum value is not exceeded by this change. - state(stream).incrementStreamWindow(delta); - } + connection.forEachActiveStream(new StreamVisitor() { + @Override + public boolean visit(Http2Stream stream) throws Http2Exception { + // Verify that the maximum value is not exceeded by this change. + state(stream).incrementStreamWindow(delta); + return true; + } + }); if (delta > 0) { // The window size increased, send any pending frames for all streams. @@ -215,7 +226,7 @@ public class DefaultHttp2RemoteFlowController implements Http2RemoteFlowControll /** * Writes as many pending bytes as possible, according to stream priority. */ - private void writePendingBytes() { + private void writePendingBytes() throws Http2Exception { Http2Stream connectionStream = connection.connectionStream(); int connectionWindow = state(connectionStream).window(); @@ -223,13 +234,8 @@ public class DefaultHttp2RemoteFlowController implements Http2RemoteFlowControll // Allocate the bytes for the connection window to the streams, but do not write. allocateBytesForTree(connectionStream, connectionWindow); - // Now write all of the allocated bytes. Copying the activeStreams array to avoid - // side effects due to stream removal/addition which might occur as a result - // of end-of-stream or errors. - Collection streams = connection.activeStreams(); - for (Http2Stream stream : streams.toArray(new Http2Stream[streams.size()])) { - state(stream).writeAllocatedBytes(); - } + // Now write all of the allocated bytes. + connection.forEachActiveStream(WRITE_ALLOCATED_BYTES); flush(); } } diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Connection.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Connection.java index 3b4ab03657..cf6a2fbc9e 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Connection.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Connection.java @@ -17,8 +17,6 @@ package io.netty.handler.codec.http2; import io.netty.buffer.ByteBuf; -import java.util.Collection; - /** * Manager for the state of an HTTP/2 connection with the remote end-point. */ @@ -260,10 +258,25 @@ public interface Http2Connection { int numActiveStreams(); /** - * Gets all streams that are actively in use (i.e. {@code OPEN} or {@code HALF CLOSED}). The returned collection is - * sorted by priority. + * Provide a means of iterating over the collection of active streams. + * + * @param visitor The visitor which will visit each active stream. + * @return The stream before iteration stopped or {@code null} if iteration went past the end. */ - Collection activeStreams(); + Http2Stream forEachActiveStream(StreamVisitor visitor) throws Http2Exception; + + /** + * A visitor that allows iteration over a collection of streams. + */ + interface StreamVisitor { + /** + * @return
    + *
  • {@code true} if the processor wants to continue the loop and handle the entry.
  • + *
  • {@code false} if the processor wants to stop handling headers and abort the loop.
  • + *
+ */ + boolean visit(Http2Stream stream) throws Http2Exception; + } /** * Indicates whether or not the local endpoint for this connection is the server. diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java index 2f8189ced3..4f89a77256 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java @@ -33,6 +33,7 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelOutboundHandler; import io.netty.channel.ChannelPromise; import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.http2.Http2Connection.StreamVisitor; import io.netty.handler.codec.http2.Http2Exception.CompositeStreamException; import io.netty.handler.codec.http2.Http2Exception.StreamException; import io.netty.util.concurrent.GenericFutureListener; @@ -40,7 +41,6 @@ import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; import java.net.SocketAddress; -import java.util.Collection; import java.util.List; /** @@ -145,10 +145,17 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http public void channelInactive(ChannelHandlerContext ctx) throws Exception { try { - ChannelFuture future = ctx.newSucceededFuture(); - final Collection streams = connection().activeStreams(); - for (Http2Stream s : streams.toArray(new Http2Stream[streams.size()])) { - closeStream(s, future); + final Http2Connection connection = connection(); + // Check if there are streams to avoid the overhead of creating the ChannelFuture. + if (connection.numActiveStreams() > 0) { + final ChannelFuture future = ctx.newSucceededFuture(); + connection.forEachActiveStream(new StreamVisitor() { + @Override + public boolean visit(Http2Stream stream) throws Http2Exception { + closeStream(stream, future); + return true; + } + }); } } finally { try { diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoderTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoderTest.java index 90bca7392b..76e10c9287 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoderTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoderTest.java @@ -46,6 +46,7 @@ import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import io.netty.channel.DefaultChannelPromise; +import io.netty.handler.codec.http2.Http2Connection.StreamVisitor; import io.netty.handler.codec.http2.Http2Exception.ClosedStreamCreationException; import java.util.Collections; @@ -124,7 +125,16 @@ public class DefaultHttp2ConnectionDecoderTest { when(stream.state()).thenReturn(OPEN); when(stream.open(anyBoolean())).thenReturn(stream); when(pushStream.id()).thenReturn(PUSH_STREAM_ID); - when(connection.activeStreams()).thenReturn(Collections.singletonList(stream)); + doAnswer(new Answer() { + @Override + public Http2Stream answer(InvocationOnMock in) throws Throwable { + StreamVisitor visitor = in.getArgumentAt(0, StreamVisitor.class); + if (!visitor.visit(stream)) { + return stream; + } + return null; + } + }).when(connection).forEachActiveStream(any(StreamVisitor.class)); when(connection.stream(STREAM_ID)).thenReturn(stream); when(connection.requireStream(STREAM_ID)).thenReturn(stream); when(connection.local()).thenReturn(local); diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoderTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoderTest.java index 529216ff71..45519a5260 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoderTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoderTest.java @@ -53,12 +53,12 @@ import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import io.netty.channel.DefaultChannelPromise; +import io.netty.handler.codec.http2.Http2Connection.StreamVisitor; import io.netty.handler.codec.http2.Http2Exception.ClosedStreamCreationException; import io.netty.handler.codec.http2.Http2RemoteFlowController.FlowControlled; import io.netty.util.concurrent.ImmediateEventExecutor; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import org.junit.Before; @@ -138,7 +138,16 @@ public class DefaultHttp2ConnectionEncoderTest { when(stream.state()).thenReturn(OPEN); when(stream.open(anyBoolean())).thenReturn(stream); when(pushStream.id()).thenReturn(PUSH_STREAM_ID); - when(connection.activeStreams()).thenReturn(Collections.singletonList(stream)); + doAnswer(new Answer() { + @Override + public Http2Stream answer(InvocationOnMock in) throws Throwable { + StreamVisitor visitor = in.getArgumentAt(0, StreamVisitor.class); + if (!visitor.visit(stream)) { + return stream; + } + return null; + } + }).when(connection).forEachActiveStream(any(StreamVisitor.class)); when(connection.stream(STREAM_ID)).thenReturn(stream); when(connection.requireStream(STREAM_ID)).thenReturn(stream); when(connection.local()).thenReturn(local); diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionTest.java index cdcbecbd14..1b6419573a 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionTest.java @@ -21,7 +21,6 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; -import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyShort; import static org.mockito.Matchers.eq; @@ -30,7 +29,6 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; - import io.netty.buffer.Unpooled; import io.netty.handler.codec.http2.Http2Connection.Endpoint; import io.netty.handler.codec.http2.Http2Stream.State; @@ -118,25 +116,25 @@ public class DefaultHttp2ConnectionTest { Http2Stream stream = server.local().createStream(2).open(false); assertEquals(2, stream.id()); assertEquals(State.OPEN, stream.state()); - assertEquals(1, server.activeStreams().size()); + assertEquals(1, server.numActiveStreams()); assertEquals(2, server.local().lastStreamCreated()); stream = server.local().createStream(4).open(true); assertEquals(4, stream.id()); assertEquals(State.HALF_CLOSED_LOCAL, stream.state()); - assertEquals(2, server.activeStreams().size()); + assertEquals(2, server.numActiveStreams()); assertEquals(4, server.local().lastStreamCreated()); stream = server.remote().createStream(3).open(true); assertEquals(3, stream.id()); assertEquals(State.HALF_CLOSED_REMOTE, stream.state()); - assertEquals(3, server.activeStreams().size()); + assertEquals(3, server.numActiveStreams()); assertEquals(3, server.remote().lastStreamCreated()); stream = server.remote().createStream(5).open(false); assertEquals(5, stream.id()); assertEquals(State.OPEN, stream.state()); - assertEquals(4, server.activeStreams().size()); + assertEquals(4, server.numActiveStreams()); assertEquals(5, server.remote().lastStreamCreated()); } @@ -145,25 +143,25 @@ public class DefaultHttp2ConnectionTest { Http2Stream stream = client.remote().createStream(2).open(false); assertEquals(2, stream.id()); assertEquals(State.OPEN, stream.state()); - assertEquals(1, client.activeStreams().size()); + assertEquals(1, client.numActiveStreams()); assertEquals(2, client.remote().lastStreamCreated()); stream = client.remote().createStream(4).open(true); assertEquals(4, stream.id()); assertEquals(State.HALF_CLOSED_REMOTE, stream.state()); - assertEquals(2, client.activeStreams().size()); + assertEquals(2, client.numActiveStreams()); assertEquals(4, client.remote().lastStreamCreated()); stream = client.local().createStream(3).open(true); assertEquals(3, stream.id()); assertEquals(State.HALF_CLOSED_LOCAL, stream.state()); - assertEquals(3, client.activeStreams().size()); + assertEquals(3, client.numActiveStreams()); assertEquals(3, client.local().lastStreamCreated()); stream = client.local().createStream(5).open(false); assertEquals(5, stream.id()); assertEquals(State.OPEN, stream.state()); - assertEquals(4, client.activeStreams().size()); + assertEquals(4, client.numActiveStreams()); assertEquals(5, client.local().lastStreamCreated()); } @@ -173,7 +171,7 @@ public class DefaultHttp2ConnectionTest { Http2Stream pushStream = server.local().reservePushStream(2, stream); assertEquals(2, pushStream.id()); assertEquals(State.RESERVED_LOCAL, pushStream.state()); - assertEquals(1, server.activeStreams().size()); + assertEquals(1, server.numActiveStreams()); assertEquals(2, server.local().lastStreamCreated()); } @@ -183,7 +181,7 @@ public class DefaultHttp2ConnectionTest { Http2Stream pushStream = server.local().reservePushStream(4, stream); assertEquals(4, pushStream.id()); assertEquals(State.RESERVED_LOCAL, pushStream.state()); - assertEquals(1, server.activeStreams().size()); + assertEquals(1, server.numActiveStreams()); assertEquals(4, server.local().lastStreamCreated()); } @@ -226,7 +224,7 @@ public class DefaultHttp2ConnectionTest { Http2Stream stream = server.remote().createStream(3).open(true); stream.close(); assertEquals(State.CLOSED, stream.state()); - assertTrue(server.activeStreams().isEmpty()); + assertEquals(0, server.numActiveStreams()); } @Test @@ -234,7 +232,7 @@ public class DefaultHttp2ConnectionTest { Http2Stream stream = server.remote().createStream(3).open(false); stream.closeLocalSide(); assertEquals(State.HALF_CLOSED_LOCAL, stream.state()); - assertEquals(1, server.activeStreams().size()); + assertEquals(1, server.numActiveStreams()); } @Test @@ -242,7 +240,7 @@ public class DefaultHttp2ConnectionTest { Http2Stream stream = server.remote().createStream(3).open(false); stream.closeRemoteSide(); assertEquals(State.HALF_CLOSED_REMOTE, stream.state()); - assertEquals(1, server.activeStreams().size()); + assertEquals(1, server.numActiveStreams()); } @Test @@ -250,7 +248,7 @@ public class DefaultHttp2ConnectionTest { Http2Stream stream = server.remote().createStream(3).open(true); stream.closeLocalSide(); assertEquals(State.CLOSED, stream.state()); - assertTrue(server.activeStreams().isEmpty()); + assertEquals(0, server.numActiveStreams()); } @Test(expected = Http2Exception.class) diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionHandlerTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionHandlerTest.java index 8a3d3203e8..7762c3f6b4 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionHandlerTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionHandlerTest.java @@ -44,10 +44,9 @@ import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import io.netty.channel.DefaultChannelPromise; +import io.netty.handler.codec.http2.Http2Connection.StreamVisitor; import io.netty.util.concurrent.GenericFutureListener; -import java.util.Arrays; -import java.util.Collections; import java.util.List; import org.junit.After; @@ -113,8 +112,18 @@ public class Http2ConnectionHandlerTest { when(channel.isActive()).thenReturn(true); when(connection.remote()).thenReturn(remote); when(connection.local()).thenReturn(local); - when(connection.activeStreams()).thenReturn(Collections.singletonList(stream)); + doAnswer(new Answer() { + @Override + public Http2Stream answer(InvocationOnMock in) throws Throwable { + StreamVisitor visitor = in.getArgumentAt(0, StreamVisitor.class); + if (!visitor.visit(stream)) { + return stream; + } + return null; + } + }).when(connection).forEachActiveStream(any(StreamVisitor.class)); when(connection.stream(NON_EXISTANT_STREAM_ID)).thenReturn(null); + when(connection.numActiveStreams()).thenReturn(1); when(connection.stream(STREAM_ID)).thenReturn(stream); when(stream.open(anyBoolean())).thenReturn(stream); when(encoder.writeSettings(eq(ctx), any(Http2Settings.class), eq(promise))).thenReturn(future); @@ -266,8 +275,6 @@ public class Http2ConnectionHandlerTest { @Test public void closeListenerShouldBeNotifiedOnlyOneTime() throws Exception { handler = newHandler(); - when(connection.activeStreams()).thenReturn(Arrays.asList(stream)); - when(connection.numActiveStreams()).thenReturn(1); when(future.isDone()).thenReturn(true); when(future.isSuccess()).thenReturn(true); doAnswer(new Answer() { @@ -276,7 +283,12 @@ public class Http2ConnectionHandlerTest { Object[] args = invocation.getArguments(); GenericFutureListener listener = (GenericFutureListener) args[0]; // Simulate that all streams have become inactive by the time the future completes. - when(connection.activeStreams()).thenReturn(Collections.emptyList()); + doAnswer(new Answer() { + @Override + public Http2Stream answer(InvocationOnMock in) throws Throwable { + return null; + } + }).when(connection).forEachActiveStream(any(StreamVisitor.class)); when(connection.numActiveStreams()).thenReturn(0); // Simulate the future being completed. listener.operationComplete(future);