diff --git a/colly.go b/colly.go index fdca9451..1d6cc9eb 100644 --- a/colly.go +++ b/colly.go @@ -140,19 +140,19 @@ type Collector struct { type RequestCallback func(*Request) // ResponseHeadersCallback is a type alias for OnResponseHeaders callback functions -type ResponseHeadersCallback func(*Response) +type ResponseHeadersCallback func(*Response) error // ResponseCallback is a type alias for OnResponse callback functions -type ResponseCallback func(*Response) +type ResponseCallback func(*Response) error // HTMLCallback is a type alias for OnHTML callback functions -type HTMLCallback func(*HTMLElement) +type HTMLCallback func(*HTMLElement) error // XMLCallback is a type alias for OnXML callback functions -type XMLCallback func(*XMLElement) +type XMLCallback func(*XMLElement) error // ErrorCallback is a type alias for OnError callback functions -type ErrorCallback func(*Response, error) +type ErrorCallback func(*Response, error) error // ScrapedCallback is a type alias for OnScraped callback functions type ScrapedCallback func(*Response) @@ -691,13 +691,18 @@ func (c *Collector) fetch(u, method string, depth int, requestData io.Reader, ct req = hTrace.WithTrace(req) } origURL := req.URL - checkHeadersFunc := func(req *http.Request, statusCode int, headers http.Header) bool { + checkHeadersFunc := func(req *http.Request, statusCode int, headers http.Header) error { if req.URL != origURL { request.URL = req.URL request.Headers = &req.Header } - c.handleOnResponseHeaders(&Response{Ctx: ctx, Request: request, StatusCode: statusCode, Headers: &headers}) - return !request.abort + if err := c.handleOnResponseHeaders(&Response{Ctx: ctx, Request: request, StatusCode: statusCode, Headers: &headers}); err != nil { + return err + } + if request.abort { + return ErrAbortedAfterHeaders + } + return nil } response, err := c.backend.Cache(req, c.MaxBodySize, checkHeadersFunc, c.CacheDir) if proxyURL, ok := req.Context().Value(ProxyURLKey).(string); ok { @@ -716,16 +721,24 @@ func (c *Collector) fetch(u, method string, depth int, requestData io.Reader, ct return err } - c.handleOnResponse(response) + if err := c.handleOnResponse(response); err != nil { + if err = c.handleOnError(response, err, request, ctx); err != nil { + return err + } + } err = c.handleOnHTML(response) if err != nil { - c.handleOnError(response, err, request, ctx) + if err = c.handleOnError(response, err, request, ctx); err != nil { + return err + } } err = c.handleOnXML(response) if err != nil { - c.handleOnError(response, err, request, ctx) + if err = c.handleOnError(response, err, request, ctx); err != nil { + return err + } } c.handleOnScraped(response) @@ -1083,7 +1096,7 @@ func (c *Collector) handleOnRequest(r *Request) { } } -func (c *Collector) handleOnResponse(r *Response) { +func (c *Collector) handleOnResponse(r *Response) error { if c.debugger != nil { c.debugger.Event(createEvent("response", r.Request.ID, c.ID, map[string]string{ "url": r.Request.URL.String(), @@ -1091,11 +1104,14 @@ func (c *Collector) handleOnResponse(r *Response) { })) } for _, f := range c.responseCallbacks { - f(r) + if err := f(r); err != nil { + return err + } } + return nil } -func (c *Collector) handleOnResponseHeaders(r *Response) { +func (c *Collector) handleOnResponseHeaders(r *Response) error { if c.debugger != nil { c.debugger.Event(createEvent("responseHeaders", r.Request.ID, c.ID, map[string]string{ "url": r.Request.URL.String(), @@ -1103,17 +1119,20 @@ func (c *Collector) handleOnResponseHeaders(r *Response) { })) } for _, f := range c.responseHeadersCallbacks { - f(r) + if err := f(r); err != nil { + return err + } } + return nil } -func (c *Collector) handleOnHTML(resp *Response) error { +func (c *Collector) handleOnHTML(resp *Response) (err error) { if len(c.htmlCallbacks) == 0 || !strings.Contains(strings.ToLower(resp.Headers.Get("Content-Type")), "html") { return nil } doc, err := goquery.NewDocumentFromReader(bytes.NewBuffer(resp.Body)) if err != nil { - return err + return } if href, found := doc.Find("base[href]").Attr("href"); found { u, err := urlParser.ParseRef(resp.Request.URL.String(), href) @@ -1127,7 +1146,7 @@ func (c *Collector) handleOnHTML(resp *Response) error { } for _, cc := range c.htmlCallbacks { i := 0 - doc.Find(cc.Selector).Each(func(_ int, s *goquery.Selection) { + doc.Find(cc.Selector).EachWithBreak(func(_ int, s *goquery.Selection) bool { for _, n := range s.Nodes { e := NewHTMLElementFromSelectionNode(resp, s, n, i) i++ @@ -1137,11 +1156,17 @@ func (c *Collector) handleOnHTML(resp *Response) error { "url": resp.Request.URL.String(), })) } - cc.Function(e) + if err = cc.Function(e); err != nil { + return false + } } + return true }) + if err != nil { + return + } } - return nil + return } func (c *Collector) handleOnXML(resp *Response) error { @@ -1180,7 +1205,9 @@ func (c *Collector) handleOnXML(resp *Response) error { "url": resp.Request.URL.String(), })) } - cc.Function(e) + if err := cc.Function(e); err != nil { + return err + } } } } else if strings.Contains(contentType, "xml") || isXMLFile { @@ -1198,8 +1225,11 @@ func (c *Collector) handleOnXML(resp *Response) error { "url": resp.Request.URL.String(), })) } - cc.Function(e) + err = cc.Function(e) }) + if err != nil { + return err + } } } return nil @@ -1230,10 +1260,15 @@ func (c *Collector) handleOnError(response *Response, err error, request *Reques if response.Ctx == nil { response.Ctx = request.Ctx } + if len(c.errorCallbacks) < 1 { + return err + } for _, f := range c.errorCallbacks { - f(response, err) + if abortErr := f(response, err); abortErr != nil { + return abortErr + } } - return err + return nil } func (c *Collector) handleOnScraped(r *Response) { diff --git a/colly_test.go b/colly_test.go index 4358b63e..215a085e 100644 --- a/colly_test.go +++ b/colly_test.go @@ -447,7 +447,7 @@ func TestCollectorVisit(t *testing.T) { r.Ctx.Put("x", "y") }) - c.OnResponse(func(r *Response) { + c.OnResponse(func(r *Response) error { onResponseCalled = true if r.Ctx.Get("x") != "y" { @@ -457,6 +457,7 @@ func TestCollectorVisit(t *testing.T) { if !bytes.Equal(r.Body, serverIndexResponse) { t.Error("Response body does not match with the original content") } + return nil }) c.OnScraped(func(r *Response) { @@ -530,14 +531,16 @@ func TestCollectorVisitResponseHeaders(t *testing.T) { var onResponseHeadersCalled bool c := NewCollector() - c.OnResponseHeaders(func(r *Response) { + c.OnResponseHeaders(func(r *Response) error { onResponseHeadersCalled = true if r.Headers.Get("Content-Type") == "application/octet-stream" { r.Request.Abort() } + return nil }) - c.OnResponse(func(r *Response) { + c.OnResponse(func(r *Response) error { t.Error("OnResponse was called") + return nil }) c.Visit(ts.URL + "/large_binary") if !onResponseHeadersCalled { @@ -554,21 +557,23 @@ func TestCollectorOnHTML(t *testing.T) { titleCallbackCalled := false paragraphCallbackCount := 0 - c.OnHTML("title", func(e *HTMLElement) { + c.OnHTML("title", func(e *HTMLElement) error { titleCallbackCalled = true if e.Text != "Test Page" { t.Error("Title element text does not match, got", e.Text) } + return nil }) - c.OnHTML("p", func(e *HTMLElement) { + c.OnHTML("p", func(e *HTMLElement) error { paragraphCallbackCount++ if e.Attr("class") != "description" { t.Error("Failed to get paragraph's class attribute") } + return nil }) - c.OnHTML("body", func(e *HTMLElement) { + c.OnHTML("body", func(e *HTMLElement) error { if e.ChildAttr("p", "class") != "description" { t.Error("Invalid class value") } @@ -576,6 +581,7 @@ func TestCollectorOnHTML(t *testing.T) { if len(classes) != 2 { t.Error("Invalid class values") } + return nil }) c.Visit(ts.URL + "/html") @@ -629,11 +635,12 @@ func TestCollectorPostRevisit(t *testing.T) { visitCount := 0 c := NewCollector() - c.OnResponse(func(r *Response) { + c.OnResponse(func(r *Response) error { if postValue != string(r.Body) { t.Error("Failed to send data with POST") } visitCount++ + return nil }) c.Post(ts.URL+"/login", postData) @@ -730,13 +737,14 @@ func TestSetCookieRedirect(t *testing.T) { ts.Start() defer ts.Close() c := NewCollector() - c.OnResponse(func(r *Response) { + c.OnResponse(func(r *Response) error { if got, want := r.Body, serverIndexResponse; !bytes.Equal(got, want) { t.Errorf("bad response body got=%q want=%q", got, want) } if got, want := r.StatusCode, http.StatusOK; got != want { t.Errorf("bad response code got=%d want=%d", got, want) } + return nil }) if err := c.Visit(ts.URL); err != nil { t.Fatal(err) @@ -831,10 +839,11 @@ func TestCollectorPost(t *testing.T) { postValue := "hello" c := NewCollector() - c.OnResponse(func(r *Response) { + c.OnResponse(func(r *Response) error { if postValue != string(r.Body) { t.Error("Failed to send data with POST") } + return nil }) c.Post(ts.URL+"/login", map[string]string{ @@ -849,10 +858,11 @@ func TestCollectorPostRaw(t *testing.T) { postValue := "hello" c := NewCollector() - c.OnResponse(func(r *Response) { + c.OnResponse(func(r *Response) error { if postValue != string(r.Body) { t.Error("Failed to send data with POST") } + return nil }) c.PostRaw(ts.URL+"/login", []byte("name="+postValue)) @@ -867,11 +877,12 @@ func TestCollectorPostRawRevisit(t *testing.T) { visitCount := 0 c := NewCollector() - c.OnResponse(func(r *Response) { + c.OnResponse(func(r *Response) error { if postValue != string(r.Body) { t.Error("Failed to send data with POST RAW") } visitCount++ + return nil }) c.PostRaw(ts.URL+"/login", []byte(postData)) @@ -897,23 +908,26 @@ func TestRedirect(t *testing.T) { defer ts.Close() c := NewCollector() - c.OnHTML("a[href]", func(e *HTMLElement) { + c.OnHTML("a[href]", func(e *HTMLElement) error { u := e.Request.AbsoluteURL(e.Attr("href")) if !strings.HasSuffix(u, "/redirected/test") { t.Error("Invalid URL after redirect: " + u) } + return nil }) - c.OnResponseHeaders(func(r *Response) { + c.OnResponseHeaders(func(r *Response) error { if !strings.HasSuffix(r.Request.URL.String(), "/redirected/") { t.Error("Invalid URL in Request after redirect (OnResponseHeaders): " + r.Request.URL.String()) } + return nil }) - c.OnResponse(func(r *Response) { + c.OnResponse(func(r *Response) error { if !strings.HasSuffix(r.Request.URL.String(), "/redirected/") { t.Error("Invalid URL in Request after redirect (OnResponse): " + r.Request.URL.String()) } + return nil }) c.Visit(ts.URL + "/redirect") } @@ -937,12 +951,13 @@ func TestRedirectWithDisallowedURLs(t *testing.T) { c := NewCollector() c.DisallowedURLFilters = []*regexp.Regexp{regexp.MustCompile(ts.URL + "/redirected/test")} - c.OnHTML("a[href]", func(e *HTMLElement) { + c.OnHTML("a[href]", func(e *HTMLElement) error { u := e.Request.AbsoluteURL(e.Attr("href")) err := c.Visit(u) if !errors.Is(err, ErrForbiddenURL) { t.Error("URL should have been forbidden: " + u) } + return nil }) c.Visit(ts.URL + "/redirect") @@ -953,20 +968,22 @@ func TestBaseTag(t *testing.T) { defer ts.Close() c := NewCollector() - c.OnHTML("a[href]", func(e *HTMLElement) { + c.OnHTML("a[href]", func(e *HTMLElement) error { u := e.Request.AbsoluteURL(e.Attr("href")) if u != "http://xy.com/z" { t.Error("Invalid tag handling in OnHTML: expected https://xy.com/z, got " + u) } + return nil }) c.Visit(ts.URL + "/base") c2 := NewCollector() - c2.OnXML("//a", func(e *XMLElement) { + c2.OnXML("//a", func(e *XMLElement) error { u := e.Request.AbsoluteURL(e.Attr("href")) if u != "http://xy.com/z" { t.Error("Invalid tag handling in OnXML: expected https://xy.com/z, got " + u) } + return nil }) c2.Visit(ts.URL + "/base") } @@ -976,22 +993,24 @@ func TestBaseTagRelative(t *testing.T) { defer ts.Close() c := NewCollector() - c.OnHTML("a[href]", func(e *HTMLElement) { + c.OnHTML("a[href]", func(e *HTMLElement) error { u := e.Request.AbsoluteURL(e.Attr("href")) expected := ts.URL + "/foobar/z" if u != expected { t.Errorf("Invalid tag handling in OnHTML: expected %q, got %q", expected, u) } + return nil }) c.Visit(ts.URL + "/base_relative") c2 := NewCollector() - c2.OnXML("//a", func(e *XMLElement) { + c2.OnXML("//a", func(e *XMLElement) error { u := e.Request.AbsoluteURL(e.Attr("href")) expected := ts.URL + "/foobar/z" if u != expected { t.Errorf("Invalid tag handling in OnXML: expected %q, got %q", expected, u) } + return nil }) c2.Visit(ts.URL + "/base_relative") } @@ -1010,13 +1029,15 @@ func TestTabsAndNewlines(t *testing.T) { } c := NewCollector() - c.OnResponse(func(res *Response) { + c.OnResponse(func(res *Response) error { visited[res.Request.URL.EscapedPath()] = struct{}{} + return nil }) - c.OnHTML("a[href]", func(e *HTMLElement) { + c.OnHTML("a[href]", func(e *HTMLElement) error { if err := e.Request.Visit(e.Attr("href")); err != nil { t.Errorf("visit failed: %v", err) } + return nil }) if err := c.Visit(ts.URL + "/tabs_and_newlines"); err != nil { @@ -1035,8 +1056,9 @@ func TestLonePercent(t *testing.T) { var visitedPath string c := NewCollector() - c.OnResponse(func(res *Response) { + c.OnResponse(func(res *Response) error { visitedPath = res.Request.URL.RequestURI() + return nil }) if err := c.Visit(ts.URL + "/100%"); err != nil { t.Errorf("visit failed: %v", err) @@ -1083,10 +1105,11 @@ func TestRobotsWhenAllowed(t *testing.T) { c := NewCollector() c.IgnoreRobotsTxt = false - c.OnResponse(func(resp *Response) { + c.OnResponse(func(resp *Response) error { if resp.StatusCode != 200 { t.Fatalf("Wrong response code: %d", resp.StatusCode) } + return nil }) err := c.Visit(ts.URL + "/allowed") @@ -1103,8 +1126,9 @@ func TestRobotsWhenDisallowed(t *testing.T) { c := NewCollector() c.IgnoreRobotsTxt = false - c.OnResponse(func(resp *Response) { + c.OnResponse(func(resp *Response) error { t.Fatalf("Received response: %d", resp.StatusCode) + return nil }) err := c.Visit(ts.URL + "/disallowed") @@ -1120,8 +1144,9 @@ func TestRobotsWhenDisallowedWithQueryParameter(t *testing.T) { c := NewCollector() c.IgnoreRobotsTxt = false - c.OnResponse(func(resp *Response) { + c.OnResponse(func(resp *Response) error { t.Fatalf("Received response: %d", resp.StatusCode) + return nil }) err := c.Visit(ts.URL + "/allowed?q=1") @@ -1137,10 +1162,11 @@ func TestIgnoreRobotsWhenDisallowed(t *testing.T) { c := NewCollector() c.IgnoreRobotsTxt = true - c.OnResponse(func(resp *Response) { + c.OnResponse(func(resp *Response) error { if resp.StatusCode != 200 { t.Fatalf("Wrong response code: %d", resp.StatusCode) } + return nil }) err := c.Visit(ts.URL + "/disallowed") @@ -1175,10 +1201,11 @@ func TestEnvSettings(t *testing.T) { valid := false - c.OnResponse(func(resp *Response) { + c.OnResponse(func(resp *Response) error { if string(resp.Body) == "test" { valid = true } + return nil }) c.Visit(ts.URL + "/user_agent") @@ -1200,8 +1227,9 @@ func TestUserAgent(t *testing.T) { func() { c := NewCollector() - c.OnResponse(func(resp *Response) { + c.OnResponse(func(resp *Response) error { receivedUserAgent = string(resp.Body) + return nil }) c.Visit(ts.URL + "/user_agent") if got, want := receivedUserAgent, defaultUserAgent; got != want { @@ -1210,8 +1238,9 @@ func TestUserAgent(t *testing.T) { }() func() { c := NewCollector(UserAgent(exampleUserAgent1)) - c.OnResponse(func(resp *Response) { + c.OnResponse(func(resp *Response) error { receivedUserAgent = string(resp.Body) + return nil }) c.Visit(ts.URL + "/user_agent") if got, want := receivedUserAgent, exampleUserAgent1; got != want { @@ -1220,8 +1249,9 @@ func TestUserAgent(t *testing.T) { }() func() { c := NewCollector(UserAgent(exampleUserAgent1)) - c.OnResponse(func(resp *Response) { + c.OnResponse(func(resp *Response) error { receivedUserAgent = string(resp.Body) + return nil }) c.Request("GET", ts.URL+"/user_agent", nil, nil, nil) @@ -1231,8 +1261,9 @@ func TestUserAgent(t *testing.T) { }() func() { c := NewCollector(UserAgent(exampleUserAgent1)) - c.OnResponse(func(resp *Response) { + c.OnResponse(func(resp *Response) error { receivedUserAgent = string(resp.Body) + return nil }) c.Request("GET", ts.URL+"/user_agent", nil, nil, http.Header{}) @@ -1242,8 +1273,9 @@ func TestUserAgent(t *testing.T) { }() func() { c := NewCollector(UserAgent(exampleUserAgent1)) - c.OnResponse(func(resp *Response) { + c.OnResponse(func(resp *Response) error { receivedUserAgent = string(resp.Body) + return nil }) hdr := http.Header{} hdr.Set("User-Agent", "") @@ -1255,8 +1287,9 @@ func TestUserAgent(t *testing.T) { }() func() { c := NewCollector(UserAgent(exampleUserAgent1)) - c.OnResponse(func(resp *Response) { + c.OnResponse(func(resp *Response) error { receivedUserAgent = string(resp.Body) + return nil }) hdr := http.Header{} hdr.Set("User-Agent", exampleUserAgent2) @@ -1281,8 +1314,9 @@ func TestHeaders(t *testing.T) { c := NewCollector( Headers(map[string]string{"Host": exampleHostHeader}), ) - c.OnResponse(func(resp *Response) { + c.OnResponse(func(resp *Response) error { receivedHeader = string(resp.Body) + return nil }) c.Visit(ts.URL + "/host_header") if got, want := receivedHeader, exampleHostHeader; got != want { @@ -1293,8 +1327,9 @@ func TestHeaders(t *testing.T) { c := NewCollector( Headers(map[string]string{"Test": exampleTestHeader}), ) - c.OnResponse(func(resp *Response) { + c.OnResponse(func(resp *Response) error { receivedHeader = string(resp.Body) + return nil }) c.Visit(ts.URL + "/custom_header") if got, want := receivedHeader, exampleTestHeader; got != want { @@ -1312,10 +1347,11 @@ func TestParseHTTPErrorResponse(t *testing.T) { AllowURLRevisit(), ) - c.OnHTML("p", func(e *HTMLElement) { + c.OnHTML("p", func(e *HTMLElement) error { if e.Text == "error" { contentCount++ } + return nil }) c.Visit(ts.URL + "/500") @@ -1382,21 +1418,23 @@ func TestCollectorOnXMLWithHtml(t *testing.T) { titleCallbackCalled := false paragraphCallbackCount := 0 - c.OnXML("/html/head/title", func(e *XMLElement) { + c.OnXML("/html/head/title", func(e *XMLElement) error { titleCallbackCalled = true if e.Text != "Test Page" { t.Error("Title element text does not match, got", e.Text) } + return nil }) - c.OnXML("/html/body/p", func(e *XMLElement) { + c.OnXML("/html/body/p", func(e *XMLElement) error { paragraphCallbackCount++ if e.Attr("class") != "description" { t.Error("Failed to get paragraph's class attribute") } + return nil }) - c.OnXML("/html/body", func(e *XMLElement) { + c.OnXML("/html/body", func(e *XMLElement) error { if e.ChildAttr("p", "class") != "description" { t.Error("Invalid class value") } @@ -1404,6 +1442,7 @@ func TestCollectorOnXMLWithHtml(t *testing.T) { if len(classes) != 2 { t.Error("Invalid class values") } + return nil }) c.Visit(ts.URL + "/html") @@ -1426,21 +1465,23 @@ func TestCollectorOnXMLWithXML(t *testing.T) { titleCallbackCalled := false paragraphCallbackCount := 0 - c.OnXML("//page/title", func(e *XMLElement) { + c.OnXML("//page/title", func(e *XMLElement) error { titleCallbackCalled = true if e.Text != "Test Page" { t.Error("Title element text does not match, got", e.Text) } + return nil }) - c.OnXML("//page/paragraph", func(e *XMLElement) { + c.OnXML("//page/paragraph", func(e *XMLElement) error { paragraphCallbackCount++ if e.Attr("type") != "description" { t.Error("Failed to get paragraph's type attribute") } + return nil }) - c.OnXML("/page", func(e *XMLElement) { + c.OnXML("/page", func(e *XMLElement) error { if e.ChildAttr("paragraph", "type") != "description" { t.Error("Invalid type value") } @@ -1448,6 +1489,7 @@ func TestCollectorOnXMLWithXML(t *testing.T) { if len(classes) != 2 { t.Error("Invalid type values") } + return nil }) c.Visit(ts.URL + "/xml") @@ -1466,10 +1508,11 @@ func TestCollectorVisitWithTrace(t *testing.T) { defer ts.Close() c := NewCollector(AllowedDomains("localhost", "127.0.0.1", "::1"), TraceHTTP()) - c.OnResponse(func(resp *Response) { + c.OnResponse(func(resp *Response) error { if resp.Trace == nil { t.Error("Failed to initialize trace") } + return nil }) err := c.Visit(ts.URL) @@ -1484,8 +1527,9 @@ func TestCollectorVisitWithCheckHead(t *testing.T) { c := NewCollector(CheckHead()) var requestMethodChain []string - c.OnResponse(func(resp *Response) { + c.OnResponse(func(resp *Response) error { requestMethodChain = append(requestMethodChain, resp.Request.Method) + return nil }) err := c.Visit(ts.URL) @@ -1506,12 +1550,13 @@ func TestCollectorDepth(t *testing.T) { AllowURLRevisit(), ) requestCount := 0 - c1.OnResponse(func(resp *Response) { + c1.OnResponse(func(resp *Response) error { requestCount++ if requestCount >= 10 { - return + return nil } c1.Visit(ts.URL) + return nil }) c1.Visit(ts.URL) if requestCount < 10 { @@ -1520,9 +1565,10 @@ func TestCollectorDepth(t *testing.T) { c2 := c1.Clone() requestCount = 0 - c2.OnResponse(func(resp *Response) { + c2.OnResponse(func(resp *Response) error { requestCount++ resp.Request.Visit(ts.URL) + return nil }) c2.Visit(ts.URL) if requestCount != 2 { @@ -1550,9 +1596,10 @@ func TestCollectorRequests(t *testing.T) { AllowURLRevisit(), ) requestCount := 0 - c1.OnResponse(func(resp *Response) { + c1.OnResponse(func(resp *Response) error { requestCount++ c1.Visit(ts.URL) + return nil }) c1.Visit(ts.URL) if requestCount != 5 { @@ -1575,15 +1622,17 @@ func TestCollectorContext(t *testing.T) { onErrorCalled := false - c.OnResponse(func(resp *Response) { + c.OnResponse(func(resp *Response) error { t.Error("OnResponse was called, expected OnError") + return nil }) - c.OnError(func(resp *Response, err error) { + c.OnError(func(resp *Response, err error) error { onErrorCalled = true if err != context.DeadlineExceeded { t.Errorf("OnError got err=%#v, expected context.DeadlineExceeded", err) } + return err }) err := c.Visit(ts.URL + "/slow") @@ -1602,7 +1651,9 @@ func BenchmarkOnHTML(b *testing.B) { defer ts.Close() c := NewCollector() - c.OnHTML("p", func(_ *HTMLElement) {}) + c.OnHTML("p", func(_ *HTMLElement) error { + return nil + }) for n := 0; n < b.N; n++ { c.Visit(fmt.Sprintf("%s/html?q=%d", ts.URL, n)) @@ -1614,7 +1665,9 @@ func BenchmarkOnXML(b *testing.B) { defer ts.Close() c := NewCollector() - c.OnXML("//p", func(_ *XMLElement) {}) + c.OnXML("//p", func(_ *XMLElement) error { + return nil + }) for n := 0; n < b.N; n++ { c.Visit(fmt.Sprintf("%s/html?q=%d", ts.URL, n)) @@ -1627,7 +1680,9 @@ func BenchmarkOnResponse(b *testing.B) { c := NewCollector() c.AllowURLRevisit = true - c.OnResponse(func(_ *Response) {}) + c.OnResponse(func(_ *Response) error { + return nil + }) for n := 0; n < b.N; n++ { c.Visit(ts.URL) @@ -1665,3 +1720,72 @@ func requireSessionCookieAuthPage(handler http.Handler) http.Handler { handler.ServeHTTP(w, r) }) } +func TestResponseContextError(t *testing.T) { + // Test in progress of Responsing error routing to onerror + + ts := newTestServer() + defer ts.Close() + + //test response error + c := NewCollector() + var onErrorCalled = false + var responseErrored = errors.New("responseErrored") + c.OnResponse(func(resp *Response) error { + return responseErrored + }) + c.OnError(func(resp *Response, err error) error { + onErrorCalled = true + if err != responseErrored { + t.Errorf("OnError got err=%#v, expected responseErrored", err) + } + return nil + }) + err := c.Visit(ts.URL + "/html") + if err != nil { + t.Errorf("Visit return err=%#v, expected nil", err) + } + if !onErrorCalled { + t.Error("OnError was not called") + } + + //test onhtml error + c = NewCollector() + c.OnHTML("*", func(*HTMLElement) error { + return responseErrored + }) + c.OnError(func(resp *Response, err error) error { + onErrorCalled = true + if err != responseErrored { + t.Errorf("OnError got err=%#v, expected responseErrored", err) + } + return nil + }) + err = c.Visit(ts.URL + "/html") + if err != nil { + t.Errorf("Visit return err=%#v, expected nil", err) + } + if !onErrorCalled { + t.Error("OnError was not called") + } + + //test onxml error + c = NewCollector() + c.OnXML("*", func(*XMLElement) error { + return responseErrored + }) + c.OnError(func(resp *Response, err error) error { + onErrorCalled = true + if err != responseErrored { + t.Errorf("OnError got err=%#v, expected responseErrored", err) + } + return nil + }) + err = c.Visit(ts.URL + "/xml") + if err != nil { + t.Errorf("Visit return err=%#v, expected nil", err) + } + if !onErrorCalled { + t.Error("OnError was not called") + } + +} diff --git a/extensions/referer.go b/extensions/referer.go index 32a1c69e..0bdc4949 100644 --- a/extensions/referer.go +++ b/extensions/referer.go @@ -8,8 +8,9 @@ import ( // Warning: this extension works only if you use Request.Visit // from callbacks instead of Collector.Visit. func Referer(c *colly.Collector) { - c.OnResponse(func(r *colly.Response) { + c.OnResponse(func(r *colly.Response) error { r.Ctx.Put("_referer", r.Request.URL.String()) + return nil }) c.OnRequest(func(r *colly.Request) { if ref := r.Ctx.Get("_referer"); ref != "" { diff --git a/http_backend.go b/http_backend.go index f48df628..7d9a4b22 100644 --- a/http_backend.go +++ b/http_backend.go @@ -18,6 +18,7 @@ import ( "crypto/sha1" "encoding/gob" "encoding/hex" + "errors" "io" "io/ioutil" "math/rand" @@ -40,7 +41,9 @@ type httpBackend struct { lock *sync.RWMutex } -type checkHeadersFunc func(req *http.Request, statusCode int, header http.Header) bool +type checkHeadersFunc func(req *http.Request, statusCode int, header http.Header) error + +var CheckHeaderFuncCheckError = errors.New("header no check") // LimitRule provides connection restrictions for domains. // Both DomainRegexp and DomainGlob can be used to specify @@ -141,7 +144,9 @@ func (h *httpBackend) Cache(request *http.Request, bodySize int, checkHeadersFun resp := new(Response) err := gob.NewDecoder(file).Decode(resp) file.Close() - checkHeadersFunc(request, resp.StatusCode, *resp.Headers) + if err = checkHeadersFunc(request, resp.StatusCode, *resp.Headers); err != nil { + return resp, err + } if resp.StatusCode < 500 { return resp, err } @@ -191,10 +196,10 @@ func (h *httpBackend) Do(request *http.Request, bodySize int, checkHeadersFunc c if res.Request != nil { finalRequest = res.Request } - if !checkHeadersFunc(finalRequest, res.StatusCode, res.Header) { + if err := checkHeadersFunc(finalRequest, res.StatusCode, res.Header); err != nil { // closing res.Body (see defer above) without reading it aborts // the download - return nil, ErrAbortedAfterHeaders + return nil, err } var bodyReader io.Reader = res.Body diff --git a/queue/queue_test.go b/queue/queue_test.go index 1d10f837..352484d7 100644 --- a/queue/queue_test.go +++ b/queue/queue_test.go @@ -48,7 +48,7 @@ func TestQueue(t *testing.T) { c.OnRequest(func(req *colly.Request) { atomic.AddUint32(&requests, 1) }) - c.OnResponse(func(resp *colly.Response) { + c.OnResponse(func(resp *colly.Response) error { if resp.StatusCode == http.StatusOK { atomic.AddUint32(&success, 1) } else { @@ -60,9 +60,11 @@ func TestQueue(t *testing.T) { if toss { put() } + return nil }) - c.OnError(func(resp *colly.Response, err error) { + c.OnError(func(resp *colly.Response, err error) error { atomic.AddUint32(&failure, 1) + return err }) err = q.Run(c) if err != nil {