Future methods getNow() and cause() now throw on incomplete futures (#11594)

Motivation:
Since most futures in Netty are of the `Void` type, methods like `getNow()` and `cause()` cannot distinguish if the future has finished or not.
This can cause data race bugs which, in the case of `Void` futures, can be silent.

Modification:
The methods `getNow()` and `cause()` now throw an `IllegalStateException` if the future has not yet completed.
Most use of these methods are inside listeners, and so are not impacted.
One place in `AbstractBootstrap` was doing a racy read and has been adjusted.

Result:
Data race bugs around `getNow()` and `cause()` are no longer silent.
This commit is contained in:
Chris Vest 2021-08-24 15:47:27 +02:00 committed by GitHub
parent 11cdf1d3cf
commit b8e1341142
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 88 additions and 36 deletions

View File

@ -234,8 +234,8 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder, Ht
weight, exclusive, padding, endOfStream, promise); weight, exclusive, padding, endOfStream, promise);
// Writing headers may fail during the encode state if they violate HPACK limits. // Writing headers may fail during the encode state if they violate HPACK limits.
Throwable failureCause = future.cause();
if (failureCause == null) { if (future.isSuccess() || !future.isDone()) {
// Synchronously set the headersSent flag to ensure that we do not subsequently write // Synchronously set the headersSent flag to ensure that we do not subsequently write
// other headers containing pseudo-header fields. // other headers containing pseudo-header fields.
// //
@ -248,6 +248,7 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder, Ht
notifyLifecycleManagerOnError(future, ctx); notifyLifecycleManagerOnError(future, ctx);
} }
} else { } else {
Throwable failureCause = future.cause();
lifecycleManager.onError(ctx, true, failureCause); lifecycleManager.onError(ctx, true, failureCause);
} }
@ -351,8 +352,7 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder, Ht
Future<Void> future = frameWriter.writePushPromise(ctx, streamId, promisedStreamId, headers, padding, Future<Void> future = frameWriter.writePushPromise(ctx, streamId, promisedStreamId, headers, padding,
promise); promise);
// Writing headers may fail during the encode state if they violate HPACK limits. // Writing headers may fail during the encode state if they violate HPACK limits.
Throwable failureCause = future.cause(); if (future.isSuccess() || !future.isDone()) {
if (failureCause == null) {
// This just sets internal stream state which is used elsewhere in the codec and doesn't // This just sets internal stream state which is used elsewhere in the codec and doesn't
// necessarily mean the write will complete successfully. // necessarily mean the write will complete successfully.
stream.pushPromiseSent(); stream.pushPromiseSent();
@ -362,6 +362,7 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder, Ht
notifyLifecycleManagerOnError(future, ctx); notifyLifecycleManagerOnError(future, ctx);
} }
} else { } else {
Throwable failureCause = future.cause();
lifecycleManager.onError(ctx, true, failureCause); lifecycleManager.onError(ctx, true, failureCause);
} }
return future; return future;
@ -581,8 +582,7 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder, Ht
Future<Void> f = sendHeaders(frameWriter, ctx, stream.id(), headers, hasPriority, streamDependency, Future<Void> f = sendHeaders(frameWriter, ctx, stream.id(), headers, hasPriority, streamDependency,
weight, exclusive, padding, endOfStream, promise); weight, exclusive, padding, endOfStream, promise);
// Writing headers may fail during the encode state if they violate HPACK limits. // Writing headers may fail during the encode state if they violate HPACK limits.
Throwable failureCause = f.cause(); if (f.isSuccess() || !f.isDone()) {
if (failureCause == null) {
// This just sets internal stream state which is used elsewhere in the codec and doesn't // This just sets internal stream state which is used elsewhere in the codec and doesn't
// necessarily mean the write will complete successfully. // necessarily mean the write will complete successfully.
stream.headersSent(isInformational); stream.headersSent(isInformational);

View File

@ -256,7 +256,7 @@ public class StreamBufferingEncoderTest {
assertEquals(0, encoder.numBufferedStreams()); assertEquals(0, encoder.numBufferedStreams());
int failCount = 0; int failCount = 0;
for (Future<Void> f : futures) { for (Future<Void> f : futures) {
if (f.cause() != null) { if (f.isFailed()) {
assertTrue(f.cause() instanceof Http2GoAwayException); assertTrue(f.cause() instanceof Http2GoAwayException);
failCount++; failCount++;
} }
@ -272,7 +272,7 @@ public class StreamBufferingEncoderTest {
connection.goAwayReceived(11, 8, EMPTY_BUFFER); connection.goAwayReceived(11, 8, EMPTY_BUFFER);
Future<Void> f = encoderWriteHeaders(5, newPromise()); Future<Void> f = encoderWriteHeaders(5, newPromise());
assertTrue(f.cause() instanceof Http2GoAwayException); assertTrue(f.awaitUninterruptibly().cause() instanceof Http2GoAwayException);
assertEquals(0, encoder.numBufferedStreams()); assertEquals(0, encoder.numBufferedStreams());
} }
@ -461,7 +461,7 @@ public class StreamBufferingEncoderTest {
Future<Void> f = encoderWriteHeaders(-1, newPromise()); Future<Void> f = encoderWriteHeaders(-1, newPromise());
// Verify that the write fails. // Verify that the write fails.
assertNotNull(f.cause()); assertNotNull(f.awaitUninterruptibly().cause());
} }
@Test @Test
@ -493,9 +493,9 @@ public class StreamBufferingEncoderTest {
Future<Void> f3 = encoderWriteHeaders(7, newPromise()); Future<Void> f3 = encoderWriteHeaders(7, newPromise());
encoder.close(); encoder.close();
assertNotNull(f1.cause()); assertNotNull(f1.awaitUninterruptibly().cause());
assertNotNull(f2.cause()); assertNotNull(f2.awaitUninterruptibly().cause());
assertNotNull(f3.cause()); assertNotNull(f3.awaitUninterruptibly().cause());
} }
@Test @Test

View File

@ -398,6 +398,11 @@ public abstract class AbstractScheduledEventExecutor extends AbstractEventExecut
return future.isSuccess(); return future.isSuccess();
} }
@Override
public boolean isFailed() {
return future.isFailed();
}
@Override @Override
public boolean isCancellable() { public boolean isCancellable() {
return future.isCancellable(); return future.isCancellable();

View File

@ -163,6 +163,11 @@ public class DefaultPromise<V> implements Promise<V> {
return result != null && result != UNCANCELLABLE && !(result instanceof CauseHolder); return result != null && result != UNCANCELLABLE && !(result instanceof CauseHolder);
} }
@Override
public boolean isFailed() {
return result instanceof CauseHolder;
}
@Override @Override
public boolean isCancellable() { public boolean isCancellable() {
return result == null; return result == null;
@ -190,6 +195,9 @@ public class DefaultPromise<V> implements Promise<V> {
} }
private Throwable cause0(Object result) { private Throwable cause0(Object result) {
if (!isDone0(result)) {
throw new IllegalStateException("Cannot call cause() on a future that has not completed.");
}
if (!(result instanceof CauseHolder)) { if (!(result instanceof CauseHolder)) {
return null; return null;
} }
@ -316,7 +324,10 @@ public class DefaultPromise<V> implements Promise<V> {
@Override @Override
public V getNow() { public V getNow() {
Object result = this.result; Object result = this.result;
if (result instanceof CauseHolder || result == SUCCESS || result == UNCANCELLABLE) { if (!isDone0(result)) {
throw new IllegalStateException("Cannot call getNow() on a future that has not completed.");
}
if (result instanceof CauseHolder || result == SUCCESS) {
return null; return null;
} }
return (V) result; return (V) result;

View File

@ -53,9 +53,9 @@ import java.util.concurrent.TimeoutException;
* | isDone() = false | | +---------------------------+ * | isDone() = false | | +---------------------------+
* | isSuccess() = false |----+----> isDone() = true | * | isSuccess() = false |----+----> isDone() = true |
* | isCancelled() = false | | | cause() = non-null | * | isCancelled() = false | | | cause() = non-null |
* | cause() = null | | +===========================+ * | cause() = throws | | +===========================+
* +--------------------------+ | | Completed by cancellation | * | getNow() = throws | | | Completed by cancellation |
* | +---------------------------+ * +--------------------------+ | +---------------------------+
* +----> isDone() = true | * +----> isDone() = true |
* | isCancelled() = true | * | isCancelled() = true |
* +---------------------------+ * +---------------------------+
@ -168,11 +168,15 @@ import java.util.concurrent.TimeoutException;
@SuppressWarnings("ClassNameSameAsAncestorName") @SuppressWarnings("ClassNameSameAsAncestorName")
public interface Future<V> extends java.util.concurrent.Future<V> { public interface Future<V> extends java.util.concurrent.Future<V> {
/** /**
* Returns {@code true} if and only if the I/O operation was completed * Returns {@code true} if and only if the operation was completed successfully.
* successfully.
*/ */
boolean isSuccess(); boolean isSuccess();
/**
* Returns {@code true} if and only if the operation was completed and failed.
*/
boolean isFailed();
/** /**
* returns {@code true} if and only if the operation can be cancelled via {@link #cancel(boolean)}. * returns {@code true} if and only if the operation can be cancelled via {@link #cancel(boolean)}.
*/ */
@ -183,8 +187,8 @@ public interface Future<V> extends java.util.concurrent.Future<V> {
* failed. * failed.
* *
* @return the cause of the failure. * @return the cause of the failure.
* {@code null} if succeeded or this future is not * {@code null} if succeeded.
* completed yet. * @throws IllegalStateException if this {@code Future} has not completed yet.
*/ */
Throwable cause(); Throwable cause();
@ -291,10 +295,9 @@ public interface Future<V> extends java.util.concurrent.Future<V> {
boolean awaitUninterruptibly(long timeoutMillis); boolean awaitUninterruptibly(long timeoutMillis);
/** /**
* Return the result without blocking. If the future is not done yet this will return {@code null}. * Return the result without blocking. If the future is not done yet this will throw {@link IllegalStateException}.
* *
* As it is possible that a {@code null} value is used to mark the future as successful you also need to check * @throws IllegalStateException if this {@code Future} has not completed yet.
* if the future is really done with {@link #isDone()} and not rely on the returned {@code null} value.
*/ */
V getNow(); V getNow();

View File

@ -44,6 +44,11 @@ final class RunnableFutureAdapter<V> implements RunnableFuture<V> {
return promise.isSuccess(); return promise.isSuccess();
} }
@Override
public boolean isFailed() {
return promise.isFailed();
}
@Override @Override
public boolean isCancellable() { public boolean isCancellable() {
return promise.isCancellable(); return promise.isCancellable();

View File

@ -145,6 +145,11 @@ final class RunnableScheduledFutureAdapter<V> implements AbstractScheduledEventE
return promise.isSuccess(); return promise.isSuccess();
} }
@Override
public boolean isFailed() {
return promise.isFailed();
}
@Override @Override
public boolean isCancellable() { public boolean isCancellable() {
return promise.isCancellable(); return promise.isCancellable();

View File

@ -22,7 +22,6 @@ import io.netty.util.internal.logging.InternalLoggerFactory;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.function.Executable;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
@ -166,6 +165,8 @@ public class DefaultPromiseTest {
Exception cause = new Exception(); Exception cause = new Exception();
DefaultPromise<Void> promise = new DefaultPromise<Void>(executor); DefaultPromise<Void> promise = new DefaultPromise<Void>(executor);
promise.setFailure(cause); promise.setFailure(cause);
assertTrue(promise.isFailed());
assertFalse(promise.isSuccess());
assertSame(cause, promise.cause()); assertSame(cause, promise.cause());
} }
@ -188,6 +189,7 @@ public class DefaultPromiseTest {
DefaultPromise<Void> promise = new DefaultPromise<>(ImmediateEventExecutor.INSTANCE); DefaultPromise<Void> promise = new DefaultPromise<>(ImmediateEventExecutor.INSTANCE);
assertTrue(promise.cancel(false)); assertTrue(promise.cancel(false));
assertThat(promise.cause()).isInstanceOf(CancellationException.class); assertThat(promise.cause()).isInstanceOf(CancellationException.class);
assertTrue(promise.isFailed());
} }
@Test @Test
@ -356,6 +358,7 @@ public class DefaultPromiseTest {
promise.setSuccess(Signal.valueOf(DefaultPromise.class, "UNCANCELLABLE")); promise.setSuccess(Signal.valueOf(DefaultPromise.class, "UNCANCELLABLE"));
assertTrue(promise.isDone()); assertTrue(promise.isDone());
assertTrue(promise.isSuccess()); assertTrue(promise.isSuccess());
assertFalse(promise.isFailed());
} }
@Test @Test
@ -364,21 +367,25 @@ public class DefaultPromiseTest {
promise.setSuccess(Signal.valueOf(DefaultPromise.class, "SUCCESS")); promise.setSuccess(Signal.valueOf(DefaultPromise.class, "SUCCESS"));
assertTrue(promise.isDone()); assertTrue(promise.isDone());
assertTrue(promise.isSuccess()); assertTrue(promise.isSuccess());
assertFalse(promise.isFailed());
} }
@Test @Test
public void setUncancellableGetNow() { public void setUncancellableGetNow() {
DefaultPromise<String> promise = new DefaultPromise<>(ImmediateEventExecutor.INSTANCE); DefaultPromise<String> promise = new DefaultPromise<>(ImmediateEventExecutor.INSTANCE);
assertNull(promise.getNow()); assertThrows(IllegalStateException.class, () -> promise.getNow());
assertFalse(promise.isDone());
assertTrue(promise.setUncancellable()); assertTrue(promise.setUncancellable());
assertNull(promise.getNow()); assertThrows(IllegalStateException.class, () -> promise.getNow());
assertFalse(promise.isDone()); assertFalse(promise.isDone());
assertFalse(promise.isSuccess()); assertFalse(promise.isSuccess());
assertFalse(promise.isFailed());
promise.setSuccess("success"); promise.setSuccess("success");
assertTrue(promise.isDone()); assertTrue(promise.isDone());
assertTrue(promise.isSuccess()); assertTrue(promise.isSuccess());
assertFalse(promise.isFailed());
assertEquals("success", promise.getNow()); assertEquals("success", promise.getNow());
} }
@ -387,6 +394,7 @@ public class DefaultPromiseTest {
Exception exception = new Exception(); Exception exception = new Exception();
DefaultPromise<String> promise = new DefaultPromise<>(ImmediateEventExecutor.INSTANCE); DefaultPromise<String> promise = new DefaultPromise<>(ImmediateEventExecutor.INSTANCE);
promise.setFailure(exception); promise.setFailure(exception);
assertTrue(promise.isFailed());
try { try {
promise.sync(); promise.sync();
@ -400,6 +408,7 @@ public class DefaultPromiseTest {
Exception exception = new Exception(); Exception exception = new Exception();
DefaultPromise<String> promise = new DefaultPromise<>(ImmediateEventExecutor.INSTANCE); DefaultPromise<String> promise = new DefaultPromise<>(ImmediateEventExecutor.INSTANCE);
promise.setFailure(exception); promise.setFailure(exception);
assertTrue(promise.isFailed());
try { try {
promise.syncUninterruptibly(); promise.syncUninterruptibly();
@ -440,6 +449,19 @@ public class DefaultPromiseTest {
promise.setSuccess(result); promise.setSuccess(result);
} }
@Test
public void getNowOnUnfinishedPromiseMustThrow() {
DefaultPromise<Object> promise = new DefaultPromise<>(ImmediateEventExecutor.INSTANCE);
assertThrows(IllegalStateException.class, () -> promise.getNow());
}
@SuppressWarnings("ThrowableNotThrown")
@Test
public void causeOnUnfinishedPromiseMustThrow() {
DefaultPromise<Object> promise = new DefaultPromise<>(ImmediateEventExecutor.INSTANCE);
assertThrows(IllegalStateException.class, () -> promise.cause());
}
private static void testStackOverFlowChainedFuturesA(int promiseChainLength, final EventExecutor executor, private static void testStackOverFlowChainedFuturesA(int promiseChainLength, final EventExecutor executor,
boolean runTestInExecutorThread) boolean runTestInExecutorThread)
throws InterruptedException { throws InterruptedException {

View File

@ -828,9 +828,9 @@ public class SslHandler extends ByteToMessageDecoder {
if (result.getStatus() == Status.CLOSED) { if (result.getStatus() == Status.CLOSED) {
// Make a best effort to preserve any exception that way previously encountered from the handshake // Make a best effort to preserve any exception that way previously encountered from the handshake
// or the transport, else fallback to a general error. // or the transport, else fallback to a general error.
Throwable exception = handshakePromise.cause(); Throwable exception = handshakePromise.isDone() ? handshakePromise.cause() : null;
if (exception == null) { if (exception == null) {
exception = sslClosePromise.cause(); exception = sslClosePromise.isDone() ? sslClosePromise.cause() : null;
if (exception == null) { if (exception == null) {
exception = new SslClosedEngineException("SSLEngine closed already"); exception = new SslClosedEngineException("SSLEngine closed already");
} }
@ -1032,7 +1032,7 @@ public class SslHandler extends ByteToMessageDecoder {
@Override @Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception { public void channelInactive(ChannelHandlerContext ctx) throws Exception {
boolean handshakeFailed = handshakePromise.cause() != null; boolean handshakeFailed = handshakePromise.isFailed();
ClosedChannelException exception = new ClosedChannelException(); ClosedChannelException exception = new ClosedChannelException();
// Make sure to release SSLEngine, // Make sure to release SSLEngine,

View File

@ -487,7 +487,7 @@ public class DnsNameResolver extends InetNameResolver {
try { try {
ch = b.createUnregistered(); ch = b.createUnregistered();
Future<Void> future = localAddress == null ? ch.register() : ch.bind(localAddress); Future<Void> future = localAddress == null ? ch.register() : ch.bind(localAddress);
if (future.cause() != null) { if (future.isFailed()) {
throw future.cause(); throw future.cause();
} }
} catch (Error | RuntimeException e) { } catch (Error | RuntimeException e) {

View File

@ -40,6 +40,7 @@ import java.util.concurrent.TimeUnit;
import static io.netty.testsuite.transport.socket.SocketTestPermutation.BAD_HOST; import static io.netty.testsuite.transport.socket.SocketTestPermutation.BAD_HOST;
import static io.netty.testsuite.transport.socket.SocketTestPermutation.BAD_PORT; import static io.netty.testsuite.transport.socket.SocketTestPermutation.BAD_PORT;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assertions.fail;
import static org.junit.jupiter.api.Assumptions.assumeTrue; import static org.junit.jupiter.api.Assumptions.assumeTrue;
@ -96,7 +97,7 @@ public class SocketConnectionAttemptTest extends AbstractClientSocketTest {
cb.option(ChannelOption.ALLOW_HALF_CLOSURE, halfClosure); cb.option(ChannelOption.ALLOW_HALF_CLOSURE, halfClosure);
Future<Channel> future = cb.connect(NetUtil.LOCALHOST, UNASSIGNED_PORT).awaitUninterruptibly(); Future<Channel> future = cb.connect(NetUtil.LOCALHOST, UNASSIGNED_PORT).awaitUninterruptibly();
assertThat(future.cause()).isInstanceOf(ConnectException.class); assertThat(future.cause()).isInstanceOf(ConnectException.class);
assertThat(errorPromise.cause()).isNull(); assertFalse(errorPromise.isFailed());
} }
@Test @Test

View File

@ -240,7 +240,7 @@ public abstract class AbstractBootstrap<B extends AbstractBootstrap<B, C, F>, C
private Future<Channel> doBind(final SocketAddress localAddress) { private Future<Channel> doBind(final SocketAddress localAddress) {
EventLoop loop = group.next(); EventLoop loop = group.next();
final Future<Channel> regFuture = initAndRegister(loop); final Future<Channel> regFuture = initAndRegister(loop);
if (regFuture.cause() != null) { if (regFuture.isFailed()) {
return regFuture; return regFuture;
} }

View File

@ -1052,7 +1052,7 @@ public class DefaultChannelPipelineTest {
assertTrue(handler.addedHandler.get()); assertTrue(handler.addedHandler.get());
assertTrue(handler.removedHandler.get()); assertTrue(handler.removedHandler.get());
assertTrue(handler2.addedHandler.get()); assertTrue(handler2.addedHandler.get());
assertNull(handler2.removedHandler.getNow()); assertFalse(handler2.removedHandler.isDone());
pipeline.channel().register().syncUninterruptibly(); pipeline.channel().register().syncUninterruptibly();
Throwable cause = handler.error.get(); Throwable cause = handler.error.get();
@ -1065,7 +1065,7 @@ public class DefaultChannelPipelineTest {
throw cause2; throw cause2;
} }
assertNull(handler2.removedHandler.getNow()); assertFalse(handler2.removedHandler.isDone());
pipeline.remove(handler2); pipeline.remove(handler2);
assertTrue(handler2.removedHandler.get()); assertTrue(handler2.removedHandler.get());
pipeline.channel().close().syncUninterruptibly(); pipeline.channel().close().syncUninterruptibly();
@ -1740,7 +1740,7 @@ public class DefaultChannelPipelineTest {
public void handlerAdded(ChannelHandlerContext ctx) throws Exception { public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
if (!addedHandler.trySuccess(true)) { if (!addedHandler.trySuccess(true)) {
error.set(new AssertionError("handlerAdded(...) called multiple times: " + ctx.name())); error.set(new AssertionError("handlerAdded(...) called multiple times: " + ctx.name()));
} else if (removedHandler.getNow() == Boolean.TRUE) { } else if (removedHandler.isDone() && removedHandler.getNow() == Boolean.TRUE) {
error.set(new AssertionError("handlerRemoved(...) called before handlerAdded(...): " + ctx.name())); error.set(new AssertionError("handlerRemoved(...) called before handlerAdded(...): " + ctx.name()));
} }
} }
@ -1749,7 +1749,7 @@ public class DefaultChannelPipelineTest {
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
if (!removedHandler.trySuccess(true)) { if (!removedHandler.trySuccess(true)) {
error.set(new AssertionError("handlerRemoved(...) called multiple times: " + ctx.name())); error.set(new AssertionError("handlerRemoved(...) called multiple times: " + ctx.name()));
} else if (addedHandler.getNow() == Boolean.FALSE) { } else if (addedHandler.isDone() && addedHandler.getNow() == Boolean.FALSE) {
error.set(new AssertionError("handlerRemoved(...) called before handlerAdded(...): " + ctx.name())); error.set(new AssertionError("handlerRemoved(...) called before handlerAdded(...): " + ctx.name()));
} }
} }