Correctly write pending data after ssl handshake completes. Related to [#2437]
Motivation: When writing data from a server before the ssl handshake completes may not be written at all to the remote peer if nothing else is written after the handshake was done. Modification: Correctly try to write pending data after the handshake was complete Result: Correctly write out all pending data
This commit is contained in:
parent
3ba856a5f4
commit
0aea9eaed5
@ -183,7 +183,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
|
||||
private final boolean startTls;
|
||||
private boolean sentFirstMessage;
|
||||
|
||||
private boolean flushedBeforeHandshakeDone;
|
||||
private final LazyChannelPromise handshakePromise = new LazyChannelPromise();
|
||||
private final LazyChannelPromise sslCloseFuture = new LazyChannelPromise();
|
||||
private final Deque<PendingWrite> pendingUnencryptedWrites = new ArrayDeque<PendingWrite>();
|
||||
@ -415,6 +415,9 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
if (pendingUnencryptedWrites.isEmpty()) {
|
||||
pendingUnencryptedWrites.add(PendingWrite.newInstance(Unpooled.EMPTY_BUFFER, null));
|
||||
}
|
||||
if (!handshakePromise.isDone()) {
|
||||
flushedBeforeHandshakeDone = true;
|
||||
}
|
||||
wrap(ctx, false);
|
||||
ctx.flush();
|
||||
}
|
||||
@ -437,7 +440,6 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
pendingUnencryptedWrites.remove();
|
||||
continue;
|
||||
}
|
||||
|
||||
ByteBuf buf = (ByteBuf) pending.msg();
|
||||
SSLEngineResult result = wrap(engine, buf, out);
|
||||
|
||||
@ -890,7 +892,6 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
private void unwrapMultiple(
|
||||
ChannelHandlerContext ctx, ByteBuffer packet, int totalLength,
|
||||
int[] recordLengths, int nRecords, List<Object> out) throws SSLException {
|
||||
|
||||
for (int i = 0; i < nRecords; i ++) {
|
||||
packet.limit(packet.position() + recordLengths[i]);
|
||||
try {
|
||||
@ -949,6 +950,14 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
||||
wrapLater = true;
|
||||
continue;
|
||||
}
|
||||
if (flushedBeforeHandshakeDone) {
|
||||
// We need to call wrap(...) in case there was a flush done before the handshake completed.
|
||||
//
|
||||
// See https://github.com/netty/netty/pull/2437
|
||||
flushedBeforeHandshakeDone = false;
|
||||
wrapLater = true;
|
||||
}
|
||||
|
||||
break;
|
||||
default:
|
||||
throw new IllegalStateException("Unknown handshake status: " + handshakeStatus);
|
||||
|
@ -0,0 +1,158 @@
|
||||
/*
|
||||
* Copyright 2014 The Netty Project
|
||||
*
|
||||
* The Netty Project licenses this file to you under the Apache License,
|
||||
* version 2.0 (the "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at:
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package io.netty.testsuite.transport.socket;
|
||||
|
||||
import io.netty.bootstrap.Bootstrap;
|
||||
import io.netty.bootstrap.ServerBootstrap;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.Channel;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInitializer;
|
||||
import io.netty.channel.ChannelPipeline;
|
||||
import io.netty.channel.SimpleChannelInboundHandler;
|
||||
import io.netty.channel.socket.SocketChannel;
|
||||
import io.netty.handler.logging.LogLevel;
|
||||
import io.netty.handler.logging.LoggingHandler;
|
||||
import io.netty.handler.ssl.SslHandler;
|
||||
import io.netty.testsuite.util.BogusSslContextFactory;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import org.junit.Test;
|
||||
|
||||
import javax.net.ssl.SSLEngine;
|
||||
import java.io.IOException;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
|
||||
public class SocketSslGreetingTest extends AbstractSocketTest {
|
||||
|
||||
private static final LogLevel LOG_LEVEL = LogLevel.TRACE;
|
||||
private final ByteBuf greeting = ReferenceCountUtil.releaseLater(Unpooled.buffer().writeByte('a'));
|
||||
|
||||
// Test for https://github.com/netty/netty/pull/2437
|
||||
@Test(timeout = 30000)
|
||||
public void testSslGreeting() throws Throwable {
|
||||
run();
|
||||
}
|
||||
|
||||
public void testSslGreeting(ServerBootstrap sb, Bootstrap cb) throws Throwable {
|
||||
final ServerHandler sh = new ServerHandler();
|
||||
final ClientHandler ch = new ClientHandler();
|
||||
|
||||
sb.childHandler(new ChannelInitializer<SocketChannel>() {
|
||||
@Override
|
||||
public void initChannel(SocketChannel sch) throws Exception {
|
||||
final SSLEngine sse = BogusSslContextFactory.getServerContext().createSSLEngine();
|
||||
sse.setUseClientMode(false);
|
||||
|
||||
ChannelPipeline p = sch.pipeline();
|
||||
p.addLast(new SslHandler(sse));
|
||||
p.addLast("logger", new LoggingHandler(LOG_LEVEL));
|
||||
p.addLast(sh);
|
||||
}
|
||||
});
|
||||
|
||||
cb.handler(new ChannelInitializer<SocketChannel>() {
|
||||
@Override
|
||||
public void initChannel(SocketChannel sch) throws Exception {
|
||||
final SSLEngine cse = BogusSslContextFactory.getClientContext().createSSLEngine();
|
||||
cse.setUseClientMode(true);
|
||||
|
||||
ChannelPipeline p = sch.pipeline();
|
||||
p.addLast(new SslHandler(cse));
|
||||
p.addLast("logger", new LoggingHandler(LOG_LEVEL));
|
||||
p.addLast(ch);
|
||||
}
|
||||
});
|
||||
|
||||
Channel sc = sb.bind().sync().channel();
|
||||
Channel cc = cb.connect().sync().channel();
|
||||
|
||||
ch.latch.await();
|
||||
|
||||
sh.channel.close().awaitUninterruptibly();
|
||||
cc.close().awaitUninterruptibly();
|
||||
sc.close().awaitUninterruptibly();
|
||||
|
||||
if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
|
||||
throw sh.exception.get();
|
||||
}
|
||||
if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
|
||||
throw ch.exception.get();
|
||||
}
|
||||
if (sh.exception.get() != null) {
|
||||
throw sh.exception.get();
|
||||
}
|
||||
if (ch.exception.get() != null) {
|
||||
throw ch.exception.get();
|
||||
}
|
||||
}
|
||||
|
||||
private class ClientHandler extends SimpleChannelInboundHandler<ByteBuf> {
|
||||
|
||||
final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
|
||||
final CountDownLatch latch = new CountDownLatch(1);
|
||||
|
||||
@Override
|
||||
public void channelRead0(ChannelHandlerContext ctx, ByteBuf buf) throws Exception {
|
||||
assertEquals(greeting, buf);
|
||||
latch.countDown();
|
||||
ctx.close();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void exceptionCaught(ChannelHandlerContext ctx,
|
||||
Throwable cause) throws Exception {
|
||||
if (logger.isWarnEnabled()) {
|
||||
logger.warn("Unexpected exception from the client side", cause);
|
||||
}
|
||||
|
||||
exception.compareAndSet(null, cause);
|
||||
ctx.close();
|
||||
}
|
||||
}
|
||||
|
||||
private class ServerHandler extends SimpleChannelInboundHandler<String> {
|
||||
volatile Channel channel;
|
||||
final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
|
||||
|
||||
@Override
|
||||
protected void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception {
|
||||
// discard
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelActive(ChannelHandlerContext ctx)
|
||||
throws Exception {
|
||||
channel = ctx.channel();
|
||||
channel.writeAndFlush(greeting.duplicate().retain());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void exceptionCaught(ChannelHandlerContext ctx,
|
||||
Throwable cause) throws Exception {
|
||||
if (logger.isWarnEnabled()) {
|
||||
logger.warn("Unexpected exception from the server side", cause);
|
||||
}
|
||||
|
||||
exception.compareAndSet(null, cause);
|
||||
ctx.close();
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user