netty5/transport/src/test/java/io/netty/channel/PendingWriteQueueTest.java
Norman Maurer c4dbbe39c9
Add executor() to ChannelOutboundInvoker and let it replace eventLoop() (#11617)
Motivation:

We should just add `executor()` to the `ChannelOutboundInvoker` interface and override this method in `Channel` to return `EventLoop`.

Modifications:

- Add `executor()` method to `ChannelOutboundInvoker`
- Let `Channel` override this method and return `EventLoop`.
- Adjust all usages of `eventLoop()`
- Add some default implementations

Result:

API cleanup
2021-08-25 18:31:24 +02:00

381 lines
14 KiB
Java

/*
* Copyright 2014 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.channel;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.util.CharsetUtil;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.Promise;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
public class PendingWriteQueueTest {
@Test
public void testRemoveAndWrite() {
assertWrite(new TestHandler() {
@Override
public void flush(ChannelHandlerContext ctx) {
assertFalse(ctx.channel().isWritable(), "Should not be writable anymore");
Future<Void> future = queue.removeAndWrite();
future.addListener(future1 -> assertQueueEmpty(queue));
super.flush(ctx);
}
}, 1);
}
@Test
public void testRemoveAndWriteAll() {
assertWrite(new TestHandler() {
@Override
public void flush(ChannelHandlerContext ctx) {
assertFalse(ctx.channel().isWritable(), "Should not be writable anymore");
Future<Void> future = queue.removeAndWriteAll();
future.addListener(future1 -> assertQueueEmpty(queue));
super.flush(ctx);
}
}, 3);
}
@Test
public void testRemoveAndFail() {
assertWriteFails(new TestHandler() {
@Override
public void flush(ChannelHandlerContext ctx) {
queue.removeAndFail(new TestException());
super.flush(ctx);
}
}, 1);
}
@Test
public void testRemoveAndFailAll() {
assertWriteFails(new TestHandler() {
@Override
public void flush(ChannelHandlerContext ctx) {
queue.removeAndFailAll(new TestException());
super.flush(ctx);
}
}, 3);
}
@Test
public void shouldFireChannelWritabilityChangedAfterRemoval() {
final AtomicReference<PendingWriteQueue> queueRef = new AtomicReference<>();
final ByteBuf msg = Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII);
final EmbeddedChannel channel = new EmbeddedChannel(new ChannelHandler() {
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
queueRef.set(new PendingWriteQueue(ctx));
}
@Override
public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
final PendingWriteQueue queue = queueRef.get();
final ByteBuf msg = (ByteBuf) queue.current();
if (msg == null) {
return;
}
assertThat(msg.refCnt(), is(1));
// This call will trigger another channelWritabilityChanged() event because the number of
// pending bytes will go below the low watermark.
//
// If PendingWriteQueue.remove() did not remove the current entry before triggering
// channelWritabilityChanged() event, we will end up with attempting to remove the same
// element twice, resulting in the double release.
queue.remove();
assertThat(msg.refCnt(), is(0));
}
});
channel.config().setWriteBufferLowWaterMark(1);
channel.config().setWriteBufferHighWaterMark(3);
final PendingWriteQueue queue = queueRef.get();
channel.executor().execute(() -> {
// Trigger channelWritabilityChanged() by adding a message that's larger than the high watermark.
queue.add(msg, channel.newPromise());
});
channel.finish();
assertThat(msg.refCnt(), is(0));
}
private static void assertWrite(ChannelHandler handler, int count) {
final ByteBuf buffer = Unpooled.copiedBuffer("Test", CharsetUtil.US_ASCII);
final EmbeddedChannel channel = new EmbeddedChannel(handler);
channel.config().setWriteBufferLowWaterMark(1);
channel.config().setWriteBufferHighWaterMark(3);
ByteBuf[] buffers = new ByteBuf[count];
for (int i = 0; i < buffers.length; i++) {
buffers[i] = buffer.retainedDuplicate();
}
assertTrue(channel.writeOutbound(buffers));
assertTrue(channel.finish());
channel.closeFuture().syncUninterruptibly();
for (int i = 0; i < buffers.length; i++) {
assertBuffer(channel, buffer);
}
buffer.release();
assertNull(channel.readOutbound());
}
private static void assertBuffer(EmbeddedChannel channel, ByteBuf buffer) {
ByteBuf written = channel.readOutbound();
assertEquals(buffer, written);
written.release();
}
private static void assertQueueEmpty(PendingWriteQueue queue) {
assertTrue(queue.isEmpty());
assertEquals(0, queue.size());
assertEquals(0, queue.bytes());
assertNull(queue.current());
assertNull(queue.removeAndWrite());
assertNull(queue.removeAndWriteAll());
}
private static void assertWriteFails(ChannelHandler handler, int count) {
final ByteBuf buffer = Unpooled.copiedBuffer("Test", CharsetUtil.US_ASCII);
final EmbeddedChannel channel = new EmbeddedChannel(handler);
ByteBuf[] buffers = new ByteBuf[count];
for (int i = 0; i < buffers.length; i++) {
buffers[i] = buffer.retainedDuplicate();
}
try {
assertFalse(channel.writeOutbound(buffers));
fail();
} catch (Exception e) {
assertTrue(e instanceof TestException);
}
assertFalse(channel.finish());
channel.closeFuture().syncUninterruptibly();
buffer.release();
assertNull(channel.readOutbound());
}
private static EmbeddedChannel newChannel() {
// Add a handler so we can access a ChannelHandlerContext via the ChannelPipeline.
return new EmbeddedChannel(new ChannelHandler() { });
}
@Test
public void testRemoveAndFailAllReentrantFailAll() {
EmbeddedChannel channel = newChannel();
final PendingWriteQueue queue = new PendingWriteQueue(channel.pipeline().firstContext());
Promise<Void> promise = channel.newPromise();
promise.addListener(future -> queue.removeAndFailAll(new IllegalStateException()));
Promise<Void> promise2 = channel.newPromise();
channel.executor().execute(() -> {
queue.add(1L, promise);
queue.add(2L, promise2);
queue.removeAndFailAll(new Exception());
});
assertTrue(promise.isDone());
assertFalse(promise.isSuccess());
assertTrue(promise2.isDone());
assertFalse(promise2.isSuccess());
assertFalse(channel.finish());
}
@Test
public void testRemoveAndWriteAllReentrantWrite() {
EmbeddedChannel channel = new EmbeddedChannel(new ChannelHandler() {
@Override
public Future<Void> write(ChannelHandlerContext ctx, Object msg) {
// Convert to writeAndFlush(...) so the promise will be notified by the transport.
return ctx.writeAndFlush(msg);
}
}, new ChannelHandler() { });
final PendingWriteQueue queue = new PendingWriteQueue(channel.pipeline().lastContext());
Promise<Void> promise = channel.newPromise();
final Promise<Void> promise3 = channel.newPromise();
promise.addListener(future -> {
queue.add(3L, promise3);
});
Promise<Void> promise2 = channel.newPromise();
channel.executor().execute(() -> {
queue.add(1L, promise);
queue.add(2L, promise2);
queue.removeAndWriteAll();
});
assertTrue(promise.isDone());
assertTrue(promise.isSuccess());
assertTrue(promise2.isDone());
assertTrue(promise2.isSuccess());
assertFalse(promise3.isDone());
assertFalse(promise3.isSuccess());
channel.executor().execute(queue::removeAndWriteAll);
assertTrue(promise3.isDone());
assertTrue(promise3.isSuccess());
channel.runPendingTasks();
assertTrue(channel.finish());
assertEquals(1L, (Long) channel.readOutbound());
assertEquals(2L, (Long) channel.readOutbound());
assertEquals(3L, (Long) channel.readOutbound());
}
@Disabled("Need to verify and think about if the assumptions made by this test are valid at all.")
@Test
public void testRemoveAndFailAllReentrantWrite() {
final List<Integer> failOrder = Collections.synchronizedList(new ArrayList<>());
EmbeddedChannel channel = newChannel();
final PendingWriteQueue queue = new PendingWriteQueue(channel.pipeline().firstContext());
Promise<Void> promise = channel.newPromise();
final Promise<Void> promise3 = channel.newPromise();
promise3.addListener(future -> failOrder.add(3));
promise.addListener(future -> {
failOrder.add(1);
queue.add(3L, promise3);
});
Promise<Void> promise2 = channel.newPromise();
promise2.addListener(future -> failOrder.add(2));
channel.executor().execute(() -> {
queue.add(1L, promise);
queue.add(2L, promise2);
queue.removeAndFailAll(new Exception());
});
assertTrue(promise.isDone());
assertFalse(promise.isSuccess());
assertTrue(promise2.isDone());
assertFalse(promise2.isSuccess());
assertTrue(promise3.isDone());
assertFalse(promise3.isSuccess());
assertFalse(channel.finish());
assertEquals(1, (int) failOrder.get(0));
assertEquals(2, (int) failOrder.get(1));
assertEquals(3, (int) failOrder.get(2));
}
@Test
public void testRemoveAndWriteAllReentrance() {
EmbeddedChannel channel = newChannel();
final PendingWriteQueue queue = new PendingWriteQueue(channel.pipeline().firstContext());
Promise<Void> promise = channel.newPromise();
promise.addListener(future -> queue.removeAndWriteAll());
Promise<Void> promise2 = channel.newPromise();
channel.executor().execute(() -> {
queue.add(1L, promise);
queue.add(2L, promise2);
queue.removeAndWriteAll();
});
channel.flush();
assertTrue(promise.isSuccess());
assertTrue(promise2.isSuccess());
assertTrue(channel.finish());
assertEquals(1L, (Long) channel.readOutbound());
assertEquals(2L, (Long) channel.readOutbound());
assertNull(channel.readOutbound());
assertNull(channel.readInbound());
}
// See https://github.com/netty/netty/issues/3967
@Test
public void testCloseChannelOnCreation() {
EmbeddedChannel channel = newChannel();
ChannelHandlerContext context = channel.pipeline().firstContext();
channel.close().syncUninterruptibly();
final PendingWriteQueue queue = new PendingWriteQueue(context);
IllegalStateException ex = new IllegalStateException();
Promise<Void> promise = channel.newPromise();
channel.executor().execute(() -> {
queue.add(1L, promise);
queue.removeAndFailAll(ex);
});
assertSame(ex, promise.cause());
}
private static class TestHandler implements ChannelHandler {
protected PendingWriteQueue queue;
private int expectedSize;
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
ctx.fireChannelActive();
assertQueueEmpty(queue);
assertTrue(ctx.channel().isWritable(), "Should be writable");
}
@Override
public Future<Void> write(ChannelHandlerContext ctx, Object msg) {
Promise<Void> promise = ctx.newPromise();
queue.add(msg, promise);
assertFalse(queue.isEmpty());
assertEquals(++expectedSize, queue.size());
assertNotNull(queue.current());
return promise;
}
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
queue = new PendingWriteQueue(ctx);
}
}
private static final class TestException extends Exception {
private static final long serialVersionUID = -9018570103039458401L;
}
}