Adding ability omit the implicit #flush() call in EmbeddedChannel#writeOutbound() and

the implicit #fireChannelReadComplete() in EmbeddedChannel#writeInbound().

Motivation

We use EmbeddedChannels to implement a ProxyChannel of some sorts that shovels
messages between a source and a destination Channel. The latter are real network
channels (such as Epoll) and they may or may not be managed in a ChannelPool. We
could fuse both ends directly together but the EmbeddedChannel provides a nice
disposable section of a ChannelPipeline that can be used to instrument the messages
that are passing through the proxy portion.

The ideal flow looks abount like this:

source#channelRead() -> proxy#writeOutbound() -> destination#write()
source#channelReadComplete() -> proxy#flushOutbound() -> destination#flush()

destination#channelRead() -> proxy#writeInbound() -> source#write()
destination#channelReadComplete() -> proxy#flushInbound() -> source#flush()

The problem is that #writeOutbound() and #writeInbound() emit surplus #flush()
and #fireChannelReadComplete() events which in turn yield to surplus #flush()
calls on both ends of the pipeline.

Modifications

Introduce a new set of write methods that reain the same sematics as the #write()
method and #flushOutbound() and #flushInbound().

Result

It's possible to implement the above ideal flow.

Fix for EmbeddedChannel#ensureOpen() and Unit Tests for it

Some PR stuff.
This commit is contained in:
Roger Kapsi 2016-08-10 14:55:38 -04:00 committed by Scott Mitchell
parent 11f09bb5a6
commit c0ca20b1fc
2 changed files with 320 additions and 35 deletions

View File

@ -15,6 +15,11 @@
*/
package io.netty.channel.embedded;
import java.net.SocketAddress;
import java.nio.channels.ClosedChannelException;
import java.util.ArrayDeque;
import java.util.Queue;
import io.netty.channel.AbstractChannel;
import io.netty.channel.Channel;
import io.netty.channel.ChannelConfig;
@ -36,11 +41,6 @@ import io.netty.util.internal.RecyclableArrayList;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import java.net.SocketAddress;
import java.nio.channels.ClosedChannelException;
import java.util.ArrayDeque;
import java.util.Queue;
/**
* Base class for {@link Channel} implementations that are used in an embedded fashion.
*/
@ -226,12 +226,53 @@ public class EmbeddedChannel extends AbstractChannel {
for (Object m: msgs) {
p.fireChannelRead(m);
}
p.fireChannelReadComplete();
runPendingTasks();
checkException();
flushInbound(false, voidPromise());
return isNotEmpty(inboundMessages);
}
/**
* Writes one message to the inbound of this {@link Channel} and does not flush it. This
* method is conceptually equivalent to {@link #write(Object)}.
*
* @see #writeOneOutbound(Object)
*/
public ChannelFuture writeOneInbound(Object msg) {
return writeOneInbound(msg, newPromise());
}
/**
* Writes one message to the inbound of this {@link Channel} and does not flush it. This
* method is conceptually equivalent to {@link #write(Object, ChannelPromise)}.
*
* @see #writeOneOutbound(Object, ChannelPromise)
*/
public ChannelFuture writeOneInbound(Object msg, ChannelPromise promise) {
if (checkOpen(true)) {
pipeline().fireChannelRead(msg);
}
return checkException(promise);
}
/**
* Flushes the inbound of this {@link Channel}. This method is conceptually equivalent to {@link #flush()}.
*
* @see #flushOutbound()
*/
public EmbeddedChannel flushInbound() {
flushInbound(true, voidPromise());
return this;
}
private ChannelFuture flushInbound(boolean recordException, ChannelPromise promise) {
if (checkOpen(recordException)) {
pipeline().fireChannelReadComplete();
runPendingTasks();
}
return checkException(promise);
}
/**
* Write messages to the outbound of this {@link Channel}.
*
@ -252,10 +293,8 @@ public class EmbeddedChannel extends AbstractChannel {
}
futures.add(write(m));
}
// We need to call runPendingTasks first as a ChannelOutboundHandler may used eventloop.execute(...) to
// delay the write on the next eventloop run.
runPendingTasks();
flush();
flushOutbound0();
int size = futures.size();
for (int i = 0; i < size; i++) {
@ -275,6 +314,50 @@ public class EmbeddedChannel extends AbstractChannel {
}
}
/**
* Writes one message to the outbound of this {@link Channel} and does not flush it. This
* method is conceptually equivalent to {@link #write(Object)}.
*
* @see #writeOneInbound(Object)
*/
public ChannelFuture writeOneOutbound(Object msg) {
return writeOneOutbound(msg, newPromise());
}
/**
* Writes one message to the outbound of this {@link Channel} and does not flush it. This
* method is conceptually equivalent to {@link #write(Object, ChannelPromise)}.
*
* @see #writeOneInbound(Object, ChannelPromise)
*/
public ChannelFuture writeOneOutbound(Object msg, ChannelPromise promise) {
if (checkOpen(true)) {
return write(msg, promise);
}
return checkException(promise);
}
/**
* Flushes the outbound of this {@link Channel}. This method is conceptually equivalent to {@link #flush()}.
*
* @see #flushInbound()
*/
public EmbeddedChannel flushOutbound() {
if (checkOpen(true)) {
flushOutbound0();
}
checkException(voidPromise());
return this;
}
private void flushOutbound0() {
// We need to call runPendingTasks first as a ChannelOutboundHandler may used eventloop.execute(...) to
// delay the write on the next eventloop run.
runPendingTasks();
flush();
}
/**
* Mark this {@link Channel} as finished. Any futher try to write data to it will fail.
*
@ -436,26 +519,51 @@ public class EmbeddedChannel extends AbstractChannel {
}
}
/**
* Checks for the presence of an {@link Exception}.
*/
private ChannelFuture checkException(ChannelPromise promise) {
Throwable t = lastException;
if (t != null) {
lastException = null;
if (promise.isVoid()) {
PlatformDependent.throwException(t);
}
return promise.setFailure(t);
}
return promise.setSuccess();
}
/**
* Check if there was any {@link Throwable} received and if so rethrow it.
*/
public void checkException() {
Throwable t = lastException;
if (t == null) {
return;
}
checkException(voidPromise());
}
lastException = null;
/**
* Returns {@code true} if the {@link Channel} is open and records optionally
* an {@link Exception} if it isn't.
*/
private boolean checkOpen(boolean recordException) {
if (!isOpen()) {
if (recordException) {
recordException(new ClosedChannelException());
}
return false;
}
PlatformDependent.throwException(t);
return true;
}
/**
* Ensure the {@link Channel} is open and if not throw an exception.
*/
protected final void ensureOpen() {
if (!isOpen()) {
recordException(new ClosedChannelException());
if (!checkOpen(true)) {
checkException();
}
}
@ -516,11 +624,27 @@ public class EmbeddedChannel extends AbstractChannel {
}
ReferenceCountUtil.retain(msg);
outboundMessages().add(msg);
handleOutboundMessage(msg);
in.remove();
}
}
/**
* Called for each outbound message.
*
* @see #doWrite(ChannelOutboundBuffer)
*/
protected void handleOutboundMessage(Object msg) {
outboundMessages().add(msg);
}
/**
* Called for each inbound message.
*/
protected void handleInboundMessage(Object msg) {
inboundMessages().add(msg);
}
private class DefaultUnsafe extends AbstractUnsafe {
@Override
public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
@ -540,7 +664,7 @@ public class EmbeddedChannel extends AbstractChannel {
@Override
protected void onUnhandledInboundMessage(Object msg) {
inboundMessages().add(msg);
handleInboundMessage(msg);
}
}
}

View File

@ -15,6 +15,23 @@
*/
package io.netty.channel.embedded;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.nio.channels.ClosedChannelException;
import java.util.ArrayDeque;
import java.util.Queue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.Test;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
@ -28,22 +45,10 @@ import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.FutureListener;
import io.netty.util.concurrent.ScheduledFuture;
import org.junit.Test;
import java.util.ArrayDeque;
import java.util.Queue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
public class EmbeddedChannelTest {
@ -372,6 +377,162 @@ public class EmbeddedChannelTest {
assertNull(channel.readOutbound());
}
@Test
public void testFlushInbound() throws InterruptedException {
final CountDownLatch latch = new CountDownLatch(1);
EmbeddedChannel channel = new EmbeddedChannel(new ChannelInboundHandlerAdapter() {
@Override
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
latch.countDown();
}
});
channel.flushInbound();
if (!latch.await(1L, TimeUnit.SECONDS)) {
fail("Nobody called #channelReadComplete() in time.");
}
}
@Test
public void testWriteOneInbound() throws InterruptedException {
final CountDownLatch latch = new CountDownLatch(1);
final AtomicInteger flushCount = new AtomicInteger(0);
EmbeddedChannel channel = new EmbeddedChannel(new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
ReferenceCountUtil.release(msg);
latch.countDown();
}
@Override
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
flushCount.incrementAndGet();
}
});
channel.writeOneInbound("Hello, Netty!");
if (!latch.await(1L, TimeUnit.SECONDS)) {
fail("Nobody called #channelRead() in time.");
}
channel.close().syncUninterruptibly();
// There was no #flushInbound() call so nobody should have called
// #channelReadComplete()
assertEquals(flushCount.get(), 0);
}
@Test
public void testFlushOutbound() throws InterruptedException {
final CountDownLatch latch = new CountDownLatch(1);
EmbeddedChannel channel = new EmbeddedChannel(new ChannelOutboundHandlerAdapter() {
@Override
public void flush(ChannelHandlerContext ctx) throws Exception {
latch.countDown();
}
});
channel.flushOutbound();
if (!latch.await(1L, TimeUnit.SECONDS)) {
fail("Nobody called #flush() in time.");
}
}
@Test
public void testWriteOneOutbound() throws InterruptedException {
final CountDownLatch latch = new CountDownLatch(1);
final AtomicInteger flushCount = new AtomicInteger(0);
EmbeddedChannel channel = new EmbeddedChannel(new ChannelOutboundHandlerAdapter() {
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
ctx.write(msg, promise);
latch.countDown();
}
@Override
public void flush(ChannelHandlerContext ctx) throws Exception {
flushCount.incrementAndGet();
}
});
// This shouldn't trigger a #flush()
channel.writeOneOutbound("Hello, Netty!");
if (!latch.await(1L, TimeUnit.SECONDS)) {
fail("Nobody called #write() in time.");
}
channel.close().syncUninterruptibly();
// There was no #flushOutbound() call so nobody should have called #flush()
assertEquals(0, flushCount.get());
}
@Test
public void testEnsureOpen() throws InterruptedException {
EmbeddedChannel channel = new EmbeddedChannel();
channel.close().syncUninterruptibly();
try {
channel.writeOutbound("Hello, Netty!");
fail("This should have failed with a ClosedChannelException");
} catch (Exception expected) {
assertTrue(expected instanceof ClosedChannelException);
}
try {
channel.writeInbound("Hello, Netty!");
fail("This should have failed with a ClosedChannelException");
} catch (Exception expected) {
assertTrue(expected instanceof ClosedChannelException);
}
}
@Test
public void testHandleInboundMessage() throws InterruptedException {
final CountDownLatch latch = new CountDownLatch(1);
EmbeddedChannel channel = new EmbeddedChannel() {
@Override
protected void handleInboundMessage(Object msg) {
latch.countDown();
}
};
channel.writeOneInbound("Hello, Netty!");
if (!latch.await(1L, TimeUnit.SECONDS)) {
fail("Nobody called #handleInboundMessage() in time.");
}
}
@Test
public void testHandleOutboundMessage() throws InterruptedException {
final CountDownLatch latch = new CountDownLatch(1);
EmbeddedChannel channel = new EmbeddedChannel() {
@Override
protected void handleOutboundMessage(Object msg) {
latch.countDown();
}
};
channel.writeOneOutbound("Hello, Netty!");
if (latch.await(50L, TimeUnit.MILLISECONDS)) {
fail("Somebody called unexpectedly #flush()");
}
channel.flushOutbound();
if (!latch.await(1L, TimeUnit.SECONDS)) {
fail("Nobody called #handleOutboundMessage() in time.");
}
}
private static void release(ByteBuf... buffers) {
for (ByteBuf buffer : buffers) {
if (buffer.refCnt() > 0) {