diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/AccessToken.java b/mcp/src/main/java/io/modelcontextprotocol/auth/AccessToken.java new file mode 100644 index 00000000..1c9f8954 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/AccessToken.java @@ -0,0 +1,60 @@ +package io.modelcontextprotocol.auth; + +import java.util.List; + +/** + * Represents an OAuth access token. + */ +public class AccessToken { + + private String token; + + private String clientId; + + private List scopes; + + private Integer expiresAt; + + public AccessToken() { + } + + public AccessToken(String token, String clientId, List scopes, Integer expiresAt) { + this.token = token; + this.clientId = clientId; + this.scopes = scopes; + this.expiresAt = expiresAt; + } + + public String getToken() { + return token; + } + + public void setToken(String token) { + this.token = token; + } + + public String getClientId() { + return clientId; + } + + public void setClientId(String clientId) { + this.clientId = clientId; + } + + public List getScopes() { + return scopes; + } + + public void setScopes(List scopes) { + this.scopes = scopes; + } + + public Integer getExpiresAt() { + return expiresAt; + } + + public void setExpiresAt(Integer expiresAt) { + this.expiresAt = expiresAt; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/AuthorizationCode.java b/mcp/src/main/java/io/modelcontextprotocol/auth/AuthorizationCode.java new file mode 100644 index 00000000..fdd5c14c --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/AuthorizationCode.java @@ -0,0 +1,95 @@ +package io.modelcontextprotocol.auth; + +import java.net.URI; +import java.util.List; + +/** + * Represents an OAuth authorization code. + */ +public class AuthorizationCode { + + private String code; + + private List scopes; + + private double expiresAt; + + private String clientId; + + private String codeChallenge; + + private URI redirectUri; + + private boolean redirectUriProvidedExplicitly; + + public AuthorizationCode() { + } + + public AuthorizationCode(String code, List scopes, double expiresAt, String clientId, String codeChallenge, + URI redirectUri, boolean redirectUriProvidedExplicitly) { + this.code = code; + this.scopes = scopes; + this.expiresAt = expiresAt; + this.clientId = clientId; + this.codeChallenge = codeChallenge; + this.redirectUri = redirectUri; + this.redirectUriProvidedExplicitly = redirectUriProvidedExplicitly; + } + + public String getCode() { + return code; + } + + public void setCode(String code) { + this.code = code; + } + + public List getScopes() { + return scopes; + } + + public void setScopes(List scopes) { + this.scopes = scopes; + } + + public double getExpiresAt() { + return expiresAt; + } + + public void setExpiresAt(double expiresAt) { + this.expiresAt = expiresAt; + } + + public String getClientId() { + return clientId; + } + + public void setClientId(String clientId) { + this.clientId = clientId; + } + + public String getCodeChallenge() { + return codeChallenge; + } + + public void setCodeChallenge(String codeChallenge) { + this.codeChallenge = codeChallenge; + } + + public URI getRedirectUri() { + return redirectUri; + } + + public void setRedirectUri(URI redirectUri) { + this.redirectUri = redirectUri; + } + + public boolean isRedirectUriProvidedExplicitly() { + return redirectUriProvidedExplicitly; + } + + public void setRedirectUriProvidedExplicitly(boolean redirectUriProvidedExplicitly) { + this.redirectUriProvidedExplicitly = redirectUriProvidedExplicitly; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/AuthorizationParams.java b/mcp/src/main/java/io/modelcontextprotocol/auth/AuthorizationParams.java new file mode 100644 index 00000000..996c48bd --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/AuthorizationParams.java @@ -0,0 +1,73 @@ +package io.modelcontextprotocol.auth; + +import java.net.URI; +import java.util.List; + +/** + * Parameters for an authorization request. + */ +public class AuthorizationParams { + + private String state; + + private List scopes; + + private String codeChallenge; + + private URI redirectUri; + + private boolean redirectUriProvidedExplicitly; + + public AuthorizationParams() { + } + + public AuthorizationParams(String state, List scopes, String codeChallenge, URI redirectUri, + boolean redirectUriProvidedExplicitly) { + this.state = state; + this.scopes = scopes; + this.codeChallenge = codeChallenge; + this.redirectUri = redirectUri; + this.redirectUriProvidedExplicitly = redirectUriProvidedExplicitly; + } + + public String getState() { + return state; + } + + public void setState(String state) { + this.state = state; + } + + public List getScopes() { + return scopes; + } + + public void setScopes(List scopes) { + this.scopes = scopes; + } + + public String getCodeChallenge() { + return codeChallenge; + } + + public void setCodeChallenge(String codeChallenge) { + this.codeChallenge = codeChallenge; + } + + public URI getRedirectUri() { + return redirectUri; + } + + public void setRedirectUri(URI redirectUri) { + this.redirectUri = redirectUri; + } + + public boolean isRedirectUriProvidedExplicitly() { + return redirectUriProvidedExplicitly; + } + + public void setRedirectUriProvidedExplicitly(boolean redirectUriProvidedExplicitly) { + this.redirectUriProvidedExplicitly = redirectUriProvidedExplicitly; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/InvalidRedirectUriException.java b/mcp/src/main/java/io/modelcontextprotocol/auth/InvalidRedirectUriException.java new file mode 100644 index 00000000..91a6fd32 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/InvalidRedirectUriException.java @@ -0,0 +1,12 @@ +package io.modelcontextprotocol.auth; + +/** + * Exception thrown when a redirect URI is invalid. + */ +public class InvalidRedirectUriException extends Exception { + + public InvalidRedirectUriException(String message) { + super(message); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/InvalidScopeException.java b/mcp/src/main/java/io/modelcontextprotocol/auth/InvalidScopeException.java new file mode 100644 index 00000000..620a1f1e --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/InvalidScopeException.java @@ -0,0 +1,12 @@ +package io.modelcontextprotocol.auth; + +/** + * Exception thrown when a requested scope is invalid. + */ +public class InvalidScopeException extends Exception { + + public InvalidScopeException(String message) { + super(message); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthAuthorizationServerProvider.java b/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthAuthorizationServerProvider.java new file mode 100644 index 00000000..cefe1498 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthAuthorizationServerProvider.java @@ -0,0 +1,99 @@ +package io.modelcontextprotocol.auth; + +import io.modelcontextprotocol.auth.exception.AuthorizeException; +import io.modelcontextprotocol.auth.exception.RegistrationException; +import io.modelcontextprotocol.auth.exception.TokenException; + +import java.util.List; +import java.util.concurrent.CompletableFuture; + +/** + * Interface for OAuth authorization server providers. + */ +public interface OAuthAuthorizationServerProvider { + + /** + * Retrieves client information by client ID. + * @param clientId The ID of the client to retrieve. + * @return A CompletableFuture that resolves to the client information, or null if the + * client does not exist. + */ + CompletableFuture getClient(String clientId); + + /** + * Saves client information as part of registering it. + * @param clientInfo The client metadata to register. + * @return A CompletableFuture that completes when the registration is done. + * @throws RegistrationException If the client metadata is invalid. + */ + CompletableFuture registerClient(OAuthClientInformation clientInfo) throws RegistrationException; + + /** + * Called as part of the /authorize endpoint, and returns a URL that the client will + * be redirected to. + * @param client The client requesting authorization. + * @param params The parameters of the authorization request. + * @return A CompletableFuture that resolves to a URL to redirect the client to for + * authorization. + * @throws AuthorizeException If the authorization request is invalid. + */ + CompletableFuture authorize(OAuthClientInformation client, AuthorizationParams params) + throws AuthorizeException; + + /** + * Loads an AuthorizationCode by its code. + * @param client The client that requested the authorization code. + * @param authorizationCode The authorization code to get the challenge for. + * @return A CompletableFuture that resolves to the AuthorizationCode, or null if not + * found. + */ + CompletableFuture loadAuthorizationCode(OAuthClientInformation client, String authorizationCode); + + /** + * Exchanges an authorization code for an access token and refresh token. + * @param client The client exchanging the authorization code. + * @param authorizationCode The authorization code to exchange. + * @return A CompletableFuture that resolves to the OAuth token, containing access and + * refresh tokens. + * @throws TokenException If the request is invalid. + */ + CompletableFuture exchangeAuthorizationCode(OAuthClientInformation client, + AuthorizationCode authorizationCode) throws TokenException; + + /** + * Loads a RefreshToken by its token string. + * @param client The client that is requesting to load the refresh token. + * @param refreshToken The refresh token string to load. + * @return A CompletableFuture that resolves to the RefreshToken object if found, or + * null if not found. + */ + CompletableFuture loadRefreshToken(OAuthClientInformation client, String refreshToken); + + /** + * Exchanges a refresh token for an access token and refresh token. + * @param client The client exchanging the refresh token. + * @param refreshToken The refresh token to exchange. + * @param scopes Optional scopes to request with the new access token. + * @return A CompletableFuture that resolves to the OAuth token, containing access and + * refresh tokens. + * @throws TokenException If the request is invalid. + */ + CompletableFuture exchangeRefreshToken(OAuthClientInformation client, RefreshToken refreshToken, + List scopes) throws TokenException; + + /** + * Loads an access token by its token. + * @param token The access token to verify. + * @return A CompletableFuture that resolves to the AccessToken, or null if the token + * is invalid. + */ + CompletableFuture loadAccessToken(String token); + + /** + * Revokes an access or refresh token. + * @param token The token to revoke. + * @return A CompletableFuture that completes when the token is revoked. + */ + CompletableFuture revokeToken(Object token); + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthClientInformation.java b/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthClientInformation.java new file mode 100644 index 00000000..31e48eeb --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthClientInformation.java @@ -0,0 +1,53 @@ +package io.modelcontextprotocol.auth; + +/** + * RFC 7591 OAuth 2.0 Dynamic Client Registration full response (client information plus + * metadata). + */ +public class OAuthClientInformation extends OAuthClientMetadata { + + private String clientId; + + private String clientSecret; + + private Long clientIdIssuedAt; + + private Long clientSecretExpiresAt; + + public OAuthClientInformation() { + super(); + } + + public String getClientId() { + return clientId; + } + + public void setClientId(String clientId) { + this.clientId = clientId; + } + + public String getClientSecret() { + return clientSecret; + } + + public void setClientSecret(String clientSecret) { + this.clientSecret = clientSecret; + } + + public Long getClientIdIssuedAt() { + return clientIdIssuedAt; + } + + public void setClientIdIssuedAt(Long clientIdIssuedAt) { + this.clientIdIssuedAt = clientIdIssuedAt; + } + + public Long getClientSecretExpiresAt() { + return clientSecretExpiresAt; + } + + public void setClientSecretExpiresAt(Long clientSecretExpiresAt) { + this.clientSecretExpiresAt = clientSecretExpiresAt; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthClientMetadata.java b/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthClientMetadata.java new file mode 100644 index 00000000..a54fb966 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthClientMetadata.java @@ -0,0 +1,217 @@ +package io.modelcontextprotocol.auth; + +import java.net.URI; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * RFC 7591 OAuth 2.0 Dynamic Client Registration metadata. See + * https://datatracker.ietf.org/doc/html/rfc7591#section-2 for the full specification. + */ +public class OAuthClientMetadata { + + private List redirectUris; + + private String tokenEndpointAuthMethod; + + private List grantTypes; + + private List responseTypes; + + private String scope; + + // Optional metadata fields + private String clientName; + + private URI clientUri; + + private URI logoUri; + + private List contacts; + + private URI tosUri; + + private URI policyUri; + + private URI jwksUri; + + private Object jwks; + + private String softwareId; + + private String softwareVersion; + + public OAuthClientMetadata() { + this.tokenEndpointAuthMethod = "client_secret_post"; + this.grantTypes = Arrays.asList("authorization_code", "refresh_token"); + this.responseTypes = Arrays.asList("code"); + } + + /** + * Validates the requested scope against the client's allowed scopes. + * @param requestedScope The scope requested by the client + * @return List of validated scopes or null if no scope was requested + * @throws InvalidScopeException if the requested scope is not allowed + */ + public List validateScope(String requestedScope) throws InvalidScopeException { + if (requestedScope == null) { + return null; + } + + List requestedScopes = Arrays.asList(requestedScope.split(" ")); + List allowedScopes = scope == null ? new ArrayList<>() : Arrays.asList(scope.split(" ")); + + for (String scope : requestedScopes) { + if (!allowedScopes.contains(scope)) { + throw new InvalidScopeException("Client was not registered with scope " + scope); + } + } + + return requestedScopes; + } + + /** + * Validates the redirect URI against the client's registered redirect URIs. + * @param redirectUri The redirect URI to validate + * @return The validated redirect URI + * @throws InvalidRedirectUriException if the redirect URI is invalid + */ + public URI validateRedirectUri(URI redirectUri) throws InvalidRedirectUriException { + if (redirectUri != null) { + if (!redirectUris.contains(redirectUri)) { + throw new InvalidRedirectUriException("Redirect URI '" + redirectUri + "' not registered for client"); + } + return redirectUri; + } + else if (redirectUris.size() == 1) { + return redirectUris.get(0); + } + else { + throw new InvalidRedirectUriException( + "redirect_uri must be specified when client has multiple registered URIs"); + } + } + + // Getters and setters + public List getRedirectUris() { + return redirectUris; + } + + public void setRedirectUris(List redirectUris) { + this.redirectUris = redirectUris; + } + + public String getTokenEndpointAuthMethod() { + return tokenEndpointAuthMethod; + } + + public void setTokenEndpointAuthMethod(String tokenEndpointAuthMethod) { + this.tokenEndpointAuthMethod = tokenEndpointAuthMethod; + } + + public List getGrantTypes() { + return grantTypes; + } + + public void setGrantTypes(List grantTypes) { + this.grantTypes = grantTypes; + } + + public List getResponseTypes() { + return responseTypes; + } + + public void setResponseTypes(List responseTypes) { + this.responseTypes = responseTypes; + } + + public String getScope() { + return scope; + } + + public void setScope(String scope) { + this.scope = scope; + } + + public String getClientName() { + return clientName; + } + + public void setClientName(String clientName) { + this.clientName = clientName; + } + + public URI getClientUri() { + return clientUri; + } + + public void setClientUri(URI clientUri) { + this.clientUri = clientUri; + } + + public URI getLogoUri() { + return logoUri; + } + + public void setLogoUri(URI logoUri) { + this.logoUri = logoUri; + } + + public List getContacts() { + return contacts; + } + + public void setContacts(List contacts) { + this.contacts = contacts; + } + + public URI getTosUri() { + return tosUri; + } + + public void setTosUri(URI tosUri) { + this.tosUri = tosUri; + } + + public URI getPolicyUri() { + return policyUri; + } + + public void setPolicyUri(URI policyUri) { + this.policyUri = policyUri; + } + + public URI getJwksUri() { + return jwksUri; + } + + public void setJwksUri(URI jwksUri) { + this.jwksUri = jwksUri; + } + + public Object getJwks() { + return jwks; + } + + public void setJwks(Object jwks) { + this.jwks = jwks; + } + + public String getSoftwareId() { + return softwareId; + } + + public void setSoftwareId(String softwareId) { + this.softwareId = softwareId; + } + + public String getSoftwareVersion() { + return softwareVersion; + } + + public void setSoftwareVersion(String softwareVersion) { + this.softwareVersion = softwareVersion; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthMetadata.java b/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthMetadata.java new file mode 100644 index 00000000..6c9c7b59 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthMetadata.java @@ -0,0 +1,230 @@ +package io.modelcontextprotocol.auth; + +import java.net.URI; +import java.util.Arrays; +import java.util.List; + +/** + * RFC 8414 OAuth 2.0 Authorization Server Metadata. See + * https://datatracker.ietf.org/doc/html/rfc8414#section-2 + */ +public class OAuthMetadata { + + private URI issuer; + + private URI authorizationEndpoint; + + private URI tokenEndpoint; + + private URI registrationEndpoint; + + private List scopesSupported; + + private List responseTypesSupported; + + private List responseModesSupported; + + private List grantTypesSupported; + + private List tokenEndpointAuthMethodsSupported; + + private List tokenEndpointAuthSigningAlgValuesSupported; + + private URI serviceDocumentation; + + private List uiLocalesSupported; + + private URI opPolicyUri; + + private URI opTosUri; + + private URI revocationEndpoint; + + private List revocationEndpointAuthMethodsSupported; + + private List revocationEndpointAuthSigningAlgValuesSupported; + + private URI introspectionEndpoint; + + private List introspectionEndpointAuthMethodsSupported; + + private List introspectionEndpointAuthSigningAlgValuesSupported; + + private List codeChallengeMethodsSupported; + + public OAuthMetadata() { + this.responseTypesSupported = Arrays.asList("code"); + } + + // Getters and setters + public URI getIssuer() { + return issuer; + } + + public void setIssuer(URI issuer) { + this.issuer = issuer; + } + + public URI getAuthorizationEndpoint() { + return authorizationEndpoint; + } + + public void setAuthorizationEndpoint(URI authorizationEndpoint) { + this.authorizationEndpoint = authorizationEndpoint; + } + + public URI getTokenEndpoint() { + return tokenEndpoint; + } + + public void setTokenEndpoint(URI tokenEndpoint) { + this.tokenEndpoint = tokenEndpoint; + } + + public URI getRegistrationEndpoint() { + return registrationEndpoint; + } + + public void setRegistrationEndpoint(URI registrationEndpoint) { + this.registrationEndpoint = registrationEndpoint; + } + + public List getScopesSupported() { + return scopesSupported; + } + + public void setScopesSupported(List scopesSupported) { + this.scopesSupported = scopesSupported; + } + + public List getResponseTypesSupported() { + return responseTypesSupported; + } + + public void setResponseTypesSupported(List responseTypesSupported) { + this.responseTypesSupported = responseTypesSupported; + } + + public List getResponseModesSupported() { + return responseModesSupported; + } + + public void setResponseModesSupported(List responseModesSupported) { + this.responseModesSupported = responseModesSupported; + } + + public List getGrantTypesSupported() { + return grantTypesSupported; + } + + public void setGrantTypesSupported(List grantTypesSupported) { + this.grantTypesSupported = grantTypesSupported; + } + + public List getTokenEndpointAuthMethodsSupported() { + return tokenEndpointAuthMethodsSupported; + } + + public void setTokenEndpointAuthMethodsSupported(List tokenEndpointAuthMethodsSupported) { + this.tokenEndpointAuthMethodsSupported = tokenEndpointAuthMethodsSupported; + } + + public List getTokenEndpointAuthSigningAlgValuesSupported() { + return tokenEndpointAuthSigningAlgValuesSupported; + } + + public void setTokenEndpointAuthSigningAlgValuesSupported(List tokenEndpointAuthSigningAlgValuesSupported) { + this.tokenEndpointAuthSigningAlgValuesSupported = tokenEndpointAuthSigningAlgValuesSupported; + } + + public URI getServiceDocumentation() { + return serviceDocumentation; + } + + public void setServiceDocumentation(URI serviceDocumentation) { + this.serviceDocumentation = serviceDocumentation; + } + + public List getUiLocalesSupported() { + return uiLocalesSupported; + } + + public void setUiLocalesSupported(List uiLocalesSupported) { + this.uiLocalesSupported = uiLocalesSupported; + } + + public URI getOpPolicyUri() { + return opPolicyUri; + } + + public void setOpPolicyUri(URI opPolicyUri) { + this.opPolicyUri = opPolicyUri; + } + + public URI getOpTosUri() { + return opTosUri; + } + + public void setOpTosUri(URI opTosUri) { + this.opTosUri = opTosUri; + } + + public URI getRevocationEndpoint() { + return revocationEndpoint; + } + + public void setRevocationEndpoint(URI revocationEndpoint) { + this.revocationEndpoint = revocationEndpoint; + } + + public List getRevocationEndpointAuthMethodsSupported() { + return revocationEndpointAuthMethodsSupported; + } + + public void setRevocationEndpointAuthMethodsSupported(List revocationEndpointAuthMethodsSupported) { + this.revocationEndpointAuthMethodsSupported = revocationEndpointAuthMethodsSupported; + } + + public List getRevocationEndpointAuthSigningAlgValuesSupported() { + return revocationEndpointAuthSigningAlgValuesSupported; + } + + public void setRevocationEndpointAuthSigningAlgValuesSupported( + List revocationEndpointAuthSigningAlgValuesSupported) { + this.revocationEndpointAuthSigningAlgValuesSupported = revocationEndpointAuthSigningAlgValuesSupported; + } + + public URI getIntrospectionEndpoint() { + return introspectionEndpoint; + } + + public void setIntrospectionEndpoint(URI introspectionEndpoint) { + this.introspectionEndpoint = introspectionEndpoint; + } + + public List getIntrospectionEndpointAuthMethodsSupported() { + return introspectionEndpointAuthMethodsSupported; + } + + public void setIntrospectionEndpointAuthMethodsSupported(List introspectionEndpointAuthMethodsSupported) { + this.introspectionEndpointAuthMethodsSupported = introspectionEndpointAuthMethodsSupported; + } + + public List getIntrospectionEndpointAuthSigningAlgValuesSupported() { + return introspectionEndpointAuthSigningAlgValuesSupported; + } + + public void setIntrospectionEndpointAuthSigningAlgValuesSupported( + List introspectionEndpointAuthSigningAlgValuesSupported) { + this.introspectionEndpointAuthSigningAlgValuesSupported = introspectionEndpointAuthSigningAlgValuesSupported; + } + + public List getCodeChallengeMethodsSupported() { + return codeChallengeMethodsSupported; + } + + public void setCodeChallengeMethodsSupported(List codeChallengeMethodsSupported) { + this.codeChallengeMethodsSupported = codeChallengeMethodsSupported; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthToken.java b/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthToken.java new file mode 100644 index 00000000..993010d4 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/OAuthToken.java @@ -0,0 +1,71 @@ +package io.modelcontextprotocol.auth; + +/** + * OAuth token as defined in RFC 6749 section 5.1 + * https://datatracker.ietf.org/doc/html/rfc6749#section-5.1 + */ +public class OAuthToken { + + private String accessToken; + + private String tokenType; + + private Integer expiresIn; + + private String scope; + + private String refreshToken; + + public OAuthToken() { + this.tokenType = "bearer"; + } + + public OAuthToken(String accessToken, Integer expiresIn, String scope, String refreshToken) { + this.accessToken = accessToken; + this.tokenType = "bearer"; + this.expiresIn = expiresIn; + this.scope = scope; + this.refreshToken = refreshToken; + } + + public String getAccessToken() { + return accessToken; + } + + public void setAccessToken(String accessToken) { + this.accessToken = accessToken; + } + + public String getTokenType() { + return tokenType; + } + + public void setTokenType(String tokenType) { + this.tokenType = tokenType; + } + + public Integer getExpiresIn() { + return expiresIn; + } + + public void setExpiresIn(Integer expiresIn) { + this.expiresIn = expiresIn; + } + + public String getScope() { + return scope; + } + + public void setScope(String scope) { + this.scope = scope; + } + + public String getRefreshToken() { + return refreshToken; + } + + public void setRefreshToken(String refreshToken) { + this.refreshToken = refreshToken; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/README.md b/mcp/src/main/java/io/modelcontextprotocol/auth/README.md new file mode 100644 index 00000000..e8a77c51 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/README.md @@ -0,0 +1,86 @@ +# Authentication Implementation for Java SDK + +This package provides OAuth 2.0 authentication functionality for the Java SDK, based on the implementation in the Python SDK. + +## Overview + +The authentication implementation follows the OAuth 2.0 specification and includes: + +1. **Core OAuth Models**: + - `OAuthToken`: Represents an OAuth token with access and refresh tokens + - `OAuthClientMetadata`: Client registration metadata + - `OAuthClientInformation`: Client information including credentials + - `OAuthMetadata`: Authorization server metadata + +2. **Token Models**: + - `AccessToken`: Represents an OAuth access token + - `RefreshToken`: Represents an OAuth refresh token + - `AuthorizationCode`: Represents an OAuth authorization code + +3. **Authentication Middleware**: + - `BearerAuthenticator`: Validates Bearer tokens in Authorization headers + - `ClientAuthenticator`: Validates client credentials + - `AuthContext`: Holds authentication context for a request + +4. **Provider Interface**: + - `OAuthAuthorizationServerProvider`: Interface for OAuth authorization server providers + +5. **Exceptions**: + - `RegistrationException`: Thrown during client registration errors + - `AuthorizeException`: Thrown during authorization errors + - `TokenException`: Thrown during token operations errors + - `InvalidScopeException`: Thrown when a requested scope is invalid + - `InvalidRedirectUriException`: Thrown when a redirect URI is invalid + +## Usage + +To use the authentication functionality: + +1. Implement the `OAuthAuthorizationServerProvider` interface +2. Use the `BearerAuthenticator` to validate Bearer tokens +3. Use the `ClientAuthenticator` to validate client credentials + +Example: + +```java +// Create an OAuth provider implementation +OAuthAuthorizationServerProvider provider = new MyOAuthProvider(); + +// Create authenticators +BearerAuthenticator bearerAuth = new BearerAuthenticator(provider); +ClientAuthenticator clientAuth = new ClientAuthenticator(provider); + +// Authenticate a request with a Bearer token +String authHeader = "Bearer abc123"; +bearerAuth.authenticate(authHeader) + .thenAccept(user -> { + if (user != null) { + // User is authenticated + String clientId = user.getClientId(); + // ... + } else { + // Authentication failed + } + }); + +// Authenticate a client +String clientId = "client123"; +String clientSecret = "secret456"; +clientAuth.authenticate(clientId, clientSecret) + .thenAccept(client -> { + // Client is authenticated + // ... + }) + .exceptionally(ex -> { + // Authentication failed + // ... + return null; + }); +``` + +## Implementation Notes + +- The implementation uses CompletableFuture for asynchronous operations +- Token validation includes expiration checks +- Client authentication supports both secret and no-secret modes +- URI utilities are provided for constructing redirect URIs \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/RefreshToken.java b/mcp/src/main/java/io/modelcontextprotocol/auth/RefreshToken.java new file mode 100644 index 00000000..31014515 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/RefreshToken.java @@ -0,0 +1,60 @@ +package io.modelcontextprotocol.auth; + +import java.util.List; + +/** + * Represents an OAuth refresh token. + */ +public class RefreshToken { + + private String token; + + private String clientId; + + private List scopes; + + private Integer expiresAt; + + public RefreshToken() { + } + + public RefreshToken(String token, String clientId, List scopes, Integer expiresAt) { + this.token = token; + this.clientId = clientId; + this.scopes = scopes; + this.expiresAt = expiresAt; + } + + public String getToken() { + return token; + } + + public void setToken(String token) { + this.token = token; + } + + public String getClientId() { + return clientId; + } + + public void setClientId(String clientId) { + this.clientId = clientId; + } + + public List getScopes() { + return scopes; + } + + public void setScopes(List scopes) { + this.scopes = scopes; + } + + public Integer getExpiresAt() { + return expiresAt; + } + + public void setExpiresAt(Integer expiresAt) { + this.expiresAt = expiresAt; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/exception/AuthorizeException.java b/mcp/src/main/java/io/modelcontextprotocol/auth/exception/AuthorizeException.java new file mode 100644 index 00000000..c16d2759 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/exception/AuthorizeException.java @@ -0,0 +1,26 @@ +package io.modelcontextprotocol.auth.exception; + +/** + * Exception thrown during authorization. + */ +public class AuthorizeException extends Exception { + + private final String error; + + private final String errorDescription; + + public AuthorizeException(String error, String errorDescription) { + super(errorDescription != null ? errorDescription : error); + this.error = error; + this.errorDescription = errorDescription; + } + + public String getError() { + return error; + } + + public String getErrorDescription() { + return errorDescription; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/exception/RegistrationException.java b/mcp/src/main/java/io/modelcontextprotocol/auth/exception/RegistrationException.java new file mode 100644 index 00000000..0f4a1a10 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/exception/RegistrationException.java @@ -0,0 +1,26 @@ +package io.modelcontextprotocol.auth.exception; + +/** + * Exception thrown during client registration. + */ +public class RegistrationException extends Exception { + + private final String error; + + private final String errorDescription; + + public RegistrationException(String error, String errorDescription) { + super(errorDescription != null ? errorDescription : error); + this.error = error; + this.errorDescription = errorDescription; + } + + public String getError() { + return error; + } + + public String getErrorDescription() { + return errorDescription; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/auth/exception/TokenException.java b/mcp/src/main/java/io/modelcontextprotocol/auth/exception/TokenException.java new file mode 100644 index 00000000..e0fa8ca6 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/auth/exception/TokenException.java @@ -0,0 +1,26 @@ +package io.modelcontextprotocol.auth.exception; + +/** + * Exception thrown during token operations. + */ +public class TokenException extends Exception { + + private final String error; + + private final String errorDescription; + + public TokenException(String error, String errorDescription) { + super(errorDescription != null ? errorDescription : error); + this.error = error; + this.errorDescription = errorDescription; + } + + public String getError() { + return error; + } + + public String getErrorDescription() { + return errorDescription; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/auth/AuthCallbackResult.java b/mcp/src/main/java/io/modelcontextprotocol/client/auth/AuthCallbackResult.java new file mode 100644 index 00000000..07ac9d2d --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/auth/AuthCallbackResult.java @@ -0,0 +1,38 @@ +package io.modelcontextprotocol.client.auth; + +/** + * Result of an OAuth authorization callback. + */ +public class AuthCallbackResult { + + private final String code; + + private final String state; + + /** + * Creates a new AuthCallbackResult. + * @param code The authorization code. + * @param state The state parameter. + */ + public AuthCallbackResult(String code, String state) { + this.code = code; + this.state = state; + } + + /** + * Get the authorization code. + * @return The authorization code. + */ + public String getCode() { + return code; + } + + /** + * Get the state parameter. + * @return The state parameter. + */ + public String getState() { + return state; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/auth/HttpClientAuthenticator.java b/mcp/src/main/java/io/modelcontextprotocol/client/auth/HttpClientAuthenticator.java new file mode 100644 index 00000000..d54c80b1 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/auth/HttpClientAuthenticator.java @@ -0,0 +1,52 @@ +package io.modelcontextprotocol.client.auth; + +import java.io.IOException; +import java.net.http.HttpRequest; +import java.net.http.HttpRequest.Builder; +import java.util.concurrent.CompletableFuture; + +/** + * Authenticator for HTTP requests using OAuth. + */ +public class HttpClientAuthenticator { + + private final OAuthClientProvider oauthProvider; + + /** + * Creates a new HttpClientAuthenticator. + * @param oauthProvider The OAuth client provider. + */ + public HttpClientAuthenticator(OAuthClientProvider oauthProvider) { + this.oauthProvider = oauthProvider; + } + + /** + * Authenticate an HTTP request by adding an Authorization header with the OAuth + * token. + * @param requestBuilder The HTTP request builder. + * @return A CompletableFuture that completes with the authenticated request builder. + */ + public CompletableFuture authenticate(HttpRequest.Builder requestBuilder) { + return oauthProvider.ensureToken().thenApply(v -> { + String accessToken = oauthProvider.getAccessToken(); + if (accessToken != null) { + return requestBuilder.header("Authorization", "Bearer " + accessToken); + } + return requestBuilder; + }); + } + + /** + * Handle an HTTP response, refreshing the token if needed. + * @param statusCode The HTTP status code. + * @return A CompletableFuture that completes when the response is handled. + */ + public CompletableFuture handleResponse(int statusCode) { + if (statusCode == 401) { + // Force token refresh on 401 Unauthorized + return oauthProvider.ensureToken(); + } + return CompletableFuture.completedFuture(null); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/auth/InMemoryTokenStorage.java b/mcp/src/main/java/io/modelcontextprotocol/client/auth/InMemoryTokenStorage.java new file mode 100644 index 00000000..54cff786 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/auth/InMemoryTokenStorage.java @@ -0,0 +1,39 @@ +package io.modelcontextprotocol.client.auth; + +import io.modelcontextprotocol.auth.OAuthClientInformation; +import io.modelcontextprotocol.auth.OAuthToken; + +import java.util.concurrent.CompletableFuture; + +/** + * In-memory implementation of TokenStorage. + */ +public class InMemoryTokenStorage implements TokenStorage { + + private OAuthToken tokens; + + private OAuthClientInformation clientInfo; + + @Override + public CompletableFuture getTokens() { + return CompletableFuture.completedFuture(tokens); + } + + @Override + public CompletableFuture setTokens(OAuthToken tokens) { + this.tokens = tokens; + return CompletableFuture.completedFuture(null); + } + + @Override + public CompletableFuture getClientInfo() { + return CompletableFuture.completedFuture(clientInfo); + } + + @Override + public CompletableFuture setClientInfo(OAuthClientInformation clientInfo) { + this.clientInfo = clientInfo; + return CompletableFuture.completedFuture(null); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/auth/OAuthClientProvider.java b/mcp/src/main/java/io/modelcontextprotocol/client/auth/OAuthClientProvider.java new file mode 100644 index 00000000..2a44195e --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/auth/OAuthClientProvider.java @@ -0,0 +1,594 @@ +package io.modelcontextprotocol.client.auth; + +import io.modelcontextprotocol.auth.OAuthClientInformation; +import io.modelcontextprotocol.auth.OAuthClientMetadata; +import io.modelcontextprotocol.auth.OAuthMetadata; +import io.modelcontextprotocol.auth.OAuthToken; + +import java.io.IOException; +import java.net.URI; +import java.net.URLEncoder; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Function; + +import com.fasterxml.jackson.databind.ObjectMapper; + +/** + * OAuth client provider that handles the OAuth 2.0 authorization code flow with PKCE. + */ +public class OAuthClientProvider { + + private final String serverUrl; + + private final OAuthClientMetadata clientMetadata; + + private final TokenStorage storage; + + private final Function> redirectHandler; + + private final Function> callbackHandler; + + private final Duration timeout; + + private final HttpClient httpClient; + + private final ObjectMapper objectMapper; + + // Cached authentication state + private OAuthToken currentTokens; + + private OAuthMetadata metadata; + + private OAuthClientInformation clientInfo; + + private Long tokenExpiryTime; + + // PKCE flow parameters + private String codeVerifier; + + private String codeChallenge; + + // State parameter for CSRF protection + private String authState; + + // Thread safety lock + private final ReentrantLock tokenLock = new ReentrantLock(); + + /** + * Creates a new OAuthClientProvider. + * @param serverUrl Base URL of the OAuth server + * @param clientMetadata OAuth client metadata + * @param storage Token storage implementation + * @param redirectHandler Function to handle authorization URL (e.g., opening a + * browser) + * @param callbackHandler Function to wait for callback and return auth code and state + * @param timeout Timeout for OAuth flow + */ + public OAuthClientProvider(String serverUrl, OAuthClientMetadata clientMetadata, TokenStorage storage, + Function> redirectHandler, + Function> callbackHandler, Duration timeout) { + + this.serverUrl = serverUrl; + this.clientMetadata = clientMetadata; + this.storage = storage; + this.redirectHandler = redirectHandler; + this.callbackHandler = callbackHandler; + this.timeout = timeout; + this.httpClient = HttpClient.newBuilder().connectTimeout(Duration.ofSeconds(30)).build(); + this.objectMapper = new ObjectMapper(); + } + + /** + * Initialize the provider by loading stored tokens and client info. + * @return A CompletableFuture that completes when initialization is done. + */ + public CompletableFuture initialize() { + return storage.getTokens() + .thenAccept(tokens -> this.currentTokens = tokens) + .thenCompose(v -> storage.getClientInfo()) + .thenAccept(clientInfo -> this.clientInfo = clientInfo); + } + + /** + * Ensure a valid access token is available, refreshing or re-authenticating as + * needed. + * @return A CompletableFuture that completes when a valid token is available. + */ + public CompletableFuture ensureToken() { + if (hasValidToken()) { + return CompletableFuture.completedFuture(null); + } + + tokenLock.lock(); + try { + // Check again after acquiring lock + if (hasValidToken()) { + return CompletableFuture.completedFuture(null); + } + + // Try refreshing existing token + if (currentTokens != null && currentTokens.getRefreshToken() != null) { + return refreshAccessToken().thenCompose(refreshed -> { + if (Boolean.TRUE.equals(refreshed)) { + return CompletableFuture.completedFuture(null); + } + else { + // Fall back to full OAuth flow if refresh fails + return performOAuthFlow(); + } + }); + } + else { + // No refresh token, perform full OAuth flow + return performOAuthFlow(); + } + } + finally { + tokenLock.unlock(); + } + } + + /** + * Check if the current token is valid. + * @return true if a valid token exists, false otherwise. + */ + private boolean hasValidToken() { + if (currentTokens == null || currentTokens.getAccessToken() == null) { + return false; + } + + // Check expiry time + return tokenExpiryTime == null || System.currentTimeMillis() < tokenExpiryTime; + } + + /** + * Perform the OAuth 2.0 authorization code flow with PKCE. + * @return A CompletableFuture that completes when the flow is done. + */ + private CompletableFuture performOAuthFlow() { + // Discover OAuth metadata + return discoverOAuthMetadata(serverUrl).thenCompose(metadata -> { + this.metadata = metadata; + return getOrRegisterClient(); + }).thenCompose(clientInfo -> { + // Generate PKCE challenge + this.codeVerifier = PkceUtils.generateCodeVerifier(); + this.codeChallenge = PkceUtils.generateCodeChallenge(codeVerifier); + + // Generate state for CSRF protection + byte[] stateBytes = new byte[32]; + new java.security.SecureRandom().nextBytes(stateBytes); + this.authState = java.util.Base64.getUrlEncoder().withoutPadding().encodeToString(stateBytes); + + // Build authorization URL + String authUrl = buildAuthorizationUrl(clientInfo); + + // Redirect user for authorization + return redirectHandler.apply(authUrl) + .thenCompose(v -> callbackHandler.apply(null)) + .thenCompose(callbackResult -> { + // Validate state parameter + if (callbackResult.getState() == null || !callbackResult.getState().equals(authState)) { + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally( + new SecurityException("State parameter mismatch: possible CSRF attack")); + return future; + } + + // Clear state after validation + authState = null; + + if (callbackResult.getCode() == null) { + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(new IllegalStateException("No authorization code received")); + return future; + } + + // Exchange authorization code for tokens + return exchangeCodeForToken(callbackResult.getCode(), clientInfo); + }); + }); + } + + /** + * Discover OAuth metadata from server's well-known endpoint. + * @param serverUrl The server URL. + * @return A CompletableFuture that resolves to the OAuth metadata. + */ + private CompletableFuture discoverOAuthMetadata(String serverUrl) { + String authBaseUrl = getAuthorizationBaseUrl(serverUrl); + String url = authBaseUrl + "/.well-known/oauth-authorization-server"; + + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create(url)) + .header("MCP-Protocol-Version", "0.1") + .GET() + .build(); + + return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()).thenApply(response -> { + if (response.statusCode() == 404) { + return null; + } + if (response.statusCode() != 200) { + throw new RuntimeException("Failed to discover OAuth metadata: " + response.statusCode()); + } + try { + return objectMapper.readValue(response.body(), OAuthMetadata.class); + } + catch (IOException e) { + throw new RuntimeException("Failed to parse OAuth metadata", e); + } + }).exceptionally(ex -> { + // Try again without MCP header + HttpRequest retryRequest = HttpRequest.newBuilder().uri(URI.create(url)).GET().build(); + + try { + HttpResponse response = httpClient.send(retryRequest, HttpResponse.BodyHandlers.ofString()); + if (response.statusCode() == 404) { + return null; + } + if (response.statusCode() != 200) { + return null; + } + return objectMapper.readValue(response.body(), OAuthMetadata.class); + } + catch (Exception e) { + return null; + } + }); + } + + /** + * Get or register client with server. + * @return A CompletableFuture that resolves to the client information. + */ + private CompletableFuture getOrRegisterClient() { + if (clientInfo != null) { + return CompletableFuture.completedFuture(clientInfo); + } + + return registerOAuthClient(serverUrl, clientMetadata, metadata).thenCompose(registeredClient -> { + this.clientInfo = registeredClient; + return storage.setClientInfo(registeredClient).thenApply(v -> registeredClient); + }); + } + + /** + * Register OAuth client with server. + * @param serverUrl The server URL. + * @param clientMetadata The client metadata. + * @param metadata The OAuth metadata. + * @return A CompletableFuture that resolves to the registered client information. + */ + private CompletableFuture registerOAuthClient(String serverUrl, + OAuthClientMetadata clientMetadata, OAuthMetadata metadata) { + + String registrationUrl; + if (metadata != null && metadata.getRegistrationEndpoint() != null) { + registrationUrl = metadata.getRegistrationEndpoint().toString(); + } + else { + // Use fallback registration endpoint + String authBaseUrl = getAuthorizationBaseUrl(serverUrl); + registrationUrl = authBaseUrl + "/register"; + } + + // Handle default scope + if (clientMetadata.getScope() == null && metadata != null && metadata.getScopesSupported() != null + && !metadata.getScopesSupported().isEmpty()) { + clientMetadata.setScope(String.join(" ", metadata.getScopesSupported())); + } + + try { + String requestBody = objectMapper.writeValueAsString(clientMetadata); + + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create(registrationUrl)) + .header("Content-Type", "application/json") + .POST(HttpRequest.BodyPublishers.ofString(requestBody)) + .build(); + + return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()).thenApply(response -> { + if (response.statusCode() != 200 && response.statusCode() != 201) { + throw new RuntimeException("Registration failed: " + response.statusCode()); + } + try { + return objectMapper.readValue(response.body(), OAuthClientInformation.class); + } + catch (IOException e) { + throw new RuntimeException("Failed to parse client information", e); + } + }); + } + catch (Exception e) { + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(e); + return future; + } + } + + /** + * Build authorization URL for the OAuth flow. + * @param clientInfo The client information. + * @return The authorization URL. + */ + private String buildAuthorizationUrl(OAuthClientInformation clientInfo) { + String authUrlBase; + if (metadata != null && metadata.getAuthorizationEndpoint() != null) { + authUrlBase = metadata.getAuthorizationEndpoint().toString(); + } + else { + // Use fallback authorization endpoint + String authBaseUrl = getAuthorizationBaseUrl(serverUrl); + authUrlBase = authBaseUrl + "/authorize"; + } + + Map params = new HashMap<>(); + params.put("response_type", "code"); + params.put("client_id", clientInfo.getClientId()); + params.put("redirect_uri", clientInfo.getRedirectUris().get(0).toString()); + params.put("state", authState); + params.put("code_challenge", codeChallenge); + params.put("code_challenge_method", "S256"); + + // Include explicit scopes only + if (clientMetadata.getScope() != null) { + params.put("scope", clientMetadata.getScope()); + } + + return authUrlBase + "?" + formatQueryParams(params); + } + + /** + * Exchange authorization code for access token. + * @param authCode The authorization code. + * @param clientInfo The client information. + * @return A CompletableFuture that completes when the exchange is done. + */ + private CompletableFuture exchangeCodeForToken(String authCode, OAuthClientInformation clientInfo) { + String tokenUrl; + if (metadata != null && metadata.getTokenEndpoint() != null) { + tokenUrl = metadata.getTokenEndpoint().toString(); + } + else { + // Use fallback token endpoint + String authBaseUrl = getAuthorizationBaseUrl(serverUrl); + tokenUrl = authBaseUrl + "/token"; + } + + Map formData = new HashMap<>(); + formData.put("grant_type", "authorization_code"); + formData.put("code", authCode); + formData.put("redirect_uri", clientInfo.getRedirectUris().get(0).toString()); + formData.put("client_id", clientInfo.getClientId()); + formData.put("code_verifier", codeVerifier); + + if (clientInfo.getClientSecret() != null) { + formData.put("client_secret", clientInfo.getClientSecret()); + } + + String requestBody = formatQueryParams(formData); + + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create(tokenUrl)) + .header("Content-Type", "application/x-www-form-urlencoded") + .timeout(Duration.ofSeconds(30)) + .POST(HttpRequest.BodyPublishers.ofString(requestBody)) + .build(); + + return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()).thenCompose(response -> { + if (response.statusCode() != 200) { + try { + Map errorData = objectMapper.readValue(response.body(), Map.class); + Object errorDesc = errorData.get("error_description"); + if (errorDesc == null) { + errorDesc = errorData.get("error"); + } + if (errorDesc == null) { + errorDesc = "Unknown error"; + } + String errorMsg = errorDesc.toString(); + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(new RuntimeException( + "Token exchange failed: " + errorMsg + " (HTTP " + response.statusCode() + ")")); + return future; + } + catch (Exception e) { + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(new RuntimeException( + "Token exchange failed: " + response.statusCode() + " " + response.body())); + return future; + } + } + + try { + OAuthToken tokenResponse = objectMapper.readValue(response.body(), OAuthToken.class); + + // Validate token scopes + validateTokenScopes(tokenResponse); + + // Calculate token expiry + if (tokenResponse.getExpiresIn() != null) { + tokenExpiryTime = System.currentTimeMillis() + (tokenResponse.getExpiresIn() * 1000L); + } + else { + tokenExpiryTime = null; + } + + // Store tokens + currentTokens = tokenResponse; + return storage.setTokens(tokenResponse); + } + catch (Exception e) { + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(e); + return future; + } + }); + } + + /** + * Refresh access token using refresh token. + * @return A CompletableFuture that resolves to true if refresh was successful, false + * otherwise. + */ + private CompletableFuture refreshAccessToken() { + if (currentTokens == null || currentTokens.getRefreshToken() == null) { + return CompletableFuture.completedFuture(false); + } + + return getOrRegisterClient().thenCompose(clientInfo -> { + String tokenUrl; + if (metadata != null && metadata.getTokenEndpoint() != null) { + tokenUrl = metadata.getTokenEndpoint().toString(); + } + else { + // Use fallback token endpoint + String authBaseUrl = getAuthorizationBaseUrl(serverUrl); + tokenUrl = authBaseUrl + "/token"; + } + + Map formData = new HashMap<>(); + formData.put("grant_type", "refresh_token"); + formData.put("refresh_token", currentTokens.getRefreshToken()); + formData.put("client_id", clientInfo.getClientId()); + + if (clientInfo.getClientSecret() != null) { + formData.put("client_secret", clientInfo.getClientSecret()); + } + + String requestBody = formatQueryParams(formData); + + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create(tokenUrl)) + .header("Content-Type", "application/x-www-form-urlencoded") + .timeout(Duration.ofSeconds(30)) + .POST(HttpRequest.BodyPublishers.ofString(requestBody)) + .build(); + + return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()).thenCompose(response -> { + if (response.statusCode() != 200) { + return CompletableFuture.completedFuture(false); + } + + try { + OAuthToken tokenResponse = objectMapper.readValue(response.body(), OAuthToken.class); + + // Validate token scopes + validateTokenScopes(tokenResponse); + + // Calculate token expiry + if (tokenResponse.getExpiresIn() != null) { + tokenExpiryTime = System.currentTimeMillis() + (tokenResponse.getExpiresIn() * 1000L); + } + else { + tokenExpiryTime = null; + } + + // Store refreshed tokens + currentTokens = tokenResponse; + return storage.setTokens(tokenResponse).thenApply(v -> true); + } + catch (Exception e) { + return CompletableFuture.completedFuture(false); + } + }).exceptionally(ex -> false); + }); + } + + /** + * Validate returned scopes against requested scopes. + * @param tokenResponse The token response. + */ + private void validateTokenScopes(OAuthToken tokenResponse) { + if (tokenResponse.getScope() == null) { + // No scope returned = validation passes + return; + } + + // Check explicitly requested scopes only + if (clientMetadata.getScope() != null) { + // Validate against explicit scope request + String[] requestedScopes = clientMetadata.getScope().split(" "); + String[] returnedScopes = tokenResponse.getScope().split(" "); + + // Check for unauthorized scopes + for (String returnedScope : returnedScopes) { + boolean found = false; + for (String requestedScope : requestedScopes) { + if (returnedScope.equals(requestedScope)) { + found = true; + break; + } + } + + if (!found) { + throw new IllegalStateException("Server granted unauthorized scope: " + returnedScope); + } + } + } + } + + /** + * Extract base URL by removing path component. + * @param serverUrl The server URL. + * @return The base URL. + */ + private String getAuthorizationBaseUrl(String serverUrl) { + try { + URI uri = new URI(serverUrl); + return new URI(uri.getScheme(), uri.getAuthority(), null, null, null).toString(); + } + catch (Exception e) { + throw new IllegalArgumentException("Invalid server URL: " + serverUrl, e); + } + } + + /** + * Format query parameters for URL or form data. + * @param params The parameters. + * @return The formatted query string. + */ + private String formatQueryParams(Map params) { + StringBuilder result = new StringBuilder(); + boolean first = true; + + for (Map.Entry entry : params.entrySet()) { + if (!first) { + result.append("&"); + } + first = false; + + result.append(URLEncoder.encode(entry.getKey(), StandardCharsets.UTF_8)); + result.append("="); + result.append(URLEncoder.encode(entry.getValue(), StandardCharsets.UTF_8)); + } + + return result.toString(); + } + + /** + * Get the current access token. + * @return The access token, or null if none exists. + */ + public String getAccessToken() { + return currentTokens != null ? currentTokens.getAccessToken() : null; + } + + /** + * Get the current OAuth tokens. + * @return The OAuth tokens, or null if none exist. + */ + public OAuthToken getCurrentTokens() { + return currentTokens; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/auth/PkceUtils.java b/mcp/src/main/java/io/modelcontextprotocol/client/auth/PkceUtils.java new file mode 100644 index 00000000..3c33e599 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/auth/PkceUtils.java @@ -0,0 +1,46 @@ +package io.modelcontextprotocol.client.auth; + +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.util.Base64; + +/** + * Utility class for PKCE (Proof Key for Code Exchange) operations. + */ +public class PkceUtils { + + private static final SecureRandom secureRandom = new SecureRandom(); + + private static final String ALLOWED_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"; + + /** + * Generates a cryptographically random code verifier for PKCE. + * @return A random code verifier string. + */ + public static String generateCodeVerifier() { + StringBuilder codeVerifier = new StringBuilder(128); + for (int i = 0; i < 128; i++) { + codeVerifier.append(ALLOWED_CHARS.charAt(secureRandom.nextInt(ALLOWED_CHARS.length()))); + } + return codeVerifier.toString(); + } + + /** + * Generates a code challenge from a code verifier using SHA-256. + * @param codeVerifier The code verifier to hash. + * @return The code challenge string. + */ + public static String generateCodeChallenge(String codeVerifier) { + try { + MessageDigest digest = MessageDigest.getInstance("SHA-256"); + byte[] hash = digest.digest(codeVerifier.getBytes(StandardCharsets.UTF_8)); + return Base64.getUrlEncoder().withoutPadding().encodeToString(hash); + } + catch (NoSuchAlgorithmException e) { + throw new RuntimeException("SHA-256 algorithm not available", e); + } + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/auth/README.md b/mcp/src/main/java/io/modelcontextprotocol/client/auth/README.md new file mode 100644 index 00000000..7fbfedf1 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/auth/README.md @@ -0,0 +1,97 @@ +# OAuth 2.0 Client Implementation + +This package provides an OAuth 2.0 client implementation for the MCP Java SDK, supporting the Authorization Code flow with PKCE (Proof Key for Code Exchange). + +## Components + +- `OAuthClientProvider`: Main class that handles the OAuth 2.0 flow +- `TokenStorage`: Interface for storing OAuth tokens and client information +- `InMemoryTokenStorage`: Simple in-memory implementation of TokenStorage +- `HttpClientAuthenticator`: Authenticator for HTTP requests using OAuth +- `PkceUtils`: Utility class for PKCE operations +- `AuthCallbackResult`: Class to hold the result of an OAuth authorization callback + +## Usage Example + +```java +// Create client metadata +OAuthClientMetadata clientMetadata = new OAuthClientMetadata(); +clientMetadata.setRedirectUris(List.of(URI.create("http://localhost:8080/callback"))); +clientMetadata.setScope("read write"); + +// Create token storage +TokenStorage storage = new InMemoryTokenStorage(); + +// Create redirect handler (e.g., open browser) +Function> redirectHandler = url -> { + try { + Desktop.getDesktop().browse(URI.create(url)); + return CompletableFuture.completedFuture(null); + } catch (IOException e) { + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(e); + return future; + } +}; + +// Create callback handler (e.g., start local server to receive callback) +Function> callbackHandler = v -> { + // Implementation to start a local server and wait for callback + // Return CompletableFuture with code and state +}; + +// Create OAuth client provider +OAuthClientProvider provider = new OAuthClientProvider( + "https://api.example.com", + clientMetadata, + storage, + redirectHandler, + callbackHandler, + Duration.ofMinutes(5) +); + +// Initialize provider +provider.initialize() + .thenCompose(v -> provider.ensureToken()) + .thenRun(() -> { + // Now you have a valid token + String accessToken = provider.getAccessToken(); + System.out.println("Access token: " + accessToken); + }) + .exceptionally(ex -> { + System.err.println("Authentication failed: " + ex.getMessage()); + return null; + }); +``` + +## HTTP Client Integration + +```java +HttpClientAuthenticator authenticator = new HttpClientAuthenticator(provider); + +HttpRequest.Builder requestBuilder = HttpRequest.newBuilder() + .uri(URI.create("https://api.example.com/resource")) + .GET(); + +authenticator.authenticate(requestBuilder) + .thenCompose(builder -> { + HttpRequest request = builder.build(); + return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()); + }) + .thenCompose(response -> { + // Handle 401 responses by refreshing the token + return authenticator.handleResponse(response.statusCode()) + .thenApply(v -> response); + }) + .thenAccept(response -> { + System.out.println("Response: " + response.body()); + }); +``` + +## Security Features + +- PKCE support to prevent authorization code interception attacks +- State parameter to prevent CSRF attacks +- Automatic token refresh +- Thread-safe token management +- Scope validation \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/auth/TokenStorage.java b/mcp/src/main/java/io/modelcontextprotocol/client/auth/TokenStorage.java new file mode 100644 index 00000000..1b67b6d0 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/auth/TokenStorage.java @@ -0,0 +1,41 @@ +package io.modelcontextprotocol.client.auth; + +import io.modelcontextprotocol.auth.OAuthClientInformation; +import io.modelcontextprotocol.auth.OAuthToken; + +import java.util.concurrent.CompletableFuture; + +/** + * Interface for token storage implementations. + */ +public interface TokenStorage { + + /** + * Get stored tokens. + * @return A CompletableFuture that resolves to the stored tokens, or null if none + * exist. + */ + CompletableFuture getTokens(); + + /** + * Store tokens. + * @param tokens The tokens to store. + * @return A CompletableFuture that completes when the tokens are stored. + */ + CompletableFuture setTokens(OAuthToken tokens); + + /** + * Get stored client information. + * @return A CompletableFuture that resolves to the stored client information, or null + * if none exists. + */ + CompletableFuture getClientInfo(); + + /** + * Store client information. + * @param clientInfo The client information to store. + * @return A CompletableFuture that completes when the client information is stored. + */ + CompletableFuture setClientInfo(OAuthClientInformation clientInfo); + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/OAuthRoutes.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/OAuthRoutes.java new file mode 100644 index 00000000..0e14def6 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/OAuthRoutes.java @@ -0,0 +1,161 @@ +package io.modelcontextprotocol.server.auth; + +import io.modelcontextprotocol.auth.OAuthAuthorizationServerProvider; +import io.modelcontextprotocol.auth.OAuthMetadata; +import io.modelcontextprotocol.server.auth.handlers.AuthorizationHandler; +import io.modelcontextprotocol.server.auth.handlers.MetadataHandler; +import io.modelcontextprotocol.server.auth.handlers.RegistrationHandler; +import io.modelcontextprotocol.server.auth.handlers.RevocationHandler; +import io.modelcontextprotocol.server.auth.handlers.TokenHandler; +import io.modelcontextprotocol.server.auth.middleware.ClientAuthenticator; +import io.modelcontextprotocol.server.auth.settings.ClientRegistrationOptions; +import io.modelcontextprotocol.server.auth.settings.RevocationOptions; +import io.modelcontextprotocol.server.auth.util.UriUtils; + +import java.net.URI; +import java.util.ArrayList; +import java.util.List; + +/** + * Helper class for creating OAuth routes. + */ +public class OAuthRoutes { + + public static final String AUTHORIZATION_PATH = "/authorize"; + + public static final String TOKEN_PATH = "/token"; + + public static final String REGISTRATION_PATH = "/register"; + + public static final String REVOCATION_PATH = "/revoke"; + + public static final String METADATA_PATH = "/.well-known/oauth-authorization-server"; + + /** + * Create OAuth metadata for the server. + * @param issuerUrl The issuer URL + * @param serviceDocumentationUrl The service documentation URL + * @param clientRegistrationOptions The client registration options + * @param revocationOptions The revocation options + * @return The OAuth metadata + */ + public static OAuthMetadata buildMetadata(URI issuerUrl, URI serviceDocumentationUrl, + ClientRegistrationOptions clientRegistrationOptions, RevocationOptions revocationOptions) { + + UriUtils.validateIssuerUrl(issuerUrl); + + URI authorizationUrl = UriUtils.buildEndpointUrl(issuerUrl, AUTHORIZATION_PATH); + URI tokenUrl = UriUtils.buildEndpointUrl(issuerUrl, TOKEN_PATH); + + OAuthMetadata metadata = new OAuthMetadata(); + metadata.setIssuer(issuerUrl); + metadata.setAuthorizationEndpoint(authorizationUrl); + metadata.setTokenEndpoint(tokenUrl); + metadata.setScopesSupported(clientRegistrationOptions.getValidScopes()); + metadata.setResponseTypesSupported(List.of("code")); + metadata.setGrantTypesSupported(List.of("authorization_code", "refresh_token")); + metadata.setTokenEndpointAuthMethodsSupported(List.of("client_secret_post")); + metadata.setServiceDocumentation(serviceDocumentationUrl); + metadata.setCodeChallengeMethodsSupported(List.of("S256")); + + // Add registration endpoint if supported + if (clientRegistrationOptions.isEnabled()) { + metadata.setRegistrationEndpoint(UriUtils.buildEndpointUrl(issuerUrl, REGISTRATION_PATH)); + } + + // Add revocation endpoint if supported + if (revocationOptions.isEnabled()) { + metadata.setRevocationEndpoint(UriUtils.buildEndpointUrl(issuerUrl, REVOCATION_PATH)); + metadata.setRevocationEndpointAuthMethodsSupported(List.of("client_secret_post")); + } + + return metadata; + } + + /** + * Create handlers for OAuth routes. + * @param provider The OAuth authorization server provider + * @param metadata The OAuth metadata + * @param clientRegistrationOptions The client registration options + * @param revocationOptions The revocation options + * @return A map of route handlers + */ + public static OAuthHandlers createHandlers(OAuthAuthorizationServerProvider provider, OAuthMetadata metadata, + ClientRegistrationOptions clientRegistrationOptions, RevocationOptions revocationOptions) { + + ClientAuthenticator clientAuthenticator = new ClientAuthenticator(provider); + + OAuthHandlers handlers = new OAuthHandlers(); + handlers.setMetadataHandler(new MetadataHandler(metadata)); + handlers.setAuthorizationHandler(new AuthorizationHandler(provider)); + handlers.setTokenHandler(new TokenHandler(provider, clientAuthenticator)); + + if (clientRegistrationOptions.isEnabled()) { + handlers.setRegistrationHandler(new RegistrationHandler(provider, clientRegistrationOptions)); + } + + if (revocationOptions.isEnabled()) { + handlers.setRevocationHandler(new RevocationHandler(provider, clientAuthenticator)); + } + + return handlers; + } + + /** + * Container for OAuth route handlers. + */ + public static class OAuthHandlers { + + private MetadataHandler metadataHandler; + + private AuthorizationHandler authorizationHandler; + + private TokenHandler tokenHandler; + + private RegistrationHandler registrationHandler; + + private RevocationHandler revocationHandler; + + public MetadataHandler getMetadataHandler() { + return metadataHandler; + } + + public void setMetadataHandler(MetadataHandler metadataHandler) { + this.metadataHandler = metadataHandler; + } + + public AuthorizationHandler getAuthorizationHandler() { + return authorizationHandler; + } + + public void setAuthorizationHandler(AuthorizationHandler authorizationHandler) { + this.authorizationHandler = authorizationHandler; + } + + public TokenHandler getTokenHandler() { + return tokenHandler; + } + + public void setTokenHandler(TokenHandler tokenHandler) { + this.tokenHandler = tokenHandler; + } + + public RegistrationHandler getRegistrationHandler() { + return registrationHandler; + } + + public void setRegistrationHandler(RegistrationHandler registrationHandler) { + this.registrationHandler = registrationHandler; + } + + public RevocationHandler getRevocationHandler() { + return revocationHandler; + } + + public void setRevocationHandler(RevocationHandler revocationHandler) { + this.revocationHandler = revocationHandler; + } + + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/README.md b/mcp/src/main/java/io/modelcontextprotocol/server/auth/README.md new file mode 100644 index 00000000..85544132 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/README.md @@ -0,0 +1,167 @@ +# OAuth 2.0 Server Implementation + +This package provides an OAuth 2.0 server implementation for the MCP Java SDK, supporting the Authorization Code flow with PKCE (Proof Key for Code Exchange). + +## Components + +### Handlers +- `AuthorizationHandler`: Handles OAuth authorization requests +- `TokenHandler`: Handles OAuth token requests +- `RegistrationHandler`: Handles OAuth client registration requests +- `RevocationHandler`: Handles OAuth token revocation requests +- `MetadataHandler`: Handles OAuth metadata requests + +### Middleware +- `ClientAuthenticator`: Authenticates OAuth clients +- `BearerAuthenticator`: Authenticates requests with bearer tokens + +### Settings +- `ClientRegistrationOptions`: Options for OAuth client registration +- `RevocationOptions`: Options for OAuth token revocation + +### Utilities +- `OAuthRoutes`: Helper class for creating OAuth routes and metadata + +## Usage Example + +```java +// Create provider implementation +OAuthAuthorizationServerProvider provider = new MyOAuthProvider(); + +// Create options +ClientRegistrationOptions registrationOptions = new ClientRegistrationOptions(); +registrationOptions.setValidScopes(List.of("read", "write")); + +RevocationOptions revocationOptions = new RevocationOptions(); + +// Create metadata +URI issuerUrl = URI.create("https://api.example.com"); +URI docsUrl = URI.create("https://docs.example.com"); +OAuthMetadata metadata = OAuthRoutes.buildMetadata( + issuerUrl, + docsUrl, + registrationOptions, + revocationOptions +); + +// Create handlers +OAuthRoutes.OAuthHandlers handlers = OAuthRoutes.createHandlers( + provider, + metadata, + registrationOptions, + revocationOptions +); + +// Use handlers in your web framework +// For example, with Spring MVC: + +@RestController +public class OAuthController { + + private final OAuthRoutes.OAuthHandlers handlers; + + public OAuthController(OAuthRoutes.OAuthHandlers handlers) { + this.handlers = handlers; + } + + @GetMapping("/.well-known/oauth-authorization-server") + public OAuthMetadata getMetadata() { + return handlers.getMetadataHandler().handle().join(); + } + + @GetMapping("/authorize") + public ResponseEntity authorize(@RequestParam Map params) { + try { + String redirectUrl = handlers.getAuthorizationHandler().handle(params).join(); + return ResponseEntity.status(HttpStatus.FOUND) + .header("Location", redirectUrl) + .header("Cache-Control", "no-store") + .build(); + } catch (CompletionException e) { + // Handle errors + return ResponseEntity.badRequest().body("Error: " + e.getCause().getMessage()); + } + } + + @PostMapping("/token") + public ResponseEntity token(@RequestParam Map params) { + try { + OAuthToken token = handlers.getTokenHandler().handle(params).join(); + return ResponseEntity.ok() + .header("Cache-Control", "no-store") + .header("Pragma", "no-cache") + .body(token); + } catch (CompletionException e) { + // Handle errors + return ResponseEntity.badRequest().body(null); + } + } + + // Add other endpoints for registration and revocation +} +``` + +## Provider Implementation + +You need to implement the `OAuthAuthorizationServerProvider` interface to provide the actual OAuth functionality: + +```java +public class MyOAuthProvider implements OAuthAuthorizationServerProvider { + + // Store clients, authorization codes, tokens, etc. + private final Map clients = new ConcurrentHashMap<>(); + private final Map authCodes = new ConcurrentHashMap<>(); + private final Map accessTokens = new ConcurrentHashMap<>(); + private final Map refreshTokens = new ConcurrentHashMap<>(); + + @Override + public CompletableFuture getClient(String clientId) { + return CompletableFuture.completedFuture(clients.get(clientId)); + } + + @Override + public CompletableFuture registerClient(OAuthClientInformation clientInfo) { + clients.put(clientInfo.getClientId(), clientInfo); + return CompletableFuture.completedFuture(null); + } + + @Override + public CompletableFuture authorize(OAuthClientInformation client, AuthorizationParams params) { + // In a real implementation, you would show a UI to the user + // and get their consent before generating an authorization code + + // For this example, we'll just generate a code immediately + String code = generateRandomCode(); + + AuthorizationCode authCode = new AuthorizationCode(); + authCode.setClientId(client.getClientId()); + authCode.setCodeChallenge(params.getCodeChallenge()); + authCode.setRedirectUri(params.getRedirectUri()); + authCode.setRedirectUriProvidedExplicitly(params.isRedirectUriProvidedExplicitly()); + authCode.setScopes(params.getScopes()); + authCode.setExpiresAt(Instant.now().plusSeconds(600).getEpochSecond()); // 10 minutes + + authCodes.put(code, authCode); + + // Build redirect URI with code and state + String redirectUri = params.getRedirectUri().toString(); + redirectUri += "?code=" + code; + if (params.getState() != null) { + redirectUri += "&state=" + params.getState(); + } + + return CompletableFuture.completedFuture(redirectUri); + } + + // Implement other methods... +} +``` + +## Security Features + +- PKCE support to prevent authorization code interception attacks +- State parameter to prevent CSRF attacks +- Strict redirect URI validation +- Token expiration +- Scope validation +- HTTPS requirement (with localhost exception for testing) \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/AuthorizationHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/AuthorizationHandler.java new file mode 100644 index 00000000..e635a05e --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/AuthorizationHandler.java @@ -0,0 +1,181 @@ +package io.modelcontextprotocol.server.auth.handlers; + +import io.modelcontextprotocol.auth.AuthorizationParams; +import io.modelcontextprotocol.auth.InvalidRedirectUriException; +import io.modelcontextprotocol.auth.InvalidScopeException; +import io.modelcontextprotocol.auth.OAuthAuthorizationServerProvider; +import io.modelcontextprotocol.auth.OAuthClientInformation; +import io.modelcontextprotocol.auth.exception.AuthorizeException; +import io.modelcontextprotocol.server.auth.model.AuthorizationErrorResponse; +import io.modelcontextprotocol.server.auth.util.UriUtils; + +import java.net.URI; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +/** + * Handler for OAuth authorization requests. + */ +public class AuthorizationHandler { + + private final OAuthAuthorizationServerProvider provider; + + public AuthorizationHandler(OAuthAuthorizationServerProvider provider) { + this.provider = provider; + } + + /** + * Handle an authorization request. + * @param params The request parameters + * @return A CompletableFuture that resolves to a response object containing either a + * redirect URL or an error + */ + public CompletableFuture handle(Map params) { + String clientId = params.get("client_id"); + String redirectUriStr = params.get("redirect_uri"); + String responseType = params.get("response_type"); + String codeChallenge = params.get("code_challenge"); + String codeChallengeMethod = params.get("code_challenge_method"); + String state = params.get("state"); + String scope = params.get("scope"); + + // Validate required parameters + if (clientId == null || responseType == null || codeChallenge == null) { + return CompletableFuture + .completedFuture(createErrorResponse("invalid_request", "Missing required parameters", state, null)); + } + + // Validate response type + if (!"code".equals(responseType)) { + return CompletableFuture.completedFuture(createErrorResponse("unsupported_response_type", + "Only 'code' response type is supported", state, null)); + } + + // Validate code challenge method + if (codeChallengeMethod != null && !"S256".equals(codeChallengeMethod)) { + return CompletableFuture.completedFuture(createErrorResponse("invalid_request", + "Only 'S256' code challenge method is supported", state, null)); + } + + // Get client information + return provider.getClient(clientId).thenCompose(client -> { + if (client == null) { + return CompletableFuture + .completedFuture(createErrorResponse("invalid_request", "Client ID not found", state, null)); + } + + // Validate redirect URI + URI redirectUri; + try { + URI tempUri = redirectUriStr != null ? URI.create(redirectUriStr) : null; + redirectUri = client.validateRedirectUri(tempUri); + } + catch (InvalidRedirectUriException e) { + return CompletableFuture + .completedFuture(createErrorResponse("invalid_request", e.getMessage(), state, null)); + } + + // Validate scope + List scopes; + try { + scopes = client.validateScope(scope); + } + catch (InvalidScopeException e) { + return CompletableFuture + .completedFuture(createErrorResponse("invalid_scope", e.getMessage(), state, redirectUri)); + } + + // Setup authorization parameters + AuthorizationParams authParams = new AuthorizationParams(); + authParams.setState(state); + authParams.setScopes(scopes); + authParams.setCodeChallenge(codeChallenge); + authParams.setRedirectUri(redirectUri); + authParams.setRedirectUriProvidedExplicitly(redirectUriStr != null); + + // Let the provider handle the authorization + try { + return provider.authorize(client, authParams) + .thenApply(url -> new AuthorizationResponse(url, true, null)) + .exceptionally(ex -> { + if (ex.getCause() instanceof AuthorizeException) { + AuthorizeException authEx = (AuthorizeException) ex.getCause(); + return createErrorResponse(authEx.getError(), authEx.getErrorDescription(), state, + redirectUri); + } + else { + return createErrorResponse("server_error", "An unexpected error occurred", state, + redirectUri); + } + }); + } + catch (AuthorizeException e) { + return CompletableFuture + .completedFuture(createErrorResponse(e.getError(), e.getErrorDescription(), state, redirectUri)); + } + }); + } + + /** + * Create an error response. + * @param error The error code + * @param errorDescription The error description + * @param state The state parameter + * @param redirectUri The redirect URI, or null if not available + * @return An AuthorizationResponse containing the error + */ + private AuthorizationResponse createErrorResponse(String error, String errorDescription, String state, + URI redirectUri) { + + AuthorizationErrorResponse errorResponse = new AuthorizationErrorResponse(error, errorDescription, state); + + if (redirectUri != null) { + // Redirect with error parameters + String redirectUrl = UriUtils.constructRedirectUri(redirectUri.toString(), errorResponse.toQueryParams()); + + return new AuthorizationResponse(redirectUrl, true, errorResponse); + } + else { + // Direct error response + return new AuthorizationResponse(null, false, errorResponse); + } + } + + /** + * Response object for authorization requests. + */ + public static class AuthorizationResponse { + + private final String redirectUrl; + + private final boolean isRedirect; + + private final AuthorizationErrorResponse error; + + public AuthorizationResponse(String redirectUrl, boolean isRedirect, AuthorizationErrorResponse error) { + this.redirectUrl = redirectUrl; + this.isRedirect = isRedirect; + this.error = error; + } + + public String getRedirectUrl() { + return redirectUrl; + } + + public boolean isRedirect() { + return isRedirect; + } + + public AuthorizationErrorResponse getError() { + return error; + } + + public boolean isSuccess() { + return error == null; + } + + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/MetadataHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/MetadataHandler.java new file mode 100644 index 00000000..f9efc2ca --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/MetadataHandler.java @@ -0,0 +1,26 @@ +package io.modelcontextprotocol.server.auth.handlers; + +import io.modelcontextprotocol.auth.OAuthMetadata; + +import java.util.concurrent.CompletableFuture; + +/** + * Handler for OAuth metadata requests. + */ +public class MetadataHandler { + + private final OAuthMetadata metadata; + + public MetadataHandler(OAuthMetadata metadata) { + this.metadata = metadata; + } + + /** + * Handle a metadata request. + * @return A CompletableFuture that resolves to the OAuth metadata + */ + public CompletableFuture handle() { + return CompletableFuture.completedFuture(metadata); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/RegistrationHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/RegistrationHandler.java new file mode 100644 index 00000000..3c5f3d16 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/RegistrationHandler.java @@ -0,0 +1,135 @@ +package io.modelcontextprotocol.server.auth.handlers; + +import io.modelcontextprotocol.auth.OAuthAuthorizationServerProvider; +import io.modelcontextprotocol.auth.OAuthClientInformation; +import io.modelcontextprotocol.auth.OAuthClientMetadata; +import io.modelcontextprotocol.auth.exception.RegistrationException; +import io.modelcontextprotocol.server.auth.settings.ClientRegistrationOptions; + +import java.net.URI; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; + +/** + * Handler for OAuth client registration requests. + */ +public class RegistrationHandler { + + private final OAuthAuthorizationServerProvider provider; + + private final ClientRegistrationOptions options; + + public RegistrationHandler(OAuthAuthorizationServerProvider provider, ClientRegistrationOptions options) { + this.provider = provider; + this.options = options; + } + + /** + * Handle a client registration request. + * @param clientMetadata The client metadata + * @return A CompletableFuture that resolves to the registered client information + */ + public CompletableFuture handle(OAuthClientMetadata clientMetadata) { + // Validate client metadata + if (clientMetadata.getRedirectUris() == null || clientMetadata.getRedirectUris().isEmpty()) { + return CompletableFuture.failedFuture( + new RegistrationException("invalid_redirect_uri", "At least one redirect URI is required")); + } + + // Validate redirect URIs + for (URI redirectUri : clientMetadata.getRedirectUris()) { + if (!isValidRedirectUri(redirectUri)) { + return CompletableFuture.failedFuture( + new RegistrationException("invalid_redirect_uri", "Invalid redirect URI: " + redirectUri)); + } + } + + // Validate scopes if provided + if (clientMetadata.getScope() != null && options.getValidScopes() != null) { + String[] requestedScopes = clientMetadata.getScope().split(" "); + for (String scope : requestedScopes) { + if (!options.getValidScopes().contains(scope)) { + return CompletableFuture + .failedFuture(new RegistrationException("invalid_scope", "Invalid scope: " + scope)); + } + } + } + + // Create client information + OAuthClientInformation clientInfo = new OAuthClientInformation(); + + // Copy metadata fields + clientInfo.setRedirectUris(clientMetadata.getRedirectUris()); + clientInfo.setTokenEndpointAuthMethod(clientMetadata.getTokenEndpointAuthMethod()); + clientInfo.setGrantTypes(clientMetadata.getGrantTypes()); + clientInfo.setResponseTypes(clientMetadata.getResponseTypes()); + clientInfo.setScope(clientMetadata.getScope()); + clientInfo.setClientName(clientMetadata.getClientName()); + clientInfo.setClientUri(clientMetadata.getClientUri()); + clientInfo.setLogoUri(clientMetadata.getLogoUri()); + clientInfo.setContacts(clientMetadata.getContacts()); + clientInfo.setTosUri(clientMetadata.getTosUri()); + clientInfo.setPolicyUri(clientMetadata.getPolicyUri()); + clientInfo.setJwksUri(clientMetadata.getJwksUri()); + clientInfo.setJwks(clientMetadata.getJwks()); + clientInfo.setSoftwareId(clientMetadata.getSoftwareId()); + clientInfo.setSoftwareVersion(clientMetadata.getSoftwareVersion()); + + // Generate client ID and secret + clientInfo.setClientId(generateClientId()); + + // Generate client secret if using client_secret_post auth method + if ("client_secret_post".equals(clientMetadata.getTokenEndpointAuthMethod())) { + clientInfo.setClientSecret(generateClientSecret()); + } + + // Set issuance time + clientInfo.setClientIdIssuedAt(System.currentTimeMillis() / 1000); + + // Register client with provider + try { + return provider.registerClient(clientInfo).thenApply(v -> clientInfo); + } + catch (RegistrationException e) { + return CompletableFuture.failedFuture(e); + } + } + + /** + * Validate a redirect URI. + * @param redirectUri The redirect URI to validate + * @return true if the redirect URI is valid, false otherwise + */ + private boolean isValidRedirectUri(URI redirectUri) { + String scheme = redirectUri.getScheme(); + + // Check if localhost is allowed for non-HTTPS URIs + if (options.isAllowLocalhostRedirect() && ("http".equals(scheme) || "custom".equals(scheme))) { + String host = redirectUri.getHost(); + if ("localhost".equals(host) || host.startsWith("127.0.0.1")) { + return true; + } + } + + // Require HTTPS for all other URIs + return "https".equals(scheme); + } + + /** + * Generate a random client ID. + * @return A random client ID + */ + private String generateClientId() { + return UUID.randomUUID().toString(); + } + + /** + * Generate a random client secret. + * @return A random client secret + */ + private String generateClientSecret() { + return UUID.randomUUID().toString() + UUID.randomUUID().toString(); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/RevocationHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/RevocationHandler.java new file mode 100644 index 00000000..d356cf0b --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/RevocationHandler.java @@ -0,0 +1,68 @@ +package io.modelcontextprotocol.server.auth.handlers; + +import io.modelcontextprotocol.auth.AccessToken; +import io.modelcontextprotocol.auth.OAuthAuthorizationServerProvider; +import io.modelcontextprotocol.auth.RefreshToken; +import io.modelcontextprotocol.server.auth.middleware.ClientAuthenticator; + +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +/** + * Handler for OAuth token revocation requests. + */ +public class RevocationHandler { + + private final OAuthAuthorizationServerProvider provider; + + private final ClientAuthenticator clientAuthenticator; + + public RevocationHandler(OAuthAuthorizationServerProvider provider, ClientAuthenticator clientAuthenticator) { + this.provider = provider; + this.clientAuthenticator = clientAuthenticator; + } + + /** + * Handle a token revocation request. + * @param params The request parameters + * @return A CompletableFuture that completes when the token is revoked + */ + public CompletableFuture handle(Map params) { + String token = params.get("token"); + String tokenTypeHint = params.get("token_type_hint"); + String clientId = params.get("client_id"); + String clientSecret = params.get("client_secret"); + + if (token == null || clientId == null) { + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(new IllegalArgumentException("Missing required parameters")); + return future; + } + + // Authenticate client + return clientAuthenticator.authenticate(clientId, clientSecret).thenCompose(client -> { + // Try to load token based on token_type_hint + if ("refresh_token".equals(tokenTypeHint)) { + return provider.loadRefreshToken(client, token).thenCompose(refreshToken -> { + if (refreshToken != null && refreshToken.getClientId().equals(client.getClientId())) { + return provider.revokeToken(refreshToken); + } + return CompletableFuture.completedFuture(null); + }); + } + else if ("access_token".equals(tokenTypeHint) || tokenTypeHint == null) { + return provider.loadAccessToken(token).thenCompose(accessToken -> { + if (accessToken != null && accessToken.getClientId().equals(client.getClientId())) { + return provider.revokeToken(accessToken); + } + return CompletableFuture.completedFuture(null); + }); + } + else { + // Unknown token type hint + return CompletableFuture.completedFuture(null); + } + }); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/TokenHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/TokenHandler.java new file mode 100644 index 00000000..6bebe1a0 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/handlers/TokenHandler.java @@ -0,0 +1,188 @@ +package io.modelcontextprotocol.server.auth.handlers; + +import io.modelcontextprotocol.auth.AccessToken; +import io.modelcontextprotocol.auth.AuthorizationCode; +import io.modelcontextprotocol.auth.OAuthAuthorizationServerProvider; +import io.modelcontextprotocol.auth.OAuthClientInformation; +import io.modelcontextprotocol.auth.OAuthToken; +import io.modelcontextprotocol.auth.RefreshToken; +import io.modelcontextprotocol.auth.exception.TokenException; +import io.modelcontextprotocol.server.auth.middleware.ClientAuthenticator; + +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.time.Instant; +import java.util.Arrays; +import java.util.Base64; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +/** + * Handler for OAuth token requests. + */ +public class TokenHandler { + + private final OAuthAuthorizationServerProvider provider; + + private final ClientAuthenticator clientAuthenticator; + + public TokenHandler(OAuthAuthorizationServerProvider provider, ClientAuthenticator clientAuthenticator) { + this.provider = provider; + this.clientAuthenticator = clientAuthenticator; + } + + /** + * Handle a token request. + * @param params The request parameters + * @return A CompletableFuture that resolves to an OAuth token + */ + public CompletableFuture handle(Map params) { + String grantType = params.get("grant_type"); + String clientId = params.get("client_id"); + String clientSecret = params.get("client_secret"); + + if (grantType == null || clientId == null) { + return CompletableFuture.failedFuture(new TokenException("invalid_request", "Missing required parameters")); + } + + // Authenticate client + return clientAuthenticator.authenticate(clientId, clientSecret).thenCompose(client -> { + // Check if grant type is supported + if (!client.getGrantTypes().contains(grantType)) { + return CompletableFuture.failedFuture(new TokenException("unsupported_grant_type", + "Unsupported grant type (supported grant types are " + client.getGrantTypes() + ")")); + } + + // Handle different grant types + if ("authorization_code".equals(grantType)) { + return handleAuthorizationCode(client, params); + } + else if ("refresh_token".equals(grantType)) { + return handleRefreshToken(client, params); + } + else { + return CompletableFuture + .failedFuture(new TokenException("unsupported_grant_type", "Unsupported grant type")); + } + }); + } + + /** + * Handle authorization code grant type. + * @param client The authenticated client + * @param params The request parameters + * @return A CompletableFuture that resolves to an OAuth token + */ + private CompletableFuture handleAuthorizationCode(OAuthClientInformation client, + Map params) { + + String code = params.get("code"); + String redirectUri = params.get("redirect_uri"); + String codeVerifier = params.get("code_verifier"); + + if (code == null || codeVerifier == null) { + return CompletableFuture.failedFuture(new TokenException("invalid_request", "Missing required parameters")); + } + + // Load authorization code + return provider.loadAuthorizationCode(client, code).thenCompose(authCode -> { + if (authCode == null || !authCode.getClientId().equals(client.getClientId())) { + return CompletableFuture + .failedFuture(new TokenException("invalid_grant", "Authorization code does not exist")); + } + + // Check if code has expired + if (authCode.getExpiresAt() < Instant.now().getEpochSecond()) { + return CompletableFuture + .failedFuture(new TokenException("invalid_grant", "Authorization code has expired")); + } + + // Verify redirect URI matches + if (authCode.isRedirectUriProvidedExplicitly()) { + if (redirectUri == null || !redirectUri.equals(authCode.getRedirectUri().toString())) { + return CompletableFuture.failedFuture(new TokenException("invalid_request", + "Redirect URI did not match the one used when creating auth code")); + } + } + + // Verify PKCE code verifier + try { + MessageDigest digest = MessageDigest.getInstance("SHA-256"); + byte[] hash = digest.digest(codeVerifier.getBytes(StandardCharsets.UTF_8)); + String hashedCodeVerifier = Base64.getUrlEncoder().withoutPadding().encodeToString(hash); + + if (!hashedCodeVerifier.equals(authCode.getCodeChallenge())) { + return CompletableFuture + .failedFuture(new TokenException("invalid_grant", "Incorrect code_verifier")); + } + } + catch (NoSuchAlgorithmException e) { + return CompletableFuture + .failedFuture(new TokenException("server_error", "Failed to verify code challenge")); + } + + // Exchange authorization code for tokens + try { + return provider.exchangeAuthorizationCode(client, authCode); + } + catch (TokenException e) { + return CompletableFuture.failedFuture(e); + } + }); + } + + /** + * Handle refresh token grant type. + * @param client The authenticated client + * @param params The request parameters + * @return A CompletableFuture that resolves to an OAuth token + */ + private CompletableFuture handleRefreshToken(OAuthClientInformation client, + Map params) { + + String refreshTokenStr = params.get("refresh_token"); + String scope = params.get("scope"); + + if (refreshTokenStr == null) { + return CompletableFuture + .failedFuture(new TokenException("invalid_request", "Missing refresh_token parameter")); + } + + // Load refresh token + return provider.loadRefreshToken(client, refreshTokenStr).thenCompose(refreshToken -> { + if (refreshToken == null || !refreshToken.getClientId().equals(client.getClientId())) { + return CompletableFuture + .failedFuture(new TokenException("invalid_grant", "Refresh token does not exist")); + } + + // Check if token has expired + if (refreshToken.getExpiresAt() != null && refreshToken.getExpiresAt() < Instant.now().getEpochSecond()) { + return CompletableFuture.failedFuture(new TokenException("invalid_grant", "Refresh token has expired")); + } + + // Parse scopes if provided + List scopes = scope != null ? Arrays.asList(scope.split(" ")) : refreshToken.getScopes(); + + // Validate requested scopes against refresh token scopes + if (scopes != null) { + for (String s : scopes) { + if (!refreshToken.getScopes().contains(s)) { + return CompletableFuture.failedFuture(new TokenException("invalid_scope", + "Cannot request scope `" + s + "` not provided by refresh token")); + } + } + } + + // Exchange refresh token for new tokens + try { + return provider.exchangeRefreshToken(client, refreshToken, scopes); + } + catch (TokenException e) { + return CompletableFuture.failedFuture(e); + } + }); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/middleware/AuthContext.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/middleware/AuthContext.java new file mode 100644 index 00000000..b64a6234 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/middleware/AuthContext.java @@ -0,0 +1,45 @@ +package io.modelcontextprotocol.server.auth.middleware; + +import io.modelcontextprotocol.auth.AccessToken; + +/** + * Holds authentication context for a request. + */ +public class AuthContext { + + private final AccessToken accessToken; + + /** + * Creates a new AuthContext. + * @param accessToken The authenticated access token. + */ + public AuthContext(AccessToken accessToken) { + this.accessToken = accessToken; + } + + /** + * Gets the access token. + * @return The access token. + */ + public AccessToken getAccessToken() { + return accessToken; + } + + /** + * Gets the client ID. + * @return The client ID. + */ + public String getClientId() { + return accessToken != null ? accessToken.getClientId() : null; + } + + /** + * Checks if the user has the specified scope. + * @param scope The scope to check. + * @return True if the user has the scope, false otherwise. + */ + public boolean hasScope(String scope) { + return accessToken != null && accessToken.getScopes().contains(scope); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/middleware/BearerAuthenticator.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/middleware/BearerAuthenticator.java new file mode 100644 index 00000000..d3302539 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/middleware/BearerAuthenticator.java @@ -0,0 +1,60 @@ +package io.modelcontextprotocol.server.auth.middleware; + +import io.modelcontextprotocol.auth.AccessToken; +import io.modelcontextprotocol.auth.OAuthAuthorizationServerProvider; + +import java.util.concurrent.CompletableFuture; + +/** + * Authenticator for OAuth bearer tokens. + */ +public class BearerAuthenticator { + + private final OAuthAuthorizationServerProvider provider; + + public BearerAuthenticator(OAuthAuthorizationServerProvider provider) { + this.provider = provider; + } + + /** + * Authenticate a request using a bearer token. + * @param authHeader The Authorization header value + * @return A CompletableFuture that resolves to the authenticated access token + */ + public CompletableFuture authenticate(String authHeader) { + if (authHeader == null || !authHeader.startsWith("Bearer ")) { + return CompletableFuture + .failedFuture(new AuthenticationException("Missing or invalid Authorization header")); + } + + String token = authHeader.substring("Bearer ".length()).trim(); + if (token.isEmpty()) { + return CompletableFuture.failedFuture(new AuthenticationException("Empty bearer token")); + } + + return provider.loadAccessToken(token).thenCompose(accessToken -> { + if (accessToken == null) { + return CompletableFuture.failedFuture(new AuthenticationException("Invalid access token")); + } + + // Check if token has expired + if (accessToken.getExpiresAt() != null && accessToken.getExpiresAt() < System.currentTimeMillis() / 1000) { + return CompletableFuture.failedFuture(new AuthenticationException("Access token has expired")); + } + + return CompletableFuture.completedFuture(accessToken); + }); + } + + /** + * Exception thrown when bearer authentication fails. + */ + public static class AuthenticationException extends Exception { + + public AuthenticationException(String message) { + super(message); + } + + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/middleware/ClientAuthenticator.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/middleware/ClientAuthenticator.java new file mode 100644 index 00000000..b5198626 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/middleware/ClientAuthenticator.java @@ -0,0 +1,61 @@ +package io.modelcontextprotocol.server.auth.middleware; + +import io.modelcontextprotocol.auth.OAuthAuthorizationServerProvider; +import io.modelcontextprotocol.auth.OAuthClientInformation; + +import java.util.concurrent.CompletableFuture; + +/** + * Authenticator for OAuth clients. + */ +public class ClientAuthenticator { + + private final OAuthAuthorizationServerProvider provider; + + public ClientAuthenticator(OAuthAuthorizationServerProvider provider) { + this.provider = provider; + } + + /** + * Authenticate a client using client ID and optional client secret. + * @param clientId The client ID + * @param clientSecret The client secret (may be null) + * @return A CompletableFuture that resolves to the authenticated client information + */ + public CompletableFuture authenticate(String clientId, String clientSecret) { + if (clientId == null) { + return CompletableFuture.failedFuture(new AuthenticationException("Missing client_id parameter")); + } + + return provider.getClient(clientId).thenCompose(client -> { + if (client == null) { + return CompletableFuture.failedFuture(new AuthenticationException("Client not found")); + } + + // If client has a secret, verify it + if (client.getClientSecret() != null) { + if (clientSecret == null) { + return CompletableFuture.failedFuture(new AuthenticationException("Client secret required")); + } + + if (!client.getClientSecret().equals(clientSecret)) { + return CompletableFuture.failedFuture(new AuthenticationException("Invalid client secret")); + } + } + + return CompletableFuture.completedFuture(client); + }); + } + + /** + * Exception thrown when client authentication fails. + */ + public static class AuthenticationException extends Exception { + + public AuthenticationException(String message) { + super(message); + } + + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/model/AuthorizationErrorResponse.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/model/AuthorizationErrorResponse.java new file mode 100644 index 00000000..cccfe594 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/model/AuthorizationErrorResponse.java @@ -0,0 +1,117 @@ +package io.modelcontextprotocol.server.auth.model; + +import java.net.URI; + +/** + * OAuth authorization error response as defined in RFC 6749 Section 4.1.2.1. + */ +public class AuthorizationErrorResponse { + + private String error; + + private String errorDescription; + + private URI errorUri; + + private String state; + + /** + * Creates a new AuthorizationErrorResponse. + * @param error The error code + * @param errorDescription The error description + * @param state The state parameter from the request + */ + public AuthorizationErrorResponse(String error, String errorDescription, String state) { + this.error = error; + this.errorDescription = errorDescription; + this.state = state; + } + + /** + * Gets the error code. + * @return The error code + */ + public String getError() { + return error; + } + + /** + * Sets the error code. + * @param error The error code + */ + public void setError(String error) { + this.error = error; + } + + /** + * Gets the error description. + * @return The error description + */ + public String getErrorDescription() { + return errorDescription; + } + + /** + * Sets the error description. + * @param errorDescription The error description + */ + public void setErrorDescription(String errorDescription) { + this.errorDescription = errorDescription; + } + + /** + * Gets the error URI. + * @return The error URI + */ + public URI getErrorUri() { + return errorUri; + } + + /** + * Sets the error URI. + * @param errorUri The error URI + */ + public void setErrorUri(URI errorUri) { + this.errorUri = errorUri; + } + + /** + * Gets the state parameter. + * @return The state parameter + */ + public String getState() { + return state; + } + + /** + * Sets the state parameter. + * @param state The state parameter + */ + public void setState(String state) { + this.state = state; + } + + /** + * Converts the error response to a map of query parameters. + * @return A map of query parameters + */ + public java.util.Map toQueryParams() { + java.util.Map params = new java.util.HashMap<>(); + params.put("error", error); + + if (errorDescription != null) { + params.put("error_description", errorDescription); + } + + if (errorUri != null) { + params.put("error_uri", errorUri.toString()); + } + + if (state != null) { + params.put("state", state); + } + + return params; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/settings/ClientRegistrationOptions.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/settings/ClientRegistrationOptions.java new file mode 100644 index 00000000..7842472a --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/settings/ClientRegistrationOptions.java @@ -0,0 +1,65 @@ +package io.modelcontextprotocol.server.auth.settings; + +import java.util.List; + +/** + * Options for OAuth client registration. + */ +public class ClientRegistrationOptions { + + private boolean enabled = true; + + private boolean allowLocalhostRedirect = true; + + private List validScopes; + + /** + * Check if client registration is enabled. + * @return true if client registration is enabled, false otherwise + */ + public boolean isEnabled() { + return enabled; + } + + /** + * Set whether client registration is enabled. + * @param enabled true to enable client registration, false to disable + */ + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + + /** + * Check if localhost redirect URIs are allowed. + * @return true if localhost redirect URIs are allowed, false otherwise + */ + public boolean isAllowLocalhostRedirect() { + return allowLocalhostRedirect; + } + + /** + * Set whether localhost redirect URIs are allowed. + * @param allowLocalhostRedirect true to allow localhost redirect URIs, false to + * disallow + */ + public void setAllowLocalhostRedirect(boolean allowLocalhostRedirect) { + this.allowLocalhostRedirect = allowLocalhostRedirect; + } + + /** + * Get the list of valid scopes. + * @return the list of valid scopes + */ + public List getValidScopes() { + return validScopes; + } + + /** + * Set the list of valid scopes. + * @param validScopes the list of valid scopes + */ + public void setValidScopes(List validScopes) { + this.validScopes = validScopes; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/settings/RevocationOptions.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/settings/RevocationOptions.java new file mode 100644 index 00000000..6e180556 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/settings/RevocationOptions.java @@ -0,0 +1,26 @@ +package io.modelcontextprotocol.server.auth.settings; + +/** + * Options for OAuth token revocation. + */ +public class RevocationOptions { + + private boolean enabled = true; + + /** + * Check if token revocation is enabled. + * @return true if token revocation is enabled, false otherwise + */ + public boolean isEnabled() { + return enabled; + } + + /** + * Set whether token revocation is enabled. + * @param enabled true to enable token revocation, false to disable + */ + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/util/UriUtils.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/util/UriUtils.java new file mode 100644 index 00000000..464246a9 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/util/UriUtils.java @@ -0,0 +1,123 @@ +package io.modelcontextprotocol.server.auth.util; + +import java.net.URI; +import java.net.URISyntaxException; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * Utility class for URI operations. + */ +public class UriUtils { + + /** + * Constructs a redirect URI with query parameters. + * @param redirectUriBase The base redirect URI. + * @param params The parameters to add to the query string. + * @return The constructed redirect URI. + */ + public static String constructRedirectUri(String redirectUriBase, Map params) { + try { + URI uri = new URI(redirectUriBase); + + // Get existing query + String query = uri.getQuery(); + StringBuilder queryBuilder = new StringBuilder(); + + // Append existing query parameters if any + if (query != null && !query.isEmpty()) { + queryBuilder.append(query); + if (!params.isEmpty()) { + queryBuilder.append("&"); + } + } + + // Append new parameters + if (!params.isEmpty()) { + String newParams = params.entrySet() + .stream() + .filter(entry -> entry.getValue() != null) + .map(entry -> URLEncoder.encode(entry.getKey(), StandardCharsets.UTF_8) + "=" + + URLEncoder.encode(entry.getValue(), StandardCharsets.UTF_8)) + .collect(Collectors.joining("&")); + queryBuilder.append(newParams); + } + + // Create new URI with updated query + return new URI(uri.getScheme(), uri.getAuthority(), uri.getPath(), queryBuilder.toString(), + uri.getFragment()) + .toString(); + } + catch (URISyntaxException e) { + throw new IllegalArgumentException("Invalid redirect URI: " + redirectUriBase, e); + } + } + + /** + * Modify a URI's path using the provided mapper function. + * @param uri The URI to modify + * @param pathMapper Function to transform the path + * @return The modified URI + */ + public static URI modifyUriPath(URI uri, Function pathMapper) { + String path = uri.getPath(); + if (path == null) { + path = ""; + } + + String newPath = pathMapper.apply(path); + + try { + return new URI(uri.getScheme(), uri.getUserInfo(), uri.getHost(), uri.getPort(), newPath, uri.getQuery(), + uri.getFragment()); + } + catch (Exception e) { + throw new IllegalArgumentException("Failed to modify URI path", e); + } + } + + /** + * Validate that the issuer URL meets OAuth 2.0 requirements. + * @param url The issuer URL to validate + */ + public static void validateIssuerUrl(URI url) { + // RFC 8414 requires HTTPS, but we allow localhost HTTP for testing + String scheme = url.getScheme(); + String host = url.getHost(); + + if (!"https".equals(scheme) && !"localhost".equals(host) && !host.startsWith("127.0.0.1")) { + throw new IllegalArgumentException("Issuer URL must be HTTPS"); + } + + // No fragments or query parameters allowed + if (url.getFragment() != null) { + throw new IllegalArgumentException("Issuer URL must not have a fragment"); + } + if (url.getQuery() != null) { + throw new IllegalArgumentException("Issuer URL must not have a query string"); + } + } + + /** + * Build an endpoint URL by appending a path to the issuer URL. + * @param issuerUrl The issuer URL + * @param path The path to append + * @return The endpoint URL + */ + public static URI buildEndpointUrl(URI issuerUrl, String path) { + String baseUrl = issuerUrl.toString(); + if (baseUrl.endsWith("/")) { + baseUrl = baseUrl.substring(0, baseUrl.length() - 1); + } + + if (!path.startsWith("/")) { + path = "/" + path; + } + + return URI.create(baseUrl + path); + } + +} \ No newline at end of file diff --git a/mcp/src/test/java/io/modelcontextprotocol/auth/OAuthClientProviderTest.java b/mcp/src/test/java/io/modelcontextprotocol/auth/OAuthClientProviderTest.java new file mode 100644 index 00000000..398cde5c --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/auth/OAuthClientProviderTest.java @@ -0,0 +1,120 @@ +package io.modelcontextprotocol.auth; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.net.URI; +import java.time.Duration; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import io.modelcontextprotocol.client.auth.AuthCallbackResult; +import io.modelcontextprotocol.client.auth.OAuthClientProvider; +import io.modelcontextprotocol.client.auth.TokenStorage; + +/** + * Tests for the OAuthClientProvider class. + */ +public class OAuthClientProviderTest { + + private OAuthClientMetadata clientMetadata; + + private TokenStorage mockStorage; + + private Function> mockRedirectHandler; + + private Function> mockCallbackHandler; + + private OAuthClientProvider clientProvider; + + private OAuthToken token; + + private OAuthClientInformation clientInfo; + + @SuppressWarnings("unchecked") + @BeforeEach + public void setup() throws Exception { + // Setup client metadata + clientMetadata = new OAuthClientMetadata(); + clientMetadata.setRedirectUris(List.of(new URI("https://example.com/callback"))); + clientMetadata.setScope("read write"); + + // Setup mock storage + mockStorage = mock(TokenStorage.class); + + // Setup mock handlers + mockRedirectHandler = mock(Function.class); + mockCallbackHandler = mock(Function.class); + + // Setup token and client info + token = new OAuthToken(); + token.setAccessToken("test-access-token"); + token.setRefreshToken("test-refresh-token"); + token.setExpiresIn(3600); + token.setScope("read write"); + + clientInfo = new OAuthClientInformation(); + clientInfo.setClientId("test-client-id"); + clientInfo.setClientSecret("test-client-secret"); + clientInfo.setRedirectUris(List.of(new URI("https://example.com/callback"))); + clientInfo.setScope("read write"); + + // Configure mocks + when(mockStorage.getTokens()).thenReturn(CompletableFuture.completedFuture(token)); + when(mockStorage.getClientInfo()).thenReturn(CompletableFuture.completedFuture(clientInfo)); + when(mockStorage.setTokens(any())).thenReturn(CompletableFuture.completedFuture(null)); + when(mockStorage.setClientInfo(any())).thenReturn(CompletableFuture.completedFuture(null)); + + when(mockRedirectHandler.apply(anyString())).thenReturn(CompletableFuture.completedFuture(null)); + + AuthCallbackResult callbackResult = new AuthCallbackResult("test-auth-code", "test-state"); + when(mockCallbackHandler.apply(any())).thenReturn(CompletableFuture.completedFuture(callbackResult)); + + // Create client provider + clientProvider = new OAuthClientProvider("https://auth.example.com", clientMetadata, mockStorage, + mockRedirectHandler, mockCallbackHandler, Duration.ofSeconds(30)); + } + + @Test + public void testInitialize() throws Exception { + // Test initialization + CompletableFuture initFuture = clientProvider.initialize(); + initFuture.get(); + + // Test access token retrieval + String accessToken = clientProvider.getAccessToken(); + assertNotNull(accessToken); + assertEquals("test-access-token", accessToken); + + // Test token retrieval + OAuthToken retrievedToken = clientProvider.getCurrentTokens(); + assertNotNull(retrievedToken); + assertEquals(token.getAccessToken(), retrievedToken.getAccessToken()); + assertEquals(token.getRefreshToken(), retrievedToken.getRefreshToken()); + } + + @Test + public void testEnsureToken() throws Exception { + // Initialize first + clientProvider.initialize().get(); + + // Test token validation + CompletableFuture tokenFuture = clientProvider.ensureToken(); + tokenFuture.get(); + + // Token should be valid and accessible + String accessToken = clientProvider.getAccessToken(); + assertNotNull(accessToken); + assertEquals("test-access-token", accessToken); + } + +} \ No newline at end of file diff --git a/mcp/src/test/java/io/modelcontextprotocol/auth/OAuthFlowTest.java b/mcp/src/test/java/io/modelcontextprotocol/auth/OAuthFlowTest.java new file mode 100644 index 00000000..6f84c66e --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/auth/OAuthFlowTest.java @@ -0,0 +1,117 @@ +package io.modelcontextprotocol.auth; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.net.URI; +import java.time.Instant; +import java.util.Arrays; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import io.modelcontextprotocol.auth.exception.AuthorizeException; +import io.modelcontextprotocol.auth.exception.TokenException; + +/** + * Tests for the OAuth authentication flow. + */ +public class OAuthFlowTest { + + private OAuthAuthorizationServerProvider mockProvider; + + private OAuthClientInformation clientInfo; + + private AuthorizationCode authCode; + + private OAuthToken token; + + @BeforeEach + public void setup() throws Exception { + // Setup mock provider + mockProvider = mock(OAuthAuthorizationServerProvider.class); + + // Setup test client + clientInfo = new OAuthClientInformation(); + clientInfo.setClientId("test-client-id"); + clientInfo.setClientSecret("test-client-secret"); + clientInfo.setRedirectUris(List.of(new URI("https://example.com/callback"))); + clientInfo.setScope("read write"); + + // Setup test auth code + authCode = new AuthorizationCode(); + authCode.setCode("test-auth-code"); + authCode.setClientId(clientInfo.getClientId()); + authCode.setScopes(Arrays.asList("read", "write")); + authCode.setExpiresAt(Instant.now().plusSeconds(600).getEpochSecond()); + authCode.setCodeChallenge("test-code-challenge"); + authCode.setRedirectUri(clientInfo.getRedirectUris().get(0)); + authCode.setRedirectUriProvidedExplicitly(true); + + // Setup test token + token = new OAuthToken(); + token.setAccessToken("test-access-token"); + token.setRefreshToken("test-refresh-token"); + token.setExpiresIn(3600); + token.setScope("read write"); + + // Configure mock provider + when(mockProvider.getClient(clientInfo.getClientId())) + .thenReturn(CompletableFuture.completedFuture(clientInfo)); + + when(mockProvider.authorize(any(), any())) + .thenReturn(CompletableFuture.completedFuture("https://example.com/auth?code=test-auth-code")); + + when(mockProvider.loadAuthorizationCode(any(), any())).thenReturn(CompletableFuture.completedFuture(authCode)); + + when(mockProvider.exchangeAuthorizationCode(any(), any())).thenReturn(CompletableFuture.completedFuture(token)); + } + + @Test + public void testAuthorizationCodeFlow() throws Exception { + // Test client lookup + CompletableFuture clientFuture = mockProvider.getClient(clientInfo.getClientId()); + OAuthClientInformation retrievedClient = clientFuture.get(); + + assertNotNull(retrievedClient); + assertEquals(clientInfo.getClientId(), retrievedClient.getClientId()); + + // Test authorization + AuthorizationParams params = new AuthorizationParams(); + params.setState(UUID.randomUUID().toString()); + params.setScopes(Arrays.asList("read", "write")); + params.setCodeChallenge("test-code-challenge"); + params.setRedirectUri(clientInfo.getRedirectUris().get(0)); + params.setRedirectUriProvidedExplicitly(true); + + CompletableFuture authUrlFuture = mockProvider.authorize(clientInfo, params); + String authUrl = authUrlFuture.get(); + + assertNotNull(authUrl); + assertTrue(authUrl.startsWith("https://example.com/auth?code=")); + + // Test code exchange + CompletableFuture codeFuture = mockProvider.loadAuthorizationCode(clientInfo, + "test-auth-code"); + AuthorizationCode retrievedCode = codeFuture.get(); + + assertNotNull(retrievedCode); + assertEquals(authCode.getCode(), retrievedCode.getCode()); + + CompletableFuture tokenFuture = mockProvider.exchangeAuthorizationCode(clientInfo, retrievedCode); + OAuthToken retrievedToken = tokenFuture.get(); + + assertNotNull(retrievedToken); + assertEquals(token.getAccessToken(), retrievedToken.getAccessToken()); + assertEquals(token.getRefreshToken(), retrievedToken.getRefreshToken()); + assertEquals(token.getExpiresIn(), retrievedToken.getExpiresIn()); + } + +} \ No newline at end of file