From 3dde52e805c2cba14c4f1895c00dd6d0b3749252 Mon Sep 17 00:00:00 2001 From: Carter Kozak Date: Mon, 2 Oct 2023 16:59:59 -0400 Subject: [PATCH] fix #1170: Use ReentrantLock instead of Object monitors --- .../java/org/conscrypt/ConscryptEngine.java | 139 +++++++++--- .../org/conscrypt/ConscryptEngineSocket.java | 106 +++++++--- .../ConscryptFileDescriptorSocket.java | 198 ++++++++++++++---- 3 files changed, 349 insertions(+), 94 deletions(-) diff --git a/common/src/main/java/org/conscrypt/ConscryptEngine.java b/common/src/main/java/org/conscrypt/ConscryptEngine.java index a58aa73cb..0fd314a8d 100644 --- a/common/src/main/java/org/conscrypt/ConscryptEngine.java +++ b/common/src/main/java/org/conscrypt/ConscryptEngine.java @@ -77,6 +77,8 @@ import java.security.interfaces.ECKey; import java.security.spec.ECParameterSpec; import java.util.Arrays; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; import javax.crypto.SecretKey; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; @@ -128,7 +130,7 @@ final class ConscryptEngine extends AbstractConscryptEngine implements NativeCry */ private String peerHostname; - // @GuardedBy("ssl"); + // @GuardedBy("sslLock"); private int state = STATE_NEW; private boolean handshakeFinished; @@ -137,6 +139,11 @@ final class ConscryptEngine extends AbstractConscryptEngine implements NativeCry */ private final NativeSsl ssl; + /** + * Lock used for {@link #ssl} access. + */ + private final Lock sslLock = new ReentrantLock(); + /** * The BIO used for reading/writing encrypted bytes. */ @@ -227,12 +234,15 @@ static BufferAllocator getDefaultBufferAllocator() { @Override void setBufferAllocator(BufferAllocator bufferAllocator) { - synchronized (ssl) { + sslLock.lock(); + try { if (isHandshakeStarted()) { throw new IllegalStateException( "Could not set buffer allocator after the initial handshake has begun."); } this.bufferAllocator = bufferAllocator; + } finally { + sslLock.unlock(); } } @@ -254,7 +264,8 @@ int maxSealOverhead() { */ @Override void setChannelIdEnabled(boolean enabled) { - synchronized (ssl) { + sslLock.lock(); + try { if (getUseClientMode()) { throw new IllegalStateException("Not allowed in client mode"); } @@ -263,6 +274,8 @@ void setChannelIdEnabled(boolean enabled) { "Could not enable/disable Channel ID after the initial handshake has begun."); } sslParameters.channelIdEnabled = enabled; + } finally { + sslLock.unlock(); } } @@ -278,7 +291,8 @@ void setChannelIdEnabled(boolean enabled) { */ @Override byte[] getChannelId() throws SSLException { - synchronized (ssl) { + sslLock.lock(); + try { if (getUseClientMode()) { throw new IllegalStateException("Not allowed in client mode"); } @@ -288,6 +302,8 @@ byte[] getChannelId() throws SSLException { "Channel ID is only available after handshake completes"); } return ssl.getTlsChannelId(); + } finally { + sslLock.unlock(); } } @@ -309,7 +325,8 @@ void setChannelIdPrivateKey(PrivateKey privateKey) { throw new IllegalStateException("Not allowed in server mode"); } - synchronized (ssl) { + sslLock.lock(); + try { if (isHandshakeStarted()) { throw new IllegalStateException("Could not change Channel ID private key " + "after the initial handshake has begun."); @@ -337,6 +354,8 @@ void setChannelIdPrivateKey(PrivateKey privateKey) { } catch (InvalidKeyException e) { // Will have error in startHandshake } + } finally { + sslLock.unlock(); } } @@ -345,12 +364,15 @@ void setChannelIdPrivateKey(PrivateKey privateKey) { */ @Override void setHandshakeListener(HandshakeListener handshakeListener) { - synchronized (ssl) { + sslLock.lock(); + try { if (isHandshakeStarted()) { throw new IllegalStateException( "Handshake listener must be set before starting the handshake."); } this.handshakeListener = handshakeListener; + } finally { + sslLock.unlock(); } } @@ -397,8 +419,11 @@ public int getPeerPort() { @Override public void beginHandshake() throws SSLException { - synchronized (ssl) { + sslLock.lock(); + try { beginHandshakeInternal(); + } finally { + sslLock.unlock(); } } @@ -452,7 +477,8 @@ private void beginHandshakeInternal() throws SSLException { @Override public void closeInbound() { - synchronized (ssl) { + sslLock.lock(); + try { if (state == STATE_CLOSED || state == STATE_CLOSED_INBOUND) { return; } @@ -467,12 +493,15 @@ public void closeInbound() { // Never started the handshake. Just close now. closeAndFreeResources(); } + } finally { + sslLock.unlock(); } } @Override public void closeOutbound() { - synchronized (ssl) { + sslLock.lock(); + try { if (state == STATE_CLOSED || state == STATE_CLOSED_OUTBOUND) { return; } @@ -488,6 +517,8 @@ public void closeOutbound() { // Never started the handshake. Just close now. closeAndFreeResources(); } + } finally { + sslLock.unlock(); } } @@ -527,8 +558,11 @@ public void setSSLParameters(SSLParameters p) { @Override public HandshakeStatus getHandshakeStatus() { - synchronized (ssl) { + sslLock.lock(); + try { return getHandshakeStatusInternal(); + } finally { + sslLock.unlock(); } } @@ -578,7 +612,8 @@ public boolean getNeedClientAuth() { */ @Override SSLSession handshakeSession() { - synchronized (ssl) { + sslLock.lock(); + try { if (state == STATE_HANDSHAKE_STARTED) { return Platform.wrapSSLSession(new ExternalSession(new ExternalSession.Provider() { @Override @@ -588,6 +623,8 @@ public ConscryptSession provideSession() { })); } return null; + } finally { + sslLock.unlock(); } } @@ -597,7 +634,8 @@ public SSLSession getSession() { } private ConscryptSession provideSession() { - synchronized (ssl) { + sslLock.lock(); + try { if (state == STATE_CLOSED) { return closedSession != null ? closedSession : SSLNullSession.getNullSession(); } @@ -606,13 +644,18 @@ private ConscryptSession provideSession() { return SSLNullSession.getNullSession(); } return activeSession; + } finally { + sslLock.unlock(); } } private ConscryptSession provideHandshakeSession() { - synchronized (ssl) { + sslLock.lock(); + try { return state == STATE_HANDSHAKE_STARTED ? activeSession : SSLNullSession.getNullSession(); + } finally { + sslLock.unlock(); } } @@ -646,21 +689,27 @@ public boolean getWantClientAuth() { @Override public boolean isInboundDone() { - synchronized (ssl) { + sslLock.lock(); + try { return (state == STATE_CLOSED || state == STATE_CLOSED_INBOUND || ssl.wasShutdownReceived()) && (pendingInboundCleartextBytes() == 0); + } finally { + sslLock.unlock(); } } @Override public boolean isOutboundDone() { - synchronized (ssl) { + sslLock.lock(); + try { return (state == STATE_CLOSED || state == STATE_CLOSED_OUTBOUND || ssl.wasShutdownSent()) && (pendingOutboundEncryptedBytes() == 0); + } finally { + sslLock.unlock(); } } @@ -686,13 +735,16 @@ public void setNeedClientAuth(boolean need) { @Override public void setUseClientMode(boolean mode) { - synchronized (ssl) { + sslLock.lock(); + try { if (isHandshakeStarted()) { throw new IllegalArgumentException( "Can not change mode after handshake: state == " + state); } transitionTo(STATE_MODE_SET); sslParameters.setUseClientMode(mode); + } finally { + sslLock.unlock(); } } @@ -703,36 +755,45 @@ public void setWantClientAuth(boolean want) { @Override public SSLEngineResult unwrap(ByteBuffer src, ByteBuffer dst) throws SSLException { - synchronized (ssl) { + sslLock.lock(); + try { try { return unwrap(singleSrcBuffer(src), singleDstBuffer(dst)); } finally { resetSingleSrcBuffer(); resetSingleDstBuffer(); } + } finally { + sslLock.unlock(); } } @Override public SSLEngineResult unwrap(ByteBuffer src, ByteBuffer[] dsts) throws SSLException { - synchronized (ssl) { + sslLock.lock(); + try { try { return unwrap(singleSrcBuffer(src), dsts); } finally { resetSingleSrcBuffer(); } + } finally { + sslLock.unlock(); } } @Override public SSLEngineResult unwrap(final ByteBuffer src, final ByteBuffer[] dsts, final int offset, final int length) throws SSLException { - synchronized (ssl) { + sslLock.lock(); + try { try { return unwrap(singleSrcBuffer(src), 0, 1, dsts, offset, length); } finally { resetSingleSrcBuffer(); } + } finally { + sslLock.unlock(); } } @@ -759,7 +820,8 @@ SSLEngineResult unwrap(final ByteBuffer[] srcs, int srcsOffset, final int srcsLe final int srcsEndOffset = srcsOffset + srcsLength; final long srcLength = calcSrcsLength(srcs, srcsOffset, srcsEndOffset); - synchronized (ssl) { + sslLock.lock(); + try { switch (state) { case STATE_MODE_SET: // Begin the handshake implicitly. @@ -930,6 +992,8 @@ SSLEngineResult unwrap(final ByteBuffer[] srcs, int srcsOffset, final int srcsLe } return newResult(bytesConsumed, bytesProduced, handshakeStatus); + } finally { + sslLock.unlock(); } } @@ -1366,12 +1430,15 @@ private SSLEngineResult newResult(int bytesConsumed, int bytesProduced, @Override public SSLEngineResult wrap(ByteBuffer src, ByteBuffer dst) throws SSLException { - synchronized (ssl) { + sslLock.lock(); + try { try { return wrap(singleSrcBuffer(src), dst); } finally { resetSingleSrcBuffer(); } + } finally { + sslLock.unlock(); } } @@ -1390,7 +1457,8 @@ public SSLEngineResult wrap(ByteBuffer[] srcs, int srcsOffset, int srcsLength, B } BufferUtils.checkNotNull(srcs); - synchronized (ssl) { + sslLock.lock(); + try { switch (state) { case STATE_MODE_SET: // Begin the handshake implicitly. @@ -1542,6 +1610,8 @@ public SSLEngineResult wrap(ByteBuffer[] srcs, int srcsOffset, int srcsLength, B } } return newResult(bytesConsumed, bytesProduced, handshakeStatus); + } finally { + sslLock.unlock(); } } @@ -1557,7 +1627,8 @@ public int serverPSKKeyRequested(String identityHint, String identity, byte[] ke @Override public void onSSLStateChange(int type, int val) { - synchronized (ssl) { + sslLock.lock(); + try { switch (type) { case SSL_CB_HANDSHAKE_START: { // For clients, this will allow the NEED_UNWRAP status to be @@ -1577,13 +1648,18 @@ public void onSSLStateChange(int type, int val) { default: // Ignore } + } finally { + sslLock.unlock(); } } @Override public void serverCertificateRequested() throws IOException { - synchronized (ssl) { + sslLock.lock(); + try { ssl.configureServerCertificate(); + } finally { + sslLock.unlock(); } } @@ -1677,8 +1753,11 @@ protected void finalize() throws Throwable { // If ssl is null, object must not be fully constructed so nothing for us to do here. if (ssl != null) { // Otherwise closeAndFreeResources() and callees expect to synchronize on ssl. - synchronized (ssl) { + sslLock.lock(); + try { closeAndFreeResources(); + } finally { + sslLock.unlock(); } } } finally { @@ -1758,10 +1837,13 @@ byte[] getTlsUnique() { @Override byte[] exportKeyingMaterial(String label, byte[] context, int length) throws SSLException { - synchronized (ssl) { + sslLock.lock(); + try { if (state < STATE_HANDSHAKE_COMPLETED || state == STATE_CLOSED) { return null; } + } finally { + sslLock.unlock(); } return ssl.exportKeyingMaterial(label, context, length); } @@ -1786,8 +1868,11 @@ public String getApplicationProtocol() { @Override public String getHandshakeApplicationProtocol() { - synchronized (ssl) { + sslLock.lock(); + try { return state >= STATE_HANDSHAKE_STARTED ? getApplicationProtocol() : null; + } finally { + sslLock.unlock(); } } diff --git a/common/src/main/java/org/conscrypt/ConscryptEngineSocket.java b/common/src/main/java/org/conscrypt/ConscryptEngineSocket.java index f05fe25aa..1063cc9e1 100644 --- a/common/src/main/java/org/conscrypt/ConscryptEngineSocket.java +++ b/common/src/main/java/org/conscrypt/ConscryptEngineSocket.java @@ -36,6 +36,9 @@ import java.security.PrivateKey; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLEngineResult.HandshakeStatus; @@ -54,8 +57,9 @@ class ConscryptEngineSocket extends OpenSSLSocketImpl implements SSLParametersIm private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocate(0); private final ConscryptEngine engine; - private final Object stateLock = new Object(); - private final Object handshakeLock = new Object(); + private final Lock stateLock = new ReentrantLock(); + private final Condition stateLockCondition = stateLock.newCondition(); + private final Lock handshakeLock = new ReentrantLock(); private SSLOutputStream out; private SSLInputStream in; @@ -188,10 +192,12 @@ public final void startHandshake() throws IOException { checkOpen(); try { - synchronized (handshakeLock) { + handshakeLock.lock(); + try { // Only lock stateLock when we begin the handshake. This is done so that we don't // hold the stateLock when we invoke the handshake completion listeners. - synchronized (stateLock) { + stateLock.lock(); + try { // Initialize the handshake if we haven't already. if (state == STATE_NEW) { transitionTo(STATE_HANDSHAKE_STARTED); @@ -206,8 +212,12 @@ public final void startHandshake() throws IOException { // ignore addition handshake calls. return; } + } finally { + stateLock.unlock(); } doHandshake(); + } finally { + handshakeLock.unlock(); } } catch (IOException e) { close(); @@ -277,13 +287,17 @@ private void doHandshake() throws IOException { } private boolean isState(int desiredState) { - synchronized (stateLock) { + stateLock.lock(); + try { return state == desiredState; + } finally { + stateLock.unlock(); } } private int transitionTo(int newState) { - synchronized (stateLock) { + stateLock.lock(); + try { if (state == newState) { return state; } @@ -328,9 +342,11 @@ private int transitionTo(int newState) { state = newState; if (notify) { - stateLock.notifyAll(); + stateLockCondition.signalAll(); } return previousState; + } finally { + stateLock.unlock(); } } @@ -341,10 +357,13 @@ public final InputStream getInputStream() throws IOException { } private SSLInputStream createInputStream() { - synchronized (stateLock) { + stateLock.lock(); + try { if (in == null) { in = new SSLInputStream(); } + } finally { + stateLock.unlock(); } return in; } @@ -356,10 +375,13 @@ public final OutputStream getOutputStream() throws IOException { } private SSLOutputStream createOutputStream() { - synchronized (stateLock) { + stateLock.lock(); + try { if (out == null) { out = new SSLOutputStream(); } + } finally { + stateLock.unlock(); } return out; } @@ -594,13 +616,14 @@ private void onEngineHandshakeFinished() { private void waitForHandshake() throws IOException { startHandshake(); - synchronized (stateLock) { + stateLock.lock(); + try { while (state != STATE_READY // Waiting threads are allowed to compete with handshake listeners for access. && state != STATE_READY_HANDSHAKE_CUT_THROUGH && state != STATE_CLOSED) { try { - stateLock.wait(); + stateLockCondition.await(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new IOException("Interrupted waiting for handshake", e); @@ -610,6 +633,8 @@ private void waitForHandshake() throws IOException { if (state == STATE_CLOSED) { throw new SocketException("Socket is closed"); } + } finally { + stateLock.unlock(); } } @@ -648,7 +673,7 @@ public final String chooseClientAlias(X509KeyManager keyManager, X500Principal[] * Wrap bytes written to the underlying socket. */ private final class SSLOutputStream extends OutputStream { - private final Object writeLock = new Object(); + private final ReentrantLock writeLock = new ReentrantLock(); private final ByteBuffer target; private final int targetArrayOffset; private OutputStream socketOutputStream; @@ -666,24 +691,33 @@ public void close() throws IOException { @Override public void write(int b) throws IOException { waitForHandshake(); - synchronized (writeLock) { + writeLock.lock(); + try { write(new byte[] {(byte) b}); + } finally { + writeLock.unlock(); } } @Override public void write(byte[] b) throws IOException { waitForHandshake(); - synchronized (writeLock) { + writeLock.lock(); + try { writeInternal(ByteBuffer.wrap(b)); + } finally { + writeLock.unlock(); } } @Override public void write(byte[] b, int off, int len) throws IOException { waitForHandshake(); - synchronized (writeLock) { + writeLock.lock(); + try { writeInternal(ByteBuffer.wrap(b, off, len)); + } finally { + writeLock.unlock(); } } @@ -727,8 +761,11 @@ private void writeInternal(ByteBuffer buffer) throws IOException { @Override public void flush() throws IOException { waitForHandshake(); - synchronized (writeLock) { + writeLock.lock(); + try { flushInternal(); + } finally { + writeLock.unlock(); } } @@ -754,7 +791,7 @@ private void writeToSocket() throws IOException { * Unwrap bytes read from the underlying socket. */ private final class SSLInputStream extends InputStream { - private final Object readLock = new Object(); + private final ReentrantLock readLock = new ReentrantLock(); private final byte[] singleByte = new byte[1]; private final ByteBuffer fromEngine; private final ByteBuffer fromSocket; @@ -783,17 +820,21 @@ public void close() throws IOException { } void release() { - synchronized (readLock) { + readLock.lock(); + try { if (allocatedBuffer != null) { allocatedBuffer.release(); } + } finally { + readLock.unlock(); } } @Override public int read() throws IOException { waitForHandshake(); - synchronized (readLock) { + readLock.lock(); + try { // Handle returning of -1 if EOF is reached. int count = read(singleByte, 0, 1); if (count == -1) { @@ -804,31 +845,42 @@ public int read() throws IOException { throw new SSLException("read incorrect number of bytes " + count); } return singleByte[0] & 0xff; + } finally { + readLock.unlock(); } } @Override public int read(byte[] b) throws IOException { waitForHandshake(); - synchronized (readLock) { + readLock.lock(); + try { return read(b, 0, b.length); + } finally { + readLock.unlock(); } } @Override public int read(byte[] b, int off, int len) throws IOException { waitForHandshake(); - synchronized (readLock) { + readLock.lock(); + try { return readUntilDataAvailable(b, off, len); + } finally { + readLock.unlock(); } } @Override public int available() throws IOException { waitForHandshake(); - synchronized (readLock) { + readLock.lock(); + try { init(); return fromEngine.remaining(); + } finally { + readLock.unlock(); } } @@ -941,8 +993,11 @@ && isHandshakeFinished()) { } private boolean isHandshakeFinished() { - synchronized (stateLock) { + stateLock.lock(); + try { return state > STATE_HANDSHAKE_STARTED; + } finally { + stateLock.unlock(); } } @@ -950,8 +1005,11 @@ private boolean isHandshakeFinished() { * Processes a renegotiation received from the remote peer. */ private void renegotiate() throws IOException { - synchronized (handshakeLock) { + handshakeLock.lock(); + try { doHandshake(); + } finally { + handshakeLock.unlock(); } } diff --git a/common/src/main/java/org/conscrypt/ConscryptFileDescriptorSocket.java b/common/src/main/java/org/conscrypt/ConscryptFileDescriptorSocket.java index f5ef00d74..92a76601f 100644 --- a/common/src/main/java/org/conscrypt/ConscryptFileDescriptorSocket.java +++ b/common/src/main/java/org/conscrypt/ConscryptFileDescriptorSocket.java @@ -36,6 +36,9 @@ import java.security.cert.X509Certificate; import java.security.interfaces.ECKey; import java.security.spec.ECParameterSpec; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; import javax.crypto.SecretKey; import javax.net.ssl.SSLException; import javax.net.ssl.SSLHandshakeException; @@ -64,7 +67,7 @@ class ConscryptFileDescriptorSocket extends OpenSSLSocketImpl SSLParametersImpl.AliasChooser { private static final boolean DBG_STATE = false; - // @GuardedBy("ssl"); + // @GuardedBy("sslLock") private int state = STATE_NEW; /** @@ -72,6 +75,12 @@ class ConscryptFileDescriptorSocket extends OpenSSLSocketImpl */ private final NativeSsl ssl; + /** + * Lock used for {@link #ssl} access. + */ + private final Lock sslLock = new ReentrantLock(); + private final Condition sslLockCondition = sslLock.newCondition(); + /** * Protected by synchronizing on ssl. Starts as null, set by * getInputStream. @@ -183,7 +192,8 @@ private static NativeSsl newSsl(SSLParametersImpl sslParameters, @Override public final void startHandshake() throws IOException { checkOpen(); - synchronized (ssl) { + sslLock.lock(); + try { if (state == STATE_NEW) { transitionTo(STATE_HANDSHAKE_STARTED); } else { @@ -191,6 +201,8 @@ public final void startHandshake() throws IOException { // Do nothing in both cases. return; } + } finally { + sslLock.unlock(); } boolean releaseResources = true; @@ -218,10 +230,13 @@ public final void startHandshake() throws IOException { setSoWriteTimeout(handshakeTimeoutMilliseconds); } - synchronized (ssl) { + sslLock.lock(); + try { if (state == STATE_CLOSED) { return; } + } finally { + sslLock.unlock(); } try { @@ -241,18 +256,24 @@ public final void startHandshake() throws IOException { // (or WANT_WRITE). Catching that exception here doesn't seem much worse than // changing the native code to return a "special" native pointer value when that // happens. - synchronized (ssl) { + sslLock.lock(); + try { if (state == STATE_CLOSED) { return; } + } finally { + sslLock.unlock(); } throw e; } - synchronized (ssl) { + sslLock.lock(); + try { if (state == STATE_CLOSED) { return; } + } finally { + sslLock.unlock(); } // Restore the original timeout now that the handshake is complete @@ -261,7 +282,8 @@ public final void startHandshake() throws IOException { setSoWriteTimeout(savedWriteTimeoutMilliseconds); } - synchronized (ssl) { + sslLock.lock(); + try { releaseResources = (state == STATE_CLOSED); if (state == STATE_HANDSHAKE_STARTED) { @@ -273,22 +295,27 @@ public final void startHandshake() throws IOException { if (!releaseResources) { // Unblock threads that are waiting for our state to transition // into STATE_READY or STATE_READY_HANDSHAKE_CUT_THROUGH. - ssl.notifyAll(); + sslLockCondition.signalAll(); } + } finally { + sslLock.unlock(); } } catch (SSLProtocolException e) { throw(SSLHandshakeException) new SSLHandshakeException("Handshake failed").initCause(e); } finally { // on exceptional exit, treat the socket as closed if (releaseResources) { - synchronized (ssl) { + sslLock.lock(); + try { // Mark the socket as closed since we might have reached this as // a result on an exception thrown by the handshake process. // // The state will already be set to closed if we reach this as a result of // an early return or an interruption due to a concurrent call to close(). transitionTo(STATE_CLOSED); - ssl.notifyAll(); + sslLockCondition.signalAll(); + } finally { + sslLock.unlock(); } try { @@ -329,7 +356,8 @@ public final void onSSLStateChange(int type, int val) { } // First, update the state. - synchronized (ssl) { + sslLock.lock(); + try { if (state == STATE_CLOSED) { // Someone called "close" but the handshake hasn't been interrupted yet. return; @@ -338,14 +366,19 @@ public final void onSSLStateChange(int type, int val) { // Now that we've fixed up our state, we can tell waiting threads that // we're ready. transitionTo(STATE_READY); + } finally { + sslLock.unlock(); } // Let listeners know we are finally done notifyHandshakeCompletedListeners(); - synchronized (ssl) { + sslLock.lock(); + try { // Notify all threads waiting for the handshake to complete. - ssl.notifyAll(); + sslLockCondition.signalAll(); + } finally { + sslLock.unlock(); } } @@ -379,8 +412,11 @@ public final long serverSessionRequested(byte[] id) { @Override public final void serverCertificateRequested() throws IOException { - synchronized (ssl) { + sslLock.lock(); + try { ssl.configureServerCertificate(); + } finally { + sslLock.unlock(); } } @@ -418,7 +454,8 @@ public final InputStream getInputStream() throws IOException { checkOpen(); InputStream returnVal; - synchronized (ssl) { + sslLock.lock(); + try { if (state == STATE_CLOSED) { throw new SocketException("Socket is closed."); } @@ -428,6 +465,8 @@ public final InputStream getInputStream() throws IOException { } returnVal = is; + } finally { + sslLock.unlock(); } // Block waiting for a handshake without a lock held. It's possible that the socket @@ -442,7 +481,8 @@ public final OutputStream getOutputStream() throws IOException { checkOpen(); OutputStream returnVal; - synchronized (ssl) { + sslLock.lock(); + try { if (state == STATE_CLOSED) { throw new SocketException("Socket is closed."); } @@ -452,6 +492,8 @@ public final OutputStream getOutputStream() throws IOException { } returnVal = os; + } finally { + sslLock.unlock(); } // Block waiting for a handshake without a lock held. It's possible that the socket @@ -472,12 +514,13 @@ private void assertReadableOrWriteableState() { private void waitForHandshake() throws IOException { startHandshake(); - synchronized (ssl) { + sslLock.lock(); + try { while (state != STATE_READY && state != STATE_READY_HANDSHAKE_CUT_THROUGH && state != STATE_CLOSED) { try { - ssl.wait(); + sslLockCondition.await(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new IOException("Interrupted waiting for handshake", e); @@ -487,6 +530,8 @@ private void waitForHandshake() throws IOException { if (state == STATE_CLOSED) { throw new SocketException("Socket is closed"); } + } finally { + sslLock.unlock(); } } @@ -501,7 +546,7 @@ private class SSLInputStream extends InputStream { * make sure we serialize callers of SSL_read. Thread is already * expected to have completed handshaking. */ - private final Object readLock = new Object(); + private final Lock readLock = new ReentrantLock(); SSLInputStream() { } @@ -532,8 +577,10 @@ public int read(byte[] buf, int offset, int byteCount) throws IOException { return 0; } - synchronized (readLock) { - synchronized (ssl) { + readLock.lock(); + try { + sslLock.lock(); + try { if (state == STATE_CLOSED) { throw new SocketException("socket is closed"); } @@ -541,18 +588,25 @@ public int read(byte[] buf, int offset, int byteCount) throws IOException { if (DBG_STATE) { assertReadableOrWriteableState(); } + } finally { + sslLock.unlock(); } int ret = ssl.read( Platform.getFileDescriptor(socket), buf, offset, byteCount, getSoTimeout()); if (ret == -1) { - synchronized (ssl) { + sslLock.lock(); + try { if (state == STATE_CLOSED) { throw new SocketException("socket is closed"); } + } finally { + sslLock.unlock(); } } return ret; + } finally { + readLock.unlock(); } } @@ -563,14 +617,22 @@ public int available() { void awaitPendingOps() { if (DBG_STATE) { - synchronized (ssl) { + sslLock.lock(); + try { if (state != STATE_CLOSED) { throw new AssertionError("State is: " + state); } + } finally { + sslLock.unlock(); } } - synchronized (readLock) {} + readLock.lock(); + try { + // Satisfy https://errorprone.info/bugpattern/LockNotBeforeTry + } finally { + readLock.unlock(); + } } } @@ -585,7 +647,7 @@ private class SSLOutputStream extends OutputStream { * to make sure we serialize callers of SSL_write. Thread is * already expected to have completed handshaking. */ - private final Object writeLock = new Object(); + private final Lock writeLock = new ReentrantLock(); SSLOutputStream() { } @@ -614,8 +676,10 @@ public void write(byte[] buf, int offset, int byteCount) throws IOException { return; } - synchronized (writeLock) { - synchronized (ssl) { + writeLock.lock(); + try { + sslLock.lock(); + try { if (state == STATE_CLOSED) { throw new SocketException("socket is closed"); } @@ -623,29 +687,44 @@ public void write(byte[] buf, int offset, int byteCount) throws IOException { if (DBG_STATE) { assertReadableOrWriteableState(); } + } finally { + sslLock.unlock(); } ssl.write(Platform.getFileDescriptor(socket), buf, offset, byteCount, writeTimeoutMilliseconds); - synchronized (ssl) { + sslLock.lock(); + try { if (state == STATE_CLOSED) { throw new SocketException("socket is closed"); } + } finally { + sslLock.unlock(); } + } finally { + writeLock.unlock(); } } void awaitPendingOps() { if (DBG_STATE) { - synchronized (ssl) { + sslLock.lock(); + try { if (state != STATE_CLOSED) { throw new AssertionError("State is: " + state); } + } finally { + sslLock.unlock(); } } - synchronized (writeLock) {} + writeLock.lock(); + try { + // Satisfy https://errorprone.info/bugpattern/LockNotBeforeTry + } finally { + writeLock.unlock(); + } } } @@ -656,7 +735,8 @@ public final SSLSession getSession() { private ConscryptSession provideSession() { boolean handshakeCompleted = false; - synchronized (ssl) { + sslLock.lock(); + try { if (state == STATE_CLOSED) { return closedSession != null ? closedSession : SSLNullSession.getNullSession(); } @@ -670,6 +750,8 @@ private ConscryptSession provideSession() { } catch (IOException e) { // Fall through. } + } finally { + sslLock.unlock(); } if (!handshakeCompleted) { @@ -691,9 +773,12 @@ private ConscryptSession provideAfterHandshakeSession() { // If handshake is in progress, provide active session otherwise a null session. private ConscryptSession provideHandshakeSession() { - synchronized (ssl) { + sslLock.lock(); + try { return state >= STATE_HANDSHAKE_STARTED && state < STATE_READY ? activeSession : SSLNullSession.getNullSession(); + } finally { + sslLock.unlock(); } } @@ -704,7 +789,8 @@ final SSLSession getActiveSession() { @Override public final SSLSession getHandshakeSession() { - synchronized (ssl) { + sslLock.lock(); + try { if (state >= STATE_HANDSHAKE_STARTED && state < STATE_READY) { return Platform.wrapSSLSession(new ExternalSession(new ExternalSession.Provider() { @Override @@ -714,6 +800,8 @@ public ConscryptSession provideSession() { })); } return null; + } finally { + sslLock.unlock(); } } @@ -793,12 +881,15 @@ public final void setChannelIdEnabled(boolean enabled) { throw new IllegalStateException("Client mode"); } - synchronized (ssl) { + sslLock.lock(); + try { if (state != STATE_NEW) { throw new IllegalStateException( "Could not enable/disable Channel ID after the initial handshake has" + " begun."); } + } finally { + sslLock.unlock(); } sslParameters.channelIdEnabled = enabled; } @@ -819,11 +910,14 @@ public final byte[] getChannelId() throws SSLException { throw new IllegalStateException("Client mode"); } - synchronized (ssl) { + sslLock.lock(); + try { if (state != STATE_READY) { throw new IllegalStateException( "Channel ID is only available after handshake completes"); } + } finally { + sslLock.unlock(); } return ssl.getTlsChannelId(); } @@ -846,12 +940,15 @@ public final void setChannelIdPrivateKey(PrivateKey privateKey) { throw new IllegalStateException("Server mode"); } - synchronized (ssl) { + sslLock.lock(); + try { if (state != STATE_NEW) { throw new IllegalStateException( "Could not change Channel ID private key after the initial handshake has" + " begun."); } + } finally { + sslLock.unlock(); } if (privateKey == null) { @@ -884,10 +981,13 @@ byte[] getTlsUnique() { @Override byte[] exportKeyingMaterial(String label, byte[] context, int length) throws SSLException { - synchronized (ssl) { + sslLock.lock(); + try { if (state < STATE_HANDSHAKE_COMPLETED || state == STATE_CLOSED) { return null; } + } finally { + sslLock.unlock(); } return ssl.exportKeyingMaterial(label, context, length); } @@ -899,11 +999,14 @@ public final boolean getUseClientMode() { @Override public final void setUseClientMode(boolean mode) { - synchronized (ssl) { + sslLock.lock(); + try { if (state != STATE_NEW) { throw new IllegalArgumentException( "Could not change the mode after the initial handshake has begun."); } + } finally { + sslLock.unlock(); } sslParameters.setUseClientMode(mode); } @@ -969,7 +1072,8 @@ public final void close() throws IOException { return; } - synchronized (ssl) { + sslLock.lock(); + try { if (state == STATE_CLOSED) { // close() has already been called, so do nothing and return. return; @@ -985,7 +1089,7 @@ public final void close() throws IOException { free(); closeUnderlyingSocket(); - ssl.notifyAll(); + sslLockCondition.signalAll(); return; } @@ -996,15 +1100,17 @@ public final void close() throws IOException { // after SSL_do_handshake returns, so we don't have anything to do here. ssl.interrupt(); - ssl.notifyAll(); + sslLockCondition.signalAll(); return; } - ssl.notifyAll(); + sslLockCondition.signalAll(); // We've already returned from startHandshake, so we potentially have // input and output streams to clean up. sslInputStream = is; sslOutputStream = os; + } finally { + sslLock.unlock(); } // Don't bother interrupting unless we have something to interrupt. @@ -1077,8 +1183,11 @@ protected final void finalize() throws Throwable { Platform.closeGuardWarnIfOpen(guard); } if (ssl != null) { - synchronized (ssl) { + sslLock.lock(); + try { transitionTo(STATE_CLOSED); + } finally { + sslLock.unlock(); } } } finally { @@ -1123,9 +1232,12 @@ public final String getApplicationProtocol() { @Override public final String getHandshakeApplicationProtocol() { - synchronized (ssl) { + sslLock.lock(); + try { return state >= STATE_HANDSHAKE_STARTED && state < STATE_READY ? getApplicationProtocol() : null; + } finally { + sslLock.unlock(); } }