+ * The transport is capable of resumability and reconnects. It reacts to transport-level
+ * session invalidation and will propagate {@link McpTransportSessionNotFoundException
+ * appropriate exceptions} to the higher level abstraction layer when needed in order to
+ * allow proper state management. The implementation handles servers that are stateful and
+ * provide session meta information, but can also communicate with stateless servers that
+ * do not provide a session identifier and do not support SSE streams.
+ *
+ *
+ * This implementation does not handle backwards compatibility with the "HTTP
+ * with SSE" transport. In order to communicate over the phased-out
+ * 2024-11-05 protocol, use {@link HttpClientSseClientTransport} or
+ * {@link WebFluxSseClientTransport}.
+ *
+ *
+ * @author Dariusz Jędrzejczyk
+ * @see Streamable
+ * HTTP transport specification
+ */
+public class WebClientStreamableHttpTransport implements McpClientTransport {
+
+ private static final Logger logger = LoggerFactory.getLogger(WebClientStreamableHttpTransport.class);
+
+ private static final String DEFAULT_ENDPOINT = "/mcp";
+
+ /**
+ * Event type for JSON-RPC messages received through the SSE connection. The server
+ * sends messages with this event type to transmit JSON-RPC protocol data.
+ */
+ private static final String MESSAGE_EVENT_TYPE = "message";
+
+ private static final ParameterizedTypeReference> PARAMETERIZED_TYPE_REF = new ParameterizedTypeReference<>() {
+ };
+
+ private final ObjectMapper objectMapper;
+
+ private final WebClient webClient;
+
+ private final String endpoint;
+
+ private final boolean openConnectionOnStartup;
+
+ private final boolean resumableStreams;
+
+ private final AtomicReference activeSession = new AtomicReference<>();
+
+ private final AtomicReference, Mono>> handler = new AtomicReference<>();
+
+ private final AtomicReference> exceptionHandler = new AtomicReference<>();
+
+ private WebClientStreamableHttpTransport(ObjectMapper objectMapper, WebClient.Builder webClientBuilder,
+ String endpoint, boolean resumableStreams, boolean openConnectionOnStartup) {
+ this.objectMapper = objectMapper;
+ this.webClient = webClientBuilder.build();
+ this.endpoint = endpoint;
+ this.resumableStreams = resumableStreams;
+ this.openConnectionOnStartup = openConnectionOnStartup;
+ this.activeSession.set(createTransportSession());
+ }
+
+ /**
+ * Create a stateful builder for creating {@link WebClientStreamableHttpTransport}
+ * instances.
+ * @param webClientBuilder the {@link WebClient.Builder} to use
+ * @return a builder which will create an instance of
+ * {@link WebClientStreamableHttpTransport} once {@link Builder#build()} is called
+ */
+ public static Builder builder(WebClient.Builder webClientBuilder) {
+ return new Builder(webClientBuilder);
+ }
+
+ @Override
+ public Mono connect(Function, Mono> handler) {
+ return Mono.deferContextual(ctx -> {
+ this.handler.set(handler);
+ if (openConnectionOnStartup) {
+ logger.debug("Eagerly opening connection on startup");
+ return this.reconnect(null).then();
+ }
+ return Mono.empty();
+ });
+ }
+
+ private DefaultMcpTransportSession createTransportSession() {
+ Supplier> onClose = () -> {
+ DefaultMcpTransportSession transportSession = this.activeSession.get();
+ return transportSession.sessionId().isEmpty() ? Mono.empty()
+ : webClient.delete().uri(this.endpoint).headers(httpHeaders -> {
+ httpHeaders.add("mcp-session-id", transportSession.sessionId().get());
+ }).retrieve().toBodilessEntity().doOnError(e -> logger.info("Got response {}", e)).then();
+ };
+ return new DefaultMcpTransportSession(onClose);
+ }
+
+ @Override
+ public void setExceptionHandler(Consumer handler) {
+ logger.debug("Exception handler registered");
+ this.exceptionHandler.set(handler);
+ }
+
+ private void handleException(Throwable t) {
+ logger.debug("Handling exception for session {}", sessionIdOrPlaceholder(this.activeSession.get()), t);
+ if (t instanceof McpTransportSessionNotFoundException) {
+ McpTransportSession> invalidSession = this.activeSession.getAndSet(createTransportSession());
+ logger.warn("Server does not recognize session {}. Invalidating.", invalidSession.sessionId());
+ invalidSession.close();
+ }
+ Consumer handler = this.exceptionHandler.get();
+ if (handler != null) {
+ handler.accept(t);
+ }
+ }
+
+ @Override
+ public Mono closeGracefully() {
+ return Mono.defer(() -> {
+ logger.debug("Graceful close triggered");
+ DefaultMcpTransportSession currentSession = this.activeSession.getAndSet(createTransportSession());
+ if (currentSession != null) {
+ return currentSession.closeGracefully();
+ }
+ return Mono.empty();
+ });
+ }
+
+ private Mono reconnect(McpTransportStream stream) {
+ return Mono.deferContextual(ctx -> {
+ if (stream != null) {
+ logger.debug("Reconnecting stream {} with lastId {}", stream.streamId(), stream.lastId());
+ }
+ else {
+ logger.debug("Reconnecting with no prior stream");
+ }
+ // Here we attempt to initialize the client. In case the server supports SSE,
+ // we will establish a long-running
+ // session here and listen for messages. If it doesn't, that's ok, the server
+ // is a simple, stateless one.
+ final AtomicReference disposableRef = new AtomicReference<>();
+ final McpTransportSession transportSession = this.activeSession.get();
+
+ Disposable connection = webClient.get()
+ .uri(this.endpoint)
+ .accept(MediaType.TEXT_EVENT_STREAM)
+ .headers(httpHeaders -> {
+ transportSession.sessionId().ifPresent(id -> httpHeaders.add("mcp-session-id", id));
+ if (stream != null) {
+ stream.lastId().ifPresent(id -> httpHeaders.add("last-event-id", id));
+ }
+ })
+ .exchangeToFlux(response -> {
+ if (isEventStream(response)) {
+ return eventStream(stream, response);
+ }
+ else if (isNotAllowed(response)) {
+ logger.debug("The server does not support SSE streams, using request-response mode.");
+ return Flux.empty();
+ }
+ else if (isNotFound(response)) {
+ String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession);
+ return mcpSessionNotFoundError(sessionIdRepresentation);
+ }
+ else {
+ return response.createError().doOnError(e -> {
+ logger.info("Opening an SSE stream failed. This can be safely ignored.", e);
+ }).flux();
+ }
+ })
+ .onErrorComplete(t -> {
+ this.handleException(t);
+ return true;
+ })
+ .doFinally(s -> {
+ Disposable ref = disposableRef.getAndSet(null);
+ if (ref != null) {
+ transportSession.removeConnection(ref);
+ }
+ })
+ .contextWrite(ctx)
+ .subscribe();
+
+ disposableRef.set(connection);
+ transportSession.addConnection(connection);
+ return Mono.just(connection);
+ });
+ }
+
+ @Override
+ public Mono sendMessage(McpSchema.JSONRPCMessage message) {
+ return Mono.create(sink -> {
+ logger.debug("Sending message {}", message);
+ // Here we attempt to initialize the client.
+ // In case the server supports SSE, we will establish a long-running session
+ // here and
+ // listen for messages.
+ // If it doesn't, nothing actually happens here, that's just the way it is...
+ final AtomicReference disposableRef = new AtomicReference<>();
+ final McpTransportSession transportSession = this.activeSession.get();
+
+ Disposable connection = webClient.post()
+ .uri(this.endpoint)
+ .accept(MediaType.TEXT_EVENT_STREAM, MediaType.APPLICATION_JSON)
+ .headers(httpHeaders -> {
+ transportSession.sessionId().ifPresent(id -> httpHeaders.add("mcp-session-id", id));
+ })
+ .bodyValue(message)
+ .exchangeToFlux(response -> {
+ if (transportSession
+ .markInitialized(response.headers().asHttpHeaders().getFirst("mcp-session-id"))) {
+ // Once we have a session, we try to open an async stream for
+ // the server to send notifications and requests out-of-band.
+ reconnect(null).contextWrite(sink.contextView()).subscribe();
+ }
+
+ String sessionRepresentation = sessionIdOrPlaceholder(transportSession);
+
+ // The spec mentions only ACCEPTED, but the existing SDKs can return
+ // 200 OK for notifications
+ if (response.statusCode().is2xxSuccessful()) {
+ Optional contentType = response.headers().contentType();
+ // Existing SDKs consume notifications with no response body nor
+ // content type
+ if (contentType.isEmpty()) {
+ logger.trace("Message was successfully sent via POST for session {}",
+ sessionRepresentation);
+ // signal the caller that the message was successfully
+ // delivered
+ sink.success();
+ // communicate to downstream there is no streamed data coming
+ return Flux.empty();
+ }
+ else {
+ MediaType mediaType = contentType.get();
+ if (mediaType.isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) {
+ // communicate to caller that the message was delivered
+ sink.success();
+ // starting a stream
+ return newEventStream(response, sessionRepresentation);
+ }
+ else if (mediaType.isCompatibleWith(MediaType.APPLICATION_JSON)) {
+ logger.trace("Received response to POST for session {}", sessionRepresentation);
+ // communicate to caller the message was delivered
+ sink.success();
+ return responseFlux(response);
+ }
+ else {
+ logger.warn("Unknown media type {} returned for POST in session {}", contentType,
+ sessionRepresentation);
+ return Flux.error(new RuntimeException("Unknown media type returned: " + contentType));
+ }
+ }
+ }
+ else {
+ if (isNotFound(response)) {
+ return mcpSessionNotFoundError(sessionRepresentation);
+ }
+ return extractError(response, sessionRepresentation);
+ }
+ })
+ .flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage)))
+ .onErrorResume(t -> {
+ // handle the error first
+ this.handleException(t);
+ // inform the caller of sendMessage
+ sink.error(t);
+ return Flux.empty();
+ })
+ .doFinally(s -> {
+ Disposable ref = disposableRef.getAndSet(null);
+ if (ref != null) {
+ transportSession.removeConnection(ref);
+ }
+ })
+ .contextWrite(sink.contextView())
+ .subscribe();
+ disposableRef.set(connection);
+ transportSession.addConnection(connection);
+ });
+ }
+
+ private static Flux mcpSessionNotFoundError(String sessionRepresentation) {
+ logger.warn("Session {} was not found on the MCP server", sessionRepresentation);
+ // inform the stream/connection subscriber
+ return Flux.error(new McpTransportSessionNotFoundException(sessionRepresentation));
+ }
+
+ private Flux extractError(ClientResponse response, String sessionRepresentation) {
+ return response.createError().onErrorResume(e -> {
+ WebClientResponseException responseException = (WebClientResponseException) e;
+ byte[] body = responseException.getResponseBodyAsByteArray();
+ McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = null;
+ Exception toPropagate;
+ try {
+ McpSchema.JSONRPCResponse jsonRpcResponse = objectMapper.readValue(body,
+ McpSchema.JSONRPCResponse.class);
+ jsonRpcError = jsonRpcResponse.error();
+ toPropagate = new McpError(jsonRpcError);
+ }
+ catch (IOException ex) {
+ toPropagate = new RuntimeException("Sending request failed", e);
+ logger.debug("Received content together with {} HTTP code response: {}", response.statusCode(), body);
+ }
+
+ // Some implementations can return 400 when presented with a
+ // session id that it doesn't know about, so we will
+ // invalidate the session
+ // https://github.com/modelcontextprotocol/typescript-sdk/issues/389
+ if (responseException.getStatusCode().isSameCodeAs(HttpStatus.BAD_REQUEST)) {
+ return Mono.error(new McpTransportSessionNotFoundException(sessionRepresentation, toPropagate));
+ }
+ return Mono.empty();
+ }).flux();
+ }
+
+ private Flux eventStream(McpTransportStream stream, ClientResponse response) {
+ McpTransportStream sessionStream = stream != null ? stream
+ : new DefaultMcpTransportStream<>(this.resumableStreams, this::reconnect);
+ logger.debug("Connected stream {}", sessionStream.streamId());
+
+ var idWithMessages = response.bodyToFlux(PARAMETERIZED_TYPE_REF).map(this::parse);
+ return Flux.from(sessionStream.consumeSseStream(idWithMessages));
+ }
+
+ private static boolean isNotFound(ClientResponse response) {
+ return response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND);
+ }
+
+ private static boolean isNotAllowed(ClientResponse response) {
+ return response.statusCode().isSameCodeAs(HttpStatus.METHOD_NOT_ALLOWED);
+ }
+
+ private static boolean isEventStream(ClientResponse response) {
+ return response.statusCode().is2xxSuccessful() && response.headers().contentType().isPresent()
+ && response.headers().contentType().get().isCompatibleWith(MediaType.TEXT_EVENT_STREAM);
+ }
+
+ private static String sessionIdOrPlaceholder(McpTransportSession> transportSession) {
+ return transportSession.sessionId().orElse("[missing_session_id]");
+ }
+
+ private Flux responseFlux(ClientResponse response) {
+ return response.bodyToMono(String.class).>handle((responseMessage, s) -> {
+ try {
+ McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema.deserializeJsonRpcMessage(objectMapper,
+ responseMessage);
+ s.next(List.of(jsonRpcResponse));
+ }
+ catch (IOException e) {
+ s.error(e);
+ }
+ }).flatMapIterable(Function.identity());
+ }
+
+ private Flux newEventStream(ClientResponse response, String sessionRepresentation) {
+ McpTransportStream sessionStream = new DefaultMcpTransportStream<>(this.resumableStreams,
+ this::reconnect);
+ logger.trace("Sent POST and opened a stream ({}) for session {}", sessionStream.streamId(),
+ sessionRepresentation);
+ return eventStream(sessionStream, response);
+ }
+
+ @Override
+ public T unmarshalFrom(Object data, TypeReference typeRef) {
+ return this.objectMapper.convertValue(data, typeRef);
+ }
+
+ private Tuple2, Iterable> parse(ServerSentEvent event) {
+ if (MESSAGE_EVENT_TYPE.equals(event.event())) {
+ try {
+ // We don't support batching ATM and probably won't since the next version
+ // considers removing it.
+ McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, event.data());
+ return Tuples.of(Optional.ofNullable(event.id()), List.of(message));
+ }
+ catch (IOException ioException) {
+ throw new McpError("Error parsing JSON-RPC message: " + event.data());
+ }
+ }
+ else {
+ throw new McpError("Received unrecognized SSE event type: " + event.event());
+ }
+ }
+
+ /**
+ * Builder for {@link WebClientStreamableHttpTransport}.
+ */
+ public static class Builder {
+
+ private ObjectMapper objectMapper;
+
+ private WebClient.Builder webClientBuilder;
+
+ private String endpoint = DEFAULT_ENDPOINT;
+
+ private boolean resumableStreams = true;
+
+ private boolean openConnectionOnStartup = false;
+
+ private Builder(WebClient.Builder webClientBuilder) {
+ Assert.notNull(webClientBuilder, "WebClient.Builder must not be null");
+ this.webClientBuilder = webClientBuilder;
+ }
+
+ /**
+ * Configure the {@link ObjectMapper} to use.
+ * @param objectMapper instance to use
+ * @return the builder instance
+ */
+ public Builder objectMapper(ObjectMapper objectMapper) {
+ Assert.notNull(objectMapper, "ObjectMapper must not be null");
+ this.objectMapper = objectMapper;
+ return this;
+ }
+
+ /**
+ * Configure the {@link WebClient.Builder} to construct the {@link WebClient}.
+ * @param webClientBuilder instance to use
+ * @return the builder instance
+ */
+ public Builder webClientBuilder(WebClient.Builder webClientBuilder) {
+ Assert.notNull(webClientBuilder, "WebClient.Builder must not be null");
+ this.webClientBuilder = webClientBuilder;
+ return this;
+ }
+
+ /**
+ * Configure the endpoint to make HTTP requests against.
+ * @param endpoint endpoint to use
+ * @return the builder instance
+ */
+ public Builder endpoint(String endpoint) {
+ Assert.hasText(endpoint, "endpoint must be a non-empty String");
+ this.endpoint = endpoint;
+ return this;
+ }
+
+ /**
+ * Configure whether to use the stream resumability feature by keeping track of
+ * SSE event ids.
+ * @param resumableStreams if {@code true} event ids will be tracked and upon
+ * disconnection, the last seen id will be used upon reconnection as a header to
+ * resume consuming messages.
+ * @return the builder instance
+ */
+ public Builder resumableStreams(boolean resumableStreams) {
+ this.resumableStreams = resumableStreams;
+ return this;
+ }
+
+ /**
+ * Configure whether the client should open an SSE connection upon startup. Not
+ * all servers support this (although it is in theory possible with the current
+ * specification), so use with caution. By default, this value is {@code false}.
+ * @param openConnectionOnStartup if {@code true} the {@link #connect(Function)}
+ * method call will try to open an SSE connection before sending any JSON-RPC
+ * request
+ * @return the builder instance
+ */
+ public Builder openConnectionOnStartup(boolean openConnectionOnStartup) {
+ this.openConnectionOnStartup = openConnectionOnStartup;
+ return this;
+ }
+
+ /**
+ * Construct a fresh instance of {@link WebClientStreamableHttpTransport} using
+ * the current builder configuration.
+ * @return a new instance of {@link WebClientStreamableHttpTransport}
+ */
+ public WebClientStreamableHttpTransport build() {
+ ObjectMapper objectMapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper();
+
+ return new WebClientStreamableHttpTransport(objectMapper, this.webClientBuilder, endpoint, resumableStreams,
+ openConnectionOnStartup);
+ }
+
+ }
+
+}
diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java
index 37abe295..128cda4c 100644
--- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java
+++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java
@@ -190,6 +190,9 @@ public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, ObjectMappe
*/
@Override
public Mono connect(Function, Mono> handler) {
+ // TODO: Avoid eager connection opening and enable resilience
+ // -> upon disconnects, re-establish connection
+ // -> allow optimizing for eager connection start using a constructor flag
Flux> events = eventStream();
this.inboundSubscription = events.concatMap(event -> Mono.just(event).handle((e, s) -> {
if (ENDPOINT_EVENT_TYPE.equals(event.event())) {
diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java
new file mode 100644
index 00000000..80fc671e
--- /dev/null
+++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java
@@ -0,0 +1,17 @@
+package io.modelcontextprotocol.client;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport;
+import io.modelcontextprotocol.spec.McpClientTransport;
+import org.junit.jupiter.api.Timeout;
+import org.springframework.web.reactive.function.client.WebClient;
+
+@Timeout(15)
+public class WebClientStreamableHttpAsyncClientResiliencyTests extends AbstractMcpAsyncClientResiliencyTests {
+
+ @Override
+ protected McpClientTransport createMcpTransport() {
+ return WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(host)).build();
+ }
+
+}
diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java
new file mode 100644
index 00000000..4c803265
--- /dev/null
+++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java
@@ -0,0 +1,42 @@
+package io.modelcontextprotocol.client;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport;
+import io.modelcontextprotocol.spec.McpClientTransport;
+import org.junit.jupiter.api.Timeout;
+import org.springframework.web.reactive.function.client.WebClient;
+import org.testcontainers.containers.GenericContainer;
+import org.testcontainers.containers.wait.strategy.Wait;
+import org.testcontainers.images.builder.ImageFromDockerfile;
+
+@Timeout(15)
+public class WebClientStreamableHttpAsyncClientTests extends AbstractMcpAsyncClientTests {
+
+ static String host = "http://localhost:3001";
+
+ // Uses the https://github.com/tzolov/mcp-everything-server-docker-image
+ @SuppressWarnings("resource")
+ GenericContainer> container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2")
+ .withCommand("node dist/index.js streamableHttp")
+ .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String()))
+ .withExposedPorts(3001)
+ .waitingFor(Wait.forHttp("/").forStatusCode(404));
+
+ @Override
+ protected McpClientTransport createMcpTransport() {
+ return WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(host)).build();
+ }
+
+ @Override
+ protected void onStart() {
+ container.start();
+ int port = container.getMappedPort(3001);
+ host = "http://" + container.getHost() + ":" + port;
+ }
+
+ @Override
+ public void onClose() {
+ container.stop();
+ }
+
+}
diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java
new file mode 100644
index 00000000..a8cad489
--- /dev/null
+++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java
@@ -0,0 +1,41 @@
+package io.modelcontextprotocol.client;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport;
+import io.modelcontextprotocol.spec.McpClientTransport;
+import org.junit.jupiter.api.Timeout;
+import org.springframework.web.reactive.function.client.WebClient;
+import org.testcontainers.containers.GenericContainer;
+import org.testcontainers.containers.wait.strategy.Wait;
+
+@Timeout(15)
+public class WebClientStreamableHttpSyncClientTests extends AbstractMcpSyncClientTests {
+
+ static String host = "http://localhost:3001";
+
+ // Uses the https://github.com/tzolov/mcp-everything-server-docker-image
+ @SuppressWarnings("resource")
+ GenericContainer> container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2")
+ .withCommand("node dist/index.js streamableHttp")
+ .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String()))
+ .withExposedPorts(3001)
+ .waitingFor(Wait.forHttp("/").forStatusCode(404));
+
+ @Override
+ protected McpClientTransport createMcpTransport() {
+ return WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(host)).build();
+ }
+
+ @Override
+ protected void onStart() {
+ container.start();
+ int port = container.getMappedPort(3001);
+ host = "http://" + container.getHost() + ":" + port;
+ }
+
+ @Override
+ public void onClose() {
+ container.stop();
+ }
+
+}
diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java
index b43c1449..f0533cb4 100644
--- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java
+++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java
@@ -26,7 +26,8 @@ class WebFluxSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests {
// Uses the https://github.com/tzolov/mcp-everything-server-docker-image
@SuppressWarnings("resource")
- GenericContainer> container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1")
+ GenericContainer> container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2")
+ .withCommand("node dist/index.js sse")
.withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String()))
.withExposedPorts(3001)
.waitingFor(Wait.forHttp("/").forStatusCode(404));
diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java
index 66ac8a6d..9b0959a3 100644
--- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java
+++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java
@@ -26,7 +26,8 @@ class WebFluxSseMcpSyncClientTests extends AbstractMcpSyncClientTests {
// Uses the https://github.com/tzolov/mcp-everything-server-docker-image
@SuppressWarnings("resource")
- GenericContainer> container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1")
+ GenericContainer> container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2")
+ .withCommand("node dist/index.js sse")
.withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String()))
.withExposedPorts(3001)
.waitingFor(Wait.forHttp("/").forStatusCode(404));
diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java
index c757d3da..42b91d14 100644
--- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java
+++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java
@@ -41,7 +41,8 @@ class WebFluxSseClientTransportTests {
static String host = "http://localhost:3001";
@SuppressWarnings("resource")
- GenericContainer> container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1")
+ GenericContainer> container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2")
+ .withCommand("node dist/index.js sse")
.withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String()))
.withExposedPorts(3001)
.waitingFor(Wait.forHttp("/").forStatusCode(404));
diff --git a/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml b/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml
index 5ad73374..abc831d1 100644
--- a/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml
+++ b/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml
@@ -9,13 +9,13 @@
-
+
-
-
+
+
-
+
diff --git a/mcp-test/pom.xml b/mcp-test/pom.xml
index a6e5bdb0..9998569d 100644
--- a/mcp-test/pom.xml
+++ b/mcp-test/pom.xml
@@ -68,6 +68,11 @@
junit-jupiter${testcontainers.version}
+
+ org.testcontainers
+ toxiproxy
+ ${toxiproxy.version}
+ org.awaitility
diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java
new file mode 100644
index 00000000..85d6a88e
--- /dev/null
+++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java
@@ -0,0 +1,222 @@
+package io.modelcontextprotocol.client;
+
+import eu.rekawek.toxiproxy.Proxy;
+import eu.rekawek.toxiproxy.ToxiproxyClient;
+import eu.rekawek.toxiproxy.model.ToxicDirection;
+import io.modelcontextprotocol.spec.McpClientTransport;
+import io.modelcontextprotocol.spec.McpSchema;
+import io.modelcontextprotocol.spec.McpTransport;
+import org.junit.jupiter.api.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.testcontainers.containers.GenericContainer;
+import org.testcontainers.containers.Network;
+import org.testcontainers.containers.ToxiproxyContainer;
+import org.testcontainers.containers.wait.strategy.Wait;
+import reactor.test.StepVerifier;
+
+import java.io.IOException;
+import java.time.Duration;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Consumer;
+import java.util.function.Function;
+
+import static org.assertj.core.api.Assertions.assertThatCode;
+
+/**
+ * Resiliency test suite for the {@link McpAsyncClient} that can be used with different
+ * {@link McpTransport} implementations that support Streamable HTTP.
+ *
+ * The purpose of these tests is to allow validating the transport layer resiliency
+ * instead of the functionality offered by the logical layer of MCP concepts such as
+ * tools, resources, prompts, etc.
+ *
+ * @author Dariusz Jędrzejczyk
+ */
+public abstract class AbstractMcpAsyncClientResiliencyTests {
+
+ private static final Logger logger = LoggerFactory.getLogger(AbstractMcpAsyncClientResiliencyTests.class);
+
+ static Network network = Network.newNetwork();
+ static String host = "http://localhost:3001";
+
+ // Uses the https://github.com/tzolov/mcp-everything-server-docker-image
+ @SuppressWarnings("resource")
+ static GenericContainer> container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2")
+ .withCommand("node dist/index.js streamableHttp")
+ .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String()))
+ .withNetwork(network)
+ .withNetworkAliases("everything-server")
+ .withExposedPorts(3001)
+ .waitingFor(Wait.forHttp("/").forStatusCode(404));
+
+ static ToxiproxyContainer toxiproxy = new ToxiproxyContainer("ghcr.io/shopify/toxiproxy:2.5.0").withNetwork(network)
+ .withExposedPorts(8474, 3000);
+
+ static Proxy proxy;
+
+ static {
+ container.start();
+
+ toxiproxy.start();
+
+ final ToxiproxyClient toxiproxyClient = new ToxiproxyClient(toxiproxy.getHost(), toxiproxy.getControlPort());
+ try {
+ proxy = toxiproxyClient.createProxy("everything-server", "0.0.0.0:3000", "everything-server:3001");
+ }
+ catch (IOException e) {
+ throw new RuntimeException("Can't create proxy!", e);
+ }
+
+ final String ipAddressViaToxiproxy = toxiproxy.getHost();
+ final int portViaToxiproxy = toxiproxy.getMappedPort(3000);
+
+ host = "http://" + ipAddressViaToxiproxy + ":" + portViaToxiproxy;
+ }
+
+ private static void disconnect() {
+ long start = System.nanoTime();
+ try {
+ // disconnect
+ // proxy.toxics().bandwidth("CUT_CONNECTION_DOWNSTREAM",
+ // ToxicDirection.DOWNSTREAM, 0);
+ // proxy.toxics().bandwidth("CUT_CONNECTION_UPSTREAM",
+ // ToxicDirection.UPSTREAM, 0);
+ proxy.toxics().resetPeer("RESET_DOWNSTREAM", ToxicDirection.DOWNSTREAM, 0);
+ proxy.toxics().resetPeer("RESET_UPSTREAM", ToxicDirection.UPSTREAM, 0);
+ logger.info("Disconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis());
+ }
+ catch (IOException e) {
+ throw new RuntimeException("Failed to disconnect", e);
+ }
+ }
+
+ private static void reconnect() {
+ long start = System.nanoTime();
+ try {
+ proxy.toxics().get("RESET_UPSTREAM").remove();
+ proxy.toxics().get("RESET_DOWNSTREAM").remove();
+ // proxy.toxics().get("CUT_CONNECTION_DOWNSTREAM").remove();
+ // proxy.toxics().get("CUT_CONNECTION_UPSTREAM").remove();
+ logger.info("Reconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis());
+ }
+ catch (IOException e) {
+ throw new RuntimeException("Failed to reconnect", e);
+ }
+ }
+
+ private static void restartMcpServer() {
+ container.stop();
+ container.start();
+ }
+
+ abstract McpClientTransport createMcpTransport();
+
+ protected Duration getRequestTimeout() {
+ return Duration.ofSeconds(14);
+ }
+
+ protected Duration getInitializationTimeout() {
+ return Duration.ofSeconds(2);
+ }
+
+ McpAsyncClient client(McpClientTransport transport) {
+ return client(transport, Function.identity());
+ }
+
+ McpAsyncClient client(McpClientTransport transport, Function customizer) {
+ AtomicReference client = new AtomicReference<>();
+
+ assertThatCode(() -> {
+ McpClient.AsyncSpec builder = McpClient.async(transport)
+ .requestTimeout(getRequestTimeout())
+ .initializationTimeout(getInitializationTimeout())
+ .capabilities(McpSchema.ClientCapabilities.builder().roots(true).build());
+ builder = customizer.apply(builder);
+ client.set(builder.build());
+ }).doesNotThrowAnyException();
+
+ return client.get();
+ }
+
+ void withClient(McpClientTransport transport, Consumer c) {
+ withClient(transport, Function.identity(), c);
+ }
+
+ void withClient(McpClientTransport transport, Function customizer,
+ Consumer c) {
+ var client = client(transport, customizer);
+ try {
+ c.accept(client);
+ }
+ finally {
+ StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10));
+ }
+ }
+
+ @Test
+ void testPing() {
+ withClient(createMcpTransport(), mcpAsyncClient -> {
+ StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete();
+
+ disconnect();
+
+ StepVerifier.create(mcpAsyncClient.ping()).expectError().verify();
+
+ reconnect();
+
+ StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete();
+ });
+ }
+
+ @Test
+ void testSessionInvalidation() {
+ withClient(createMcpTransport(), mcpAsyncClient -> {
+ StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete();
+
+ restartMcpServer();
+
+ // The first try will face the session mismatch exception and the second one
+ // will go through the re-initialization process.
+ StepVerifier.create(mcpAsyncClient.ping().retry(1)).expectNextCount(1).verifyComplete();
+ });
+ }
+
+ @Test
+ void testCallTool() {
+ withClient(createMcpTransport(), mcpAsyncClient -> {
+ AtomicReference> tools = new AtomicReference<>();
+ StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete();
+ StepVerifier.create(mcpAsyncClient.listTools())
+ .consumeNextWith(list -> tools.set(list.tools()))
+ .verifyComplete();
+
+ disconnect();
+
+ String name = tools.get().get(0).name();
+ // Assuming this is the echo tool
+ McpSchema.CallToolRequest request = new McpSchema.CallToolRequest(name, Map.of("message", "hello"));
+ StepVerifier.create(mcpAsyncClient.callTool(request)).expectError().verify();
+
+ reconnect();
+
+ StepVerifier.create(mcpAsyncClient.callTool(request)).expectNextCount(1).verifyComplete();
+ });
+ }
+
+ @Test
+ void testSessionClose() {
+ withClient(createMcpTransport(), mcpAsyncClient -> {
+ StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete();
+ // In case of Streamable HTTP this call should issue a HTTP DELETE request
+ // invalidating the session
+ StepVerifier.create(mcpAsyncClient.closeGracefully()).expectComplete().verify();
+ // The next use should immediately re-initialize with no issue and send the
+ // request without any broken connections.
+ StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete();
+ });
+ }
+
+}
diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java
index 5452c8ea..049bea00 100644
--- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java
+++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java
@@ -110,14 +110,16 @@ void tearDown() {
onClose();
}
- void verifyInitializationTimeout(Function> operation, String action) {
+ void verifyNotificationSucceedsWithImplicitInitialization(Function> operation,
+ String action) {
withClient(createMcpTransport(), mcpAsyncClient -> {
- StepVerifier.withVirtualTime(() -> operation.apply(mcpAsyncClient))
- .expectSubscription()
- .thenAwait(getInitializationTimeout())
- .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class)
- .hasMessage("Client must be initialized before " + action))
- .verify();
+ StepVerifier.create(operation.apply(mcpAsyncClient)).verifyComplete();
+ });
+ }
+
+ void verifyCallSucceedsWithImplicitInitialization(Function> operation, String action) {
+ withClient(createMcpTransport(), mcpAsyncClient -> {
+ StepVerifier.create(operation.apply(mcpAsyncClient)).expectNextCount(1).verifyComplete();
});
}
@@ -133,7 +135,7 @@ void testConstructorWithInvalidArguments() {
@Test
void testListToolsWithoutInitialization() {
- verifyInitializationTimeout(client -> client.listTools(null), "listing tools");
+ verifyCallSucceedsWithImplicitInitialization(client -> client.listTools(null), "listing tools");
}
@Test
@@ -153,7 +155,7 @@ void testListTools() {
@Test
void testPingWithoutInitialization() {
- verifyInitializationTimeout(client -> client.ping(), "pinging the server");
+ verifyCallSucceedsWithImplicitInitialization(client -> client.ping(), "pinging the server");
}
@Test
@@ -168,7 +170,7 @@ void testPing() {
@Test
void testCallToolWithoutInitialization() {
CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE));
- verifyInitializationTimeout(client -> client.callTool(callToolRequest), "calling tools");
+ verifyCallSucceedsWithImplicitInitialization(client -> client.callTool(callToolRequest), "calling tools");
}
@Test
@@ -202,7 +204,7 @@ void testCallToolWithInvalidTool() {
@Test
void testListResourcesWithoutInitialization() {
- verifyInitializationTimeout(client -> client.listResources(null), "listing resources");
+ verifyCallSucceedsWithImplicitInitialization(client -> client.listResources(null), "listing resources");
}
@Test
@@ -233,7 +235,7 @@ void testMcpAsyncClientState() {
@Test
void testListPromptsWithoutInitialization() {
- verifyInitializationTimeout(client -> client.listPrompts(null), "listing " + "prompts");
+ verifyCallSucceedsWithImplicitInitialization(client -> client.listPrompts(null), "listing " + "prompts");
}
@Test
@@ -258,7 +260,7 @@ void testListPrompts() {
@Test
void testGetPromptWithoutInitialization() {
GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of());
- verifyInitializationTimeout(client -> client.getPrompt(request), "getting " + "prompts");
+ verifyCallSucceedsWithImplicitInitialization(client -> client.getPrompt(request), "getting " + "prompts");
}
@Test
@@ -279,7 +281,7 @@ void testGetPrompt() {
@Test
void testRootsListChangedWithoutInitialization() {
- verifyInitializationTimeout(client -> client.rootsListChangedNotification(),
+ verifyNotificationSucceedsWithImplicitInitialization(client -> client.rootsListChangedNotification(),
"sending roots list changed notification");
}
@@ -354,7 +356,8 @@ void testReadResource() {
@Test
void testListResourceTemplatesWithoutInitialization() {
- verifyInitializationTimeout(client -> client.listResourceTemplates(), "listing resource templates");
+ verifyCallSucceedsWithImplicitInitialization(client -> client.listResourceTemplates(),
+ "listing resource templates");
}
@Test
@@ -447,8 +450,8 @@ void testInitializeWithAllCapabilities() {
@Test
void testLoggingLevelsWithoutInitialization() {
- verifyInitializationTimeout(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG),
- "setting logging level");
+ verifyNotificationSucceedsWithImplicitInitialization(
+ client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), "setting logging level");
}
@Test
diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java
index 128441f8..3785fd64 100644
--- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java
+++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java
@@ -5,6 +5,7 @@
package io.modelcontextprotocol.client;
import java.time.Duration;
+import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
@@ -12,7 +13,6 @@
import java.util.function.Function;
import io.modelcontextprotocol.spec.McpClientTransport;
-import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
@@ -112,33 +112,18 @@ void tearDown() {
static final Object DUMMY_RETURN_VALUE = new Object();
- void verifyNotificationTimesOut(Consumer operation, String action) {
- verifyCallTimesOut(client -> {
+ void verifyNotificationSucceedsWithImplicitInitialization(Consumer operation, String action) {
+ verifyCallSucceedsWithImplicitInitialization(client -> {
operation.accept(client);
return DUMMY_RETURN_VALUE;
}, action);
}
- void verifyCallTimesOut(Function blockingOperation, String action) {
+ void verifyCallSucceedsWithImplicitInitialization(Function blockingOperation, String action) {
withClient(createMcpTransport(), mcpSyncClient -> {
- // This scheduler is not replaced by virtual time scheduler
- Scheduler customScheduler = Schedulers.newBoundedElastic(1, 1, "actualBoundedElastic");
-
- StepVerifier.withVirtualTime(() -> Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient))
- // Offload the blocking call to the real scheduler
- .subscribeOn(customScheduler))
- .expectSubscription()
- // This works without actually waiting but executes all the
- // tasks pending execution on the VirtualTimeScheduler.
- // It is possible to execute the blocking code from the operation
- // because it is blocked on a dedicated Scheduler and the main
- // flow is not blocked and uses the VirtualTimeScheduler.
- .thenAwait(getInitializationTimeout())
- .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class)
- .hasMessage("Client must be initialized before " + action))
- .verify();
-
- customScheduler.dispose();
+ StepVerifier.create(Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient)))
+ .expectNextCount(1)
+ .verifyComplete();
});
}
@@ -154,7 +139,7 @@ void testConstructorWithInvalidArguments() {
@Test
void testListToolsWithoutInitialization() {
- verifyCallTimesOut(client -> client.listTools(null), "listing tools");
+ verifyCallSucceedsWithImplicitInitialization(client -> client.listTools(null), "listing tools");
}
@Test
@@ -175,8 +160,8 @@ void testListTools() {
@Test
void testCallToolsWithoutInitialization() {
- verifyCallTimesOut(client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))),
- "calling tools");
+ verifyCallSucceedsWithImplicitInitialization(
+ client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))), "calling tools");
}
@Test
@@ -200,7 +185,7 @@ void testCallTools() {
@Test
void testPingWithoutInitialization() {
- verifyCallTimesOut(client -> client.ping(), "pinging the server");
+ verifyCallSucceedsWithImplicitInitialization(client -> client.ping(), "pinging the server");
}
@Test
@@ -214,7 +199,7 @@ void testPing() {
@Test
void testCallToolWithoutInitialization() {
CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE));
- verifyCallTimesOut(client -> client.callTool(callToolRequest), "calling tools");
+ verifyCallSucceedsWithImplicitInitialization(client -> client.callTool(callToolRequest), "calling tools");
}
@Test
@@ -243,7 +228,7 @@ void testCallToolWithInvalidTool() {
@Test
void testRootsListChangedWithoutInitialization() {
- verifyNotificationTimesOut(client -> client.rootsListChangedNotification(),
+ verifyNotificationSucceedsWithImplicitInitialization(client -> client.rootsListChangedNotification(),
"sending roots list changed notification");
}
@@ -257,7 +242,7 @@ void testRootsListChanged() {
@Test
void testListResourcesWithoutInitialization() {
- verifyCallTimesOut(client -> client.listResources(null), "listing resources");
+ verifyCallSucceedsWithImplicitInitialization(client -> client.listResources(null), "listing resources");
}
@Test
@@ -333,8 +318,14 @@ void testRemoveNonExistentRoot() {
@Test
void testReadResourceWithoutInitialization() {
- Resource resource = new Resource("test://uri", "Test Resource", null, null, null);
- verifyCallTimesOut(client -> client.readResource(resource), "reading resources");
+ AtomicReference> resources = new AtomicReference<>();
+ withClient(createMcpTransport(), mcpSyncClient -> {
+ mcpSyncClient.initialize();
+ resources.set(mcpSyncClient.listResources().resources());
+ });
+
+ verifyCallSucceedsWithImplicitInitialization(client -> client.readResource(resources.get().get(0)),
+ "reading resources");
}
@Test
@@ -355,7 +346,8 @@ void testReadResource() {
@Test
void testListResourceTemplatesWithoutInitialization() {
- verifyCallTimesOut(client -> client.listResourceTemplates(null), "listing resource templates");
+ verifyCallSucceedsWithImplicitInitialization(client -> client.listResourceTemplates(null),
+ "listing resource templates");
}
@Test
@@ -413,8 +405,8 @@ void testNotificationHandlers() {
@Test
void testLoggingLevelsWithoutInitialization() {
- verifyNotificationTimesOut(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG),
- "setting logging level");
+ verifyNotificationSucceedsWithImplicitInitialization(
+ client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), "setting logging level");
}
@Test
diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java
index a22ef6b5..8f0433eb 100644
--- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java
+++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java
@@ -9,9 +9,9 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.TimeoutException;
-import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
+import java.util.function.Supplier;
import com.fasterxml.jackson.core.type.TypeReference;
import io.modelcontextprotocol.spec.McpClientSession;
@@ -32,7 +32,7 @@
import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification;
import io.modelcontextprotocol.spec.McpSchema.PaginatedRequest;
import io.modelcontextprotocol.spec.McpSchema.Root;
-import io.modelcontextprotocol.spec.McpTransport;
+import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException;
import io.modelcontextprotocol.util.Assert;
import io.modelcontextprotocol.util.Utils;
import org.slf4j.Logger;
@@ -77,29 +77,37 @@
* @see McpClient
* @see McpSchema
* @see McpClientSession
+ * @see McpClientTransport
*/
public class McpAsyncClient {
private static final Logger logger = LoggerFactory.getLogger(McpAsyncClient.class);
- private static TypeReference VOID_TYPE_REFERENCE = new TypeReference<>() {
+ private static final TypeReference VOID_TYPE_REFERENCE = new TypeReference<>() {
};
- protected final Sinks.One initializedSink = Sinks.one();
+ public static final TypeReference
*/
public Mono initialize() {
+ return withSession("by explicit API call", init -> Mono.just(init.get()));
+ }
+ private Mono doInitialize(McpClientSession mcpClientSession) {
String latestVersion = this.protocolVersions.get(this.protocolVersions.size() - 1);
McpSchema.InitializeRequest initializeRequest = new McpSchema.InitializeRequest(// @formatter:off
@@ -356,16 +395,10 @@ public Mono initialize() {
this.clientCapabilities,
this.clientInfo); // @formatter:on
- Mono result = this.mcpSession.sendRequest(McpSchema.METHOD_INITIALIZE,
- initializeRequest, new TypeReference() {
- });
+ Mono result = mcpClientSession.sendRequest(McpSchema.METHOD_INITIALIZE,
+ initializeRequest, INITIALIZE_RESULT_TYPE_REF);
return result.flatMap(initializeResult -> {
-
- this.serverCapabilities = initializeResult.capabilities();
- this.serverInstructions = initializeResult.instructions();
- this.serverInfo = initializeResult.serverInfo();
-
logger.info("Server response with Protocol: {}, Capabilities: {}, Info: {} and Instructions {}",
initializeResult.protocolVersion(), initializeResult.capabilities(), initializeResult.serverInfo(),
initializeResult.instructions());
@@ -375,28 +408,93 @@ public Mono initialize() {
"Unsupported protocol version from the server: " + initializeResult.protocolVersion()));
}
- return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null).doOnSuccess(v -> {
- this.initialized.set(true);
- this.initializedSink.tryEmitValue(initializeResult);
- }).thenReturn(initializeResult);
+ return mcpClientSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null)
+ .thenReturn(initializeResult);
});
}
+ private static class Initialization {
+
+ private final Sinks.One initSink = Sinks.one();
+
+ private final AtomicReference result = new AtomicReference<>();
+
+ private final AtomicReference mcpClientSession = new AtomicReference<>();
+
+ static Initialization create() {
+ return new Initialization();
+ }
+
+ void setMcpClientSession(McpClientSession mcpClientSession) {
+ this.mcpClientSession.set(mcpClientSession);
+ }
+
+ McpClientSession mcpSession() {
+ return this.mcpClientSession.get();
+ }
+
+ McpSchema.InitializeResult get() {
+ return this.result.get();
+ }
+
+ Mono await() {
+ return this.initSink.asMono();
+ }
+
+ void complete(McpSchema.InitializeResult initializeResult) {
+ // first ensure the result is cached
+ this.result.set(initializeResult);
+ // inform all the subscribers waiting for the initialization
+ this.initSink.emitValue(initializeResult, Sinks.EmitFailureHandler.FAIL_FAST);
+ }
+
+ void error(Throwable t) {
+ this.initSink.emitError(t, Sinks.EmitFailureHandler.FAIL_FAST);
+ }
+
+ void close() {
+ this.mcpSession().close();
+ }
+
+ Mono closeGracefully() {
+ return this.mcpSession().closeGracefully();
+ }
+
+ }
+
/**
- * Utility method to handle the common pattern of checking initialization before
+ * Utility method to handle the common pattern of ensuring initialization before
* executing an operation.
* @param The type of the result Mono
- * @param actionName The action to perform if the client is initialized
- * @param operation The operation to execute if the client is initialized
+ * @param actionName The action to perform when the client is initialized
+ * @param operation The operation to execute when the client is initialized
* @return A Mono that completes with the result of the operation
*/
- private Mono withInitializationCheck(String actionName,
- Function> operation) {
- return this.initializedSink.asMono()
- .timeout(this.initializationTimeout)
- .onErrorResume(TimeoutException.class,
- ex -> Mono.error(new McpError("Client must be initialized before " + actionName)))
- .flatMap(operation);
+ private Mono withSession(String actionName, Function> operation) {
+ return Mono.defer(() -> {
+ Initialization newInit = Initialization.create();
+ Initialization previous = this.initializationRef.compareAndExchange(null, newInit);
+
+ boolean needsToInitialize = previous == null;
+ logger.debug(needsToInitialize ? "Initialization process started" : "Joining previous initialization");
+ if (needsToInitialize) {
+ newInit.setMcpClientSession(this.sessionSupplier.get());
+ }
+
+ Mono initializationJob = needsToInitialize
+ ? doInitialize(newInit.mcpSession()).doOnNext(newInit::complete).onErrorResume(ex -> {
+ newInit.error(ex);
+ return Mono.error(ex);
+ }) : previous.await();
+
+ return initializationJob.map(initializeResult -> this.initializationRef.get())
+ .timeout(this.initializationTimeout)
+ .onErrorResume(ex -> {
+ logger.warn("Failed to initialize", ex);
+ return Mono.error(new McpError("Client failed to initialize " + actionName));
+ })
+ .flatMap(operation);
+ });
}
// --------------------------
@@ -408,9 +506,8 @@ private Mono withInitializationCheck(String actionName,
* @return A Mono that completes with the server's ping response
*/
public Mono