diff --git a/client.go b/client.go index bf9af8c..0193d44 100644 --- a/client.go +++ b/client.go @@ -73,6 +73,7 @@ type Client struct { responseBodyTransformer func(rawBody []byte, req *Request, resp *Response) (transformedBody []byte, err error) resultStateCheckFunc func(resp *Response) ResultState onError ErrorHook + ctx *clientContext // 使Client拥有携带上下文信息的能力 } type ErrorHook func(client *Client, req *Request, resp *Response, err error) @@ -82,6 +83,7 @@ func (c *Client) R() *Request { return &Request{ client: c, retryOption: c.retryOption.Clone(), + ctx: c.Context(), // 继承 Client 的上下文 } } @@ -1527,6 +1529,12 @@ func (c *Client) Clone() *Client { cc.afterResponse = cloneSlice(c.afterResponse) cc.dumpOptions = c.dumpOptions.Clone() cc.retryOption = c.retryOption.Clone() + + // clone context and clientContext + cc.ctx = &clientContext{ + context: c.ctx.context, + } + return &cc } @@ -1565,6 +1573,9 @@ func C() *Client { xmlMarshal: xml.Marshal, xmlUnmarshal: xml.Unmarshal, cookiejarFactory: memoryCookieJarFactory, + ctx: &clientContext{ + context: context.Background(), + }, } c.SetRedirectPolicy(DefaultRedirectPolicy()) c.initCookieJar() diff --git a/context.go b/context.go new file mode 100644 index 0000000..72a6807 --- /dev/null +++ b/context.go @@ -0,0 +1,78 @@ +package req + +import ( + "context" + "sync" +) + +// clientContext 为 Client 提供上下文能力,确保并发安全 +type clientContext struct { + mu sync.RWMutex + context context.Context +} + +// WithContext 添加上下文键值对,使用读写锁保证并发安全 +// 这种方式避免了 Clone 整个 Client 的开销 +func (c *Client) WithContext(key, value any) *Client { + if c.ctx == nil { + c.ctx = &clientContext{ + context: context.Background(), + } + } + + c.ctx.mu.Lock() + defer c.ctx.mu.Unlock() + + if c.ctx.context == nil { + c.ctx.context = context.Background() + } + c.ctx.context = context.WithValue(c.ctx.context, key, value) + return c +} + +// GetContext 获取上下文值,使用读锁保证并发安全 +func (c *Client) GetContext(key any) any { + if c.ctx == nil { + return context.Background().Value(key) + } + + c.ctx.mu.RLock() + defer c.ctx.mu.RUnlock() + + if c.ctx.context == nil { + return nil + } + return c.ctx.context.Value(key) +} + +// SetContext 设置基础上下文,使用写锁保证并发安全 +func (c *Client) SetContext(ctx context.Context) *Client { + if ctx == nil { + ctx = context.Background() + } + + if c.ctx == nil { + c.ctx = &clientContext{} + } + + c.ctx.mu.Lock() + defer c.ctx.mu.Unlock() + + c.ctx.context = ctx + return c +} + +// Context 获取当前上下文,使用读锁保证并发安全 +func (c *Client) Context() context.Context { + if c.ctx == nil { + return context.Background() + } + + c.ctx.mu.RLock() + defer c.ctx.mu.RUnlock() + + if c.ctx.context == nil { + return context.Background() + } + return c.ctx.context +}