diff --git a/client/sse.go b/client/sse.go index cf4a102..00ba59f 100644 --- a/client/sse.go +++ b/client/sse.go @@ -41,6 +41,9 @@ type SSEMCPClient struct { type ClientOption func(*SSEMCPClient) +// WithHeaders sets custom HTTP headers that will be included in all requests made by the client. +// This is particularly useful for authentication (e.g., bearer tokens, API keys) and other +// custom header requirements. func WithHeaders(headers map[string]string) ClientOption { return func(sc *SSEMCPClient) { sc.headers = headers @@ -55,6 +58,15 @@ func WithSSEReadTimeout(timeout time.Duration) ClientOption { // NewSSEMCPClient creates a new SSE-based MCP client with the given base URL. // Returns an error if the URL is invalid. +// Example: +// +// // Create a client with authentication headers +// client, err := NewSSEMCPClient( +// "https://mcp.example.com", +// WithHeaders(map[string]string{ +// "Authorization": "Bearer your-token-here", +// }), +// ) func NewSSEMCPClient(baseURL string, options ...ClientOption) (*SSEMCPClient, error) { parsedURL, err := url.Parse(baseURL) if err != nil { @@ -94,6 +106,16 @@ func (c *SSEMCPClient) Start(ctx context.Context) error { req.Header.Set("Cache-Control", "no-cache") req.Header.Set("Connection", "keep-alive") + // Set custom http headers + for k, v := range c.headers { + // Skip standard headers that should not be overridden + switch k { + case "Accept", "Cache-Control", "Connection", "Content-Type": + continue + } + req.Header.Set(k, v) + } + resp, err := c.httpClient.Do(req) if err != nil { return fmt.Errorf("failed to connect to SSE stream: %w", err) @@ -301,8 +323,13 @@ func (c *SSEMCPClient) sendRequest( } req.Header.Set("Content-Type", "application/json") - // set custom http headers + // Set custom http headers for k, v := range c.headers { + // Skip standard headers that should not be overridden + switch k { + case "Accept", "Cache-Control", "Connection", "Content-Type": + continue + } req.Header.Set(k, v) } @@ -391,6 +418,15 @@ func (c *SSEMCPClient) Initialize( } req.Header.Set("Content-Type", "application/json") + // Set custom http headers + for k, v := range c.headers { + // Skip standard headers that should not be overridden + switch k { + case "Accept", "Cache-Control", "Connection", "Content-Type": + continue + } + req.Header.Set(k, v) + } resp, err := c.httpClient.Do(req) if err != nil {