diff --git a/client/transport/sse.go b/client/transport/sse.go index 501036d1..a8b56fd3 100644 --- a/client/transport/sse.go +++ b/client/transport/sse.go @@ -17,6 +17,18 @@ import ( "github.com/mark3labs/mcp-go/mcp" ) +// OnBeforeRequestFunc is called before sending the request, with context. +type OnBeforeRequestFunc func(ctx context.Context, req *http.Request) + +// OnAfterResponseFunc is called after receiving the response, with context. (Regardless of error, when err is not nil resp may be nil.) The req parameter is included. +type OnAfterResponseFunc func(ctx context.Context, req *http.Request, resp *http.Response, err error) + +// SSEHooks supports multiple before and after processing functions. +type SSEHooks struct { + OnBeforeRequest []OnBeforeRequestFunc + OnAfterResponse []OnAfterResponseFunc +} + // SSE implements the transport layer of the MCP protocol using Server-Sent Events (SSE). // It maintains a persistent HTTP connection to receive server-pushed events // while sending requests over regular HTTP POST calls. The client handles @@ -32,6 +44,8 @@ type SSE struct { endpointChan chan struct{} headers map[string]string + hooks SSEHooks + started atomic.Bool closed atomic.Bool cancelSSEStream context.CancelFunc @@ -45,6 +59,13 @@ func WithHeaders(headers map[string]string) ClientOption { } } +// Register a set of hooks (overwrites existing hooks) +func WithSSEHooks(hooks SSEHooks) ClientOption { + return func(sc *SSE) { + sc.hooks = hooks + } +} + // NewSSE creates a new SSE-based MCP client with the given base URL. // Returns an error if the URL is invalid. func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) { @@ -261,6 +282,13 @@ func (c *SSE) SendRequest( req.Header.Set(k, v) } + // hooks: before request + for _, hook := range c.hooks.OnBeforeRequest { + if hook != nil { + hook(ctx, req) + } + } + // Register response channel responseChan := make(chan *JSONRPCResponse, 1) c.mu.Lock() @@ -274,6 +302,14 @@ func (c *SSE) SendRequest( // Send request resp, err := c.httpClient.Do(req) + + // hooks: after response + for _, hook := range c.hooks.OnAfterResponse { + if hook != nil { + hook(ctx, req, resp, err) + } + } + if err != nil { deleteResponseChan() return nil, fmt.Errorf("failed to send request: %w", err) @@ -348,7 +384,22 @@ func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNoti req.Header.Set(k, v) } + // hooks: before request + for _, hook := range c.hooks.OnBeforeRequest { + if hook != nil { + hook(ctx, req) + } + } + resp, err := c.httpClient.Do(req) + + // hooks: after response + for _, hook := range c.hooks.OnAfterResponse { + if hook != nil { + hook(ctx, req, resp, err) + } + } + if err != nil { return fmt.Errorf("failed to send notification: %w", err) }