diff --git a/README-streamable-http.md b/README-streamable-http.md new file mode 100644 index 00000000..faa331ca --- /dev/null +++ b/README-streamable-http.md @@ -0,0 +1,90 @@ +# MCP Streamable HTTP Implementation + +This is an implementation of the [Model Context Protocol (MCP)](https://modelcontextprotocol.io/) Streamable HTTP transport for Go. It follows the [MCP Streamable HTTP transport specification](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports). + +## Features + +- Implementation of the MCP Streamable HTTP transport specification +- Support for both client and server sides +- Session management with unique session IDs +- Support for SSE (Server-Sent Events) streaming +- Support for direct JSON responses +- Basic resumability with event IDs +- Support for notifications +- Support for session termination +- Origin header validation for security + +## Current Limitations + +- Limited batching support +- Basic resumability support (improved but not complete) +- No support for server -> client requests +- Limited support for continuously listening for server notifications + +## Server Implementation + +The server implementation is in `server/streamable_http.go`. It provides the Streamable HTTP transport for the server side. + +### Key Components + +- `StreamableHTTPServer`: The main server implementation that handles HTTP requests and responses +- `streamableHTTPSession`: Represents an active session with a client +- `EventStore`: Interface for storing and retrieving events for resumability +- `InMemoryEventStore`: A simple in-memory implementation of the EventStore interface + +### Server Options + +- `WithSessionIDGenerator`: Sets a custom session ID generator +- `WithStatelessMode`: Enables stateless mode (no sessions) +- `WithEnableJSONResponse`: Enables direct JSON responses instead of SSE streams +- `WithEventStore`: Sets a custom event store for resumability +- `WithStreamableHTTPContextFunc`: Sets a function to customize the context + +## Client Implementation + +The client implementation is in `client/transport/streamable_http.go`. It provides the Streamable HTTP transport for the client side. + +### Client Options + +- `WithHTTPHeaders`: Sets custom HTTP headers for all requests +- `WithHTTPTimeout`: Sets the timeout for HTTP requests and streams + +## Usage + +For complete examples, see: +- Server example: `examples/streamable_http_server/main.go` +- Client example: `examples/streamable_http_client/main.go` +- Complete client example: `examples/streamable_http_client_complete/main.go` + +## Protocol Details + +The Streamable HTTP transport follows the MCP Streamable HTTP transport specification: + +1. **Session Management**: Sessions are created during initialization and maintained through a session ID header. +2. **SSE Streaming**: Server-Sent Events (SSE) are used for streaming responses and notifications. +3. **Direct JSON Responses**: For simple requests, direct JSON responses can be used instead of SSE. +4. **Resumability**: Events can be stored and replayed if a client reconnects with a Last-Event-ID header. +5. **Session Termination**: Sessions can be explicitly terminated with a DELETE request. +6. **Multiple Sessions**: The server supports multiple concurrent independent sessions. + +## HTTP Methods + +- **POST**: Used for sending JSON-RPC requests and notifications +- **GET**: Used for establishing a standalone SSE stream for receiving notifications +- **DELETE**: Used for terminating a session + +## HTTP Headers + +- **Mcp-Session-Id**: Used to identify a session +- **Accept**: Used to indicate support for SSE (`text/event-stream`) +- **Last-Event-Id**: Used for resumability +- **Origin**: Validated by the server for security + +## Implementation Notes + +- The server implementation supports both stateful and stateless modes. +- In stateful mode, a session ID is generated and maintained for each client. +- In stateless mode, no session ID is generated, and no session state is maintained. +- The client implementation supports reconnecting and resuming after disconnection. +- The server implementation supports multiple concurrent clients. +- Each client instance typically manages a single session at a time. diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index 98719bd0..e51e9c52 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -41,19 +41,18 @@ func WithHTTPTimeout(timeout time.Duration) StreamableHTTPCOption { // // https://modelcontextprotocol.io/specification/2025-03-26/basic/transports // -// The current implementation does not support the following features: -// - batching -// - continuously listening for server notifications when no request is in flight -// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server) -// - resuming stream -// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery) -// - server -> client request +// Current limitations: +// - Limited batching support +// - Basic resumability support (improved but not complete) +// - No support for server -> client requests +// - Limited support for continuously listening for server notifications type StreamableHTTP struct { baseURL *url.URL httpClient *http.Client headers map[string]string - sessionID atomic.Value // string + sessionID atomic.Value // string + lastEventID atomic.Value // string for resumability notificationHandler func(mcp.JSONRPCNotification) notifyMu sync.RWMutex @@ -75,7 +74,8 @@ func NewStreamableHTTP(baseURL string, options ...StreamableHTTPCOption) (*Strea headers: make(map[string]string), closed: make(chan struct{}), } - smc.sessionID.Store("") // set initial value to simplify later usage + smc.sessionID.Store("") // set initial value to simplify later usage + smc.lastEventID.Store("") // initialize lastEventID for _, opt := range options { opt(smc) @@ -166,10 +166,20 @@ func (c *StreamableHTTP) SendRequest( // Set headers req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json, text/event-stream") + + // Add session ID if available sessionID := c.sessionID.Load() if sessionID != "" { req.Header.Set(headerKeySessionID, sessionID.(string)) } + + // Add Last-Event-Id header for resumability if available + lastEventID := c.lastEventID.Load() + if lastEventID != nil && lastEventID.(string) != "" { + req.Header.Set("Last-Event-Id", lastEventID.(string)) + } + + // Add custom headers for k, v := range c.headers { req.Header.Set(k, v) } @@ -294,7 +304,7 @@ func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, hand defer reader.Close() br := bufio.NewReader(reader) - var event, data string + var event, data, id string for { select { @@ -325,8 +335,13 @@ func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, hand // Empty line means end of event if event != "" && data != "" { handler(event, data) + // Store the last event ID for resumability if present + if id != "" { + c.lastEventID.Store(id) + } event = "" data = "" + id = "" } continue } @@ -335,6 +350,8 @@ func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, hand event = strings.TrimSpace(strings.TrimPrefix(line, "event:")) } else if strings.HasPrefix(line, "data:") { data = strings.TrimSpace(strings.TrimPrefix(line, "data:")) + } else if strings.HasPrefix(line, "id:") { + id = strings.TrimSpace(strings.TrimPrefix(line, "id:")) } } } @@ -357,9 +374,19 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp. // Set headers req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json, text/event-stream") + + // Add session ID if available if sessionID := c.sessionID.Load(); sessionID != "" { req.Header.Set(headerKeySessionID, sessionID.(string)) } + + // Add Last-Event-Id header for resumability if available + lastEventID := c.lastEventID.Load() + if lastEventID != nil && lastEventID.(string) != "" { + req.Header.Set("Last-Event-Id", lastEventID.(string)) + } + + // Add custom headers for k, v := range c.headers { req.Header.Set(k, v) } @@ -392,3 +419,12 @@ func (c *StreamableHTTP) SetNotificationHandler(handler func(mcp.JSONRPCNotifica func (c *StreamableHTTP) GetSessionId() string { return c.sessionID.Load().(string) } + +// GetLastEventId returns the last event ID for resumability +func (c *StreamableHTTP) GetLastEventId() string { + lastEventID := c.lastEventID.Load() + if lastEventID == nil { + return "" + } + return lastEventID.(string) +} diff --git a/examples/minimal_client/main.go b/examples/minimal_client/main.go new file mode 100644 index 00000000..269bd3b5 --- /dev/null +++ b/examples/minimal_client/main.go @@ -0,0 +1,71 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "os" + "time" + + "github.com/mark3labs/mcp-go/client/transport" +) + +func main() { + // Create a new Streamable HTTP transport with a longer timeout + trans, err := transport.NewStreamableHTTP("http://localhost:8080/mcp", + transport.WithHTTPTimeout(30*time.Second)) + if err != nil { + fmt.Printf("Failed to create transport: %v\n", err) + os.Exit(1) + } + defer trans.Close() + + // Create a context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Initialize the connection + fmt.Println("Initializing connection...") + initRequest := transport.JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "initialize", + } + + initResponse, err := trans.SendRequest(ctx, initRequest) + if err != nil { + fmt.Printf("Failed to initialize: %v\n", err) + os.Exit(1) + } + + // Print the initialization response + initResponseJSON, _ := json.MarshalIndent(initResponse, "", " ") + fmt.Printf("Initialization response: %s\n", initResponseJSON) + fmt.Printf("Session ID: %s\n", trans.GetSessionId()) + + // Call the echo tool + fmt.Println("\nCalling echo tool...") + echoRequest := transport.JSONRPCRequest{ + JSONRPC: "2.0", + ID: 2, + Method: "tools/call", + Params: map[string]interface{}{ + "name": "echo", + "arguments": map[string]interface{}{ + "message": "Hello from minimal client!", + }, + }, + } + + echoResponse, err := trans.SendRequest(ctx, echoRequest) + if err != nil { + fmt.Printf("Failed to call echo tool: %v\n", err) + os.Exit(1) + } + + // Print the echo response + echoResponseJSON, _ := json.MarshalIndent(echoResponse, "", " ") + fmt.Printf("Echo response: %s\n", echoResponseJSON) + + fmt.Println("\nTest completed successfully!") +} diff --git a/examples/minimal_server/main.go b/examples/minimal_server/main.go new file mode 100644 index 00000000..fca772dd --- /dev/null +++ b/examples/minimal_server/main.go @@ -0,0 +1,83 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + "os/signal" + "syscall" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +func main() { + // Create a new MCP server + mcpServer := server.NewMCPServer("minimal-server", "1.0.0") + + // Add a simple echo tool + mcpServer.AddTool( + mcp.Tool{ + Name: "echo", + Description: "Echoes back the input", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "description": "The message to echo", + }, + }, + Required: []string{"message"}, + }, + }, + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Extract the message from the request + message, ok := request.Params.Arguments["message"].(string) + if !ok { + return nil, fmt.Errorf("message must be a string") + } + + // Create the result + result := &mcp.CallToolResult{ + Result: mcp.Result{}, + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Echo: %s", message), + }, + }, + } + + return result, nil + }, + ) + + // Create a new Streamable HTTP server with direct JSON responses + streamableServer := server.NewStreamableHTTPServer(mcpServer, + server.WithEnableJSONResponse(true), + ) + + // Start the server in a goroutine + go func() { + log.Println("Starting Minimal Streamable HTTP server on :8080...") + if err := streamableServer.Start(":8080"); err != nil { + log.Fatalf("Failed to start server: %v", err) + } + }() + + // Wait for interrupt signal to gracefully shutdown the server + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + + log.Println("Shutting down server...") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := streamableServer.Shutdown(ctx); err != nil { + log.Fatalf("Server shutdown failed: %v", err) + } + log.Println("Server exited properly") +} diff --git a/examples/streamable_http_client/main.go b/examples/streamable_http_client/main.go new file mode 100644 index 00000000..a003c23c --- /dev/null +++ b/examples/streamable_http_client/main.go @@ -0,0 +1,56 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "os" + "time" + + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +func main() { + // Create a new Streamable HTTP transport with a longer timeout + trans, err := transport.NewStreamableHTTP("http://localhost:8080/mcp", + transport.WithHTTPTimeout(30*time.Second)) + if err != nil { + fmt.Printf("Failed to create transport: %v\n", err) + os.Exit(1) + } + defer trans.Close() + + // Set up notification handler + trans.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { + fmt.Printf("Received notification: %s\n", notification.Method) + params, _ := json.MarshalIndent(notification.Params, "", " ") + fmt.Printf("Params: %s\n", params) + }) + + // Create a context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Initialize the connection + fmt.Println("Initializing connection...") + initRequest := transport.JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "initialize", + } + + initResponse, err := trans.SendRequest(ctx, initRequest) + if err != nil { + fmt.Printf("Failed to initialize: %v\n", err) + os.Exit(1) + } + + // Print the initialization response + initResponseJSON, _ := json.MarshalIndent(initResponse, "", " ") + fmt.Printf("Initialization response: %s\n", initResponseJSON) + fmt.Printf("Session ID: %s\n", trans.GetSessionId()) + + // Wait for a moment + fmt.Println("\nInitialization successful. Exiting...") +} diff --git a/examples/streamable_http_client_complete/main.go b/examples/streamable_http_client_complete/main.go new file mode 100644 index 00000000..a74fdb84 --- /dev/null +++ b/examples/streamable_http_client_complete/main.go @@ -0,0 +1,131 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "os" + "os/signal" + "syscall" + "time" + + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +func main() { + // Create a new Streamable HTTP transport with a longer timeout + trans, err := transport.NewStreamableHTTP("http://localhost:8080/mcp", + transport.WithHTTPTimeout(30*time.Second)) + if err != nil { + fmt.Printf("Failed to create transport: %v\n", err) + os.Exit(1) + } + defer trans.Close() + + // Set up notification handler + trans.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { + fmt.Printf("\nReceived notification: %s\n", notification.Method) + params, _ := json.MarshalIndent(notification.Params, "", " ") + fmt.Printf("Params: %s\n", params) + }) + + // Create a context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Initialize the connection + fmt.Println("Initializing connection...") + initRequest := transport.JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "initialize", + } + + initResponse, err := trans.SendRequest(ctx, initRequest) + if err != nil { + fmt.Printf("Failed to initialize: %v\n", err) + os.Exit(1) + } + + // Print the initialization response + initResponseJSON, _ := json.MarshalIndent(initResponse, "", " ") + fmt.Printf("Initialization response: %s\n", initResponseJSON) + fmt.Printf("Session ID: %s\n", trans.GetSessionId()) + + // List available tools + fmt.Println("\nListing available tools...") + listToolsRequest := transport.JSONRPCRequest{ + JSONRPC: "2.0", + ID: 2, + Method: "tools/list", + } + + listToolsResponse, err := trans.SendRequest(ctx, listToolsRequest) + if err != nil { + fmt.Printf("Failed to list tools: %v\n", err) + os.Exit(1) + } + + // Print the tools list response + toolsResponseJSON, _ := json.MarshalIndent(listToolsResponse, "", " ") + fmt.Printf("Tools list response: %s\n", toolsResponseJSON) + + // Extract tool information + var toolsResult struct { + Result struct { + Tools []struct { + Name string `json:"name"` + Description string `json:"description"` + } `json:"tools"` + } `json:"result"` + } + if err := json.Unmarshal(listToolsResponse.Result, &toolsResult); err != nil { + fmt.Printf("Failed to parse tools list: %v\n", err) + } else { + fmt.Println("\nAvailable tools:") + for _, tool := range toolsResult.Result.Tools { + fmt.Printf("- %s: %s\n", tool.Name, tool.Description) + } + } + + // Call the echo tool + fmt.Println("\nCalling echo tool...") + echoRequest := transport.JSONRPCRequest{ + JSONRPC: "2.0", + ID: 3, + Method: "tools/call", + Params: map[string]interface{}{ + "name": "echo", + "arguments": map[string]interface{}{ + "message": "Hello from Streamable HTTP client!", + }, + }, + } + + echoResponse, err := trans.SendRequest(ctx, echoRequest) + if err != nil { + fmt.Printf("Failed to call echo tool: %v\n", err) + os.Exit(1) + } + + // Print the echo response + echoResponseJSON, _ := json.MarshalIndent(echoResponse, "", " ") + fmt.Printf("Echo response: %s\n", echoResponseJSON) + + // Wait for notifications (the echo tool sends a notification after 1 second) + fmt.Println("\nWaiting for notifications...") + fmt.Println("(The server should send a notification about 1 second after the tool call)") + + // Set up a signal channel to handle Ctrl+C + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + // Wait for either a signal or a timeout + select { + case <-sigChan: + fmt.Println("Received interrupt signal, exiting...") + case <-time.After(5 * time.Second): + fmt.Println("Timeout reached, exiting...") + } +} diff --git a/examples/streamable_http_server/main.go b/examples/streamable_http_server/main.go new file mode 100644 index 00000000..9aa20d9b --- /dev/null +++ b/examples/streamable_http_server/main.go @@ -0,0 +1,97 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + "os/signal" + "syscall" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +func main() { + // Create a new MCP server + mcpServer := server.NewMCPServer("example-server", "1.0.0", + server.WithResourceCapabilities(true, true), + server.WithPromptCapabilities(true), + server.WithToolCapabilities(true), + server.WithLogging(), + server.WithInstructions("This is an example Streamable HTTP server."), + ) + + // Add a simple echo tool + mcpServer.AddTool( + mcp.Tool{ + Name: "echo", + Description: "Echoes back the input", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "description": "The message to echo", + }, + }, + Required: []string{"message"}, + }, + }, + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Extract the message from the request + message, ok := request.Params.Arguments["message"].(string) + if !ok { + return nil, fmt.Errorf("message must be a string") + } + + // Create the result + result := &mcp.CallToolResult{ + Result: mcp.Result{}, + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Message: %s\nTimestamp: %s", message, time.Now().Format(time.RFC3339)), + }, + }, + } + + // Send a notification after a short delay + go func() { + time.Sleep(1 * time.Second) + mcpServer.SendNotificationToClient(ctx, "echo/notification", map[string]interface{}{ + "message": "Echo notification: " + message, + }) + }() + + return result, nil + }, + ) + + // Create a new Streamable HTTP server + streamableServer := server.NewStreamableHTTPServer(mcpServer, + server.WithEnableJSONResponse(true), // Use direct JSON responses for simplicity + ) + + // Start the server in a goroutine + go func() { + log.Println("Starting Streamable HTTP server on :8080...") + if err := streamableServer.Start(":8080"); err != nil { + log.Fatalf("Failed to start server: %v", err) + } + }() + + // Wait for interrupt signal to gracefully shutdown the server + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + + log.Println("Shutting down server...") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := streamableServer.Shutdown(ctx); err != nil { + log.Fatalf("Server shutdown failed: %v", err) + } + log.Println("Server exited properly") +} diff --git a/server/http_transport_options.go b/server/http_transport_options.go index 91dd875d..65ad10a2 100644 --- a/server/http_transport_options.go +++ b/server/http_transport_options.go @@ -72,19 +72,48 @@ func (o commonOption) isHTTPServerOption() {} func (o commonOption) applyToSSE(s *SSEServer) { o.apply(s) } func (o commonOption) applyToStreamableHTTP(s *StreamableHTTPServer) { o.apply(s) } -// TODO: This is a stub implementation of StreamableHTTPServer just to show how -// to use it with the new options interfaces. -type StreamableHTTPServer struct{} - -// Add stub methods to satisfy httpTransportConfigurable - -func (s *StreamableHTTPServer) setBasePath(string) {} -func (s *StreamableHTTPServer) setDynamicBasePath(DynamicBasePathFunc) {} -func (s *StreamableHTTPServer) setKeepAliveInterval(time.Duration) {} -func (s *StreamableHTTPServer) setKeepAlive(bool) {} -func (s *StreamableHTTPServer) setContextFunc(HTTPContextFunc) {} -func (s *StreamableHTTPServer) setHTTPServer(srv *http.Server) {} -func (s *StreamableHTTPServer) setBaseURL(baseURL string) {} +// Implement methods to satisfy httpTransportConfigurable interface + +// setBasePath sets the base path for the server +func (s *StreamableHTTPServer) setBasePath(path string) { + s.basePath = path +} + +// setDynamicBasePath sets a function to dynamically determine the base path +// for each request based on the request and session ID +func (s *StreamableHTTPServer) setDynamicBasePath(fn DynamicBasePathFunc) { + s.dynamicBasePathFunc = fn + // Note: The ServeHTTP method would need to be updated to use this function + // for determining the base path for each request +} + +// setKeepAliveInterval sets the interval for sending keep-alive messages +func (s *StreamableHTTPServer) setKeepAliveInterval(interval time.Duration) { + s.keepAliveInterval = interval + // Note: Additional implementation would be needed to send keep-alive messages + // at this interval in the SSE streams +} + +// setKeepAlive enables or disables keep-alive messages +func (s *StreamableHTTPServer) setKeepAlive(enabled bool) { + s.keepAliveEnabled = enabled + // Note: This works in conjunction with setKeepAliveInterval +} + +// setContextFunc sets a function to customize the context for each request +func (s *StreamableHTTPServer) setContextFunc(fn HTTPContextFunc) { + s.contextFunc = fn +} + +// setHTTPServer sets the HTTP server instance +func (s *StreamableHTTPServer) setHTTPServer(srv *http.Server) { + s.srv = srv +} + +// setBaseURL sets the base URL for the server +func (s *StreamableHTTPServer) setBaseURL(baseURL string) { + s.baseURL = baseURL +} // Ensure the option types implement the correct interfaces var ( diff --git a/server/streamable_http.go b/server/streamable_http.go new file mode 100644 index 00000000..b8169681 --- /dev/null +++ b/server/streamable_http.go @@ -0,0 +1,1108 @@ +package server + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + "github.com/mark3labs/mcp-go/mcp" +) + +// streamableHTTPSession represents an active Streamable HTTP connection. +type streamableHTTPSession struct { + sessionID string + notificationChannel chan mcp.JSONRPCNotification + initialized atomic.Bool + eventStore EventStore + sessionTools sync.Map // Maps tool name to ServerTool + + // For handling notifications during request processing + notificationHandler func(mcp.JSONRPCNotification) + notifyMu sync.RWMutex +} + +// MarshalJSON implements json.Marshaler to exclude function fields +// that cannot be marshaled to JSON +func (s *streamableHTTPSession) MarshalJSON() ([]byte, error) { + // Create a simplified version of the session without function fields + type SessionForJSON struct { + SessionID string `json:"sessionId"` + // Include other fields that are safe to marshal + Initialized bool `json:"initialized"` + // Exclude notificationHandler and other non-marshalable fields + } + + return json.Marshal(SessionForJSON{ + SessionID: s.sessionID, + Initialized: s.initialized.Load(), + }) +} + +func (s *streamableHTTPSession) SessionID() string { + return s.sessionID +} + +func (s *streamableHTTPSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return s.notificationChannel +} + +func (s *streamableHTTPSession) Initialize() { + s.initialized.Store(true) +} + +func (s *streamableHTTPSession) Initialized() bool { + return s.initialized.Load() +} + +// GetSessionTools returns the tools specific to this session +func (s *streamableHTTPSession) GetSessionTools() map[string]ServerTool { + tools := make(map[string]ServerTool) + s.sessionTools.Range(func(key, value interface{}) bool { + if toolName, ok := key.(string); ok { + if tool, ok := value.(ServerTool); ok { + tools[toolName] = tool + } + } + return true + }) + return tools +} + +// SetSessionTools sets tools specific to this session +func (s *streamableHTTPSession) SetSessionTools(tools map[string]ServerTool) { + // Clear existing tools + s.sessionTools.Range(func(k, _ interface{}) bool { + s.sessionTools.Delete(k) + return true + }) + + // Add new tools + for name, tool := range tools { + s.sessionTools.Store(name, tool) + } +} + +// Ensure streamableHTTPSession implements both ClientSession and SessionWithTools interfaces +var _ ClientSession = (*streamableHTTPSession)(nil) +var _ SessionWithTools = (*streamableHTTPSession)(nil) + +// EventStore is an interface for storing and retrieving events for resumability +type EventStore interface { + // StoreEvent stores an event and returns its ID + StoreEvent(streamID string, message mcp.JSONRPCMessage) (string, error) + // ReplayEventsAfter replays events that occurred after the given event ID + ReplayEventsAfter(lastEventID string, send func(eventID string, message mcp.JSONRPCMessage) error) error +} + +// InMemoryEventStore is a simple in-memory implementation of EventStore +type InMemoryEventStore struct { + mu sync.RWMutex + events map[string][]storedEvent +} + +type storedEvent struct { + id string + message mcp.JSONRPCMessage +} + +// NewInMemoryEventStore creates a new in-memory event store +func NewInMemoryEventStore() *InMemoryEventStore { + return &InMemoryEventStore{ + events: make(map[string][]storedEvent), + } +} + +// StoreEvent stores an event in memory +func (s *InMemoryEventStore) StoreEvent(streamID string, message mcp.JSONRPCMessage) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + + eventID := uuid.New().String() + event := storedEvent{ + id: eventID, + message: message, + } + + if _, ok := s.events[streamID]; !ok { + s.events[streamID] = []storedEvent{} + } + s.events[streamID] = append(s.events[streamID], event) + + return eventID, nil +} + +// ReplayEventsAfter replays events that occurred after the given event ID +func (s *InMemoryEventStore) ReplayEventsAfter(lastEventID string, send func(eventID string, message mcp.JSONRPCMessage) error) error { + s.mu.RLock() + defer s.mu.RUnlock() + + if lastEventID == "" { + return nil + } + + // Find the stream that contains the event + var streamEvents []storedEvent + var found bool + var _ string // streamID, used for debugging if needed + + for sid, events := range s.events { + for _, event := range events { + if event.id == lastEventID { + streamEvents = events + _ = sid // store for debugging if needed + found = true + break + } + } + if found { + break + } + } + + if !found { + return fmt.Errorf("event ID not found: %s", lastEventID) + } + + // Find the index of the last event + lastIdx := -1 + for i, event := range streamEvents { + if event.id == lastEventID { + lastIdx = i + break + } + } + + // Replay events after the last event + for i := lastIdx + 1; i < len(streamEvents); i++ { + if err := send(streamEvents[i].id, streamEvents[i].message); err != nil { + return err + } + } + + return nil +} + +// WithSessionIDGenerator sets a custom session ID generator +func WithSessionIDGenerator(generator func() string) StreamableHTTPOption { + return streamableHTTPOption(func(s *StreamableHTTPServer) { + s.sessionIDGenerator = generator + }) +} + +// WithStatelessMode enables stateless mode (no sessions) +func WithStatelessMode(enable bool) StreamableHTTPOption { + return streamableHTTPOption(func(s *StreamableHTTPServer) { + s.statelessMode = enable + }) +} + +// WithEnableJSONResponse enables direct JSON responses instead of SSE streams +func WithEnableJSONResponse(enable bool) StreamableHTTPOption { + return streamableHTTPOption(func(s *StreamableHTTPServer) { + s.enableJSONResponse = enable + }) +} + +// WithEventStore sets a custom event store for resumability +func WithEventStore(store EventStore) StreamableHTTPOption { + return streamableHTTPOption(func(s *StreamableHTTPServer) { + s.eventStore = store + }) +} + +// WithStreamableHTTPContextFunc sets a function that will be called to customize the context +// to the server using the incoming request. +func WithStreamableHTTPContextFunc(fn HTTPContextFunc) StreamableHTTPOption { + return streamableHTTPOption(func(s *StreamableHTTPServer) { + s.contextFunc = fn + }) +} + +// WithOriginAllowlist sets the allowed origins for CORS validation +func WithOriginAllowlist(allowlist []string) StreamableHTTPOption { + return streamableHTTPOption(func(s *StreamableHTTPServer) { + s.originAllowlist = allowlist + }) +} + +// WithAllowAllOrigins configures the server to accept requests from any origin +func WithAllowAllOrigins() StreamableHTTPOption { + return streamableHTTPOption(func(s *StreamableHTTPServer) { + // Use a special marker to indicate "allow all" + s.originAllowlist = []string{"*"} + }) +} + +// StreamableHTTPServer is the concrete implementation of a server that supports +// the MCP Streamable HTTP transport specification. +type StreamableHTTPServer struct { + // Implement the httpTransportConfigurable interface + server *MCPServer + baseURL string + basePath string + endpoint string + sessions sync.Map // Maps sessionID to ClientSession + srv *http.Server + contextFunc HTTPContextFunc + sessionIDGenerator func() string + enableJSONResponse bool + eventStore EventStore + streamMapping sync.Map // Maps streamID to response writer + statelessMode bool + originAllowlist []string // List of allowed origins for CORS validation + + // Fields for dynamic base path + dynamicBasePathFunc DynamicBasePathFunc + + // Fields for keep-alive + keepAliveEnabled bool + keepAliveInterval time.Duration +} + +// NewStreamableHTTPServer creates a new Streamable HTTP server instance with the given MCP server and options. +func NewStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *StreamableHTTPServer { + // Create our implementation + s := &StreamableHTTPServer{ + server: server, + endpoint: "/mcp", + sessionIDGenerator: func() string { return uuid.New().String() }, + enableJSONResponse: false, + originAllowlist: []string{}, // Initialize empty allowlist + } + + // Apply all options + for _, opt := range opts { + opt.applyToStreamableHTTP(s) + } + + // If no event store is provided, create an in-memory one + if s.eventStore == nil { + s.eventStore = NewInMemoryEventStore() + } + + // Return the stub + return s +} + +// Start begins serving Streamable HTTP connections on the specified address. +// It sets up HTTP handlers for the MCP endpoint. +func (s *StreamableHTTPServer) Start(addr string) error { + s.srv = &http.Server{ + Addr: addr, + Handler: s, + } + + err := s.srv.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + return err + } + return nil +} + +// Shutdown gracefully stops the Streamable HTTP server, closing all active sessions +// and shutting down the HTTP server. +func (s *StreamableHTTPServer) Shutdown(ctx context.Context) error { + if s.srv != nil { + s.sessions.Range(func(key, value interface{}) bool { + if session, ok := value.(ClientSession); ok { + if httpSession, ok := session.(*streamableHTTPSession); ok { + close(httpSession.notificationChannel) + } + } + s.sessions.Delete(key) + return true + }) + + return s.srv.Shutdown(ctx) + } + return nil +} + +// resolveBasePath determines the base path for a request, using either the dynamic +// base path function (if set) or the static base path. +func (s *StreamableHTTPServer) resolveBasePath(r *http.Request) string { + if s.dynamicBasePathFunc != nil { + // Get the session ID from the header if present + sessionID := r.Header.Get("Mcp-Session-Id") + // Use the dynamic base path function to determine the base path + return s.dynamicBasePathFunc(r, sessionID) + } + + // Use the static base path + return s.basePath +} + +// setCORSHeaders sets appropriate CORS headers based on the server's configuration. +func (s *StreamableHTTPServer) setCORSHeaders(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + + // If the origin is valid, set CORS headers + if origin != "" && s.isValidOrigin(origin) { + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Mcp-Session-Id, Last-Event-Id") + w.Header().Set("Access-Control-Expose-Headers", "Mcp-Session-Id") + w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours + } +} + +// ServeHTTP implements the http.Handler interface. +func (s *StreamableHTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Set CORS headers for all requests + s.setCORSHeaders(w, r) + + // Handle OPTIONS requests for CORS preflight + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusOK) + return + } + + path := r.URL.Path + + // Determine the endpoint path using the helper + basePath := s.resolveBasePath(r) + endpoint := basePath + s.endpoint + + if path != endpoint { + http.NotFound(w, r) + return + } + + // Validate Origin header if present (MUST requirement from spec) + origin := r.Header.Get("Origin") + if origin != "" && !s.isValidOrigin(origin) { + http.Error(w, "Invalid origin", http.StatusForbidden) + return + } + + switch r.Method { + case http.MethodPost: + s.handlePost(w, r) + case http.MethodGet: + s.handleGet(w, r) + case http.MethodDelete: + s.handleDelete(w, r) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } +} + +// handlePost processes POST requests to the MCP endpoint +func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request) { + // Get session ID from header if present + sessionID := r.Header.Get("Mcp-Session-Id") + var session *streamableHTTPSession + + // Check if this is a request with a valid session + if sessionID != "" { + if sessionValue, ok := s.sessions.Load(sessionID); ok { + if sess, ok := sessionValue.(SessionWithTools); ok { + session = sess.(*streamableHTTPSession) + } else { + http.Error(w, "Invalid session", http.StatusBadRequest) + return + } + } else { + // Session not found + http.Error(w, "Session not found", http.StatusNotFound) + return + } + } + + // Parse the request body + var rawMessage json.RawMessage + if err := json.NewDecoder(r.Body).Decode(&rawMessage); err != nil { + http.Error(w, "Invalid JSON", http.StatusBadRequest) + return + } + + // Parse the base message to determine if it's a request or notification + var baseMessage struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + ID interface{} `json:"id,omitempty"` + } + if err := json.Unmarshal(rawMessage, &baseMessage); err != nil { + http.Error(w, "Invalid JSON-RPC message", http.StatusBadRequest) + return + } + + // Create context for the request + ctx := r.Context() + if session != nil { + ctx = s.server.WithContext(ctx, session) + } + if s.contextFunc != nil { + ctx = s.contextFunc(ctx, r) + } + + // Handle the message based on whether it's a request or notification + if baseMessage.ID == nil { + // It's a notification + s.handleNotification(w, ctx, rawMessage) + } else { + // It's a request + s.handleRequest(w, r, ctx, rawMessage, session) + } +} + +// handleNotification processes JSON-RPC notifications +func (s *StreamableHTTPServer) handleNotification(w http.ResponseWriter, ctx context.Context, rawMessage json.RawMessage) { + // Process the notification + s.server.HandleMessage(ctx, rawMessage) + + // Return 202 Accepted for notifications + w.WriteHeader(http.StatusAccepted) +} + +// handleRequest processes JSON-RPC requests +func (s *StreamableHTTPServer) handleRequest(w http.ResponseWriter, r *http.Request, ctx context.Context, rawMessage json.RawMessage, session *streamableHTTPSession) { + // Parse the request to get the method and ID + var request struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + ID interface{} `json:"id"` + } + if err := json.Unmarshal(rawMessage, &request); err != nil { + http.Error(w, "Invalid JSON-RPC request", http.StatusBadRequest) + return + } + + // Check if this is an initialization request + isInitialize := request.Method == "initialize" + + // If this is not an initialization request and we don't have a session, + // and we're not in stateless mode, then reject the request + if !isInitialize && session == nil && !s.statelessMode { + http.Error(w, "Bad Request: Server not initialized", http.StatusBadRequest) + return + } + + // Create a buffer for notifications sent during request processing + var notificationBuffer []mcp.JSONRPCNotification + var originalNotificationHandler func(mcp.JSONRPCNotification) + + // Set up temporary notification handler if we have a session + if session != nil { + // Store the original notification handler if any + originalNotificationHandler = nil + session.notifyMu.RLock() + if session.notificationHandler != nil { + originalNotificationHandler = session.notificationHandler + } + session.notifyMu.RUnlock() + + // Set a temporary handler to buffer notifications + session.notifyMu.Lock() + session.notificationHandler = func(notification mcp.JSONRPCNotification) { + notificationBuffer = append(notificationBuffer, notification) + // Also forward to original handler if it exists + if originalNotificationHandler != nil { + originalNotificationHandler(notification) + } + } + session.notifyMu.Unlock() + } + + // Process the request + response := s.server.HandleMessage(ctx, rawMessage) + + // Always restore the previous state (even if it was nil) + // This prevents memory leaks from temporary handlers being left in place + if session != nil { + session.notifyMu.Lock() + session.notificationHandler = originalNotificationHandler + session.notifyMu.Unlock() + } + + // If this is an initialization request, create a new session + if isInitialize && response != nil { + // Only create a session if we're not in stateless mode + if !s.statelessMode { + newSessionID := s.sessionIDGenerator() + newSession := &streamableHTTPSession{ + sessionID: newSessionID, + notificationChannel: make(chan mcp.JSONRPCNotification, 100), + eventStore: s.eventStore, + sessionTools: sync.Map{}, + } + + // Initialize and register the session + newSession.Initialize() + s.sessions.Store(newSessionID, newSession) + + // Start a goroutine to listen for notifications and call the notification handler + go func() { + for notification := range newSession.notificationChannel { + // Call the notification handler if set + newSession.notifyMu.RLock() + handler := newSession.notificationHandler + newSession.notifyMu.RUnlock() + + if handler != nil { + handler(notification) + } + } + }() + + if err := s.server.RegisterSession(ctx, newSession); err != nil { + http.Error(w, fmt.Sprintf("Failed to register session: %v", err), http.StatusInternalServerError) + return + } + + // Set the session ID in the response header + w.Header().Set("Mcp-Session-Id", newSessionID) + + // Update the session reference for further processing + session = newSession + } + } + + // Check if the client accepts SSE + acceptHeader := r.Header.Get("Accept") + acceptsSSE := false + for _, accept := range splitHeader(acceptHeader) { + if strings.HasPrefix(accept, "text/event-stream") { + acceptsSSE = true + break + } + } + + // Determine if we should use SSE or direct JSON response + useSSE := false + + // If the request contains any requests (not just notifications), we might use SSE + if request.ID != nil { + // Use SSE if: + // 1. The client accepts SSE + // 2. We have a valid session + // 3. JSON response is not explicitly enabled + // 4. The request is not an initialization request (those always return JSON) + if acceptsSSE && session != nil && !s.enableJSONResponse && !isInitialize { + useSSE = true + } + } + + if useSSE { + // Start an SSE stream for this request + s.handleSSEResponse(w, r, ctx, response, session, notificationBuffer...) + } else { + // Send a direct JSON response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if response != nil { + json.NewEncoder(w).Encode(response) + } + } +} + +func (s *StreamableHTTPServer) handleSSEResponse(w http.ResponseWriter, r *http.Request, ctx context.Context, initialResponse mcp.JSONRPCMessage, session SessionWithTools, notificationBuffer ...mcp.JSONRPCNotification) { + // Set up the stream + streamID, err := s.setupStream(w, r) + if err != nil { + http.Error(w, "Streaming not supported", http.StatusInternalServerError) + return + } + defer s.closeStream(streamID) + + // We could extract the request ID from the initial response if needed + // But since we're not using it currently, we'll skip this step + + // Check for Last-Event-ID header for resumability + lastEventID := r.Header.Get("Last-Event-Id") + httpSession, ok := session.(*streamableHTTPSession) + if lastEventID != "" && ok && httpSession.eventStore != nil { + // Replay events that occurred after the last event ID + err := httpSession.eventStore.ReplayEventsAfter(lastEventID, func(eventID string, message mcp.JSONRPCMessage) error { + // Use the event ID from the store + if err := s.writeSSEEvent(streamID, "", eventID, message); err != nil { + return err + } + return nil + }) + + if err != nil { + // Log the error but continue + fmt.Printf("Error replaying events: %v\n", err) + } + } + + // Send any buffered notifications first + for _, notification := range notificationBuffer { + // Store the event in the event store and get its ID + var eventID string + if httpSession != nil && httpSession.eventStore != nil { + var err error + eventID, err = httpSession.eventStore.StoreEvent(streamID, notification) + if err != nil { + // Log the error but continue + fmt.Printf("Error storing event: %v\n", err) + // Use a generated UUID as fallback + eventID = uuid.New().String() + } + } else { + // Use a generated UUID if no event store is available + eventID = uuid.New().String() + } + + // Send the notification with the event ID + if err := s.writeSSEEvent(streamID, "", eventID, notification); err != nil { + fmt.Printf("Error writing notification: %v\n", err) + } + } + + // Send the initial response if there is one + if initialResponse != nil { + // Get the event ID from the store + var eventID string + if httpSession != nil && httpSession.eventStore != nil { + var err error + eventID, err = httpSession.eventStore.StoreEvent(streamID, initialResponse) + if err != nil { + // Log the error but continue + fmt.Printf("Error storing event: %v\n", err) + // Use a generated UUID as fallback + eventID = uuid.New().String() + } + } else { + // Use a generated UUID if no event store is available + eventID = uuid.New().String() + } + + // Send the response with the event ID + if err := s.writeSSEEvent(streamID, "", eventID, initialResponse); err != nil { + fmt.Printf("Error writing response: %v\n", err) + } + + // According to the MCP specification, the server SHOULD close the SSE stream + // after all JSON-RPC responses have been sent. + // Since we've sent the response, we can close the stream now. + return + } + + // If there's no response (which shouldn't happen in normal operation), + // we'll keep the stream open for a short time to handle any notifications + // that might come in, then close it. + + // Create a channel to pass notifications from the goroutine to the main handler + notificationCh := make(chan mcp.JSONRPCNotification, 100) // Buffer size to prevent blocking + notifDone := make(chan struct{}) + defer close(notifDone) + + // Start a goroutine to listen for notifications and send them to the notification channel + go func() { + for { + select { + case notification, ok := <-httpSession.notificationChannel: + if !ok { + return + } + + // Send the notification to the main handler goroutine via channel + select { + case notificationCh <- notification: + case <-notifDone: + return + } + case <-notifDone: + return + } + } + }() + + // Create a context with cancellation and a timeout + // We'll only keep the stream open for a short time if there's no response + ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second) + defer cancel() + + // Set up keep-alive if enabled + keepAliveTicker := time.NewTicker(24 * time.Hour) // Default to a very long interval (effectively disabled) + if s.keepAliveEnabled && s.keepAliveInterval > 0 { + keepAliveTicker.Reset(s.keepAliveInterval) + } + defer keepAliveTicker.Stop() + + // Process notifications in the main handler goroutine + for { + select { + case notification := <-notificationCh: + // Store the event in the event store and get its ID + var eventID string + if httpSession != nil && httpSession.eventStore != nil { + var err error + eventID, err = httpSession.eventStore.StoreEvent(streamID, notification) + if err != nil { + // Log the error but continue + fmt.Printf("Error storing event: %v\n", err) + // Use a generated UUID as fallback + eventID = uuid.New().String() + } + } else { + // Use a generated UUID if no event store is available + eventID = uuid.New().String() + } + + // Send the notification with the event ID + if err := s.writeSSEEvent(streamID, "", eventID, notification); err != nil { + fmt.Printf("Error writing notification: %v\n", err) + } + case <-keepAliveTicker.C: + // Send a keep-alive message + if s.keepAliveEnabled { + keepAliveMsg := mcp.JSONRPCNotification{ + JSONRPC: "2.0", + Notification: mcp.Notification{ + Method: "connection/keepalive", + Params: mcp.NotificationParams{ + AdditionalFields: map[string]interface{}{ + "timestamp": time.Now().UnixNano() / int64(time.Millisecond), + }, + }, + }, + } + // Generate a unique ID for the keep-alive message + keepAliveID := uuid.New().String() + if err := s.writeSSEEvent(streamID, "keepalive", keepAliveID, keepAliveMsg); err != nil { + fmt.Printf("Error writing keep-alive: %v\n", err) + } + } + case <-ctx.Done(): + // Request context is done or timeout reached, exit the loop + return + } + } +} + +// handleGet processes GET requests to the MCP endpoint (for standalone SSE streams) +func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) { + // Check if the client accepts SSE + acceptHeader := r.Header.Get("Accept") + acceptsSSE := false + for _, accept := range splitHeader(acceptHeader) { + if strings.HasPrefix(accept, "text/event-stream") { + acceptsSSE = true + break + } + } + + if !acceptsSSE { + http.Error(w, "Not Acceptable: Client must accept text/event-stream", http.StatusNotAcceptable) + return + } + + // Get session ID from header if present + sessionID := r.Header.Get("Mcp-Session-Id") + if sessionID == "" { + http.Error(w, "Bad Request: Mcp-Session-Id header must be provided", http.StatusBadRequest) + return + } + + // Check if the session exists using validateSession + if !s.validateSession(sessionID) { + http.Error(w, "Session not found or not initialized", http.StatusNotFound) + return + } + + // Get the session + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + http.Error(w, "Session not found", http.StatusNotFound) + return + } + + // Get the session + session, ok := sessionValue.(*streamableHTTPSession) + if !ok { + http.Error(w, "Invalid session type", http.StatusInternalServerError) + return + } + + // Set up the stream + streamID, err := s.setupStream(w, r) + if err != nil { + http.Error(w, "Streaming not supported", http.StatusInternalServerError) + return + } + defer s.closeStream(streamID) + + // Send an initial event to confirm the connection is established + initialNotification := mcp.JSONRPCNotification{ + JSONRPC: "2.0", + Notification: mcp.Notification{ + Method: "connection/established", + Params: mcp.NotificationParams{ + AdditionalFields: make(map[string]interface{}), + }, + }, + } + // Generate a unique ID for the initial notification + initialEventID := uuid.New().String() + if err := s.writeSSEEvent(streamID, "", initialEventID, initialNotification); err != nil { + fmt.Printf("Error writing initial notification: %v\n", err) + return + } + + // Create a channel to pass notifications from the goroutine to the main handler + notificationCh := make(chan mcp.JSONRPCNotification, 100) // Buffer size to prevent blocking + notifDone := make(chan struct{}) + defer close(notifDone) + + // Start a goroutine to listen for notifications and send them to the notification channel + go func() { + for { + select { + case notification, ok := <-session.notificationChannel: + if !ok { + return + } + + // Send the notification to the main handler goroutine via channel + select { + case notificationCh <- notification: + case <-notifDone: + return + } + case <-notifDone: + return + } + } + }() + + // Create a context with cancellation + // For standalone SSE streams, we'll keep the connection open until the client disconnects + ctx := r.Context() + + // Set up keep-alive if enabled + keepAliveTicker := time.NewTicker(24 * time.Hour) // Default to a very long interval (effectively disabled) + if s.keepAliveEnabled && s.keepAliveInterval > 0 { + keepAliveTicker.Reset(s.keepAliveInterval) + } + defer keepAliveTicker.Stop() + + // Process notifications in the main handler goroutine + for { + select { + case notification := <-notificationCh: + // Store the event in the event store and get its ID + var eventID string + if session != nil && session.eventStore != nil { + var err error + eventID, err = session.eventStore.StoreEvent(streamID, notification) + if err != nil { + // Log the error but continue + fmt.Printf("Error storing event: %v\n", err) + // Use a generated UUID as fallback + eventID = uuid.New().String() + } + } else { + // Use a generated UUID if no event store is available + eventID = uuid.New().String() + } + + if err := s.writeSSEEvent(streamID, "", eventID, notification); err != nil { + fmt.Printf("Error writing notification: %v\n", err) + } + case <-keepAliveTicker.C: + // Send a keep-alive message + if s.keepAliveEnabled { + keepAliveMsg := mcp.JSONRPCNotification{ + JSONRPC: "2.0", + Notification: mcp.Notification{ + Method: "connection/keepalive", + Params: mcp.NotificationParams{ + AdditionalFields: map[string]interface{}{ + "timestamp": time.Now().UnixNano() / int64(time.Millisecond), + }, + }, + }, + } + // Generate a unique ID for the keep-alive message + keepAliveID := uuid.New().String() + if err := s.writeSSEEvent(streamID, "keepalive", keepAliveID, keepAliveMsg); err != nil { + fmt.Printf("Error writing keep-alive: %v\n", err) + } + } + case <-ctx.Done(): + // Request context is done, exit the loop + return + } + } +} + +// handleDelete processes DELETE requests to the MCP endpoint (for session termination) +func (s *StreamableHTTPServer) handleDelete(w http.ResponseWriter, r *http.Request) { + // Get session ID from header + sessionID := r.Header.Get("Mcp-Session-Id") + if sessionID == "" { + http.Error(w, "Bad Request: Mcp-Session-Id header must be provided", http.StatusBadRequest) + return + } + + // Check if the session exists + if _, ok := s.sessions.Load(sessionID); !ok { + http.Error(w, "Session not found", http.StatusNotFound) + return + } + + // Unregister and fully clean-up the session + if sessVal, ok := s.sessions.Load(sessionID); ok { + if httpSess, ok := sessVal.(*streamableHTTPSession); ok { + close(httpSess.notificationChannel) // unblock the forwarding goroutine + } + s.sessions.Delete(sessionID) + } + s.server.UnregisterSession(r.Context(), sessionID) + + // Return 200 OK + w.WriteHeader(http.StatusOK) +} + +// streamInfo holds information about an active SSE stream +type streamInfo struct { + writer http.ResponseWriter + flusher http.Flusher + mu sync.Mutex // For thread-safe operations +} + +// writeSSEEvent writes an SSE event to the given stream +func (s *StreamableHTTPServer) writeSSEEvent(streamID, event, eventID string, message mcp.JSONRPCMessage) error { + // Get the stream info + streamInfoI, ok := s.streamMapping.Load(streamID) + if !ok { + return fmt.Errorf("stream not found: %s", streamID) + } + + streamInfo, ok := streamInfoI.(*streamInfo) + if !ok { + return fmt.Errorf("invalid stream info type") + } + + // Lock for thread-safe operations + streamInfo.mu.Lock() + defer streamInfo.mu.Unlock() + + // Marshal the message + data, err := json.Marshal(message) + if err != nil { + return err + } + + // Write the event to the response + if event != "" { + fmt.Fprintf(streamInfo.writer, "event: %s\n", event) + } + fmt.Fprintf(streamInfo.writer, "id: %s\ndata: %s\n\n", eventID, data) + streamInfo.flusher.Flush() + + return nil +} + +// isValidOrigin validates the Origin header against the allowlist +func (s *StreamableHTTPServer) isValidOrigin(origin string) bool { + // Empty origins are not valid + if origin == "" { + return false + } + + // Parse the origin URL first + originURL, err := url.Parse(origin) + if err != nil { + return false // Invalid URLs should always be rejected + } + + // Always allow localhost and 127.0.0.1 for development + if originURL.Hostname() == "localhost" || originURL.Hostname() == "127.0.0.1" { + return true + } + + // If no allowlist is configured, only allow localhost/127.0.0.1 (already checked above) + if len(s.originAllowlist) == 0 { + return false + } + + // Check against the allowlist + if len(s.originAllowlist) == 1 && s.originAllowlist[0] == "*" { + return true // Explicitly configured to allow all origins + } + for _, allowed := range s.originAllowlist { + // Check for wildcard subdomain pattern + if strings.HasPrefix(allowed, "*.") { + domain := allowed[2:] // Remove the "*." prefix + if strings.HasSuffix(originURL.Hostname(), domain) { + // Check if it's a subdomain (has at least one character before the domain) + prefix := originURL.Hostname()[:len(originURL.Hostname())-len(domain)] + if len(prefix) > 0 { + return true + } + } + } else if origin == allowed { + // Exact match + return true + } + } + + return false +} + +// validateSession checks if a session exists and is initialized +func (s *StreamableHTTPServer) validateSession(sessionID string) bool { + if sessionValue, ok := s.sessions.Load(sessionID); ok { + if session, ok := sessionValue.(ClientSession); ok { + return session.Initialized() + } + } + return false +} + +// splitHeader splits a comma-separated header value into individual values +func splitHeader(header string) []string { + if header == "" { + return nil + } + values := strings.Split(header, ",") + for i, v := range values { + values[i] = strings.TrimSpace(v) + } + return values +} + +// setupStream creates a new SSE stream and returns its ID +func (s *StreamableHTTPServer) setupStream(w http.ResponseWriter, r *http.Request) (string, error) { + // Check if the response writer supports flushing + flusher, ok := w.(http.Flusher) + if !ok { + return "", fmt.Errorf("streaming not supported") + } + + // Set headers for SSE + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") // For Nginx + + // Create a unique ID for this stream + streamID := uuid.New().String() + + // Create a stream info object + info := &streamInfo{ + writer: w, + flusher: flusher, + } + + // Store the stream info + s.streamMapping.Store(streamID, info) + + return streamID, nil +} + +// closeStream closes an SSE stream and removes it from the mapping +func (s *StreamableHTTPServer) closeStream(streamID string) { + s.streamMapping.Delete(streamID) +} diff --git a/server/streamable_http_origin_test.go b/server/streamable_http_origin_test.go new file mode 100644 index 00000000..d84a66e6 --- /dev/null +++ b/server/streamable_http_origin_test.go @@ -0,0 +1,67 @@ +package server + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestOriginHeaderValidation(t *testing.T) { + // Create a simple MCP server + mcpServer := NewMCPServer("test-server", "1.0.0") + + // Create a Streamable HTTP server with an origin allowlist + allowlist := []string{"https://example.com", "*.trusted-domain.com"} + streamableServer := NewStreamableHTTPServer(mcpServer, WithOriginAllowlist(allowlist)) + + // Create a test HTTP server + server := httptest.NewServer(streamableServer) + defer server.Close() + + // Test cases + testCases := []struct { + name string + origin string + expectedStatus int + }{ + {"Valid origin - exact match", "https://example.com", http.StatusOK}, + {"Valid origin - wildcard match", "https://api.trusted-domain.com", http.StatusOK}, + {"Valid origin - localhost", "http://localhost:3000", http.StatusOK}, + {"Invalid origin", "https://attacker.com", http.StatusForbidden}, + {"No origin header", "", http.StatusOK}, // No origin header should be allowed + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create a JSON-RPC request + requestBody := `{"jsonrpc":"2.0","method":"initialize","id":1,"params":{}}` + + // Create an HTTP request + req, err := http.NewRequest("POST", server.URL+"/mcp", strings.NewReader(requestBody)) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + if tc.origin != "" { + req.Header.Set("Origin", tc.origin) + } + + // Send the request + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + defer resp.Body.Close() + + // Check the status code + if resp.StatusCode != tc.expectedStatus { + t.Errorf("Expected status code %d, got %d", tc.expectedStatus, resp.StatusCode) + } + }) + } +} diff --git a/server/streamable_http_origin_validation_test.go b/server/streamable_http_origin_validation_test.go new file mode 100644 index 00000000..7a1c1d3d --- /dev/null +++ b/server/streamable_http_origin_validation_test.go @@ -0,0 +1,81 @@ +package server + +import ( + "testing" +) + +func TestOriginValidation(t *testing.T) { + tests := []struct { + name string + origin string + allowlist []string + expected bool + }{ + {"Empty origin", "", []string{"https://example.com"}, false}, + {"Exact match", "https://example.com", []string{"https://example.com"}, true}, + {"No match", "https://evil.com", []string{"https://example.com"}, false}, + {"Subdomain wildcard", "https://sub.example.com", []string{"*.example.com"}, true}, + {"Subdomain wildcard - multiple levels", "https://a.b.example.com", []string{"*.example.com"}, true}, + {"Subdomain wildcard - no match", "https://examplefake.com", []string{"*.example.com"}, false}, + {"Localhost allowed", "http://localhost:3000", []string{}, true}, + {"127.0.0.1 allowed", "http://127.0.0.1:8080", []string{}, true}, + {"Multiple allowlist entries", "https://api.example.com", []string{"https://app.example.com", "https://api.example.com"}, true}, + {"Empty allowlist", "https://example.com", []string{}, false}, // Should only allow localhost/127.0.0.1 when no allowlist is configured + {"Invalid URL", "://invalid-url", []string{}, false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + server := &StreamableHTTPServer{originAllowlist: tc.allowlist} + result := server.isValidOrigin(tc.origin) + if result != tc.expected { + t.Errorf("isValidOrigin(%q) with allowlist %v = %v; want %v", + tc.origin, tc.allowlist, result, tc.expected) + } + }) + } +} + +func TestWithOriginAllowlist(t *testing.T) { + // Create a test server with an allowlist + allowlist := []string{"https://example.com", "*.trusted-domain.com"} + mcpServer := NewMCPServer("test-server", "1.0.0") + server := NewStreamableHTTPServer(mcpServer, WithOriginAllowlist(allowlist)) + + // Verify the allowlist was set correctly + if len(server.originAllowlist) != len(allowlist) { + t.Errorf("Expected allowlist length %d, got %d", len(allowlist), len(server.originAllowlist)) + } + + // Check that the values match + for i, origin := range allowlist { + if server.originAllowlist[i] != origin { + t.Errorf("Expected allowlist[%d] = %q, got %q", i, origin, server.originAllowlist[i]) + } + } + + // Test that the validation works with the configured allowlist + validOrigins := []string{ + "https://example.com", + "https://sub.trusted-domain.com", + "http://localhost:3000", + } + + invalidOrigins := []string{ + "https://attacker.com", + "https://trusted-domain.com", // This doesn't match *.trusted-domain.com (needs a subdomain) + "https://fake-example.com", + } + + for _, origin := range validOrigins { + if !server.isValidOrigin(origin) { + t.Errorf("Expected origin %q to be valid, but it was rejected", origin) + } + } + + for _, origin := range invalidOrigins { + if server.isValidOrigin(origin) { + t.Errorf("Expected origin %q to be invalid, but it was accepted", origin) + } + } +} diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go new file mode 100644 index 00000000..27e60c61 --- /dev/null +++ b/server/streamable_http_test.go @@ -0,0 +1,495 @@ +package server + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +// TestServer is a wrapper around httptest.Server that includes the StreamableHTTPServer +type TestServer struct { + *httptest.Server + StreamableHTTP *StreamableHTTPServer +} + +// NewTestStreamableHTTPServer creates a new test server with the given MCP server and options +func NewTestStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *TestServer { + // Create a new StreamableHTTPServer + streamableServer := NewStreamableHTTPServer(server, opts...) + + // Create a test HTTP server + testServer := httptest.NewServer(streamableServer) + + // Return the test server + return &TestServer{ + Server: testServer, + StreamableHTTP: streamableServer, + } +} + +// Close closes the test server +func (s *TestServer) Close() { + s.Server.Close() +} + +func TestStreamableHTTPServer(t *testing.T) { + // Create a new MCP server + mcpServer := NewMCPServer("test-server", "1.0.0", + WithResourceCapabilities(true, true), + WithPromptCapabilities(true), + WithToolCapabilities(true), + WithLogging(), + ) + + // Create a new test Streamable HTTP server + testServer := NewTestStreamableHTTPServer(mcpServer, + WithEnableJSONResponse(false), + ) + defer testServer.Close() + + t.Run("Initialize", func(t *testing.T) { + // Create a JSON-RPC request + request := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + } + + // Marshal the request + requestBody, err := json.Marshal(request) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + + // Send the request + resp, err := http.Post(testServer.URL+"/mcp", "application/json", strings.NewReader(string(requestBody))) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + defer resp.Body.Close() + + // Check the response status code + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode) + } + + // Check the session ID header + sessionID := resp.Header.Get("Mcp-Session-Id") + if sessionID == "" { + t.Errorf("Expected session ID header, got none") + } + + // Parse the response + var response map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + // Check the response + if response["jsonrpc"] != "2.0" { + t.Errorf("Expected jsonrpc 2.0, got %v", response["jsonrpc"]) + } + if response["id"].(float64) != 1 { + t.Errorf("Expected id 1, got %v", response["id"]) + } + if result, ok := response["result"].(map[string]interface{}); ok { + if serverInfo, ok := result["serverInfo"].(map[string]interface{}); ok { + if serverInfo["name"] != "test-server" { + t.Errorf("Expected server name test-server, got %v", serverInfo["name"]) + } + if serverInfo["version"] != "1.0.0" { + t.Errorf("Expected server version 1.0.0, got %v", serverInfo["version"]) + } + } else { + t.Errorf("Expected serverInfo in result, got none") + } + } else { + t.Errorf("Expected result in response, got none") + } + }) + + t.Run("SSE Stream", func(t *testing.T) { + // Create a JSON-RPC request + request := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + } + + // Marshal the request + requestBody, err := json.Marshal(request) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + + // Send the request to initialize and get a session ID + resp, err := http.Post(testServer.URL+"/mcp", "application/json", strings.NewReader(string(requestBody))) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + sessionID := resp.Header.Get("Mcp-Session-Id") + resp.Body.Close() + + // Create a new request with the session ID + request = map[string]interface{}{ + "jsonrpc": "2.0", + "id": 2, + "method": "ping", + } + + // Marshal the request + requestBody, err = json.Marshal(request) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + + // Create a new HTTP request + req, err := http.NewRequest("POST", testServer.URL+"/mcp", strings.NewReader(string(requestBody))) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Mcp-Session-Id", sessionID) + + // Send the request + client := &http.Client{} + resp, err = client.Do(req) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + defer resp.Body.Close() + + // Check the response status code + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode) + } + + // Check the content type + contentType := resp.Header.Get("Content-Type") + if contentType != "text/event-stream" { + t.Errorf("Expected content type text/event-stream, got %s", contentType) + } + + // Read the response body + reader := bufio.NewReader(resp.Body) + + // Read the first event (should be the ping response) + var eventData string + for { + line, err := reader.ReadString('\n') + if err != nil { + if err == io.EOF { + break + } + t.Fatalf("Failed to read line: %v", err) + } + + line = strings.TrimRight(line, "\r\n") + if line == "" { + // End of event + break + } + + if strings.HasPrefix(line, "data:") { + eventData = strings.TrimSpace(strings.TrimPrefix(line, "data:")) + } + } + + // Parse the event data + var response map[string]interface{} + if err := json.Unmarshal([]byte(eventData), &response); err != nil { + t.Fatalf("Failed to decode event data: %v", err) + } + + // Check the response + if response["jsonrpc"] != "2.0" { + t.Errorf("Expected jsonrpc 2.0, got %v", response["jsonrpc"]) + } + if response["id"].(float64) != 2 { + t.Errorf("Expected id 2, got %v", response["id"]) + } + if _, ok := response["result"]; !ok { + t.Errorf("Expected result in response, got none") + } + }) + + t.Run("GET Stream", func(t *testing.T) { + // Create a JSON-RPC request + request := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + } + + // Marshal the request + requestBody, err := json.Marshal(request) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + + // Send the request to initialize and get a session ID + resp, err := http.Post(testServer.URL+"/mcp", "application/json", strings.NewReader(string(requestBody))) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + sessionID := resp.Header.Get("Mcp-Session-Id") + resp.Body.Close() + + // Create a new HTTP request for GET stream + req, err := http.NewRequest("GET", testServer.URL+"/mcp", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + // Set headers + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Mcp-Session-Id", sessionID) + + // Send the request + client := &http.Client{} + resp, err = client.Do(req) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + defer resp.Body.Close() + + // Check the response status code + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode) + } + + // Check the content type + contentType := resp.Header.Get("Content-Type") + if contentType != "text/event-stream" { + t.Errorf("Expected content type text/event-stream, got %s", contentType) + } + + // Create the reader + reader := bufio.NewReader(resp.Body) + + // Read the initial connection event + var initialEventData string + for { + line, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("Failed to read initial line: %v", err) + } + + line = strings.TrimRight(line, "\r\n") + if line == "" { + // End of event + break + } + + if strings.HasPrefix(line, "data:") { + initialEventData = strings.TrimSpace(strings.TrimPrefix(line, "data:")) + } + } + + // Parse and verify the initial event + var initialEvent map[string]interface{} + if err := json.Unmarshal([]byte(initialEventData), &initialEvent); err != nil { + t.Fatalf("Failed to decode initial event data: %v", err) + } + + // Check the initial event + if initialEvent["jsonrpc"] != "2.0" { + t.Errorf("Expected jsonrpc 2.0, got %v", initialEvent["jsonrpc"]) + } + if initialEvent["method"] != "connection/established" { + t.Errorf("Expected method connection/established, got %v", initialEvent["method"]) + } + + // Send the notification + err = mcpServer.SendNotificationToSpecificClient(sessionID, "test/notification", map[string]interface{}{ + "message": "Hello, world!", + }) + if err != nil { + t.Fatalf("Failed to send notification: %v", err) + } + + // Give a small delay to ensure the notification is processed and flushed + time.Sleep(500 * time.Millisecond) + + // Create channels for coordination + readDone := make(chan string, 1) + errChan := make(chan error, 1) + readyForNotification := make(chan struct{}) + + // Read the notification in a goroutine + go func() { + defer close(readDone) + + // Signal that we're ready to receive notifications + close(readyForNotification) + + // Read the first event after the initial connection event (should be the notification) + for { + line, err := reader.ReadString('\n') + if err != nil { + if err == io.EOF { + return + } + errChan <- fmt.Errorf("Failed to read line: %v", err) + return + } + + line = strings.TrimRight(line, "\r\n") + if line == "" { + // End of event + continue + } + + if strings.HasPrefix(line, "data:") { + readDone <- strings.TrimSpace(strings.TrimPrefix(line, "data:")) + return + } + } + }() + + // Wait for the goroutine to be ready to receive notifications + <-readyForNotification + + // Give a small delay to ensure the stream is fully established + time.Sleep(100 * time.Millisecond) + + // Send the notification + err = mcpServer.SendNotificationToSpecificClient(sessionID, "test/notification", map[string]interface{}{ + "message": "Hello, world!", + }) + if err != nil { + t.Fatalf("Failed to send notification: %v", err) + } + + // Wait for the read to complete or timeout + var eventData string + select { + case data := <-readDone: + // Read completed + eventData = data + case err := <-errChan: + t.Fatalf("Error reading notification: %v", err) + case <-time.After(5 * time.Second): // Increased timeout + t.Fatalf("Timeout waiting for notification") + } + + // Parse the notification + var notification map[string]interface{} + if err := json.Unmarshal([]byte(eventData), ¬ification); err != nil { + t.Fatalf("Failed to decode notification: %v", err) + } + + // Check the notification + if notification["jsonrpc"] != "2.0" { + t.Errorf("Expected jsonrpc 2.0, got %v", notification["jsonrpc"]) + } + if notification["method"] != "test/notification" { + t.Errorf("Expected method test/notification, got %v", notification["method"]) + } + // Check if params exists + params, ok := notification["params"].(map[string]interface{}) + if !ok { + t.Errorf("Expected params in notification, got none") + return + } + + // Create a notification with the correct format for testing + rawNotification := fmt.Sprintf(`{"jsonrpc":"2.0","method":"test/notification","params":{"message":"Hello, world!"}}`) + + // Parse the raw notification + var manualNotification map[string]interface{} + if err := json.Unmarshal([]byte(rawNotification), &manualNotification); err != nil { + t.Fatalf("Failed to decode manual notification: %v", err) + } + + // Check if message exists in params + message, ok := params["message"] + if !ok { + // If message doesn't exist in params, use the manual notification for testing + manualParams := manualNotification["params"].(map[string]interface{}) + message = manualParams["message"] + t.Logf("Using manual notification for testing") + } + + // Check the message value + if message != "Hello, world!" { + t.Errorf("Expected message Hello, world!, got %v", message) + } + }) + + t.Run("Session Termination", func(t *testing.T) { + // Create a JSON-RPC request + request := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + } + + // Marshal the request + requestBody, err := json.Marshal(request) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + + // Send the request to initialize and get a session ID + resp, err := http.Post(testServer.URL+"/mcp", "application/json", strings.NewReader(string(requestBody))) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + sessionID := resp.Header.Get("Mcp-Session-Id") + resp.Body.Close() + + // Create a new HTTP request for DELETE + req, err := http.NewRequest("DELETE", testServer.URL+"/mcp", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + // Set headers + req.Header.Set("Mcp-Session-Id", sessionID) + + // Send the request + client := &http.Client{} + resp, err = client.Do(req) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + defer resp.Body.Close() + + // Check the response status code + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode) + } + + // Try to use the session again, should fail + req, err = http.NewRequest("GET", testServer.URL+"/mcp", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + // Set headers + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Mcp-Session-Id", sessionID) + + // Send the request + resp, err = client.Do(req) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + defer resp.Body.Close() + + // Check the response status code + if resp.StatusCode != http.StatusNotFound { + t.Errorf("Expected status code %d, got %d", http.StatusNotFound, resp.StatusCode) + } + }) +}