Skip to content

feat: implement rate limiting for server and tools #125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion examples/current_time_server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/ThinkInAIXYZ/go-mcp/protocol"
"github.com/ThinkInAIXYZ/go-mcp/server"
"github.com/ThinkInAIXYZ/go-mcp/server/components"
"github.com/ThinkInAIXYZ/go-mcp/transport"
)

Expand All @@ -27,13 +28,24 @@ func main() {
Name: "current-time-v2-server",
Version: "1.0.0",
}),
// 创建一个每秒5个请求,突发容量为10的限速器
server.WithRateLimiter(components.NewTokenBucketLimiter(components.Rate{
Limit: 5.0, // 每秒5个请求
Burst: 10, // 最多允许10个请求的突发
})),
)
if err != nil {
log.Fatalf("Failed to create server: %v", err)
}

// new protocol tool with name, descipriton and properties
tool, err := protocol.NewTool("current_time", "Get current time with timezone, Asia/Shanghai is default", currentTimeReq{})
tool, err := protocol.NewTool("current_time", "Get current time with timezone, Asia/Shanghai is default", currentTimeReq{},
// 为指定工具设置每秒10个请求,突发容量为20的限速器
protocol.WithRateLimit(srv.GetLimiter(), components.Rate{
Limit: 10.0, // 每秒10个请求
Burst: 20, // 最多允许20个请求的突发
}),
)
if err != nil {
log.Fatalf("Failed to create tool: %v", err)
return
Expand Down
1 change: 1 addition & 0 deletions pkg/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ var (
ErrSessionHasNotInitialized = errors.New("the session has not been initialized")
ErrLackSession = errors.New("lack session")
ErrSendEOF = errors.New("send EOF")
ErrRateLimitExceeded = errors.New("rate limit exceeded")
)

type ResponseError struct {
Expand Down
22 changes: 19 additions & 3 deletions protocol/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,18 @@ import (
"fmt"

"github.com/ThinkInAIXYZ/go-mcp/pkg"
"github.com/ThinkInAIXYZ/go-mcp/server/components"
)

type Option func(*Tool)

// WithRateLimit 设置工具的速率限制
func WithRateLimit(rateLimit components.RateLimiter, rate components.Rate) Option {
return func(t *Tool) {
rateLimit.SetToolLimit(t.Name, rate)
}
}

// ListToolsRequest represents a request to list available tools
type ListToolsRequest struct{}

Expand Down Expand Up @@ -177,17 +187,23 @@ type ToolListChangedNotification struct {
}

// NewTool create a tool
func NewTool(name string, description string, inputReqStruct interface{}) (*Tool, error) {
func NewTool(name string, description string, inputReqStruct interface{}, opts ...Option) (*Tool, error) {
schema, err := generateSchemaFromReqStruct(inputReqStruct)
if err != nil {
return nil, err
}

return &Tool{
t := &Tool{
Name: name,
Description: description,
InputSchema: *schema,
}, nil
}

for _, opt := range opts {
opt(t)
}

return t, nil
}

func NewToolWithRawSchema(name, description string, schema json.RawMessage) *Tool {
Expand Down
95 changes: 95 additions & 0 deletions server/components/limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package components

import (
"sync"
"time"
)

// RateLimiter 定义速率限制接口
type RateLimiter interface {
Allow(toolName string) bool
SetToolLimit(toolName string, rate Rate)
}

// TokenBucketLimiter 令牌桶限速器实现
type TokenBucketLimiter struct {
mu sync.RWMutex
buckets map[string]*bucket
defaultLimit Rate
toolLimits map[string]Rate
}

// Rate 定义速率限制参数
type Rate struct {
Limit float64 // 每秒允许的请求数
Burst int // 突发请求上限
}

// bucket 令牌桶
type bucket struct {
tokens float64
lastTimestamp time.Time
rate Rate
}

// NewTokenBucketLimiter 创建新的令牌桶限速器
func NewTokenBucketLimiter(defaultRate Rate) *TokenBucketLimiter {
return &TokenBucketLimiter{
buckets: make(map[string]*bucket),
defaultLimit: defaultRate,
toolLimits: make(map[string]Rate),
}
}

// SetToolLimit 为特定工具设置限制
func (l *TokenBucketLimiter) SetToolLimit(toolName string, rate Rate) {
l.mu.Lock()
defer l.mu.Unlock()

l.toolLimits[toolName] = rate
// 如果已有桶,更新其速率
if b, exists := l.buckets[toolName]; exists {
b.rate = rate
}
}

// Allow 检查请求是否被允许
func (l *TokenBucketLimiter) Allow(toolName string) bool {
l.mu.RLock()
defer l.mu.RUnlock()

now := time.Now()

// 获取或创建桶
b, exists := l.buckets[toolName]
if !exists {
// 查找工具特定的限制,如果没有则使用默认限制
rate, exists := l.toolLimits[toolName]
if !exists {
rate = l.defaultLimit
}

b = &bucket{
tokens: float64(rate.Burst),
lastTimestamp: now,
rate: rate,
}
l.buckets[toolName] = b
}

// 计算从上次请求到现在应该添加的令牌
elapsed := now.Sub(b.lastTimestamp).Seconds()
b.lastTimestamp = now

// 添加令牌,但不超过最大值
b.tokens += elapsed * b.rate.Limit
if b.tokens > float64(b.rate.Burst) {
b.tokens = float64(b.rate.Burst)
}

if b.tokens >= 1.0 {
b.tokens -= 1.0
return true
}
return false
}
5 changes: 5 additions & 0 deletions server/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,11 @@ func (server *Server) handleRequestWithCallTool(rawParams json.RawMessage) (*pro
return nil, err
}

// 检查速率限制
if server.limiter != nil && !server.limiter.Allow(request.Name) {
return nil, pkg.ErrRateLimitExceeded
}

entry, ok := server.tools.Load(request.Name)
if !ok {
return nil, fmt.Errorf("missing tool, toolName=%s", request.Name)
Expand Down
23 changes: 23 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/ThinkInAIXYZ/go-mcp/pkg"
"github.com/ThinkInAIXYZ/go-mcp/protocol"
"github.com/ThinkInAIXYZ/go-mcp/server/components"
"github.com/ThinkInAIXYZ/go-mcp/server/session"
"github.com/ThinkInAIXYZ/go-mcp/transport"
)
Expand Down Expand Up @@ -44,6 +45,24 @@ func WithLogger(logger pkg.Logger) Option {
}
}

// WithRateLimiter 添加一个创建带速率限制的服务器的选项
func WithRateLimiter(limiter components.RateLimiter) Option {
return func(s *Server) {
s.limiter = limiter
}
}

// WithDefaultRateLimiter 添加一个便捷函数用默认参数创建限速器
func WithDefaultRateLimiter() Option {
const defaultLimit = 5.0
const defaultBurst = 10
limiter := components.NewTokenBucketLimiter(components.Rate{
Limit: defaultLimit,
Burst: defaultBurst,
})
return WithRateLimiter(limiter)
}

type Server struct {
transport transport.ServerTransport

Expand All @@ -53,6 +72,7 @@ type Server struct {
resourceTemplates pkg.SyncMap[*resourceTemplateEntry]

sessionManager *session.Manager
limiter components.RateLimiter

inShutdown *pkg.AtomicBool // true when server is in shutdown
inFlyRequest sync.WaitGroup
Expand Down Expand Up @@ -244,3 +264,6 @@ func (server *Server) sessionDetection(ctx context.Context, sessionID string) er
}
return nil
}

// GetLimiter 获取限速器
func (server *Server) GetLimiter() components.RateLimiter { return server.limiter }
Loading