From 602c9305ece56f8b95f5cbf1c86fff43269a5b83 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Tue, 27 May 2025 15:26:44 -0700 Subject: [PATCH 1/5] feat: implement support for elicitation --- .../WebFluxSseIntegrationTests.java | 225 ++++++++++ .../server/WebMvcSseIntegrationTests.java | 216 +++++++++ .../client/McpAsyncClient.java | 32 ++ .../client/McpClient.java | 40 +- .../client/McpClientFeatures.java | 31 +- .../server/McpAsyncServerExchange.java | 28 ++ .../server/McpSyncServerExchange.java | 18 + .../modelcontextprotocol/spec/McpSchema.java | 417 +++++++++++++++++- .../spec/McpSchemaObjectDeserializer.java | 53 +++ .../client/AbstractMcpAsyncClientTests.java | 22 +- .../McpAsyncClientResponseHandlerTests.java | 159 +++++++ ...rverTransportProviderIntegrationTests.java | 216 +++++++++ .../spec/McpSchemaTests.java | 99 +++++ 13 files changed, 1525 insertions(+), 31 deletions(-) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpSchemaObjectDeserializer.java diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 03fbc996..b329de80 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -33,6 +33,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.publisher.Mono; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; @@ -41,6 +42,7 @@ import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.server.RouterFunctions; +import reactor.test.StepVerifier; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -331,6 +333,229 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt mcpServer.closeGracefully().block(); } + // --------------------------------------- + // Elicitation Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationWithoutElicitationCapabilities(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.createElicitation(mock(ElicitRequest.class)).block(); + + return Mono.just(mock(CallToolResult.class)); + }); + + var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + + try ( + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + } + } + server.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder() + .properties(Map.of("message", McpSchema.StringSchema.builder().build())) + .build()) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + mcpServer.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { + + // Client + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder() + .properties(Map.of("message", McpSchema.StringSchema.builder().build())) + .build()) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationWithRequestTimeoutFail(String clientType) { + + // Client + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder() + .properties(Map.of("message", McpSchema.StringSchema.builder().build())) + .build()) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("within 1000ms"); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + // --------------------------------------- // Roots Tests // --------------------------------------- diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index b12d6843..a9e5623e 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -357,6 +357,222 @@ void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { mcpServer.close(); } + // --------------------------------------- + // Elicitation Tests + // --------------------------------------- + @Test + void testCreateElicitationWithoutElicitationCapabilities() { + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.createElicitation(mock(McpSchema.ElicitRequest.class)).block(); + + return Mono.just(mock(CallToolResult.class)); + }); + + var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + + try ( + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + } + } + server.closeGracefully().block(); + } + + @Test + void testCreateElicitationSuccess() { + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, + Map.of("message", request.message())); + }; + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder() + .properties(Map.of("message", McpSchema.StringSchema.builder().build())) + .build()) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + mcpServer.closeGracefully().block(); + } + + @Test + void testCreateElicitationWithRequestTimeoutSuccess() { + + // Client + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, + Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder() + .properties(Map.of("message", McpSchema.StringSchema.builder().build())) + .build()) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + + @Test + void testCreateElicitationWithRequestTimeoutFail() { + + // Client + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, + Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder() + .properties(Map.of("message", McpSchema.StringSchema.builder().build())) + .build()) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("Timeout"); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + // --------------------------------------- // Roots Tests // --------------------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index e3a997ba..a22ef6b5 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -23,6 +23,8 @@ import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; import io.modelcontextprotocol.spec.McpSchema.ListPromptsResult; @@ -141,6 +143,15 @@ public class McpAsyncClient { */ private Function> samplingHandler; + /** + * MCP provides a standardized way for servers to request additional information from + * users through the client during interactions. This flow allows clients to maintain + * control over user interactions and data sharing while enabling servers to gather + * necessary information dynamically. Servers can request structured data from users + * with optional JSON schemas to validate responses. + */ + private Function> elicitationHandler; + /** * Client transport implementation. */ @@ -189,6 +200,15 @@ public class McpAsyncClient { requestHandlers.put(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, samplingCreateMessageHandler()); } + // Elicitation Handler + if (this.clientCapabilities.elicitation() != null) { + if (features.elicitationHandler() == null) { + throw new McpError("Elicitation handler must not be null when client capabilities include elicitation"); + } + this.elicitationHandler = features.elicitationHandler(); + requestHandlers.put(McpSchema.METHOD_ELICITATION_CREATE, elicitationCreateHandler()); + } + // Notification Handlers Map notificationHandlers = new HashMap<>(); @@ -500,6 +520,18 @@ private RequestHandler samplingCreateMessageHandler() { }; } + // -------------------------- + // Elicitation + // -------------------------- + private RequestHandler elicitationCreateHandler() { + return params -> { + ElicitRequest request = transport.unmarshalFrom(params, new TypeReference<>() { + }); + + return this.elicitationHandler.apply(request); + }; + } + // -------------------------- // Tools // -------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java index a1dc1168..280906cf 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java @@ -18,6 +18,8 @@ import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.Implementation; import io.modelcontextprotocol.spec.McpSchema.Root; import io.modelcontextprotocol.util.Assert; @@ -175,6 +177,8 @@ class SyncSpec { private Function samplingHandler; + private Function elicitationHandler; + private SyncSpec(McpClientTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; @@ -283,6 +287,21 @@ public SyncSpec sampling(Function sam return this; } + /** + * Sets a custom elicitation handler for processing elicitation message requests. + * The elicitation handler can modify or validate messages before they are sent to + * the server, enabling custom processing logic. + * @param elicitationHandler A function that processes elicitation requests and + * returns results. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if elicitationHandler is null + */ + public SyncSpec elicitation(Function elicitationHandler) { + Assert.notNull(elicitationHandler, "Elicitation handler must not be null"); + this.elicitationHandler = elicitationHandler; + return this; + } + /** * Adds a consumer to be notified when the available tools change. This allows the * client to react to changes in the server's tool capabilities, such as tools @@ -364,7 +383,7 @@ public SyncSpec loggingConsumers(List> samplingHandler; + private Function> elicitationHandler; + private AsyncSpec(McpClientTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; @@ -522,6 +543,21 @@ public AsyncSpec sampling(Function> elicitationHandler) { + Assert.notNull(elicitationHandler, "Elicitation handler must not be null"); + this.elicitationHandler = elicitationHandler; + return this; + } + /** * Adds a consumer to be notified when the available tools change. This allows the * client to react to changes in the server's tool capabilities, such as tools @@ -606,7 +642,7 @@ public McpAsyncClient build() { return new McpAsyncClient(this.transport, this.requestTimeout, this.initializationTimeout, new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.promptsChangeConsumers, - this.loggingConsumers, this.samplingHandler)); + this.loggingConsumers, this.samplingHandler, this.elicitationHandler)); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java index 284b93f8..23d7c6a6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java @@ -60,13 +60,15 @@ class McpClientFeatures { * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. * @param samplingHandler the sampling handler. + * @param elicitationHandler the elicitation handler. */ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List, Mono>> toolsChangeConsumers, List, Mono>> resourcesChangeConsumers, List, Mono>> promptsChangeConsumers, List>> loggingConsumers, - Function> samplingHandler) { + Function> samplingHandler, + Function> elicitationHandler) { /** * Create an instance and validate the arguments. @@ -77,6 +79,7 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. * @param samplingHandler the sampling handler. + * @param elicitationHandler the elicitation handler. */ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, @@ -84,14 +87,16 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c List, Mono>> resourcesChangeConsumers, List, Mono>> promptsChangeConsumers, List>> loggingConsumers, - Function> samplingHandler) { + Function> samplingHandler, + Function> elicitationHandler) { Assert.notNull(clientInfo, "Client info must not be null"); this.clientInfo = clientInfo; this.clientCapabilities = (clientCapabilities != null) ? clientCapabilities : new McpSchema.ClientCapabilities(null, !Utils.isEmpty(roots) ? new McpSchema.ClientCapabilities.RootCapabilities(false) : null, - samplingHandler != null ? new McpSchema.ClientCapabilities.Sampling() : null); + samplingHandler != null ? new McpSchema.ClientCapabilities.Sampling() : null, + elicitationHandler != null ? new McpSchema.ClientCapabilities.Elicitation() : null); this.roots = roots != null ? new ConcurrentHashMap<>(roots) : new ConcurrentHashMap<>(); this.toolsChangeConsumers = toolsChangeConsumers != null ? toolsChangeConsumers : List.of(); @@ -99,6 +104,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c this.promptsChangeConsumers = promptsChangeConsumers != null ? promptsChangeConsumers : List.of(); this.loggingConsumers = loggingConsumers != null ? loggingConsumers : List.of(); this.samplingHandler = samplingHandler; + this.elicitationHandler = elicitationHandler; } /** @@ -138,9 +144,14 @@ public static Async fromSync(Sync syncSpec) { Function> samplingHandler = r -> Mono .fromCallable(() -> syncSpec.samplingHandler().apply(r)) .subscribeOn(Schedulers.boundedElastic()); + + Function> elicitationHandler = r -> Mono + .fromCallable(() -> syncSpec.elicitationHandler().apply(r)) + .subscribeOn(Schedulers.boundedElastic()); + return new Async(syncSpec.clientInfo(), syncSpec.clientCapabilities(), syncSpec.roots(), toolsChangeConsumers, resourcesChangeConsumers, promptsChangeConsumers, loggingConsumers, - samplingHandler); + samplingHandler, elicitationHandler); } } @@ -156,13 +167,15 @@ public static Async fromSync(Sync syncSpec) { * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. * @param samplingHandler the sampling handler. + * @param elicitationHandler the elicitation handler. */ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List>> toolsChangeConsumers, List>> resourcesChangeConsumers, List>> promptsChangeConsumers, List> loggingConsumers, - Function samplingHandler) { + Function samplingHandler, + Function elicitationHandler) { /** * Create an instance and validate the arguments. @@ -174,20 +187,23 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. * @param samplingHandler the sampling handler. + * @param elicitationHandler the elicitation handler. */ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List>> toolsChangeConsumers, List>> resourcesChangeConsumers, List>> promptsChangeConsumers, List> loggingConsumers, - Function samplingHandler) { + Function samplingHandler, + Function elicitationHandler) { Assert.notNull(clientInfo, "Client info must not be null"); this.clientInfo = clientInfo; this.clientCapabilities = (clientCapabilities != null) ? clientCapabilities : new McpSchema.ClientCapabilities(null, !Utils.isEmpty(roots) ? new McpSchema.ClientCapabilities.RootCapabilities(false) : null, - samplingHandler != null ? new McpSchema.ClientCapabilities.Sampling() : null); + samplingHandler != null ? new McpSchema.ClientCapabilities.Sampling() : null, + elicitationHandler != null ? new McpSchema.ClientCapabilities.Elicitation() : null); this.roots = roots != null ? new HashMap<>(roots) : new HashMap<>(); this.toolsChangeConsumers = toolsChangeConsumers != null ? toolsChangeConsumers : List.of(); @@ -195,6 +211,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl this.promptsChangeConsumers = promptsChangeConsumers != null ? promptsChangeConsumers : List.of(); this.loggingConsumers = loggingConsumers != null ? loggingConsumers : List.of(); this.samplingHandler = samplingHandler; + this.elicitationHandler = elicitationHandler; } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index 889dc66d..cfb07d26 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -36,6 +36,9 @@ public class McpAsyncServerExchange { private static final TypeReference LIST_ROOTS_RESULT_TYPE_REF = new TypeReference<>() { }; + private static final TypeReference ELICITATION_RESULT_TYPE_REF = new TypeReference<>() { + }; + /** * Create a new asynchronous exchange with the client. * @param session The server session representing a 1-1 interaction. @@ -93,6 +96,31 @@ public Mono createMessage(McpSchema.CreateMessage CREATE_MESSAGE_RESULT_TYPE_REF); } + /** + * Creates a new elicitation. MCP provides a standardized way for servers to request + * additional information from users through the client during interactions. This flow + * allows clients to maintain control over user interactions and data sharing while + * enabling servers to gather necessary information dynamically. Servers can request + * structured data from users with optional JSON schemas to validate responses. + * @param elicitRequest The request to create a new elicitation + * @return A Mono that completes when the elicitation has been resolved. + * @see McpSchema.ElicitRequest + * @see McpSchema.ElicitResult + * @see Elicitation + * Specification + */ + public Mono createElicitation(McpSchema.ElicitRequest elicitRequest) { + if (this.clientCapabilities == null) { + return Mono.error(new McpError("Client must be initialized. Call the initialize method first!")); + } + if (this.clientCapabilities.elicitation() == null) { + return Mono.error(new McpError("Client must be configured with elicitation capabilities")); + } + return this.session.sendRequest(McpSchema.METHOD_ELICITATION_CREATE, elicitRequest, + ELICITATION_RESULT_TYPE_REF); + } + /** * Retrieves the list of all roots provided by the client. * @return A Mono that emits the list of roots result. diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java index 52360e54..084412b9 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java @@ -64,6 +64,24 @@ public McpSchema.CreateMessageResult createMessage(McpSchema.CreateMessageReques return this.exchange.createMessage(createMessageRequest).block(); } + /** + * Creates a new elicitation. MCP provides a standardized way for servers to request + * additional information from users through the client during interactions. This flow + * allows clients to maintain control over user interactions and data sharing while + * enabling servers to gather necessary information dynamically. Servers can request + * structured data from users with optional JSON schemas to validate responses. + * @param elicitRequest The request to create a new elicitation + * @return A result containing the elicitation response. + * @see McpSchema.ElicitRequest + * @see McpSchema.ElicitResult + * @see Elicitation + * Specification + */ + public McpSchema.ElicitResult createElicitation(McpSchema.ElicitRequest elicitRequest) { + return this.exchange.createElicitation(elicitRequest).block(); + } + /** * Retrieves the list of all roots provided by the client. * @return The list of roots result. diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 8df8a158..14fa163b 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -10,14 +10,12 @@ import java.util.List; import java.util.Map; -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.annotation.JsonSubTypes; -import com.fasterxml.jackson.annotation.JsonTypeInfo; +import com.fasterxml.jackson.annotation.*; import com.fasterxml.jackson.annotation.JsonTypeInfo.As; import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JsonDeserializer; import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -94,6 +92,9 @@ private McpSchema() { // Sampling Methods public static final String METHOD_SAMPLING_CREATE_MESSAGE = "sampling/createMessage"; + // Elicitation Methods + public static final String METHOD_ELICITATION_CREATE = "elicitation/create"; + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); // --------------------------- @@ -131,8 +132,8 @@ public static final class ErrorCodes { } - public sealed interface Request - permits InitializeRequest, CallToolRequest, CreateMessageRequest, CompleteRequest, GetPromptRequest { + public sealed interface Request permits InitializeRequest, CallToolRequest, CreateMessageRequest, ElicitRequest, + CompleteRequest, GetPromptRequest { } @@ -221,7 +222,7 @@ public record JSONRPCError( public record InitializeRequest( // @formatter:off @JsonProperty("protocolVersion") String protocolVersion, @JsonProperty("capabilities") ClientCapabilities capabilities, - @JsonProperty("clientInfo") Implementation clientInfo) implements Request { + @JsonProperty("clientInfo") Implementation clientInfo) implements Request { } // @formatter:on @JsonInclude(JsonInclude.Include.NON_ABSENT) @@ -245,6 +246,8 @@ public record InitializeResult( // @formatter:off * access to. * @param sampling Provides a standardized way for servers to request LLM sampling * (“completions” or “generations”) from language models via clients. + * @param elicitation Provides a standardized way for servers to request additional + * information from users through the client during interactions. * */ @JsonInclude(JsonInclude.Include.NON_ABSENT) @@ -252,7 +255,8 @@ public record InitializeResult( // @formatter:off public record ClientCapabilities( // @formatter:off @JsonProperty("experimental") Map experimental, @JsonProperty("roots") RootCapabilities roots, - @JsonProperty("sampling") Sampling sampling) { + @JsonProperty("sampling") Sampling sampling, + @JsonProperty("elicitation") Elicitation elicitation) { /** * Roots define the boundaries of where servers can operate within the filesystem, @@ -264,7 +268,7 @@ public record ClientCapabilities( // @formatter:off * has changed since the last time the server checked. */ @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) + @JsonIgnoreProperties(ignoreUnknown = true) public record RootCapabilities( @JsonProperty("listChanged") Boolean listChanged) { } @@ -279,10 +283,22 @@ public record RootCapabilities( * image-based interactions and optionally include context * from MCP servers in their prompts. */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonInclude(JsonInclude.Include.NON_ABSENT) public record Sampling() { } + /** + * Provides a standardized way for servers to request additional + * information from users through the client during interactions. + * This flow allows clients to maintain control over user + * interactions and data sharing while enabling servers to gather + * necessary information dynamically. Servers can request structured + * data from users with optional JSON schemas to validate responses. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record Elicitation() { + } + public static Builder builder() { return new Builder(); } @@ -291,6 +307,7 @@ public static class Builder { private Map experimental; private RootCapabilities roots; private Sampling sampling; + private Elicitation elicitation; public Builder experimental(Map experimental) { this.experimental = experimental; @@ -307,8 +324,13 @@ public Builder sampling() { return this; } + public Builder elicitation() { + this.elicitation = new Elicitation(); + return this; + } + public ClientCapabilities build() { - return new ClientCapabilities(experimental, roots, sampling); + return new ClientCapabilities(experimental, roots, sampling, elicitation); } } }// @formatter:on @@ -326,11 +348,11 @@ public record ServerCapabilities( // @formatter:off @JsonInclude(JsonInclude.Include.NON_ABSENT) public record CompletionCapabilities() { } - + @JsonInclude(JsonInclude.Include.NON_ABSENT) public record LoggingCapabilities() { } - + @JsonInclude(JsonInclude.Include.NON_ABSENT) public record PromptCapabilities( @JsonProperty("listChanged") Boolean listChanged) { @@ -727,11 +749,11 @@ public record Tool( // @formatter:off @JsonProperty("name") String name, @JsonProperty("description") String description, @JsonProperty("inputSchema") JsonSchema inputSchema) { - + public Tool(String name, String description, String schema) { this(name, description, parseSchema(schema)); } - + } // @formatter:on private static JsonSchema parseSchema(String schema) { @@ -758,7 +780,7 @@ public record CallToolRequest(// @formatter:off @JsonProperty("arguments") Map arguments) implements Request { public CallToolRequest(String name, String jsonArguments) { - this(name, parseJsonArguments(jsonArguments)); + this(name, parseJsonArguments(jsonArguments)); } private static Map parseJsonArguments(String jsonArguments) { @@ -893,7 +915,7 @@ public record ModelPreferences(// @formatter:off @JsonProperty("costPriority") Double costPriority, @JsonProperty("speedPriority") Double speedPriority, @JsonProperty("intelligencePriority") Double intelligencePriority) { - + public static Builder builder() { return new Builder(); } @@ -963,7 +985,7 @@ public record CreateMessageRequest(// @formatter:off @JsonProperty("includeContext") ContextInclusionStrategy includeContext, @JsonProperty("temperature") Double temperature, @JsonProperty("maxTokens") int maxTokens, - @JsonProperty("stopSequences") List stopSequences, + @JsonProperty("stopSequences") List stopSequences, @JsonProperty("metadata") Map metadata) implements Request { public enum ContextInclusionStrategy { @@ -971,7 +993,7 @@ public enum ContextInclusionStrategy { @JsonProperty("thisServer") THIS_SERVER, @JsonProperty("allServers") ALL_SERVERS } - + public static Builder builder() { return new Builder(); } @@ -1040,7 +1062,7 @@ public record CreateMessageResult(// @formatter:off @JsonProperty("content") Content content, @JsonProperty("model") String model, @JsonProperty("stopReason") StopReason stopReason) { - + public enum StopReason { @JsonProperty("endTurn") END_TURN, @JsonProperty("stopSequence") STOP_SEQUENCE, @@ -1088,6 +1110,359 @@ public CreateMessageResult build() { } }// @formatter:on + // Elicitation + /** + * Used by the server to send an elicitation to the client. + * + * @param message The body of the elicitation message. + * @param requestedSchema The elicitation response schema that must be satisfied. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ElicitRequest(// @formatter:off + @JsonProperty("message") String message, + @JsonProperty("requestedSchema") PrimitiveSchemaDefinition requestedSchema) implements Request { + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private String message; + private PrimitiveSchemaDefinition requestedSchema; + + public Builder message(String message) { + this.message = message; + return this; + } + + public Builder requestedSchema(PrimitiveSchemaDefinition requestedSchema) { + this.requestedSchema = requestedSchema; + return this; + } + + public ElicitRequest build() { + return new ElicitRequest(message, requestedSchema); + } + } + }// @formatter:on + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ElicitResult(// @formatter:off + @JsonProperty("action") Action action, + @JsonProperty("content") Map content) { + + public enum Action { + @JsonProperty("accept") ACCEPT, + @JsonProperty("decline") DECLINE, + @JsonProperty("cancel") CANCEL + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private Action action; + private Map content; + + public Builder message(Action action) { + this.action = action; + return this; + } + + public Builder content(Map content) { + this.content = content; + return this; + } + + public ElicitResult build() { + return new ElicitResult(action, content); + } + } + }// @formatter:on + + // Schema objects + + // We use a custom deserializer for this due to complications around handling the + // "string" type, + // which can either be used for StringSchema or EnumSchema, depending on if the "enum" + // field is present + // in the schema object. We can't cleanly combine tagged union deserialization with + // custom deserializer + // logic, so we just use the custom deserializer for this entire hierarchy. + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonDeserialize(using = McpSchemaObjectDeserializer.class) + public sealed interface Schema permits StringSchema, EnumSchema, BooleanSchema, NumberSchema { + + @JsonProperty("type") + default String type() { + if (this instanceof StringSchema || this instanceof EnumSchema) { + return "string"; + } + + if (this instanceof BooleanSchema) { + return "boolean"; + } + + if (this instanceof NumberSchema ns) { + // NumberSchema keeps track of if it was created with "number" or + // "integer" + return ns.typeVariant.toString(); + } + + throw new IllegalArgumentException("Unknown schema type: " + this); + } + + } + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record PrimitiveSchemaDefinition(// @formatter:off + @JsonProperty("properties") Map properties, + @JsonProperty("required") List required) { + + @JsonProperty("type") + public String type() { + return "object"; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private Map properties; + private List required; + + public Builder properties(Map properties) { + this.properties = properties; + return this; + } + + public Builder required(List required) { + this.required = required; + return this; + } + + public PrimitiveSchemaDefinition build() { + return new PrimitiveSchemaDefinition(properties, required); + } + } + }// @formatter:on + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonDeserialize // Override to default deserializer to avoid recursion + public record StringSchema(// @formatter:off + @JsonProperty("description") String description, + @JsonProperty("format") Format format, + @JsonProperty("maxLength") Integer maxLength, + @JsonProperty("minLength") Integer minLength, + @JsonProperty("title") String title) implements Schema { + + public enum Format { + @JsonProperty("date") DATE, + @JsonProperty("date-time") DATE_TIME, + @JsonProperty("email") EMAIL, + @JsonProperty("uri") URI + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private String description; + private Format format; + private Integer maxLength; + private Integer minLength; + private String title; + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder format(Format format) { + this.format = format; + return this; + } + + public Builder maxLength(Integer maxLength) { + this.maxLength = maxLength; + return this; + } + + public Builder minLength(Integer minLength) { + this.minLength = minLength; + return this; + } + + public Builder title(String title) { + this.title = title; + return this; + } + + public StringSchema build() { + return new StringSchema(description, format, maxLength, minLength, title); + } + } + }// @formatter:on + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonDeserialize // Override to default deserializer to avoid recursion + public record EnumSchema(// @formatter:off + @JsonProperty("description") String description, + @JsonProperty("enum") List enumValues, + @JsonProperty("enumNames") List enumNames, + @JsonProperty("title") String title) implements Schema { + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private String description; + private List enumValues; + private List enumNames; + private String title; + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder enumValues(List enumValues) { + this.enumValues = enumValues; + return this; + } + + public Builder enumNames(List enumNames) { + this.enumNames = enumNames; + return this; + } + + public Builder title(String title) { + this.title = title; + return this; + } + + public EnumSchema build() { + return new EnumSchema(description, enumValues, enumNames, title); + } + } + }// @formatter:on + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonDeserialize // Override to default deserializer to avoid recursion + public record BooleanSchema(// @formatter:off + @JsonProperty("default") boolean defaultValue, + @JsonProperty("description") String description, + @JsonProperty("title") String title) implements Schema { + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private boolean defaultValue; + private String description; + private String title; + + public Builder defaultValue(boolean defaultValue) { + this.defaultValue = defaultValue; + return this; + } + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder title(String title) { + this.title = title; + return this; + } + + public BooleanSchema build() { + return new BooleanSchema(defaultValue, description, title); + } + } + }// @formatter:on + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonDeserialize // Override to default deserializer to avoid recursion + public record NumberSchema(// @formatter:off + @JsonProperty("description") String description, + @JsonProperty("minimum") double minimum, + @JsonProperty("maximum") double maximum, + @JsonProperty("title") String title, + @JsonProperty("type") TypeVariant typeVariant) implements Schema { + + public enum TypeVariant { + @JsonProperty("number") NUMBER("number"), + @JsonProperty("integer") INTEGER("integer"); + + private final String name; + + TypeVariant(String value) { + name = value; + } + + @Override + public String toString() { + return this.name; + } + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private String description; + private double minimum; + private double maximum; + private String title; + private TypeVariant type; + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder minimum(double minimum) { + this.minimum = minimum; + return this; + } + + public Builder maximum(double maximum) { + this.maximum = maximum; + return this; + } + + public Builder title(String title) { + this.title = title; + return this; + } + + public Builder type(TypeVariant type) { + this.type = type; + return this; + } + + public NumberSchema build() { + return new NumberSchema(description, minimum, maximum, title, type); + } + } + }// @formatter:on + // --------------------------- // Pagination Interfaces // --------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchemaObjectDeserializer.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchemaObjectDeserializer.java new file mode 100644 index 00000000..8d41e1c9 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchemaObjectDeserializer.java @@ -0,0 +1,53 @@ +package io.modelcontextprotocol.spec; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.ObjectReader; +import com.fasterxml.jackson.databind.deser.std.StdDeserializer; +import com.fasterxml.jackson.databind.node.TreeTraversingParser; + +import java.io.IOException; + +public class McpSchemaObjectDeserializer extends StdDeserializer { + + public McpSchemaObjectDeserializer() { + this(null); + } + + public McpSchemaObjectDeserializer(Class vc) { + super(vc); + } + + @Override + public McpSchema.Schema deserialize(JsonParser jp, DeserializationContext ctx) throws IOException { + ObjectMapper mapper = (ObjectMapper) jp.getCodec(); + JsonNode root = mapper.readTree(jp); + + String type = root.path("type").asText(); + if ("string".equals(type)) { + if (root.has("enum")) { + return readValue(mapper, root, McpSchema.EnumSchema.class); + } + else { + return readValue(mapper, root, McpSchema.StringSchema.class); + } + } + else if ("boolean".equals(type)) { + return readValue(mapper, root, McpSchema.BooleanSchema.class); + } + else if ("number".equals(type) || "integer".equals(type)) { + return readValue(mapper, root, McpSchema.NumberSchema.class); + } + + throw new RuntimeException("Unknown schema type: " + type); + } + + private T readValue(ObjectMapper mapper, JsonNode node, Class clazz) throws IOException { + ObjectReader reader = mapper.readerFor(clazz); + TreeTraversingParser treeParser = new TreeTraversingParser(node, mapper); + return reader.readValue(treeParser); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 72b409af..d1a2581e 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -19,6 +19,8 @@ import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; import io.modelcontextprotocol.spec.McpSchema.Prompt; import io.modelcontextprotocol.spec.McpSchema.Resource; @@ -422,6 +424,20 @@ void testInitializeWithSamplingCapability() { }); } + @Test + void testInitializeWithElicitationCapability() { + ClientCapabilities capabilities = ClientCapabilities.builder().elicitation().build(); + ElicitResult elicitResult = ElicitResult.builder() + .message(ElicitResult.Action.ACCEPT) + .content(Map.of("foo", "bar")) + .build(); + withClient(createMcpTransport(), + builder -> builder.capabilities(capabilities).elicitation(request -> Mono.just(elicitResult)), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + }); + } + @Test void testInitializeWithAllCapabilities() { var capabilities = ClientCapabilities.builder() @@ -433,7 +449,11 @@ void testInitializeWithAllCapabilities() { Function> samplingHandler = request -> Mono .just(CreateMessageResult.builder().message("test").model("test-model").build()); - withClient(createMcpTransport(), builder -> builder.capabilities(capabilities).sampling(samplingHandler), + Function> elicitationHandler = request -> Mono + .just(ElicitResult.builder().message(ElicitResult.Action.ACCEPT).content(Map.of("foo", "bar")).build()); + + withClient(createMcpTransport(), + builder -> builder.capabilities(capabilities).sampling(samplingHandler).elicitation(elicitationHandler), client -> StepVerifier.create(client.initialize()).assertNext(result -> { diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java index 4510b152..0d1bef1b 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java @@ -19,6 +19,8 @@ import io.modelcontextprotocol.spec.McpSchema.InitializeResult; import io.modelcontextprotocol.spec.McpSchema.Root; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; import reactor.core.publisher.Mono; import static io.modelcontextprotocol.spec.McpSchema.METHOD_INITIALIZE; @@ -349,4 +351,161 @@ void testSamplingCreateMessageRequestHandlingWithNullHandler() { .hasMessage("Sampling handler must not be null when client capabilities include sampling"); } + @Test + void testElicitationCreateRequestHandling() { + MockMcpClientTransport transport = initializationEnabledTransport(); + + // Create a test elicitation handler that echoes back the input + Function> elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isInstanceOf(McpSchema.PrimitiveSchemaDefinition.class); + assertThat(request.requestedSchema().type()).isEqualTo("object"); + + var properties = request.requestedSchema().properties(); + assertThat(properties).isNotNull(); + assertThat(properties.get("message")).isInstanceOf(McpSchema.StringSchema.class); + + return Mono.just(McpSchema.ElicitResult.builder() + .message(McpSchema.ElicitResult.Action.ACCEPT) + .content(Map.of("message", request.message())) + .build()); + }; + + // Create client with elicitation capability and handler + McpAsyncClient asyncMcpClient = McpClient.async(transport) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + + // Create a mock elicitation + var elicitRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder() + .properties(Map.of("message", McpSchema.StringSchema.builder().build())) + .build()) + .build(); + + // Simulate incoming request + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_ELICITATION_CREATE, "test-id", elicitRequest); + transport.simulateIncomingMessage(request); + + // Verify response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.error()).isNull(); + + McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeReference<>() { + }); + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content()).isEqualTo(Map.of("message", "Test message")); + + asyncMcpClient.closeGracefully(); + } + + @ParameterizedTest + @EnumSource(value = McpSchema.ElicitResult.Action.class, names = { "DECLINE", "CANCEL" }) + void testElicitationFailRequestHandling(McpSchema.ElicitResult.Action action) { + MockMcpClientTransport transport = initializationEnabledTransport(); + + // Create a test elicitation handler to decline the request + Function> elicitationHandler = request -> Mono + .just(McpSchema.ElicitResult.builder().message(action).build()); + + // Create client with elicitation capability and handler + McpAsyncClient asyncMcpClient = McpClient.async(transport) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + + // Create a mock elicitation + var elicitRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder() + .properties(Map.of("message", McpSchema.StringSchema.builder().build())) + .build()) + .build(); + + // Simulate incoming request + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_ELICITATION_CREATE, "test-id", elicitRequest); + transport.simulateIncomingMessage(request); + + // Verify response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.error()).isNull(); + + McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeReference<>() { + }); + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(action); + assertThat(result.content()).isNull(); + + asyncMcpClient.closeGracefully(); + } + + @Test + void testElicitationCreateRequestHandlingWithoutCapability() { + MockMcpClientTransport transport = initializationEnabledTransport(); + + // Create client without elicitation capability + McpAsyncClient asyncMcpClient = McpClient.async(transport) + .capabilities(ClientCapabilities.builder().build()) // No elicitation + // capability + .build(); + + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + + // Create a mock elicitation + var elicitRequest = new McpSchema.ElicitRequest("test", + McpSchema.PrimitiveSchemaDefinition.builder() + .properties(Map.of("test", + McpSchema.BooleanSchema.builder() + .defaultValue(true) + .description("test-description") + .title("test-title") + .build())) + .build()); + + // Simulate incoming request + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_ELICITATION_CREATE, "test-id", elicitRequest); + transport.simulateIncomingMessage(request); + + // Verify error response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.result()).isNull(); + assertThat(response.error()).isNotNull(); + assertThat(response.error().message()).contains("Method not found: elicitation/create"); + + asyncMcpClient.closeGracefully(); + } + + @Test + void testElicitationCreateRequestHandlingWithNullHandler() { + MockMcpClientTransport transport = new MockMcpClientTransport(); + + // Create client with elicitation capability but null handler + assertThatThrownBy(() -> McpClient.async(transport) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .build()).isInstanceOf(McpError.class) + .hasMessage("Elicitation handler must not be null when client capabilities include elicitation"); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index 2ff6325a..8d1a1360 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -24,6 +24,8 @@ import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.InitializeResult; import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; import io.modelcontextprotocol.spec.McpSchema.Role; @@ -339,6 +341,220 @@ void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { mcpServer.close(); } + // --------------------------------------- + // Elicitation Tests + // --------------------------------------- + @Test + @Disabled + void testCreateElicitationWithoutElicitationCapabilities() { + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.createElicitation(mock(ElicitRequest.class)).block(); + + return Mono.just(mock(CallToolResult.class)); + }); + + var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + + try ( + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + } + } + server.closeGracefully().block(); + } + + @Test + void testCreateElicitationSuccess() { + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder() + .properties(Map.of("message", McpSchema.StringSchema.builder().build())) + .build()) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + mcpServer.closeGracefully().block(); + } + + @Test + void testCreateElicitationWithRequestTimeoutSuccess() { + + // Client + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder() + .properties(Map.of("message", McpSchema.StringSchema.builder().build())) + .build()) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + + @Test + void testCreateElicitationWithRequestTimeoutFail() { + + // Client + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder() + .properties(Map.of("message", McpSchema.StringSchema.builder().build())) + .build()) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("Timeout"); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + // --------------------------------------- // Roots Tests // --------------------------------------- diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java index ff78c1bf..0db6cac2 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java @@ -807,6 +807,42 @@ void testCreateMessageResult() throws Exception { {"role":"assistant","content":{"type":"text","text":"Assistant response"},"model":"gpt-4","stopReason":"endTurn"}""")); } + // Elicitation Tests + + @Test + void testCreateElicitationRequest() throws Exception { + McpSchema.ElicitRequest request = McpSchema.ElicitRequest.builder() + .requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder() + .required(List.of("a")) + .properties(Map.of("foo", McpSchema.StringSchema.builder().build())) + .build()) + .build(); + + String value = mapper.writeValueAsString(request); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"requestedSchema":{"properties":{"foo":{"type":"string"}},"required":["a"],"type":"object"}}""")); + } + + @Test + void testCreateElicitationResult() throws Exception { + McpSchema.ElicitResult result = McpSchema.ElicitResult.builder() + .content(Map.of("foo", "bar")) + .message(McpSchema.ElicitResult.Action.ACCEPT) + .build(); + + String value = mapper.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"action":"accept","content":{"foo":"bar"}}""")); + } + // Roots Tests @Test @@ -840,4 +876,67 @@ void testListRootsResult() throws Exception { } + // Schema Tests + + @Test + void testSchema() throws Exception { + McpSchema.PrimitiveSchemaDefinition schemaDefinition = McpSchema.PrimitiveSchemaDefinition.builder() + .properties(Map.of("foo", + McpSchema.StringSchema.builder() + .title("title") + .description("description") + .format(McpSchema.StringSchema.Format.URI) + .maxLength(10) + .minLength(1) + .build(), + "bar", + McpSchema.EnumSchema.builder() + .title("title") + .description("description") + .enumNames(List.of("A", "B", "C")) + .enumValues(List.of("a", "b", "c")) + .build(), + "baz", + McpSchema.NumberSchema.builder() + .title("title") + .description("description") + .maximum(10) + .minimum(1) + .type(McpSchema.NumberSchema.TypeVariant.INTEGER) + .build(), + "baz2", + McpSchema.NumberSchema.builder() + .title("title") + .description("description") + .maximum(0.2) + .minimum(0.1) + .type(McpSchema.NumberSchema.TypeVariant.NUMBER) + .build(), + "buz", + McpSchema.BooleanSchema.builder() + .title("title") + .description("description") + .defaultValue(true) + .build())) + .required(List.of("foo")) + .build(); + + String value = mapper.writeValueAsString(schemaDefinition); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"properties":{"bar":{"description":"description","enum":["a", "b", "c"],"enumNames":["A", "B", "C"],"title":"title","type":"string"},"baz":{"description":"description","maximum":10.0,"minimum":1.0,"title":"title","type":"integer"},"baz2":{"description":"description","maximum":0.2,"minimum":0.1,"title":"title","type":"number"},"buz":{"default":true,"description":"description","title":"title","type":"boolean"},"foo":{"description":"description","format":"uri","maxLength":10,"minLength":1,"title":"title","type":"string"}},"required":["foo"],"type":"object"}""")); + + // Attempt to go the other way, since Schema is a complex type to (de)serialize + // and behaves differently when + // serialized vs. deserialized due to the string/number types being used for + // multiple concrete types + McpSchema.PrimitiveSchemaDefinition schemaDefinition2 = mapper.readValue(value, + McpSchema.PrimitiveSchemaDefinition.class); + assertThat(schemaDefinition2).isEqualTo(schemaDefinition); + } + } From c9ae30ef671a3cd1e5987cd1e4155e43f89fa93a Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Thu, 5 Jun 2025 09:28:56 -0700 Subject: [PATCH 2/5] chore: fix sampling->elicitation in copied integration tests --- .../io/modelcontextprotocol/WebFluxSseIntegrationTests.java | 2 +- .../modelcontextprotocol/server/WebMvcSseIntegrationTests.java | 2 +- .../HttpServletSseServerTransportProviderIntegrationTests.java | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index b329de80..0c970c61 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -353,7 +353,7 @@ void testCreateElicitationWithoutElicitationCapabilities(String clientType) { var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); try ( - // Create client without sampling capabilities + // Create client without elicitation capabilities var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { assertThat(client.initialize()).isNotNull(); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index a9e5623e..70d0cbd2 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -374,7 +374,7 @@ void testCreateElicitationWithoutElicitationCapabilities() { var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); try ( - // Create client without sampling capabilities + // Create client without elicitation capabilities var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { assertThat(client.initialize()).isNotNull(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index 8d1a1360..f296cb9c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -359,7 +359,7 @@ void testCreateElicitationWithoutElicitationCapabilities() { var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); try ( - // Create client without sampling capabilities + // Create client without elicitation capabilities var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { assertThat(client.initialize()).isNotNull(); From b1b179f7c6c9d2a66fa9b952339d98ca248a7376 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Thu, 5 Jun 2025 09:31:43 -0700 Subject: [PATCH 3/5] chore: remove wildcard import in McpSchema --- .../main/java/io/modelcontextprotocol/spec/McpSchema.java | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 14fa163b..34b342fd 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -10,10 +10,13 @@ import java.util.List; import java.util.Map; -import com.fasterxml.jackson.annotation.*; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.fasterxml.jackson.annotation.JsonTypeInfo.As; import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.JsonDeserializer; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import io.modelcontextprotocol.util.Assert; From 5c02b513f76309f70b73fd2f95e0d5ecdcdae7f5 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Mon, 9 Jun 2025 10:26:08 -0700 Subject: [PATCH 4/5] chore: remove elicitation schema --- .../WebFluxSseIntegrationTests.java | 17 +- .../server/WebMvcSseIntegrationTests.java | 15 +- .../modelcontextprotocol/spec/McpSchema.java | 286 +----------------- .../spec/McpSchemaObjectDeserializer.java | 53 ---- .../McpAsyncClientResponseHandlerTests.java | 27 +- ...rverTransportProviderIntegrationTests.java | 15 +- .../spec/McpSchemaTests.java | 69 +---- 7 files changed, 32 insertions(+), 450 deletions(-) delete mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpSchemaObjectDeserializer.java diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 0c970c61..2f85654e 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -4,7 +4,6 @@ package io.modelcontextprotocol; import java.time.Duration; -import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -28,7 +27,6 @@ import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.*; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities.CompletionCapabilities; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; @@ -390,9 +388,8 @@ void testCreateElicitationSuccess(String clientType) { var elicitationRequest = ElicitRequest.builder() .message("Test message") - .requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder() - .properties(Map.of("message", McpSchema.StringSchema.builder().build())) - .build()) + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) .build(); StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { @@ -459,9 +456,8 @@ void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { var elicitationRequest = ElicitRequest.builder() .message("Test message") - .requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder() - .properties(Map.of("message", McpSchema.StringSchema.builder().build())) - .build()) + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) .build(); StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { @@ -525,9 +521,8 @@ void testCreateElicitationWithRequestTimeoutFail(String clientType) { var elicitationRequest = ElicitRequest.builder() .message("Test message") - .requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder() - .properties(Map.of("message", McpSchema.StringSchema.builder().build())) - .build()) + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) .build(); StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index 70d0cbd2..3f3f7be6 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -409,9 +409,8 @@ void testCreateElicitationSuccess() { var elicitationRequest = McpSchema.ElicitRequest.builder() .message("Test message") - .requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder() - .properties(Map.of("message", McpSchema.StringSchema.builder().build())) - .build()) + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) .build(); StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { @@ -477,9 +476,8 @@ void testCreateElicitationWithRequestTimeoutSuccess() { var elicitationRequest = McpSchema.ElicitRequest.builder() .message("Test message") - .requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder() - .properties(Map.of("message", McpSchema.StringSchema.builder().build())) - .build()) + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) .build(); StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { @@ -542,9 +540,8 @@ void testCreateElicitationWithRequestTimeoutFail() { var elicitationRequest = McpSchema.ElicitRequest.builder() .message("Test message") - .requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder() - .properties(Map.of("message", McpSchema.StringSchema.builder().build())) - .build()) + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) .build(); StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 34b342fd..a3e82953 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -1124,7 +1124,7 @@ public CreateMessageResult build() { @JsonIgnoreProperties(ignoreUnknown = true) public record ElicitRequest(// @formatter:off @JsonProperty("message") String message, - @JsonProperty("requestedSchema") PrimitiveSchemaDefinition requestedSchema) implements Request { + @JsonProperty("requestedSchema") Map requestedSchema) implements Request { public static Builder builder() { return new Builder(); @@ -1132,14 +1132,14 @@ public static Builder builder() { public static class Builder { private String message; - private PrimitiveSchemaDefinition requestedSchema; + private Map requestedSchema; public Builder message(String message) { this.message = message; return this; } - public Builder requestedSchema(PrimitiveSchemaDefinition requestedSchema) { + public Builder requestedSchema(Map requestedSchema) { this.requestedSchema = requestedSchema; return this; } @@ -1186,286 +1186,6 @@ public ElicitResult build() { } }// @formatter:on - // Schema objects - - // We use a custom deserializer for this due to complications around handling the - // "string" type, - // which can either be used for StringSchema or EnumSchema, depending on if the "enum" - // field is present - // in the schema object. We can't cleanly combine tagged union deserialization with - // custom deserializer - // logic, so we just use the custom deserializer for this entire hierarchy. - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - @JsonDeserialize(using = McpSchemaObjectDeserializer.class) - public sealed interface Schema permits StringSchema, EnumSchema, BooleanSchema, NumberSchema { - - @JsonProperty("type") - default String type() { - if (this instanceof StringSchema || this instanceof EnumSchema) { - return "string"; - } - - if (this instanceof BooleanSchema) { - return "boolean"; - } - - if (this instanceof NumberSchema ns) { - // NumberSchema keeps track of if it was created with "number" or - // "integer" - return ns.typeVariant.toString(); - } - - throw new IllegalArgumentException("Unknown schema type: " + this); - } - - } - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public record PrimitiveSchemaDefinition(// @formatter:off - @JsonProperty("properties") Map properties, - @JsonProperty("required") List required) { - - @JsonProperty("type") - public String type() { - return "object"; - } - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - private Map properties; - private List required; - - public Builder properties(Map properties) { - this.properties = properties; - return this; - } - - public Builder required(List required) { - this.required = required; - return this; - } - - public PrimitiveSchemaDefinition build() { - return new PrimitiveSchemaDefinition(properties, required); - } - } - }// @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - @JsonDeserialize // Override to default deserializer to avoid recursion - public record StringSchema(// @formatter:off - @JsonProperty("description") String description, - @JsonProperty("format") Format format, - @JsonProperty("maxLength") Integer maxLength, - @JsonProperty("minLength") Integer minLength, - @JsonProperty("title") String title) implements Schema { - - public enum Format { - @JsonProperty("date") DATE, - @JsonProperty("date-time") DATE_TIME, - @JsonProperty("email") EMAIL, - @JsonProperty("uri") URI - } - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - private String description; - private Format format; - private Integer maxLength; - private Integer minLength; - private String title; - - public Builder description(String description) { - this.description = description; - return this; - } - - public Builder format(Format format) { - this.format = format; - return this; - } - - public Builder maxLength(Integer maxLength) { - this.maxLength = maxLength; - return this; - } - - public Builder minLength(Integer minLength) { - this.minLength = minLength; - return this; - } - - public Builder title(String title) { - this.title = title; - return this; - } - - public StringSchema build() { - return new StringSchema(description, format, maxLength, minLength, title); - } - } - }// @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - @JsonDeserialize // Override to default deserializer to avoid recursion - public record EnumSchema(// @formatter:off - @JsonProperty("description") String description, - @JsonProperty("enum") List enumValues, - @JsonProperty("enumNames") List enumNames, - @JsonProperty("title") String title) implements Schema { - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - private String description; - private List enumValues; - private List enumNames; - private String title; - - public Builder description(String description) { - this.description = description; - return this; - } - - public Builder enumValues(List enumValues) { - this.enumValues = enumValues; - return this; - } - - public Builder enumNames(List enumNames) { - this.enumNames = enumNames; - return this; - } - - public Builder title(String title) { - this.title = title; - return this; - } - - public EnumSchema build() { - return new EnumSchema(description, enumValues, enumNames, title); - } - } - }// @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - @JsonDeserialize // Override to default deserializer to avoid recursion - public record BooleanSchema(// @formatter:off - @JsonProperty("default") boolean defaultValue, - @JsonProperty("description") String description, - @JsonProperty("title") String title) implements Schema { - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - private boolean defaultValue; - private String description; - private String title; - - public Builder defaultValue(boolean defaultValue) { - this.defaultValue = defaultValue; - return this; - } - - public Builder description(String description) { - this.description = description; - return this; - } - - public Builder title(String title) { - this.title = title; - return this; - } - - public BooleanSchema build() { - return new BooleanSchema(defaultValue, description, title); - } - } - }// @formatter:on - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - @JsonDeserialize // Override to default deserializer to avoid recursion - public record NumberSchema(// @formatter:off - @JsonProperty("description") String description, - @JsonProperty("minimum") double minimum, - @JsonProperty("maximum") double maximum, - @JsonProperty("title") String title, - @JsonProperty("type") TypeVariant typeVariant) implements Schema { - - public enum TypeVariant { - @JsonProperty("number") NUMBER("number"), - @JsonProperty("integer") INTEGER("integer"); - - private final String name; - - TypeVariant(String value) { - name = value; - } - - @Override - public String toString() { - return this.name; - } - } - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - private String description; - private double minimum; - private double maximum; - private String title; - private TypeVariant type; - - public Builder description(String description) { - this.description = description; - return this; - } - - public Builder minimum(double minimum) { - this.minimum = minimum; - return this; - } - - public Builder maximum(double maximum) { - this.maximum = maximum; - return this; - } - - public Builder title(String title) { - this.title = title; - return this; - } - - public Builder type(TypeVariant type) { - this.type = type; - return this; - } - - public NumberSchema build() { - return new NumberSchema(description, minimum, maximum, title, type); - } - } - }// @formatter:on - // --------------------------- // Pagination Interfaces // --------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchemaObjectDeserializer.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchemaObjectDeserializer.java deleted file mode 100644 index 8d41e1c9..00000000 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchemaObjectDeserializer.java +++ /dev/null @@ -1,53 +0,0 @@ -package io.modelcontextprotocol.spec; - -import com.fasterxml.jackson.core.JsonParser; -import com.fasterxml.jackson.databind.DeserializationContext; -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.ObjectReader; -import com.fasterxml.jackson.databind.deser.std.StdDeserializer; -import com.fasterxml.jackson.databind.node.TreeTraversingParser; - -import java.io.IOException; - -public class McpSchemaObjectDeserializer extends StdDeserializer { - - public McpSchemaObjectDeserializer() { - this(null); - } - - public McpSchemaObjectDeserializer(Class vc) { - super(vc); - } - - @Override - public McpSchema.Schema deserialize(JsonParser jp, DeserializationContext ctx) throws IOException { - ObjectMapper mapper = (ObjectMapper) jp.getCodec(); - JsonNode root = mapper.readTree(jp); - - String type = root.path("type").asText(); - if ("string".equals(type)) { - if (root.has("enum")) { - return readValue(mapper, root, McpSchema.EnumSchema.class); - } - else { - return readValue(mapper, root, McpSchema.StringSchema.class); - } - } - else if ("boolean".equals(type)) { - return readValue(mapper, root, McpSchema.BooleanSchema.class); - } - else if ("number".equals(type) || "integer".equals(type)) { - return readValue(mapper, root, McpSchema.NumberSchema.class); - } - - throw new RuntimeException("Unknown schema type: " + type); - } - - private T readValue(ObjectMapper mapper, JsonNode node, Class clazz) throws IOException { - ObjectReader reader = mapper.readerFor(clazz); - TreeTraversingParser treeParser = new TreeTraversingParser(node, mapper); - return reader.readValue(treeParser); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java index 0d1bef1b..e6cde8e3 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java @@ -352,18 +352,19 @@ void testSamplingCreateMessageRequestHandlingWithNullHandler() { } @Test + @SuppressWarnings("unchecked") void testElicitationCreateRequestHandling() { MockMcpClientTransport transport = initializationEnabledTransport(); // Create a test elicitation handler that echoes back the input Function> elicitationHandler = request -> { assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isInstanceOf(McpSchema.PrimitiveSchemaDefinition.class); - assertThat(request.requestedSchema().type()).isEqualTo("object"); + assertThat(request.requestedSchema()).isInstanceOf(Map.class); + assertThat(request.requestedSchema().get("type")).isEqualTo("object"); - var properties = request.requestedSchema().properties(); + var properties = request.requestedSchema().get("properties"); assertThat(properties).isNotNull(); - assertThat(properties.get("message")).isInstanceOf(McpSchema.StringSchema.class); + assertThat(((Map) properties).get("message")).isInstanceOf(Map.class); return Mono.just(McpSchema.ElicitResult.builder() .message(McpSchema.ElicitResult.Action.ACCEPT) @@ -382,9 +383,7 @@ void testElicitationCreateRequestHandling() { // Create a mock elicitation var elicitRequest = McpSchema.ElicitRequest.builder() .message("Test message") - .requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder() - .properties(Map.of("message", McpSchema.StringSchema.builder().build())) - .build()) + .requestedSchema(Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) .build(); // Simulate incoming request @@ -429,9 +428,7 @@ void testElicitationFailRequestHandling(McpSchema.ElicitResult.Action action) { // Create a mock elicitation var elicitRequest = McpSchema.ElicitRequest.builder() .message("Test message") - .requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder() - .properties(Map.of("message", McpSchema.StringSchema.builder().build())) - .build()) + .requestedSchema(Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) .build(); // Simulate incoming request @@ -470,14 +467,8 @@ void testElicitationCreateRequestHandlingWithoutCapability() { // Create a mock elicitation var elicitRequest = new McpSchema.ElicitRequest("test", - McpSchema.PrimitiveSchemaDefinition.builder() - .properties(Map.of("test", - McpSchema.BooleanSchema.builder() - .defaultValue(true) - .description("test-description") - .title("test-title") - .build())) - .build()); + Map.of("type", "object", "properties", Map.of("test", Map.of("type", "boolean", "defaultValue", true, + "description", "test-description", "title", "test-title")))); // Simulate incoming request McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index f296cb9c..dc9d1cfa 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -393,9 +393,8 @@ void testCreateElicitationSuccess() { var elicitationRequest = ElicitRequest.builder() .message("Test message") - .requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder() - .properties(Map.of("message", McpSchema.StringSchema.builder().build())) - .build()) + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) .build(); StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { @@ -460,9 +459,8 @@ void testCreateElicitationWithRequestTimeoutSuccess() { var elicitationRequest = ElicitRequest.builder() .message("Test message") - .requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder() - .properties(Map.of("message", McpSchema.StringSchema.builder().build())) - .build()) + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) .build(); StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { @@ -524,9 +522,8 @@ void testCreateElicitationWithRequestTimeoutFail() { var elicitationRequest = ElicitRequest.builder() .message("Test message") - .requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder() - .properties(Map.of("message", McpSchema.StringSchema.builder().build())) - .build()) + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) .build(); StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java index 0db6cac2..99015d8c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java @@ -812,10 +812,8 @@ void testCreateMessageResult() throws Exception { @Test void testCreateElicitationRequest() throws Exception { McpSchema.ElicitRequest request = McpSchema.ElicitRequest.builder() - .requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder() - .required(List.of("a")) - .properties(Map.of("foo", McpSchema.StringSchema.builder().build())) - .build()) + .requestedSchema(Map.of("type", "object", "required", List.of("a"), "properties", + Map.of("foo", Map.of("type", "string")))) .build(); String value = mapper.writeValueAsString(request); @@ -876,67 +874,4 @@ void testListRootsResult() throws Exception { } - // Schema Tests - - @Test - void testSchema() throws Exception { - McpSchema.PrimitiveSchemaDefinition schemaDefinition = McpSchema.PrimitiveSchemaDefinition.builder() - .properties(Map.of("foo", - McpSchema.StringSchema.builder() - .title("title") - .description("description") - .format(McpSchema.StringSchema.Format.URI) - .maxLength(10) - .minLength(1) - .build(), - "bar", - McpSchema.EnumSchema.builder() - .title("title") - .description("description") - .enumNames(List.of("A", "B", "C")) - .enumValues(List.of("a", "b", "c")) - .build(), - "baz", - McpSchema.NumberSchema.builder() - .title("title") - .description("description") - .maximum(10) - .minimum(1) - .type(McpSchema.NumberSchema.TypeVariant.INTEGER) - .build(), - "baz2", - McpSchema.NumberSchema.builder() - .title("title") - .description("description") - .maximum(0.2) - .minimum(0.1) - .type(McpSchema.NumberSchema.TypeVariant.NUMBER) - .build(), - "buz", - McpSchema.BooleanSchema.builder() - .title("title") - .description("description") - .defaultValue(true) - .build())) - .required(List.of("foo")) - .build(); - - String value = mapper.writeValueAsString(schemaDefinition); - - assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo( - json(""" - {"properties":{"bar":{"description":"description","enum":["a", "b", "c"],"enumNames":["A", "B", "C"],"title":"title","type":"string"},"baz":{"description":"description","maximum":10.0,"minimum":1.0,"title":"title","type":"integer"},"baz2":{"description":"description","maximum":0.2,"minimum":0.1,"title":"title","type":"number"},"buz":{"default":true,"description":"description","title":"title","type":"boolean"},"foo":{"description":"description","format":"uri","maxLength":10,"minLength":1,"title":"title","type":"string"}},"required":["foo"],"type":"object"}""")); - - // Attempt to go the other way, since Schema is a complex type to (de)serialize - // and behaves differently when - // serialized vs. deserialized due to the string/number types being used for - // multiple concrete types - McpSchema.PrimitiveSchemaDefinition schemaDefinition2 = mapper.readValue(value, - McpSchema.PrimitiveSchemaDefinition.class); - assertThat(schemaDefinition2).isEqualTo(schemaDefinition); - } - } From a775f6503b61ce45f9072d8b12a83b9564b412d6 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Mon, 9 Jun 2025 10:27:44 -0700 Subject: [PATCH 5/5] chore: remove unused import --- mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java | 1 - 1 file changed, 1 deletion(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index a3e82953..9dae0826 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -18,7 +18,6 @@ import com.fasterxml.jackson.annotation.JsonTypeInfo.As; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory;