diff --git a/server/sse.go b/server/sse.go index f69451c6..34a3752e 100644 --- a/server/sse.go +++ b/server/sse.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "log" "net/http" "net/http/httptest" "net/url" @@ -53,15 +54,15 @@ var _ ClientSession = (*sseSession)(nil) // SSEServer implements a Server-Sent Events (SSE) based MCP server. // It provides real-time communication capabilities over HTTP using the SSE protocol. type SSEServer struct { - server *MCPServer - baseURL string - basePath string - useFullURLForMessageEndpoint bool - messageEndpoint string - sseEndpoint string - sessions sync.Map - srv *http.Server - contextFunc SSEContextFunc + server *MCPServer + baseURL string + basePath string + useFullURLForMessageEndpoint bool + messageEndpoint string + sseEndpoint string + sessions sync.Map + srv *http.Server + contextFunc SSEContextFunc keepAlive bool keepAliveInterval time.Duration @@ -158,12 +159,12 @@ func WithSSEContextFunc(fn SSEContextFunc) SSEOption { // NewSSEServer creates a new SSE server instance with the given MCP server and options. func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer { s := &SSEServer{ - server: server, - sseEndpoint: "/sse", - messageEndpoint: "/message", - useFullURLForMessageEndpoint: true, - keepAlive: false, - keepAliveInterval: 10 * time.Second, + server: server, + sseEndpoint: "/sse", + messageEndpoint: "/message", + useFullURLForMessageEndpoint: true, + keepAlive: false, + keepAliveInterval: 10 * time.Second, } // Apply all options @@ -293,7 +294,6 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { }() } - // Send the initial endpoint event fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", s.GetMessageEndpointForClient(sessionID)) flusher.Flush() @@ -323,7 +323,7 @@ func (s *SSEServer) GetMessageEndpointForClient(sessionID string) string { } // handleMessage processes incoming JSON-RPC messages from clients and sends responses -// back through both the SSE connection and HTTP response. +// back through the SSE connection and 202 code to HTTP response. func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { s.writeJSONRPCError(w, nil, mcp.INVALID_REQUEST, "Method not allowed") @@ -356,31 +356,37 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { return } - // Process message through MCPServer - response := s.server.HandleMessage(ctx, rawMessage) + // quick return request, send 202 Accepted with no body, then deal the message and sent response via SSE + w.WriteHeader(http.StatusAccepted) - // Only send response if there is one (not for notifications) - if response != nil { - eventData, _ := json.Marshal(response) + go func() { + // Process message through MCPServer + response := s.server.HandleMessage(ctx, rawMessage) + + // Only send response if there is one (not for notifications) + if response != nil { + var message string + if eventData, err := json.Marshal(response); err != nil { + // If there is an error marshalling the response, send a generic error response + log.Printf("failed to marshal response: %v", err) + message = fmt.Sprintf("event: message\ndata: {\"error\": \"internal error\",\"jsonrpc\": \"2.0\", \"id\": null}\n\n") + return + } else { + message = fmt.Sprintf("event: message\ndata: %s\n\n", eventData) + } - // Queue the event for sending via SSE - select { - case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData): - // Event queued successfully - case <-session.done: - // Session is closed, don't try to queue - default: - // Queue is full, could log this + // Queue the event for sending via SSE + select { + case session.eventQueue <- message: + // Event queued successfully + case <-session.done: + // Session is closed, don't try to queue + default: + // Queue is full, log this situation + log.Printf("Event queue full for session %s", sessionID) + } } - - // Send HTTP response - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusAccepted) - json.NewEncoder(w).Encode(response) - } else { - // For notifications, just send 202 Accepted with no body - w.WriteHeader(http.StatusAccepted) - } + }() } // writeJSONRPCError writes a JSON-RPC error response with the given error details. diff --git a/server/sse_test.go b/server/sse_test.go index 111c5845..7be96a20 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -59,13 +59,10 @@ func TestSSEServer(t *testing.T) { defer sseResp.Body.Close() // Read the endpoint event - buf := make([]byte, 1024) - n, err := sseResp.Body.Read(buf) + endpointEvent, err := readSeeEvent(sseResp) if err != nil { t.Fatalf("Failed to read SSE response: %v", err) } - - endpointEvent := string(buf[:n]) if !strings.Contains(endpointEvent, "event: endpoint") { t.Fatalf("Expected endpoint event, got: %s", endpointEvent) } @@ -107,19 +104,6 @@ func TestSSEServer(t *testing.T) { if resp.StatusCode != http.StatusAccepted { t.Errorf("Expected status 202, got %d", resp.StatusCode) } - - // Verify response - var response map[string]interface{} - if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { - t.Fatalf("Failed to decode response: %v", err) - } - - if response["jsonrpc"] != "2.0" { - t.Errorf("Expected jsonrpc 2.0, got %v", response["jsonrpc"]) - } - if response["id"].(float64) != 1 { - t.Errorf("Expected id 1, got %v", response["id"]) - } }) t.Run("Can handle multiple sessions", func(t *testing.T) { @@ -208,8 +192,17 @@ func TestSSEServer(t *testing.T) { } defer resp.Body.Close() + endpointEvent, err = readSeeEvent(sseResp) + if err != nil { + t.Fatalf("Failed to read SSE response: %v", err) + } + respFromSee := strings.TrimSpace( + strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0], + ) + + fmt.Printf("========> %v", respFromSee) var response map[string]interface{} - if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + if err := json.NewDecoder(strings.NewReader(respFromSee)).Decode(&response); err != nil { t.Errorf( "Session %d: Failed to decode response: %v", sessionNum, @@ -586,13 +579,10 @@ func TestSSEServer(t *testing.T) { defer sseResp.Body.Close() // Read the endpoint event - buf := make([]byte, 1024) - n, err := sseResp.Body.Read(buf) + endpointEvent, err := readSeeEvent(sseResp) if err != nil { t.Fatalf("Failed to read SSE response: %v", err) } - - endpointEvent := string(buf[:n]) messageURL := strings.TrimSpace( strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0], ) @@ -632,8 +622,16 @@ func TestSSEServer(t *testing.T) { } // Verify response + endpointEvent, err = readSeeEvent(sseResp) + if err != nil { + t.Fatalf("Failed to read SSE response: %v", err) + } + respFromSee := strings.TrimSpace( + strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0], + ) + var response map[string]interface{} - if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + if err := json.NewDecoder(strings.NewReader(respFromSee)).Decode(&response); err != nil { t.Fatalf("Failed to decode response: %v", err) } @@ -671,8 +669,17 @@ func TestSSEServer(t *testing.T) { } defer resp.Body.Close() + endpointEvent, err = readSeeEvent(sseResp) + if err != nil { + t.Fatalf("Failed to read SSE response: %v", err) + } + + respFromSee = strings.TrimSpace( + strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0], + ) + response = make(map[string]interface{}) - if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + if err := json.NewDecoder(strings.NewReader(respFromSee)).Decode(&response); err != nil { t.Fatalf("Failed to decode response: %v", err) } @@ -740,3 +747,12 @@ func TestSSEServer(t *testing.T) { } }) } + +func readSeeEvent(sseResp *http.Response) (string, error) { + buf := make([]byte, 1024) + n, err := sseResp.Body.Read(buf) + if err != nil { + return "", err + } + return string(buf[:n]), nil +}