Skip to content

Add CSOT to OIDC. #1741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ public static MongoOperationTimeoutException createMongoTimeoutException(final S
public static <T> T throwMongoTimeoutException(final String message) {
throw new MongoOperationTimeoutException(message);
}
public static <T> T throwMongoTimeoutException() {
throw new MongoOperationTimeoutException("The operation exceeded the timeout limit.");
}

public static MongoOperationTimeoutException createMongoTimeoutException(final Throwable cause) {
return createMongoTimeoutException("Operation exceeded the timeout limit: " + cause.getMessage(), cause);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public String getMechanismName() {
}

@Override
protected SaslClient createSaslClient(final ServerAddress serverAddress) {
protected SaslClient createSaslClient(final ServerAddress serverAddress, final OperationContext operationContext) {
return new AwsSaslClient(getMongoCredential());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public String getMechanismName() {
}

@Override
protected SaslClient createSaslClient(final ServerAddress serverAddress) {
protected SaslClient createSaslClient(final ServerAddress serverAddress, final OperationContext operationContext) {
MongoCredential credential = getMongoCredential();
try {
Map<String, Object> saslClientProperties = credential.getMechanismProperty(JAVA_SASL_CLIENT_PROPERTIES_KEY, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
import com.mongodb.MongoSecurityException;
import com.mongodb.ServerAddress;
import com.mongodb.ServerApi;
import com.mongodb.assertions.Assertions;
import com.mongodb.connection.ClusterConnectionMode;
import com.mongodb.connection.ConnectionDescription;
import com.mongodb.internal.Locks;
import com.mongodb.internal.TimeoutContext;
import com.mongodb.internal.VisibleForTesting;
import com.mongodb.internal.async.SingleResultCallback;
import com.mongodb.internal.authentication.AzureCredentialHelper;
Expand All @@ -49,6 +51,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import static com.mongodb.AuthenticationMechanism.MONGODB_OIDC;
Expand All @@ -64,11 +67,14 @@
import static com.mongodb.assertions.Assertions.assertFalse;
import static com.mongodb.assertions.Assertions.assertNotNull;
import static com.mongodb.assertions.Assertions.assertTrue;
import static com.mongodb.internal.TimeoutContext.throwMongoTimeoutException;
import static com.mongodb.internal.async.AsyncRunnable.beginAsync;
import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateBeforeUse;
import static java.lang.String.format;

/**
* Created per connection, and exists until connection is closed.
*
* <p>This class is not part of the public API and may be removed or changed at any time</p>
*/
public final class OidcAuthenticator extends SaslAuthenticator {
Expand Down Expand Up @@ -118,8 +124,22 @@ public OidcAuthenticator(final MongoCredentialWithCache credential,
}
}

private Duration getCallbackTimeout() {
return isHumanCallback() ? HUMAN_CALLBACK_TIMEOUT : CALLBACK_TIMEOUT;
private Duration getCallbackTimeout(final TimeoutContext timeoutContext) {
if (isHumanCallback()) {
return HUMAN_CALLBACK_TIMEOUT;
}

if (timeoutContext.hasTimeoutMS()) {
return assertNotNull(timeoutContext.getTimeout()).call(TimeUnit.MILLISECONDS,
() -> {
Assertions.fail();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the tests it looks like it should fall back to serverSelectionTimeoutMS - if theres an infinite timeout - so how does that logic work?

Also it might be best to put an assertion message incase a future regression means this code path is hit - potentially would make debugging it easier.

return null;
},
(renamingMs) -> Duration.ofMillis(renamingMs),
() -> throwMongoTimeoutException());

}
return CALLBACK_TIMEOUT;
}

@Override
Expand All @@ -128,10 +148,10 @@ public String getMechanismName() {
}

@Override
protected SaslClient createSaslClient(final ServerAddress serverAddress) {
protected SaslClient createSaslClient(final ServerAddress serverAddress, final OperationContext operationContext) {
this.serverAddress = assertNotNull(serverAddress);
MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache();
return new OidcSaslClient(mongoCredentialWithCache);
return new OidcSaslClient(mongoCredentialWithCache, operationContext.getTimeoutContext());
}

@Override
Expand Down Expand Up @@ -322,7 +342,7 @@ private void authenticationLoopAsync(final InternalConnection connection, final
).finish(callback);
}

private byte[] evaluate(final byte[] challenge) {
private byte[] evaluate(final byte[] challenge, final TimeoutContext timeoutContext) {
byte[][] jwt = new byte[1][];
Locks.withInterruptibleLock(getMongoCredentialWithCache().getOidcLock(), () -> {
OidcCacheEntry oidcCacheEntry = getMongoCredentialWithCache().getOidcCacheEntry();
Expand All @@ -343,7 +363,7 @@ private byte[] evaluate(final byte[] challenge) {
// Invoke Callback using cached Refresh Token
fallbackState = FallbackState.PHASE_2_REFRESH_CALLBACK_TOKEN;
OidcCallbackResult result = requestCallback.onRequest(new OidcCallbackContextImpl(
getCallbackTimeout(), cachedIdpInfo, cachedRefreshToken, userName));
getCallbackTimeout(timeoutContext), cachedIdpInfo, cachedRefreshToken, userName));
jwt[0] = populateCacheWithCallbackResultAndPrepareJwt(cachedIdpInfo, result);
} else {
// cache is empty
Expand All @@ -352,7 +372,7 @@ private byte[] evaluate(final byte[] challenge) {
// no principal request
fallbackState = FallbackState.PHASE_3B_CALLBACK_TOKEN;
OidcCallbackResult result = requestCallback.onRequest(new OidcCallbackContextImpl(
getCallbackTimeout(), userName));
getCallbackTimeout(timeoutContext), userName));
jwt[0] = populateCacheWithCallbackResultAndPrepareJwt(null, result);
if (result.getRefreshToken() != null) {
throw new MongoConfigurationException(
Expand Down Expand Up @@ -382,7 +402,7 @@ private byte[] evaluate(final byte[] challenge) {
// there is no cached refresh token
fallbackState = FallbackState.PHASE_3B_CALLBACK_TOKEN;
OidcCallbackResult result = requestCallback.onRequest(new OidcCallbackContextImpl(
getCallbackTimeout(), idpInfo, null, userName));
getCallbackTimeout(timeoutContext), idpInfo, null, userName));
jwt[0] = populateCacheWithCallbackResultAndPrepareJwt(idpInfo, result);
}
}
Expand Down Expand Up @@ -501,14 +521,18 @@ OidcCacheEntry clearRefreshToken() {
}

private final class OidcSaslClient extends SaslClientImpl {
private final TimeoutContext timeoutContext;

private OidcSaslClient(final MongoCredentialWithCache mongoCredentialWithCache) {
private OidcSaslClient(final MongoCredentialWithCache mongoCredentialWithCache,
final TimeoutContext timeoutContext) {
super(mongoCredentialWithCache.getCredential());

this.timeoutContext = timeoutContext;
}

@Override
public byte[] evaluateChallenge(final byte[] challenge) {
return evaluate(challenge);
return evaluate(challenge, timeoutContext);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public String getMechanismName() {
}

@Override
protected SaslClient createSaslClient(final ServerAddress serverAddress) {
protected SaslClient createSaslClient(final ServerAddress serverAddress, final OperationContext operationContext) {
MongoCredential credential = getMongoCredential();
isTrue("mechanism is PLAIN", credential.getAuthenticationMechanism() == PLAIN);
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ abstract class SaslAuthenticator extends Authenticator implements SpeculativeAut
public void authenticate(final InternalConnection connection, final ConnectionDescription connectionDescription,
final OperationContext operationContext) {
doAsSubject(() -> {
SaslClient saslClient = createSaslClient(connection.getDescription().getServerAddress());
SaslClient saslClient = createSaslClient(connection.getDescription().getServerAddress(), operationContext);
throwIfSaslClientIsNull(saslClient);
try {
BsonDocument responseDocument = getNextSaslResponse(saslClient, connection, operationContext);
Expand Down Expand Up @@ -105,7 +105,7 @@ void authenticateAsync(final InternalConnection connection, final ConnectionDesc
final OperationContext operationContext, final SingleResultCallback<Void> callback) {
try {
doAsSubject(() -> {
SaslClient saslClient = createSaslClient(connection.getDescription().getServerAddress());
SaslClient saslClient = createSaslClient(connection.getDescription().getServerAddress(), operationContext);
throwIfSaslClientIsNull(saslClient);
getNextSaslResponseAsync(saslClient, connection, operationContext, callback);
return null;
Expand All @@ -117,7 +117,7 @@ void authenticateAsync(final InternalConnection connection, final ConnectionDesc

public abstract String getMechanismName();

protected abstract SaslClient createSaslClient(ServerAddress serverAddress);
protected abstract SaslClient createSaslClient(ServerAddress serverAddress, OperationContext operationContext);

protected void appendSaslStartOptions(final BsonDocument saslStartCommand) {
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,17 @@ protected void appendSaslStartOptions(final BsonDocument saslStartCommand) {


@Override
protected SaslClient createSaslClient(final ServerAddress serverAddress) {
protected SaslClient createSaslClient(final ServerAddress serverAddress, @Nullable final OperationContext operationContext) {
if (speculativeSaslClient != null) {
return speculativeSaslClient;
}
return new ScramShaSaslClient(getMongoCredentialWithCache().getCredential(), randomStringGenerator, authenticationHashGenerator);
}

protected SaslClient createSaslClient(final ServerAddress serverAddress) {
return createSaslClient(serverAddress, null);
}

@Override
public BsonDocument createSpeculativeAuthenticateCommand(final InternalConnection connection) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@
import org.bson.Document;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import java.io.IOException;
import java.lang.reflect.Field;
Expand All @@ -58,9 +62,11 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static com.mongodb.MongoCredential.ALLOWED_HOSTS_KEY;
import static com.mongodb.MongoCredential.ENVIRONMENT_KEY;
Expand All @@ -72,9 +78,12 @@
import static com.mongodb.MongoCredential.TOKEN_RESOURCE_KEY;
import static com.mongodb.assertions.Assertions.assertNotNull;
import static com.mongodb.testing.MongoAssertions.assertCause;
import static java.lang.Math.min;
import static java.lang.String.format;
import static java.lang.System.getenv;
import static java.util.Arrays.asList;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
Expand Down Expand Up @@ -198,6 +207,63 @@ public void test2p1ValidCallbackInputs() {
}
}

// Not a prose test
@ParameterizedTest
@MethodSource
@DisplayName("{testName}")
void testValidCallbackInputsTimeout(final String testName,
final int timeoutMs,
final int serverSelectionTimeoutMS,
final int expectedTimeoutThreshold) {
TestCallback callback1 = createCallback();

OidcCallback callback2 = (context) -> {
assertTrue(context.getTimeout().toMillis() < expectedTimeoutThreshold,
format("Expected timeout to be less than %d, but was %d",
expectedTimeoutThreshold,
context.getTimeout().toMillis()));
return callback1.onRequest(context);
};

MongoClientSettings clientSettings = MongoClientSettings.builder(createSettings(callback2))
.applyToClusterSettings(builder ->
builder.serverSelectionTimeout(
serverSelectionTimeoutMS,
TimeUnit.MILLISECONDS))
.timeout(timeoutMs, TimeUnit.MILLISECONDS)
.build();

try (MongoClient mongoClient = createMongoClient(clientSettings)) {
long start = System.nanoTime();
performFind(mongoClient);
assertEquals(1, callback1.getInvocations());
long elapsed = msElapsedSince(start);

assertFalse(elapsed > (timeoutMs == 0 ? serverSelectionTimeoutMS : min(serverSelectionTimeoutMS, timeoutMs)),
format("Elapsed time %d is greater then minimum of serverSelectionTimeoutMS and timeoutMs, which is %d. "
+ "This indicates that the callback was not called with the expected timeout.",
min(serverSelectionTimeoutMS, timeoutMs),
elapsed));
}
}

private static Stream<Arguments> testValidCallbackInputsTimeout() {
return Stream.of(
Arguments.of("serverSelectionTimeoutMS honored for oidc callback if it's lower than timeoutMS",
1000, // timeoutMS
500, // serverSelectionTimeoutMS
499), // expectedTimeoutThreshold
Arguments.of("timeoutMS honored for oidc callback if it's lower than serverSelectionTimeoutMS",
500, // timeoutMS
1000, // serverSelectionTimeoutMS
499), // expectedTimeoutThreshold
Arguments.of("serverSelectionTimeoutMS honored for oidc callback if timeoutMS=0",
0, // infinite timeoutMS
500, // serverSelectionTimeoutMS
499) // expectedTimeoutThreshold
);
}

@Test
public void test2p2RequestCallbackReturnsNull() {
//noinspection ConstantConditions
Expand Down Expand Up @@ -1143,4 +1209,8 @@ public TestCallback createHumanCallback() {
.setPathSupplier(() -> oidcTokenDirectory() + "test_user1")
.setRefreshToken("refreshToken");
}

private long msElapsedSince(final long timeOfStart) {
return TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - timeOfStart);
}
}