From 1999773cba368b5dfd8e540736e598680f64d1ed Mon Sep 17 00:00:00 2001 From: Robert Jackson Date: Sun, 27 Apr 2025 09:15:28 -0400 Subject: [PATCH 1/2] feat!(server/sse): Add support for dynamic base paths This change introduces the ability to mount SSE endpoints at dynamic paths with variable segments (e.g., `/api/{tenant}/sse`) by adding a new `WithDynamicBasePath` option and related functionality. This enables advanced use cases such as multi-tenant architectures or integration with routers that support path parameters. Key Features: * DynamicBasePathFunc: New function type and option (WithDynamicBasePath) to generate the SSE server's base path dynamically per request/session. * Flexible Routing: New SSEHandler() and MessageHandler() methods allow mounting handlers at arbitrary or dynamic paths using any router (e.g., net/http, chi, gorilla/mux). * Endpoint Generation: GetMessageEndpointForClient now supports both static and dynamic path modes, and correctly generates full URLs when configured. * Example: Added examples/dynamic_path/main.go demonstrating dynamic path mounting and usage. ```go mcpServer := mcp.NewMCPServer("dynamic-path-example", "1.0.0") sseServer := mcp.NewSSEServer( mcpServer, mcp.WithDynamicBasePath(func(r *http.Request, sessionID string) string { tenant := r.PathValue("tenant") return "/api/" + tenant }), mcp.WithBaseURL("http://localhost:8080"), ) mux := http.NewServeMux() mux.Handle("/api/{tenant}/sse", sseServer.SSEHandler()) mux.Handle("/api/{tenant}/message", sseServer.MessageHandler()) ``` --- examples/dynamic_path/main.go | 46 ++++++++ server/errors.go | 10 ++ server/sse.go | 150 ++++++++++++++++++++++---- server/sse_test.go | 192 ++++++++++++++++++++++++++++++++-- 4 files changed, 366 insertions(+), 32 deletions(-) create mode 100644 examples/dynamic_path/main.go diff --git a/examples/dynamic_path/main.go b/examples/dynamic_path/main.go new file mode 100644 index 00000000..5793fecb --- /dev/null +++ b/examples/dynamic_path/main.go @@ -0,0 +1,46 @@ +package main + +import ( + "context" + "flag" + "fmt" + "log" + "net/http" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +func main() { + var addr string + flag.StringVar(&addr, "addr", ":8080", "address to listen on") + flag.Parse() + + mcpServer := server.NewMCPServer("dynamic-path-example", "1.0.0") + + // Add a trivial tool for demonstration + mcpServer.AddTool(mcp.NewTool("echo"), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return mcp.NewToolResultText(fmt.Sprintf("Echo: %v", req.Params.Arguments["message"])), nil + }) + + // Use a dynamic base path based on a path parameter (Go 1.22+) + sseServer := server.NewSSEServer( + mcpServer, + server.WithDynamicBasePath(func(r *http.Request, sessionID string) string { + tenant := r.PathValue("tenant") + return "/api/" + tenant + }), + server.WithBaseURL(fmt.Sprintf("http://localhost%s", addr)), + server.WithUseFullURLForMessageEndpoint(true), + ) + + mux := http.NewServeMux() + mux.Handle("/api/{tenant}/sse", sseServer.SSEHandler()) + mux.Handle("/api/{tenant}/message", sseServer.MessageHandler()) + + log.Printf("Dynamic SSE server listening on %s", addr) + if err := http.ListenAndServe(addr, mux); err != nil { + log.Fatalf("Server error: %v", err) + } +} + diff --git a/server/errors.go b/server/errors.go index 7ced5cf7..b984a28c 100644 --- a/server/errors.go +++ b/server/errors.go @@ -2,6 +2,7 @@ package server import ( "errors" + "fmt" ) var ( @@ -21,3 +22,12 @@ var ( ErrNotificationNotInitialized = errors.New("notification channel not initialized") ErrNotificationChannelBlocked = errors.New("notification channel full or blocked") ) + +// ErrDynamicPathConfig is returned when attempting to use static path methods with dynamic path configuration +type ErrDynamicPathConfig struct { + Method string +} + +func (e *ErrDynamicPathConfig) Error() string { + return fmt.Sprintf("%s cannot be used with WithDynamicBasePath. Use dynamic path logic in your router.", e.Method) +} diff --git a/server/sse.go b/server/sse.go index 382664d4..705cc7d5 100644 --- a/server/sse.go +++ b/server/sse.go @@ -34,6 +34,13 @@ type sseSession struct { // content. This can be used to inject context values from headers, for example. type SSEContextFunc func(ctx context.Context, r *http.Request) context.Context +// DynamicBasePathFunc allows the user to provide a function to generate the +// base path for a given request and sessionID. This is useful for cases where +// the base path is not known at the time of SSE server creation, such as when +// using a reverse proxy or when the base path is dynamically generated. The +// function should return the base path (e.g., "/mcp/tenant123"). +type DynamicBasePathFunc func(r *http.Request, sessionID string) string + func (s *sseSession) SessionID() string { return s.sessionID } @@ -58,19 +65,19 @@ type SSEServer struct { server *MCPServer baseURL string basePath string + appendQueryToMessageEndpoint bool useFullURLForMessageEndpoint bool messageEndpoint string sseEndpoint string sessions sync.Map srv *http.Server contextFunc SSEContextFunc + dynamicBasePathFunc DynamicBasePathFunc keepAlive bool keepAliveInterval time.Duration mu sync.RWMutex - - appendQueryToMessageEndpoint bool } // SSEOption defines a function type for configuring SSEServer @@ -99,7 +106,7 @@ func WithBaseURL(baseURL string) SSEOption { } } -// WithBasePath adds a new option for setting base path +// WithBasePath adds a new option for setting a static base path func WithBasePath(basePath string) SSEOption { return func(s *SSEServer) { // Ensure the path starts with / and doesn't end with / @@ -110,6 +117,24 @@ func WithBasePath(basePath string) SSEOption { } } +// WithDynamicBasePath accepts a function for generating the base path. This is +// useful for cases where the base path is not known at the time of SSE server +// creation, such as when using a reverse proxy or when the server is mounted +// at a dynamic path. +func WithDynamicBasePath(fn DynamicBasePathFunc) SSEOption { + return func(s *SSEServer) { + if fn != nil { + s.dynamicBasePathFunc = func(r *http.Request, sid string) string { + bp := fn(r, sid) + if !strings.HasPrefix(bp, "/") { + bp = "/" + bp + } + return strings.TrimSuffix(bp, "/") + } + } + } +} + // WithMessageEndpoint sets the message endpoint path func WithMessageEndpoint(endpoint string) SSEOption { return func(s *SSEServer) { @@ -208,8 +233,8 @@ func (s *SSEServer) Start(addr string) error { if s.srv == nil { s.srv = &http.Server{ - Addr: addr, - Handler: s, + Addr: addr, + Handler: s, } } else { if s.srv.Addr == "" { @@ -218,7 +243,7 @@ func (s *SSEServer) Start(addr string) error { return fmt.Errorf("conflicting listen address: WithHTTPServer(%q) vs Start(%q)", s.srv.Addr, addr) } } - + return s.srv.ListenAndServe() } @@ -331,7 +356,7 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { } // Send the initial endpoint event - endpoint := s.GetMessageEndpointForClient(sessionID) + endpoint := s.GetMessageEndpointForClient(r, sessionID) if s.appendQueryToMessageEndpoint && len(r.URL.RawQuery) > 0 { endpoint += "&" + r.URL.RawQuery } @@ -355,13 +380,20 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { } // GetMessageEndpointForClient returns the appropriate message endpoint URL with session ID -// based on the useFullURLForMessageEndpoint configuration. -func (s *SSEServer) GetMessageEndpointForClient(sessionID string) string { - messageEndpoint := s.messageEndpoint - if s.useFullURLForMessageEndpoint { - messageEndpoint = s.CompleteMessageEndpoint() +// for the given request. This is the canonical way to compute the message endpoint for a client. +// It handles both dynamic and static path modes, and honors the WithUseFullURLForMessageEndpoint flag. +func (s *SSEServer) GetMessageEndpointForClient(r *http.Request, sessionID string) string { + basePath := s.basePath + if s.dynamicBasePathFunc != nil { + basePath = s.dynamicBasePathFunc(r, sessionID) } - return fmt.Sprintf("%s?sessionId=%s", messageEndpoint, sessionID) + + endpointPath := basePath + s.messageEndpoint + if s.useFullURLForMessageEndpoint && s.baseURL != "" { + endpointPath = s.baseURL + endpointPath + } + + return fmt.Sprintf("%s?sessionId=%s", endpointPath, sessionID) } // handleMessage processes incoming JSON-RPC messages from clients and sends responses @@ -479,32 +511,108 @@ func (s *SSEServer) GetUrlPath(input string) (string, error) { return parse.Path, nil } -func (s *SSEServer) CompleteSseEndpoint() string { - return s.baseURL + s.basePath + s.sseEndpoint +func (s *SSEServer) CompleteSseEndpoint() (string, error) { + if s.dynamicBasePathFunc != nil { + return "", &ErrDynamicPathConfig{Method: "CompleteSseEndpoint"} + } + return s.baseURL + s.basePath + s.sseEndpoint, nil } func (s *SSEServer) CompleteSsePath() string { - path, err := s.GetUrlPath(s.CompleteSseEndpoint()) + path, err := s.CompleteSseEndpoint() + if err != nil { + return s.basePath + s.sseEndpoint + } + urlPath, err := s.GetUrlPath(path) if err != nil { return s.basePath + s.sseEndpoint } - return path + return urlPath } -func (s *SSEServer) CompleteMessageEndpoint() string { - return s.baseURL + s.basePath + s.messageEndpoint +func (s *SSEServer) CompleteMessageEndpoint() (string, error) { + if s.dynamicBasePathFunc != nil { + return "", &ErrDynamicPathConfig{Method: "CompleteMessageEndpoint"} + } + return s.baseURL + s.basePath + s.messageEndpoint, nil } func (s *SSEServer) CompleteMessagePath() string { - path, err := s.GetUrlPath(s.CompleteMessageEndpoint()) + path, err := s.CompleteMessageEndpoint() + if err != nil { + return s.basePath + s.messageEndpoint + } + urlPath, err := s.GetUrlPath(path) if err != nil { return s.basePath + s.messageEndpoint } - return path + return urlPath +} + +// SSEHandler returns an http.Handler for the SSE endpoint. +// +// This method allows you to mount the SSE handler at any arbitrary path +// using your own router (e.g. net/http, gorilla/mux, chi, etc.). It is +// intended for advanced scenarios where you want to control the routing or +// support dynamic segments. +// +// IMPORTANT: When using this handler in advanced/dynamic mounting scenarios, +// you must use the WithDynamicBasePath option to ensure the correct base path +// is communicated to clients. +// +// Example usage: +// +// // Advanced/dynamic: +// sseServer := NewSSEServer(mcpServer, +// WithDynamicBasePath(func(r *http.Request, sessionID string) string { +// tenant := r.PathValue("tenant") +// return "/mcp/" + tenant +// }), +// WithBaseURL("http://localhost:8080") +// ) +// mux.Handle("/mcp/{tenant}/sse", sseServer.SSEHandler()) +// mux.Handle("/mcp/{tenant}/message", sseServer.MessageHandler()) +// +// For non-dynamic cases, use ServeHTTP method instead. +func (s *SSEServer) SSEHandler() http.Handler { + return http.HandlerFunc(s.handleSSE) +} + +// MessageHandler returns an http.Handler for the message endpoint. +// +// This method allows you to mount the message handler at any arbitrary path +// using your own router (e.g. net/http, gorilla/mux, chi, etc.). It is +// intended for advanced scenarios where you want to control the routing or +// support dynamic segments. +// +// IMPORTANT: When using this handler in advanced/dynamic mounting scenarios, +// you must use the WithDynamicBasePath option to ensure the correct base path +// is communicated to clients. +// +// Example usage: +// +// // Advanced/dynamic: +// sseServer := NewSSEServer(mcpServer, +// WithDynamicBasePath(func(r *http.Request, sessionID string) string { +// tenant := r.PathValue("tenant") +// return "/mcp/" + tenant +// }), +// WithBaseURL("http://localhost:8080") +// ) +// mux.Handle("/mcp/{tenant}/sse", sseServer.SSEHandler()) +// mux.Handle("/mcp/{tenant}/message", sseServer.MessageHandler()) +// +// For non-dynamic cases, use ServeHTTP method instead. +func (s *SSEServer) MessageHandler() http.Handler { + return http.HandlerFunc(s.handleMessage) } // ServeHTTP implements the http.Handler interface. func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if s.dynamicBasePathFunc != nil { + http.Error(w, (&ErrDynamicPathConfig{Method: "ServeHTTP"}).Error(), http.StatusInternalServerError) + return + } path := r.URL.Path // Use exact path matching rather than Contains ssePath := s.CompleteSsePath() diff --git a/server/sse_test.go b/server/sse_test.go index a1ce01fa..9ef95510 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -16,6 +16,7 @@ import ( "time" "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/require" ) func TestSSEServer(t *testing.T) { @@ -443,17 +444,25 @@ func TestSSEServer(t *testing.T) { t.Errorf("Expected status 200, got %d", resp.StatusCode) } - // Read the endpoint event - buf := make([]byte, 1024) - n, err := resp.Body.Read(buf) - if err != nil { - t.Fatalf("Failed to read SSE response: %v", err) + // Read the endpoint event using a bufio.Reader loop to ensure we get the full SSE frame + reader := bufio.NewReader(resp.Body) + var endpointEvent strings.Builder + for { + line, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("Failed to read SSE response: %v", err) + } + endpointEvent.WriteString(line) + if line == "\n" || line == "\r\n" { + break // End of SSE frame + } } - - endpointEvent := string(buf[:n]) - messageURL := strings.TrimSpace( - strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0], - ) + endpointEventStr := endpointEvent.String() + if !strings.Contains(endpointEventStr, "event: endpoint") { + t.Fatalf("Expected endpoint event, got: %s", endpointEventStr) + } + // Extract message endpoint and check correctness + messageURL := strings.TrimSpace(strings.Split(strings.Split(endpointEventStr, "data: ")[1], "\n")[0]) if !strings.HasPrefix(messageURL, sseServer.messageEndpoint) { t.Errorf("Expected messageURL to be %s, got %s", sseServer.messageEndpoint, messageURL) } @@ -613,7 +622,6 @@ func TestSSEServer(t *testing.T) { "application/json", bytes.NewBuffer(requestBody), ) - if err != nil { t.Fatalf("Failed to send message: %v", err) } @@ -861,6 +869,168 @@ func TestSSEServer(t *testing.T) { } } }) + + t.Run("TestSSEHandlerWithDynamicMounting", func(t *testing.T) { + mcpServer := NewMCPServer("test", "1.0.0") + // MessageEndpointFunc that extracts tenant from the path using Go 1.22+ PathValue + + sseServer := NewSSEServer( + mcpServer, + WithDynamicBasePath(func(r *http.Request, sessionID string) string { + tenant := r.PathValue("tenant") + return "/mcp/" + tenant + }), + ) + + mux := http.NewServeMux() + mux.Handle("/mcp/{tenant}/sse", sseServer.SSEHandler()) + mux.Handle("/mcp/{tenant}/message", sseServer.MessageHandler()) + + ts := httptest.NewServer(mux) + defer ts.Close() + + // Use a dynamic tenant + tenant := "tenant123" + // Connect to SSE endpoint + req, _ := http.NewRequest("GET", ts.URL+"/mcp/"+tenant+"/sse", nil) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Failed to connect to SSE endpoint: %v", err) + } + defer resp.Body.Close() + + reader := bufio.NewReader(resp.Body) + var endpointEvent strings.Builder + for { + line, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("Failed to read SSE response: %v", err) + } + endpointEvent.WriteString(line) + if line == "\n" || line == "\r\n" { + break // End of SSE frame + } + } + endpointEventStr := endpointEvent.String() + if !strings.Contains(endpointEventStr, "event: endpoint") { + t.Fatalf("Expected endpoint event, got: %s", endpointEventStr) + } + // Extract message endpoint and check correctness + messageURL := strings.TrimSpace(strings.Split(strings.Split(endpointEventStr, "data: ")[1], "\n")[0]) + if !strings.HasPrefix(messageURL, "/mcp/"+tenant+"/message") { + t.Errorf("Expected message endpoint to start with /mcp/%s/message, got %s", tenant, messageURL) + } + + // Optionally, test sending a message to the message endpoint + initRequest := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "clientInfo": map[string]interface{}{ + "name": "test-client", + "version": "1.0.0", + }, + }, + } + requestBody, err := json.Marshal(initRequest) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + + // The message endpoint is relative, so prepend the test server URL + fullMessageURL := ts.URL + messageURL + resp2, err := http.Post(fullMessageURL, "application/json", bytes.NewBuffer(requestBody)) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + defer resp2.Body.Close() + + if resp2.StatusCode != http.StatusAccepted { + t.Errorf("Expected status 202, got %d", resp2.StatusCode) + } + + // Read the response from the SSE stream + reader = bufio.NewReader(resp.Body) + var initResponse strings.Builder + for { + line, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("Failed to read SSE response: %v", err) + } + initResponse.WriteString(line) + if line == "\n" || line == "\r\n" { + break // End of SSE frame + } + } + initResponseStr := initResponse.String() + if !strings.Contains(initResponseStr, "event: message") { + t.Fatalf("Expected message event, got: %s", initResponseStr) + } + + // Extract and parse the response data + respData := strings.TrimSpace(strings.Split(strings.Split(initResponseStr, "data: ")[1], "\n")[0]) + var response map[string]interface{} + if err := json.NewDecoder(strings.NewReader(respData)).Decode(&response); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + if response["jsonrpc"] != "2.0" { + t.Errorf("Expected jsonrpc 2.0, got %v", response["jsonrpc"]) + } + if response["id"].(float64) != 1 { + t.Errorf("Expected id 1, got %v", response["id"]) + } + }) + t.Run("TestSSEHandlerRequiresDynamicBasePath", func(t *testing.T) { + mcpServer := NewMCPServer("test", "1.0.0") + sseServer := NewSSEServer(mcpServer) + require.NotPanics(t, func() { sseServer.SSEHandler() }) + require.NotPanics(t, func() { sseServer.MessageHandler() }) + + sseServer = NewSSEServer( + mcpServer, + WithDynamicBasePath(func(r *http.Request, sessionID string) string { + return "/foo" + }), + ) + req := httptest.NewRequest("GET", "/foo/sse", nil) + w := httptest.NewRecorder() + + sseServer.ServeHTTP(w, req) + require.Equal(t, http.StatusInternalServerError, w.Code) + require.Contains(t, w.Body.String(), "ServeHTTP cannot be used with WithDynamicBasePath") + }) + + t.Run("TestCompleteSseEndpointAndMessageEndpointErrors", func(t *testing.T) { + mcpServer := NewMCPServer("test", "1.0.0") + sseServer := NewSSEServer(mcpServer, WithDynamicBasePath(func(r *http.Request, sessionID string) string { + return "/foo" + })) + + // Test CompleteSseEndpoint + endpoint, err := sseServer.CompleteSseEndpoint() + require.Error(t, err) + var dynamicPathErr *ErrDynamicPathConfig + require.ErrorAs(t, err, &dynamicPathErr) + require.Equal(t, "CompleteSseEndpoint", dynamicPathErr.Method) + require.Empty(t, endpoint) + + // Test CompleteMessageEndpoint + messageEndpoint, err := sseServer.CompleteMessageEndpoint() + require.Error(t, err) + require.ErrorAs(t, err, &dynamicPathErr) + require.Equal(t, "CompleteMessageEndpoint", dynamicPathErr.Method) + require.Empty(t, messageEndpoint) + + // Test that path methods still work and return fallback values + ssePath := sseServer.CompleteSsePath() + require.Equal(t, sseServer.basePath+sseServer.sseEndpoint, ssePath) + + messagePath := sseServer.CompleteMessagePath() + require.Equal(t, sseServer.basePath+sseServer.messageEndpoint, messagePath) + }) } func readSeeEvent(sseResp *http.Response) (string, error) { From 40c8a68761e5f3a2e7ebfa85b8f2d18a98dbb553 Mon Sep 17 00:00:00 2001 From: Robert Jackson Date: Thu, 1 May 2025 08:22:53 -0400 Subject: [PATCH 2/2] refactor(server): standardize URL path handling with normalizeURLPath Replace manual path manipulation with a dedicated normalizeURLPath function that properly handles path joining while ensuring consistent formatting. The function: - Always starts paths with a leading slash - Never ends paths with a trailing slash (except for root path "/") - Uses path.Join internally for proper path normalization - Handles edge cases like empty segments, double slashes, and parent references This eliminates duplicated code and creates a more consistent approach to URL path handling throughout the SSE server implementation. Comprehensive tests were added to validate the function's behavior. --- server/sse.go | 47 ++++++++++++++-------- server/sse_test.go | 98 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 129 insertions(+), 16 deletions(-) diff --git a/server/sse.go b/server/sse.go index 705cc7d5..e380d20a 100644 --- a/server/sse.go +++ b/server/sse.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "path" "strings" "sync" "sync/atomic" @@ -109,11 +110,7 @@ func WithBaseURL(baseURL string) SSEOption { // WithBasePath adds a new option for setting a static base path func WithBasePath(basePath string) SSEOption { return func(s *SSEServer) { - // Ensure the path starts with / and doesn't end with / - if !strings.HasPrefix(basePath, "/") { - basePath = "/" + basePath - } - s.basePath = strings.TrimSuffix(basePath, "/") + s.basePath = normalizeURLPath(basePath) } } @@ -126,10 +123,7 @@ func WithDynamicBasePath(fn DynamicBasePathFunc) SSEOption { if fn != nil { s.dynamicBasePathFunc = func(r *http.Request, sid string) string { bp := fn(r, sid) - if !strings.HasPrefix(bp, "/") { - bp = "/" + bp - } - return strings.TrimSuffix(bp, "/") + return normalizeURLPath(bp) } } } @@ -388,7 +382,7 @@ func (s *SSEServer) GetMessageEndpointForClient(r *http.Request, sessionID strin basePath = s.dynamicBasePathFunc(r, sessionID) } - endpointPath := basePath + s.messageEndpoint + endpointPath := normalizeURLPath(basePath, s.messageEndpoint) if s.useFullURLForMessageEndpoint && s.baseURL != "" { endpointPath = s.baseURL + endpointPath } @@ -515,17 +509,19 @@ func (s *SSEServer) CompleteSseEndpoint() (string, error) { if s.dynamicBasePathFunc != nil { return "", &ErrDynamicPathConfig{Method: "CompleteSseEndpoint"} } - return s.baseURL + s.basePath + s.sseEndpoint, nil + + path := normalizeURLPath(s.basePath, s.sseEndpoint) + return s.baseURL + path, nil } func (s *SSEServer) CompleteSsePath() string { path, err := s.CompleteSseEndpoint() if err != nil { - return s.basePath + s.sseEndpoint + return normalizeURLPath(s.basePath, s.sseEndpoint) } urlPath, err := s.GetUrlPath(path) if err != nil { - return s.basePath + s.sseEndpoint + return normalizeURLPath(s.basePath, s.sseEndpoint) } return urlPath } @@ -534,17 +530,18 @@ func (s *SSEServer) CompleteMessageEndpoint() (string, error) { if s.dynamicBasePathFunc != nil { return "", &ErrDynamicPathConfig{Method: "CompleteMessageEndpoint"} } - return s.baseURL + s.basePath + s.messageEndpoint, nil + path := normalizeURLPath(s.basePath, s.messageEndpoint) + return s.baseURL + path, nil } func (s *SSEServer) CompleteMessagePath() string { path, err := s.CompleteMessageEndpoint() if err != nil { - return s.basePath + s.messageEndpoint + return normalizeURLPath(s.basePath, s.messageEndpoint) } urlPath, err := s.GetUrlPath(path) if err != nil { - return s.basePath + s.messageEndpoint + return normalizeURLPath(s.basePath, s.messageEndpoint) } return urlPath } @@ -628,3 +625,21 @@ func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.NotFound(w, r) } + +// normalizeURLPath joins path elements like path.Join but ensures the +// result always starts with a leading slash and never ends with a slash +func normalizeURLPath(elem ...string) string { + joined := path.Join(elem...) + + // Ensure leading slash + if !strings.HasPrefix(joined, "/") { + joined = "/" + joined + } + + // Remove trailing slash if not just "/" + if len(joined) > 1 && strings.HasSuffix(joined, "/") { + joined = joined[:len(joined)-1] + } + + return joined +} diff --git a/server/sse_test.go b/server/sse_test.go index 9ef95510..a121581a 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -1031,6 +1031,104 @@ func TestSSEServer(t *testing.T) { messagePath := sseServer.CompleteMessagePath() require.Equal(t, sseServer.basePath+sseServer.messageEndpoint, messagePath) }) + + t.Run("TestNormalizeURLPath", func(t *testing.T) { + tests := []struct { + name string + inputs []string + expected string + }{ + // Basic path joining + { + name: "empty inputs", + inputs: []string{"", ""}, + expected: "/", + }, + { + name: "single path segment", + inputs: []string{"mcp"}, + expected: "/mcp", + }, + { + name: "multiple path segments", + inputs: []string{"mcp", "api", "message"}, + expected: "/mcp/api/message", + }, + + // Leading slash handling + { + name: "already has leading slash", + inputs: []string{"/mcp", "message"}, + expected: "/mcp/message", + }, + { + name: "mixed leading slashes", + inputs: []string{"/mcp", "/message"}, + expected: "/mcp/message", + }, + + // Trailing slash handling + { + name: "with trailing slashes", + inputs: []string{"mcp/", "message/"}, + expected: "/mcp/message", + }, + { + name: "mixed trailing slashes", + inputs: []string{"mcp", "message/"}, + expected: "/mcp/message", + }, + { + name: "root path", + inputs: []string{"/"}, + expected: "/", + }, + + // Path normalization + { + name: "normalize double slashes", + inputs: []string{"mcp//api", "//message"}, + expected: "/mcp/api/message", + }, + { + name: "normalize parent directory", + inputs: []string{"mcp/parent/../child", "message"}, + expected: "/mcp/child/message", + }, + { + name: "normalize current directory", + inputs: []string{"mcp/./api", "./message"}, + expected: "/mcp/api/message", + }, + + // Complex cases + { + name: "complex mixed case", + inputs: []string{"/mcp/", "/api//", "message/"}, + expected: "/mcp/api/message", + }, + { + name: "absolute path in second segment", + inputs: []string{"tenant", "/message"}, + expected: "/tenant/message", + }, + { + name: "URL pattern with parameters", + inputs: []string{"/mcp/{tenant}", "message"}, + expected: "/mcp/{tenant}/message", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := normalizeURLPath(tt.inputs...) + if result != tt.expected { + t.Errorf("normalizeURLPath(%q) = %q, want %q", + tt.inputs, result, tt.expected) + } + }) + } + }) } func readSeeEvent(sseResp *http.Response) (string, error) {