diff --git a/driver-core/src/main/com/mongodb/internal/TimeoutContext.java b/driver-core/src/main/com/mongodb/internal/TimeoutContext.java index 2a886704cd..ba3b8eb0ac 100644 --- a/driver-core/src/main/com/mongodb/internal/TimeoutContext.java +++ b/driver-core/src/main/com/mongodb/internal/TimeoutContext.java @@ -70,6 +70,9 @@ public static MongoOperationTimeoutException createMongoTimeoutException(final S public static T throwMongoTimeoutException(final String message) { throw new MongoOperationTimeoutException(message); } + public static T throwMongoTimeoutException() { + throw new MongoOperationTimeoutException("The operation exceeded the timeout limit."); + } public static MongoOperationTimeoutException createMongoTimeoutException(final Throwable cause) { return createMongoTimeoutException("Operation exceeded the timeout limit: " + cause.getMessage(), cause); diff --git a/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java index 35f9f8120e..294e88b81e 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java @@ -68,7 +68,7 @@ public String getMechanismName() { } @Override - protected SaslClient createSaslClient(final ServerAddress serverAddress) { + protected SaslClient createSaslClient(final ServerAddress serverAddress, final OperationContext operationContext) { return new AwsSaslClient(getMongoCredential()); } diff --git a/driver-core/src/main/com/mongodb/internal/connection/GSSAPIAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/GSSAPIAuthenticator.java index 43d634c199..c3902751ec 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/GSSAPIAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/GSSAPIAuthenticator.java @@ -67,7 +67,7 @@ public String getMechanismName() { } @Override - protected SaslClient createSaslClient(final ServerAddress serverAddress) { + protected SaslClient createSaslClient(final ServerAddress serverAddress, final OperationContext operationContext) { MongoCredential credential = getMongoCredential(); try { Map saslClientProperties = credential.getMechanismProperty(JAVA_SASL_CLIENT_PROPERTIES_KEY, null); diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java index 1e67626d60..87f48b3308 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -29,6 +29,7 @@ import com.mongodb.connection.ClusterConnectionMode; import com.mongodb.connection.ConnectionDescription; import com.mongodb.internal.Locks; +import com.mongodb.internal.TimeoutContext; import com.mongodb.internal.VisibleForTesting; import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.authentication.AzureCredentialHelper; @@ -45,10 +46,12 @@ import java.nio.file.Files; import java.nio.file.Paths; import java.time.Duration; +import java.time.temporal.ChronoUnit; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import static com.mongodb.AuthenticationMechanism.MONGODB_OIDC; @@ -64,11 +67,14 @@ import static com.mongodb.assertions.Assertions.assertFalse; import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.assertions.Assertions.assertTrue; +import static com.mongodb.internal.TimeoutContext.throwMongoTimeoutException; import static com.mongodb.internal.async.AsyncRunnable.beginAsync; import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateBeforeUse; import static java.lang.String.format; /** + * Created per connection, and exists until connection is closed. + * *

This class is not part of the public API and may be removed or changed at any time

*/ public final class OidcAuthenticator extends SaslAuthenticator { @@ -118,8 +124,21 @@ public OidcAuthenticator(final MongoCredentialWithCache credential, } } - private Duration getCallbackTimeout() { - return isHumanCallback() ? HUMAN_CALLBACK_TIMEOUT : CALLBACK_TIMEOUT; + private Duration getCallbackTimeout(final TimeoutContext timeoutContext) { + if (isHumanCallback()) { + return HUMAN_CALLBACK_TIMEOUT; + } + + if (timeoutContext.hasTimeoutMS()) { + return assertNotNull(timeoutContext.getTimeout()).call(TimeUnit.MILLISECONDS, + () -> + // we can get here if server selection timeout was set to infinite. + ChronoUnit.FOREVER.getDuration(), + (renamingMs) -> Duration.ofMillis(renamingMs), + () -> throwMongoTimeoutException()); + + } + return CALLBACK_TIMEOUT; } @Override @@ -128,10 +147,10 @@ public String getMechanismName() { } @Override - protected SaslClient createSaslClient(final ServerAddress serverAddress) { + protected SaslClient createSaslClient(final ServerAddress serverAddress, final OperationContext operationContext) { this.serverAddress = assertNotNull(serverAddress); MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); - return new OidcSaslClient(mongoCredentialWithCache); + return new OidcSaslClient(mongoCredentialWithCache, operationContext.getTimeoutContext()); } @Override @@ -322,7 +341,7 @@ private void authenticationLoopAsync(final InternalConnection connection, final ).finish(callback); } - private byte[] evaluate(final byte[] challenge) { + private byte[] evaluate(final byte[] challenge, final TimeoutContext timeoutContext) { byte[][] jwt = new byte[1][]; Locks.withInterruptibleLock(getMongoCredentialWithCache().getOidcLock(), () -> { OidcCacheEntry oidcCacheEntry = getMongoCredentialWithCache().getOidcCacheEntry(); @@ -343,7 +362,7 @@ private byte[] evaluate(final byte[] challenge) { // Invoke Callback using cached Refresh Token fallbackState = FallbackState.PHASE_2_REFRESH_CALLBACK_TOKEN; OidcCallbackResult result = requestCallback.onRequest(new OidcCallbackContextImpl( - getCallbackTimeout(), cachedIdpInfo, cachedRefreshToken, userName)); + getCallbackTimeout(timeoutContext), cachedIdpInfo, cachedRefreshToken, userName)); jwt[0] = populateCacheWithCallbackResultAndPrepareJwt(cachedIdpInfo, result); } else { // cache is empty @@ -352,7 +371,7 @@ private byte[] evaluate(final byte[] challenge) { // no principal request fallbackState = FallbackState.PHASE_3B_CALLBACK_TOKEN; OidcCallbackResult result = requestCallback.onRequest(new OidcCallbackContextImpl( - getCallbackTimeout(), userName)); + getCallbackTimeout(timeoutContext), userName)); jwt[0] = populateCacheWithCallbackResultAndPrepareJwt(null, result); if (result.getRefreshToken() != null) { throw new MongoConfigurationException( @@ -382,7 +401,7 @@ private byte[] evaluate(final byte[] challenge) { // there is no cached refresh token fallbackState = FallbackState.PHASE_3B_CALLBACK_TOKEN; OidcCallbackResult result = requestCallback.onRequest(new OidcCallbackContextImpl( - getCallbackTimeout(), idpInfo, null, userName)); + getCallbackTimeout(timeoutContext), idpInfo, null, userName)); jwt[0] = populateCacheWithCallbackResultAndPrepareJwt(idpInfo, result); } } @@ -501,14 +520,18 @@ OidcCacheEntry clearRefreshToken() { } private final class OidcSaslClient extends SaslClientImpl { + private final TimeoutContext timeoutContext; - private OidcSaslClient(final MongoCredentialWithCache mongoCredentialWithCache) { + private OidcSaslClient(final MongoCredentialWithCache mongoCredentialWithCache, + final TimeoutContext timeoutContext) { super(mongoCredentialWithCache.getCredential()); + + this.timeoutContext = timeoutContext; } @Override public byte[] evaluateChallenge(final byte[] challenge) { - return evaluate(challenge); + return evaluate(challenge, timeoutContext); } @Override diff --git a/driver-core/src/main/com/mongodb/internal/connection/PlainAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/PlainAuthenticator.java index ff7eacb11d..f075ab154f 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/PlainAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/PlainAuthenticator.java @@ -47,7 +47,7 @@ public String getMechanismName() { } @Override - protected SaslClient createSaslClient(final ServerAddress serverAddress) { + protected SaslClient createSaslClient(final ServerAddress serverAddress, final OperationContext operationContext) { MongoCredential credential = getMongoCredential(); isTrue("mechanism is PLAIN", credential.getAuthenticationMechanism() == PLAIN); try { diff --git a/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java index 900d9a14e1..eeee3a31ab 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java @@ -65,7 +65,7 @@ abstract class SaslAuthenticator extends Authenticator implements SpeculativeAut public void authenticate(final InternalConnection connection, final ConnectionDescription connectionDescription, final OperationContext operationContext) { doAsSubject(() -> { - SaslClient saslClient = createSaslClient(connection.getDescription().getServerAddress()); + SaslClient saslClient = createSaslClient(connection.getDescription().getServerAddress(), operationContext); throwIfSaslClientIsNull(saslClient); try { BsonDocument responseDocument = getNextSaslResponse(saslClient, connection, operationContext); @@ -105,7 +105,7 @@ void authenticateAsync(final InternalConnection connection, final ConnectionDesc final OperationContext operationContext, final SingleResultCallback callback) { try { doAsSubject(() -> { - SaslClient saslClient = createSaslClient(connection.getDescription().getServerAddress()); + SaslClient saslClient = createSaslClient(connection.getDescription().getServerAddress(), operationContext); throwIfSaslClientIsNull(saslClient); getNextSaslResponseAsync(saslClient, connection, operationContext, callback); return null; @@ -117,7 +117,7 @@ void authenticateAsync(final InternalConnection connection, final ConnectionDesc public abstract String getMechanismName(); - protected abstract SaslClient createSaslClient(ServerAddress serverAddress); + protected abstract SaslClient createSaslClient(ServerAddress serverAddress, OperationContext operationContext); protected void appendSaslStartOptions(final BsonDocument saslStartCommand) { } diff --git a/driver-core/src/main/com/mongodb/internal/connection/ScramShaAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/ScramShaAuthenticator.java index 542ce47360..b98b72b3be 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/ScramShaAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/ScramShaAuthenticator.java @@ -90,13 +90,17 @@ protected void appendSaslStartOptions(final BsonDocument saslStartCommand) { @Override - protected SaslClient createSaslClient(final ServerAddress serverAddress) { + protected SaslClient createSaslClient(final ServerAddress serverAddress, @Nullable final OperationContext operationContext) { if (speculativeSaslClient != null) { return speculativeSaslClient; } return new ScramShaSaslClient(getMongoCredentialWithCache().getCredential(), randomStringGenerator, authenticationHashGenerator); } + protected SaslClient createSaslClient(final ServerAddress serverAddress) { + return createSaslClient(serverAddress, null); + } + @Override public BsonDocument createSpeculativeAuthenticateCommand(final InternalConnection connection) { try { diff --git a/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java index 2b0544f0c5..4fd3000802 100644 --- a/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java +++ b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java @@ -42,6 +42,10 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; import java.io.IOException; import java.lang.reflect.Field; @@ -50,6 +54,7 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.time.Duration; +import java.time.temporal.ChronoUnit; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -58,9 +63,11 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Supplier; import java.util.stream.Collectors; +import java.util.stream.Stream; import static com.mongodb.MongoCredential.ALLOWED_HOSTS_KEY; import static com.mongodb.MongoCredential.ENVIRONMENT_KEY; @@ -72,9 +79,12 @@ import static com.mongodb.MongoCredential.TOKEN_RESOURCE_KEY; import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.testing.MongoAssertions.assertCause; +import static java.lang.Math.min; +import static java.lang.String.format; import static java.lang.System.getenv; import static java.util.Arrays.asList; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -198,6 +208,91 @@ public void test2p1ValidCallbackInputs() { } } + // Not a prose test + @ParameterizedTest(name = "{0}. " + + "Parameters: timeoutMs={1}, " + + "serverSelectionTimeoutMS={2}," + + " expectedTimeoutThreshold={3}") + @MethodSource + void testValidCallbackInputsTimeoutWhenTimeoutMsIsSet(final String testName, + final int timeoutMs, + final int serverSelectionTimeoutMS, + final int expectedTimeoutThreshold) { + TestCallback callback1 = createCallback(); + + OidcCallback callback2 = (context) -> { + assertTrue(context.getTimeout().toMillis() < expectedTimeoutThreshold, + format("Expected timeout to be less than %d, but was %d", + expectedTimeoutThreshold, + context.getTimeout().toMillis())); + return callback1.onRequest(context); + }; + + MongoClientSettings clientSettings = MongoClientSettings.builder(createSettings(callback2)) + .applyToClusterSettings(builder -> + builder.serverSelectionTimeout( + serverSelectionTimeoutMS, + TimeUnit.MILLISECONDS)) + .timeout(timeoutMs, TimeUnit.MILLISECONDS) + .build(); + + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + long start = System.nanoTime(); + performFind(mongoClient); + assertEquals(1, callback1.getInvocations()); + long elapsed = msElapsedSince(start); + + assertFalse(elapsed > (timeoutMs == 0 ? serverSelectionTimeoutMS : min(serverSelectionTimeoutMS, timeoutMs)), + format("Elapsed time %d is greater then minimum of serverSelectionTimeoutMS and timeoutMs, which is %d. " + + "This indicates that the callback was not called with the expected timeout.", + min(serverSelectionTimeoutMS, timeoutMs), + elapsed)); + } + } + + private static Stream testValidCallbackInputsTimeoutWhenTimeoutMsIsSet() { + return Stream.of( + Arguments.of("serverSelectionTimeoutMS honored for oidc callback if it's lower than timeoutMS", + 1000, // timeoutMS + 500, // serverSelectionTimeoutMS + 499), // expectedTimeoutThreshold + Arguments.of("timeoutMS honored for oidc callback if it's lower than serverSelectionTimeoutMS", + 500, // timeoutMS + 1000, // serverSelectionTimeoutMS + 499), // expectedTimeoutThreshold + Arguments.of("serverSelectionTimeoutMS honored for oidc callback if timeoutMS=0", + 0, // infinite timeoutMS + 500, // serverSelectionTimeoutMS + 499) // expectedTimeoutThreshold + ); + } + + // Not a prose test + @ParameterizedTest(name = "test callback timeout when server selection timeout is " + + "infinite and timeoutMs is set to {0}") + @ValueSource(ints = {0, 100}) + void testCallbackTimeoutWhenServerSelectionTimeoutIsInfiniteTimeoutMsIsSet(final int timeoutMs) { + TestCallback callback1 = createCallback(); + + OidcCallback callback2 = (context) -> { + assertEquals(context.getTimeout(), ChronoUnit.FOREVER.getDuration()); + return callback1.onRequest(context); + }; + + MongoClientSettings clientSettings = MongoClientSettings.builder(createSettings(callback2)) + .applyToClusterSettings(builder -> + builder.serverSelectionTimeout( + -1, // -1 means infinite + TimeUnit.MILLISECONDS)) + .timeout(timeoutMs, TimeUnit.MILLISECONDS) + .build(); + + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + performFind(mongoClient); + assertEquals(1, callback1.getInvocations()); + } + } + @Test public void test2p2RequestCallbackReturnsNull() { //noinspection ConstantConditions @@ -1143,4 +1238,8 @@ public TestCallback createHumanCallback() { .setPathSupplier(() -> oidcTokenDirectory() + "test_user1") .setRefreshToken("refreshToken"); } + + private long msElapsedSince(final long timeOfStart) { + return TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - timeOfStart); + } }