diff --git a/handler/src/main/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandler.java b/handler/src/main/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandler.java index fe91aff654..6db036e8d2 100644 --- a/handler/src/main/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandler.java @@ -23,6 +23,7 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelPipeline; import io.netty.handler.codec.DecoderException; +import io.netty.util.internal.RecyclableArrayList; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -69,6 +70,8 @@ public abstract class ApplicationProtocolNegotiationHandler implements ChannelHa InternalLoggerFactory.getInstance(ApplicationProtocolNegotiationHandler.class); private final String fallbackProtocol; + private final RecyclableArrayList bufferedMessages = RecyclableArrayList.newInstance(); + private ChannelHandlerContext ctx; /** * Creates a new instance with the specified fallback protocol name. @@ -80,6 +83,36 @@ public abstract class ApplicationProtocolNegotiationHandler implements ChannelHa this.fallbackProtocol = requireNonNull(fallbackProtocol, "fallbackProtocol"); } + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + this.ctx = ctx; + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + fireBufferedMessages(); + bufferedMessages.recycle(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + // Let's buffer all data until this handler will be removed from the pipeline. + bufferedMessages.add(msg); + } + + /** + * Process all backlog into pipeline from List. + */ + private void fireBufferedMessages() { + if (!bufferedMessages.isEmpty()) { + for (int i = 0; i < bufferedMessages.size(); i++) { + ctx.fireChannelRead(bufferedMessages.get(i)); + } + ctx.fireChannelReadComplete(); + bufferedMessages.clear(); + } + } + @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { if (evt instanceof SslHandshakeCompletionEvent) { @@ -120,6 +153,7 @@ public abstract class ApplicationProtocolNegotiationHandler implements ChannelHa pipeline.remove(this); } } + /** * Invoked on successful initial SSL/TLS handshake. Implement this method to configure your pipeline * for the negotiated application-level protocol. diff --git a/handler/src/test/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandlerTest.java index 678d0d22db..bc2470a5d9 100644 --- a/handler/src/test/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandlerTest.java @@ -24,11 +24,13 @@ import org.junit.Test; import java.security.NoSuchAlgorithmException; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLHandshakeException; import static io.netty.handler.ssl.CloseNotifyTest.assertCloseNotify; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; @@ -99,4 +101,50 @@ public class ApplicationProtocolNegotiationHandlerTest { assertFalse(channel.finishAndReleaseAll()); } } + + @Test + public void testBufferMessagesUntilHandshakeComplete() throws Exception { + final AtomicReference channelReadData = new AtomicReference(); + final AtomicBoolean channelReadCompleteCalled = new AtomicBoolean(false); + ChannelHandler alpnHandler = new ApplicationProtocolNegotiationHandler(ApplicationProtocolNames.HTTP_1_1) { + @Override + protected void configurePipeline(ChannelHandlerContext ctx, String protocol) { + assertEquals(ApplicationProtocolNames.HTTP_1_1, protocol); + ctx.pipeline().addLast(new ChannelHandler() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + channelReadData.set((byte[]) msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + channelReadCompleteCalled.set(true); + } + }); + } + }; + + SSLEngine engine = SSLContext.getDefault().createSSLEngine(); + // This test is mocked/simulated and doesn't go through full TLS handshake. Currently only JDK SSLEngineImpl + // client mode will generate a close_notify. + engine.setUseClientMode(true); + + final byte[] someBytes = new byte[1024]; + + EmbeddedChannel channel = new EmbeddedChannel(new SslHandler(engine), new ChannelHandler() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt == SslHandshakeCompletionEvent.SUCCESS) { + ctx.fireChannelRead(someBytes); + } + ctx.fireUserEventTriggered(evt); + } + }, alpnHandler); + channel.pipeline().fireUserEventTriggered(SslHandshakeCompletionEvent.SUCCESS); + assertNull(channel.pipeline().context(alpnHandler)); + assertArrayEquals(someBytes, channelReadData.get()); + assertTrue(channelReadCompleteCalled.get()); + assertNull(channel.readInbound()); + channel.finishAndReleaseAll(); + } }