From 53012ad9557cdf57dfc2118d4aff9bfef1b45b54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Tue, 27 May 2025 15:34:06 +0200 Subject: [PATCH 01/20] WIP --- .../client/transport/Main.java | 50 +++ .../WebClientStreamableHttpTransport.java | 291 ++++++++++++++++++ ...bClientStreamableHttpAsyncClientTests.java | 40 +++ .../MockMcpTransport.java | 13 + .../spec/McpClientTransport.java | 4 + .../modelcontextprotocol/spec/McpSchema.java | 12 +- .../spec/McpSessionNotFoundException.java | 7 + 7 files changed, 412 insertions(+), 5 deletions(-) create mode 100644 mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/Main.java create mode 100644 mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java create mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpSessionNotFoundException.java diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/Main.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/Main.java new file mode 100644 index 00000000..0c85a551 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/Main.java @@ -0,0 +1,50 @@ +package io.modelcontextprotocol.client.transport; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.spec.McpSchema; +import org.springframework.web.reactive.function.client.WebClient; + +public class Main { + public static void main(String[] args) { + McpAsyncClient client = McpClient.async( + new WebClientStreamableHttpTransport(new ObjectMapper(), + WebClient.builder().baseUrl("http://localhost:3001"), + "/mcp", true, false) + ).build(); + + /* + Inspector does this: + 1. -> POST initialize request + 2. <- capabilities response (with sessionId) + 3. -> POST initialized notification + 4. -> GET initialize SSE connection (with sessionId) + + VS + + 1. -> GET initialize SSE connection + 2. <- 2xx ok with sessionId + 3. -> POST initialize request + 4. <- capabilities response + 5. -> POST initialized notification + + + SERVER-A + SERVER-B + LOAD BALANCING between SERVER-A and SERVER-B + STATELESS SERVER + + 1. -> (A) POST initialize request + 2. <- (A) 2xx ok with capabilities + 3. -> (B) POST initialized notification + 4. -> (B) 2xx ok + 5. -> (A or B) POST request tools + 6. -> 2xx response + */ + + client.initialize().flatMap(r -> client.listTools()) + .map(McpSchema.ListToolsResult::tools) + .doOnNext(System.out::println) + .block(); + } +} diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java new file mode 100644 index 00000000..e9c0e432 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -0,0 +1,291 @@ +package io.modelcontextprotocol.client.transport; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSessionNotFoundException; +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.context.ContextView; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; + +import java.io.IOException; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +public class WebClientStreamableHttpTransport implements McpClientTransport { + + private static final Logger logger = LoggerFactory.getLogger(WebClientStreamableHttpTransport.class); + + /** + * Event type for JSON-RPC messages received through the SSE connection. The server + * sends messages with this event type to transmit JSON-RPC protocol data. + */ + private static final String MESSAGE_EVENT_TYPE = "message"; + + private final ObjectMapper objectMapper; + private final WebClient webClient; + private final String endpoint; + private final boolean openConnectionOnStartup; + private final boolean resumableStreams; + + private AtomicReference, + Mono>> handler = new AtomicReference<>(); + + private final Disposable.Composite openConnections = Disposables.composite(); + private final AtomicBoolean initialized = new AtomicBoolean(); + private final AtomicReference sessionId = new AtomicReference<>(); + + public WebClientStreamableHttpTransport( + ObjectMapper objectMapper, + WebClient.Builder webClientBuilder, + String endpoint, + boolean resumableStreams, + boolean openConnectionOnStartup) { + this.objectMapper = objectMapper; + this.webClient = webClientBuilder.build(); + this.endpoint = endpoint; + this.resumableStreams = resumableStreams; + this.openConnectionOnStartup = openConnectionOnStartup; + } + + @Override + public Mono connect(Function, Mono> handler) { + if (this.openConnections.isDisposed()) { + return Mono.error(new RuntimeException("Transport already disposed")); + } + this.handler.set(handler); + return openConnectionOnStartup ? startOrResumeSession(null) : Mono.empty(); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(this.openConnections::dispose); + } + + private void reconnect(McpStream stream, ContextView ctx) { + Disposable connection = this.startOrResumeSession(stream) + .contextWrite(ctx) + .subscribe(); + this.openConnections.add(connection); + } + + private Mono startOrResumeSession(McpStream stream) { + return Mono.create(sink -> { + // Here we attempt to initialize the client. + // In case the server supports SSE, we will establish a long-running session here and + // listen for messages. + // If it doesn't, nothing actually happens here, that's just the way it is... + + Disposable connection = webClient.get() + .uri(this.endpoint) + .accept(MediaType.TEXT_EVENT_STREAM) + .headers(httpHeaders -> { + if (sessionId.get() != null) { + httpHeaders.add("mcp-session-id", sessionId.get()); + } + if (stream != null && stream.lastId() != null) { + httpHeaders.add("last-event-id", stream.lastId()); + } + }) + .exchangeToFlux(response -> { + // Per spec, we are not checking whether it's 2xx, but only if the Accept header is proper. + if (response.headers().contentType().isPresent() + && response.headers().contentType().get().isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) { + + sink.success(); + + McpStream sessionStream = stream != null ? stream : new McpStream(this.resumableStreams); + + Flux, Iterable>> idWithMessages = + response.bodyToFlux(new ParameterizedTypeReference>() { + }).map(this::parse); + + return sessionStream.consumeSseStream(idWithMessages); + } else if (response.statusCode().isSameCodeAs(HttpStatus.METHOD_NOT_ALLOWED)) { + sink.success(); + logger.info("The server does not support SSE streams, using request-response mode."); + return Flux.empty(); + } else { + return response.createError().doOnError(e -> { + sink.error(new RuntimeException("Connection on client startup failed", e)); + }).flux(); + } + }) + // TODO: Consider retries - examine cause to decide whether a retry is needed. + .contextWrite(sink.contextView()) + .subscribe(); + this.openConnections.add(connection); + }); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return Mono.create(sink -> { + System.out.println("Sending message " + message); + // Here we attempt to initialize the client. + // In case the server supports SSE, we will establish a long-running session here and + // listen for messages. + // If it doesn't, nothing actually happens here, that's just the way it is... + Disposable connection = webClient.post() + .uri(this.endpoint) + .accept(MediaType.TEXT_EVENT_STREAM, MediaType.APPLICATION_JSON) + .headers(httpHeaders -> { + if (sessionId.get() != null) { + httpHeaders.add("mcp-session-id", sessionId.get()); + } + }) + .bodyValue(message) + .exchangeToFlux(response -> { + // TODO: this goes into the request phase + if (!initialized.compareAndExchange(false, true)) { + if (!response.headers().header("mcp-session-id").isEmpty()) { + sessionId.set(response.headers().asHttpHeaders().getFirst("mcp-session-id")); + // Once we have a session, we try to open an async stream for the server to send notifications and requests out-of-band. + startOrResumeSession(null) + .contextWrite(sink.contextView()) + .subscribe(); + } + } + + // The spec mentions only ACCEPTED, but the existing SDKs can return 200 OK for notifications +// if (!response.statusCode().isSameCodeAs(HttpStatus.ACCEPTED)) { + if (!response.statusCode().is2xxSuccessful()) { + if (response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND)) { + logger.info("Session {} was not found on the MCP server", sessionId.get()); + + McpSessionNotFoundException notFoundException = new McpSessionNotFoundException("Session " + sessionId.get() + " not found"); + // inform the caller of sendMessage + sink.error(notFoundException); + // inform the stream/connection subscriber + return Flux.error(notFoundException); + } + return response.createError().doOnError(e -> { + sink.error(new RuntimeException("Sending request failed", e)); + }).flux(); + } + + // Existing SDKs consume notifications with no response body nor content type + if (response.headers().contentType().isEmpty()) { + sink.success(); + return Flux.empty(); +// return response.createError().doOnError(e -> { +//// sink.error(new RuntimeException("Response has no content type")); +// }).flux(); + } + + MediaType contentType = response.headers().contentType().get(); + + if (contentType.isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) { + sink.success(); + McpStream sessionStream = new McpStream(this.resumableStreams); + + Flux, Iterable>> idWithMessages = + response.bodyToFlux(new ParameterizedTypeReference>() { + }).map(this::parse); + + return sessionStream.consumeSseStream(idWithMessages); + } else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) { + sink.success(); +// return response.bodyToMono(new ParameterizedTypeReference>() {}); + return response.bodyToMono(String.class) + .>handle((responseMessage, s) -> { + try { + McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema.deserializeJsonRpcMessage(objectMapper, responseMessage); + s.next(List.of(jsonRpcResponse)); + } catch (IOException e) { + s.error(e); + } + }) + .flatMapIterable(Function.identity()); +// .map(Mono::just) +// .flatMap(this.handler.get()); + } else { + sink.error(new RuntimeException("Unknown media type")); + return Flux.empty(); + } + }) + .map(Mono::just) + .flatMap(this.handler.get()) + // TODO: Consider retries - examine cause to decide whether a retry is needed. + .contextWrite(sink.contextView()) + .subscribe(); + this.openConnections.add(connection); + }); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return this.objectMapper.convertValue(data, typeRef); + } + + private Tuple2, Iterable> parse(ServerSentEvent event) { + if (MESSAGE_EVENT_TYPE.equals(event.event())) { + try { + // TODO: support batching + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, event.data()); + return Tuples.of(Optional.ofNullable(event.id()), List.of(message)); + } + catch (IOException ioException) { + throw new McpError("Error parsing JSON-RPC message: " + event.data()); + } + } + else { + throw new McpError("Received unrecognized SSE event type: " + event.event()); + } + } + + private class McpStream { + + private static final AtomicLong counter = new AtomicLong(); + + private final AtomicReference lastId = new AtomicReference<>(); + + private final long streamId; + private final boolean resumable; + + McpStream(boolean resumable) { + this.streamId = counter.getAndIncrement(); + this.resumable = resumable; + } + + String lastId() { + return this.lastId.get(); + } + + Flux consumeSseStream(Publisher, Iterable>> eventStream) { + return Flux.deferContextual(ctx -> + Flux.from(eventStream) + .doOnError(e -> { + // TODO: examine which error :) + if (resumable) { + Disposable connection = WebClientStreamableHttpTransport.this.startOrResumeSession(this) + .contextWrite(ctx) + .subscribe(); + WebClientStreamableHttpTransport.this.openConnections.add(connection); + } + }) + .doOnNext(idAndMessage -> idAndMessage.getT1().ifPresent(this.lastId::set)) + .flatMapIterable(Tuple2::getT2) + ); + } + + } +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java new file mode 100644 index 00000000..bf5440cb --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java @@ -0,0 +1,40 @@ +package io.modelcontextprotocol.client; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.images.builder.ImageFromDockerfile; + +@Timeout(15) +public class WebClientStreamableHttpAsyncClientTests extends AbstractMcpAsyncClientTests { + + static String host = "http://localhost:3001"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server-streamable:v2") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @Override + protected McpClientTransport createMcpTransport() { + return new WebClientStreamableHttpTransport(new ObjectMapper(), WebClient.builder(), "/mcp", true, false); + } + + @Override + protected void onStart() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @Override + public void onClose() { + container.stop(); + } +} diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java b/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java index 5484a63c..0809ae72 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java @@ -7,6 +7,7 @@ import java.util.ArrayList; import java.util.List; import java.util.function.BiConsumer; +import java.util.function.Consumer; import java.util.function.Function; import com.fasterxml.jackson.core.type.TypeReference; @@ -70,6 +71,18 @@ public McpSchema.JSONRPCMessage getLastSentMessage() { private volatile boolean connected = false; +// @Override +// public Mono connect(Consumer consumer) { +// if (connected) { +// return Mono.error(new IllegalStateException("Already connected")); +// } +// connected = true; +// return inbound.asFlux() +// .doOnNext(consumer) +// .doFinally(signal -> connected = false) +// .then(); +// } + @Override public Mono connect(Function, Mono> handler) { if (connected) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java index f2909124..400e1bd4 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java @@ -3,6 +3,7 @@ */ package io.modelcontextprotocol.spec; +import java.util.function.Consumer; import java.util.function.Function; import reactor.core.publisher.Mono; @@ -15,6 +16,9 @@ */ public interface McpClientTransport extends McpTransport { +// @Deprecated Mono connect(Function, Mono> handler); +// Mono connect(Consumer consumer); + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 8df8a158..33ab8571 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -10,11 +10,7 @@ import java.util.List; import java.util.Map; -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.annotation.JsonSubTypes; -import com.fasterxml.jackson.annotation.JsonTypeInfo; +import com.fasterxml.jackson.annotation.*; import com.fasterxml.jackson.annotation.JsonTypeInfo.As; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; @@ -181,6 +177,8 @@ public sealed interface JSONRPCMessage permits JSONRPCRequest, JSONRPCNotificati @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) + // TODO: batching support + // @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) public record JSONRPCRequest( // @formatter:off @JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("method") String method, @@ -190,6 +188,8 @@ public record JSONRPCRequest( // @formatter:off @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) + // TODO: batching support + // @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) public record JSONRPCNotification( // @formatter:off @JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("method") String method, @@ -198,6 +198,8 @@ public record JSONRPCNotification( // @formatter:off @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) + // TODO: batching support + // @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) public record JSONRPCResponse( // @formatter:off @JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("id") Object id, diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSessionNotFoundException.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSessionNotFoundException.java new file mode 100644 index 00000000..b2c3554e --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSessionNotFoundException.java @@ -0,0 +1,7 @@ +package io.modelcontextprotocol.spec; + +public class McpSessionNotFoundException extends RuntimeException { + public McpSessionNotFoundException(String message) { + super(message); + } +} From 26e8af0f50f69ea044302d3b0f6c128a5231d53a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Wed, 28 May 2025 09:37:45 +0200 Subject: [PATCH 02/20] Use new docker images --- .../WebClientStreamableHttpTransport.java | 311 ++++++++++-------- ...bClientStreamableHttpAsyncClientTests.java | 53 +-- ...ebClientStreamableHttpSyncClientTests.java | 42 +++ .../client/WebFluxSseMcpAsyncClientTests.java | 3 +- .../client/WebFluxSseMcpSyncClientTests.java | 3 +- .../WebFluxSseClientTransportTests.java | 3 +- .../MockMcpTransport.java | 22 +- .../spec/McpClientTransport.java | 4 +- .../spec/McpSessionNotFoundException.java | 8 +- .../client/HttpSseMcpAsyncClientTests.java | 3 +- .../client/HttpSseMcpSyncClientTests.java | 3 +- .../HttpClientSseClientTransportTests.java | 3 +- 12 files changed, 265 insertions(+), 193 deletions(-) create mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java index e9c0e432..c37a3b21 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -41,24 +41,25 @@ public class WebClientStreamableHttpTransport implements McpClientTransport { private static final String MESSAGE_EVENT_TYPE = "message"; private final ObjectMapper objectMapper; + private final WebClient webClient; + private final String endpoint; + private final boolean openConnectionOnStartup; + private final boolean resumableStreams; - private AtomicReference, - Mono>> handler = new AtomicReference<>(); + private AtomicReference, Mono>> handler = new AtomicReference<>(); private final Disposable.Composite openConnections = Disposables.composite(); + private final AtomicBoolean initialized = new AtomicBoolean(); + private final AtomicReference sessionId = new AtomicReference<>(); - public WebClientStreamableHttpTransport( - ObjectMapper objectMapper, - WebClient.Builder webClientBuilder, - String endpoint, - boolean resumableStreams, - boolean openConnectionOnStartup) { + public WebClientStreamableHttpTransport(ObjectMapper objectMapper, WebClient.Builder webClientBuilder, + String endpoint, boolean resumableStreams, boolean openConnectionOnStartup) { this.objectMapper = objectMapper; this.webClient = webClientBuilder.build(); this.endpoint = endpoint; @@ -81,57 +82,61 @@ public Mono closeGracefully() { } private void reconnect(McpStream stream, ContextView ctx) { - Disposable connection = this.startOrResumeSession(stream) - .contextWrite(ctx) - .subscribe(); + Disposable connection = this.startOrResumeSession(stream).contextWrite(ctx).subscribe(); this.openConnections.add(connection); } private Mono startOrResumeSession(McpStream stream) { return Mono.create(sink -> { // Here we attempt to initialize the client. - // In case the server supports SSE, we will establish a long-running session here and + // In case the server supports SSE, we will establish a long-running session + // here and // listen for messages. // If it doesn't, nothing actually happens here, that's just the way it is... Disposable connection = webClient.get() - .uri(this.endpoint) - .accept(MediaType.TEXT_EVENT_STREAM) - .headers(httpHeaders -> { - if (sessionId.get() != null) { - httpHeaders.add("mcp-session-id", sessionId.get()); - } - if (stream != null && stream.lastId() != null) { - httpHeaders.add("last-event-id", stream.lastId()); - } - }) - .exchangeToFlux(response -> { - // Per spec, we are not checking whether it's 2xx, but only if the Accept header is proper. - if (response.headers().contentType().isPresent() - && response.headers().contentType().get().isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) { - - sink.success(); - - McpStream sessionStream = stream != null ? stream : new McpStream(this.resumableStreams); - - Flux, Iterable>> idWithMessages = - response.bodyToFlux(new ParameterizedTypeReference>() { - }).map(this::parse); - - return sessionStream.consumeSseStream(idWithMessages); - } else if (response.statusCode().isSameCodeAs(HttpStatus.METHOD_NOT_ALLOWED)) { - sink.success(); - logger.info("The server does not support SSE streams, using request-response mode."); - return Flux.empty(); - } else { - return response.createError().doOnError(e -> { - sink.error(new RuntimeException("Connection on client startup failed", e)); - }).flux(); - } - }) - // TODO: Consider retries - examine cause to decide whether a retry is needed. - .contextWrite(sink.contextView()) - .subscribe(); + .uri(this.endpoint) + .accept(MediaType.TEXT_EVENT_STREAM) + .headers(httpHeaders -> { + if (sessionId.get() != null) { + httpHeaders.add("mcp-session-id", sessionId.get()); + } + if (stream != null && stream.lastId() != null) { + httpHeaders.add("last-event-id", stream.lastId()); + } + }) + .exchangeToFlux(response -> { + // Per spec, we are not checking whether it's 2xx, but only if the + // Accept header is proper. + if (response.headers().contentType().isPresent() + && response.headers().contentType().get().isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) { + + sink.success(); + + McpStream sessionStream = stream != null ? stream : new McpStream(this.resumableStreams); + + Flux, Iterable>> idWithMessages = response + .bodyToFlux(new ParameterizedTypeReference>() { + }) + .map(this::parse); + + return sessionStream.consumeSseStream(idWithMessages); + } + else if (response.statusCode().isSameCodeAs(HttpStatus.METHOD_NOT_ALLOWED)) { + sink.success(); + logger.info("The server does not support SSE streams, using request-response mode."); + return Flux.empty(); + } + else { + return response.createError().doOnError(e -> { + sink.error(new RuntimeException("Connection on client startup failed", e)); + }).flux(); + } + }) + // TODO: Consider retries - examine cause to decide whether a retry is + // needed. + .contextWrite(sink.contextView()) + .subscribe(); this.openConnections.add(connection); }); } @@ -141,92 +146,106 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { return Mono.create(sink -> { System.out.println("Sending message " + message); // Here we attempt to initialize the client. - // In case the server supports SSE, we will establish a long-running session here and + // In case the server supports SSE, we will establish a long-running session + // here and // listen for messages. // If it doesn't, nothing actually happens here, that's just the way it is... Disposable connection = webClient.post() - .uri(this.endpoint) - .accept(MediaType.TEXT_EVENT_STREAM, MediaType.APPLICATION_JSON) - .headers(httpHeaders -> { - if (sessionId.get() != null) { - httpHeaders.add("mcp-session-id", sessionId.get()); - } - }) - .bodyValue(message) - .exchangeToFlux(response -> { - // TODO: this goes into the request phase - if (!initialized.compareAndExchange(false, true)) { - if (!response.headers().header("mcp-session-id").isEmpty()) { - sessionId.set(response.headers().asHttpHeaders().getFirst("mcp-session-id")); - // Once we have a session, we try to open an async stream for the server to send notifications and requests out-of-band. - startOrResumeSession(null) - .contextWrite(sink.contextView()) - .subscribe(); - } - } - - // The spec mentions only ACCEPTED, but the existing SDKs can return 200 OK for notifications -// if (!response.statusCode().isSameCodeAs(HttpStatus.ACCEPTED)) { - if (!response.statusCode().is2xxSuccessful()) { - if (response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND)) { - logger.info("Session {} was not found on the MCP server", sessionId.get()); - - McpSessionNotFoundException notFoundException = new McpSessionNotFoundException("Session " + sessionId.get() + " not found"); - // inform the caller of sendMessage - sink.error(notFoundException); - // inform the stream/connection subscriber - return Flux.error(notFoundException); - } - return response.createError().doOnError(e -> { - sink.error(new RuntimeException("Sending request failed", e)); - }).flux(); + .uri(this.endpoint) + .accept(MediaType.TEXT_EVENT_STREAM, MediaType.APPLICATION_JSON) + .headers(httpHeaders -> { + if (sessionId.get() != null) { + httpHeaders.add("mcp-session-id", sessionId.get()); + } + }) + .bodyValue(message) + .exchangeToFlux(response -> { + // TODO: this goes into the request phase + if (!initialized.compareAndExchange(false, true)) { + if (!response.headers().header("mcp-session-id").isEmpty()) { + sessionId.set(response.headers().asHttpHeaders().getFirst("mcp-session-id")); + // Once we have a session, we try to open an async stream for + // the server to send notifications and requests out-of-band. + startOrResumeSession(null).contextWrite(sink.contextView()).subscribe(); } - - // Existing SDKs consume notifications with no response body nor content type - if (response.headers().contentType().isEmpty()) { - sink.success(); - return Flux.empty(); -// return response.createError().doOnError(e -> { -//// sink.error(new RuntimeException("Response has no content type")); -// }).flux(); + } + + // The spec mentions only ACCEPTED, but the existing SDKs can return + // 200 OK for notifications + // if (!response.statusCode().isSameCodeAs(HttpStatus.ACCEPTED)) { + if (!response.statusCode().is2xxSuccessful()) { + if (response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND)) { + logger.info("Session {} was not found on the MCP server", sessionId.get()); + + McpSessionNotFoundException notFoundException = new McpSessionNotFoundException( + "Session " + sessionId.get() + " not found"); + // inform the caller of sendMessage + sink.error(notFoundException); + // inform the stream/connection subscriber + return Flux.error(notFoundException); } - - MediaType contentType = response.headers().contentType().get(); - - if (contentType.isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) { - sink.success(); - McpStream sessionStream = new McpStream(this.resumableStreams); - - Flux, Iterable>> idWithMessages = - response.bodyToFlux(new ParameterizedTypeReference>() { - }).map(this::parse); - - return sessionStream.consumeSseStream(idWithMessages); - } else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) { - sink.success(); -// return response.bodyToMono(new ParameterizedTypeReference>() {}); - return response.bodyToMono(String.class) - .>handle((responseMessage, s) -> { - try { - McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema.deserializeJsonRpcMessage(objectMapper, responseMessage); - s.next(List.of(jsonRpcResponse)); - } catch (IOException e) { - s.error(e); - } + return response.createError().doOnError(e -> { + sink.error(new RuntimeException("Sending request failed", e)); + }).flux(); + } + + // Existing SDKs consume notifications with no response body nor + // content type + if (response.headers().contentType().isEmpty()) { + sink.success(); + return Flux.empty(); + // return + // response.createError().doOnError(e -> + // { + //// sink.error(new RuntimeException("Response has no content + // type")); + // }).flux(); + } + + MediaType contentType = response.headers().contentType().get(); + + if (contentType.isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) { + sink.success(); + McpStream sessionStream = new McpStream(this.resumableStreams); + + Flux, Iterable>> idWithMessages = response + .bodyToFlux(new ParameterizedTypeReference>() { }) - .flatMapIterable(Function.identity()); -// .map(Mono::just) -// .flatMap(this.handler.get()); - } else { - sink.error(new RuntimeException("Unknown media type")); - return Flux.empty(); - } - }) - .map(Mono::just) - .flatMap(this.handler.get()) - // TODO: Consider retries - examine cause to decide whether a retry is needed. - .contextWrite(sink.contextView()) - .subscribe(); + .map(this::parse); + + return sessionStream.consumeSseStream(idWithMessages); + } + else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) { + sink.success(); + // return response.bodyToMono(new + // ParameterizedTypeReference>() + // {}); + return response.bodyToMono( + String.class).>handle((responseMessage, s) -> { + try { + McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema + .deserializeJsonRpcMessage(objectMapper, responseMessage); + s.next(List.of(jsonRpcResponse)); + } + catch (IOException e) { + s.error(e); + } + }) + .flatMapIterable(Function.identity()); + // .map(Mono::just) + // .flatMap(this.handler.get()); + } + else { + sink.error(new RuntimeException("Unknown media type")); + return Flux.empty(); + } + }) + .map(Mono::just) + .flatMap(this.handler.get()) + // TODO: Consider retries - examine cause to decide whether a retry is + // needed. + .contextWrite(sink.contextView()) + .subscribe(); this.openConnections.add(connection); }); } @@ -259,10 +278,11 @@ private class McpStream { private final AtomicReference lastId = new AtomicReference<>(); private final long streamId; + private final boolean resumable; - McpStream(boolean resumable) { - this.streamId = counter.getAndIncrement(); + McpStream(boolean resumable) { + this.streamId = counter.getAndIncrement(); this.resumable = resumable; } @@ -270,22 +290,21 @@ String lastId() { return this.lastId.get(); } - Flux consumeSseStream(Publisher, Iterable>> eventStream) { - return Flux.deferContextual(ctx -> - Flux.from(eventStream) - .doOnError(e -> { - // TODO: examine which error :) - if (resumable) { - Disposable connection = WebClientStreamableHttpTransport.this.startOrResumeSession(this) - .contextWrite(ctx) - .subscribe(); - WebClientStreamableHttpTransport.this.openConnections.add(connection); - } - }) - .doOnNext(idAndMessage -> idAndMessage.getT1().ifPresent(this.lastId::set)) - .flatMapIterable(Tuple2::getT2) - ); + Flux consumeSseStream( + Publisher, Iterable>> eventStream) { + return Flux.deferContextual(ctx -> Flux.from(eventStream).doOnError(e -> { + // TODO: examine which error :) + if (resumable) { + Disposable connection = WebClientStreamableHttpTransport.this.startOrResumeSession(this) + .contextWrite(ctx) + .subscribe(); + WebClientStreamableHttpTransport.this.openConnections.add(connection); + } + }) + .doOnNext(idAndMessage -> idAndMessage.getT1().ifPresent(this.lastId::set)) + .flatMapIterable(Tuple2::getT2)); } } + } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java index bf5440cb..94a035c0 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java @@ -12,29 +12,32 @@ @Timeout(15) public class WebClientStreamableHttpAsyncClientTests extends AbstractMcpAsyncClientTests { - static String host = "http://localhost:3001"; - - // Uses the https://github.com/tzolov/mcp-everything-server-docker-image - @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server-streamable:v2") - .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) - .withExposedPorts(3001) - .waitingFor(Wait.forHttp("/").forStatusCode(404)); - - @Override - protected McpClientTransport createMcpTransport() { - return new WebClientStreamableHttpTransport(new ObjectMapper(), WebClient.builder(), "/mcp", true, false); - } - - @Override - protected void onStart() { - container.start(); - int port = container.getMappedPort(3001); - host = "http://" + container.getHost() + ":" + port; - } - - @Override - public void onClose() { - container.stop(); - } + static String host = "http://localhost:3001"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @Override + protected McpClientTransport createMcpTransport() { + return new WebClientStreamableHttpTransport(new ObjectMapper(), WebClient.builder().baseUrl(host), "/mcp", true, + false); + } + + @Override + protected void onStart() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @Override + public void onClose() { + container.stop(); + } + } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java new file mode 100644 index 00000000..0764d348 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java @@ -0,0 +1,42 @@ +package io.modelcontextprotocol.client; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +@Timeout(15) +public class WebClientStreamableHttpSyncClientTests extends AbstractMcpSyncClientTests { + + static String host = "http://localhost:3001"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @Override + protected McpClientTransport createMcpTransport() { + return new WebClientStreamableHttpTransport(new ObjectMapper(), WebClient.builder().baseUrl(host), "/mcp", true, + false); + } + + @Override + protected void onStart() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @Override + public void onClose() { + container.stop(); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java index b43c1449..f0533cb4 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java @@ -26,7 +26,8 @@ class WebFluxSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java index 66ac8a6d..9b0959a3 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java @@ -26,7 +26,8 @@ class WebFluxSseMcpSyncClientTests extends AbstractMcpSyncClientTests { // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java index c757d3da..42b91d14 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java @@ -41,7 +41,8 @@ class WebFluxSseClientTransportTests { static String host = "http://localhost:3001"; @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java b/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java index 0809ae72..06ce2804 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java @@ -71,17 +71,17 @@ public McpSchema.JSONRPCMessage getLastSentMessage() { private volatile boolean connected = false; -// @Override -// public Mono connect(Consumer consumer) { -// if (connected) { -// return Mono.error(new IllegalStateException("Already connected")); -// } -// connected = true; -// return inbound.asFlux() -// .doOnNext(consumer) -// .doFinally(signal -> connected = false) -// .then(); -// } + // @Override + // public Mono connect(Consumer consumer) { + // if (connected) { + // return Mono.error(new IllegalStateException("Already connected")); + // } + // connected = true; + // return inbound.asFlux() + // .doOnNext(consumer) + // .doFinally(signal -> connected = false) + // .then(); + // } @Override public Mono connect(Function, Mono> handler) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java index 400e1bd4..5fb4cfc8 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java @@ -16,9 +16,9 @@ */ public interface McpClientTransport extends McpTransport { -// @Deprecated + // @Deprecated Mono connect(Function, Mono> handler); -// Mono connect(Consumer consumer); + // Mono connect(Consumer consumer); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSessionNotFoundException.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSessionNotFoundException.java index b2c3554e..c351620b 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSessionNotFoundException.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSessionNotFoundException.java @@ -1,7 +1,9 @@ package io.modelcontextprotocol.spec; public class McpSessionNotFoundException extends RuntimeException { - public McpSessionNotFoundException(String message) { - super(message); - } + + public McpSessionNotFoundException(String message) { + super(message); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java index fdff4b77..1b66a98c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java @@ -22,7 +22,8 @@ class HttpSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java index 204cf298..8646c1b4 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java @@ -22,7 +22,8 @@ class HttpSseMcpSyncClientTests extends AbstractMcpSyncClientTests { // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java index 762264de..1b1c7201 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java @@ -51,7 +51,8 @@ class HttpClientSseClientTransportTests { static String host = "http://localhost:3001"; @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); From 40bc356910261617b3b8cd9bd6890a63c532cd6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Wed, 28 May 2025 10:41:47 +0200 Subject: [PATCH 03/20] Provide disposables cleanup --- .../WebClientStreamableHttpTransport.java | 135 +++++++++--------- 1 file changed, 69 insertions(+), 66 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java index c37a3b21..19ee52fc 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -69,11 +69,16 @@ public WebClientStreamableHttpTransport(ObjectMapper objectMapper, WebClient.Bui @Override public Mono connect(Function, Mono> handler) { - if (this.openConnections.isDisposed()) { - return Mono.error(new RuntimeException("Transport already disposed")); - } - this.handler.set(handler); - return openConnectionOnStartup ? startOrResumeSession(null) : Mono.empty(); + return Mono.deferContextual(ctx -> { + if (this.openConnections.isDisposed()) { + return Mono.error(new RuntimeException("Transport already disposed")); + } + this.handler.set(handler); + if (openConnectionOnStartup) { + this.reconnect(null, ctx); + } + return Mono.empty(); + }); } @Override @@ -82,63 +87,58 @@ public Mono closeGracefully() { } private void reconnect(McpStream stream, ContextView ctx) { - Disposable connection = this.startOrResumeSession(stream).contextWrite(ctx).subscribe(); - this.openConnections.add(connection); - } - - private Mono startOrResumeSession(McpStream stream) { - return Mono.create(sink -> { - // Here we attempt to initialize the client. - // In case the server supports SSE, we will establish a long-running session - // here and - // listen for messages. - // If it doesn't, nothing actually happens here, that's just the way it is... - - Disposable connection = webClient.get() - .uri(this.endpoint) - .accept(MediaType.TEXT_EVENT_STREAM) - .headers(httpHeaders -> { - if (sessionId.get() != null) { - httpHeaders.add("mcp-session-id", sessionId.get()); - } - if (stream != null && stream.lastId() != null) { - httpHeaders.add("last-event-id", stream.lastId()); - } - }) - .exchangeToFlux(response -> { - // Per spec, we are not checking whether it's 2xx, but only if the - // Accept header is proper. - if (response.headers().contentType().isPresent() - && response.headers().contentType().get().isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) { - - sink.success(); + // Here we attempt to initialize the client. + // In case the server supports SSE, we will establish a long-running session + // here and + // listen for messages. + // If it doesn't, nothing actually happens here, that's just the way it is... + final AtomicReference disposableRef = new AtomicReference<>(); + Disposable connection = webClient.get() + .uri(this.endpoint) + .accept(MediaType.TEXT_EVENT_STREAM) + .headers(httpHeaders -> { + if (sessionId.get() != null) { + httpHeaders.add("mcp-session-id", sessionId.get()); + } + if (stream != null && stream.lastId() != null) { + httpHeaders.add("last-event-id", stream.lastId()); + } + }) + .exchangeToFlux(response -> { + // Per spec, we are not checking whether it's 2xx, but only if the + // Accept header is proper. + if (response.headers().contentType().isPresent() + && response.headers().contentType().get().isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) { - McpStream sessionStream = stream != null ? stream : new McpStream(this.resumableStreams); + McpStream sessionStream = stream != null ? stream : new McpStream(this.resumableStreams); - Flux, Iterable>> idWithMessages = response - .bodyToFlux(new ParameterizedTypeReference>() { - }) - .map(this::parse); + Flux, Iterable>> idWithMessages = response + .bodyToFlux(new ParameterizedTypeReference>() { + }) + .map(this::parse); - return sessionStream.consumeSseStream(idWithMessages); - } - else if (response.statusCode().isSameCodeAs(HttpStatus.METHOD_NOT_ALLOWED)) { - sink.success(); - logger.info("The server does not support SSE streams, using request-response mode."); - return Flux.empty(); - } - else { - return response.createError().doOnError(e -> { - sink.error(new RuntimeException("Connection on client startup failed", e)); - }).flux(); - } - }) - // TODO: Consider retries - examine cause to decide whether a retry is - // needed. - .contextWrite(sink.contextView()) - .subscribe(); - this.openConnections.add(connection); - }); + return sessionStream.consumeSseStream(idWithMessages); + } + else if (response.statusCode().isSameCodeAs(HttpStatus.METHOD_NOT_ALLOWED)) { + logger.info("The server does not support SSE streams, using request-response mode."); + return Flux.empty(); + } + else { + return response.createError().doOnError(e -> { + logger.info("Opening an SSE stream failed. This can be safely ignored.", e); + }).flux(); + } + }) + .doFinally(s -> { + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + this.openConnections.remove(ref); + } + }) + .contextWrite(ctx) + .subscribe(); + disposableRef.set(connection); + this.openConnections.add(connection); } @Override @@ -150,6 +150,7 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { // here and // listen for messages. // If it doesn't, nothing actually happens here, that's just the way it is... + final AtomicReference disposableRef = new AtomicReference<>(); Disposable connection = webClient.post() .uri(this.endpoint) .accept(MediaType.TEXT_EVENT_STREAM, MediaType.APPLICATION_JSON) @@ -166,7 +167,7 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { sessionId.set(response.headers().asHttpHeaders().getFirst("mcp-session-id")); // Once we have a session, we try to open an async stream for // the server to send notifications and requests out-of-band. - startOrResumeSession(null).contextWrite(sink.contextView()).subscribe(); + reconnect(null, sink.contextView()); } } @@ -242,10 +243,15 @@ else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) { }) .map(Mono::just) .flatMap(this.handler.get()) - // TODO: Consider retries - examine cause to decide whether a retry is - // needed. + .doFinally(s -> { + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + this.openConnections.remove(ref); + } + }) .contextWrite(sink.contextView()) .subscribe(); + disposableRef.set(connection); this.openConnections.add(connection); }); } @@ -295,10 +301,7 @@ Flux consumeSseStream( return Flux.deferContextual(ctx -> Flux.from(eventStream).doOnError(e -> { // TODO: examine which error :) if (resumable) { - Disposable connection = WebClientStreamableHttpTransport.this.startOrResumeSession(this) - .contextWrite(ctx) - .subscribe(); - WebClientStreamableHttpTransport.this.openConnections.add(connection); + reconnect(this, ctx); } }) .doOnNext(idAndMessage -> idAndMessage.getT1().ifPresent(this.lastId::set)) From e249145305564c79472a8c75a42bde59f851d013 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Wed, 28 May 2025 18:06:39 +0200 Subject: [PATCH 04/20] WIP.. more resiliency --- .../client/transport/Main.java | 70 +++++++-------- .../WebClientStreamableHttpTransport.java | 71 +++++++++------ .../client/McpAsyncClient.java | 88 +++++++++---------- .../spec/McpClientSession.java | 20 ++++- .../spec/McpClientTransport.java | 8 +- .../spec/McpTransportSession.java | 49 +++++++++++ 6 files changed, 194 insertions(+), 112 deletions(-) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/Main.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/Main.java index 0c85a551..e501ca1c 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/Main.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/Main.java @@ -7,44 +7,38 @@ import org.springframework.web.reactive.function.client.WebClient; public class Main { - public static void main(String[] args) { - McpAsyncClient client = McpClient.async( - new WebClientStreamableHttpTransport(new ObjectMapper(), - WebClient.builder().baseUrl("http://localhost:3001"), - "/mcp", true, false) - ).build(); - /* - Inspector does this: - 1. -> POST initialize request - 2. <- capabilities response (with sessionId) - 3. -> POST initialized notification - 4. -> GET initialize SSE connection (with sessionId) + public static void main(String[] args) { + McpAsyncClient client = McpClient + .async(new WebClientStreamableHttpTransport(new ObjectMapper(), + WebClient.builder().baseUrl("http://localhost:3001"), "/mcp", true, false)) + .build(); + + /* + * Inspector does this: 1. -> POST initialize request 2. <- capabilities response + * (with sessionId) 3. -> POST initialized notification 4. -> GET initialize SSE + * connection (with sessionId) + * + * VS + * + * 1. -> GET initialize SSE connection 2. <- 2xx ok with sessionId 3. -> POST + * initialize request 4. <- capabilities response 5. -> POST initialized + * notification + * + * + * SERVER-A + SERVER-B LOAD BALANCING between SERVER-A and SERVER-B STATELESS + * SERVER + * + * 1. -> (A) POST initialize request 2. <- (A) 2xx ok with capabilities 3. -> (B) + * POST initialized notification 4. -> (B) 2xx ok 5. -> (A or B) POST request + * tools 6. -> 2xx response + */ + + client.initialize() + .flatMap(r -> client.listTools()) + .map(McpSchema.ListToolsResult::tools) + .doOnNext(System.out::println) + .block(); + } - VS - - 1. -> GET initialize SSE connection - 2. <- 2xx ok with sessionId - 3. -> POST initialize request - 4. <- capabilities response - 5. -> POST initialized notification - - - SERVER-A + SERVER-B - LOAD BALANCING between SERVER-A and SERVER-B - STATELESS SERVER - - 1. -> (A) POST initialize request - 2. <- (A) 2xx ok with capabilities - 3. -> (B) POST initialized notification - 4. -> (B) 2xx ok - 5. -> (A or B) POST request tools - 6. -> 2xx response - */ - - client.initialize().flatMap(r -> client.listTools()) - .map(McpSchema.ListToolsResult::tools) - .doOnNext(System.out::println) - .block(); - } } diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java index 19ee52fc..55240dd5 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -2,10 +2,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSessionNotFoundException; +import io.modelcontextprotocol.spec.*; import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -28,6 +25,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import java.util.function.Function; public class WebClientStreamableHttpTransport implements McpClientTransport { @@ -52,11 +50,9 @@ public class WebClientStreamableHttpTransport implements McpClientTransport { private AtomicReference, Mono>> handler = new AtomicReference<>(); - private final Disposable.Composite openConnections = Disposables.composite(); + private final AtomicReference activeSession = new AtomicReference<>(); - private final AtomicBoolean initialized = new AtomicBoolean(); - - private final AtomicReference sessionId = new AtomicReference<>(); + private final AtomicReference> exceptionHandler = new AtomicReference<>(); public WebClientStreamableHttpTransport(ObjectMapper objectMapper, WebClient.Builder webClientBuilder, String endpoint, boolean resumableStreams, boolean openConnectionOnStartup) { @@ -65,14 +61,12 @@ public WebClientStreamableHttpTransport(ObjectMapper objectMapper, WebClient.Bui this.endpoint = endpoint; this.resumableStreams = resumableStreams; this.openConnectionOnStartup = openConnectionOnStartup; + this.activeSession.set(new McpTransportSession()); } @Override public Mono connect(Function, Mono> handler) { return Mono.deferContextual(ctx -> { - if (this.openConnections.isDisposed()) { - return Mono.error(new RuntimeException("Transport already disposed")); - } this.handler.set(handler); if (openConnectionOnStartup) { this.reconnect(null, ctx); @@ -81,9 +75,20 @@ public Mono connect(Function, Mono handler) { + this.exceptionHandler.set(handler); + } + @Override public Mono closeGracefully() { - return Mono.fromRunnable(this.openConnections::dispose); + return Mono.defer(() -> { + McpTransportSession currentSession = this.activeSession.get(); + if (currentSession != null) { + return currentSession.closeGracefully(); + } + return Mono.empty(); + }); } private void reconnect(McpStream stream, ContextView ctx) { @@ -93,12 +98,13 @@ private void reconnect(McpStream stream, ContextView ctx) { // listen for messages. // If it doesn't, nothing actually happens here, that's just the way it is... final AtomicReference disposableRef = new AtomicReference<>(); + final McpTransportSession transportSession = this.activeSession.get(); Disposable connection = webClient.get() .uri(this.endpoint) .accept(MediaType.TEXT_EVENT_STREAM) .headers(httpHeaders -> { - if (sessionId.get() != null) { - httpHeaders.add("mcp-session-id", sessionId.get()); + if (transportSession.sessionId() != null) { + httpHeaders.add("mcp-session-id", transportSession.sessionId()); } if (stream != null && stream.lastId() != null) { httpHeaders.add("last-event-id", stream.lastId()); @@ -123,22 +129,33 @@ else if (response.statusCode().isSameCodeAs(HttpStatus.METHOD_NOT_ALLOWED)) { logger.info("The server does not support SSE streams, using request-response mode."); return Flux.empty(); } + else if (response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND)) { + logger.info("Session {} was not found on the MCP server", transportSession.sessionId()); + + McpSessionNotFoundException notFoundException = new McpSessionNotFoundException( + "Session " + transportSession.sessionId() + " not found"); + // inform the stream/connection subscriber + return Flux.error(notFoundException); + } else { return response.createError().doOnError(e -> { logger.info("Opening an SSE stream failed. This can be safely ignored.", e); }).flux(); } }) + .doOnError(e -> { + this.exceptionHandler.get().accept(e); + }) .doFinally(s -> { Disposable ref = disposableRef.getAndSet(null); if (ref != null) { - this.openConnections.remove(ref); + transportSession.removeConnection(ref); } }) .contextWrite(ctx) .subscribe(); disposableRef.set(connection); - this.openConnections.add(connection); + transportSession.addConnection(connection); } @Override @@ -151,20 +168,22 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { // listen for messages. // If it doesn't, nothing actually happens here, that's just the way it is... final AtomicReference disposableRef = new AtomicReference<>(); + final McpTransportSession transportSession = this.activeSession.get(); + Disposable connection = webClient.post() .uri(this.endpoint) .accept(MediaType.TEXT_EVENT_STREAM, MediaType.APPLICATION_JSON) .headers(httpHeaders -> { - if (sessionId.get() != null) { - httpHeaders.add("mcp-session-id", sessionId.get()); + if (transportSession.sessionId() != null) { + httpHeaders.add("mcp-session-id", transportSession.sessionId()); } }) .bodyValue(message) .exchangeToFlux(response -> { - // TODO: this goes into the request phase - if (!initialized.compareAndExchange(false, true)) { + if (transportSession.markInitialized()) { if (!response.headers().header("mcp-session-id").isEmpty()) { - sessionId.set(response.headers().asHttpHeaders().getFirst("mcp-session-id")); + transportSession + .setSessionId(response.headers().asHttpHeaders().getFirst("mcp-session-id")); // Once we have a session, we try to open an async stream for // the server to send notifications and requests out-of-band. reconnect(null, sink.contextView()); @@ -176,10 +195,10 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { // if (!response.statusCode().isSameCodeAs(HttpStatus.ACCEPTED)) { if (!response.statusCode().is2xxSuccessful()) { if (response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND)) { - logger.info("Session {} was not found on the MCP server", sessionId.get()); + logger.info("Session {} was not found on the MCP server", transportSession.sessionId()); McpSessionNotFoundException notFoundException = new McpSessionNotFoundException( - "Session " + sessionId.get() + " not found"); + "Session " + transportSession.sessionId() + " not found"); // inform the caller of sendMessage sink.error(notFoundException); // inform the stream/connection subscriber @@ -233,8 +252,6 @@ else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) { } }) .flatMapIterable(Function.identity()); - // .map(Mono::just) - // .flatMap(this.handler.get()); } else { sink.error(new RuntimeException("Unknown media type")); @@ -246,13 +263,13 @@ else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) { .doFinally(s -> { Disposable ref = disposableRef.getAndSet(null); if (ref != null) { - this.openConnections.remove(ref); + transportSession.removeConnection(ref); } }) .contextWrite(sink.contextView()) .subscribe(); disposableRef.set(connection); - this.openConnections.add(connection); + transportSession.addConnection(connection); }); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index e3a997ba..9ef900b0 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -11,15 +11,13 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import com.fasterxml.jackson.core.type.TypeReference; -import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.*; import io.modelcontextprotocol.spec.McpClientSession.NotificationHandler; import io.modelcontextprotocol.spec.McpClientSession.RequestHandler; -import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; @@ -30,7 +28,6 @@ import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpSchema.PaginatedRequest; import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpTransport; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; @@ -83,9 +80,7 @@ public class McpAsyncClient { private static TypeReference VOID_TYPE_REFERENCE = new TypeReference<>() { }; - protected final Sinks.One initializedSink = Sinks.one(); - - private AtomicBoolean initialized = new AtomicBoolean(false); + private final AtomicReference initialization = new AtomicReference<>(); /** * The max timeout to await for the client-server connection to be initialized. @@ -108,21 +103,6 @@ public class McpAsyncClient { */ private final McpSchema.Implementation clientInfo; - /** - * Server capabilities. - */ - private McpSchema.ServerCapabilities serverCapabilities; - - /** - * Server instructions. - */ - private String serverInstructions; - - /** - * Server implementation information. - */ - private McpSchema.Implementation serverInfo; - /** * Roots define the boundaries of where servers can operate within the filesystem, * allowing them to understand which directories and files they have access to. @@ -234,16 +214,24 @@ public class McpAsyncClient { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_MESSAGE, asyncLoggingNotificationHandler(loggingConsumersFinal)); - this.mcpSession = new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers); + this.mcpSession = new McpClientSession(requestTimeout, transport, this::handleException, requestHandlers, + notificationHandlers); } + private void handleException(Throwable t) { + if (t instanceof McpSessionNotFoundException) { + this.initialize().subscribe(); + } + } + /** * Get the server capabilities that define the supported features and functionality. * @return The server capabilities */ public McpSchema.ServerCapabilities getServerCapabilities() { - return this.serverCapabilities; + McpSchema.InitializeResult initializeResult = this.initialization.get(); + return initializeResult != null ? initializeResult.capabilities() : null; } /** @@ -252,7 +240,8 @@ public McpSchema.ServerCapabilities getServerCapabilities() { * @return The server instructions */ public String getServerInstructions() { - return this.serverInstructions; + McpSchema.InitializeResult initializeResult = this.initialization.get(); + return initializeResult != null ? initializeResult.instructions() : null; } /** @@ -260,7 +249,8 @@ public String getServerInstructions() { * @return The server implementation details */ public McpSchema.Implementation getServerInfo() { - return this.serverInfo; + McpSchema.InitializeResult initializeResult = this.initialization.get(); + return initializeResult != null ? initializeResult.serverInfo() : null; } /** @@ -268,7 +258,7 @@ public McpSchema.Implementation getServerInfo() { * @return true if the client-server connection is initialized */ public boolean isInitialized() { - return this.initialized.get(); + return this.initialization.get() != null; } /** @@ -341,11 +331,6 @@ public Mono initialize() { }); return result.flatMap(initializeResult -> { - - this.serverCapabilities = initializeResult.capabilities(); - this.serverInstructions = initializeResult.instructions(); - this.serverInfo = initializeResult.serverInfo(); - logger.info("Server response with Protocol: {}, Capabilities: {}, Info: {} and Instructions {}", initializeResult.protocolVersion(), initializeResult.capabilities(), initializeResult.serverInfo(), initializeResult.instructions()); @@ -356,8 +341,7 @@ public Mono initialize() { } return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null).doOnSuccess(v -> { - this.initialized.set(true); - this.initializedSink.tryEmitValue(initializeResult); + this.initialization.set(initializeResult); }).thenReturn(initializeResult); }); } @@ -372,11 +356,25 @@ public Mono initialize() { */ private Mono withInitializationCheck(String actionName, Function> operation) { - return this.initializedSink.asMono() - .timeout(this.initializationTimeout) - .onErrorResume(TimeoutException.class, - ex -> Mono.error(new McpError("Client must be initialized before " + actionName))) - .flatMap(operation); + return Mono.defer(() -> { + McpSchema.InitializeResult initializeResult = this.initialization.get(); + // FIXME: in case of bursts this will trigger multiple inits, we have to batch + // requests + // and dispatch once a single init finishes + if (initializeResult != null) { + return operation.apply(initializeResult); + } + else { + return this.initialize() + .timeout(this.initializationTimeout) + // TODO: McpError should be used when communicating over the wire, not + // to + // the user of the client API + .onErrorResume(TimeoutException.class, + ex -> Mono.error(new McpError("Client must be initialized before " + actionName))) + .flatMap(operation); + } + }); } // -------------------------- @@ -522,7 +520,7 @@ private RequestHandler samplingCreateMessageHandler() { */ public Mono callTool(McpSchema.CallToolRequest callToolRequest) { return this.withInitializationCheck("calling tools", initializedResult -> { - if (this.serverCapabilities.tools() == null) { + if (initializedResult.capabilities().tools() == null) { return Mono.error(new McpError("Server does not provide tools capability")); } return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF); @@ -544,7 +542,7 @@ public Mono listTools() { */ public Mono listTools(String cursor) { return this.withInitializationCheck("listing tools", initializedResult -> { - if (this.serverCapabilities.tools() == null) { + if (initializedResult.capabilities().tools() == null) { return Mono.error(new McpError("Server does not provide tools capability")); } return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor), @@ -601,7 +599,7 @@ public Mono listResources() { */ public Mono listResources(String cursor) { return this.withInitializationCheck("listing resources", initializedResult -> { - if (this.serverCapabilities.resources() == null) { + if (initializedResult.capabilities().resources() == null) { return Mono.error(new McpError("Server does not provide the resources capability")); } return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_LIST, new McpSchema.PaginatedRequest(cursor), @@ -632,7 +630,7 @@ public Mono readResource(McpSchema.Resource resour */ public Mono readResource(McpSchema.ReadResourceRequest readResourceRequest) { return this.withInitializationCheck("reading resources", initializedResult -> { - if (this.serverCapabilities.resources() == null) { + if (initializedResult.capabilities().resources() == null) { return Mono.error(new McpError("Server does not provide the resources capability")); } return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_READ, readResourceRequest, @@ -661,7 +659,7 @@ public Mono listResourceTemplates() { */ public Mono listResourceTemplates(String cursor) { return this.withInitializationCheck("listing resource templates", initializedResult -> { - if (this.serverCapabilities.resources() == null) { + if (initializedResult.capabilities().resources() == null) { return Mono.error(new McpError("Server does not provide the resources capability")); } return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index f577b493..2d57b891 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -9,6 +9,7 @@ import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.util.Assert; @@ -104,7 +105,7 @@ public interface NotificationHandler { * @param requestHandlers Map of method names to request handlers * @param notificationHandlers Map of method names to notification handlers */ - public McpClientSession(Duration requestTimeout, McpClientTransport transport, + public McpClientSession(Duration requestTimeout, McpClientTransport transport, Consumer exceptionHandler, Map> requestHandlers, Map notificationHandlers) { Assert.notNull(requestTimeout, "The requestTimeout can not be null"); @@ -123,6 +124,23 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport, // create child Observation and emit it together with the message to the // consumer this.connection = this.transport.connect(mono -> mono.doOnNext(this::handle)).subscribe(); + this.transport.handleException(t -> { + this.pendingResponses.clear(); + exceptionHandler.accept(t); + }); + } + + /** + * Creates a new McpClientSession with the specified configuration and handlers. + * @param requestTimeout Duration to wait for responses + * @param transport Transport implementation for message exchange + * @param requestHandlers Map of method names to request handlers + * @param notificationHandlers Map of method names to notification handlers + */ + public McpClientSession(Duration requestTimeout, McpClientTransport transport, + Map> requestHandlers, Map notificationHandlers) { + this(requestTimeout, transport, e -> { + }, requestHandlers, notificationHandlers); } private void handle(McpSchema.JSONRPCMessage message) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java index 5fb4cfc8..4c784ce5 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java @@ -19,6 +19,12 @@ public interface McpClientTransport extends McpTransport { // @Deprecated Mono connect(Function, Mono> handler); - // Mono connect(Consumer consumer); + default void handleException(Consumer handler) { + } + + // default void connect(Consumer consumer) { + // this.connect((Function, + // Mono>) mono -> mono.doOnNext(consumer)).subscribe(); + // } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java new file mode 100644 index 00000000..dd5c108e --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java @@ -0,0 +1,49 @@ +package io.modelcontextprotocol.spec; + +import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.publisher.Mono; + +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +public class McpTransportSession { + + private final Disposable.Composite openConnections = Disposables.composite(); + + private final AtomicBoolean initialized = new AtomicBoolean(false); + + private final AtomicReference sessionId = new AtomicReference<>(); + + public McpTransportSession() { + } + + public String sessionId() { + return this.sessionId.get(); + } + + public void setSessionId(String sessionId) { + this.sessionId.set(sessionId); + } + + public boolean markInitialized() { + return this.initialized.compareAndSet(false, true); + } + + public void addConnection(Disposable connection) { + this.openConnections.add(connection); + } + + public void removeConnection(Disposable connection) { + this.openConnections.remove(connection); + } + + public void close() { + this.closeGracefully().subscribe(); + } + + public Mono closeGracefully() { + return Mono.fromRunnable(this.openConnections::dispose); + } + +} From 84f7ae05db073249957918a90c9abee6b035840c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Thu, 29 May 2025 14:25:58 +0200 Subject: [PATCH 05/20] Starting with resiliency tests --- mcp-spring/mcp-spring-webflux/pom.xml | 6 + .../transport/WebFluxSseClientTransport.java | 3 + ...eamableHttpAsyncClientResiliencyTests.java | 20 +++ mcp-test/pom.xml | 5 + ...AbstractMcpAsyncClientResiliencyTests.java | 141 ++++++++++++++++++ .../spec/McpClientSession.java | 5 +- pom.xml | 3 +- 7 files changed, 181 insertions(+), 2 deletions(-) create mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java create mode 100644 mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java diff --git a/mcp-spring/mcp-spring-webflux/pom.xml b/mcp-spring/mcp-spring-webflux/pom.xml index a8b92bd0..26452fe9 100644 --- a/mcp-spring/mcp-spring-webflux/pom.xml +++ b/mcp-spring/mcp-spring-webflux/pom.xml @@ -99,6 +99,12 @@ ${testcontainers.version} test + + org.testcontainers + toxiproxy + ${toxiproxy.version} + test + org.awaitility diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java index 37abe295..128cda4c 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java @@ -190,6 +190,9 @@ public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, ObjectMappe */ @Override public Mono connect(Function, Mono> handler) { + // TODO: Avoid eager connection opening and enable resilience + // -> upon disconnects, re-establish connection + // -> allow optimizing for eager connection start using a constructor flag Flux> events = eventStream(); this.inboundSubscription = events.concatMap(event -> Mono.just(event).handle((e, s) -> { if (ENDPOINT_EVENT_TYPE.equals(event.event())) { diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java new file mode 100644 index 00000000..4f7bbf37 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java @@ -0,0 +1,20 @@ +package io.modelcontextprotocol.client; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.springframework.web.reactive.function.client.WebClient; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +public class WebClientStreamableHttpAsyncClientResiliencyTests extends AbstractMcpAsyncClientResiliencyTests { + + @Override + protected McpClientTransport createMcpTransport() { + return new WebClientStreamableHttpTransport(new ObjectMapper(), WebClient.builder().baseUrl(host), "/mcp", true, + false); + } + +} diff --git a/mcp-test/pom.xml b/mcp-test/pom.xml index a6e5bdb0..9998569d 100644 --- a/mcp-test/pom.xml +++ b/mcp-test/pom.xml @@ -68,6 +68,11 @@ junit-jupiter ${testcontainers.version} + + org.testcontainers + toxiproxy + ${toxiproxy.version} + org.awaitility diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java new file mode 100644 index 00000000..5bd388c7 --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java @@ -0,0 +1,141 @@ +package io.modelcontextprotocol.client; + +import com.fasterxml.jackson.databind.ObjectMapper; +import eu.rekawek.toxiproxy.Proxy; +import eu.rekawek.toxiproxy.ToxiproxyClient; +import eu.rekawek.toxiproxy.model.ToxicDirection; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import org.awaitility.Awaitility; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.Network; +import org.testcontainers.containers.ToxiproxyContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import reactor.test.StepVerifier; + +import java.io.IOException; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; + +public abstract class AbstractMcpAsyncClientResiliencyTests { + + static Network network = Network.newNetwork(); + static String host = "http://localhost:3001"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withNetwork(network) + .withNetworkAliases("everything-server") + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + static ToxiproxyContainer toxiproxy = new ToxiproxyContainer("ghcr.io/shopify/toxiproxy:2.5.0").withNetwork(network) + .withExposedPorts(8474, 3000); + + static Proxy proxy; + + static { + container.start(); + + toxiproxy.start(); + + final ToxiproxyClient toxiproxyClient = new ToxiproxyClient(toxiproxy.getHost(), toxiproxy.getControlPort()); + try { + proxy = toxiproxyClient.createProxy("everything-server", "0.0.0.0:3000", "everything-server:3001"); + } + catch (IOException e) { + throw new RuntimeException("Can't create proxy!", e); + } + + final String ipAddressViaToxiproxy = toxiproxy.getHost(); + final int portViaToxiproxy = toxiproxy.getMappedPort(3000); + + // int port = container.getMappedPort(3001); + host = "http://" + ipAddressViaToxiproxy + ":" + portViaToxiproxy; + } + + abstract McpClientTransport createMcpTransport(); + + protected Duration getRequestTimeout() { + return Duration.ofSeconds(14); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(2); + } + + McpAsyncClient client(McpClientTransport transport) { + return client(transport, Function.identity()); + } + + McpAsyncClient client(McpClientTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); + + assertThatCode(() -> { + McpClient.AsyncSpec builder = McpClient.async(transport) + .requestTimeout(getRequestTimeout()) + .initializationTimeout(getInitializationTimeout()) + .capabilities(McpSchema.ClientCapabilities.builder().roots(true).build()); + builder = customizer.apply(builder); + client.set(builder.build()); + }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(McpClientTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(McpClientTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10)); + } + } + + @Test + void testPing() { + withClient(createMcpTransport(), mcpAsyncClient -> { + try { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + // disconnect + // proxy.toxics().bandwidth("CUT_CONNECTION_DOWNSTREAM", + // ToxicDirection.DOWNSTREAM, 0); + // proxy.toxics().bandwidth("CUT_CONNECTION_UPSTREAM", + // ToxicDirection.UPSTREAM, 0); + proxy.toxics().resetPeer("RESET_DOWNSTREAM", ToxicDirection.DOWNSTREAM, 0); + proxy.toxics().resetPeer("RESET_UPSTREAM", ToxicDirection.UPSTREAM, 0); + + StepVerifier.create(mcpAsyncClient.ping()).expectError().verify(); + + proxy.toxics().get("RESET_UPSTREAM").remove(); + proxy.toxics().get("RESET_DOWNSTREAM").remove(); + // proxy.toxics().get("CUT_CONNECTION_DOWNSTREAM").remove(); + // proxy.toxics().get("CUT_CONNECTION_UPSTREAM").remove(); + + StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + } + catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index 2d57b891..c816eae9 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -125,7 +125,10 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport, C // consumer this.connection = this.transport.connect(mono -> mono.doOnNext(this::handle)).subscribe(); this.transport.handleException(t -> { - this.pendingResponses.clear(); + // 🤔 let's think for a moment - we only clear when the session is invalidated + if (t instanceof McpSessionNotFoundException) { + this.pendingResponses.clear(); + } exceptionHandler.accept(t); }); } diff --git a/pom.xml b/pom.xml index c2327ee8..3fd0857e 100644 --- a/pom.xml +++ b/pom.xml @@ -63,7 +63,8 @@ 5.10.2 5.17.0 1.20.4 - 1.17.5 + 1.17.5 + 1.21.0 2.0.16 1.5.15 From 7b7fa87aa66d1f779a4562809e39f756332fc22f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Thu, 29 May 2025 19:24:27 +0200 Subject: [PATCH 06/20] More logs, resilience tests improved --- .../WebClientStreamableHttpTransport.java | 166 +++++++++++------- ...eamableHttpAsyncClientResiliencyTests.java | 4 - .../src/test/resources/logback.xml | 6 +- ...AbstractMcpAsyncClientResiliencyTests.java | 67 ++++--- .../spec/McpClientSession.java | 2 +- .../spec/McpClientTransport.java | 2 +- .../spec/McpSessionNotFoundException.java | 4 +- 7 files changed, 154 insertions(+), 97 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java index 55240dd5..671f35aa 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -69,6 +69,7 @@ public Mono connect(Function, Mono { this.handler.set(handler); if (openConnectionOnStartup) { + logger.debug("Eagerly opening connection on startup"); this.reconnect(null, ctx); } return Mono.empty(); @@ -76,13 +77,23 @@ public Mono connect(Function, Mono handler) { + public void registerExceptionHandler(Consumer handler) { + logger.debug("Exception handler registered"); this.exceptionHandler.set(handler); } + private void handleException(Throwable t) { + logger.debug("Handling exception for session {}", activeSession.get().sessionId(), t); + Consumer handler = this.exceptionHandler.get(); + if (handler != null) { + handler.accept(t); + } + } + @Override public Mono closeGracefully() { return Mono.defer(() -> { + logger.debug("Graceful close triggered"); McpTransportSession currentSession = this.activeSession.get(); if (currentSession != null) { return currentSession.closeGracefully(); @@ -92,6 +103,12 @@ public Mono closeGracefully() { } private void reconnect(McpStream stream, ContextView ctx) { + if (stream != null) { + logger.debug("Reconnecting stream {} with lastId {}", stream.streamId(), stream.lastId()); + } + else { + logger.debug("Reconnecting with no prior stream"); + } // Here we attempt to initialize the client. // In case the server supports SSE, we will establish a long-running session // here and @@ -113,10 +130,11 @@ private void reconnect(McpStream stream, ContextView ctx) { .exchangeToFlux(response -> { // Per spec, we are not checking whether it's 2xx, but only if the // Accept header is proper. - if (response.headers().contentType().isPresent() + if (response.statusCode().is2xxSuccessful() && response.headers().contentType().isPresent() && response.headers().contentType().get().isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) { McpStream sessionStream = stream != null ? stream : new McpStream(this.resumableStreams); + logger.debug("Established stream {}", sessionStream.streamId()); Flux, Iterable>> idWithMessages = response .bodyToFlux(new ParameterizedTypeReference>() { @@ -126,14 +144,14 @@ private void reconnect(McpStream stream, ContextView ctx) { return sessionStream.consumeSseStream(idWithMessages); } else if (response.statusCode().isSameCodeAs(HttpStatus.METHOD_NOT_ALLOWED)) { - logger.info("The server does not support SSE streams, using request-response mode."); + logger.debug("The server does not support SSE streams, using request-response mode."); return Flux.empty(); } else if (response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND)) { - logger.info("Session {} was not found on the MCP server", transportSession.sessionId()); + logger.warn("Session {} was not found on the MCP server", transportSession.sessionId()); McpSessionNotFoundException notFoundException = new McpSessionNotFoundException( - "Session " + transportSession.sessionId() + " not found"); + transportSession.sessionId()); // inform the stream/connection subscriber return Flux.error(notFoundException); } @@ -143,8 +161,9 @@ else if (response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND)) { }).flux(); } }) - .doOnError(e -> { - this.exceptionHandler.get().accept(e); + .onErrorResume(t -> { + this.handleException(t); + return Flux.empty(); }) .doFinally(s -> { Disposable ref = disposableRef.getAndSet(null); @@ -161,7 +180,7 @@ else if (response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND)) { @Override public Mono sendMessage(McpSchema.JSONRPCMessage message) { return Mono.create(sink -> { - System.out.println("Sending message " + message); + logger.debug("Sending message {}", message); // Here we attempt to initialize the client. // In case the server supports SSE, we will establish a long-running session // here and @@ -182,8 +201,9 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { .exchangeToFlux(response -> { if (transportSession.markInitialized()) { if (!response.headers().header("mcp-session-id").isEmpty()) { - transportSession - .setSessionId(response.headers().asHttpHeaders().getFirst("mcp-session-id")); + String sessionId = response.headers().asHttpHeaders().getFirst("mcp-session-id"); + logger.debug("Established session with id {}", sessionId); + transportSession.setSessionId(sessionId); // Once we have a session, we try to open an async stream for // the server to send notifications and requests out-of-band. reconnect(null, sink.contextView()); @@ -193,12 +213,72 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { // The spec mentions only ACCEPTED, but the existing SDKs can return // 200 OK for notifications // if (!response.statusCode().isSameCodeAs(HttpStatus.ACCEPTED)) { - if (!response.statusCode().is2xxSuccessful()) { + if (response.statusCode().is2xxSuccessful()) { + // Existing SDKs consume notifications with no response body nor + // content type + if (response.headers().contentType().isEmpty()) { + logger.trace("Message was successfuly sent via POST for session {}", + transportSession.sessionId()); + // signal the caller that the message was successfully + // delivered + sink.success(); + // communicate to downstream there is no streamed data coming + return Flux.empty(); + } + + MediaType contentType = response.headers().contentType().get(); + + if (contentType.isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) { + // communicate to caller that the message was delivered + sink.success(); + + // starting a stream + McpStream sessionStream = new McpStream(this.resumableStreams); + + logger.trace("Sent POST and opened a stream ({}) for session {}", sessionStream.streamId(), + transportSession.sessionId()); + + Flux, Iterable>> idWithMessages = response + .bodyToFlux(new ParameterizedTypeReference>() { + }) + .map(this::parse); + + return sessionStream.consumeSseStream(idWithMessages); + } + else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) { + logger.trace("Received response to POST for session {}", transportSession.sessionId()); + + // communicate to caller the message was delivered + sink.success(); + + // provide the response body as a stream of a single response + // to consume + return response.bodyToMono( + String.class).>handle((responseMessage, s) -> { + try { + McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema + .deserializeJsonRpcMessage(objectMapper, responseMessage); + s.next(List.of(jsonRpcResponse)); + } + catch (IOException e) { + s.error(e); + } + }) + .flatMapIterable(Function.identity()); + } + else { + logger.warn("Unknown media type {} returned for POST in session {}", contentType, + transportSession.sessionId()); + sink.error(new RuntimeException("Unknown media type returned: " + contentType)); + return Flux.empty(); + } + } + else { if (response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND)) { - logger.info("Session {} was not found on the MCP server", transportSession.sessionId()); + logger.warn("Session {} was not found on the MCP server", transportSession.sessionId()); McpSessionNotFoundException notFoundException = new McpSessionNotFoundException( - "Session " + transportSession.sessionId() + " not found"); + transportSession.sessionId()); // inform the caller of sendMessage sink.error(notFoundException); // inform the stream/connection subscriber @@ -208,58 +288,14 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { sink.error(new RuntimeException("Sending request failed", e)); }).flux(); } - - // Existing SDKs consume notifications with no response body nor - // content type - if (response.headers().contentType().isEmpty()) { - sink.success(); - return Flux.empty(); - // return - // response.createError().doOnError(e -> - // { - //// sink.error(new RuntimeException("Response has no content - // type")); - // }).flux(); - } - - MediaType contentType = response.headers().contentType().get(); - - if (contentType.isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) { - sink.success(); - McpStream sessionStream = new McpStream(this.resumableStreams); - - Flux, Iterable>> idWithMessages = response - .bodyToFlux(new ParameterizedTypeReference>() { - }) - .map(this::parse); - - return sessionStream.consumeSseStream(idWithMessages); - } - else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) { - sink.success(); - // return response.bodyToMono(new - // ParameterizedTypeReference>() - // {}); - return response.bodyToMono( - String.class).>handle((responseMessage, s) -> { - try { - McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema - .deserializeJsonRpcMessage(objectMapper, responseMessage); - s.next(List.of(jsonRpcResponse)); - } - catch (IOException e) { - s.error(e); - } - }) - .flatMapIterable(Function.identity()); - } - else { - sink.error(new RuntimeException("Unknown media type")); - return Flux.empty(); - } }) .map(Mono::just) .flatMap(this.handler.get()) + .onErrorResume(t -> { + this.handleException(t); + sink.error(t); + return Flux.empty(); + }) .doFinally(s -> { Disposable ref = disposableRef.getAndSet(null); if (ref != null) { @@ -281,7 +317,7 @@ public T unmarshalFrom(Object data, TypeReference typeRef) { private Tuple2, Iterable> parse(ServerSentEvent event) { if (MESSAGE_EVENT_TYPE.equals(event.event())) { try { - // TODO: support batching + // TODO: support batching? McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, event.data()); return Tuples.of(Optional.ofNullable(event.id()), List.of(message)); } @@ -313,6 +349,10 @@ String lastId() { return this.lastId.get(); } + long streamId() { + return this.streamId; + } + Flux consumeSseStream( Publisher, Iterable>> eventStream) { return Flux.deferContextual(ctx -> Flux.from(eventStream).doOnError(e -> { diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java index 4f7bbf37..58ca5e00 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java @@ -3,11 +3,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; import io.modelcontextprotocol.spec.McpClientTransport; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; import org.springframework.web.reactive.function.client.WebClient; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.containers.wait.strategy.Wait; public class WebClientStreamableHttpAsyncClientResiliencyTests extends AbstractMcpAsyncClientResiliencyTests { diff --git a/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml b/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml index 5ad73374..80730429 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml +++ b/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml @@ -9,13 +9,13 @@ - + - + - + diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java index 5bd388c7..204dee13 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java @@ -10,6 +10,8 @@ import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.Network; import org.testcontainers.containers.ToxiproxyContainer; @@ -27,6 +29,8 @@ public abstract class AbstractMcpAsyncClientResiliencyTests { + private static final Logger logger = LoggerFactory.getLogger(AbstractMcpAsyncClientResiliencyTests.class); + static Network network = Network.newNetwork(); static String host = "http://localhost:3001"; @@ -65,6 +69,37 @@ public abstract class AbstractMcpAsyncClientResiliencyTests { host = "http://" + ipAddressViaToxiproxy + ":" + portViaToxiproxy; } + void disconnect() { + long start = System.nanoTime(); + try { + // disconnect + // proxy.toxics().bandwidth("CUT_CONNECTION_DOWNSTREAM", + // ToxicDirection.DOWNSTREAM, 0); + // proxy.toxics().bandwidth("CUT_CONNECTION_UPSTREAM", + // ToxicDirection.UPSTREAM, 0); + proxy.toxics().resetPeer("RESET_DOWNSTREAM", ToxicDirection.DOWNSTREAM, 0); + proxy.toxics().resetPeer("RESET_UPSTREAM", ToxicDirection.UPSTREAM, 0); + logger.info("Disconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to disconnect", e); + } + } + + void reconnect() { + long start = System.nanoTime(); + try { + proxy.toxics().get("RESET_UPSTREAM").remove(); + proxy.toxics().get("RESET_DOWNSTREAM").remove(); + // proxy.toxics().get("CUT_CONNECTION_DOWNSTREAM").remove(); + // proxy.toxics().get("CUT_CONNECTION_UPSTREAM").remove(); + logger.info("Reconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to reconnect", e); + } + } + abstract McpClientTransport createMcpTransport(); protected Duration getRequestTimeout() { @@ -112,29 +147,15 @@ void withClient(McpClientTransport transport, Function { - try { - StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); - - // disconnect - // proxy.toxics().bandwidth("CUT_CONNECTION_DOWNSTREAM", - // ToxicDirection.DOWNSTREAM, 0); - // proxy.toxics().bandwidth("CUT_CONNECTION_UPSTREAM", - // ToxicDirection.UPSTREAM, 0); - proxy.toxics().resetPeer("RESET_DOWNSTREAM", ToxicDirection.DOWNSTREAM, 0); - proxy.toxics().resetPeer("RESET_UPSTREAM", ToxicDirection.UPSTREAM, 0); - - StepVerifier.create(mcpAsyncClient.ping()).expectError().verify(); - - proxy.toxics().get("RESET_UPSTREAM").remove(); - proxy.toxics().get("RESET_DOWNSTREAM").remove(); - // proxy.toxics().get("CUT_CONNECTION_DOWNSTREAM").remove(); - // proxy.toxics().get("CUT_CONNECTION_UPSTREAM").remove(); - - StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); - } - catch (IOException e) { - throw new RuntimeException(e); - } + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + disconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectError().verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); }); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index c816eae9..534ae01c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -124,7 +124,7 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport, C // create child Observation and emit it together with the message to the // consumer this.connection = this.transport.connect(mono -> mono.doOnNext(this::handle)).subscribe(); - this.transport.handleException(t -> { + this.transport.registerExceptionHandler(t -> { // 🤔 let's think for a moment - we only clear when the session is invalidated if (t instanceof McpSessionNotFoundException) { this.pendingResponses.clear(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java index 4c784ce5..6550121c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java @@ -19,7 +19,7 @@ public interface McpClientTransport extends McpTransport { // @Deprecated Mono connect(Function, Mono> handler); - default void handleException(Consumer handler) { + default void registerExceptionHandler(Consumer handler) { } // default void connect(Consumer consumer) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSessionNotFoundException.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSessionNotFoundException.java index c351620b..3f017de6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSessionNotFoundException.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSessionNotFoundException.java @@ -2,8 +2,8 @@ public class McpSessionNotFoundException extends RuntimeException { - public McpSessionNotFoundException(String message) { - super(message); + public McpSessionNotFoundException(String sessionId) { + super("Session " + sessionId + " not found on the server"); } } From a42f2bfe63b136093f5ea8683a747d9b86844f3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Thu, 29 May 2025 19:27:34 +0200 Subject: [PATCH 07/20] Imports --- .../client/transport/WebClientStreamableHttpTransport.java | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java index 671f35aa..0c9ddcfe 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -2,7 +2,11 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.*; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSessionNotFoundException; +import io.modelcontextprotocol.spec.McpTransportSession; import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; From 122a06ac30ccdeec042ebe4589e9065a25dc305a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Fri, 30 May 2025 14:03:56 +0200 Subject: [PATCH 08/20] Properly handling session invalidation --- .../WebClientStreamableHttpTransport.java | 57 ++++++++++++++----- .../src/test/resources/logback.xml | 2 +- ...AbstractMcpAsyncClientResiliencyTests.java | 14 +++++ .../client/McpAsyncClient.java | 1 + .../spec/McpClientSession.java | 1 + .../spec/McpSessionNotFoundException.java | 5 ++ 6 files changed, 66 insertions(+), 14 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java index 0c9ddcfe..1c3656d8 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -15,8 +15,8 @@ import org.springframework.http.MediaType; import org.springframework.http.codec.ServerSentEvent; import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.client.WebClientResponseException; import reactor.core.Disposable; -import reactor.core.Disposables; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.util.context.ContextView; @@ -26,7 +26,6 @@ import java.io.IOException; import java.util.List; import java.util.Optional; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; @@ -52,10 +51,10 @@ public class WebClientStreamableHttpTransport implements McpClientTransport { private final boolean resumableStreams; - private AtomicReference, Mono>> handler = new AtomicReference<>(); - private final AtomicReference activeSession = new AtomicReference<>(); + private final AtomicReference, Mono>> handler = new AtomicReference<>(); + private final AtomicReference> exceptionHandler = new AtomicReference<>(); public WebClientStreamableHttpTransport(ObjectMapper objectMapper, WebClient.Builder webClientBuilder, @@ -88,6 +87,11 @@ public void registerExceptionHandler(Consumer handler) { private void handleException(Throwable t) { logger.debug("Handling exception for session {}", activeSession.get().sessionId(), t); + if (t instanceof McpSessionNotFoundException) { + McpTransportSession invalidSession = this.activeSession.getAndSet(new McpTransportSession()); + logger.warn("Server does not recognize session {}. Invalidating.", invalidSession.sessionId()); + invalidSession.close(); + } Consumer handler = this.exceptionHandler.get(); if (handler != null) { handler.accept(t); @@ -106,6 +110,8 @@ public Mono closeGracefully() { }); } + // FIXME: Avoid passing the ContextView - add hook allowing the Reactor Context to be + // attached to the chain? private void reconnect(McpStream stream, ContextView ctx) { if (stream != null) { logger.debug("Reconnecting stream {} with lastId {}", stream.streamId(), stream.lastId()); @@ -273,8 +279,7 @@ else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) { else { logger.warn("Unknown media type {} returned for POST in session {}", contentType, transportSession.sessionId()); - sink.error(new RuntimeException("Unknown media type returned: " + contentType)); - return Flux.empty(); + return Flux.error(new RuntimeException("Unknown media type returned: " + contentType)); } } else { @@ -283,20 +288,45 @@ else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) { McpSessionNotFoundException notFoundException = new McpSessionNotFoundException( transportSession.sessionId()); - // inform the caller of sendMessage - sink.error(notFoundException); // inform the stream/connection subscriber return Flux.error(notFoundException); } - return response.createError().doOnError(e -> { - sink.error(new RuntimeException("Sending request failed", e)); + return response.createError().onErrorResume(e -> { + WebClientResponseException responseException = (WebClientResponseException) e; + byte[] body = responseException.getResponseBodyAsByteArray(); + McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = null; + Exception toPropagate; + try { + McpSchema.JSONRPCResponse jsonRpcResponse = objectMapper.readValue(body, + McpSchema.JSONRPCResponse.class); + jsonRpcError = jsonRpcResponse.error(); + toPropagate = new McpError(jsonRpcError); + } + catch (IOException ex) { + toPropagate = new RuntimeException("Sending request failed", e); + logger.debug("Received content together with {} HTTP code response: {}", + response.statusCode(), body); + } + + // Some implementations can return 400 when presented with a + // session id that it doesn't know about, so we will + // invalidate the session + // https://github.com/modelcontextprotocol/typescript-sdk/issues/389 + if (responseException.getStatusCode().isSameCodeAs(HttpStatus.BAD_REQUEST)) { + return Mono.error(new McpSessionNotFoundException(this.activeSession.get().sessionId(), + toPropagate)); + } + return Mono.empty(); }).flux(); } }) .map(Mono::just) .flatMap(this.handler.get()) .onErrorResume(t -> { + // handle the error first this.handleException(t); + + // inform the caller of sendMessage sink.error(t); return Flux.empty(); }) @@ -321,7 +351,8 @@ public T unmarshalFrom(Object data, TypeReference typeRef) { private Tuple2, Iterable> parse(ServerSentEvent event) { if (MESSAGE_EVENT_TYPE.equals(event.event())) { try { - // TODO: support batching? + // We don't support batching ATM and probably won't since the next version + // considers removing it. McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, event.data()); return Tuples.of(Optional.ofNullable(event.id()), List.of(message)); } @@ -340,6 +371,7 @@ private class McpStream { private final AtomicReference lastId = new AtomicReference<>(); + // Used only for internal accounting private final long streamId; private final boolean resumable; @@ -360,8 +392,7 @@ long streamId() { Flux consumeSseStream( Publisher, Iterable>> eventStream) { return Flux.deferContextual(ctx -> Flux.from(eventStream).doOnError(e -> { - // TODO: examine which error :) - if (resumable) { + if (resumable && !(e instanceof McpSessionNotFoundException)) { reconnect(this, ctx); } }) diff --git a/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml b/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml index 80730429..2652e2ee 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml +++ b/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml @@ -9,7 +9,7 @@ - + diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java index 204dee13..9abfa345 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java @@ -159,4 +159,18 @@ void testPing() { }); } + @Test + void testSessionInvalidation() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + container.stop(); + container.start(); + + // The first try will face the session mismatch exception and the second one + // will go through the re-initialization process. + StepVerifier.create(mcpAsyncClient.ping().retry(1)).expectNextCount(1).verifyComplete(); + }); + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index 9ef900b0..0925d476 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -221,6 +221,7 @@ public class McpAsyncClient { private void handleException(Throwable t) { if (t instanceof McpSessionNotFoundException) { + this.initialization.set(null); this.initialize().subscribe(); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index 534ae01c..9ed45766 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -252,6 +252,7 @@ public Mono sendRequest(String method, Object requestParams, TypeReferenc String requestId = this.generateRequestId(); return Mono.deferContextual(ctx -> Mono.create(sink -> { + logger.debug("Sending message for method {}", method); this.pendingResponses.put(requestId, sink); McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method, requestId, requestParams); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSessionNotFoundException.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSessionNotFoundException.java index 3f017de6..1281d087 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSessionNotFoundException.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSessionNotFoundException.java @@ -2,6 +2,11 @@ public class McpSessionNotFoundException extends RuntimeException { + public McpSessionNotFoundException(String sessionId, Exception cause) { + super("Session " + sessionId + " not found on the server", cause); + + } + public McpSessionNotFoundException(String sessionId) { super("Session " + sessionId + " not found on the server"); } From f7ea062d90f72042693369a7b8fd5b70216d1594 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Mon, 2 Jun 2025 13:27:29 +0200 Subject: [PATCH 09/20] Automatic initialization and burst protection --- ...AbstractMcpAsyncClientResiliencyTests.java | 17 ++--- .../client/AbstractMcpAsyncClientTests.java | 37 ++++----- .../client/AbstractMcpSyncClientTests.java | 60 +++++++-------- .../client/McpAsyncClient.java | 75 +++++++++++++------ .../client/AbstractMcpAsyncClientTests.java | 37 ++++----- .../client/AbstractMcpSyncClientTests.java | 60 +++++++-------- 6 files changed, 153 insertions(+), 133 deletions(-) diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java index 9abfa345..39f601d7 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java @@ -1,14 +1,10 @@ package io.modelcontextprotocol.client; -import com.fasterxml.jackson.databind.ObjectMapper; import eu.rekawek.toxiproxy.Proxy; import eu.rekawek.toxiproxy.ToxiproxyClient; import eu.rekawek.toxiproxy.model.ToxicDirection; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; -import org.awaitility.Awaitility; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -24,7 +20,6 @@ import java.util.function.Consumer; import java.util.function.Function; -import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; public abstract class AbstractMcpAsyncClientResiliencyTests { @@ -69,7 +64,7 @@ public abstract class AbstractMcpAsyncClientResiliencyTests { host = "http://" + ipAddressViaToxiproxy + ":" + portViaToxiproxy; } - void disconnect() { + private static void disconnect() { long start = System.nanoTime(); try { // disconnect @@ -86,7 +81,7 @@ void disconnect() { } } - void reconnect() { + private static void reconnect() { long start = System.nanoTime(); try { proxy.toxics().get("RESET_UPSTREAM").remove(); @@ -100,6 +95,11 @@ void reconnect() { } } + private static void restartMcpServer() { + container.stop(); + container.start(); + } + abstract McpClientTransport createMcpTransport(); protected Duration getRequestTimeout() { @@ -164,8 +164,7 @@ void testSessionInvalidation() { withClient(createMcpTransport(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); - container.stop(); - container.start(); + restartMcpServer(); // The first try will face the session mismatch exception and the second one // will go through the re-initialization process. diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 5452c8ea..049bea00 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -110,14 +110,16 @@ void tearDown() { onClose(); } - void verifyInitializationTimeout(Function> operation, String action) { + void verifyNotificationSucceedsWithImplicitInitialization(Function> operation, + String action) { withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.withVirtualTime(() -> operation.apply(mcpAsyncClient)) - .expectSubscription() - .thenAwait(getInitializationTimeout()) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before " + action)) - .verify(); + StepVerifier.create(operation.apply(mcpAsyncClient)).verifyComplete(); + }); + } + + void verifyCallSucceedsWithImplicitInitialization(Function> operation, String action) { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(operation.apply(mcpAsyncClient)).expectNextCount(1).verifyComplete(); }); } @@ -133,7 +135,7 @@ void testConstructorWithInvalidArguments() { @Test void testListToolsWithoutInitialization() { - verifyInitializationTimeout(client -> client.listTools(null), "listing tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listTools(null), "listing tools"); } @Test @@ -153,7 +155,7 @@ void testListTools() { @Test void testPingWithoutInitialization() { - verifyInitializationTimeout(client -> client.ping(), "pinging the server"); + verifyCallSucceedsWithImplicitInitialization(client -> client.ping(), "pinging the server"); } @Test @@ -168,7 +170,7 @@ void testPing() { @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - verifyInitializationTimeout(client -> client.callTool(callToolRequest), "calling tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.callTool(callToolRequest), "calling tools"); } @Test @@ -202,7 +204,7 @@ void testCallToolWithInvalidTool() { @Test void testListResourcesWithoutInitialization() { - verifyInitializationTimeout(client -> client.listResources(null), "listing resources"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResources(null), "listing resources"); } @Test @@ -233,7 +235,7 @@ void testMcpAsyncClientState() { @Test void testListPromptsWithoutInitialization() { - verifyInitializationTimeout(client -> client.listPrompts(null), "listing " + "prompts"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listPrompts(null), "listing " + "prompts"); } @Test @@ -258,7 +260,7 @@ void testListPrompts() { @Test void testGetPromptWithoutInitialization() { GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - verifyInitializationTimeout(client -> client.getPrompt(request), "getting " + "prompts"); + verifyCallSucceedsWithImplicitInitialization(client -> client.getPrompt(request), "getting " + "prompts"); } @Test @@ -279,7 +281,7 @@ void testGetPrompt() { @Test void testRootsListChangedWithoutInitialization() { - verifyInitializationTimeout(client -> client.rootsListChangedNotification(), + verifyNotificationSucceedsWithImplicitInitialization(client -> client.rootsListChangedNotification(), "sending roots list changed notification"); } @@ -354,7 +356,8 @@ void testReadResource() { @Test void testListResourceTemplatesWithoutInitialization() { - verifyInitializationTimeout(client -> client.listResourceTemplates(), "listing resource templates"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResourceTemplates(), + "listing resource templates"); } @Test @@ -447,8 +450,8 @@ void testInitializeWithAllCapabilities() { @Test void testLoggingLevelsWithoutInitialization() { - verifyInitializationTimeout(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), - "setting logging level"); + verifyNotificationSucceedsWithImplicitInitialization( + client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), "setting logging level"); } @Test diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 128441f8..3785fd64 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -5,6 +5,7 @@ package io.modelcontextprotocol.client; import java.time.Duration; +import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; @@ -12,7 +13,6 @@ import java.util.function.Function; import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; @@ -112,33 +112,18 @@ void tearDown() { static final Object DUMMY_RETURN_VALUE = new Object(); - void verifyNotificationTimesOut(Consumer operation, String action) { - verifyCallTimesOut(client -> { + void verifyNotificationSucceedsWithImplicitInitialization(Consumer operation, String action) { + verifyCallSucceedsWithImplicitInitialization(client -> { operation.accept(client); return DUMMY_RETURN_VALUE; }, action); } - void verifyCallTimesOut(Function blockingOperation, String action) { + void verifyCallSucceedsWithImplicitInitialization(Function blockingOperation, String action) { withClient(createMcpTransport(), mcpSyncClient -> { - // This scheduler is not replaced by virtual time scheduler - Scheduler customScheduler = Schedulers.newBoundedElastic(1, 1, "actualBoundedElastic"); - - StepVerifier.withVirtualTime(() -> Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient)) - // Offload the blocking call to the real scheduler - .subscribeOn(customScheduler)) - .expectSubscription() - // This works without actually waiting but executes all the - // tasks pending execution on the VirtualTimeScheduler. - // It is possible to execute the blocking code from the operation - // because it is blocked on a dedicated Scheduler and the main - // flow is not blocked and uses the VirtualTimeScheduler. - .thenAwait(getInitializationTimeout()) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before " + action)) - .verify(); - - customScheduler.dispose(); + StepVerifier.create(Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient))) + .expectNextCount(1) + .verifyComplete(); }); } @@ -154,7 +139,7 @@ void testConstructorWithInvalidArguments() { @Test void testListToolsWithoutInitialization() { - verifyCallTimesOut(client -> client.listTools(null), "listing tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listTools(null), "listing tools"); } @Test @@ -175,8 +160,8 @@ void testListTools() { @Test void testCallToolsWithoutInitialization() { - verifyCallTimesOut(client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))), - "calling tools"); + verifyCallSucceedsWithImplicitInitialization( + client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))), "calling tools"); } @Test @@ -200,7 +185,7 @@ void testCallTools() { @Test void testPingWithoutInitialization() { - verifyCallTimesOut(client -> client.ping(), "pinging the server"); + verifyCallSucceedsWithImplicitInitialization(client -> client.ping(), "pinging the server"); } @Test @@ -214,7 +199,7 @@ void testPing() { @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); - verifyCallTimesOut(client -> client.callTool(callToolRequest), "calling tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.callTool(callToolRequest), "calling tools"); } @Test @@ -243,7 +228,7 @@ void testCallToolWithInvalidTool() { @Test void testRootsListChangedWithoutInitialization() { - verifyNotificationTimesOut(client -> client.rootsListChangedNotification(), + verifyNotificationSucceedsWithImplicitInitialization(client -> client.rootsListChangedNotification(), "sending roots list changed notification"); } @@ -257,7 +242,7 @@ void testRootsListChanged() { @Test void testListResourcesWithoutInitialization() { - verifyCallTimesOut(client -> client.listResources(null), "listing resources"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResources(null), "listing resources"); } @Test @@ -333,8 +318,14 @@ void testRemoveNonExistentRoot() { @Test void testReadResourceWithoutInitialization() { - Resource resource = new Resource("test://uri", "Test Resource", null, null, null); - verifyCallTimesOut(client -> client.readResource(resource), "reading resources"); + AtomicReference> resources = new AtomicReference<>(); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + resources.set(mcpSyncClient.listResources().resources()); + }); + + verifyCallSucceedsWithImplicitInitialization(client -> client.readResource(resources.get().get(0)), + "reading resources"); } @Test @@ -355,7 +346,8 @@ void testReadResource() { @Test void testListResourceTemplatesWithoutInitialization() { - verifyCallTimesOut(client -> client.listResourceTemplates(null), "listing resource templates"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResourceTemplates(null), + "listing resource templates"); } @Test @@ -413,8 +405,8 @@ void testNotificationHandlers() { @Test void testLoggingLevelsWithoutInitialization() { - verifyNotificationTimesOut(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), - "setting logging level"); + verifyNotificationSucceedsWithImplicitInitialization( + client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), "setting logging level"); } @Test diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index 0925d476..ec7104b4 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -34,6 +34,7 @@ import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoSink; import reactor.core.publisher.Sinks; /** @@ -80,7 +81,9 @@ public class McpAsyncClient { private static TypeReference VOID_TYPE_REFERENCE = new TypeReference<>() { }; - private final AtomicReference initialization = new AtomicReference<>(); + // private final AtomicReference initialization = new + // AtomicReference<>(); + private final AtomicReference initialization = new AtomicReference<>(); /** * The max timeout to await for the client-server connection to be initialized. @@ -220,9 +223,10 @@ public class McpAsyncClient { } private void handleException(Throwable t) { + logger.warn("Handling exception", t); if (t instanceof McpSessionNotFoundException) { this.initialization.set(null); - this.initialize().subscribe(); + withInitializationCheck("re-initializing", result -> Mono.empty()).subscribe(); } } @@ -231,7 +235,8 @@ private void handleException(Throwable t) { * @return The server capabilities */ public McpSchema.ServerCapabilities getServerCapabilities() { - McpSchema.InitializeResult initializeResult = this.initialization.get(); + Initialization current = this.initialization.get(); + McpSchema.InitializeResult initializeResult = current != null ? current.result.get() : null; return initializeResult != null ? initializeResult.capabilities() : null; } @@ -241,7 +246,8 @@ public McpSchema.ServerCapabilities getServerCapabilities() { * @return The server instructions */ public String getServerInstructions() { - McpSchema.InitializeResult initializeResult = this.initialization.get(); + Initialization current = this.initialization.get(); + McpSchema.InitializeResult initializeResult = current != null ? current.result.get() : null; return initializeResult != null ? initializeResult.instructions() : null; } @@ -250,7 +256,8 @@ public String getServerInstructions() { * @return The server implementation details */ public McpSchema.Implementation getServerInfo() { - McpSchema.InitializeResult initializeResult = this.initialization.get(); + Initialization current = this.initialization.get(); + McpSchema.InitializeResult initializeResult = current != null ? current.result.get() : null; return initializeResult != null ? initializeResult.serverInfo() : null; } @@ -319,6 +326,10 @@ public Mono closeGracefully() { * Initialization Spec */ public Mono initialize() { + return withInitializationCheck("initialize", Mono::just); + } + + private Mono initialize0() { String latestVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); @@ -341,12 +352,27 @@ public Mono initialize() { "Unsupported protocol version from the server: " + initializeResult.protocolVersion())); } - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null).doOnSuccess(v -> { - this.initialization.set(initializeResult); - }).thenReturn(initializeResult); + return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null) + .thenReturn(initializeResult); }); } + private static class Initialization { + + private final Sinks.One initSink; + + private final AtomicReference result = new AtomicReference<>(); + + Initialization(Sinks.One initSink) { + this.initSink = initSink; + } + + static Initialization create() { + return new Initialization(Sinks.one()); + } + + } + /** * Utility method to handle the common pattern of checking initialization before * executing an operation. @@ -358,23 +384,28 @@ public Mono initialize() { private Mono withInitializationCheck(String actionName, Function> operation) { return Mono.defer(() -> { - McpSchema.InitializeResult initializeResult = this.initialization.get(); - // FIXME: in case of bursts this will trigger multiple inits, we have to batch - // requests - // and dispatch once a single init finishes - if (initializeResult != null) { - return operation.apply(initializeResult); + Initialization newInit = Initialization.create(); + Initialization current = this.initialization.compareAndExchange(null, newInit); + if (current == null) { + logger.info("Initialization process started"); } else { - return this.initialize() - .timeout(this.initializationTimeout) - // TODO: McpError should be used when communicating over the wire, not - // to - // the user of the client API - .onErrorResume(TimeoutException.class, - ex -> Mono.error(new McpError("Client must be initialized before " + actionName))) - .flatMap(operation); + logger.info("Joining previous initialization"); } + return (current != null ? current.initSink.asMono() : this.initialize0().doOnNext(result -> { + // first ensure the state is persisted + newInit.result.set(result); + // inform all the subscribers + newInit.initSink.emitValue(result, Sinks.EmitFailureHandler.FAIL_FAST); + }).onErrorResume(ex -> { + Initialization ongoing = this.initialization.getAndSet(null); + if (ongoing != null) { + ongoing.initSink.emitError(ex, Sinks.EmitFailureHandler.FAIL_FAST); + } + return Mono.error(ex); + })).timeout(this.initializationTimeout).onErrorResume(ex -> { + return Mono.error(new McpError("Client failed to initialize " + actionName)); + }).flatMap(operation); }); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 72b409af..14a8bcc8 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -111,14 +111,16 @@ void tearDown() { onClose(); } - void verifyInitializationTimeout(Function> operation, String action) { + void verifyNotificationSucceedsWithImplicitInitialization(Function> operation, + String action) { withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.withVirtualTime(() -> operation.apply(mcpAsyncClient)) - .expectSubscription() - .thenAwait(getInitializationTimeout()) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before " + action)) - .verify(); + StepVerifier.create(operation.apply(mcpAsyncClient)).verifyComplete(); + }); + } + + void verifyCallSucceedsWithImplicitInitialization(Function> operation, String action) { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(operation.apply(mcpAsyncClient)).expectNextCount(1).verifyComplete(); }); } @@ -134,7 +136,7 @@ void testConstructorWithInvalidArguments() { @Test void testListToolsWithoutInitialization() { - verifyInitializationTimeout(client -> client.listTools(null), "listing tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listTools(null), "listing tools"); } @Test @@ -154,7 +156,7 @@ void testListTools() { @Test void testPingWithoutInitialization() { - verifyInitializationTimeout(client -> client.ping(), "pinging the server"); + verifyCallSucceedsWithImplicitInitialization(client -> client.ping(), "pinging the server"); } @Test @@ -169,7 +171,7 @@ void testPing() { @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - verifyInitializationTimeout(client -> client.callTool(callToolRequest), "calling tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.callTool(callToolRequest), "calling tools"); } @Test @@ -203,7 +205,7 @@ void testCallToolWithInvalidTool() { @Test void testListResourcesWithoutInitialization() { - verifyInitializationTimeout(client -> client.listResources(null), "listing resources"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResources(null), "listing resources"); } @Test @@ -234,7 +236,7 @@ void testMcpAsyncClientState() { @Test void testListPromptsWithoutInitialization() { - verifyInitializationTimeout(client -> client.listPrompts(null), "listing " + "prompts"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listPrompts(null), "listing " + "prompts"); } @Test @@ -259,7 +261,7 @@ void testListPrompts() { @Test void testGetPromptWithoutInitialization() { GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - verifyInitializationTimeout(client -> client.getPrompt(request), "getting " + "prompts"); + verifyCallSucceedsWithImplicitInitialization(client -> client.getPrompt(request), "getting " + "prompts"); } @Test @@ -280,7 +282,7 @@ void testGetPrompt() { @Test void testRootsListChangedWithoutInitialization() { - verifyInitializationTimeout(client -> client.rootsListChangedNotification(), + verifyNotificationSucceedsWithImplicitInitialization(client -> client.rootsListChangedNotification(), "sending roots list changed notification"); } @@ -355,7 +357,8 @@ void testReadResource() { @Test void testListResourceTemplatesWithoutInitialization() { - verifyInitializationTimeout(client -> client.listResourceTemplates(), "listing resource templates"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResourceTemplates(), + "listing resource templates"); } @Test @@ -448,8 +451,8 @@ void testInitializeWithAllCapabilities() { @Test void testLoggingLevelsWithoutInitialization() { - verifyInitializationTimeout(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), - "setting logging level"); + verifyNotificationSucceedsWithImplicitInitialization( + client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), "setting logging level"); } @Test diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 24c161eb..c5c9ed2e 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -5,6 +5,7 @@ package io.modelcontextprotocol.client; import java.time.Duration; +import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; @@ -12,7 +13,6 @@ import java.util.function.Function; import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; @@ -113,33 +113,18 @@ void tearDown() { static final Object DUMMY_RETURN_VALUE = new Object(); - void verifyNotificationTimesOut(Consumer operation, String action) { - verifyCallTimesOut(client -> { + void verifyNotificationSucceedsWithImplicitInitialization(Consumer operation, String action) { + verifyCallSucceedsWithImplicitInitialization(client -> { operation.accept(client); return DUMMY_RETURN_VALUE; }, action); } - void verifyCallTimesOut(Function blockingOperation, String action) { + void verifyCallSucceedsWithImplicitInitialization(Function blockingOperation, String action) { withClient(createMcpTransport(), mcpSyncClient -> { - // This scheduler is not replaced by virtual time scheduler - Scheduler customScheduler = Schedulers.newBoundedElastic(1, 1, "actualBoundedElastic"); - - StepVerifier.withVirtualTime(() -> Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient)) - // Offload the blocking call to the real scheduler - .subscribeOn(customScheduler)) - .expectSubscription() - // This works without actually waiting but executes all the - // tasks pending execution on the VirtualTimeScheduler. - // It is possible to execute the blocking code from the operation - // because it is blocked on a dedicated Scheduler and the main - // flow is not blocked and uses the VirtualTimeScheduler. - .thenAwait(getInitializationTimeout()) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before " + action)) - .verify(); - - customScheduler.dispose(); + StepVerifier.create(Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient))) + .expectNextCount(1) + .verifyComplete(); }); } @@ -155,7 +140,7 @@ void testConstructorWithInvalidArguments() { @Test void testListToolsWithoutInitialization() { - verifyCallTimesOut(client -> client.listTools(null), "listing tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listTools(null), "listing tools"); } @Test @@ -176,8 +161,8 @@ void testListTools() { @Test void testCallToolsWithoutInitialization() { - verifyCallTimesOut(client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))), - "calling tools"); + verifyCallSucceedsWithImplicitInitialization( + client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))), "calling tools"); } @Test @@ -201,7 +186,7 @@ void testCallTools() { @Test void testPingWithoutInitialization() { - verifyCallTimesOut(client -> client.ping(), "pinging the server"); + verifyCallSucceedsWithImplicitInitialization(client -> client.ping(), "pinging the server"); } @Test @@ -215,7 +200,7 @@ void testPing() { @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); - verifyCallTimesOut(client -> client.callTool(callToolRequest), "calling tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.callTool(callToolRequest), "calling tools"); } @Test @@ -244,7 +229,7 @@ void testCallToolWithInvalidTool() { @Test void testRootsListChangedWithoutInitialization() { - verifyNotificationTimesOut(client -> client.rootsListChangedNotification(), + verifyNotificationSucceedsWithImplicitInitialization(client -> client.rootsListChangedNotification(), "sending roots list changed notification"); } @@ -258,7 +243,7 @@ void testRootsListChanged() { @Test void testListResourcesWithoutInitialization() { - verifyCallTimesOut(client -> client.listResources(null), "listing resources"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResources(null), "listing resources"); } @Test @@ -334,8 +319,14 @@ void testRemoveNonExistentRoot() { @Test void testReadResourceWithoutInitialization() { - Resource resource = new Resource("test://uri", "Test Resource", null, null, null); - verifyCallTimesOut(client -> client.readResource(resource), "reading resources"); + AtomicReference> resources = new AtomicReference<>(); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + resources.set(mcpSyncClient.listResources().resources()); + }); + + verifyCallSucceedsWithImplicitInitialization(client -> client.readResource(resources.get().get(0)), + "reading resources"); } @Test @@ -356,7 +347,8 @@ void testReadResource() { @Test void testListResourceTemplatesWithoutInitialization() { - verifyCallTimesOut(client -> client.listResourceTemplates(null), "listing resource templates"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResourceTemplates(null), + "listing resource templates"); } @Test @@ -414,8 +406,8 @@ void testNotificationHandlers() { @Test void testLoggingLevelsWithoutInitialization() { - verifyNotificationTimesOut(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), - "setting logging level"); + verifyNotificationSucceedsWithImplicitInitialization( + client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), "setting logging level"); } @Test From b91380e9b80e89afba97811de0e44ec386ab1adf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Tue, 3 Jun 2025 19:46:13 +0200 Subject: [PATCH 10/20] Cleanup, javadoc, refactor --- mcp-spring/mcp-spring-webflux/pom.xml | 11 +- .../client/transport/Main.java | 57 +++++++-- .../WebClientStreamableHttpTransport.java | 53 +++++---- .../src/main/resources/logback.xml | 18 +++ .../MockMcpTransport.java | 13 --- ...AbstractMcpAsyncClientResiliencyTests.java | 25 +++- .../client/McpAsyncClient.java | 109 +++++++++--------- .../spec/DefaultMcpTransportSession.java | 57 +++++++++ .../spec/McpClientSession.java | 56 ++++----- .../spec/McpClientTransport.java | 26 +++-- .../modelcontextprotocol/spec/McpSchema.java | 6 +- .../spec/McpSessionNotFoundException.java | 6 + .../spec/McpTransportSession.java | 88 ++++++++------ 13 files changed, 347 insertions(+), 178 deletions(-) create mode 100644 mcp-spring/mcp-spring-webflux/src/main/resources/logback.xml create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java diff --git a/mcp-spring/mcp-spring-webflux/pom.xml b/mcp-spring/mcp-spring-webflux/pom.xml index 26452fe9..4da36f11 100644 --- a/mcp-spring/mcp-spring-webflux/pom.xml +++ b/mcp-spring/mcp-spring-webflux/pom.xml @@ -88,6 +88,15 @@ ${byte-buddy.version} test + + org.springframework + spring-context + 6.2.6 + + + io.projectreactor.netty + reactor-netty-http + io.projectreactor reactor-test @@ -117,7 +126,7 @@ ch.qos.logback logback-classic ${logback.version} - test + diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/Main.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/Main.java index e501ca1c..f7c4b30f 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/Main.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/Main.java @@ -3,15 +3,29 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.client.McpAsyncClient; import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.spec.McpSchema; +import io.netty.channel.socket.nio.NioChannelOption; +import jdk.net.ExtendedSocketOptions; +import org.springframework.http.client.reactive.ReactorClientHttpConnector; import org.springframework.web.reactive.function.client.WebClient; +import reactor.netty.http.client.HttpClient; + +import java.util.List; +import java.util.Map; +import java.util.Scanner; +import java.util.concurrent.atomic.AtomicReference; public class Main { - public static void main(String[] args) { - McpAsyncClient client = McpClient - .async(new WebClientStreamableHttpTransport(new ObjectMapper(), - WebClient.builder().baseUrl("http://localhost:3001"), "/mcp", true, false)) + public static void main(String[] args) throws InterruptedException { + McpSyncClient client = McpClient + .sync(new WebClientStreamableHttpTransport(new ObjectMapper(), + WebClient.builder() + .clientConnector(new ReactorClientHttpConnector( + HttpClient.create().option(NioChannelOption.of(ExtendedSocketOptions.TCP_KEEPIDLE), 5))) + .baseUrl("http://localhost:3001"), + "/mcp", true, false)) .build(); /* @@ -34,11 +48,36 @@ public static void main(String[] args) { * tools 6. -> 2xx response */ - client.initialize() - .flatMap(r -> client.listTools()) - .map(McpSchema.ListToolsResult::tools) - .doOnNext(System.out::println) - .block(); + List tools = null; + while (tools == null) { + try { + client.initialize(); + tools = client.listTools().tools(); + } + catch (Exception e) { + System.out.println("Got exception. Retrying in 5s. " + e); + Thread.sleep(5000); + } + } + + Scanner scanner = new Scanner(System.in); + while (scanner.hasNext()) { + String text = scanner.nextLine(); + if (text == null || text.isEmpty()) { + System.out.println("Done"); + break; + } + try { + McpSchema.CallToolResult result = client + .callTool(new McpSchema.CallToolRequest(tools.get(0).name(), Map.of("message", text))); + System.out.println("Tool call result: " + result); + } + catch (Exception e) { + System.out.println("Error calling tool " + e); + } + } + + client.closeGracefully(); } } diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java index 1c3656d8..22dbdb0b 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -2,6 +2,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.DefaultMcpTransportSession; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; @@ -51,12 +52,13 @@ public class WebClientStreamableHttpTransport implements McpClientTransport { private final boolean resumableStreams; - private final AtomicReference activeSession = new AtomicReference<>(); + private final AtomicReference activeSession = new AtomicReference<>(); private final AtomicReference, Mono>> handler = new AtomicReference<>(); private final AtomicReference> exceptionHandler = new AtomicReference<>(); + // TODO: builder public WebClientStreamableHttpTransport(ObjectMapper objectMapper, WebClient.Builder webClientBuilder, String endpoint, boolean resumableStreams, boolean openConnectionOnStartup) { this.objectMapper = objectMapper; @@ -64,7 +66,7 @@ public WebClientStreamableHttpTransport(ObjectMapper objectMapper, WebClient.Bui this.endpoint = endpoint; this.resumableStreams = resumableStreams; this.openConnectionOnStartup = openConnectionOnStartup; - this.activeSession.set(new McpTransportSession()); + this.activeSession.set(new DefaultMcpTransportSession()); } @Override @@ -80,15 +82,15 @@ public Mono connect(Function, Mono handler) { + public void setExceptionHandler(Consumer handler) { logger.debug("Exception handler registered"); this.exceptionHandler.set(handler); } private void handleException(Throwable t) { - logger.debug("Handling exception for session {}", activeSession.get().sessionId(), t); + logger.debug("Handling exception for session {}", sessionIdRepresentation(this.activeSession.get()), t); if (t instanceof McpSessionNotFoundException) { - McpTransportSession invalidSession = this.activeSession.getAndSet(new McpTransportSession()); + McpTransportSession invalidSession = this.activeSession.getAndSet(new DefaultMcpTransportSession()); logger.warn("Server does not recognize session {}. Invalidating.", invalidSession.sessionId()); invalidSession.close(); } @@ -102,7 +104,7 @@ private void handleException(Throwable t) { public Mono closeGracefully() { return Mono.defer(() -> { logger.debug("Graceful close triggered"); - McpTransportSession currentSession = this.activeSession.get(); + DefaultMcpTransportSession currentSession = this.activeSession.get(); if (currentSession != null) { return currentSession.closeGracefully(); } @@ -125,16 +127,14 @@ private void reconnect(McpStream stream, ContextView ctx) { // listen for messages. // If it doesn't, nothing actually happens here, that's just the way it is... final AtomicReference disposableRef = new AtomicReference<>(); - final McpTransportSession transportSession = this.activeSession.get(); + final McpTransportSession transportSession = this.activeSession.get(); Disposable connection = webClient.get() .uri(this.endpoint) .accept(MediaType.TEXT_EVENT_STREAM) .headers(httpHeaders -> { - if (transportSession.sessionId() != null) { - httpHeaders.add("mcp-session-id", transportSession.sessionId()); - } - if (stream != null && stream.lastId() != null) { - httpHeaders.add("last-event-id", stream.lastId()); + transportSession.sessionId().ifPresent(id -> httpHeaders.add("mcp-session-id", id)); + if (stream != null) { + stream.lastId().ifPresent(id -> httpHeaders.add("last-event-id", id)); } }) .exchangeToFlux(response -> { @@ -161,7 +161,7 @@ else if (response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND)) { logger.warn("Session {} was not found on the MCP server", transportSession.sessionId()); McpSessionNotFoundException notFoundException = new McpSessionNotFoundException( - transportSession.sessionId()); + sessionIdRepresentation(transportSession)); // inform the stream/connection subscriber return Flux.error(notFoundException); } @@ -187,6 +187,10 @@ else if (response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND)) { transportSession.addConnection(connection); } + private static String sessionIdRepresentation(McpTransportSession transportSession) { + return transportSession.sessionId().orElse("[missing_session_id]"); + } + @Override public Mono sendMessage(McpSchema.JSONRPCMessage message) { return Mono.create(sink -> { @@ -197,15 +201,13 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { // listen for messages. // If it doesn't, nothing actually happens here, that's just the way it is... final AtomicReference disposableRef = new AtomicReference<>(); - final McpTransportSession transportSession = this.activeSession.get(); + final McpTransportSession transportSession = this.activeSession.get(); Disposable connection = webClient.post() .uri(this.endpoint) .accept(MediaType.TEXT_EVENT_STREAM, MediaType.APPLICATION_JSON) .headers(httpHeaders -> { - if (transportSession.sessionId() != null) { - httpHeaders.add("mcp-session-id", transportSession.sessionId()); - } + transportSession.sessionId().ifPresent(id -> httpHeaders.add("mcp-session-id", id)); }) .bodyValue(message) .exchangeToFlux(response -> { @@ -287,7 +289,7 @@ else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) { logger.warn("Session {} was not found on the MCP server", transportSession.sessionId()); McpSessionNotFoundException notFoundException = new McpSessionNotFoundException( - transportSession.sessionId()); + sessionIdRepresentation(transportSession)); // inform the stream/connection subscriber return Flux.error(notFoundException); } @@ -313,8 +315,8 @@ else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) { // invalidate the session // https://github.com/modelcontextprotocol/typescript-sdk/issues/389 if (responseException.getStatusCode().isSameCodeAs(HttpStatus.BAD_REQUEST)) { - return Mono.error(new McpSessionNotFoundException(this.activeSession.get().sessionId(), - toPropagate)); + return Mono.error(new McpSessionNotFoundException( + sessionIdRepresentation(this.activeSession.get()), toPropagate)); } return Mono.empty(); }).flux(); @@ -381,8 +383,8 @@ private class McpStream { this.resumable = resumable; } - String lastId() { - return this.lastId.get(); + Optional lastId() { + return Optional.ofNullable(this.lastId.get()); } long streamId() { @@ -395,9 +397,10 @@ Flux consumeSseStream( if (resumable && !(e instanceof McpSessionNotFoundException)) { reconnect(this, ctx); } - }) - .doOnNext(idAndMessage -> idAndMessage.getT1().ifPresent(this.lastId::set)) - .flatMapIterable(Tuple2::getT2)); + }).doOnNext(idAndMessage -> idAndMessage.getT1().ifPresent(id -> { + String previousId = this.lastId.getAndSet(id); + logger.debug("Updating last id {} -> {} for stream {}", previousId, id, this.streamId); + })).flatMapIterable(Tuple2::getT2)); } } diff --git a/mcp-spring/mcp-spring-webflux/src/main/resources/logback.xml b/mcp-spring/mcp-spring-webflux/src/main/resources/logback.xml new file mode 100644 index 00000000..e38239e7 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/main/resources/logback.xml @@ -0,0 +1,18 @@ + + + + + + + %d{yyyy-MM-dd HH:mm:ss} [%thread] %-5level %logger{36} - %msg%n + + + + + + + + + + + diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java b/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java index 06ce2804..5484a63c 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java @@ -7,7 +7,6 @@ import java.util.ArrayList; import java.util.List; import java.util.function.BiConsumer; -import java.util.function.Consumer; import java.util.function.Function; import com.fasterxml.jackson.core.type.TypeReference; @@ -71,18 +70,6 @@ public McpSchema.JSONRPCMessage getLastSentMessage() { private volatile boolean connected = false; - // @Override - // public Mono connect(Consumer consumer) { - // if (connected) { - // return Mono.error(new IllegalStateException("Already connected")); - // } - // connected = true; - // return inbound.asFlux() - // .doOnNext(consumer) - // .doFinally(signal -> connected = false) - // .then(); - // } - @Override public Mono connect(Function, Mono> handler) { if (connected) { diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java index 39f601d7..9990b8a3 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java @@ -16,6 +16,8 @@ import java.io.IOException; import java.time.Duration; +import java.util.List; +import java.util.Map; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; @@ -60,7 +62,6 @@ public abstract class AbstractMcpAsyncClientResiliencyTests { final String ipAddressViaToxiproxy = toxiproxy.getHost(); final int portViaToxiproxy = toxiproxy.getMappedPort(3000); - // int port = container.getMappedPort(3001); host = "http://" + ipAddressViaToxiproxy + ":" + portViaToxiproxy; } @@ -172,4 +173,26 @@ void testSessionInvalidation() { }); } + @Test + void testCallTool() { + withClient(createMcpTransport(), mcpAsyncClient -> { + AtomicReference> tools = new AtomicReference<>(); + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + StepVerifier.create(mcpAsyncClient.listTools()) + .consumeNextWith(list -> tools.set(list.tools())) + .verifyComplete(); + + disconnect(); + + String name = tools.get().get(0).name(); + // Assuming this is the echo tool + McpSchema.CallToolRequest request = new McpSchema.CallToolRequest(name, Map.of("message", "hello")); + StepVerifier.create(mcpAsyncClient.callTool(request)).expectError().verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.callTool(request)).expectNextCount(1).verifyComplete(); + }); + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index ec7104b4..38eaef97 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -9,8 +9,6 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.TimeoutException; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; @@ -34,7 +32,6 @@ import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoSink; import reactor.core.publisher.Sinks; /** @@ -78,11 +75,9 @@ public class McpAsyncClient { private static final Logger logger = LoggerFactory.getLogger(McpAsyncClient.class); - private static TypeReference VOID_TYPE_REFERENCE = new TypeReference<>() { + private static final TypeReference VOID_TYPE_REFERENCE = new TypeReference<>() { }; - // private final AtomicReference initialization = new - // AtomicReference<>(); private final AtomicReference initialization = new AtomicReference<>(); /** @@ -226,7 +221,7 @@ private void handleException(Throwable t) { logger.warn("Handling exception", t); if (t instanceof McpSessionNotFoundException) { this.initialization.set(null); - withInitializationCheck("re-initializing", result -> Mono.empty()).subscribe(); + withInitialization("re-initializing", result -> Mono.empty()).subscribe(); } } @@ -266,7 +261,8 @@ public McpSchema.Implementation getServerInfo() { * @return true if the client-server connection is initialized */ public boolean isInitialized() { - return this.initialization.get() != null; + Initialization current = this.initialization.get(); + return current != null && (current.result.get() != null); } /** @@ -304,7 +300,10 @@ public Mono closeGracefully() { // Initialization // -------------------------- /** - * The initialization phase MUST be the first interaction between client and server. + * The initialization phase should be the first interaction between client and server. + * The client will ensure it happens in case it has not been explicitly called and in + * case of transport session invalidation. + *

* During this phase, the client and server: *

    *
  • Establish protocol version compatibility
  • @@ -324,13 +323,13 @@ public Mono closeGracefully() { * @see MCP * Initialization Spec + *

    */ public Mono initialize() { - return withInitializationCheck("initialize", Mono::just); + return withInitialization("by explicit API call", Mono::just); } - private Mono initialize0() { - + private Mono doInitialize() { String latestVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); McpSchema.InitializeRequest initializeRequest = new McpSchema.InitializeRequest(// @formatter:off @@ -359,53 +358,49 @@ private Mono initialize0() { private static class Initialization { - private final Sinks.One initSink; + private final Sinks.One initSink = Sinks.one(); private final AtomicReference result = new AtomicReference<>(); - Initialization(Sinks.One initSink) { - this.initSink = initSink; - } - static Initialization create() { - return new Initialization(Sinks.one()); + return new Initialization(); } } /** - * Utility method to handle the common pattern of checking initialization before + * Utility method to handle the common pattern of ensuring initialization before * executing an operation. * @param The type of the result Mono - * @param actionName The action to perform if the client is initialized - * @param operation The operation to execute if the client is initialized + * @param actionName The action to perform when the client is initialized + * @param operation The operation to execute when the client is initialized * @return A Mono that completes with the result of the operation */ - private Mono withInitializationCheck(String actionName, - Function> operation) { + private Mono withInitialization(String actionName, Function> operation) { return Mono.defer(() -> { Initialization newInit = Initialization.create(); - Initialization current = this.initialization.compareAndExchange(null, newInit); - if (current == null) { - logger.info("Initialization process started"); - } - else { - logger.info("Joining previous initialization"); - } - return (current != null ? current.initSink.asMono() : this.initialize0().doOnNext(result -> { - // first ensure the state is persisted - newInit.result.set(result); - // inform all the subscribers - newInit.initSink.emitValue(result, Sinks.EmitFailureHandler.FAIL_FAST); - }).onErrorResume(ex -> { - Initialization ongoing = this.initialization.getAndSet(null); - if (ongoing != null) { - ongoing.initSink.emitError(ex, Sinks.EmitFailureHandler.FAIL_FAST); - } - return Mono.error(ex); - })).timeout(this.initializationTimeout).onErrorResume(ex -> { - return Mono.error(new McpError("Client failed to initialize " + actionName)); - }).flatMap(operation); + Initialization previous = this.initialization.compareAndExchange(null, newInit); + + boolean needsToInitialize = previous == null; + logger.info(needsToInitialize ? "Initialization process started" : "Joining previous initialization"); + + Mono initializationJob = needsToInitialize + ? this.doInitialize().doOnNext(result -> { + // first ensure the result is cached + newInit.result.set(result); + // inform all the subscribers waiting for the initialization + newInit.initSink.emitValue(result, Sinks.EmitFailureHandler.FAIL_FAST); + }).onErrorResume(ex -> { + Initialization ongoing = this.initialization.getAndSet(null); + if (ongoing != null) { + ongoing.initSink.emitError(ex, Sinks.EmitFailureHandler.FAIL_FAST); + } + return Mono.error(ex); + }) : previous.initSink.asMono(); + + return initializationJob.timeout(this.initializationTimeout) + .onErrorResume(ex -> Mono.error(new McpError("Client failed to initialize " + actionName))) + .flatMap(operation); }); } @@ -418,7 +413,7 @@ private Mono withInitializationCheck(String actionName, * @return A Mono that completes with the server's ping response */ public Mono ping() { - return this.withInitializationCheck("pinging the server", initializedResult -> this.mcpSession + return this.withInitialization("pinging the server", initializedResult -> this.mcpSession .sendRequest(McpSchema.METHOD_PING, null, new TypeReference() { })); } @@ -500,7 +495,7 @@ public Mono removeRoot(String rootUri) { * @return A Mono that completes when the notification is sent. */ public Mono rootsListChangedNotification() { - return this.withInitializationCheck("sending roots list changed notification", + return this.withInitialization("sending roots list changed notification", initResult -> this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED)); } @@ -551,7 +546,7 @@ private RequestHandler samplingCreateMessageHandler() { * @see #listTools() */ public Mono callTool(McpSchema.CallToolRequest callToolRequest) { - return this.withInitializationCheck("calling tools", initializedResult -> { + return this.withInitialization("calling tools", initializedResult -> { if (initializedResult.capabilities().tools() == null) { return Mono.error(new McpError("Server does not provide tools capability")); } @@ -573,7 +568,7 @@ public Mono listTools() { * @return A Mono that emits the list of tools result */ public Mono listTools(String cursor) { - return this.withInitializationCheck("listing tools", initializedResult -> { + return this.withInitialization("listing tools", initializedResult -> { if (initializedResult.capabilities().tools() == null) { return Mono.error(new McpError("Server does not provide tools capability")); } @@ -630,7 +625,7 @@ public Mono listResources() { * @see #readResource(McpSchema.Resource) */ public Mono listResources(String cursor) { - return this.withInitializationCheck("listing resources", initializedResult -> { + return this.withInitialization("listing resources", initializedResult -> { if (initializedResult.capabilities().resources() == null) { return Mono.error(new McpError("Server does not provide the resources capability")); } @@ -661,7 +656,7 @@ public Mono readResource(McpSchema.Resource resour * @see McpSchema.ReadResourceResult */ public Mono readResource(McpSchema.ReadResourceRequest readResourceRequest) { - return this.withInitializationCheck("reading resources", initializedResult -> { + return this.withInitialization("reading resources", initializedResult -> { if (initializedResult.capabilities().resources() == null) { return Mono.error(new McpError("Server does not provide the resources capability")); } @@ -690,7 +685,7 @@ public Mono listResourceTemplates() { * @see McpSchema.ListResourceTemplatesResult */ public Mono listResourceTemplates(String cursor) { - return this.withInitializationCheck("listing resource templates", initializedResult -> { + return this.withInitialization("listing resource templates", initializedResult -> { if (initializedResult.capabilities().resources() == null) { return Mono.error(new McpError("Server does not provide the resources capability")); } @@ -709,7 +704,7 @@ public Mono listResourceTemplates(String * @see #unsubscribeResource(McpSchema.UnsubscribeRequest) */ public Mono subscribeResource(McpSchema.SubscribeRequest subscribeRequest) { - return this.withInitializationCheck("subscribing to resources", initializedResult -> this.mcpSession + return this.withInitialization("subscribing to resources", initializedResult -> this.mcpSession .sendRequest(McpSchema.METHOD_RESOURCES_SUBSCRIBE, subscribeRequest, VOID_TYPE_REFERENCE)); } @@ -723,7 +718,7 @@ public Mono subscribeResource(McpSchema.SubscribeRequest subscribeRequest) * @see #subscribeResource(McpSchema.SubscribeRequest) */ public Mono unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest) { - return this.withInitializationCheck("unsubscribing from resources", initializedResult -> this.mcpSession + return this.withInitialization("unsubscribing from resources", initializedResult -> this.mcpSession .sendRequest(McpSchema.METHOD_RESOURCES_UNSUBSCRIBE, unsubscribeRequest, VOID_TYPE_REFERENCE)); } @@ -765,7 +760,7 @@ public Mono listPrompts() { * @see #getPrompt(GetPromptRequest) */ public Mono listPrompts(String cursor) { - return this.withInitializationCheck("listing prompts", initializedResult -> this.mcpSession + return this.withInitialization("listing prompts", initializedResult -> this.mcpSession .sendRequest(McpSchema.METHOD_PROMPT_LIST, new PaginatedRequest(cursor), LIST_PROMPTS_RESULT_TYPE_REF)); } @@ -779,7 +774,7 @@ public Mono listPrompts(String cursor) { * @see #listPrompts() */ public Mono getPrompt(GetPromptRequest getPromptRequest) { - return this.withInitializationCheck("getting prompts", initializedResult -> this.mcpSession + return this.withInitialization("getting prompts", initializedResult -> this.mcpSession .sendRequest(McpSchema.METHOD_PROMPT_GET, getPromptRequest, GET_PROMPT_RESULT_TYPE_REF)); } @@ -831,7 +826,7 @@ public Mono setLoggingLevel(LoggingLevel loggingLevel) { return Mono.error(new McpError("Logging level must not be null")); } - return this.withInitializationCheck("setting logging level", initializedResult -> { + return this.withInitialization("setting logging level", initializedResult -> { var params = new McpSchema.SetLevelRequest(loggingLevel); return this.mcpSession.sendRequest(McpSchema.METHOD_LOGGING_SET_LEVEL, params, new TypeReference() { }).then(); @@ -864,7 +859,7 @@ void setProtocolVersions(List protocolVersions) { * @see McpSchema.CompleteResult */ public Mono completeCompletion(McpSchema.CompleteRequest completeRequest) { - return this.withInitializationCheck("complete completions", initializedResult -> this.mcpSession + return this.withInitialization("complete completions", initializedResult -> this.mcpSession .sendRequest(McpSchema.METHOD_COMPLETION_COMPLETE, completeRequest, COMPLETION_COMPLETE_RESULT_TYPE_REF)); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java new file mode 100644 index 00000000..6bda390c --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java @@ -0,0 +1,57 @@ +package io.modelcontextprotocol.spec; + +import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.publisher.Mono; + +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +public class DefaultMcpTransportSession implements McpTransportSession { + + private final Disposable.Composite openConnections = Disposables.composite(); + + private final AtomicBoolean initialized = new AtomicBoolean(false); + + private final AtomicReference sessionId = new AtomicReference<>(); + + public DefaultMcpTransportSession() { + } + + @Override + public Optional sessionId() { + return Optional.ofNullable(this.sessionId.get()); + } + + @Override + public void setSessionId(String sessionId) { + this.sessionId.set(sessionId); + } + + @Override + public boolean markInitialized() { + return this.initialized.compareAndSet(false, true); + } + + @Override + public void addConnection(Disposable connection) { + this.openConnections.add(connection); + } + + @Override + public void removeConnection(Disposable connection) { + this.openConnections.remove(connection); + } + + @Override + public void close() { + this.closeGracefully().subscribe(); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(this.openConnections::dispose); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index 9ed45766..db219eee 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -15,7 +15,6 @@ import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import reactor.core.Disposable; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoSink; @@ -38,9 +37,12 @@ */ public class McpClientSession implements McpSession { - /** Logger for this class */ private static final Logger logger = LoggerFactory.getLogger(McpClientSession.class); + private static final Consumer DEFAULT_CONSUMER = t -> { + logger.warn("MCP transport issued an exception", t); + }; + /** Duration to wait for request responses before timing out */ private final Duration requestTimeout; @@ -62,8 +64,6 @@ public class McpClientSession implements McpSession { /** Atomic counter for generating unique request IDs */ private final AtomicLong requestCounter = new AtomicLong(0); - private final Disposable connection; - /** * Functional interface for handling incoming JSON-RPC requests. Implementations * should process the request parameters and return a response. @@ -102,6 +102,8 @@ public interface NotificationHandler { * Creates a new McpClientSession with the specified configuration and handlers. * @param requestTimeout Duration to wait for responses * @param transport Transport implementation for message exchange + * @param exceptionHandler A hook to take action when transport level exceptions + * happen. * @param requestHandlers Map of method names to request handlers * @param notificationHandlers Map of method names to notification handlers */ @@ -118,19 +120,22 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport, C this.requestHandlers.putAll(requestHandlers); this.notificationHandlers.putAll(notificationHandlers); - // TODO: consider mono.transformDeferredContextual where the Context contains - // the - // Observation associated with the individual message - it can be used to - // create child Observation and emit it together with the message to the - // consumer - this.connection = this.transport.connect(mono -> mono.doOnNext(this::handle)).subscribe(); - this.transport.registerExceptionHandler(t -> { - // 🤔 let's think for a moment - we only clear when the session is invalidated + this.transport.setExceptionHandler(t -> { + // We only clear when the session is invalidated if (t instanceof McpSessionNotFoundException) { this.pendingResponses.clear(); } exceptionHandler.accept(t); }); + this.transport.connect(mono -> mono.doOnNext(this::handle)).subscribe(); + } + + private void dismissPendingResponses() { + this.pendingResponses.forEach((id, sink) -> { + logger.warn("Abruptly terminating exchange for request {}", id); + sink.error(new RuntimeException("MCP session with server terminated")); + }); + this.pendingResponses.clear(); } /** @@ -139,11 +144,13 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport, C * @param transport Transport implementation for message exchange * @param requestHandlers Map of method names to request handlers * @param notificationHandlers Map of method names to notification handlers + * @deprecated Use + * {@link #McpClientSession(Duration, McpClientTransport, Consumer, Map, Map)}. */ + @Deprecated public McpClientSession(Duration requestTimeout, McpClientTransport transport, Map> requestHandlers, Map notificationHandlers) { - this(requestTimeout, transport, e -> { - }, requestHandlers, notificationHandlers); + this(requestTimeout, transport, DEFAULT_CONSUMER, requestHandlers, notificationHandlers); } private void handle(McpSchema.JSONRPCMessage message) { @@ -256,14 +263,11 @@ public Mono sendRequest(String method, Object requestParams, TypeReferenc this.pendingResponses.put(requestId, sink); McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method, requestId, requestParams); - this.transport.sendMessage(jsonrpcRequest) - .contextWrite(ctx) - // TODO: It's most efficient to create a dedicated Subscriber here - .subscribe(v -> { - }, error -> { - this.pendingResponses.remove(requestId); - sink.error(error); - }); + this.transport.sendMessage(jsonrpcRequest).contextWrite(ctx).subscribe(v -> { + }, error -> { + this.pendingResponses.remove(requestId); + sink.error(error); + }); })).timeout(this.requestTimeout).handle((jsonRpcResponse, sink) -> { if (jsonRpcResponse.error() != null) { logger.error("Error handling request: {}", jsonRpcResponse.error()); @@ -299,10 +303,8 @@ public Mono sendNotification(String method, Object params) { */ @Override public Mono closeGracefully() { - return Mono.defer(() -> { - this.connection.dispose(); - return transport.closeGracefully(); - }); + dismissPendingResponses(); + return this.transport.closeGracefully(); } /** @@ -310,7 +312,7 @@ public Mono closeGracefully() { */ @Override public void close() { - this.connection.dispose(); + dismissPendingResponses(); transport.close(); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java index 6550121c..5c3b3313 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java @@ -9,22 +9,32 @@ import reactor.core.publisher.Mono; /** - * Marker interface for the client-side MCP transport. + * Interface for the client side of the {@link McpTransport}. It allows setting handlers + * for messages that are incoming from the MCP server and hooking in to exceptions raised + * on the transport layer. * * @author Christian Tzolov * @author Dariusz Jędrzejczyk */ public interface McpClientTransport extends McpTransport { - // @Deprecated + /** + * Used to register the incoming messages' handler and potentially (eagerly) connect + * to the server. + * @param handler a transformer for incoming messages + * @return a {@link Mono} that terminates upon successful client setup. It can mean + * establishing a connection (which can be later disposed) but it doesn't have to, + * depending on the transport type. The successful termination of the returned + * {@link Mono} simply means the client can now be used. An error can be retried + * according to the application requirements. + */ Mono connect(Function, Mono> handler); - default void registerExceptionHandler(Consumer handler) { + /** + * Sets the exception handler for exceptions raised on the transport layer. + * @param handler Allows reacting to transport level exceptions by the higher layers + */ + default void setExceptionHandler(Consumer handler) { } - // default void connect(Consumer consumer) { - // this.connect((Function, - // Mono>) mono -> mono.doOnNext(consumer)).subscribe(); - // } - } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 33ab8571..5a91b475 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -10,7 +10,11 @@ import java.util.List; import java.util.Map; -import com.fasterxml.jackson.annotation.*; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.fasterxml.jackson.annotation.JsonTypeInfo.As; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSessionNotFoundException.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSessionNotFoundException.java index 1281d087..be07eea1 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSessionNotFoundException.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSessionNotFoundException.java @@ -1,5 +1,11 @@ package io.modelcontextprotocol.spec; +/** + * Exception that signifies that the server does not recognize the connecting client via + * the presented transport session identifier. + * + * @author Dariusz Jędrzejczyk + */ public class McpSessionNotFoundException extends RuntimeException { public McpSessionNotFoundException(String sessionId, Exception cause) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java index dd5c108e..902db90d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java @@ -1,49 +1,65 @@ package io.modelcontextprotocol.spec; -import reactor.core.Disposable; -import reactor.core.Disposables; -import reactor.core.publisher.Mono; +import org.reactivestreams.Publisher; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; +import java.util.Optional; -public class McpTransportSession { +/** + * An abstraction of the session as perceived from the MCP transport layer. Not to be + * confused with the {@link McpSession} type that operates at the level of the JSON-RPC + * communication protocol and matches asynchronous responses with previously issued + * requests. + * + * @param the resource representing the connection that the transport + * manages. + * @author Dariusz Jędrzejczyk + */ +public interface McpTransportSession { - private final Disposable.Composite openConnections = Disposables.composite(); + /** + * In case of stateful MCP servers, the value is present and contains the String + * identifier for the transport-level session. + * @return optional session id + */ + Optional sessionId(); - private final AtomicBoolean initialized = new AtomicBoolean(false); + /** + * If the transport provides a session id for the communication, this method should be + * called to record the current identifier. + * @param sessionId session identifier as provided by the server + */ + void setSessionId(String sessionId); - private final AtomicReference sessionId = new AtomicReference<>(); + /** + * Stateful operation that flips the un-initialized state to initialized if this is + * the first call. + * @return if successful, this method returns {@code true} and means that a + * post-initialization step can be performed + */ + boolean markInitialized(); - public McpTransportSession() { - } + /** + * Adds a resource that this transport session can monitor and dismiss when needed. + * @param connection the managed resource + */ + void addConnection(CONNECTION connection); - public String sessionId() { - return this.sessionId.get(); - } + /** + * Called when the resource is terminating by itself and the transport session does + * not need to track it anymore. + * @param connection the resource to remove from the monitored collection + */ + void removeConnection(CONNECTION connection); - public void setSessionId(String sessionId) { - this.sessionId.set(sessionId); - } + /** + * Close and clear the monitored resources. Potentially asynchronous. + */ + void close(); - public boolean markInitialized() { - return this.initialized.compareAndSet(false, true); - } - - public void addConnection(Disposable connection) { - this.openConnections.add(connection); - } - - public void removeConnection(Disposable connection) { - this.openConnections.remove(connection); - } - - public void close() { - this.closeGracefully().subscribe(); - } - - public Mono closeGracefully() { - return Mono.fromRunnable(this.openConnections::dispose); - } + /** + * Close and clear the monitored resources in a graceful manner. + * @return completes once all resources have been dismissed + */ + Publisher closeGracefully(); } From 7a54c6a6aaf41269f45dc0377a379829a8601d81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Tue, 3 Jun 2025 23:36:31 +0200 Subject: [PATCH 11/20] Bulletproof initialization --- .../client/McpAsyncClient.java | 201 +++++++++++------- .../spec/McpClientSession.java | 47 +--- 2 files changed, 137 insertions(+), 111 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index 38eaef97..e9535a76 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -11,6 +11,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; +import java.util.function.Supplier; import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.spec.*; @@ -78,6 +79,21 @@ public class McpAsyncClient { private static final TypeReference VOID_TYPE_REFERENCE = new TypeReference<>() { }; + public static final TypeReference OBJECT_TYPE_REF = new TypeReference<>() { + }; + + public static final TypeReference PAGINATED_REQUEST_TYPE_REF = new TypeReference<>() { + }; + + public static final TypeReference INITIALIZE_RESULT_TYPE_REF = new TypeReference<>() { + }; + + public static final TypeReference CREATE_MESSAGE_REQUEST_TYPE_REF = new TypeReference<>() { + }; + + public static final TypeReference LOGGING_MESSAGE_NOTIFICATION_TYPE_REF = new TypeReference<>() { + }; + private final AtomicReference initialization = new AtomicReference<>(); /** @@ -85,12 +101,6 @@ public class McpAsyncClient { */ private final Duration initializationTimeout; - /** - * The MCP session implementation that manages bidirectional JSON-RPC communication - * between clients and servers. - */ - private final McpClientSession mcpSession; - /** * Client capabilities. */ @@ -122,13 +132,19 @@ public class McpAsyncClient { /** * Client transport implementation. */ - private final McpTransport transport; + private final McpClientTransport transport; /** * Supported protocol versions. */ private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); + /** + * The MCP session supplier that manages bidirectional JSON-RPC communication between + * clients and servers. + */ + private final Supplier sessionSupplier; + /** * Create a new McpAsyncClient with the given transport and session request-response * timeout. @@ -212,7 +228,8 @@ public class McpAsyncClient { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_MESSAGE, asyncLoggingNotificationHandler(loggingConsumersFinal)); - this.mcpSession = new McpClientSession(requestTimeout, transport, this::handleException, requestHandlers, + this.transport.setExceptionHandler(this::handleException); + this.sessionSupplier = () -> new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers); } @@ -220,8 +237,11 @@ public class McpAsyncClient { private void handleException(Throwable t) { logger.warn("Handling exception", t); if (t instanceof McpSessionNotFoundException) { - this.initialization.set(null); - withInitialization("re-initializing", result -> Mono.empty()).subscribe(); + Initialization previous = this.initialization.getAndSet(null); + if (previous != null) { + previous.close(); + } + withSession("re-initializing", result -> Mono.empty()).subscribe(); } } @@ -285,7 +305,11 @@ public McpSchema.Implementation getClientInfo() { * Closes the client connection immediately. */ public void close() { - this.mcpSession.close(); + Initialization current = this.initialization.getAndSet(null); + if (current != null) { + current.close(); + } + this.transport.close(); } /** @@ -293,7 +317,9 @@ public void close() { * @return A Mono that completes when the connection is closed */ public Mono closeGracefully() { - return this.mcpSession.closeGracefully(); + Initialization current = this.initialization.getAndSet(null); + Mono sessionClose = current != null ? current.closeGracefully() : Mono.empty(); + return sessionClose.then(transport.closeGracefully()); } // -------------------------- @@ -326,10 +352,10 @@ public Mono closeGracefully() { *

    */ public Mono initialize() { - return withInitialization("by explicit API call", Mono::just); + return withSession("by explicit API call", init -> Mono.just(init.get())); } - private Mono doInitialize() { + private Mono doInitialize(McpClientSession session) { String latestVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); McpSchema.InitializeRequest initializeRequest = new McpSchema.InitializeRequest(// @formatter:off @@ -337,9 +363,8 @@ private Mono doInitialize() { this.clientCapabilities, this.clientInfo); // @formatter:on - Mono result = this.mcpSession.sendRequest(McpSchema.METHOD_INITIALIZE, - initializeRequest, new TypeReference() { - }); + Mono result = session.sendRequest(McpSchema.METHOD_INITIALIZE, initializeRequest, + INITIALIZE_RESULT_TYPE_REF); return result.flatMap(initializeResult -> { logger.info("Server response with Protocol: {}, Capabilities: {}, Info: {} and Instructions {}", @@ -351,7 +376,7 @@ private Mono doInitialize() { "Unsupported protocol version from the server: " + initializeResult.protocolVersion())); } - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null) + return session.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null) .thenReturn(initializeResult); }); } @@ -362,10 +387,47 @@ private static class Initialization { private final AtomicReference result = new AtomicReference<>(); + private final AtomicReference mcpClientSession = new AtomicReference<>(); + static Initialization create() { return new Initialization(); } + void setMcpClientSession(McpClientSession mcpClientSession) { + this.mcpClientSession.set(mcpClientSession); + } + + McpClientSession session() { + return this.mcpClientSession.get(); + } + + McpSchema.InitializeResult get() { + return this.result.get(); + } + + Mono await() { + return this.initSink.asMono(); + } + + void complete(McpSchema.InitializeResult initializeResult) { + // first ensure the result is cached + this.result.set(initializeResult); + // inform all the subscribers waiting for the initialization + this.initSink.emitValue(initializeResult, Sinks.EmitFailureHandler.FAIL_FAST); + } + + void error(Throwable t) { + this.initSink.emitError(t, Sinks.EmitFailureHandler.FAIL_FAST); + } + + void close() { + this.session().close(); + } + + Mono closeGracefully() { + return this.session().closeGracefully(); + } + } /** @@ -376,30 +438,29 @@ static Initialization create() { * @param operation The operation to execute when the client is initialized * @return A Mono that completes with the result of the operation */ - private Mono withInitialization(String actionName, Function> operation) { + private Mono withSession(String actionName, Function> operation) { return Mono.defer(() -> { Initialization newInit = Initialization.create(); Initialization previous = this.initialization.compareAndExchange(null, newInit); boolean needsToInitialize = previous == null; logger.info(needsToInitialize ? "Initialization process started" : "Joining previous initialization"); + if (needsToInitialize) { + newInit.setMcpClientSession(this.sessionSupplier.get()); + } Mono initializationJob = needsToInitialize - ? this.doInitialize().doOnNext(result -> { - // first ensure the result is cached - newInit.result.set(result); - // inform all the subscribers waiting for the initialization - newInit.initSink.emitValue(result, Sinks.EmitFailureHandler.FAIL_FAST); - }).onErrorResume(ex -> { - Initialization ongoing = this.initialization.getAndSet(null); - if (ongoing != null) { - ongoing.initSink.emitError(ex, Sinks.EmitFailureHandler.FAIL_FAST); - } + ? doInitialize(newInit.session()).doOnNext(newInit::complete).onErrorResume(ex -> { + newInit.error(ex); return Mono.error(ex); - }) : previous.initSink.asMono(); + }) : previous.await(); - return initializationJob.timeout(this.initializationTimeout) - .onErrorResume(ex -> Mono.error(new McpError("Client failed to initialize " + actionName))) + return initializationJob.map(initializeResult -> this.initialization.get()) + .timeout(this.initializationTimeout) + .onErrorResume(ex -> { + logger.warn("Failed to initialize", ex); + return Mono.error(new McpError("Client failed to initialize " + actionName)); + }) .flatMap(operation); }); } @@ -413,9 +474,8 @@ private Mono withInitialization(String actionName, Function ping() { - return this.withInitialization("pinging the server", initializedResult -> this.mcpSession - .sendRequest(McpSchema.METHOD_PING, null, new TypeReference() { - })); + return this.withSession("pinging the server", + init -> init.session().sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF)); } // -------------------------- @@ -495,16 +555,14 @@ public Mono removeRoot(String rootUri) { * @return A Mono that completes when the notification is sent. */ public Mono rootsListChangedNotification() { - return this.withInitialization("sending roots list changed notification", - initResult -> this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED)); + return this.withSession("sending roots list changed notification", + init -> init.session().sendNotification(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED)); } private RequestHandler rootsListRequestHandler() { return params -> { @SuppressWarnings("unused") - McpSchema.PaginatedRequest request = transport.unmarshalFrom(params, - new TypeReference() { - }); + McpSchema.PaginatedRequest request = transport.unmarshalFrom(params, PAGINATED_REQUEST_TYPE_REF); List roots = this.roots.values().stream().toList(); @@ -517,9 +575,7 @@ private RequestHandler rootsListRequestHandler() { // -------------------------- private RequestHandler samplingCreateMessageHandler() { return params -> { - McpSchema.CreateMessageRequest request = transport.unmarshalFrom(params, - new TypeReference() { - }); + McpSchema.CreateMessageRequest request = transport.unmarshalFrom(params, CREATE_MESSAGE_REQUEST_TYPE_REF); return this.samplingHandler.apply(request); }; @@ -546,11 +602,11 @@ private RequestHandler samplingCreateMessageHandler() { * @see #listTools() */ public Mono callTool(McpSchema.CallToolRequest callToolRequest) { - return this.withInitialization("calling tools", initializedResult -> { - if (initializedResult.capabilities().tools() == null) { + return this.withSession("calling tools", init -> { + if (init.get().capabilities().tools() == null) { return Mono.error(new McpError("Server does not provide tools capability")); } - return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF); + return init.session().sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF); }); } @@ -568,12 +624,13 @@ public Mono listTools() { * @return A Mono that emits the list of tools result */ public Mono listTools(String cursor) { - return this.withInitialization("listing tools", initializedResult -> { - if (initializedResult.capabilities().tools() == null) { + return this.withSession("listing tools", init -> { + if (init.get().capabilities().tools() == null) { return Mono.error(new McpError("Server does not provide tools capability")); } - return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor), - LIST_TOOLS_RESULT_TYPE_REF); + return init.session() + .sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor), + LIST_TOOLS_RESULT_TYPE_REF); }); } @@ -625,12 +682,13 @@ public Mono listResources() { * @see #readResource(McpSchema.Resource) */ public Mono listResources(String cursor) { - return this.withInitialization("listing resources", initializedResult -> { - if (initializedResult.capabilities().resources() == null) { + return this.withSession("listing resources", init -> { + if (init.get().capabilities().resources() == null) { return Mono.error(new McpError("Server does not provide the resources capability")); } - return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_LIST, new McpSchema.PaginatedRequest(cursor), - LIST_RESOURCES_RESULT_TYPE_REF); + return init.session() + .sendRequest(McpSchema.METHOD_RESOURCES_LIST, new McpSchema.PaginatedRequest(cursor), + LIST_RESOURCES_RESULT_TYPE_REF); }); } @@ -656,12 +714,12 @@ public Mono readResource(McpSchema.Resource resour * @see McpSchema.ReadResourceResult */ public Mono readResource(McpSchema.ReadResourceRequest readResourceRequest) { - return this.withInitialization("reading resources", initializedResult -> { - if (initializedResult.capabilities().resources() == null) { + return this.withSession("reading resources", init -> { + if (init.get().capabilities().resources() == null) { return Mono.error(new McpError("Server does not provide the resources capability")); } - return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_READ, readResourceRequest, - READ_RESOURCE_RESULT_TYPE_REF); + return init.session() + .sendRequest(McpSchema.METHOD_RESOURCES_READ, readResourceRequest, READ_RESOURCE_RESULT_TYPE_REF); }); } @@ -685,12 +743,13 @@ public Mono listResourceTemplates() { * @see McpSchema.ListResourceTemplatesResult */ public Mono listResourceTemplates(String cursor) { - return this.withInitialization("listing resource templates", initializedResult -> { - if (initializedResult.capabilities().resources() == null) { + return this.withSession("listing resource templates", init -> { + if (init.get().capabilities().resources() == null) { return Mono.error(new McpError("Server does not provide the resources capability")); } - return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, - new McpSchema.PaginatedRequest(cursor), LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF); + return init.session() + .sendRequest(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, new McpSchema.PaginatedRequest(cursor), + LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF); }); } @@ -704,7 +763,7 @@ public Mono listResourceTemplates(String * @see #unsubscribeResource(McpSchema.UnsubscribeRequest) */ public Mono subscribeResource(McpSchema.SubscribeRequest subscribeRequest) { - return this.withInitialization("subscribing to resources", initializedResult -> this.mcpSession + return this.withSession("subscribing to resources", init -> init.session() .sendRequest(McpSchema.METHOD_RESOURCES_SUBSCRIBE, subscribeRequest, VOID_TYPE_REFERENCE)); } @@ -718,7 +777,7 @@ public Mono subscribeResource(McpSchema.SubscribeRequest subscribeRequest) * @see #subscribeResource(McpSchema.SubscribeRequest) */ public Mono unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest) { - return this.withInitialization("unsubscribing from resources", initializedResult -> this.mcpSession + return this.withSession("unsubscribing from resources", init -> init.session() .sendRequest(McpSchema.METHOD_RESOURCES_UNSUBSCRIBE, unsubscribeRequest, VOID_TYPE_REFERENCE)); } @@ -760,7 +819,7 @@ public Mono listPrompts() { * @see #getPrompt(GetPromptRequest) */ public Mono listPrompts(String cursor) { - return this.withInitialization("listing prompts", initializedResult -> this.mcpSession + return this.withSession("listing prompts", init -> init.session() .sendRequest(McpSchema.METHOD_PROMPT_LIST, new PaginatedRequest(cursor), LIST_PROMPTS_RESULT_TYPE_REF)); } @@ -774,7 +833,7 @@ public Mono listPrompts(String cursor) { * @see #listPrompts() */ public Mono getPrompt(GetPromptRequest getPromptRequest) { - return this.withInitialization("getting prompts", initializedResult -> this.mcpSession + return this.withSession("getting prompts", init -> init.session() .sendRequest(McpSchema.METHOD_PROMPT_GET, getPromptRequest, GET_PROMPT_RESULT_TYPE_REF)); } @@ -805,8 +864,7 @@ private NotificationHandler asyncLoggingNotificationHandler( return params -> { McpSchema.LoggingMessageNotification loggingMessageNotification = transport.unmarshalFrom(params, - new TypeReference() { - }); + LOGGING_MESSAGE_NOTIFICATION_TYPE_REF); return Flux.fromIterable(loggingConsumers) .flatMap(consumer -> consumer.apply(loggingMessageNotification)) @@ -826,10 +884,9 @@ public Mono setLoggingLevel(LoggingLevel loggingLevel) { return Mono.error(new McpError("Logging level must not be null")); } - return this.withInitialization("setting logging level", initializedResult -> { + return this.withSession("setting logging level", init -> { var params = new McpSchema.SetLevelRequest(loggingLevel); - return this.mcpSession.sendRequest(McpSchema.METHOD_LOGGING_SET_LEVEL, params, new TypeReference() { - }).then(); + return init.session().sendRequest(McpSchema.METHOD_LOGGING_SET_LEVEL, params, OBJECT_TYPE_REF).then(); }); } @@ -859,7 +916,7 @@ void setProtocolVersions(List protocolVersions) { * @see McpSchema.CompleteResult */ public Mono completeCompletion(McpSchema.CompleteRequest completeRequest) { - return this.withInitialization("complete completions", initializedResult -> this.mcpSession + return this.withSession("complete completions", init -> init.session() .sendRequest(McpSchema.METHOD_COMPLETION_COMPLETE, completeRequest, COMPLETION_COMPLETE_RESULT_TYPE_REF)); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index db219eee..c8399240 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -4,13 +4,6 @@ package io.modelcontextprotocol.spec; -import java.time.Duration; -import java.util.Map; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicLong; -import java.util.function.Consumer; - import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; @@ -18,6 +11,12 @@ import reactor.core.publisher.Mono; import reactor.core.publisher.MonoSink; +import java.time.Duration; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; + /** * Default implementation of the MCP (Model Context Protocol) session that manages * bidirectional JSON-RPC communication between clients and servers. This implementation @@ -39,10 +38,6 @@ public class McpClientSession implements McpSession { private static final Logger logger = LoggerFactory.getLogger(McpClientSession.class); - private static final Consumer DEFAULT_CONSUMER = t -> { - logger.warn("MCP transport issued an exception", t); - }; - /** Duration to wait for request responses before timing out */ private final Duration requestTimeout; @@ -102,12 +97,10 @@ public interface NotificationHandler { * Creates a new McpClientSession with the specified configuration and handlers. * @param requestTimeout Duration to wait for responses * @param transport Transport implementation for message exchange - * @param exceptionHandler A hook to take action when transport level exceptions - * happen. * @param requestHandlers Map of method names to request handlers * @param notificationHandlers Map of method names to notification handlers */ - public McpClientSession(Duration requestTimeout, McpClientTransport transport, Consumer exceptionHandler, + public McpClientSession(Duration requestTimeout, McpClientTransport transport, Map> requestHandlers, Map notificationHandlers) { Assert.notNull(requestTimeout, "The requestTimeout can not be null"); @@ -120,13 +113,6 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport, C this.requestHandlers.putAll(requestHandlers); this.notificationHandlers.putAll(notificationHandlers); - this.transport.setExceptionHandler(t -> { - // We only clear when the session is invalidated - if (t instanceof McpSessionNotFoundException) { - this.pendingResponses.clear(); - } - exceptionHandler.accept(t); - }); this.transport.connect(mono -> mono.doOnNext(this::handle)).subscribe(); } @@ -138,21 +124,6 @@ private void dismissPendingResponses() { this.pendingResponses.clear(); } - /** - * Creates a new McpClientSession with the specified configuration and handlers. - * @param requestTimeout Duration to wait for responses - * @param transport Transport implementation for message exchange - * @param requestHandlers Map of method names to request handlers - * @param notificationHandlers Map of method names to notification handlers - * @deprecated Use - * {@link #McpClientSession(Duration, McpClientTransport, Consumer, Map, Map)}. - */ - @Deprecated - public McpClientSession(Duration requestTimeout, McpClientTransport transport, - Map> requestHandlers, Map notificationHandlers) { - this(requestTimeout, transport, DEFAULT_CONSUMER, requestHandlers, notificationHandlers); - } - private void handle(McpSchema.JSONRPCMessage message) { if (message instanceof McpSchema.JSONRPCResponse response) { logger.debug("Received Response: {}", response); @@ -303,8 +274,7 @@ public Mono sendNotification(String method, Object params) { */ @Override public Mono closeGracefully() { - dismissPendingResponses(); - return this.transport.closeGracefully(); + return Mono.fromRunnable(this::dismissPendingResponses); } /** @@ -313,7 +283,6 @@ public Mono closeGracefully() { @Override public void close() { dismissPendingResponses(); - transport.close(); } } From 807155629b33d39cfae269e893339487acb3d69f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Wed, 4 Jun 2025 16:05:05 +0200 Subject: [PATCH 12/20] Refactoring --- .../WebClientStreamableHttpTransport.java | 393 ++++++++---------- .../spec/DefaultMcpTransportSession.java | 24 +- .../spec/DefaultMcpTransportStream.java | 60 +++ .../spec/McpTransportSession.java | 13 +- .../spec/McpTransportStream.java | 17 + 5 files changed, 278 insertions(+), 229 deletions(-) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java index 22dbdb0b..259b4866 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -3,31 +3,31 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.DefaultMcpTransportSession; +import io.modelcontextprotocol.spec.DefaultMcpTransportStream; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSessionNotFoundException; import io.modelcontextprotocol.spec.McpTransportSession; -import org.reactivestreams.Publisher; +import io.modelcontextprotocol.spec.McpTransportStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.client.WebClientResponseException; import reactor.core.Disposable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.util.context.ContextView; import reactor.util.function.Tuple2; import reactor.util.function.Tuples; import java.io.IOException; import java.util.List; import java.util.Optional; -import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; @@ -42,6 +42,9 @@ public class WebClientStreamableHttpTransport implements McpClientTransport { */ private static final String MESSAGE_EVENT_TYPE = "message"; + public static final ParameterizedTypeReference> PARAMETERIZED_TYPE_REF = new ParameterizedTypeReference<>() { + }; + private final ObjectMapper objectMapper; private final WebClient webClient; @@ -75,7 +78,7 @@ public Mono connect(Function, Mono handler) { } private void handleException(Throwable t) { - logger.debug("Handling exception for session {}", sessionIdRepresentation(this.activeSession.get()), t); + logger.debug("Handling exception for session {}", sessionIdOrPlaceholder(this.activeSession.get()), t); if (t instanceof McpSessionNotFoundException) { McpTransportSession invalidSession = this.activeSession.getAndSet(new DefaultMcpTransportSession()); logger.warn("Server does not recognize session {}. Invalidating.", invalidSession.sessionId()); @@ -112,83 +115,65 @@ public Mono closeGracefully() { }); } - // FIXME: Avoid passing the ContextView - add hook allowing the Reactor Context to be - // attached to the chain? - private void reconnect(McpStream stream, ContextView ctx) { - if (stream != null) { - logger.debug("Reconnecting stream {} with lastId {}", stream.streamId(), stream.lastId()); - } - else { - logger.debug("Reconnecting with no prior stream"); - } - // Here we attempt to initialize the client. - // In case the server supports SSE, we will establish a long-running session - // here and - // listen for messages. - // If it doesn't, nothing actually happens here, that's just the way it is... - final AtomicReference disposableRef = new AtomicReference<>(); - final McpTransportSession transportSession = this.activeSession.get(); - Disposable connection = webClient.get() - .uri(this.endpoint) - .accept(MediaType.TEXT_EVENT_STREAM) - .headers(httpHeaders -> { - transportSession.sessionId().ifPresent(id -> httpHeaders.add("mcp-session-id", id)); - if (stream != null) { - stream.lastId().ifPresent(id -> httpHeaders.add("last-event-id", id)); - } - }) - .exchangeToFlux(response -> { - // Per spec, we are not checking whether it's 2xx, but only if the - // Accept header is proper. - if (response.statusCode().is2xxSuccessful() && response.headers().contentType().isPresent() - && response.headers().contentType().get().isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) { - - McpStream sessionStream = stream != null ? stream : new McpStream(this.resumableStreams); - logger.debug("Established stream {}", sessionStream.streamId()); - - Flux, Iterable>> idWithMessages = response - .bodyToFlux(new ParameterizedTypeReference>() { - }) - .map(this::parse); - - return sessionStream.consumeSseStream(idWithMessages); - } - else if (response.statusCode().isSameCodeAs(HttpStatus.METHOD_NOT_ALLOWED)) { - logger.debug("The server does not support SSE streams, using request-response mode."); - return Flux.empty(); - } - else if (response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND)) { - logger.warn("Session {} was not found on the MCP server", transportSession.sessionId()); - - McpSessionNotFoundException notFoundException = new McpSessionNotFoundException( - sessionIdRepresentation(transportSession)); - // inform the stream/connection subscriber - return Flux.error(notFoundException); - } - else { - return response.createError().doOnError(e -> { - logger.info("Opening an SSE stream failed. This can be safely ignored.", e); - }).flux(); - } - }) - .onErrorResume(t -> { - this.handleException(t); - return Flux.empty(); - }) - .doFinally(s -> { - Disposable ref = disposableRef.getAndSet(null); - if (ref != null) { - transportSession.removeConnection(ref); - } - }) - .contextWrite(ctx) - .subscribe(); - disposableRef.set(connection); - transportSession.addConnection(connection); - } + private Mono reconnect(McpTransportStream stream) { + return Mono.deferContextual(ctx -> { + if (stream != null) { + logger.debug("Reconnecting stream {} with lastId {}", stream.streamId(), stream.lastId()); + } + else { + logger.debug("Reconnecting with no prior stream"); + } + // Here we attempt to initialize the client. In case the server supports SSE, + // we will establish a long-running + // session here and listen for messages. If it doesn't, that's ok, the server + // is a simple, stateless one. + final AtomicReference disposableRef = new AtomicReference<>(); + final McpTransportSession transportSession = this.activeSession.get(); - private static String sessionIdRepresentation(McpTransportSession transportSession) { - return transportSession.sessionId().orElse("[missing_session_id]"); + Disposable connection = webClient.get() + .uri(this.endpoint) + .accept(MediaType.TEXT_EVENT_STREAM) + .headers(httpHeaders -> { + transportSession.sessionId().ifPresent(id -> httpHeaders.add("mcp-session-id", id)); + if (stream != null) { + stream.lastId().ifPresent(id -> httpHeaders.add("last-event-id", id)); + } + }) + .exchangeToFlux(response -> { + if (isEventStream(response)) { + return eventStream(stream, response); + } + else if (isNotAllowed(response)) { + logger.debug("The server does not support SSE streams, using request-response mode."); + return Flux.empty(); + } + else if (isNotFound(response)) { + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + return mcpSessionNotFoundError(sessionIdRepresentation); + } + else { + return response.createError().doOnError(e -> { + logger.info("Opening an SSE stream failed. This can be safely ignored.", e); + }).flux(); + } + }) + .onErrorComplete(t -> { + this.handleException(t); + return true; + }) + .doFinally(s -> { + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + transportSession.removeConnection(ref); + } + }) + .contextWrite(ctx) + .subscribe(); + + disposableRef.set(connection); + transportSession.addConnection(connection); + return Mono.just(connection); + }); } @Override @@ -211,123 +196,62 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { }) .bodyValue(message) .exchangeToFlux(response -> { - if (transportSession.markInitialized()) { - if (!response.headers().header("mcp-session-id").isEmpty()) { - String sessionId = response.headers().asHttpHeaders().getFirst("mcp-session-id"); - logger.debug("Established session with id {}", sessionId); - transportSession.setSessionId(sessionId); - // Once we have a session, we try to open an async stream for - // the server to send notifications and requests out-of-band. - reconnect(null, sink.contextView()); - } + if (transportSession + .markInitialized(response.headers().asHttpHeaders().getFirst("mcp-session-id"))) { + // Once we have a session, we try to open an async stream for + // the server to send notifications and requests out-of-band. + reconnect(null).contextWrite(sink.contextView()).subscribe(); } + String sessionRepresentation = sessionIdOrPlaceholder(transportSession); + // The spec mentions only ACCEPTED, but the existing SDKs can return // 200 OK for notifications - // if (!response.statusCode().isSameCodeAs(HttpStatus.ACCEPTED)) { if (response.statusCode().is2xxSuccessful()) { + Optional contentType = response.headers().contentType(); // Existing SDKs consume notifications with no response body nor // content type - if (response.headers().contentType().isEmpty()) { - logger.trace("Message was successfuly sent via POST for session {}", - transportSession.sessionId()); + if (contentType.isEmpty()) { + logger.trace("Message was successfully sent via POST for session {}", + sessionRepresentation); // signal the caller that the message was successfully // delivered sink.success(); // communicate to downstream there is no streamed data coming return Flux.empty(); } - - MediaType contentType = response.headers().contentType().get(); - - if (contentType.isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) { - // communicate to caller that the message was delivered - sink.success(); - - // starting a stream - McpStream sessionStream = new McpStream(this.resumableStreams); - - logger.trace("Sent POST and opened a stream ({}) for session {}", sessionStream.streamId(), - transportSession.sessionId()); - - Flux, Iterable>> idWithMessages = response - .bodyToFlux(new ParameterizedTypeReference>() { - }) - .map(this::parse); - - return sessionStream.consumeSseStream(idWithMessages); - } - else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) { - logger.trace("Received response to POST for session {}", transportSession.sessionId()); - - // communicate to caller the message was delivered - sink.success(); - - // provide the response body as a stream of a single response - // to consume - return response.bodyToMono( - String.class).>handle((responseMessage, s) -> { - try { - McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema - .deserializeJsonRpcMessage(objectMapper, responseMessage); - s.next(List.of(jsonRpcResponse)); - } - catch (IOException e) { - s.error(e); - } - }) - .flatMapIterable(Function.identity()); - } else { - logger.warn("Unknown media type {} returned for POST in session {}", contentType, - transportSession.sessionId()); - return Flux.error(new RuntimeException("Unknown media type returned: " + contentType)); + MediaType mediaType = contentType.get(); + if (mediaType.isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) { + // communicate to caller that the message was delivered + sink.success(); + // starting a stream + return newEventStream(response, sessionRepresentation); + } + else if (mediaType.isCompatibleWith(MediaType.APPLICATION_JSON)) { + logger.trace("Received response to POST for session {}", sessionRepresentation); + // communicate to caller the message was delivered + sink.success(); + return responseFlux(response); + } + else { + logger.warn("Unknown media type {} returned for POST in session {}", contentType, + sessionRepresentation); + return Flux.error(new RuntimeException("Unknown media type returned: " + contentType)); + } } } else { - if (response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND)) { - logger.warn("Session {} was not found on the MCP server", transportSession.sessionId()); - - McpSessionNotFoundException notFoundException = new McpSessionNotFoundException( - sessionIdRepresentation(transportSession)); - // inform the stream/connection subscriber - return Flux.error(notFoundException); + if (isNotFound(response)) { + return mcpSessionNotFoundError(sessionRepresentation); } - return response.createError().onErrorResume(e -> { - WebClientResponseException responseException = (WebClientResponseException) e; - byte[] body = responseException.getResponseBodyAsByteArray(); - McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = null; - Exception toPropagate; - try { - McpSchema.JSONRPCResponse jsonRpcResponse = objectMapper.readValue(body, - McpSchema.JSONRPCResponse.class); - jsonRpcError = jsonRpcResponse.error(); - toPropagate = new McpError(jsonRpcError); - } - catch (IOException ex) { - toPropagate = new RuntimeException("Sending request failed", e); - logger.debug("Received content together with {} HTTP code response: {}", - response.statusCode(), body); - } - - // Some implementations can return 400 when presented with a - // session id that it doesn't know about, so we will - // invalidate the session - // https://github.com/modelcontextprotocol/typescript-sdk/issues/389 - if (responseException.getStatusCode().isSameCodeAs(HttpStatus.BAD_REQUEST)) { - return Mono.error(new McpSessionNotFoundException( - sessionIdRepresentation(this.activeSession.get()), toPropagate)); - } - return Mono.empty(); - }).flux(); + return extractError(response, sessionRepresentation); } }) - .map(Mono::just) - .flatMap(this.handler.get()) + .flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage))) .onErrorResume(t -> { // handle the error first this.handleException(t); - // inform the caller of sendMessage sink.error(t); return Flux.empty(); @@ -345,6 +269,87 @@ else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) { }); } + private static Flux mcpSessionNotFoundError(String sessionRepresentation) { + logger.warn("Session {} was not found on the MCP server", sessionRepresentation); + // inform the stream/connection subscriber + return Flux.error(new McpSessionNotFoundException(sessionRepresentation)); + } + + private Flux extractError(ClientResponse response, String sessionRepresentation) { + return response.createError().onErrorResume(e -> { + WebClientResponseException responseException = (WebClientResponseException) e; + byte[] body = responseException.getResponseBodyAsByteArray(); + McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = null; + Exception toPropagate; + try { + McpSchema.JSONRPCResponse jsonRpcResponse = objectMapper.readValue(body, + McpSchema.JSONRPCResponse.class); + jsonRpcError = jsonRpcResponse.error(); + toPropagate = new McpError(jsonRpcError); + } + catch (IOException ex) { + toPropagate = new RuntimeException("Sending request failed", e); + logger.debug("Received content together with {} HTTP code response: {}", response.statusCode(), body); + } + + // Some implementations can return 400 when presented with a + // session id that it doesn't know about, so we will + // invalidate the session + // https://github.com/modelcontextprotocol/typescript-sdk/issues/389 + if (responseException.getStatusCode().isSameCodeAs(HttpStatus.BAD_REQUEST)) { + return Mono.error(new McpSessionNotFoundException(sessionRepresentation, toPropagate)); + } + return Mono.empty(); + }).flux(); + } + + private Flux eventStream(McpTransportStream stream, ClientResponse response) { + McpTransportStream sessionStream = stream != null ? stream + : new DefaultMcpTransportStream<>(this.resumableStreams, this::reconnect); + logger.debug("Connected stream {}", sessionStream.streamId()); + + var idWithMessages = response.bodyToFlux(PARAMETERIZED_TYPE_REF).map(this::parse); + return Flux.from(sessionStream.consumeSseStream(idWithMessages)); + } + + private static boolean isNotFound(ClientResponse response) { + return response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND); + } + + private static boolean isNotAllowed(ClientResponse response) { + return response.statusCode().isSameCodeAs(HttpStatus.METHOD_NOT_ALLOWED); + } + + private static boolean isEventStream(ClientResponse response) { + return response.statusCode().is2xxSuccessful() && response.headers().contentType().isPresent() + && response.headers().contentType().get().isCompatibleWith(MediaType.TEXT_EVENT_STREAM); + } + + private static String sessionIdOrPlaceholder(McpTransportSession transportSession) { + return transportSession.sessionId().orElse("[missing_session_id]"); + } + + private Flux responseFlux(ClientResponse response) { + return response.bodyToMono(String.class).>handle((responseMessage, s) -> { + try { + McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema.deserializeJsonRpcMessage(objectMapper, + responseMessage); + s.next(List.of(jsonRpcResponse)); + } + catch (IOException e) { + s.error(e); + } + }).flatMapIterable(Function.identity()); + } + + private Flux newEventStream(ClientResponse response, String sessionRepresentation) { + McpTransportStream sessionStream = new DefaultMcpTransportStream<>(this.resumableStreams, + this::reconnect); + logger.trace("Sent POST and opened a stream ({}) for session {}", sessionStream.streamId(), + sessionRepresentation); + return eventStream(sessionStream, response); + } + @Override public T unmarshalFrom(Object data, TypeReference typeRef) { return this.objectMapper.convertValue(data, typeRef); @@ -367,42 +372,4 @@ private Tuple2, Iterable> parse(Serve } } - private class McpStream { - - private static final AtomicLong counter = new AtomicLong(); - - private final AtomicReference lastId = new AtomicReference<>(); - - // Used only for internal accounting - private final long streamId; - - private final boolean resumable; - - McpStream(boolean resumable) { - this.streamId = counter.getAndIncrement(); - this.resumable = resumable; - } - - Optional lastId() { - return Optional.ofNullable(this.lastId.get()); - } - - long streamId() { - return this.streamId; - } - - Flux consumeSseStream( - Publisher, Iterable>> eventStream) { - return Flux.deferContextual(ctx -> Flux.from(eventStream).doOnError(e -> { - if (resumable && !(e instanceof McpSessionNotFoundException)) { - reconnect(this, ctx); - } - }).doOnNext(idAndMessage -> idAndMessage.getT1().ifPresent(id -> { - String previousId = this.lastId.getAndSet(id); - logger.debug("Updating last id {} -> {} for stream {}", previousId, id, this.streamId); - })).flatMapIterable(Tuple2::getT2)); - } - - } - } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java index 6bda390c..d66a6ccc 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java @@ -1,5 +1,7 @@ package io.modelcontextprotocol.spec; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.Disposable; import reactor.core.Disposables; import reactor.core.publisher.Mono; @@ -10,6 +12,8 @@ public class DefaultMcpTransportSession implements McpTransportSession { + private static final Logger logger = LoggerFactory.getLogger(DefaultMcpTransportSession.class); + private final Disposable.Composite openConnections = Disposables.composite(); private final AtomicBoolean initialized = new AtomicBoolean(false); @@ -25,13 +29,19 @@ public Optional sessionId() { } @Override - public void setSessionId(String sessionId) { - this.sessionId.set(sessionId); - } - - @Override - public boolean markInitialized() { - return this.initialized.compareAndSet(false, true); + public boolean markInitialized(String sessionId) { + boolean flipped = this.initialized.compareAndSet(false, true); + if (flipped) { + this.sessionId.set(sessionId); + logger.debug("Established session with id {}", sessionId); + } + else { + if (sessionId != null && !sessionId.equals(this.sessionId.get())) { + logger.warn("Different session id provided in response. Expecting {} but server returned {}", + this.sessionId.get(), sessionId); + } + } + return flipped; } @Override diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java new file mode 100644 index 00000000..e11263e4 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java @@ -0,0 +1,60 @@ +package io.modelcontextprotocol.spec; + +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; + +import java.util.Optional; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +public class DefaultMcpTransportStream implements McpTransportStream { + + private static final Logger logger = LoggerFactory.getLogger(DefaultMcpTransportStream.class); + + private static final AtomicLong counter = new AtomicLong(); + + private final AtomicReference lastId = new AtomicReference<>(); + + // Used only for internal accounting + private final long streamId; + + private final boolean resumable; + + private final Function, Publisher> reconnect; + + public DefaultMcpTransportStream(boolean resumable, + Function, Publisher> reconnect) { + this.reconnect = reconnect; + this.streamId = counter.getAndIncrement(); + this.resumable = resumable; + } + + @Override + public Optional lastId() { + return Optional.ofNullable(this.lastId.get()); + } + + @Override + public long streamId() { + return this.streamId; + } + + @Override + public Publisher consumeSseStream( + Publisher, Iterable>> eventStream) { + return Flux.deferContextual(ctx -> Flux.from(eventStream).doOnError(e -> { + if (resumable && !(e instanceof McpSessionNotFoundException)) { + Mono.from(reconnect.apply(this)).contextWrite(ctx).subscribe(); + } + }).doOnNext(idAndMessage -> idAndMessage.getT1().ifPresent(id -> { + String previousId = this.lastId.getAndSet(id); + logger.debug("Updating last id {} -> {} for stream {}", previousId, id, this.streamId); + })).flatMapIterable(Tuple2::getT2)); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java index 902db90d..555f018f 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java @@ -23,20 +23,15 @@ public interface McpTransportSession { */ Optional sessionId(); - /** - * If the transport provides a session id for the communication, this method should be - * called to record the current identifier. - * @param sessionId session identifier as provided by the server - */ - void setSessionId(String sessionId); - /** * Stateful operation that flips the un-initialized state to initialized if this is - * the first call. + * the first call. If the transport provides a session id for the communication, + * argument should not be null to record the current identifier. + * @param sessionId session identifier as provided by the server * @return if successful, this method returns {@code true} and means that a * post-initialization step can be performed */ - boolean markInitialized(); + boolean markInitialized(String sessionId); /** * Adds a resource that this transport session can monitor and dismiss when needed. diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java new file mode 100644 index 00000000..e78750b3 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java @@ -0,0 +1,17 @@ +package io.modelcontextprotocol.spec; + +import org.reactivestreams.Publisher; +import reactor.util.function.Tuple2; + +import java.util.Optional; + +public interface McpTransportStream { + + Optional lastId(); + + long streamId(); + + Publisher consumeSseStream( + Publisher, Iterable>> eventStream); + +} From 358ce25b2161571fe74bdd68cb0a219db2568dd7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Wed, 4 Jun 2025 16:22:54 +0200 Subject: [PATCH 13/20] Cleanup --- mcp-spring/mcp-spring-webflux/pom.xml | 11 +-- .../client/transport/Main.java | 83 ------------------- .../src/main/resources/logback.xml | 18 ---- ...eamableHttpAsyncClientResiliencyTests.java | 2 + .../src/test/resources/logback.xml | 8 +- .../client/McpAsyncClient.java | 8 +- 6 files changed, 13 insertions(+), 117 deletions(-) delete mode 100644 mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/Main.java delete mode 100644 mcp-spring/mcp-spring-webflux/src/main/resources/logback.xml diff --git a/mcp-spring/mcp-spring-webflux/pom.xml b/mcp-spring/mcp-spring-webflux/pom.xml index 4da36f11..26452fe9 100644 --- a/mcp-spring/mcp-spring-webflux/pom.xml +++ b/mcp-spring/mcp-spring-webflux/pom.xml @@ -88,15 +88,6 @@ ${byte-buddy.version} test - - org.springframework - spring-context - 6.2.6 - - - io.projectreactor.netty - reactor-netty-http - io.projectreactor reactor-test @@ -126,7 +117,7 @@ ch.qos.logback logback-classic ${logback.version} - + test diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/Main.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/Main.java deleted file mode 100644 index f7c4b30f..00000000 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/Main.java +++ /dev/null @@ -1,83 +0,0 @@ -package io.modelcontextprotocol.client.transport; - -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.client.McpAsyncClient; -import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.McpSyncClient; -import io.modelcontextprotocol.spec.McpSchema; -import io.netty.channel.socket.nio.NioChannelOption; -import jdk.net.ExtendedSocketOptions; -import org.springframework.http.client.reactive.ReactorClientHttpConnector; -import org.springframework.web.reactive.function.client.WebClient; -import reactor.netty.http.client.HttpClient; - -import java.util.List; -import java.util.Map; -import java.util.Scanner; -import java.util.concurrent.atomic.AtomicReference; - -public class Main { - - public static void main(String[] args) throws InterruptedException { - McpSyncClient client = McpClient - .sync(new WebClientStreamableHttpTransport(new ObjectMapper(), - WebClient.builder() - .clientConnector(new ReactorClientHttpConnector( - HttpClient.create().option(NioChannelOption.of(ExtendedSocketOptions.TCP_KEEPIDLE), 5))) - .baseUrl("http://localhost:3001"), - "/mcp", true, false)) - .build(); - - /* - * Inspector does this: 1. -> POST initialize request 2. <- capabilities response - * (with sessionId) 3. -> POST initialized notification 4. -> GET initialize SSE - * connection (with sessionId) - * - * VS - * - * 1. -> GET initialize SSE connection 2. <- 2xx ok with sessionId 3. -> POST - * initialize request 4. <- capabilities response 5. -> POST initialized - * notification - * - * - * SERVER-A + SERVER-B LOAD BALANCING between SERVER-A and SERVER-B STATELESS - * SERVER - * - * 1. -> (A) POST initialize request 2. <- (A) 2xx ok with capabilities 3. -> (B) - * POST initialized notification 4. -> (B) 2xx ok 5. -> (A or B) POST request - * tools 6. -> 2xx response - */ - - List tools = null; - while (tools == null) { - try { - client.initialize(); - tools = client.listTools().tools(); - } - catch (Exception e) { - System.out.println("Got exception. Retrying in 5s. " + e); - Thread.sleep(5000); - } - } - - Scanner scanner = new Scanner(System.in); - while (scanner.hasNext()) { - String text = scanner.nextLine(); - if (text == null || text.isEmpty()) { - System.out.println("Done"); - break; - } - try { - McpSchema.CallToolResult result = client - .callTool(new McpSchema.CallToolRequest(tools.get(0).name(), Map.of("message", text))); - System.out.println("Tool call result: " + result); - } - catch (Exception e) { - System.out.println("Error calling tool " + e); - } - } - - client.closeGracefully(); - } - -} diff --git a/mcp-spring/mcp-spring-webflux/src/main/resources/logback.xml b/mcp-spring/mcp-spring-webflux/src/main/resources/logback.xml deleted file mode 100644 index e38239e7..00000000 --- a/mcp-spring/mcp-spring-webflux/src/main/resources/logback.xml +++ /dev/null @@ -1,18 +0,0 @@ - - - - - - - %d{yyyy-MM-dd HH:mm:ss} [%thread] %-5level %logger{36} - %msg%n - - - - - - - - - - - diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java index 58ca5e00..8205789e 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java @@ -3,8 +3,10 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.Timeout; import org.springframework.web.reactive.function.client.WebClient; +@Timeout(15) public class WebClientStreamableHttpAsyncClientResiliencyTests extends AbstractMcpAsyncClientResiliencyTests { @Override diff --git a/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml b/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml index 2652e2ee..abc831d1 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml +++ b/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml @@ -9,13 +9,13 @@ - + - - + + - + diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index e9535a76..e318b269 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -14,9 +14,12 @@ import java.util.function.Supplier; import com.fasterxml.jackson.core.type.TypeReference; -import io.modelcontextprotocol.spec.*; +import io.modelcontextprotocol.spec.McpClientSession; import io.modelcontextprotocol.spec.McpClientSession.NotificationHandler; import io.modelcontextprotocol.spec.McpClientSession.RequestHandler; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; @@ -27,6 +30,7 @@ import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpSchema.PaginatedRequest; import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSessionNotFoundException; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; @@ -444,7 +448,7 @@ private Mono withSession(String actionName, Function Date: Fri, 6 Jun 2025 10:48:29 +0200 Subject: [PATCH 14/20] Bring back blocking op offloading to sync tests --- .../client/AbstractMcpSyncClientTests.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index c5c9ed2e..77989577 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -122,9 +122,9 @@ void verifyNotificationSucceedsWithImplicitInitialization(Consumer void verifyCallSucceedsWithImplicitInitialization(Function blockingOperation, String action) { withClient(createMcpTransport(), mcpSyncClient -> { - StepVerifier.create(Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient))) - .expectNextCount(1) - .verifyComplete(); + StepVerifier.create(Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient)) + // Offload the blocking call to the real scheduler + .subscribeOn(Schedulers.boundedElastic())).expectNextCount(1).verifyComplete(); }); } From 335fca0fc5df8e2ffbecc20f7e1b0ab351c9709d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Fri, 6 Jun 2025 11:06:01 +0200 Subject: [PATCH 15/20] Add diagnostic log for stdio client test and increase init timeout --- .../transport/StdioClientTransport.java | 27 ++++++++++++++----- .../client/StdioMcpSyncClientTests.java | 2 +- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java index 9d71cbb4..b2352597 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java @@ -124,13 +124,15 @@ public Mono connect(Function, Mono> h processBuilder.command(fullCommand); processBuilder.environment().putAll(params.getEnv()); - // Start the process - try { - this.process = processBuilder.start(); - } - catch (IOException e) { - throw new RuntimeException("Failed to start process with command: " + fullCommand, e); - } + measureTime(() -> { + // Start the process + try { + this.process = processBuilder.start(); + } + catch (IOException e) { + throw new RuntimeException("Failed to start process with command: " + fullCommand, e); + } + }, "Process start"); // Validate process streams if (this.process.getInputStream() == null || process.getOutputStream() == null) { @@ -391,4 +393,15 @@ public T unmarshalFrom(Object data, TypeReference typeRef) { return this.objectMapper.convertValue(data, typeRef); } + private static void measureTime(Runnable op, String opName) { + long start = System.nanoTime(); + try { + op.run(); + } + finally { + long delta = System.nanoTime() - start; + logger.info("{} took {}ms", opName, Duration.ofNanos(delta).toMillis()); + } + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java index 706aa9b2..4b5f4f9c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java @@ -68,7 +68,7 @@ void customErrorHandlerShouldReceiveErrors() throws InterruptedException { } protected Duration getInitializationTimeout() { - return Duration.ofSeconds(6); + return Duration.ofSeconds(10); } } From 1d3c1256d193521203e2cecd27bcb59974e4517b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Fri, 6 Jun 2025 13:09:05 +0200 Subject: [PATCH 16/20] Remove temp diagnostics but add stdio process lifecycle logs --- .../transport/StdioClientTransport.java | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java index b2352597..72a9f995 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java @@ -112,6 +112,7 @@ public StdioClientTransport(ServerParameters params, ObjectMapper objectMapper) @Override public Mono connect(Function, Mono> handler) { return Mono.fromRunnable(() -> { + logger.info("MCP server starting."); handleIncomingMessages(handler); handleIncomingErrors(); @@ -124,15 +125,13 @@ public Mono connect(Function, Mono> h processBuilder.command(fullCommand); processBuilder.environment().putAll(params.getEnv()); - measureTime(() -> { - // Start the process - try { - this.process = processBuilder.start(); - } - catch (IOException e) { - throw new RuntimeException("Failed to start process with command: " + fullCommand, e); - } - }, "Process start"); + // Start the process + try { + this.process = processBuilder.start(); + } + catch (IOException e) { + throw new RuntimeException("Failed to start process with command: " + fullCommand, e); + } // Validate process streams if (this.process.getInputStream() == null || process.getOutputStream() == null) { @@ -144,6 +143,7 @@ public Mono connect(Function, Mono> h startInboundProcessing(); startOutboundProcessing(); startErrorProcessing(); + logger.info("MCP server started"); }).subscribeOn(Schedulers.boundedElastic()); } @@ -367,6 +367,8 @@ public Mono closeGracefully() { })).doOnNext(process -> { if (process.exitValue() != 0) { logger.warn("Process terminated with code " + process.exitValue()); + } else { + logger.info("MCP server process stopped"); } }).then(Mono.fromRunnable(() -> { try { From e402f75023dbf4a2a6f520bbc716535d8ee08189 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Fri, 6 Jun 2025 13:19:43 +0200 Subject: [PATCH 17/20] Formatting --- .../client/transport/StdioClientTransport.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java index 72a9f995..76668931 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java @@ -367,7 +367,8 @@ public Mono closeGracefully() { })).doOnNext(process -> { if (process.exitValue() != 0) { logger.warn("Process terminated with code " + process.exitValue()); - } else { + } + else { logger.info("MCP server process stopped"); } }).then(Mono.fromRunnable(() -> { From 9b0b02a66fb954dd51541d268d30c34ce9f847ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Mon, 9 Jun 2025 14:58:53 +0200 Subject: [PATCH 18/20] Handle HTTP DELETE upon client close --- .../WebClientStreamableHttpTransport.java | 25 ++++++++++++++++--- ...AbstractMcpAsyncClientResiliencyTests.java | 11 ++++++++ .../client/McpAsyncClient.java | 8 +++--- .../spec/DefaultMcpTransportSession.java | 11 +++++--- 4 files changed, 46 insertions(+), 9 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java index 259b4866..6bba81a0 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -10,6 +10,7 @@ import io.modelcontextprotocol.spec.McpSessionNotFoundException; import io.modelcontextprotocol.spec.McpTransportSession; import io.modelcontextprotocol.spec.McpTransportStream; +import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.core.ParameterizedTypeReference; @@ -31,6 +32,7 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Supplier; public class WebClientStreamableHttpTransport implements McpClientTransport { @@ -69,7 +71,24 @@ public WebClientStreamableHttpTransport(ObjectMapper objectMapper, WebClient.Bui this.endpoint = endpoint; this.resumableStreams = resumableStreams; this.openConnectionOnStartup = openConnectionOnStartup; - this.activeSession.set(new DefaultMcpTransportSession()); + this.activeSession.set(createTransportSession()); + } + + private DefaultMcpTransportSession createTransportSession() { + Supplier> onClose = () -> { + DefaultMcpTransportSession transportSession = this.activeSession.get(); + return transportSession.sessionId().isEmpty() ? Mono.empty() : webClient + .delete() + .uri(this.endpoint) + .headers(httpHeaders -> { + httpHeaders.add("mcp-session-id", transportSession.sessionId().get()); + }) + .retrieve() + .toBodilessEntity() + .doOnError(e -> logger.info("Got response {}", e)) + .then(); + }; + return new DefaultMcpTransportSession(onClose); } @Override @@ -93,7 +112,7 @@ public void setExceptionHandler(Consumer handler) { private void handleException(Throwable t) { logger.debug("Handling exception for session {}", sessionIdOrPlaceholder(this.activeSession.get()), t); if (t instanceof McpSessionNotFoundException) { - McpTransportSession invalidSession = this.activeSession.getAndSet(new DefaultMcpTransportSession()); + McpTransportSession invalidSession = this.activeSession.getAndSet(createTransportSession()); logger.warn("Server does not recognize session {}. Invalidating.", invalidSession.sessionId()); invalidSession.close(); } @@ -107,7 +126,7 @@ private void handleException(Throwable t) { public Mono closeGracefully() { return Mono.defer(() -> { logger.debug("Graceful close triggered"); - DefaultMcpTransportSession currentSession = this.activeSession.get(); + DefaultMcpTransportSession currentSession = this.activeSession.getAndSet(createTransportSession()); if (currentSession != null) { return currentSession.closeGracefully(); } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java index 9990b8a3..d720df51 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java @@ -195,4 +195,15 @@ void testCallTool() { }); } + @Test + void testSessionClose() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + // In case of Streamable HTTP this call should issue a HTTP DELETE request invalidating the session + StepVerifier.create(mcpAsyncClient.closeGracefully()).expectComplete().verify(); + // The next use should immediately re-initialize with no issue and send the request without any broken connections. + StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + }); + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index e318b269..ee6cc6c1 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -321,9 +321,11 @@ public void close() { * @return A Mono that completes when the connection is closed */ public Mono closeGracefully() { - Initialization current = this.initialization.getAndSet(null); - Mono sessionClose = current != null ? current.closeGracefully() : Mono.empty(); - return sessionClose.then(transport.closeGracefully()); + return Mono.defer(() -> { + Initialization current = this.initialization.getAndSet(null); + Mono sessionClose = current != null ? current.closeGracefully() : Mono.empty(); + return sessionClose.then(transport.closeGracefully()); + }); } // -------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java index d66a6ccc..e83ea66c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java @@ -1,5 +1,6 @@ package io.modelcontextprotocol.spec; +import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.Disposable; @@ -9,6 +10,7 @@ import java.util.Optional; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; public class DefaultMcpTransportSession implements McpTransportSession { @@ -20,8 +22,11 @@ public class DefaultMcpTransportSession implements McpTransportSession sessionId = new AtomicReference<>(); - public DefaultMcpTransportSession() { - } + private final Supplier> onClose; + + public DefaultMcpTransportSession(Supplier> onClose) { + this.onClose = onClose; + } @Override public Optional sessionId() { @@ -61,7 +66,7 @@ public void close() { @Override public Mono closeGracefully() { - return Mono.fromRunnable(this.openConnections::dispose); + return Mono.from(this.onClose.get()).then(Mono.fromRunnable(this.openConnections::dispose)); } } From 0bcdbeeaad0533a21413831438e271e5bad23bbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Mon, 9 Jun 2025 16:32:47 +0200 Subject: [PATCH 19/20] Address review and format --- .../WebClientStreamableHttpTransport.java | 22 ++---- ...AbstractMcpAsyncClientResiliencyTests.java | 6 +- .../client/McpAsyncClient.java | 76 ++++++++++--------- .../spec/DefaultMcpTransportSession.java | 4 +- .../spec/DefaultMcpTransportStream.java | 2 +- ...McpTransportSessionNotFoundException.java} | 6 +- 6 files changed, 58 insertions(+), 58 deletions(-) rename mcp/src/main/java/io/modelcontextprotocol/spec/{McpSessionNotFoundException.java => McpTransportSessionNotFoundException.java} (61%) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java index 6bba81a0..6439f755 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -7,7 +7,7 @@ import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSessionNotFoundException; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; import io.modelcontextprotocol.spec.McpTransportSession; import io.modelcontextprotocol.spec.McpTransportStream; import org.reactivestreams.Publisher; @@ -77,16 +77,10 @@ public WebClientStreamableHttpTransport(ObjectMapper objectMapper, WebClient.Bui private DefaultMcpTransportSession createTransportSession() { Supplier> onClose = () -> { DefaultMcpTransportSession transportSession = this.activeSession.get(); - return transportSession.sessionId().isEmpty() ? Mono.empty() : webClient - .delete() - .uri(this.endpoint) - .headers(httpHeaders -> { - httpHeaders.add("mcp-session-id", transportSession.sessionId().get()); - }) - .retrieve() - .toBodilessEntity() - .doOnError(e -> logger.info("Got response {}", e)) - .then(); + return transportSession.sessionId().isEmpty() ? Mono.empty() + : webClient.delete().uri(this.endpoint).headers(httpHeaders -> { + httpHeaders.add("mcp-session-id", transportSession.sessionId().get()); + }).retrieve().toBodilessEntity().doOnError(e -> logger.info("Got response {}", e)).then(); }; return new DefaultMcpTransportSession(onClose); } @@ -111,7 +105,7 @@ public void setExceptionHandler(Consumer handler) { private void handleException(Throwable t) { logger.debug("Handling exception for session {}", sessionIdOrPlaceholder(this.activeSession.get()), t); - if (t instanceof McpSessionNotFoundException) { + if (t instanceof McpTransportSessionNotFoundException) { McpTransportSession invalidSession = this.activeSession.getAndSet(createTransportSession()); logger.warn("Server does not recognize session {}. Invalidating.", invalidSession.sessionId()); invalidSession.close(); @@ -291,7 +285,7 @@ else if (mediaType.isCompatibleWith(MediaType.APPLICATION_JSON)) { private static Flux mcpSessionNotFoundError(String sessionRepresentation) { logger.warn("Session {} was not found on the MCP server", sessionRepresentation); // inform the stream/connection subscriber - return Flux.error(new McpSessionNotFoundException(sessionRepresentation)); + return Flux.error(new McpTransportSessionNotFoundException(sessionRepresentation)); } private Flux extractError(ClientResponse response, String sessionRepresentation) { @@ -316,7 +310,7 @@ private Flux extractError(ClientResponse response, Str // invalidate the session // https://github.com/modelcontextprotocol/typescript-sdk/issues/389 if (responseException.getStatusCode().isSameCodeAs(HttpStatus.BAD_REQUEST)) { - return Mono.error(new McpSessionNotFoundException(sessionRepresentation, toPropagate)); + return Mono.error(new McpTransportSessionNotFoundException(sessionRepresentation, toPropagate)); } return Mono.empty(); }).flux(); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java index d720df51..7809dd54 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java @@ -199,9 +199,11 @@ void testCallTool() { void testSessionClose() { withClient(createMcpTransport(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); - // In case of Streamable HTTP this call should issue a HTTP DELETE request invalidating the session + // In case of Streamable HTTP this call should issue a HTTP DELETE request + // invalidating the session StepVerifier.create(mcpAsyncClient.closeGracefully()).expectComplete().verify(); - // The next use should immediately re-initialize with no issue and send the request without any broken connections. + // The next use should immediately re-initialize with no issue and send the + // request without any broken connections. StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); }); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index ee6cc6c1..66f29662 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -30,7 +30,7 @@ import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpSchema.PaginatedRequest; import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpSessionNotFoundException; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; @@ -98,7 +98,7 @@ public class McpAsyncClient { public static final TypeReference LOGGING_MESSAGE_NOTIFICATION_TYPE_REF = new TypeReference<>() { }; - private final AtomicReference initialization = new AtomicReference<>(); + private final AtomicReference initializationRef = new AtomicReference<>(); /** * The max timeout to await for the client-server connection to be initialized. @@ -240,8 +240,8 @@ public class McpAsyncClient { private void handleException(Throwable t) { logger.warn("Handling exception", t); - if (t instanceof McpSessionNotFoundException) { - Initialization previous = this.initialization.getAndSet(null); + if (t instanceof McpTransportSessionNotFoundException) { + Initialization previous = this.initializationRef.getAndSet(null); if (previous != null) { previous.close(); } @@ -249,13 +249,18 @@ private void handleException(Throwable t) { } } + private McpSchema.InitializeResult currentInitializationResult() { + Initialization current = this.initializationRef.get(); + McpSchema.InitializeResult initializeResult = current != null ? current.result.get() : null; + return initializeResult; + } + /** * Get the server capabilities that define the supported features and functionality. * @return The server capabilities */ public McpSchema.ServerCapabilities getServerCapabilities() { - Initialization current = this.initialization.get(); - McpSchema.InitializeResult initializeResult = current != null ? current.result.get() : null; + McpSchema.InitializeResult initializeResult = currentInitializationResult(); return initializeResult != null ? initializeResult.capabilities() : null; } @@ -265,8 +270,7 @@ public McpSchema.ServerCapabilities getServerCapabilities() { * @return The server instructions */ public String getServerInstructions() { - Initialization current = this.initialization.get(); - McpSchema.InitializeResult initializeResult = current != null ? current.result.get() : null; + McpSchema.InitializeResult initializeResult = currentInitializationResult(); return initializeResult != null ? initializeResult.instructions() : null; } @@ -275,8 +279,7 @@ public String getServerInstructions() { * @return The server implementation details */ public McpSchema.Implementation getServerInfo() { - Initialization current = this.initialization.get(); - McpSchema.InitializeResult initializeResult = current != null ? current.result.get() : null; + McpSchema.InitializeResult initializeResult = currentInitializationResult(); return initializeResult != null ? initializeResult.serverInfo() : null; } @@ -285,7 +288,7 @@ public McpSchema.Implementation getServerInfo() { * @return true if the client-server connection is initialized */ public boolean isInitialized() { - Initialization current = this.initialization.get(); + Initialization current = this.initializationRef.get(); return current != null && (current.result.get() != null); } @@ -309,7 +312,7 @@ public McpSchema.Implementation getClientInfo() { * Closes the client connection immediately. */ public void close() { - Initialization current = this.initialization.getAndSet(null); + Initialization current = this.initializationRef.getAndSet(null); if (current != null) { current.close(); } @@ -322,7 +325,7 @@ public void close() { */ public Mono closeGracefully() { return Mono.defer(() -> { - Initialization current = this.initialization.getAndSet(null); + Initialization current = this.initializationRef.getAndSet(null); Mono sessionClose = current != null ? current.closeGracefully() : Mono.empty(); return sessionClose.then(transport.closeGracefully()); }); @@ -361,7 +364,7 @@ public Mono initialize() { return withSession("by explicit API call", init -> Mono.just(init.get())); } - private Mono doInitialize(McpClientSession session) { + private Mono doInitialize(McpClientSession mcpClientSession) { String latestVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); McpSchema.InitializeRequest initializeRequest = new McpSchema.InitializeRequest(// @formatter:off @@ -369,8 +372,8 @@ private Mono doInitialize(McpClientSession session) this.clientCapabilities, this.clientInfo); // @formatter:on - Mono result = session.sendRequest(McpSchema.METHOD_INITIALIZE, initializeRequest, - INITIALIZE_RESULT_TYPE_REF); + Mono result = mcpClientSession.sendRequest(McpSchema.METHOD_INITIALIZE, + initializeRequest, INITIALIZE_RESULT_TYPE_REF); return result.flatMap(initializeResult -> { logger.info("Server response with Protocol: {}, Capabilities: {}, Info: {} and Instructions {}", @@ -382,7 +385,7 @@ private Mono doInitialize(McpClientSession session) "Unsupported protocol version from the server: " + initializeResult.protocolVersion())); } - return session.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null) + return mcpClientSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null) .thenReturn(initializeResult); }); } @@ -403,7 +406,7 @@ void setMcpClientSession(McpClientSession mcpClientSession) { this.mcpClientSession.set(mcpClientSession); } - McpClientSession session() { + McpClientSession mcpSession() { return this.mcpClientSession.get(); } @@ -427,11 +430,11 @@ void error(Throwable t) { } void close() { - this.session().close(); + this.mcpSession().close(); } Mono closeGracefully() { - return this.session().closeGracefully(); + return this.mcpSession().closeGracefully(); } } @@ -447,7 +450,7 @@ Mono closeGracefully() { private Mono withSession(String actionName, Function> operation) { return Mono.defer(() -> { Initialization newInit = Initialization.create(); - Initialization previous = this.initialization.compareAndExchange(null, newInit); + Initialization previous = this.initializationRef.compareAndExchange(null, newInit); boolean needsToInitialize = previous == null; logger.debug(needsToInitialize ? "Initialization process started" : "Joining previous initialization"); @@ -456,12 +459,12 @@ private Mono withSession(String actionName, Function initializationJob = needsToInitialize - ? doInitialize(newInit.session()).doOnNext(newInit::complete).onErrorResume(ex -> { + ? doInitialize(newInit.mcpSession()).doOnNext(newInit::complete).onErrorResume(ex -> { newInit.error(ex); return Mono.error(ex); }) : previous.await(); - return initializationJob.map(initializeResult -> this.initialization.get()) + return initializationJob.map(initializeResult -> this.initializationRef.get()) .timeout(this.initializationTimeout) .onErrorResume(ex -> { logger.warn("Failed to initialize", ex); @@ -481,7 +484,7 @@ private Mono withSession(String actionName, Function ping() { return this.withSession("pinging the server", - init -> init.session().sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF)); + init -> init.mcpSession().sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF)); } // -------------------------- @@ -562,7 +565,7 @@ public Mono removeRoot(String rootUri) { */ public Mono rootsListChangedNotification() { return this.withSession("sending roots list changed notification", - init -> init.session().sendNotification(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED)); + init -> init.mcpSession().sendNotification(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED)); } private RequestHandler rootsListRequestHandler() { @@ -612,7 +615,8 @@ public Mono callTool(McpSchema.CallToolRequest callToo if (init.get().capabilities().tools() == null) { return Mono.error(new McpError("Server does not provide tools capability")); } - return init.session().sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF); + return init.mcpSession() + .sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF); }); } @@ -634,7 +638,7 @@ public Mono listTools(String cursor) { if (init.get().capabilities().tools() == null) { return Mono.error(new McpError("Server does not provide tools capability")); } - return init.session() + return init.mcpSession() .sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor), LIST_TOOLS_RESULT_TYPE_REF); }); @@ -692,7 +696,7 @@ public Mono listResources(String cursor) { if (init.get().capabilities().resources() == null) { return Mono.error(new McpError("Server does not provide the resources capability")); } - return init.session() + return init.mcpSession() .sendRequest(McpSchema.METHOD_RESOURCES_LIST, new McpSchema.PaginatedRequest(cursor), LIST_RESOURCES_RESULT_TYPE_REF); }); @@ -724,7 +728,7 @@ public Mono readResource(McpSchema.ReadResourceReq if (init.get().capabilities().resources() == null) { return Mono.error(new McpError("Server does not provide the resources capability")); } - return init.session() + return init.mcpSession() .sendRequest(McpSchema.METHOD_RESOURCES_READ, readResourceRequest, READ_RESOURCE_RESULT_TYPE_REF); }); } @@ -753,7 +757,7 @@ public Mono listResourceTemplates(String if (init.get().capabilities().resources() == null) { return Mono.error(new McpError("Server does not provide the resources capability")); } - return init.session() + return init.mcpSession() .sendRequest(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, new McpSchema.PaginatedRequest(cursor), LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF); }); @@ -769,7 +773,7 @@ public Mono listResourceTemplates(String * @see #unsubscribeResource(McpSchema.UnsubscribeRequest) */ public Mono subscribeResource(McpSchema.SubscribeRequest subscribeRequest) { - return this.withSession("subscribing to resources", init -> init.session() + return this.withSession("subscribing to resources", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_RESOURCES_SUBSCRIBE, subscribeRequest, VOID_TYPE_REFERENCE)); } @@ -783,7 +787,7 @@ public Mono subscribeResource(McpSchema.SubscribeRequest subscribeRequest) * @see #subscribeResource(McpSchema.SubscribeRequest) */ public Mono unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest) { - return this.withSession("unsubscribing from resources", init -> init.session() + return this.withSession("unsubscribing from resources", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_RESOURCES_UNSUBSCRIBE, unsubscribeRequest, VOID_TYPE_REFERENCE)); } @@ -825,7 +829,7 @@ public Mono listPrompts() { * @see #getPrompt(GetPromptRequest) */ public Mono listPrompts(String cursor) { - return this.withSession("listing prompts", init -> init.session() + return this.withSession("listing prompts", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_PROMPT_LIST, new PaginatedRequest(cursor), LIST_PROMPTS_RESULT_TYPE_REF)); } @@ -839,7 +843,7 @@ public Mono listPrompts(String cursor) { * @see #listPrompts() */ public Mono getPrompt(GetPromptRequest getPromptRequest) { - return this.withSession("getting prompts", init -> init.session() + return this.withSession("getting prompts", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_PROMPT_GET, getPromptRequest, GET_PROMPT_RESULT_TYPE_REF)); } @@ -892,7 +896,7 @@ public Mono setLoggingLevel(LoggingLevel loggingLevel) { return this.withSession("setting logging level", init -> { var params = new McpSchema.SetLevelRequest(loggingLevel); - return init.session().sendRequest(McpSchema.METHOD_LOGGING_SET_LEVEL, params, OBJECT_TYPE_REF).then(); + return init.mcpSession().sendRequest(McpSchema.METHOD_LOGGING_SET_LEVEL, params, OBJECT_TYPE_REF).then(); }); } @@ -922,7 +926,7 @@ void setProtocolVersions(List protocolVersions) { * @see McpSchema.CompleteResult */ public Mono completeCompletion(McpSchema.CompleteRequest completeRequest) { - return this.withSession("complete completions", init -> init.session() + return this.withSession("complete completions", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_COMPLETION_COMPLETE, completeRequest, COMPLETION_COMPLETE_RESULT_TYPE_REF)); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java index e83ea66c..cd74c793 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java @@ -25,8 +25,8 @@ public class DefaultMcpTransportSession implements McpTransportSession> onClose; public DefaultMcpTransportSession(Supplier> onClose) { - this.onClose = onClose; - } + this.onClose = onClose; + } @Override public Optional sessionId() { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java index e11263e4..c9e29224 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java @@ -48,7 +48,7 @@ public long streamId() { public Publisher consumeSseStream( Publisher, Iterable>> eventStream) { return Flux.deferContextual(ctx -> Flux.from(eventStream).doOnError(e -> { - if (resumable && !(e instanceof McpSessionNotFoundException)) { + if (resumable && !(e instanceof McpTransportSessionNotFoundException)) { Mono.from(reconnect.apply(this)).contextWrite(ctx).subscribe(); } }).doOnNext(idAndMessage -> idAndMessage.getT1().ifPresent(id -> { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSessionNotFoundException.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java similarity index 61% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpSessionNotFoundException.java rename to mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java index be07eea1..7b33e62f 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSessionNotFoundException.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java @@ -6,14 +6,14 @@ * * @author Dariusz Jędrzejczyk */ -public class McpSessionNotFoundException extends RuntimeException { +public class McpTransportSessionNotFoundException extends RuntimeException { - public McpSessionNotFoundException(String sessionId, Exception cause) { + public McpTransportSessionNotFoundException(String sessionId, Exception cause) { super("Session " + sessionId + " not found on the server", cause); } - public McpSessionNotFoundException(String sessionId) { + public McpTransportSessionNotFoundException(String sessionId) { super("Session " + sessionId + " not found on the server"); } From 1680a27283342f65edeb717c0393a427eae090ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Mon, 9 Jun 2025 16:35:17 +0200 Subject: [PATCH 20/20] Add clarifying comment --- .../java/io/modelcontextprotocol/client/McpAsyncClient.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index 66f29662..2e29e40a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -245,6 +245,8 @@ private void handleException(Throwable t) { if (previous != null) { previous.close(); } + // Providing an empty operation since we are only interested in triggering the + // implicit initialization step. withSession("re-initializing", result -> Mono.empty()).subscribe(); } }