diff --git a/examples/current_time_server/main.go b/examples/current_time_server/main.go index 7a54f7d..e10c16b 100644 --- a/examples/current_time_server/main.go +++ b/examples/current_time_server/main.go @@ -12,6 +12,7 @@ import ( "github.com/ThinkInAIXYZ/go-mcp/protocol" "github.com/ThinkInAIXYZ/go-mcp/server" + "github.com/ThinkInAIXYZ/go-mcp/server/components" "github.com/ThinkInAIXYZ/go-mcp/transport" ) @@ -27,13 +28,24 @@ func main() { Name: "current-time-v2-server", Version: "1.0.0", }), + // 创建一个每秒5个请求,突发容量为10的限速器 + server.WithRateLimiter(components.NewTokenBucketLimiter(components.Rate{ + Limit: 5.0, // 每秒5个请求 + Burst: 10, // 最多允许10个请求的突发 + })), ) if err != nil { log.Fatalf("Failed to create server: %v", err) } // new protocol tool with name, descipriton and properties - tool, err := protocol.NewTool("current_time", "Get current time with timezone, Asia/Shanghai is default", currentTimeReq{}) + tool, err := protocol.NewTool("current_time", "Get current time with timezone, Asia/Shanghai is default", currentTimeReq{}, + // 为指定工具设置每秒10个请求,突发容量为20的限速器 + protocol.WithRateLimit(srv.GetLimiter(), components.Rate{ + Limit: 10.0, // 每秒10个请求 + Burst: 20, // 最多允许20个请求的突发 + }), + ) if err != nil { log.Fatalf("Failed to create tool: %v", err) return diff --git a/pkg/errors.go b/pkg/errors.go index 3fe9eec..fc7e3b1 100644 --- a/pkg/errors.go +++ b/pkg/errors.go @@ -15,6 +15,7 @@ var ( ErrSessionHasNotInitialized = errors.New("the session has not been initialized") ErrLackSession = errors.New("lack session") ErrSendEOF = errors.New("send EOF") + ErrRateLimitExceeded = errors.New("rate limit exceeded") ) type ResponseError struct { diff --git a/protocol/tools.go b/protocol/tools.go index 7135b8c..9f83323 100644 --- a/protocol/tools.go +++ b/protocol/tools.go @@ -5,8 +5,18 @@ import ( "fmt" "github.com/ThinkInAIXYZ/go-mcp/pkg" + "github.com/ThinkInAIXYZ/go-mcp/server/components" ) +type Option func(*Tool) + +// WithRateLimit 设置工具的速率限制 +func WithRateLimit(rateLimit components.RateLimiter, rate components.Rate) Option { + return func(t *Tool) { + rateLimit.SetToolLimit(t.Name, rate) + } +} + // ListToolsRequest represents a request to list available tools type ListToolsRequest struct{} @@ -177,17 +187,23 @@ type ToolListChangedNotification struct { } // NewTool create a tool -func NewTool(name string, description string, inputReqStruct interface{}) (*Tool, error) { +func NewTool(name string, description string, inputReqStruct interface{}, opts ...Option) (*Tool, error) { schema, err := generateSchemaFromReqStruct(inputReqStruct) if err != nil { return nil, err } - return &Tool{ + t := &Tool{ Name: name, Description: description, InputSchema: *schema, - }, nil + } + + for _, opt := range opts { + opt(t) + } + + return t, nil } func NewToolWithRawSchema(name, description string, schema json.RawMessage) *Tool { diff --git a/server/components/limiter.go b/server/components/limiter.go new file mode 100644 index 0000000..3bd4ade --- /dev/null +++ b/server/components/limiter.go @@ -0,0 +1,95 @@ +package components + +import ( + "sync" + "time" +) + +// RateLimiter 定义速率限制接口 +type RateLimiter interface { + Allow(toolName string) bool + SetToolLimit(toolName string, rate Rate) +} + +// TokenBucketLimiter 令牌桶限速器实现 +type TokenBucketLimiter struct { + mu sync.RWMutex + buckets map[string]*bucket + defaultLimit Rate + toolLimits map[string]Rate +} + +// Rate 定义速率限制参数 +type Rate struct { + Limit float64 // 每秒允许的请求数 + Burst int // 突发请求上限 +} + +// bucket 令牌桶 +type bucket struct { + tokens float64 + lastTimestamp time.Time + rate Rate +} + +// NewTokenBucketLimiter 创建新的令牌桶限速器 +func NewTokenBucketLimiter(defaultRate Rate) *TokenBucketLimiter { + return &TokenBucketLimiter{ + buckets: make(map[string]*bucket), + defaultLimit: defaultRate, + toolLimits: make(map[string]Rate), + } +} + +// SetToolLimit 为特定工具设置限制 +func (l *TokenBucketLimiter) SetToolLimit(toolName string, rate Rate) { + l.mu.Lock() + defer l.mu.Unlock() + + l.toolLimits[toolName] = rate + // 如果已有桶,更新其速率 + if b, exists := l.buckets[toolName]; exists { + b.rate = rate + } +} + +// Allow 检查请求是否被允许 +func (l *TokenBucketLimiter) Allow(toolName string) bool { + l.mu.RLock() + defer l.mu.RUnlock() + + now := time.Now() + + // 获取或创建桶 + b, exists := l.buckets[toolName] + if !exists { + // 查找工具特定的限制,如果没有则使用默认限制 + rate, exists := l.toolLimits[toolName] + if !exists { + rate = l.defaultLimit + } + + b = &bucket{ + tokens: float64(rate.Burst), + lastTimestamp: now, + rate: rate, + } + l.buckets[toolName] = b + } + + // 计算从上次请求到现在应该添加的令牌 + elapsed := now.Sub(b.lastTimestamp).Seconds() + b.lastTimestamp = now + + // 添加令牌,但不超过最大值 + b.tokens += elapsed * b.rate.Limit + if b.tokens > float64(b.rate.Burst) { + b.tokens = float64(b.rate.Burst) + } + + if b.tokens >= 1.0 { + b.tokens -= 1.0 + return true + } + return false +} diff --git a/server/handle.go b/server/handle.go index 86c7f77..829b5a9 100644 --- a/server/handle.go +++ b/server/handle.go @@ -230,6 +230,11 @@ func (server *Server) handleRequestWithCallTool(rawParams json.RawMessage) (*pro return nil, err } + // 检查速率限制 + if server.limiter != nil && !server.limiter.Allow(request.Name) { + return nil, pkg.ErrRateLimitExceeded + } + entry, ok := server.tools.Load(request.Name) if !ok { return nil, fmt.Errorf("missing tool, toolName=%s", request.Name) diff --git a/server/server.go b/server/server.go index 4af102d..a463c55 100644 --- a/server/server.go +++ b/server/server.go @@ -8,6 +8,7 @@ import ( "github.com/ThinkInAIXYZ/go-mcp/pkg" "github.com/ThinkInAIXYZ/go-mcp/protocol" + "github.com/ThinkInAIXYZ/go-mcp/server/components" "github.com/ThinkInAIXYZ/go-mcp/server/session" "github.com/ThinkInAIXYZ/go-mcp/transport" ) @@ -44,6 +45,24 @@ func WithLogger(logger pkg.Logger) Option { } } +// WithRateLimiter 添加一个创建带速率限制的服务器的选项 +func WithRateLimiter(limiter components.RateLimiter) Option { + return func(s *Server) { + s.limiter = limiter + } +} + +// WithDefaultRateLimiter 添加一个便捷函数用默认参数创建限速器 +func WithDefaultRateLimiter() Option { + const defaultLimit = 5.0 + const defaultBurst = 10 + limiter := components.NewTokenBucketLimiter(components.Rate{ + Limit: defaultLimit, + Burst: defaultBurst, + }) + return WithRateLimiter(limiter) +} + type Server struct { transport transport.ServerTransport @@ -53,6 +72,7 @@ type Server struct { resourceTemplates pkg.SyncMap[*resourceTemplateEntry] sessionManager *session.Manager + limiter components.RateLimiter inShutdown *pkg.AtomicBool // true when server is in shutdown inFlyRequest sync.WaitGroup @@ -244,3 +264,6 @@ func (server *Server) sessionDetection(ctx context.Context, sessionID string) er } return nil } + +// GetLimiter 获取限速器 +func (server *Server) GetLimiter() components.RateLimiter { return server.limiter } diff --git a/server/server_test.go b/server/server_test.go index eaa7f1e..5a01ecb 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -6,11 +6,13 @@ import ( "io" "reflect" "testing" + "time" "github.com/google/uuid" "github.com/ThinkInAIXYZ/go-mcp/pkg" "github.com/ThinkInAIXYZ/go-mcp/protocol" + "github.com/ThinkInAIXYZ/go-mcp/server/components" "github.com/ThinkInAIXYZ/go-mcp/transport" ) @@ -505,3 +507,200 @@ func testServerInit(t *testing.T, server *Server, in io.Writer, outScan *bufio.S t.Fatalf("in Write: %+v", err) } } + +type testLimiter struct { + name string + rate components.Rate + numRequests int + requestInterval time.Duration // Interval between requests + expectedErrorCount int + description string +} + +func TestServerRateLimiters(t *testing.T) { + tests := []testLimiter{ + { + name: "rapid_requests_exceed_burst", + rate: components.Rate{ + Limit: 5.0, + Burst: 10, + }, + numRequests: 15, + requestInterval: 0, // No delay between requests + expectedErrorCount: 5, + description: "Sending requests rapidly should exceed burst limit and trigger rate limiting", + }, + { + name: "slow_requests_under_limit", + rate: components.Rate{ + Limit: 5.0, + Burst: 5, + }, + numRequests: 10, + requestInterval: 210 * time.Millisecond, // ~4.7 req/s, under the 5.0 limit + expectedErrorCount: 0, + description: "Sending requests under the rate limit should not trigger rate limiting", + }, + { + name: "mixed_rate_pattern", + rate: components.Rate{ + Limit: 10.0, + Burst: 5, + }, + numRequests: 20, + requestInterval: 50 * time.Millisecond, // 20 req/s, above the 10.0 limit + expectedErrorCount: 5, + description: "Sending requests at a rate higher than limit should trigger rate limiting after burst is consumed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testServerRateLimiter(t, tt) + }) + } +} + +func testServerRateLimiter(t *testing.T, tt testLimiter) { + // Set up pipes for communication + reader1, writer1 := io.Pipe() + reader2, writer2 := io.Pipe() + + var ( + in = struct { + reader io.ReadCloser + writer io.WriteCloser + }{ + reader: reader1, + writer: writer1, + } + + out = struct { + reader io.ReadCloser + writer io.WriteCloser + }{ + reader: reader2, + writer: writer2, + } + + outScan = bufio.NewScanner(out.reader) + ) + + // Create server with rate limiter + server, err := NewServer( + transport.NewMockServerTransport(in.reader, out.writer), + WithServerInfo(protocol.Implementation{ + Name: "ExampleServer", + Version: "1.0.0", + }), + WithRateLimiter(components.NewTokenBucketLimiter(tt.rate)), + ) + if err != nil { + t.Fatalf("NewServer: %+v", err) + } + + // Register test tool + testTool, err := protocol.NewTool("test_tool", "test_tool", currentTimeReq{}) + if err != nil { + t.Fatalf("NewTool: %+v", err) + return + } + testToolCallContent := protocol.TextContent{ + Type: "text", + Text: "pong", + } + + // Add minimal processing delay to simulate real-world scenario + server.RegisterTool(testTool, func(_ *protocol.CallToolRequest) (*protocol.CallToolResult, error) { + time.Sleep(5 * time.Millisecond) // Small processing delay + return &protocol.CallToolResult{ + Content: []protocol.Content{testToolCallContent}, + }, nil + }) + + // Start server + serverErrCh := make(chan error, 1) + go func() { + if err := server.Run(); err != nil { + serverErrCh <- err + } + }() + + // Initialize server + testServerInit(t, server, in.writer, outScan) + + // Test rate limiting by sending multiple requests + errorCount := 0 + successCount := 0 + + for i := 0; i < tt.numRequests; i++ { + uuid, _ := uuid.NewUUID() + req := protocol.NewJSONRPCRequest(uuid, protocol.ToolsCall, protocol.CallToolRequest{ + Name: testTool.Name, + }) + reqBytes, err := json.Marshal(req) + if err != nil { + t.Fatalf("json Marshal: %+v", err) + } + + if _, err = in.writer.Write(append(reqBytes, "\n"...)); err != nil { + t.Fatalf("in Write: %+v", err) + } + + var respBytes []byte + if outScan.Scan() { + respBytes = outScan.Bytes() + if outScan.Err() != nil { + t.Fatalf("outScan: %+v", err) + } + } + + var resp map[string]interface{} + if err = pkg.JSONUnmarshal(respBytes, &resp); err != nil { + t.Fatal(err) + } + + // Check if response contains error + if errObj, exists := resp["error"]; exists { + errorObj, ok := errObj.(map[string]interface{}) + if ok { + // Check if it's a rate limit error + if code, codeExists := errorObj["code"].(float64); codeExists && code == float64(-32603) { + errorCount++ + } + } + } else { + successCount++ + } + + // Apply interval between requests if specified + if tt.requestInterval > 0 && i < tt.numRequests-1 { + time.Sleep(tt.requestInterval) + } + } + + // duration := time.Since(startTime) + + // Verify that we got the expected number of rate limit errors + if errorCount != tt.expectedErrorCount { + t.Errorf("Expected %d rate limit errors, got %d", tt.expectedErrorCount, errorCount) + } + + // Verify that successful + errors = total requests + if successCount+errorCount != tt.numRequests { + t.Errorf("Request count mismatch: got %d successes + %d errors = %d, expected total %d", + successCount, errorCount, successCount+errorCount, tt.numRequests) + } + + // Cleanup + in.writer.Close() + out.reader.Close() + + // Check if server encountered errors + select { + case err := <-serverErrCh: + t.Fatalf("Server error: %v", err) + default: + // No error, continue + } +}