This commit is contained in:
Norman Maurer 2021-02-03 14:26:18 +01:00
parent 9b632db772
commit 2398f6557f
3 changed files with 65 additions and 57 deletions

View File

@ -36,13 +36,8 @@ final class OpenSslClientSessionCache extends OpenSslSessionCache {
@Override @Override
protected boolean sessionCreated(OpenSslSession session) { protected boolean sessionCreated(OpenSslSession session) {
assert Thread.holdsLock(this); assert Thread.holdsLock(this);
String host = session.getPeerHost(); HostPort hostPort = keyFor(session.getPeerHost(), session.getPeerPort());
int port = session.getPeerPort(); if (hostPort == null || sessions.containsKey(hostPort)) {
if (host == null || port == -1) {
return false;
}
HostPort hostPort = new HostPort(host, port);
if (sessions.containsKey(hostPort)) {
return false; return false;
} }
sessions.put(hostPort, session); sessions.put(hostPort, session);
@ -52,12 +47,11 @@ final class OpenSslClientSessionCache extends OpenSslSessionCache {
@Override @Override
protected void sessionRemoved(OpenSslSession session) { protected void sessionRemoved(OpenSslSession session) {
assert Thread.holdsLock(this); assert Thread.holdsLock(this);
String host = session.getPeerHost(); HostPort hostPort = keyFor(session.getPeerHost(), session.getPeerPort());
int port = session.getPeerPort(); if (hostPort == null) {
if (host == null || port == -1) {
return; return;
} }
sessions.remove(new HostPort(host, port)); sessions.remove(hostPort);
} }
private static boolean isProtocolEnabled(OpenSslSession session, String[] enabledProtocols) { private static boolean isProtocolEnabled(OpenSslSession session, String[] enabledProtocols) {
@ -79,12 +73,10 @@ final class OpenSslClientSessionCache extends OpenSslSessionCache {
} }
void setSession(ReferenceCountedOpenSslEngine engine) throws SSLException { void setSession(ReferenceCountedOpenSslEngine engine) throws SSLException {
String host = engine.getPeerHost(); HostPort hostPort = keyFor(engine.getPeerHost(), engine.getPeerPort());
int port = engine.getPeerPort(); if (hostPort == null) {
if (host == null || port == -1) {
return; return;
} }
HostPort hostPort = new HostPort(host, port);
final OpenSslSession session; final OpenSslSession session;
synchronized (this) { synchronized (this) {
session = sessions.get(hostPort); session = sessions.get(hostPort);
@ -93,7 +85,7 @@ final class OpenSslClientSessionCache extends OpenSslSessionCache {
} }
assert session.refCnt() >= 1; assert session.refCnt() >= 1;
if (!session.isValid()) { if (!session.isValid()) {
removeSession(session); removeSessionWithId(session.sessionId());
return; return;
} }
} }
@ -115,6 +107,19 @@ final class OpenSslClientSessionCache extends OpenSslSessionCache {
} }
} }
private static HostPort keyFor(String host, int port) {
if (host == null && port < 1) {
return null;
}
return new HostPort(host, port);
}
@Override
synchronized void clear() {
super.clear();
sessions.clear();
}
/** /**
* Host / Port tuple used to find a {@link OpenSslSession} in the cache. * Host / Port tuple used to find a {@link OpenSslSession} in the cache.
*/ */

View File

@ -31,6 +31,8 @@ import java.util.concurrent.atomic.AtomicInteger;
* {@link SSLSessionCache} implementation for our native SSL implementation. * {@link SSLSessionCache} implementation for our native SSL implementation.
*/ */
class OpenSslSessionCache implements SSLSessionCache { class OpenSslSessionCache implements SSLSessionCache {
private static final OpenSslSession[] EMPTY_SESSIONS = new OpenSslSession[0];
private static final int DEFAULT_CACHE_SIZE; private static final int DEFAULT_CACHE_SIZE;
static { static {
// Respect the same system property as the JDK implementation to make it easy to switch between implementations. // Respect the same system property as the JDK implementation to make it easy to switch between implementations.
@ -52,8 +54,7 @@ class OpenSslSessionCache implements SSLSessionCache {
protected boolean removeEldestEntry(Map.Entry<OpenSslSessionId, OpenSslSession> eldest) { protected boolean removeEldestEntry(Map.Entry<OpenSslSessionId, OpenSslSession> eldest) {
int maxSize = maximumCacheSize.get(); int maxSize = maximumCacheSize.get();
if (maxSize >= 0 && this.size() > maxSize) { if (maxSize >= 0 && this.size() > maxSize) {
OpenSslSession session = eldest.getValue(); removeSessionWithId(eldest.getKey());
removeSession(session);
} }
// We always need to return false as we modify the map directly. // We always need to return false as we modify the map directly.
return false; return false;
@ -76,7 +77,7 @@ class OpenSslSessionCache implements SSLSessionCache {
if (oldTimeout > seconds) { if (oldTimeout > seconds) {
// Drain the whole cache as this way we can use the ordering of the LinkedHashMap to detect early // Drain the whole cache as this way we can use the ordering of the LinkedHashMap to detect early
// if there are any other sessions left that are invalid. // if there are any other sessions left that are invalid.
freeSessions(); clear();
} }
} }
@ -105,7 +106,7 @@ class OpenSslSessionCache implements SSLSessionCache {
long oldSize = maximumCacheSize.getAndSet(size); long oldSize = maximumCacheSize.getAndSet(size);
if (oldSize > size) { if (oldSize > size) {
// Just keep it simple for now and drain the whole cache. // Just keep it simple for now and drain the whole cache.
freeSessions(); clear();
} }
} }
@ -113,7 +114,11 @@ class OpenSslSessionCache implements SSLSessionCache {
return maximumCacheSize.get(); return maximumCacheSize.get();
} }
private void expungeInvalidSessions(long now) { private void expungeInvalidSessions() {
if (sessions.isEmpty()) {
return;
}
long now = System.currentTimeMillis();
Iterator<Map.Entry<OpenSslSessionId, OpenSslSession>> iterator = sessions.entrySet().iterator(); Iterator<Map.Entry<OpenSslSessionId, OpenSslSession>> iterator = sessions.entrySet().iterator();
while (iterator.hasNext()) { while (iterator.hasNext()) {
OpenSslSession session = iterator.next().getValue(); OpenSslSession session = iterator.next().getValue();
@ -124,8 +129,8 @@ class OpenSslSessionCache implements SSLSessionCache {
break; break;
} }
iterator.remove(); iterator.remove();
sessionRemoved(session);
session.release(); notifyRemovalAndRelease(session);
} }
} }
@ -133,30 +138,31 @@ class OpenSslSessionCache implements SSLSessionCache {
public final boolean sessionCreated(long ssl, long sslSession) { public final boolean sessionCreated(long ssl, long sslSession) {
ReferenceCountedOpenSslEngine engine = engineMap.get(ssl); ReferenceCountedOpenSslEngine engine = engineMap.get(ssl);
if (engine == null) { if (engine == null) {
// We couldn't find the engine itself.
return false; return false;
} }
final OpenSslSession session = engine.sessionCreated(sslSession); final OpenSslSession session = engine.sessionCreated(sslSession);
if (session == null) {
return false;
}
synchronized (this) { synchronized (this) {
// Mimic what OpenSSL is doing and expunge every 255 new sessions // Mimic what OpenSSL is doing and expunge every 255 new sessions
// See https://www.openssl.org/docs/man1.0.2/man3/SSL_CTX_flush_sessions.html // See https://www.openssl.org/docs/man1.0.2/man3/SSL_CTX_flush_sessions.html
if (++sessionCounter == 255) { if (++sessionCounter == 255) {
sessionCounter = 0; sessionCounter = 0;
expungeInvalidSessions(System.currentTimeMillis()); expungeInvalidSessions();
} }
if (session == null) {
return false;
}
assert session.refCnt() >= 1; assert session.refCnt() >= 1;
if (!sessionCreated(session)) { if (!sessionCreated(session)) {
// Should not be cached, return false.
return false; return false;
} }
final OpenSslSession old = sessions.put(session.sessionId(), session.retain()); final OpenSslSession old = sessions.put(session.sessionId(), session.retain());
if (old != null) { if (old != null) {
// Let's remove the old session and release it. notifyRemovalAndRelease(old);
sessionRemoved(old);
old.release();
} }
} }
return true; return true;
@ -173,24 +179,26 @@ class OpenSslSessionCache implements SSLSessionCache {
} }
assert session.refCnt() >= 1; assert session.refCnt() >= 1;
if (!session.isValid()) { // If the session is not valid anymore we should remove it from the cache and just signal back
removeSession(session); // that we couldn't find a session that is re-usable.
return -1; if (!session.isValid() ||
} // This needs to happen in the synchronized block so we ensure we never destroy it before we
// incremented the reference count. If we cant increment the reference count there is something
// This needs to happen in the synchronized block so we ensure we never destroy it before we incremented // wrong. In this case just remove the session from the cache and signal back that we couldn't
// the reference count. // find a session for re-use.
if (!session.upRef()) { !session.upRef()) {
// we could not increment the reference count, something is wrong. Let's just drop the session. removeSessionWithId(session.sessionId());
removeSession(session);
return -1; return -1;
} }
} }
// Now that we incremented the reference count of the native SSL_SESSION we should also retain the reference
// count of our java session.
session.retain(); session.retain();
if (session.shouldBeSingleUse()) { if (session.shouldBeSingleUse()) {
// Should only be used once // Should only be used once. In this case invalidate the session which will also ensure we remove it
// from the cache and not use it again in the future.
session.invalidate(); session.invalidate();
} }
session.updateLastAccessedTime(); session.updateLastAccessedTime();
@ -200,13 +208,11 @@ class OpenSslSessionCache implements SSLSessionCache {
final synchronized void removeSessionWithId(OpenSslSessionId id) { final synchronized void removeSessionWithId(OpenSslSessionId id) {
OpenSslSession sslSession = sessions.remove(id); OpenSslSession sslSession = sessions.remove(id);
if (sslSession != null) { if (sslSession != null) {
sessionRemoved(sslSession); notifyRemovalAndRelease(sslSession);
sslSession.release();
} }
} }
protected final void removeSession(OpenSslSession session) { private void notifyRemovalAndRelease(OpenSslSession session) {
sessions.remove(session.sessionId());
sessionRemoved(session); sessionRemoved(session);
session.release(); session.release();
} }
@ -215,11 +221,10 @@ class OpenSslSessionCache implements SSLSessionCache {
OpenSslSessionId id = new OpenSslSessionId(bytes); OpenSslSessionId id = new OpenSslSessionId(bytes);
synchronized (this) { synchronized (this) {
OpenSslSession session = sessions.get(id); OpenSslSession session = sessions.get(id);
if (session == null) { if (session != null && !session.isValid()) {
return null; // The session is not valid anymore, let's remove it and just signale back that there is no session
} // with the given ID in the cache anymore.
if (!session.isValid()) { removeSessionWithId(session.sessionId());
removeSession(session);
return null; return null;
} }
return session; return session;
@ -229,7 +234,7 @@ class OpenSslSessionCache implements SSLSessionCache {
final List<byte[]> getIds() { final List<byte[]> getIds() {
final OpenSslSession[] sessionsArray; final OpenSslSession[] sessionsArray;
synchronized (this) { synchronized (this) {
sessionsArray = sessions.values().toArray(new OpenSslSession[0]); sessionsArray = sessions.values().toArray(EMPTY_SESSIONS);
} }
List<byte[]> ids = new ArrayList<byte[]>(sessionsArray.length); List<byte[]> ids = new ArrayList<byte[]>(sessionsArray.length);
for (OpenSslSession session: sessionsArray) { for (OpenSslSession session: sessionsArray) {
@ -240,14 +245,12 @@ class OpenSslSessionCache implements SSLSessionCache {
return ids; return ids;
} }
final synchronized void freeSessions() { synchronized void clear() {
Iterator<Map.Entry<OpenSslSessionId, OpenSslSession>> iterator = sessions.entrySet().iterator(); Iterator<Map.Entry<OpenSslSessionId, OpenSslSession>> iterator = sessions.entrySet().iterator();
while (iterator.hasNext()) { while (iterator.hasNext()) {
OpenSslSession session = iterator.next().getValue(); OpenSslSession session = iterator.next().getValue();
iterator.remove(); iterator.remove();
sessionRemoved(session); notifyRemovalAndRelease(session);
session.release();
} }
} }
} }

View File

@ -175,7 +175,7 @@ public abstract class OpenSslSessionContext implements SSLSessionContext {
try { try {
SSLContext.setSessionCacheMode(context.ctx, mode); SSLContext.setSessionCacheMode(context.ctx, mode);
if (!enabled) { if (!enabled) {
sessionCache.freeSessions(); sessionCache.clear();
} }
} finally { } finally {
writerLock.unlock(); writerLock.unlock();
@ -213,6 +213,6 @@ public abstract class OpenSslSessionContext implements SSLSessionContext {
if (provider != null) { if (provider != null) {
provider.destroy(); provider.destroy();
} }
sessionCache.freeSessions(); sessionCache.clear();
} }
} }