diff --git a/_examples/queue/queue.go b/_examples/queue/queue.go index e6cacc3c2..48bfb386a 100644 --- a/_examples/queue/queue.go +++ b/_examples/queue/queue.go @@ -30,8 +30,10 @@ func main() { }) for i := 0; i < 5; i++ { + ctx := colly.NewContext() + ctx.Put("num", i) // Add URLs to the queue - q.AddURL(fmt.Sprintf("%s?n=%d", url, i)) + q.AddURL(fmt.Sprintf("%s?n=%d", url, i), colly.WithRequestContext(ctx)) } // Consume URLs q.Run(c) diff --git a/queue/queue.go b/queue/queue.go index 0d0d78a66..6a987c435 100644 --- a/queue/queue.go +++ b/queue/queue.go @@ -78,7 +78,7 @@ func (q *Queue) IsEmpty() bool { } // AddURL adds a new URL to the queue -func (q *Queue) AddURL(URL string) error { +func (q *Queue) AddURL(URL string, opts ...colly.RequestOpt) error { u, err := urlParser.Parse(URL) if err != nil { return err @@ -87,10 +87,8 @@ func (q *Queue) AddURL(URL string) error { if err != nil { return err } - r := &colly.Request{ - URL: u2, - Method: "GET", - } + r := colly.NewRequest(u2, opts...) + d, err := r.Marshal() if err != nil { return err diff --git a/request.go b/request.go index 6beef8347..fef904eb0 100644 --- a/request.go +++ b/request.go @@ -192,3 +192,14 @@ func (r *Request) Marshal() ([]byte, error) { } return json.Marshal(sr) } + +func NewRequest(url *url.URL, opts ...RequestOpt) *Request { + req := &Request{ + URL: url, + Method: "GET", + } + for _, opt := range opts { + opt(req) + } + return req +} diff --git a/request_opt.go b/request_opt.go new file mode 100644 index 000000000..26f7f459f --- /dev/null +++ b/request_opt.go @@ -0,0 +1,61 @@ +package colly + +import ( + "io" + "net/http" +) + +type RequestOpt func(req *Request) + +func WithRequestProxyURL(proxyUrl string) RequestOpt { + return func(req *Request) { + req.ProxyURL = proxyUrl + } +} +func WithRequestID(id uint32) RequestOpt { + return func(req *Request) { + req.ID = id + } +} +func WithRequestResponseCharacterEncoding(responseCharacterEncoding string) RequestOpt { + return func(req *Request) { + req.ResponseCharacterEncoding = responseCharacterEncoding + } +} + +func WithRequestBody(body io.Reader) RequestOpt { + return func(req *Request) { + req.Body = body + } +} + +func WithRequestMethod(method string) RequestOpt { + return func(req *Request) { + req.Method = method + } +} +func WithRequestHost(host string) RequestOpt { + return func(req *Request) { + req.Host = host + } +} + +func WithRequestDepth(depth int) RequestOpt { + return func(req *Request) { + req.Depth = depth + } +} + +func WithRequestSetHeader(k, v string) RequestOpt { + return func(req *Request) { + if req.Headers == nil { + req.Headers = &http.Header{} + } + req.Headers.Add(k, v) + } +} +func WithRequestContext(ctx *Context) RequestOpt { + return func(req *Request) { + req.Ctx = ctx + } +} diff --git a/request_test.go b/request_test.go new file mode 100644 index 000000000..edfe93468 --- /dev/null +++ b/request_test.go @@ -0,0 +1,40 @@ +package colly + +import ( + "bytes" + "net/url" + "testing" +) + +func TestNewRequest(t *testing.T) { + + ctx := &Context{} + host := "test host" + headerK := "headerk" + headerV := "headerv" + depth := 8 + method := "POST" + body := bytes.NewReader([]byte{1, 2, 3}) + respCharSet := "testC" + proxyUrl := "testUrl" + url := &url.URL{Path: "test"} + var id uint32 = 123 + + req := NewRequest( + url, + WithRequestContext(ctx), + WithRequestSetHeader(headerK, headerV), + WithRequestDepth(depth), + WithRequestMethod(method), + WithRequestBody(body), + WithRequestResponseCharacterEncoding(respCharSet), + WithRequestProxyURL(proxyUrl), + WithRequestHost(host), + WithRequestID(id), + ) + + if req.URL != url || req.Ctx != ctx || req.Headers.Get(headerK) != headerV || req.Depth != depth || req.Method != req.Method || req.Body != body || req.ResponseCharacterEncoding != respCharSet || req.ProxyURL != proxyUrl || req.ID != id { + t.Fail() + } + +}