From c7b7a5128aaab776fdc1794b3bf011515c37d1e3 Mon Sep 17 00:00:00 2001 From: root <1@root.com> Date: Wed, 9 Jul 2025 01:31:37 +0900 Subject: [PATCH] =?UTF-8?q?=E4=BD=BFclient=E6=9C=89=E6=90=BA=E5=B8=A6?= =?UTF-8?q?=E4=B8=8A=E4=B8=8B=E6=96=87=E7=9A=84=E8=83=BD=E5=8A=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client.go | 11 ++++++++ context.go | 78 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+) create mode 100644 context.go diff --git a/client.go b/client.go index bf9af8c..0193d44 100644 --- a/client.go +++ b/client.go @@ -73,6 +73,7 @@ type Client struct { responseBodyTransformer func(rawBody []byte, req *Request, resp *Response) (transformedBody []byte, err error) resultStateCheckFunc func(resp *Response) ResultState onError ErrorHook + ctx *clientContext // 使Client拥有携带上下文信息的能力 } type ErrorHook func(client *Client, req *Request, resp *Response, err error) @@ -82,6 +83,7 @@ func (c *Client) R() *Request { return &Request{ client: c, retryOption: c.retryOption.Clone(), + ctx: c.Context(), // 继承 Client 的上下文 } } @@ -1527,6 +1529,12 @@ func (c *Client) Clone() *Client { cc.afterResponse = cloneSlice(c.afterResponse) cc.dumpOptions = c.dumpOptions.Clone() cc.retryOption = c.retryOption.Clone() + + // clone context and clientContext + cc.ctx = &clientContext{ + context: c.ctx.context, + } + return &cc } @@ -1565,6 +1573,9 @@ func C() *Client { xmlMarshal: xml.Marshal, xmlUnmarshal: xml.Unmarshal, cookiejarFactory: memoryCookieJarFactory, + ctx: &clientContext{ + context: context.Background(), + }, } c.SetRedirectPolicy(DefaultRedirectPolicy()) c.initCookieJar() diff --git a/context.go b/context.go new file mode 100644 index 0000000..72a6807 --- /dev/null +++ b/context.go @@ -0,0 +1,78 @@ +package req + +import ( + "context" + "sync" +) + +// clientContext 为 Client 提供上下文能力,确保并发安全 +type clientContext struct { + mu sync.RWMutex + context context.Context +} + +// WithContext 添加上下文键值对,使用读写锁保证并发安全 +// 这种方式避免了 Clone 整个 Client 的开销 +func (c *Client) WithContext(key, value any) *Client { + if c.ctx == nil { + c.ctx = &clientContext{ + context: context.Background(), + } + } + + c.ctx.mu.Lock() + defer c.ctx.mu.Unlock() + + if c.ctx.context == nil { + c.ctx.context = context.Background() + } + c.ctx.context = context.WithValue(c.ctx.context, key, value) + return c +} + +// GetContext 获取上下文值,使用读锁保证并发安全 +func (c *Client) GetContext(key any) any { + if c.ctx == nil { + return context.Background().Value(key) + } + + c.ctx.mu.RLock() + defer c.ctx.mu.RUnlock() + + if c.ctx.context == nil { + return nil + } + return c.ctx.context.Value(key) +} + +// SetContext 设置基础上下文,使用写锁保证并发安全 +func (c *Client) SetContext(ctx context.Context) *Client { + if ctx == nil { + ctx = context.Background() + } + + if c.ctx == nil { + c.ctx = &clientContext{} + } + + c.ctx.mu.Lock() + defer c.ctx.mu.Unlock() + + c.ctx.context = ctx + return c +} + +// Context 获取当前上下文,使用读锁保证并发安全 +func (c *Client) Context() context.Context { + if c.ctx == nil { + return context.Background() + } + + c.ctx.mu.RLock() + defer c.ctx.mu.RUnlock() + + if c.ctx.context == nil { + return context.Background() + } + return c.ctx.context +}