diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 62264d9a..f43d93ad 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -1,6 +1,7 @@ package io.modelcontextprotocol.server.transport; import java.io.IOException; +import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -110,6 +111,11 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv */ private volatile boolean isClosing = false; + /** + * DNS rebinding protection configuration. + */ + private final DnsRebindingProtectionConfig dnsRebindingProtectionConfig; + /** * Constructs a new WebFlux SSE server transport provider instance with the default * SSE endpoint. @@ -134,7 +140,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa * @throws IllegalArgumentException if either parameter is null */ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { - this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint); + this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint, null); } /** @@ -149,6 +155,24 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa */ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { + this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); + } + + /** + * Constructs a new WebFlux SSE server transport provider instance with optional DNS + * rebinding protection. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of MCP messages. Must not be null. + * @param baseUrl webflux message base path + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages. This endpoint will be communicated to clients during SSE connection + * setup. Must not be null. + * @param sseEndpoint The endpoint URI where clients establish their SSE connections. + * @param dnsRebindingProtectionConfig The DNS rebinding protection configuration (may be null). + * @throws IllegalArgumentException if required parameters are null + */ + public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, DnsRebindingProtectionConfig dnsRebindingProtectionConfig) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.notNull(baseUrl, "Message base path must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); @@ -158,6 +182,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; + this.dnsRebindingProtectionConfig = dnsRebindingProtectionConfig; this.routerFunction = RouterFunctions.route() .GET(this.sseEndpoint, this::handleSseConnection) .POST(this.messageEndpoint, this::handleMessage) @@ -256,6 +281,16 @@ private Mono handleSseConnection(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); } + // Validate headers + if (dnsRebindingProtectionConfig != null) { + String hostHeader = request.headers().asHttpHeaders().getFirst("Host"); + String originHeader = request.headers().asHttpHeaders().getFirst("Origin"); + if (!dnsRebindingProtectionConfig.validate(hostHeader, originHeader)) { + logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, originHeader); + return ServerResponse.status(HttpStatus.FORBIDDEN).bodyValue("DNS rebinding protection validation failed"); + } + } + return ServerResponse.ok() .contentType(MediaType.TEXT_EVENT_STREAM) .body(Flux.>create(sink -> { @@ -300,6 +335,25 @@ private Mono handleMessage(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); } + // Always validate Content-Type for POST requests + String contentType = request.headers().contentType() + .map(MediaType::toString) + .orElse(null); + if (contentType == null || !contentType.toLowerCase().startsWith("application/json")) { + logger.warn("Invalid Content-Type header: '{}'", contentType); + return ServerResponse.badRequest().bodyValue(new McpError("Content-Type must be application/json")); + } + + // Validate headers for POST requests if DNS rebinding protection is configured + if (dnsRebindingProtectionConfig != null) { + String hostHeader = request.headers().asHttpHeaders().getFirst("Host"); + String originHeader = request.headers().asHttpHeaders().getFirst("Origin"); + if (!dnsRebindingProtectionConfig.validate(hostHeader, originHeader)) { + logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, originHeader); + return ServerResponse.status(HttpStatus.FORBIDDEN).bodyValue("DNS rebinding protection validation failed"); + } + } + if (request.queryParam("sessionId").isEmpty()) { return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing in message endpoint")); } @@ -397,6 +451,8 @@ public static class Builder { private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + private DnsRebindingProtectionConfig dnsRebindingProtectionConfig; + /** * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP * messages. @@ -447,6 +503,23 @@ public Builder sseEndpoint(String sseEndpoint) { return this; } + + /** + * Sets the DNS rebinding protection configuration. + *

+ * When set, this configuration will be used to create a header validator that + * enforces DNS rebinding protection rules. This will override any previously set + * header validator. + * @param config The DNS rebinding protection configuration + * @return this builder instance + * @throws IllegalArgumentException if config is null + */ + public Builder dnsRebindingProtectionConfig(DnsRebindingProtectionConfig config) { + Assert.notNull(config, "DNS rebinding protection config must not be null"); + this.dnsRebindingProtectionConfig = config; + return this; + } + /** * Builds a new instance of {@link WebFluxSseServerTransportProvider} with the * configured settings. @@ -457,7 +530,8 @@ public WebFluxSseServerTransportProvider build() { Assert.notNull(objectMapper, "ObjectMapper must be set"); Assert.notNull(messageEndpoint, "Message endpoint must be set"); - return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint); + return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint, + dnsRebindingProtectionConfig); } } diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index fc86cfaa..d350d9ab 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -6,6 +6,7 @@ import java.io.IOException; import java.time.Duration; +import java.util.List; import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; @@ -107,6 +108,11 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi */ private volatile boolean isClosing = false; + /** + * DNS rebinding protection configuration. + */ + private final DnsRebindingProtectionConfig dnsRebindingProtectionConfig; + /** * Constructs a new WebMvcSseServerTransportProvider instance with the default SSE * endpoint. @@ -132,7 +138,7 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag * @throws IllegalArgumentException if any parameter is null */ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { - this(objectMapper, "", messageEndpoint, sseEndpoint); + this(objectMapper, "", messageEndpoint, sseEndpoint, null); } /** @@ -149,6 +155,24 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag */ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { + this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); + } + + /** + * Constructs a new WebMvcSseServerTransportProvider instance with DNS rebinding protection. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of messages. + * @param baseUrl The base URL for the message endpoint, used to construct the full + * endpoint URL for clients. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages via HTTP POST. This endpoint will be communicated to clients through the + * SSE connection's initial endpoint event. + * @param sseEndpoint The endpoint URI where clients establish their SSE connections. + * @param dnsRebindingProtectionConfig The DNS rebinding protection configuration (may be null). + * @throws IllegalArgumentException if any required parameter is null + */ + public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, DnsRebindingProtectionConfig dnsRebindingProtectionConfig) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.notNull(baseUrl, "Message base URL must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); @@ -158,6 +182,7 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUr this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; + this.dnsRebindingProtectionConfig = dnsRebindingProtectionConfig; this.routerFunction = RouterFunctions.route() .GET(this.sseEndpoint, this::handleSseConnection) .POST(this.messageEndpoint, this::handleMessage) @@ -247,6 +272,16 @@ private ServerResponse handleSseConnection(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); } + // Validate headers + if (dnsRebindingProtectionConfig != null) { + String hostHeader = request.headers().asHttpHeaders().getFirst("Host"); + String originHeader = request.headers().asHttpHeaders().getFirst("Origin"); + if (!dnsRebindingProtectionConfig.validate(hostHeader, originHeader)) { + logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, originHeader); + return ServerResponse.status(HttpStatus.FORBIDDEN).body("DNS rebinding protection validation failed"); + } + } + String sessionId = UUID.randomUUID().toString(); logger.debug("Creating new SSE connection for session: {}", sessionId); @@ -300,6 +335,23 @@ private ServerResponse handleMessage(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); } + // Always validate Content-Type for POST requests + String contentType = request.headers().asHttpHeaders().getFirst("Content-Type"); + if (contentType == null || !contentType.toLowerCase().startsWith("application/json")) { + logger.warn("Invalid Content-Type header: '{}'", contentType); + return ServerResponse.badRequest().body(new McpError("Content-Type must be application/json")); + } + + // Validate headers for POST requests if DNS rebinding protection is configured + if (dnsRebindingProtectionConfig != null) { + String hostHeader = request.headers().asHttpHeaders().getFirst("Host"); + String originHeader = request.headers().asHttpHeaders().getFirst("Origin"); + if (!dnsRebindingProtectionConfig.validate(hostHeader, originHeader)) { + logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, originHeader); + return ServerResponse.status(HttpStatus.FORBIDDEN).body("DNS rebinding protection validation failed"); + } + } + if (request.param("sessionId").isEmpty()) { return ServerResponse.badRequest().body(new McpError("Session ID missing in message endpoint")); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionConfig.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionConfig.java new file mode 100644 index 00000000..e03052db --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionConfig.java @@ -0,0 +1,108 @@ +package io.modelcontextprotocol.server.transport; + +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +/** + * Configuration for DNS rebinding protection in SSE server transports. Provides + * validation for Host and Origin headers to prevent DNS rebinding attacks. + */ +public class DnsRebindingProtectionConfig { + + private final Set allowedHosts; + + private final Set allowedOrigins; + + private final boolean enableDnsRebindingProtection; + + private DnsRebindingProtectionConfig(Builder builder) { + this.allowedHosts = Collections.unmodifiableSet(new HashSet<>(builder.allowedHosts)); + this.allowedOrigins = Collections.unmodifiableSet(new HashSet<>(builder.allowedOrigins)); + this.enableDnsRebindingProtection = builder.enableDnsRebindingProtection; + } + + /** + * Validates Host and Origin headers for DNS rebinding protection. Returns true if the + * headers are valid, false otherwise. + * @param hostHeader The value of the Host header (may be null) + * @param originHeader The value of the Origin header (may be null) + * @return true if the headers are valid, false otherwise + */ + public boolean validate(String hostHeader, String originHeader) { + // Skip validation if protection is not enabled + if (!enableDnsRebindingProtection) { + return true; + } + + // Validate Host header + if (hostHeader != null) { + String lowerHost = hostHeader.toLowerCase(); + if (!allowedHosts.contains(lowerHost)) { + return false; + } + } + + // Validate Origin header + if (originHeader != null) { + String lowerOrigin = originHeader.toLowerCase(); + if (!allowedOrigins.contains(lowerOrigin)) { + return false; + } + } + + return true; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private final Set allowedHosts = new HashSet<>(); + + private final Set allowedOrigins = new HashSet<>(); + + private boolean enableDnsRebindingProtection = true; + + public Builder allowedHost(String host) { + if (host != null) { + this.allowedHosts.add(host.toLowerCase()); + } + return this; + } + + public Builder allowedHosts(Set hosts) { + if (hosts != null) { + hosts.forEach(this::allowedHost); + } + return this; + } + + public Builder allowedOrigin(String origin) { + if (origin != null) { + this.allowedOrigins.add(origin.toLowerCase()); + } + return this; + } + + public Builder allowedOrigins(Set origins) { + if (origins != null) { + origins.forEach(this::allowedOrigin); + } + return this; + } + + public Builder enableDnsRebindingProtection(boolean enable) { + this.enableDnsRebindingProtection = enable; + return this; + } + + public DnsRebindingProtectionConfig build() { + return new DnsRebindingProtectionConfig(this); + } + + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index afdbff47..1ce7ee0c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -103,6 +103,9 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement /** Session factory for creating new sessions */ private McpServerSession.Factory sessionFactory; + /** DNS rebinding protection configuration */ + private final DnsRebindingProtectionConfig dnsRebindingProtectionConfig; + /** * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE * endpoint. @@ -113,7 +116,7 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement */ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { - this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint); + this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint, null); } /** @@ -127,10 +130,27 @@ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String m */ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { + this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); + } + + /** + * Creates a new HttpServletSseServerTransportProvider instance with optional DNS + * rebinding protection. + * @param objectMapper The JSON object mapper to use for message + * serialization/deserialization + * @param baseUrl The base URL for the server transport + * @param messageEndpoint The endpoint path where clients will send their messages + * @param sseEndpoint The endpoint path where clients will establish SSE connections + * @param dnsRebindingProtectionConfig The DNS rebinding protection configuration (may + * be null) + */ + public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, DnsRebindingProtectionConfig dnsRebindingProtectionConfig) { this.objectMapper = objectMapper; this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; + this.dnsRebindingProtectionConfig = dnsRebindingProtectionConfig; } /** @@ -202,6 +222,18 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) return; } + // Validate headers if DNS rebinding protection is configured + if (dnsRebindingProtectionConfig != null) { + String hostHeader = request.getHeader("Host"); + String originHeader = request.getHeader("Origin"); + if (!dnsRebindingProtectionConfig.validate(hostHeader, originHeader)) { + logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, + originHeader); + response.sendError(HttpServletResponse.SC_FORBIDDEN, "DNS rebinding protection validation failed"); + return; + } + } + response.setContentType("text/event-stream"); response.setCharacterEncoding(UTF_8); response.setHeader("Cache-Control", "no-cache"); @@ -252,6 +284,26 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) return; } + // Always validate Content-Type for POST requests + String contentType = request.getContentType(); + if (contentType == null || !contentType.toLowerCase().startsWith("application/json")) { + logger.warn("Invalid Content-Type header: '{}'", contentType); + response.sendError(HttpServletResponse.SC_BAD_REQUEST, "Content-Type must be application/json"); + return; + } + + // Validate headers for POST requests if DNS rebinding protection is configured + if (dnsRebindingProtectionConfig != null) { + String hostHeader = request.getHeader("Host"); + String originHeader = request.getHeader("Origin"); + if (!dnsRebindingProtectionConfig.validate(hostHeader, originHeader)) { + logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, + originHeader); + response.sendError(HttpServletResponse.SC_FORBIDDEN, "DNS rebinding protection validation failed"); + return; + } + } + // Get the session ID from the request parameter String sessionId = request.getParameter("sessionId"); if (sessionId == null) { @@ -475,6 +527,8 @@ public static class Builder { private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + private DnsRebindingProtectionConfig dnsRebindingProtectionConfig; + /** * Sets the JSON object mapper to use for message serialization/deserialization. * @param objectMapper The object mapper to use @@ -522,6 +576,17 @@ public Builder sseEndpoint(String sseEndpoint) { return this; } + /** + * Sets the DNS rebinding protection configuration. + * @param config The DNS rebinding protection configuration + * @return This builder instance for method chaining + */ + public Builder dnsRebindingProtectionConfig(DnsRebindingProtectionConfig config) { + Assert.notNull(config, "DNS rebinding protection config must not be null"); + this.dnsRebindingProtectionConfig = config; + return this; + } + /** * Builds a new instance of HttpServletSseServerTransportProvider with the * configured settings. @@ -535,7 +600,8 @@ public HttpServletSseServerTransportProvider build() { if (messageEndpoint == null) { throw new IllegalStateException("MessageEndpoint must be set"); } - return new HttpServletSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint); + return new HttpServletSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint, + dnsRebindingProtectionConfig); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionConfigTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionConfigTests.java new file mode 100644 index 00000000..388a48cf --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionConfigTests.java @@ -0,0 +1,157 @@ +package io.modelcontextprotocol.server.transport; + +import org.junit.jupiter.api.Test; + +import java.util.Set; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for DNS rebinding protection configuration. + */ +public class DnsRebindingProtectionConfigTests { + + @Test + void testDefaultConfiguration() { + DnsRebindingProtectionConfig config = DnsRebindingProtectionConfig.builder().build(); + + // Test default behavior - when allowed lists are empty and headers are provided, + // validation fails because the headers are not in the (empty) allowed lists + assertThat(config.validate("any.host.com", "http://any.origin.com")).isFalse(); + assertThat(config.validate("localhost", null)).isFalse(); + assertThat(config.validate(null, "http://example.com")).isFalse(); + // Null values are allowed when lists are empty + assertThat(config.validate(null, null)).isTrue(); + } + + @Test + void testDisableDnsRebindingProtection() { + DnsRebindingProtectionConfig config = DnsRebindingProtectionConfig.builder() + .enableDnsRebindingProtection(false) + .allowedHost("localhost") // Should be ignored when protection is disabled + .allowedOrigin("http://localhost") // Should be ignored when protection is + // disabled + .build(); + + // When protection is disabled, all hosts and origins should be allowed + assertThat(config.validate("evil.com", "http://evil.com")).isTrue(); + assertThat(config.validate("any.host", "http://any.origin")).isTrue(); + assertThat(config.validate(null, null)).isTrue(); + } + + @Test + void testHostValidation() { + DnsRebindingProtectionConfig config = DnsRebindingProtectionConfig.builder() + .allowedHost("localhost") + .allowedHost("127.0.0.1") + .build(); + + // Valid hosts + assertThat(config.validate("localhost", null)).isTrue(); + assertThat(config.validate("127.0.0.1", null)).isTrue(); + + // Invalid hosts + assertThat(config.validate("evil.com", null)).isFalse(); + + // Null host is allowed when no specific hosts are being checked + assertThat(config.validate(null, null)).isTrue(); + } + + @Test + void testOriginValidation() { + DnsRebindingProtectionConfig config = DnsRebindingProtectionConfig.builder() + .allowedOrigin("http://localhost:8080") + .allowedOrigin("https://app.example.com") + .build(); + + // Valid origins + assertThat(config.validate(null, "http://localhost:8080")).isTrue(); + assertThat(config.validate(null, "https://app.example.com")).isTrue(); + + // Invalid origins + assertThat(config.validate(null, "http://evil.com")).isFalse(); + + // Null origin is allowed when no specific origins are being checked + assertThat(config.validate(null, null)).isTrue(); + } + + @Test + void testCombinedHostAndOriginValidation() { + DnsRebindingProtectionConfig config = DnsRebindingProtectionConfig.builder() + .allowedHost("localhost") + .allowedOrigin("http://localhost:8080") + .build(); + + // Both valid + assertThat(config.validate("localhost", "http://localhost:8080")).isTrue(); + + // Host valid, origin invalid + assertThat(config.validate("localhost", "http://evil.com")).isFalse(); + + // Host invalid, origin valid + assertThat(config.validate("evil.com", "http://localhost:8080")).isFalse(); + + // Both invalid + assertThat(config.validate("evil.com", "http://evil.com")).isFalse(); + } + + @Test + void testCaseInsensitiveHostAndOrigin() { + DnsRebindingProtectionConfig config = DnsRebindingProtectionConfig.builder() + .allowedHost("LOCALHOST") + .allowedOrigin("HTTP://LOCALHOST:8080") + .build(); + + // Case insensitive matching + assertThat(config.validate("localhost", null)).isTrue(); + assertThat(config.validate("LOCALHOST", null)).isTrue(); + assertThat(config.validate("LoCaLhOsT", null)).isTrue(); + + assertThat(config.validate(null, "http://localhost:8080")).isTrue(); + assertThat(config.validate(null, "HTTP://LOCALHOST:8080")).isTrue(); + } + + @Test + void testEmptyAllowedListsDenyNonNull() { + DnsRebindingProtectionConfig config = DnsRebindingProtectionConfig.builder().build(); + + // When allowed lists are empty and headers are provided, validation fails + assertThat(config.validate("any.host.com", "http://any.origin.com")).isFalse(); + assertThat(config.validate("random.host", "http://random.origin")).isFalse(); + // But null values are allowed + assertThat(config.validate(null, null)).isTrue(); + } + + @Test + void testBuilderWithSets() { + Set hosts = Set.of("host1.com", "host2.com"); + Set origins = Set.of("http://origin1.com", "http://origin2.com"); + + DnsRebindingProtectionConfig config = DnsRebindingProtectionConfig.builder() + .allowedHosts(hosts) + .allowedOrigins(origins) + .build(); + + assertThat(config.validate("host1.com", null)).isTrue(); + assertThat(config.validate("host2.com", null)).isTrue(); + assertThat(config.validate("host3.com", null)).isFalse(); + + assertThat(config.validate(null, "http://origin1.com")).isTrue(); + assertThat(config.validate(null, "http://origin2.com")).isTrue(); + assertThat(config.validate(null, "http://origin3.com")).isFalse(); + } + + @Test + void testNullValuesWithConfiguredLists() { + DnsRebindingProtectionConfig config = DnsRebindingProtectionConfig.builder() + .allowedHost("localhost") + .allowedOrigin("http://localhost") + .build(); + + // Null values should be allowed when no check is needed for that header + assertThat(config.validate(null, "http://localhost")).isTrue(); + assertThat(config.validate("localhost", null)).isTrue(); + assertThat(config.validate(null, null)).isTrue(); + } + +} \ No newline at end of file diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseHeaderValidationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseHeaderValidationTests.java new file mode 100644 index 00000000..68ec9ec0 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseHeaderValidationTests.java @@ -0,0 +1,240 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server.transport; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Map; +import java.util.concurrent.CompletionException; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Integration tests for header validation in + * {@link HttpServletSseServerTransportProvider}. + */ +class HttpServletSseHeaderValidationTests { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private static final String SSE_ENDPOINT = "/sse"; + + private Tomcat tomcat; + + private HttpServletSseServerTransportProvider transportProvider; + + private McpSyncServer server; + + @AfterEach + void tearDown() { + if (server != null) { + server.close(); + } + if (transportProvider != null) { + transportProvider.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Test + void testConnectionSucceedsWithValidHeaders() { + // Create DNS rebinding protection config that validates API key + DnsRebindingProtectionConfig dnsRebindingProtectionConfig = DnsRebindingProtectionConfig.builder() + .enableDnsRebindingProtection(false) // Disable Host/Origin validation for + // this test + .build(); + + // For this test, we'll need to use a custom transport provider implementation + // since DnsRebindingProtectionConfig doesn't support custom header validation + transportProvider = HttpServletSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(MESSAGE_ENDPOINT) + .sseEndpoint(SSE_ENDPOINT) + .dnsRebindingProtectionConfig(dnsRebindingProtectionConfig) + .build(); + + startServer(); + + // Create client - should succeed since DNS rebinding protection is disabled + try (var client = McpClient + .sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT).sseEndpoint(SSE_ENDPOINT).build()) + .build()) { + + // Connection should succeed + McpSchema.InitializeResult result = client.initialize(); + assertThat(result).isNotNull(); + assertThat(result.serverInfo().name()).isEqualTo("test-server"); + } + } + + @Test + void testConnectionFailsWithInvalidHeaders() { + // Create DNS rebinding protection config with restricted hosts + DnsRebindingProtectionConfig dnsRebindingProtectionConfig = DnsRebindingProtectionConfig.builder() + .allowedHost("valid-host.com") + .build(); + + // Create server with header validation + transportProvider = HttpServletSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(MESSAGE_ENDPOINT) + .sseEndpoint(SSE_ENDPOINT) + .dnsRebindingProtectionConfig(dnsRebindingProtectionConfig) + .build(); + + startServer(); + + // Create client with localhost which won't match the allowed host + // The Host header will be "localhost:PORT" which won't match "valid-host.com" + var clientTransport = HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(SSE_ENDPOINT) + .build(); + + // Connection should fail during initialization + assertThatThrownBy(() -> { + try (var client = McpClient.sync(clientTransport).build()) { + client.initialize(); + } + }).isInstanceOf(RuntimeException.class); + } + + @Test + void testConnectionFailsWithEmptyAllowedHostsButProvidedHost() { + // Create DNS rebinding protection config with specific allowed origin but no + // allowed hosts + // This means any non-null host will be rejected + DnsRebindingProtectionConfig dnsRebindingProtectionConfig = DnsRebindingProtectionConfig.builder() + .allowedOrigin("http://allowed-origin.com") + .build(); + + // Create server with header validation + transportProvider = HttpServletSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(MESSAGE_ENDPOINT) + .sseEndpoint(SSE_ENDPOINT) + .dnsRebindingProtectionConfig(dnsRebindingProtectionConfig) + .build(); + + startServer(); + + // Create client - the client will send a Host header like "localhost:PORT" + // Since allowedHosts is empty, any non-null host will be rejected + var clientTransport = HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(SSE_ENDPOINT) + .build(); + + // With the new behavior, a non-null Host header is rejected when allowedHosts is + // empty + assertThatThrownBy(() -> { + try (var client = McpClient.sync(clientTransport).build()) { + client.initialize(); + } + }).isInstanceOf(RuntimeException.class); + } + + @Test + void testComplexHeaderValidation() { + // Create DNS rebinding protection config with specific allowed hosts and origins + // Note: The Host header will include the port, so we need to allow + // "localhost:PORT" + DnsRebindingProtectionConfig dnsRebindingProtectionConfig = DnsRebindingProtectionConfig.builder() + .allowedHost("localhost:" + PORT) + .allowedOrigin("http://localhost:" + PORT) + .build(); + + // Create server with DNS rebinding protection + transportProvider = HttpServletSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(MESSAGE_ENDPOINT) + .sseEndpoint(SSE_ENDPOINT) + .dnsRebindingProtectionConfig(dnsRebindingProtectionConfig) + .build(); + + startServer(); + + // Test with valid headers (localhost is allowed) + try (var client = McpClient + .sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT).sseEndpoint(SSE_ENDPOINT).build()) + .build()) { + + McpSchema.InitializeResult result = client.initialize(); + assertThat(result).isNotNull(); + } + + // Test with different host (should fail) + var invalidHostTransport = HttpClientSseClientTransport.builder("http://127.0.0.1:" + PORT) + .sseEndpoint(SSE_ENDPOINT) + .build(); + + assertThatThrownBy(() -> { + try (var client = McpClient.sync(invalidHostTransport).build()) { + client.initialize(); + } + }).isInstanceOf(RuntimeException.class); + } + + @Test + void testDefaultValidatorAllowsAllHeaders() { + // Create server without specifying a DNS rebinding protection config (no + // validation) + transportProvider = HttpServletSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(MESSAGE_ENDPOINT) + .sseEndpoint(SSE_ENDPOINT) + .build(); + + startServer(); + + // Create client with arbitrary headers + try (var client = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(SSE_ENDPOINT) + .customizeRequest(requestBuilder -> { + requestBuilder.header("X-Random-Header", "random-value"); + requestBuilder.header("X-Another-Header", "another-value"); + }) + .build()).build()) { + + // Connection should succeed with any headers + McpSchema.InitializeResult result = client.initialize(); + assertThat(result).isNotNull(); + } + } + + private void startServer() { + tomcat = TomcatTestUtil.createTomcatServer("", PORT, transportProvider); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + server = McpServer.sync(transportProvider).serverInfo("test-server", "1.0.0").build(); + } + +}