diff --git a/README-streamable-http.md b/README-streamable-http.md new file mode 100644 index 00000000..36d49e9f --- /dev/null +++ b/README-streamable-http.md @@ -0,0 +1,301 @@ +# MCP Streamable HTTP Implementation + +This is an implementation of the [Model Context Protocol (MCP)](https://modelcontextprotocol.io/) Streamable HTTP transport for Go. It follows the [MCP Streamable HTTP transport specification](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports). + +## Features + +- Full implementation of the MCP Streamable HTTP transport specification +- Support for both client and server sides +- Session management with unique session IDs +- Support for SSE (Server-Sent Events) streaming +- Support for direct JSON responses +- Support for resumability with event IDs +- Support for notifications +- Support for session termination + +## Server Implementation + +The server implementation is in `server/streamable_http.go`. It provides a complete implementation of the Streamable HTTP transport for the server side. + +### Key Components + +- `StreamableHTTPServer`: The main server implementation that handles HTTP requests and responses +- `streamableHTTPSession`: Represents an active session with a client +- `EventStore`: Interface for storing and retrieving events for resumability +- `InMemoryEventStore`: A simple in-memory implementation of the EventStore interface + +### Server Options + +- `WithSessionIDGenerator`: Sets a custom session ID generator +- `WithEnableJSONResponse`: Enables direct JSON responses instead of SSE streams +- `WithEventStore`: Sets a custom event store for resumability +- `WithStreamableHTTPContextFunc`: Sets a function to customize the context + +## Client Implementation + +The client implementation is in `client/transport/streamable_http.go`. It provides a complete implementation of the Streamable HTTP transport for the client side. + +## Usage + +### Server Example + +```go +package main + +import ( + "context" + "fmt" + "log" + "os" + "os/signal" + "syscall" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +func main() { + // Create a new MCP server + mcpServer := server.NewMCPServer("example-server", "1.0.0", + server.WithResourceCapabilities(true, true), + server.WithPromptCapabilities(true), + server.WithToolCapabilities(true), + server.WithLogging(), + server.WithInstructions("This is an example Streamable HTTP server."), + ) + + // Add a simple echo tool + mcpServer.AddTool( + mcp.Tool{ + Name: "echo", + Description: "Echoes back the input", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "description": "The message to echo", + }, + }, + Required: []string{"message"}, + }, + }, + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Extract the message from the request + message, ok := request.Params.Arguments["message"].(string) + if !ok { + return nil, fmt.Errorf("message must be a string") + } + + // Create the result + result := &mcp.CallToolResult{ + Result: mcp.Result{}, + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Message: %s\nTimestamp: %s", message, time.Now().Format(time.RFC3339)), + }, + }, + } + + // Send a notification after a short delay + go func() { + time.Sleep(1 * time.Second) + mcpServer.SendNotificationToClient(ctx, "echo/notification", map[string]interface{}{ + "message": "Echo notification: " + message, + }) + }() + + return result, nil + }, + ) + + // Create a new Streamable HTTP server + streamableServer := server.NewStreamableHTTPServer(mcpServer, + server.WithEnableJSONResponse(false), // Use SSE streaming by default + ) + + // Start the server in a goroutine + go func() { + log.Println("Starting Streamable HTTP server on :8080...") + if err := streamableServer.Start(":8080"); err != nil { + log.Fatalf("Failed to start server: %v", err) + } + }() + + // Wait for interrupt signal to gracefully shutdown the server + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + + log.Println("Shutting down server...") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := streamableServer.Shutdown(ctx); err != nil { + log.Fatalf("Server shutdown failed: %v", err) + } + log.Println("Server exited properly") +} +``` + +### Client Example + +```go +package main + +import ( + "context" + "encoding/json" + "fmt" + "os" + "os/signal" + "syscall" + "time" + + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +func main() { + // Create a new Streamable HTTP transport + trans, err := transport.NewStreamableHTTP("http://localhost:8080/mcp") + if err != nil { + fmt.Printf("Failed to create transport: %v\n", err) + os.Exit(1) + } + defer trans.Close() + + // Set up notification handler + trans.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { + fmt.Printf("Received notification: %s\n", notification.Method) + params, _ := json.MarshalIndent(notification.Params, "", " ") + fmt.Printf("Params: %s\n", params) + }) + + // Create a context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Initialize the connection + fmt.Println("Initializing connection...") + initRequest := transport.JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "initialize", + } + + initResponse, err := trans.SendRequest(ctx, initRequest) + if err != nil { + fmt.Printf("Failed to initialize: %v\n", err) + os.Exit(1) + } + + // Print the initialization response + initResponseJSON, _ := json.MarshalIndent(initResponse, "", " ") + fmt.Printf("Initialization response: %s\n", initResponseJSON) + + // List available tools + fmt.Println("\nListing available tools...") + listToolsRequest := transport.JSONRPCRequest{ + JSONRPC: "2.0", + ID: 2, + Method: "tools/list", + } + + listToolsResponse, err := trans.SendRequest(ctx, listToolsRequest) + if err != nil { + fmt.Printf("Failed to list tools: %v\n", err) + os.Exit(1) + } + + // Print the tools list response + toolsResponseJSON, _ := json.MarshalIndent(listToolsResponse, "", " ") + fmt.Printf("Tools list response: %s\n", toolsResponseJSON) + + // Call the echo tool + fmt.Println("\nCalling echo tool...") + fmt.Println("Using session ID from initialization...") + echoRequest := transport.JSONRPCRequest{ + JSONRPC: "2.0", + ID: 3, + Method: "tools/call", + Params: map[string]interface{}{ + "name": "echo", + "arguments": map[string]interface{}{ + "message": "Hello from Streamable HTTP client!", + }, + }, + } + + echoResponse, err := trans.SendRequest(ctx, echoRequest) + if err != nil { + fmt.Printf("Failed to call echo tool: %v\n", err) + os.Exit(1) + } + + // Print the echo response + echoResponseJSON, _ := json.MarshalIndent(echoResponse, "", " ") + fmt.Printf("Echo response: %s\n", echoResponseJSON) + + // Wait for notifications (the echo tool sends a notification after 1 second) + fmt.Println("\nWaiting for notifications...") + fmt.Println("(The server should send a notification about 1 second after the tool call)") + + // Set up a signal channel to handle Ctrl+C + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + // Wait for either a signal or a timeout + select { + case <-sigChan: + fmt.Println("Received interrupt signal, exiting...") + case <-time.After(5 * time.Second): + fmt.Println("Timeout reached, exiting...") + } +} +``` + +## Running the Examples + +1. Start the server: + +```bash +go run examples/streamable_http_server/main.go +``` + +2. In another terminal, run the client: + +```bash +go run examples/streamable_http_client/main.go +``` + +## Protocol Details + +The Streamable HTTP transport follows the MCP Streamable HTTP transport specification. Key aspects include: + +1. **Session Management**: Sessions are created during initialization and maintained through a session ID header. +2. **SSE Streaming**: Server-Sent Events (SSE) are used for streaming responses and notifications. +3. **Direct JSON Responses**: For simple requests, direct JSON responses can be used instead of SSE. +4. **Resumability**: Events can be stored and replayed if a client reconnects with a Last-Event-ID header. +5. **Session Termination**: Sessions can be explicitly terminated with a DELETE request. + +## HTTP Methods + +- **POST**: Used for sending JSON-RPC requests and notifications +- **GET**: Used for establishing a standalone SSE stream for receiving notifications +- **DELETE**: Used for terminating a session + +## HTTP Headers + +- **Mcp-Session-Id**: Used to identify a session +- **Accept**: Used to indicate support for SSE (`text/event-stream`) +- **Last-Event-Id**: Used for resumability + +## Implementation Notes + +- The server implementation supports both stateful and stateless modes. +- In stateful mode, a session ID is generated and maintained for each client. +- In stateless mode, no session ID is generated, and no session state is maintained. +- The client implementation supports reconnecting and resuming after disconnection. +- The server implementation supports multiple concurrent clients. diff --git a/examples/minimal_client/main.go b/examples/minimal_client/main.go new file mode 100644 index 00000000..269bd3b5 --- /dev/null +++ b/examples/minimal_client/main.go @@ -0,0 +1,71 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "os" + "time" + + "github.com/mark3labs/mcp-go/client/transport" +) + +func main() { + // Create a new Streamable HTTP transport with a longer timeout + trans, err := transport.NewStreamableHTTP("http://localhost:8080/mcp", + transport.WithHTTPTimeout(30*time.Second)) + if err != nil { + fmt.Printf("Failed to create transport: %v\n", err) + os.Exit(1) + } + defer trans.Close() + + // Create a context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Initialize the connection + fmt.Println("Initializing connection...") + initRequest := transport.JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "initialize", + } + + initResponse, err := trans.SendRequest(ctx, initRequest) + if err != nil { + fmt.Printf("Failed to initialize: %v\n", err) + os.Exit(1) + } + + // Print the initialization response + initResponseJSON, _ := json.MarshalIndent(initResponse, "", " ") + fmt.Printf("Initialization response: %s\n", initResponseJSON) + fmt.Printf("Session ID: %s\n", trans.GetSessionId()) + + // Call the echo tool + fmt.Println("\nCalling echo tool...") + echoRequest := transport.JSONRPCRequest{ + JSONRPC: "2.0", + ID: 2, + Method: "tools/call", + Params: map[string]interface{}{ + "name": "echo", + "arguments": map[string]interface{}{ + "message": "Hello from minimal client!", + }, + }, + } + + echoResponse, err := trans.SendRequest(ctx, echoRequest) + if err != nil { + fmt.Printf("Failed to call echo tool: %v\n", err) + os.Exit(1) + } + + // Print the echo response + echoResponseJSON, _ := json.MarshalIndent(echoResponse, "", " ") + fmt.Printf("Echo response: %s\n", echoResponseJSON) + + fmt.Println("\nTest completed successfully!") +} diff --git a/examples/minimal_server/main.go b/examples/minimal_server/main.go new file mode 100644 index 00000000..fca772dd --- /dev/null +++ b/examples/minimal_server/main.go @@ -0,0 +1,83 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + "os/signal" + "syscall" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +func main() { + // Create a new MCP server + mcpServer := server.NewMCPServer("minimal-server", "1.0.0") + + // Add a simple echo tool + mcpServer.AddTool( + mcp.Tool{ + Name: "echo", + Description: "Echoes back the input", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "description": "The message to echo", + }, + }, + Required: []string{"message"}, + }, + }, + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Extract the message from the request + message, ok := request.Params.Arguments["message"].(string) + if !ok { + return nil, fmt.Errorf("message must be a string") + } + + // Create the result + result := &mcp.CallToolResult{ + Result: mcp.Result{}, + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Echo: %s", message), + }, + }, + } + + return result, nil + }, + ) + + // Create a new Streamable HTTP server with direct JSON responses + streamableServer := server.NewStreamableHTTPServer(mcpServer, + server.WithEnableJSONResponse(true), + ) + + // Start the server in a goroutine + go func() { + log.Println("Starting Minimal Streamable HTTP server on :8080...") + if err := streamableServer.Start(":8080"); err != nil { + log.Fatalf("Failed to start server: %v", err) + } + }() + + // Wait for interrupt signal to gracefully shutdown the server + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + + log.Println("Shutting down server...") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := streamableServer.Shutdown(ctx); err != nil { + log.Fatalf("Server shutdown failed: %v", err) + } + log.Println("Server exited properly") +} diff --git a/examples/streamable_http_client/main.go b/examples/streamable_http_client/main.go new file mode 100644 index 00000000..a003c23c --- /dev/null +++ b/examples/streamable_http_client/main.go @@ -0,0 +1,56 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "os" + "time" + + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +func main() { + // Create a new Streamable HTTP transport with a longer timeout + trans, err := transport.NewStreamableHTTP("http://localhost:8080/mcp", + transport.WithHTTPTimeout(30*time.Second)) + if err != nil { + fmt.Printf("Failed to create transport: %v\n", err) + os.Exit(1) + } + defer trans.Close() + + // Set up notification handler + trans.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { + fmt.Printf("Received notification: %s\n", notification.Method) + params, _ := json.MarshalIndent(notification.Params, "", " ") + fmt.Printf("Params: %s\n", params) + }) + + // Create a context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Initialize the connection + fmt.Println("Initializing connection...") + initRequest := transport.JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "initialize", + } + + initResponse, err := trans.SendRequest(ctx, initRequest) + if err != nil { + fmt.Printf("Failed to initialize: %v\n", err) + os.Exit(1) + } + + // Print the initialization response + initResponseJSON, _ := json.MarshalIndent(initResponse, "", " ") + fmt.Printf("Initialization response: %s\n", initResponseJSON) + fmt.Printf("Session ID: %s\n", trans.GetSessionId()) + + // Wait for a moment + fmt.Println("\nInitialization successful. Exiting...") +} diff --git a/examples/streamable_http_client_complete/main.go b/examples/streamable_http_client_complete/main.go new file mode 100644 index 00000000..a74fdb84 --- /dev/null +++ b/examples/streamable_http_client_complete/main.go @@ -0,0 +1,131 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "os" + "os/signal" + "syscall" + "time" + + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +func main() { + // Create a new Streamable HTTP transport with a longer timeout + trans, err := transport.NewStreamableHTTP("http://localhost:8080/mcp", + transport.WithHTTPTimeout(30*time.Second)) + if err != nil { + fmt.Printf("Failed to create transport: %v\n", err) + os.Exit(1) + } + defer trans.Close() + + // Set up notification handler + trans.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { + fmt.Printf("\nReceived notification: %s\n", notification.Method) + params, _ := json.MarshalIndent(notification.Params, "", " ") + fmt.Printf("Params: %s\n", params) + }) + + // Create a context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Initialize the connection + fmt.Println("Initializing connection...") + initRequest := transport.JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "initialize", + } + + initResponse, err := trans.SendRequest(ctx, initRequest) + if err != nil { + fmt.Printf("Failed to initialize: %v\n", err) + os.Exit(1) + } + + // Print the initialization response + initResponseJSON, _ := json.MarshalIndent(initResponse, "", " ") + fmt.Printf("Initialization response: %s\n", initResponseJSON) + fmt.Printf("Session ID: %s\n", trans.GetSessionId()) + + // List available tools + fmt.Println("\nListing available tools...") + listToolsRequest := transport.JSONRPCRequest{ + JSONRPC: "2.0", + ID: 2, + Method: "tools/list", + } + + listToolsResponse, err := trans.SendRequest(ctx, listToolsRequest) + if err != nil { + fmt.Printf("Failed to list tools: %v\n", err) + os.Exit(1) + } + + // Print the tools list response + toolsResponseJSON, _ := json.MarshalIndent(listToolsResponse, "", " ") + fmt.Printf("Tools list response: %s\n", toolsResponseJSON) + + // Extract tool information + var toolsResult struct { + Result struct { + Tools []struct { + Name string `json:"name"` + Description string `json:"description"` + } `json:"tools"` + } `json:"result"` + } + if err := json.Unmarshal(listToolsResponse.Result, &toolsResult); err != nil { + fmt.Printf("Failed to parse tools list: %v\n", err) + } else { + fmt.Println("\nAvailable tools:") + for _, tool := range toolsResult.Result.Tools { + fmt.Printf("- %s: %s\n", tool.Name, tool.Description) + } + } + + // Call the echo tool + fmt.Println("\nCalling echo tool...") + echoRequest := transport.JSONRPCRequest{ + JSONRPC: "2.0", + ID: 3, + Method: "tools/call", + Params: map[string]interface{}{ + "name": "echo", + "arguments": map[string]interface{}{ + "message": "Hello from Streamable HTTP client!", + }, + }, + } + + echoResponse, err := trans.SendRequest(ctx, echoRequest) + if err != nil { + fmt.Printf("Failed to call echo tool: %v\n", err) + os.Exit(1) + } + + // Print the echo response + echoResponseJSON, _ := json.MarshalIndent(echoResponse, "", " ") + fmt.Printf("Echo response: %s\n", echoResponseJSON) + + // Wait for notifications (the echo tool sends a notification after 1 second) + fmt.Println("\nWaiting for notifications...") + fmt.Println("(The server should send a notification about 1 second after the tool call)") + + // Set up a signal channel to handle Ctrl+C + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + // Wait for either a signal or a timeout + select { + case <-sigChan: + fmt.Println("Received interrupt signal, exiting...") + case <-time.After(5 * time.Second): + fmt.Println("Timeout reached, exiting...") + } +} diff --git a/examples/streamable_http_server/main.go b/examples/streamable_http_server/main.go new file mode 100644 index 00000000..9aa20d9b --- /dev/null +++ b/examples/streamable_http_server/main.go @@ -0,0 +1,97 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + "os/signal" + "syscall" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +func main() { + // Create a new MCP server + mcpServer := server.NewMCPServer("example-server", "1.0.0", + server.WithResourceCapabilities(true, true), + server.WithPromptCapabilities(true), + server.WithToolCapabilities(true), + server.WithLogging(), + server.WithInstructions("This is an example Streamable HTTP server."), + ) + + // Add a simple echo tool + mcpServer.AddTool( + mcp.Tool{ + Name: "echo", + Description: "Echoes back the input", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "description": "The message to echo", + }, + }, + Required: []string{"message"}, + }, + }, + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Extract the message from the request + message, ok := request.Params.Arguments["message"].(string) + if !ok { + return nil, fmt.Errorf("message must be a string") + } + + // Create the result + result := &mcp.CallToolResult{ + Result: mcp.Result{}, + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Message: %s\nTimestamp: %s", message, time.Now().Format(time.RFC3339)), + }, + }, + } + + // Send a notification after a short delay + go func() { + time.Sleep(1 * time.Second) + mcpServer.SendNotificationToClient(ctx, "echo/notification", map[string]interface{}{ + "message": "Echo notification: " + message, + }) + }() + + return result, nil + }, + ) + + // Create a new Streamable HTTP server + streamableServer := server.NewStreamableHTTPServer(mcpServer, + server.WithEnableJSONResponse(true), // Use direct JSON responses for simplicity + ) + + // Start the server in a goroutine + go func() { + log.Println("Starting Streamable HTTP server on :8080...") + if err := streamableServer.Start(":8080"); err != nil { + log.Fatalf("Failed to start server: %v", err) + } + }() + + // Wait for interrupt signal to gracefully shutdown the server + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + + log.Println("Shutting down server...") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := streamableServer.Shutdown(ctx); err != nil { + log.Fatalf("Server shutdown failed: %v", err) + } + log.Println("Server exited properly") +} diff --git a/server/streamable_http.go b/server/streamable_http.go new file mode 100644 index 00000000..57324fab --- /dev/null +++ b/server/streamable_http.go @@ -0,0 +1,845 @@ +package server + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + + "github.com/google/uuid" + "github.com/mark3labs/mcp-go/mcp" +) + +// streamableHTTPSession represents an active Streamable HTTP connection. +type streamableHTTPSession struct { + sessionID string + notificationChannel chan mcp.JSONRPCNotification + initialized atomic.Bool + lastEventID string + eventStore EventStore +} + +func (s *streamableHTTPSession) SessionID() string { + return s.sessionID +} + +func (s *streamableHTTPSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return s.notificationChannel +} + +func (s *streamableHTTPSession) Initialize() { + s.initialized.Store(true) +} + +func (s *streamableHTTPSession) Initialized() bool { + return s.initialized.Load() +} + +var _ ClientSession = (*streamableHTTPSession)(nil) + +// EventStore is an interface for storing and retrieving events for resumability +type EventStore interface { + // StoreEvent stores an event and returns its ID + StoreEvent(streamID string, message mcp.JSONRPCMessage) (string, error) + // ReplayEventsAfter replays events that occurred after the given event ID + ReplayEventsAfter(lastEventID string, send func(eventID string, message mcp.JSONRPCMessage) error) error +} + +// InMemoryEventStore is a simple in-memory implementation of EventStore +type InMemoryEventStore struct { + mu sync.RWMutex + events map[string][]storedEvent +} + +type storedEvent struct { + id string + message mcp.JSONRPCMessage +} + +// NewInMemoryEventStore creates a new in-memory event store +func NewInMemoryEventStore() *InMemoryEventStore { + return &InMemoryEventStore{ + events: make(map[string][]storedEvent), + } +} + +// StoreEvent stores an event in memory +func (s *InMemoryEventStore) StoreEvent(streamID string, message mcp.JSONRPCMessage) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + + eventID := uuid.New().String() + event := storedEvent{ + id: eventID, + message: message, + } + + if _, ok := s.events[streamID]; !ok { + s.events[streamID] = []storedEvent{} + } + s.events[streamID] = append(s.events[streamID], event) + + return eventID, nil +} + +// ReplayEventsAfter replays events that occurred after the given event ID +func (s *InMemoryEventStore) ReplayEventsAfter(lastEventID string, send func(eventID string, message mcp.JSONRPCMessage) error) error { + s.mu.RLock() + defer s.mu.RUnlock() + + if lastEventID == "" { + return nil + } + + // Find the stream that contains the event + var streamEvents []storedEvent + var found bool + var _ string // streamID, used for debugging if needed + + for sid, events := range s.events { + for _, event := range events { + if event.id == lastEventID { + streamEvents = events + _ = sid // store for debugging if needed + found = true + break + } + } + if found { + break + } + } + + if !found { + return fmt.Errorf("event ID not found: %s", lastEventID) + } + + // Find the index of the last event + lastIdx := -1 + for i, event := range streamEvents { + if event.id == lastEventID { + lastIdx = i + break + } + } + + // Replay events after the last event + for i := lastIdx + 1; i < len(streamEvents); i++ { + if err := send(streamEvents[i].id, streamEvents[i].message); err != nil { + return err + } + } + + return nil +} + +// StreamableHTTPOption defines a function type for configuring StreamableHTTPServer +type StreamableHTTPOption func(*StreamableHTTPServer) + +// WithSessionIDGenerator sets a custom session ID generator +func WithSessionIDGenerator(generator func() string) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + s.sessionIDGenerator = generator + } +} + +// WithEnableJSONResponse enables direct JSON responses instead of SSE streams +func WithEnableJSONResponse(enable bool) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + s.enableJSONResponse = enable + } +} + +// WithEventStore sets a custom event store for resumability +func WithEventStore(store EventStore) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + s.eventStore = store + } +} + +// WithStreamableHTTPContextFunc sets a function that will be called to customize the context +// to the server using the incoming request. +func WithStreamableHTTPContextFunc(fn SSEContextFunc) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + s.contextFunc = fn + } +} + +// StreamableHTTPServer implements a Streamable HTTP based MCP server. +// It provides HTTP transport capabilities following the MCP Streamable HTTP specification. +type StreamableHTTPServer struct { + server *MCPServer + baseURL string + basePath string + endpoint string + sessions sync.Map + srv *http.Server + contextFunc SSEContextFunc + sessionIDGenerator func() string + enableJSONResponse bool + eventStore EventStore + standaloneStreamID string + streamMapping sync.Map // Maps streamID to response writer + requestToStreamMap sync.Map // Maps requestID to streamID +} + +// NewStreamableHTTPServer creates a new Streamable HTTP server instance with the given MCP server and options. +func NewStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *StreamableHTTPServer { + s := &StreamableHTTPServer{ + server: server, + endpoint: "/mcp", + sessionIDGenerator: func() string { return uuid.New().String() }, + enableJSONResponse: false, + standaloneStreamID: "_GET_stream", + } + + // Apply all options + for _, opt := range opts { + opt(s) + } + + // If no event store is provided, create an in-memory one + if s.eventStore == nil { + s.eventStore = NewInMemoryEventStore() + } + + return s +} + +// Start begins serving Streamable HTTP connections on the specified address. +// It sets up HTTP handlers for the MCP endpoint. +func (s *StreamableHTTPServer) Start(addr string) error { + s.srv = &http.Server{ + Addr: addr, + Handler: s, + } + + return s.srv.ListenAndServe() +} + +// Shutdown gracefully stops the Streamable HTTP server, closing all active sessions +// and shutting down the HTTP server. +func (s *StreamableHTTPServer) Shutdown(ctx context.Context) error { + if s.srv != nil { + s.sessions.Range(func(key, value interface{}) bool { + if session, ok := value.(*streamableHTTPSession); ok { + close(session.notificationChannel) + } + s.sessions.Delete(key) + return true + }) + + return s.srv.Shutdown(ctx) + } + return nil +} + +// ServeHTTP implements the http.Handler interface. +func (s *StreamableHTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + path := r.URL.Path + endpoint := s.basePath + s.endpoint + + if path != endpoint { + http.NotFound(w, r) + return + } + + switch r.Method { + case http.MethodPost: + s.handlePost(w, r) + case http.MethodGet: + s.handleGet(w, r) + case http.MethodDelete: + s.handleDelete(w, r) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } +} + +// handlePost processes POST requests to the MCP endpoint +func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request) { + // Get session ID from header if present + sessionID := r.Header.Get("Mcp-Session-Id") + var session *streamableHTTPSession + + // Check if this is a request with a valid session + if sessionID != "" { + if sessionValue, ok := s.sessions.Load(sessionID); ok { + if sess, ok := sessionValue.(*streamableHTTPSession); ok { + session = sess + } else { + http.Error(w, "Invalid session", http.StatusBadRequest) + return + } + } else { + // Session not found + http.Error(w, "Session not found", http.StatusNotFound) + return + } + } + + // Parse the request body + var rawMessage json.RawMessage + if err := json.NewDecoder(r.Body).Decode(&rawMessage); err != nil { + http.Error(w, "Invalid JSON", http.StatusBadRequest) + return + } + + // Parse the base message to determine if it's a request or notification + var baseMessage struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + ID interface{} `json:"id,omitempty"` + } + if err := json.Unmarshal(rawMessage, &baseMessage); err != nil { + http.Error(w, "Invalid JSON-RPC message", http.StatusBadRequest) + return + } + + // Create context for the request + ctx := r.Context() + if session != nil { + ctx = s.server.WithContext(ctx, session) + } + if s.contextFunc != nil { + ctx = s.contextFunc(ctx, r) + } + + // Handle the message based on whether it's a request or notification + if baseMessage.ID == nil { + // It's a notification + s.handleNotification(w, ctx, rawMessage) + } else { + // It's a request + s.handleRequest(w, r, ctx, rawMessage, session) + } +} + +// handleNotification processes JSON-RPC notifications +func (s *StreamableHTTPServer) handleNotification(w http.ResponseWriter, ctx context.Context, rawMessage json.RawMessage) { + // Process the notification + s.server.HandleMessage(ctx, rawMessage) + + // Return 202 Accepted for notifications + w.WriteHeader(http.StatusAccepted) +} + +// handleRequest processes JSON-RPC requests +func (s *StreamableHTTPServer) handleRequest(w http.ResponseWriter, r *http.Request, ctx context.Context, rawMessage json.RawMessage, session *streamableHTTPSession) { + // Parse the request to get the method and ID + var request struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + ID interface{} `json:"id"` + } + if err := json.Unmarshal(rawMessage, &request); err != nil { + http.Error(w, "Invalid JSON-RPC request", http.StatusBadRequest) + return + } + + // Check if this is an initialization request + isInitialize := request.Method == "initialize" + + // If this is not an initialization request and we don't have a session, + // and we're not in stateless mode (sessionIDGenerator returns empty string), + // then reject the request + if !isInitialize && session == nil && s.sessionIDGenerator() != "" { + http.Error(w, "Bad Request: Server not initialized", http.StatusBadRequest) + return + } + + // Process the request + response := s.server.HandleMessage(ctx, rawMessage) + + // If this is an initialization request, create a new session + if isInitialize && response != nil { + // Only create a session if we're not in stateless mode + if s.sessionIDGenerator() != "" { + newSessionID := s.sessionIDGenerator() + newSession := &streamableHTTPSession{ + sessionID: newSessionID, + notificationChannel: make(chan mcp.JSONRPCNotification, 100), + eventStore: s.eventStore, + } + + // Register the session + s.sessions.Store(newSessionID, newSession) + if err := s.server.RegisterSession(ctx, newSession); err != nil { + http.Error(w, fmt.Sprintf("Failed to register session: %v", err), http.StatusInternalServerError) + return + } + + // Set the session ID in the response header + w.Header().Set("Mcp-Session-Id", newSessionID) + + // Update the session reference for further processing + session = newSession + } + } + + // Check if the client accepts SSE + acceptHeader := r.Header.Get("Accept") + acceptsSSE := false + for _, accept := range splitHeader(acceptHeader) { + if accept == "text/event-stream" { + acceptsSSE = true + break + } + } + + // Determine if we should use SSE or direct JSON response + useSSE := false + + // If the request contains any requests (not just notifications), we might use SSE + if request.ID != nil { + // Use SSE if: + // 1. The client accepts SSE + // 2. We have a valid session + // 3. JSON response is not explicitly enabled + // 4. The request is not an initialization request (those always return JSON) + if acceptsSSE && session != nil && !s.enableJSONResponse && !isInitialize { + useSSE = true + } + } + + if useSSE { + // Start an SSE stream for this request + s.handleSSEResponse(w, r, ctx, response, session) + } else { + // Send a direct JSON response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if response != nil { + json.NewEncoder(w).Encode(response) + } + } +} + +// handleSSEResponse sends the response as an SSE stream +func (s *StreamableHTTPServer) handleSSEResponse(w http.ResponseWriter, r *http.Request, ctx context.Context, initialResponse mcp.JSONRPCMessage, session *streamableHTTPSession) { + // Set SSE headers + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache, no-transform") + w.Header().Set("Connection", "keep-alive") + w.WriteHeader(http.StatusOK) + + // Create a unique stream ID for this request + streamID := uuid.New().String() + + // Get the request ID from the initial response + var requestID interface{} + if resp, ok := initialResponse.(mcp.JSONRPCResponse); ok { + requestID = resp.ID + } else if errResp, ok := initialResponse.(mcp.JSONRPCError); ok { + requestID = errResp.ID + } + + // If we have a request ID, map it to this stream + if requestID != nil { + s.requestToStreamMap.Store(requestID, streamID) + defer s.requestToStreamMap.Delete(requestID) + } + + // Create a channel for this stream + eventChan := make(chan string, 10) + defer close(eventChan) + + // Store the stream mapping + s.streamMapping.Store(streamID, eventChan) + defer s.streamMapping.Delete(streamID) + + // Check for Last-Event-ID header for resumability + lastEventID := r.Header.Get("Last-Event-Id") + if lastEventID != "" && session.eventStore != nil { + // Replay events that occurred after the last event ID + err := session.eventStore.ReplayEventsAfter(lastEventID, func(eventID string, message mcp.JSONRPCMessage) error { + data, err := json.Marshal(message) + if err != nil { + return err + } + + eventData := fmt.Sprintf("id: %s\ndata: %s\n\n", eventID, data) + select { + case eventChan <- eventData: + return nil + case <-ctx.Done(): + return ctx.Err() + } + }) + + if err != nil { + // Log the error but continue + fmt.Printf("Error replaying events: %v\n", err) + } + } + + // Send the initial response if there is one + if initialResponse != nil { + data, err := json.Marshal(initialResponse) + if err != nil { + http.Error(w, "Failed to marshal response", http.StatusInternalServerError) + return + } + + // Store the event if we have an event store + var eventID string + if session.eventStore != nil { + var storeErr error + eventID, storeErr = session.eventStore.StoreEvent(streamID, initialResponse) + if storeErr != nil { + // Log the error but continue + fmt.Printf("Error storing event: %v\n", storeErr) + } + } + + // Send the event + if eventID != "" { + fmt.Fprintf(w, "id: %s\ndata: %s\n\n", eventID, data) + } else { + fmt.Fprintf(w, "data: %s\n\n", data) + } + w.(http.Flusher).Flush() + } + + // Start a goroutine to listen for notifications and forward them to the client + notifDone := make(chan struct{}) + defer close(notifDone) + + go func() { + for { + select { + case notification, ok := <-session.notificationChannel: + if !ok { + return + } + + data, err := json.Marshal(notification) + if err != nil { + continue + } + + // Store the event if we have an event store + var eventID string + if session.eventStore != nil { + var storeErr error + eventID, storeErr = session.eventStore.StoreEvent(streamID, notification) + if storeErr != nil { + // Log the error but continue + fmt.Printf("Error storing event: %v\n", storeErr) + } + } + + // Create the event data + var eventData string + if eventID != "" { + eventData = fmt.Sprintf("id: %s\ndata: %s\n\n", eventID, data) + } else { + eventData = fmt.Sprintf("data: %s\n\n", data) + } + + // Send the event to the channel + select { + case eventChan <- eventData: + // Event sent successfully + case <-notifDone: + return + } + case <-notifDone: + return + } + } + }() + + // Main event loop + for { + select { + case event := <-eventChan: + // Write the event to the response + _, err := fmt.Fprint(w, event) + if err != nil { + return + } + w.(http.Flusher).Flush() + case <-r.Context().Done(): + return + } + } +} + +// handleGet processes GET requests to the MCP endpoint (for standalone SSE streams) +func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) { + // Check if the client accepts SSE + acceptHeader := r.Header.Get("Accept") + acceptsSSE := false + for _, accept := range splitHeader(acceptHeader) { + if accept == "text/event-stream" { + acceptsSSE = true + break + } + } + + if !acceptsSSE { + http.Error(w, "Not Acceptable: Client must accept text/event-stream", http.StatusNotAcceptable) + return + } + + // Get session ID from header if present + sessionID := r.Header.Get("Mcp-Session-Id") + var session *streamableHTTPSession + + // Check if this is a request with a valid session + if sessionID != "" { + if sessionValue, ok := s.sessions.Load(sessionID); ok { + if sess, ok := sessionValue.(*streamableHTTPSession); ok { + session = sess + } else { + http.Error(w, "Invalid session", http.StatusBadRequest) + return + } + } else { + // Session not found + http.Error(w, "Session not found", http.StatusNotFound) + return + } + } else { + // No session ID provided + http.Error(w, "Bad Request: Mcp-Session-Id header must be provided", http.StatusBadRequest) + return + } + + // Create context for the request + ctx := r.Context() + ctx = s.server.WithContext(ctx, session) + if s.contextFunc != nil { + ctx = s.contextFunc(ctx, r) + } + + // Set SSE headers + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache, no-transform") + w.Header().Set("Connection", "keep-alive") + w.WriteHeader(http.StatusOK) + + // Create a channel for this stream + eventChan := make(chan string, 10) + defer close(eventChan) + + // Store the stream mapping for the standalone stream + s.streamMapping.Store(s.standaloneStreamID, eventChan) + defer s.streamMapping.Delete(s.standaloneStreamID) + + // Check for Last-Event-ID header for resumability + lastEventID := r.Header.Get("Last-Event-Id") + if lastEventID != "" && session.eventStore != nil { + // Replay events that occurred after the last event ID + err := session.eventStore.ReplayEventsAfter(lastEventID, func(eventID string, message mcp.JSONRPCMessage) error { + data, err := json.Marshal(message) + if err != nil { + return err + } + + eventData := fmt.Sprintf("id: %s\ndata: %s\n\n", eventID, data) + select { + case eventChan <- eventData: + return nil + case <-ctx.Done(): + return ctx.Err() + } + }) + + if err != nil { + // Log the error but continue + fmt.Printf("Error replaying events: %v\n", err) + } + } + + // Start a goroutine to listen for notifications and forward them to the client + notifDone := make(chan struct{}) + defer close(notifDone) + + go func() { + for { + select { + case notification, ok := <-session.notificationChannel: + if !ok { + return + } + + data, err := json.Marshal(notification) + if err != nil { + continue + } + + // Store the event if we have an event store + var eventID string + if session.eventStore != nil { + var storeErr error + eventID, storeErr = session.eventStore.StoreEvent(s.standaloneStreamID, notification) + if storeErr != nil { + // Log the error but continue + fmt.Printf("Error storing event: %v\n", storeErr) + } + } + + // Create the event data + var eventData string + if eventID != "" { + eventData = fmt.Sprintf("id: %s\ndata: %s\n\n", eventID, data) + } else { + eventData = fmt.Sprintf("data: %s\n\n", data) + } + + // Send the event to the channel + select { + case eventChan <- eventData: + // Event sent successfully + case <-notifDone: + return + } + case <-notifDone: + return + } + } + }() + + // Main event loop + for { + select { + case event := <-eventChan: + // Write the event to the response + _, err := fmt.Fprint(w, event) + if err != nil { + return + } + w.(http.Flusher).Flush() + case <-r.Context().Done(): + return + } + } +} + +// handleDelete processes DELETE requests to the MCP endpoint (for session termination) +func (s *StreamableHTTPServer) handleDelete(w http.ResponseWriter, r *http.Request) { + // Get session ID from header + sessionID := r.Header.Get("Mcp-Session-Id") + if sessionID == "" { + http.Error(w, "Bad Request: Mcp-Session-Id header must be provided", http.StatusBadRequest) + return + } + + // Check if the session exists + if _, ok := s.sessions.Load(sessionID); !ok { + http.Error(w, "Session not found", http.StatusNotFound) + return + } + + // Unregister the session + s.server.UnregisterSession(r.Context(), sessionID) + s.sessions.Delete(sessionID) + + // Return 200 OK + w.WriteHeader(http.StatusOK) +} + +// writeSSEEvent writes an SSE event to the given stream +func (s *StreamableHTTPServer) writeSSEEvent(streamID string, event string, message mcp.JSONRPCMessage) error { + // Get the stream channel + streamChanI, ok := s.streamMapping.Load(streamID) + if !ok { + return fmt.Errorf("stream not found: %s", streamID) + } + + streamChan, ok := streamChanI.(chan string) + if !ok { + return fmt.Errorf("invalid stream channel type") + } + + // Marshal the message + data, err := json.Marshal(message) + if err != nil { + return err + } + + // Create the event data + eventData := fmt.Sprintf("event: %s\ndata: %s\n\n", event, data) + + // Send the event to the channel + select { + case streamChan <- eventData: + return nil + default: + return fmt.Errorf("stream channel full") + } +} + +// splitHeader splits a comma-separated header value into individual values +func splitHeader(header string) []string { + if header == "" { + return nil + } + + var values []string + for _, value := range splitAndTrim(header, ',') { + if value != "" { + values = append(values, value) + } + } + + return values +} + +// splitAndTrim splits a string by the given separator and trims whitespace from each part +func splitAndTrim(s string, sep rune) []string { + var result []string + var builder strings.Builder + var inQuotes bool + + for _, r := range s { + if r == '"' { + inQuotes = !inQuotes + builder.WriteRune(r) + } else if r == sep && !inQuotes { + result = append(result, strings.TrimSpace(builder.String())) + builder.Reset() + } else { + builder.WriteRune(r) + } + } + + if builder.Len() > 0 { + result = append(result, strings.TrimSpace(builder.String())) + } + + return result +} + +// NewTestStreamableHTTPServer creates a test server for testing purposes +func NewTestStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *httptest.Server { + streamableServer := NewStreamableHTTPServer(server, opts...) + testServer := httptest.NewServer(streamableServer) + streamableServer.baseURL = testServer.URL + return testServer +} + +// validateSession checks if the session ID is valid and the session is initialized +func (s *StreamableHTTPServer) validateSession(sessionID string) bool { + if sessionID == "" { + return false + } + + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return false + } + + session, ok := sessionValue.(*streamableHTTPSession) + if !ok { + return false + } + + return session.Initialized() +} diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go new file mode 100644 index 00000000..51ce68f4 --- /dev/null +++ b/server/streamable_http_test.go @@ -0,0 +1,402 @@ +package server + +import ( + "bufio" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +func TestStreamableHTTPServer(t *testing.T) { + // Create a new MCP server + mcpServer := NewMCPServer("test-server", "1.0.0", + WithResourceCapabilities(true, true), + WithPromptCapabilities(true), + WithToolCapabilities(true), + WithLogging(), + ) + + // Create a new Streamable HTTP server + streamableServer := NewStreamableHTTPServer(mcpServer, + WithEnableJSONResponse(false), + ) + + // Create a test server + testServer := httptest.NewServer(streamableServer) + defer testServer.Close() + + t.Run("Initialize", func(t *testing.T) { + // Create a JSON-RPC request + request := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + } + + // Marshal the request + requestBody, err := json.Marshal(request) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + + // Send the request + resp, err := http.Post(testServer.URL+"/mcp", "application/json", strings.NewReader(string(requestBody))) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + defer resp.Body.Close() + + // Check the response status code + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode) + } + + // Check the session ID header + sessionID := resp.Header.Get("Mcp-Session-Id") + if sessionID == "" { + t.Errorf("Expected session ID header, got none") + } + + // Parse the response + var response map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + // Check the response + if response["jsonrpc"] != "2.0" { + t.Errorf("Expected jsonrpc 2.0, got %v", response["jsonrpc"]) + } + if response["id"].(float64) != 1 { + t.Errorf("Expected id 1, got %v", response["id"]) + } + if result, ok := response["result"].(map[string]interface{}); ok { + if serverInfo, ok := result["serverInfo"].(map[string]interface{}); ok { + if serverInfo["name"] != "test-server" { + t.Errorf("Expected server name test-server, got %v", serverInfo["name"]) + } + if serverInfo["version"] != "1.0.0" { + t.Errorf("Expected server version 1.0.0, got %v", serverInfo["version"]) + } + } else { + t.Errorf("Expected serverInfo in result, got none") + } + } else { + t.Errorf("Expected result in response, got none") + } + }) + + t.Run("SSE Stream", func(t *testing.T) { + // Create a JSON-RPC request + request := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + } + + // Marshal the request + requestBody, err := json.Marshal(request) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + + // Send the request to initialize and get a session ID + resp, err := http.Post(testServer.URL+"/mcp", "application/json", strings.NewReader(string(requestBody))) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + sessionID := resp.Header.Get("Mcp-Session-Id") + resp.Body.Close() + + // Create a new request with the session ID + request = map[string]interface{}{ + "jsonrpc": "2.0", + "id": 2, + "method": "ping", + } + + // Marshal the request + requestBody, err = json.Marshal(request) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + + // Create a new HTTP request + req, err := http.NewRequest("POST", testServer.URL+"/mcp", strings.NewReader(string(requestBody))) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Mcp-Session-Id", sessionID) + + // Send the request + client := &http.Client{} + resp, err = client.Do(req) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + defer resp.Body.Close() + + // Check the response status code + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode) + } + + // Check the content type + contentType := resp.Header.Get("Content-Type") + if contentType != "text/event-stream" { + t.Errorf("Expected content type text/event-stream, got %s", contentType) + } + + // Read the response body + reader := bufio.NewReader(resp.Body) + + // Read the first event (should be the ping response) + var eventData string + for { + line, err := reader.ReadString('\n') + if err != nil { + if err == io.EOF { + break + } + t.Fatalf("Failed to read line: %v", err) + } + + line = strings.TrimRight(line, "\r\n") + if line == "" { + // End of event + break + } + + if strings.HasPrefix(line, "data:") { + eventData = strings.TrimSpace(strings.TrimPrefix(line, "data:")) + } + } + + // Parse the event data + var response map[string]interface{} + if err := json.Unmarshal([]byte(eventData), &response); err != nil { + t.Fatalf("Failed to decode event data: %v", err) + } + + // Check the response + if response["jsonrpc"] != "2.0" { + t.Errorf("Expected jsonrpc 2.0, got %v", response["jsonrpc"]) + } + if response["id"].(float64) != 2 { + t.Errorf("Expected id 2, got %v", response["id"]) + } + if _, ok := response["result"]; !ok { + t.Errorf("Expected result in response, got none") + } + }) + + t.Run("GET Stream", func(t *testing.T) { + // Create a JSON-RPC request + request := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + } + + // Marshal the request + requestBody, err := json.Marshal(request) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + + // Send the request to initialize and get a session ID + resp, err := http.Post(testServer.URL+"/mcp", "application/json", strings.NewReader(string(requestBody))) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + sessionID := resp.Header.Get("Mcp-Session-Id") + resp.Body.Close() + + // Create a new HTTP request for GET stream + req, err := http.NewRequest("GET", testServer.URL+"/mcp", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + // Set headers + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Mcp-Session-Id", sessionID) + + // Send the request + client := &http.Client{} + resp, err = client.Do(req) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + defer resp.Body.Close() + + // Check the response status code + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode) + } + + // Check the content type + contentType := resp.Header.Get("Content-Type") + if contentType != "text/event-stream" { + t.Errorf("Expected content type text/event-stream, got %s", contentType) + } + + // Send a notification to the session + go func() { + // Wait a bit for the stream to be established + time.Sleep(100 * time.Millisecond) + + // Create a notification + notification := mcp.JSONRPCNotification{ + JSONRPC: "2.0", + Notification: mcp.Notification{ + Method: "test/notification", + Params: mcp.NotificationParams{ + AdditionalFields: map[string]interface{}{ + "message": "Hello, world!", + }, + }, + }, + } + + // Find the session + sessionValue, ok := streamableServer.sessions.Load(sessionID) + if !ok { + t.Errorf("Session not found: %s", sessionID) + return + } + + // Send the notification + session, ok := sessionValue.(*streamableHTTPSession) + if !ok { + t.Errorf("Invalid session type") + return + } + + session.notificationChannel <- notification + }() + + // Read the response body + reader := bufio.NewReader(resp.Body) + + // Read the first event (should be the notification) + var eventData string + for { + line, err := reader.ReadString('\n') + if err != nil { + if err == io.EOF { + break + } + t.Fatalf("Failed to read line: %v", err) + } + + line = strings.TrimRight(line, "\r\n") + if line == "" { + // End of event + break + } + + if strings.HasPrefix(line, "data:") { + eventData = strings.TrimSpace(strings.TrimPrefix(line, "data:")) + } + } + + // Parse the event data + var notification map[string]interface{} + if err := json.Unmarshal([]byte(eventData), ¬ification); err != nil { + t.Fatalf("Failed to decode event data: %v", err) + } + + // Check the notification + if notification["jsonrpc"] != "2.0" { + t.Errorf("Expected jsonrpc 2.0, got %v", notification["jsonrpc"]) + } + if notification["method"] != "test/notification" { + t.Errorf("Expected method test/notification, got %v", notification["method"]) + } + if params, ok := notification["params"].(map[string]interface{}); ok { + if params["message"] != "Hello, world!" { + t.Errorf("Expected message Hello, world!, got %v", params["message"]) + } + } else { + t.Errorf("Expected params in notification, got none") + } + }) + + t.Run("Session Termination", func(t *testing.T) { + // Create a JSON-RPC request + request := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + } + + // Marshal the request + requestBody, err := json.Marshal(request) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + + // Send the request to initialize and get a session ID + resp, err := http.Post(testServer.URL+"/mcp", "application/json", strings.NewReader(string(requestBody))) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + sessionID := resp.Header.Get("Mcp-Session-Id") + resp.Body.Close() + + // Create a new HTTP request for DELETE + req, err := http.NewRequest("DELETE", testServer.URL+"/mcp", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + // Set headers + req.Header.Set("Mcp-Session-Id", sessionID) + + // Send the request + client := &http.Client{} + resp, err = client.Do(req) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + defer resp.Body.Close() + + // Check the response status code + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode) + } + + // Try to use the session again, should fail + req, err = http.NewRequest("GET", testServer.URL+"/mcp", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + // Set headers + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Mcp-Session-Id", sessionID) + + // Send the request + resp, err = client.Do(req) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + defer resp.Body.Close() + + // Check the response status code + if resp.StatusCode != http.StatusNotFound { + t.Errorf("Expected status code %d, got %d", http.StatusNotFound, resp.StatusCode) + } + }) +}