Correctly write pending data after ssl handshake completes. Related to [#2437]

Motivation:
When writing data from a server before the ssl handshake completes may not be written at all to the remote peer
if nothing else is written after the handshake was done.

Modification:
Correctly try to write pending data after the handshake was complete

Result:
Correctly write out all pending data
This commit is contained in:
Norman Maurer 2014-04-30 12:31:41 +02:00
parent 79fc0d2408
commit e920dc57db
2 changed files with 170 additions and 3 deletions

View File

@ -177,7 +177,7 @@ public class SslHandler extends ByteToMessageDecoder {
private final boolean startTls;
private boolean sentFirstMessage;
private boolean flushedBeforeHandshakeDone;
private final LazyChannelPromise handshakePromise = new LazyChannelPromise();
private final LazyChannelPromise sslCloseFuture = new LazyChannelPromise();
private final Deque<PendingWrite> pendingUnencryptedWrites = new ArrayDeque<PendingWrite>();
@ -367,6 +367,9 @@ public class SslHandler extends ByteToMessageDecoder {
if (pendingUnencryptedWrites.isEmpty()) {
pendingUnencryptedWrites.add(PendingWrite.newInstance(Unpooled.EMPTY_BUFFER, null));
}
if (!handshakePromise.isDone()) {
flushedBeforeHandshakeDone = true;
}
wrap(ctx, false);
ctx.flush();
}
@ -389,7 +392,6 @@ public class SslHandler extends ByteToMessageDecoder {
pendingUnencryptedWrites.remove();
continue;
}
ByteBuf buf = (ByteBuf) pending.msg();
SSLEngineResult result = wrap(engine, buf, out);
@ -842,7 +844,6 @@ public class SslHandler extends ByteToMessageDecoder {
private void unwrapMultiple(
ChannelHandlerContext ctx, ByteBuffer packet, int totalLength,
int[] recordLengths, int nRecords, List<Object> out) throws SSLException {
for (int i = 0; i < nRecords; i ++) {
packet.limit(packet.position() + recordLengths[i]);
try {
@ -901,6 +902,14 @@ public class SslHandler extends ByteToMessageDecoder {
wrapLater = true;
continue;
}
if (flushedBeforeHandshakeDone) {
// We need to call wrap(...) in case there was a flush done before the handshake completed.
//
// See https://github.com/netty/netty/pull/2437
flushedBeforeHandshakeDone = false;
wrapLater = true;
}
break;
default:
throw new IllegalStateException("Unknown handshake status: " + handshakeStatus);

View File

@ -0,0 +1,158 @@
/*
* Copyright 2014 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.testsuite.transport.socket;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.ssl.SslHandler;
import io.netty.testsuite.util.BogusSslContextFactory;
import io.netty.util.ReferenceCountUtil;
import org.junit.Test;
import javax.net.ssl.SSLEngine;
import java.io.IOException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import static org.junit.Assert.assertEquals;
public class SocketSslGreetingTest extends AbstractSocketTest {
private static final LogLevel LOG_LEVEL = LogLevel.TRACE;
private final ByteBuf greeting = ReferenceCountUtil.releaseLater(Unpooled.buffer().writeByte('a'));
// Test for https://github.com/netty/netty/pull/2437
@Test(timeout = 30000)
public void testSslGreeting() throws Throwable {
run();
}
public void testSslGreeting(ServerBootstrap sb, Bootstrap cb) throws Throwable {
final ServerHandler sh = new ServerHandler();
final ClientHandler ch = new ClientHandler();
sb.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel sch) throws Exception {
final SSLEngine sse = BogusSslContextFactory.getServerContext().createSSLEngine();
sse.setUseClientMode(false);
ChannelPipeline p = sch.pipeline();
p.addLast(new SslHandler(sse));
p.addLast("logger", new LoggingHandler(LOG_LEVEL));
p.addLast(sh);
}
});
cb.handler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel sch) throws Exception {
final SSLEngine cse = BogusSslContextFactory.getClientContext().createSSLEngine();
cse.setUseClientMode(true);
ChannelPipeline p = sch.pipeline();
p.addLast(new SslHandler(cse));
p.addLast("logger", new LoggingHandler(LOG_LEVEL));
p.addLast(ch);
}
});
Channel sc = sb.bind().sync().channel();
Channel cc = cb.connect().sync().channel();
ch.latch.await();
sh.channel.close().awaitUninterruptibly();
cc.close().awaitUninterruptibly();
sc.close().awaitUninterruptibly();
if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
throw sh.exception.get();
}
if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
throw ch.exception.get();
}
if (sh.exception.get() != null) {
throw sh.exception.get();
}
if (ch.exception.get() != null) {
throw ch.exception.get();
}
}
private class ClientHandler extends SimpleChannelInboundHandler<ByteBuf> {
final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
final CountDownLatch latch = new CountDownLatch(1);
@Override
public void channelRead0(ChannelHandlerContext ctx, ByteBuf buf) throws Exception {
assertEquals(greeting, buf);
latch.countDown();
ctx.close();
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx,
Throwable cause) throws Exception {
if (logger.isWarnEnabled()) {
logger.warn("Unexpected exception from the client side", cause);
}
exception.compareAndSet(null, cause);
ctx.close();
}
}
private class ServerHandler extends SimpleChannelInboundHandler<String> {
volatile Channel channel;
final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
@Override
protected void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception {
// discard
}
@Override
public void channelActive(ChannelHandlerContext ctx)
throws Exception {
channel = ctx.channel();
channel.writeAndFlush(greeting.duplicate().retain());
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx,
Throwable cause) throws Exception {
if (logger.isWarnEnabled()) {
logger.warn("Unexpected exception from the server side", cause);
}
exception.compareAndSet(null, cause);
ctx.close();
}
}
}