Skip to content

feat: quick return tool-call request, send response via SSE in goroutine #163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 45 additions & 39 deletions server/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"net/http/httptest"
"net/url"
Expand Down Expand Up @@ -53,15 +54,15 @@ var _ ClientSession = (*sseSession)(nil)
// SSEServer implements a Server-Sent Events (SSE) based MCP server.
// It provides real-time communication capabilities over HTTP using the SSE protocol.
type SSEServer struct {
server *MCPServer
baseURL string
basePath string
useFullURLForMessageEndpoint bool
messageEndpoint string
sseEndpoint string
sessions sync.Map
srv *http.Server
contextFunc SSEContextFunc
server *MCPServer
baseURL string
basePath string
useFullURLForMessageEndpoint bool
messageEndpoint string
sseEndpoint string
sessions sync.Map
srv *http.Server
contextFunc SSEContextFunc

keepAlive bool
keepAliveInterval time.Duration
Expand Down Expand Up @@ -158,12 +159,12 @@ func WithSSEContextFunc(fn SSEContextFunc) SSEOption {
// NewSSEServer creates a new SSE server instance with the given MCP server and options.
func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer {
s := &SSEServer{
server: server,
sseEndpoint: "/sse",
messageEndpoint: "/message",
useFullURLForMessageEndpoint: true,
keepAlive: false,
keepAliveInterval: 10 * time.Second,
server: server,
sseEndpoint: "/sse",
messageEndpoint: "/message",
useFullURLForMessageEndpoint: true,
keepAlive: false,
keepAliveInterval: 10 * time.Second,
}

// Apply all options
Expand Down Expand Up @@ -293,7 +294,6 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
}()
}


// Send the initial endpoint event
fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", s.GetMessageEndpointForClient(sessionID))
flusher.Flush()
Expand Down Expand Up @@ -323,7 +323,7 @@ func (s *SSEServer) GetMessageEndpointForClient(sessionID string) string {
}

// handleMessage processes incoming JSON-RPC messages from clients and sends responses
// back through both the SSE connection and HTTP response.
// back through the SSE connection and 202 code to HTTP response.
func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
s.writeJSONRPCError(w, nil, mcp.INVALID_REQUEST, "Method not allowed")
Expand Down Expand Up @@ -356,31 +356,37 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
return
}

// Process message through MCPServer
response := s.server.HandleMessage(ctx, rawMessage)
// quick return request, send 202 Accepted with no body, then deal the message and sent response via SSE
w.WriteHeader(http.StatusAccepted)

// Only send response if there is one (not for notifications)
if response != nil {
eventData, _ := json.Marshal(response)
go func() {
// Process message through MCPServer
response := s.server.HandleMessage(ctx, rawMessage)

// Only send response if there is one (not for notifications)
if response != nil {
var message string
if eventData, err := json.Marshal(response); err != nil {
// If there is an error marshalling the response, send a generic error response
log.Printf("failed to marshal response: %v", err)
message = fmt.Sprintf("event: message\ndata: {\"error\": \"internal error\",\"jsonrpc\": \"2.0\", \"id\": null}\n\n")
return
} else {
message = fmt.Sprintf("event: message\ndata: %s\n\n", eventData)
}

// Queue the event for sending via SSE
select {
case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData):
// Event queued successfully
case <-session.done:
// Session is closed, don't try to queue
default:
// Queue is full, could log this
// Queue the event for sending via SSE
select {
case session.eventQueue <- message:
// Event queued successfully
case <-session.done:
// Session is closed, don't try to queue
default:
// Queue is full, log this situation
log.Printf("Event queue full for session %s", sessionID)
}
}

// Send HTTP response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)
json.NewEncoder(w).Encode(response)
} else {
// For notifications, just send 202 Accepted with no body
w.WriteHeader(http.StatusAccepted)
}
}()
}

// writeJSONRPCError writes a JSON-RPC error response with the given error details.
Expand Down
64 changes: 40 additions & 24 deletions server/sse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,10 @@ func TestSSEServer(t *testing.T) {
defer sseResp.Body.Close()

// Read the endpoint event
buf := make([]byte, 1024)
n, err := sseResp.Body.Read(buf)
endpointEvent, err := readSeeEvent(sseResp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix test code to accommodate asynchronous response delivery.

The test code has been properly adapted to read and process responses from SSE events instead of directly from HTTP response bodies, which aligns with the server-side changes. However, there's a critical issue with using t.Fatalf in a goroutine.

Using t.Fatalf from a non-test goroutine (line 197) can cause race conditions or panics in the test framework. This was previously flagged in code reviews but hasn't been fixed.

endpointEvent, err = readSeeEvent(sseResp)
if err != nil {
-   t.Fatalf("Failed to read SSE response: %v", err)
+   t.Errorf("Session %d: Failed to read SSE response: %v", sessionNum, err)
+   return
}

Apply this fix to all instances of t.Fatalf used in goroutines in this file.

Also applies to: 195-204, 582-632, 672-682

if err != nil {
t.Fatalf("Failed to read SSE response: %v", err)
}

endpointEvent := string(buf[:n])
if !strings.Contains(endpointEvent, "event: endpoint") {
t.Fatalf("Expected endpoint event, got: %s", endpointEvent)
}
Expand Down Expand Up @@ -107,19 +104,6 @@ func TestSSEServer(t *testing.T) {
if resp.StatusCode != http.StatusAccepted {
t.Errorf("Expected status 202, got %d", resp.StatusCode)
}

// Verify response
var response map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}

if response["jsonrpc"] != "2.0" {
t.Errorf("Expected jsonrpc 2.0, got %v", response["jsonrpc"])
}
if response["id"].(float64) != 1 {
t.Errorf("Expected id 1, got %v", response["id"])
}
})

t.Run("Can handle multiple sessions", func(t *testing.T) {
Expand Down Expand Up @@ -208,8 +192,17 @@ func TestSSEServer(t *testing.T) {
}
defer resp.Body.Close()

endpointEvent, err = readSeeEvent(sseResp)
if err != nil {
t.Fatalf("Failed to read SSE response: %v", err)
}
respFromSee := strings.TrimSpace(
strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0],
)

fmt.Printf("========> %v", respFromSee)
var response map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
if err := json.NewDecoder(strings.NewReader(respFromSee)).Decode(&response); err != nil {
Comment on lines +195 to +205
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix: Testing function called from goroutine.

The static analysis tool flagged that t.Fatalf is being called from a non-test goroutine, which can lead to race conditions or panics in the test framework.

-					t.Fatalf("Failed to read SSE response: %v", err)
+					t.Errorf("Session %d: Failed to read SSE response: %v", sessionNum, err)
+					return

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 golangci-lint (1.64.8)

197-197: testinggoroutine: call to (*testing.T).Fatalf from a non-test goroutine

(govet)

t.Errorf(
"Session %d: Failed to decode response: %v",
sessionNum,
Expand Down Expand Up @@ -586,13 +579,10 @@ func TestSSEServer(t *testing.T) {
defer sseResp.Body.Close()

// Read the endpoint event
buf := make([]byte, 1024)
n, err := sseResp.Body.Read(buf)
endpointEvent, err := readSeeEvent(sseResp)
if err != nil {
t.Fatalf("Failed to read SSE response: %v", err)
}

endpointEvent := string(buf[:n])
messageURL := strings.TrimSpace(
strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0],
)
Expand Down Expand Up @@ -632,8 +622,16 @@ func TestSSEServer(t *testing.T) {
}

// Verify response
endpointEvent, err = readSeeEvent(sseResp)
if err != nil {
t.Fatalf("Failed to read SSE response: %v", err)
}
respFromSee := strings.TrimSpace(
strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0],
)

var response map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
if err := json.NewDecoder(strings.NewReader(respFromSee)).Decode(&response); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}

Expand Down Expand Up @@ -671,8 +669,17 @@ func TestSSEServer(t *testing.T) {
}
defer resp.Body.Close()

endpointEvent, err = readSeeEvent(sseResp)
if err != nil {
t.Fatalf("Failed to read SSE response: %v", err)
}

respFromSee = strings.TrimSpace(
strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0],
)

response = make(map[string]interface{})
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
if err := json.NewDecoder(strings.NewReader(respFromSee)).Decode(&response); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}

Expand Down Expand Up @@ -740,3 +747,12 @@ func TestSSEServer(t *testing.T) {
}
})
}

func readSeeEvent(sseResp *http.Response) (string, error) {
buf := make([]byte, 1024)
n, err := sseResp.Body.Read(buf)
if err != nil {
return "", err
}
return string(buf[:n]), nil
}