Allow rejection of remote initiated renegotiation

Motivation:

To prevent from DOS attacks it can be useful to disable remote initiated renegotiation.

Modifications:

Add new flag to OpenSslContext that can be used to disable it
Adding a testcase

Result:

Remote initiated renegotion requests can be disabled now.
This commit is contained in:
Norman Maurer 2015-05-05 15:14:07 +02:00
parent 04396acb75
commit edf1eb10d5
4 changed files with 279 additions and 5 deletions

View File

@ -17,6 +17,7 @@ package io.netty.handler.ssl;
import io.netty.buffer.ByteBufAllocator;
import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.SystemPropertyUtil;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import org.apache.tomcat.jni.CertificateVerifier;
@ -45,6 +46,16 @@ import static io.netty.handler.ssl.ApplicationProtocolConfig.SelectedListenerFai
public abstract class OpenSslContext extends SslContext {
private static final InternalLogger logger = InternalLoggerFactory.getInstance(OpenSslContext.class);
/**
* To make it easier for users to replace JDK implemention with OpenSsl version we also use
* {@code jdk.tls.rejectClientInitiatedRenegotiation} to allow disabling client initiated renegotiation.
* Java8+ uses this system property as well.
*
* See also <a href="http://blog.ivanristic.com/2014/03/ssl-tls-improvements-in-java-8.html">
* Significant SSL/TLS improvements in Java 8</a>
*/
private static final boolean JDK_REJECT_CLIENT_INITIATED_RENEGOTIATION =
SystemPropertyUtil.getBoolean("jdk.tls.rejectClientInitiatedRenegotiation", false);
private static final List<String> DEFAULT_CIPHERS;
private static final AtomicIntegerFieldUpdater<OpenSslContext> DESTROY_UPDATER;
@ -54,6 +65,7 @@ public abstract class OpenSslContext extends SslContext {
private final long aprPool;
@SuppressWarnings({ "unused", "FieldMayBeFinal" })
private volatile int aprPoolDestroyed;
private volatile boolean rejectRemoteInitiatedRenegotiation;
private final List<String> unmodifiableCiphers;
private final long sessionCacheSize;
private final long sessionTimeout;
@ -129,6 +141,10 @@ public abstract class OpenSslContext extends SslContext {
}
this.mode = mode;
if (mode == SSL.SSL_MODE_SERVER) {
rejectRemoteInitiatedRenegotiation =
JDK_REJECT_CLIENT_INITIATED_RENEGOTIATION;
}
final List<String> convertedCiphers;
if (ciphers == null) {
convertedCiphers = null;
@ -280,7 +296,8 @@ public abstract class OpenSslContext extends SslContext {
*/
@Override
public final SSLEngine newEngine(ByteBufAllocator alloc) {
final OpenSslEngine engine = new OpenSslEngine(ctx, alloc, isClient(), sessionContext(), apn, engineMap);
final OpenSslEngine engine = new OpenSslEngine(
ctx, alloc, isClient(), sessionContext(), apn, engineMap, rejectRemoteInitiatedRenegotiation);
engineMap.add(engine);
return engine;
}
@ -301,6 +318,14 @@ public abstract class OpenSslContext extends SslContext {
return sessionContext().stats();
}
/**
* Specify if remote initiated renegotiation is supported or not. If not supported and the remote side tries
* to initiate a renegotiation a {@link SSLHandshakeException} will be thrown during decoding.
*/
public void setRejectRemoteInitiatedRenegotiation(boolean rejectRemoteInitiatedRenegotiation) {
this.rejectRemoteInitiatedRenegotiation = rejectRemoteInitiatedRenegotiation;
}
@Override
@SuppressWarnings("FinalizeDeclaration")
protected final void finalize() throws Throwable {

View File

@ -159,6 +159,7 @@ public final class OpenSslEngine extends SSLEngine {
private final OpenSslSessionContext sessionContext;
private final OpenSslEngineMap engineMap;
private final OpenSslApplicationProtocolNegotiator apn;
private final boolean rejectRemoteInitiatedRenegation;
private final SSLSession session = new OpenSslSession();
// This is package-private as we set it from OpenSslContext if an exception is thrown during
@ -174,7 +175,7 @@ public final class OpenSslEngine extends SSLEngine {
@Deprecated
public OpenSslEngine(long sslCtx, ByteBufAllocator alloc,
@SuppressWarnings("unused") String fallbackApplicationProtocol) {
this(sslCtx, alloc, false, null, OpenSslContext.NONE_PROTOCOL_NEGOTIATOR, OpenSslEngineMap.EMPTY);
this(sslCtx, alloc, false, null, OpenSslContext.NONE_PROTOCOL_NEGOTIATOR, OpenSslEngineMap.EMPTY, false);
}
/**
@ -187,7 +188,8 @@ public final class OpenSslEngine extends SSLEngine {
*/
OpenSslEngine(long sslCtx, ByteBufAllocator alloc,
boolean clientMode, OpenSslSessionContext sessionContext,
OpenSslApplicationProtocolNegotiator apn, OpenSslEngineMap engineMap) {
OpenSslApplicationProtocolNegotiator apn, OpenSslEngineMap engineMap,
boolean rejectRemoteInitiatedRenegation) {
OpenSsl.ensureAvailability();
if (sslCtx == 0) {
throw new NullPointerException("sslCtx");
@ -200,6 +202,7 @@ public final class OpenSslEngine extends SSLEngine {
this.clientMode = clientMode;
this.sessionContext = sessionContext;
this.engineMap = engineMap;
this.rejectRemoteInitiatedRenegation = rejectRemoteInitiatedRenegation;
}
@Override
@ -634,6 +637,8 @@ public final class OpenSslEngine extends SSLEngine {
checkPendingHandshakeException();
}
}
rejectRemoteInitiatedRenegation();
} else {
// Reset to 0 as -1 is used to signal that nothing was written and no priming read needs to be done
bytesConsumed = 0;
@ -671,10 +676,11 @@ public final class OpenSslEngine extends SSLEngine {
throw new SSLException(e);
}
rejectRemoteInitiatedRenegation();
if (bytesRead == 0) {
break;
}
bytesProduced += bytesRead;
pendingApp -= bytesRead;
@ -694,6 +700,15 @@ public final class OpenSslEngine extends SSLEngine {
return new SSLEngineResult(getEngineStatus(), handshakeStatus0(), bytesConsumed, bytesProduced);
}
private void rejectRemoteInitiatedRenegation() throws SSLHandshakeException {
if (rejectRemoteInitiatedRenegation && SSL.getHandshakeCount(ssl) > 1) {
// TODO: In future versions me may also want to send a fatal_alert to the client and so notify it
// that the renegotiation failed.
shutdown();
throw new SSLHandshakeException("remote-initiated renegotation not allowed");
}
}
public SSLEngineResult unwrap(final ByteBuffer[] srcs, final ByteBuffer[] dsts) throws SSLException {
return unwrap(srcs, 0, srcs.length, dsts, 0, dsts.length);
}

View File

@ -681,7 +681,7 @@
<dependency>
<groupId>${project.groupId}</groupId>
<artifactId>netty-tcnative</artifactId>
<version>1.1.33.Fork1</version>
<version>1.1.33.Fork2</version>
<classifier>${os.detected.classifier}</classifier>
<scope>compile</scope>
<optional>true</optional>

View File

@ -0,0 +1,234 @@
/*
* Copyright 2015 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.testsuite.transport.socket;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.ssl.JdkSslClientContext;
import io.netty.handler.ssl.OpenSsl;
import io.netty.handler.ssl.OpenSslServerContext;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.util.concurrent.Future;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import org.junit.Assume;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
import javax.net.ssl.SSLHandshakeException;
import java.io.File;
import java.nio.channels.ClosedChannelException;
import java.security.cert.CertificateException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import static org.junit.Assert.*;
@RunWith(Parameterized.class)
public class SocketSslClientRenegotiateTest extends AbstractSocketTest {
private static final InternalLogger logger = InternalLoggerFactory.getInstance(
SocketSslClientRenegotiateTest.class);
private static final File CERT_FILE;
private static final File KEY_FILE;
static {
SelfSignedCertificate ssc;
try {
ssc = new SelfSignedCertificate();
} catch (CertificateException e) {
throw new Error(e);
}
CERT_FILE = ssc.certificate();
KEY_FILE = ssc.privateKey();
}
@Parameters(name = "{index}: serverEngine = {0}, clientEngine = {1}")
public static Collection<Object[]> data() throws Exception {
List<SslContext> serverContexts = new ArrayList<SslContext>();
List<SslContext> clientContexts = new ArrayList<SslContext>();
clientContexts.add(new JdkSslClientContext(CERT_FILE));
boolean hasOpenSsl = OpenSsl.isAvailable();
if (hasOpenSsl) {
OpenSslServerContext context = new OpenSslServerContext(CERT_FILE, KEY_FILE);
context.setRejectRemoteInitiatedRenegotiation(true);
serverContexts.add(context);
} else {
logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
}
List<Object[]> params = new ArrayList<Object[]>();
for (SslContext sc: serverContexts) {
for (SslContext cc: clientContexts) {
for (int i = 0; i < 32; i++) {
params.add(new Object[] { sc, cc});
}
}
}
return params;
}
private final SslContext serverCtx;
private final SslContext clientCtx;
private final AtomicReference<Throwable> clientException = new AtomicReference<Throwable>();
private final AtomicReference<Throwable> serverException = new AtomicReference<Throwable>();
private volatile Channel clientChannel;
private volatile Channel serverChannel;
private volatile SslHandler clientSslHandler;
private volatile SslHandler serverSslHandler;
private final TestHandler clientHandler = new TestHandler(clientException);
private final TestHandler serverHandler = new TestHandler(serverException);
public SocketSslClientRenegotiateTest(
SslContext serverCtx, SslContext clientCtx) {
this.serverCtx = serverCtx;
this.clientCtx = clientCtx;
}
@Test(timeout = 30000)
public void testSslRenegotiationRejected() throws Throwable {
Assume.assumeTrue(OpenSsl.isAvailable());
run();
}
public void testSslRenegotiationRejected(ServerBootstrap sb, Bootstrap cb) throws Throwable {
reset();
sb.childHandler(new ChannelInitializer<Channel>() {
@Override
@SuppressWarnings("deprecation")
public void initChannel(Channel sch) throws Exception {
serverChannel = sch;
serverSslHandler = serverCtx.newHandler(sch.alloc());
sch.pipeline().addLast("ssl", serverSslHandler);
sch.pipeline().addLast("handler", serverHandler);
}
});
cb.handler(new ChannelInitializer<Channel>() {
@Override
@SuppressWarnings("deprecation")
public void initChannel(Channel sch) throws Exception {
clientChannel = sch;
clientSslHandler = clientCtx.newHandler(sch.alloc());
sch.pipeline().addLast("ssl", clientSslHandler);
sch.pipeline().addLast("handler", clientHandler);
}
});
Channel sc = sb.bind().sync().channel();
cb.connect().sync();
Future<Channel> clientHandshakeFuture = clientSslHandler.handshakeFuture();
clientHandshakeFuture.sync();
String renegotiation = "SSL_RSA_WITH_RC4_128_SHA";
clientSslHandler.engine().setEnabledCipherSuites(new String[] { renegotiation });
clientSslHandler.renegotiate().await();
serverChannel.close().awaitUninterruptibly();
clientChannel.close().awaitUninterruptibly();
sc.close().awaitUninterruptibly();
try {
if (serverException.get() != null) {
throw serverException.get();
}
fail();
} catch (DecoderException e) {
assertTrue(e.getCause() instanceof SSLHandshakeException);
}
if (clientException.get() != null) {
throw clientException.get();
}
}
private void reset() {
clientException.set(null);
serverException.set(null);
clientHandler.handshakeCounter = 0;
serverHandler.handshakeCounter = 0;
clientChannel = null;
serverChannel = null;
clientSslHandler = null;
serverSslHandler = null;
}
@Sharable
private static final class TestHandler extends SimpleChannelInboundHandler<ByteBuf> {
protected final AtomicReference<Throwable> exception;
private int handshakeCounter;
TestHandler(AtomicReference<Throwable> exception) {
this.exception = exception;
}
@Override
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
ctx.flush();
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
exception.compareAndSet(null, cause);
ctx.close();
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof SslHandshakeCompletionEvent) {
SslHandshakeCompletionEvent handshakeEvt = (SslHandshakeCompletionEvent) evt;
if (handshakeCounter == 0) {
handshakeCounter++;
if (handshakeEvt.cause() != null) {
logger.warn("Handshake failed:", handshakeEvt.cause());
}
assertSame(SslHandshakeCompletionEvent.SUCCESS, evt);
} else {
if (ctx.channel().parent() == null) {
assertTrue(handshakeEvt.cause() instanceof ClosedChannelException);
}
}
}
}
@Override
public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception { }
}
}