diff --git a/colly.go b/colly.go
index fdca9451..1d6cc9eb 100644
--- a/colly.go
+++ b/colly.go
@@ -140,19 +140,19 @@ type Collector struct {
type RequestCallback func(*Request)
// ResponseHeadersCallback is a type alias for OnResponseHeaders callback functions
-type ResponseHeadersCallback func(*Response)
+type ResponseHeadersCallback func(*Response) error
// ResponseCallback is a type alias for OnResponse callback functions
-type ResponseCallback func(*Response)
+type ResponseCallback func(*Response) error
// HTMLCallback is a type alias for OnHTML callback functions
-type HTMLCallback func(*HTMLElement)
+type HTMLCallback func(*HTMLElement) error
// XMLCallback is a type alias for OnXML callback functions
-type XMLCallback func(*XMLElement)
+type XMLCallback func(*XMLElement) error
// ErrorCallback is a type alias for OnError callback functions
-type ErrorCallback func(*Response, error)
+type ErrorCallback func(*Response, error) error
// ScrapedCallback is a type alias for OnScraped callback functions
type ScrapedCallback func(*Response)
@@ -691,13 +691,18 @@ func (c *Collector) fetch(u, method string, depth int, requestData io.Reader, ct
req = hTrace.WithTrace(req)
}
origURL := req.URL
- checkHeadersFunc := func(req *http.Request, statusCode int, headers http.Header) bool {
+ checkHeadersFunc := func(req *http.Request, statusCode int, headers http.Header) error {
if req.URL != origURL {
request.URL = req.URL
request.Headers = &req.Header
}
- c.handleOnResponseHeaders(&Response{Ctx: ctx, Request: request, StatusCode: statusCode, Headers: &headers})
- return !request.abort
+ if err := c.handleOnResponseHeaders(&Response{Ctx: ctx, Request: request, StatusCode: statusCode, Headers: &headers}); err != nil {
+ return err
+ }
+ if request.abort {
+ return ErrAbortedAfterHeaders
+ }
+ return nil
}
response, err := c.backend.Cache(req, c.MaxBodySize, checkHeadersFunc, c.CacheDir)
if proxyURL, ok := req.Context().Value(ProxyURLKey).(string); ok {
@@ -716,16 +721,24 @@ func (c *Collector) fetch(u, method string, depth int, requestData io.Reader, ct
return err
}
- c.handleOnResponse(response)
+ if err := c.handleOnResponse(response); err != nil {
+ if err = c.handleOnError(response, err, request, ctx); err != nil {
+ return err
+ }
+ }
err = c.handleOnHTML(response)
if err != nil {
- c.handleOnError(response, err, request, ctx)
+ if err = c.handleOnError(response, err, request, ctx); err != nil {
+ return err
+ }
}
err = c.handleOnXML(response)
if err != nil {
- c.handleOnError(response, err, request, ctx)
+ if err = c.handleOnError(response, err, request, ctx); err != nil {
+ return err
+ }
}
c.handleOnScraped(response)
@@ -1083,7 +1096,7 @@ func (c *Collector) handleOnRequest(r *Request) {
}
}
-func (c *Collector) handleOnResponse(r *Response) {
+func (c *Collector) handleOnResponse(r *Response) error {
if c.debugger != nil {
c.debugger.Event(createEvent("response", r.Request.ID, c.ID, map[string]string{
"url": r.Request.URL.String(),
@@ -1091,11 +1104,14 @@ func (c *Collector) handleOnResponse(r *Response) {
}))
}
for _, f := range c.responseCallbacks {
- f(r)
+ if err := f(r); err != nil {
+ return err
+ }
}
+ return nil
}
-func (c *Collector) handleOnResponseHeaders(r *Response) {
+func (c *Collector) handleOnResponseHeaders(r *Response) error {
if c.debugger != nil {
c.debugger.Event(createEvent("responseHeaders", r.Request.ID, c.ID, map[string]string{
"url": r.Request.URL.String(),
@@ -1103,17 +1119,20 @@ func (c *Collector) handleOnResponseHeaders(r *Response) {
}))
}
for _, f := range c.responseHeadersCallbacks {
- f(r)
+ if err := f(r); err != nil {
+ return err
+ }
}
+ return nil
}
-func (c *Collector) handleOnHTML(resp *Response) error {
+func (c *Collector) handleOnHTML(resp *Response) (err error) {
if len(c.htmlCallbacks) == 0 || !strings.Contains(strings.ToLower(resp.Headers.Get("Content-Type")), "html") {
return nil
}
doc, err := goquery.NewDocumentFromReader(bytes.NewBuffer(resp.Body))
if err != nil {
- return err
+ return
}
if href, found := doc.Find("base[href]").Attr("href"); found {
u, err := urlParser.ParseRef(resp.Request.URL.String(), href)
@@ -1127,7 +1146,7 @@ func (c *Collector) handleOnHTML(resp *Response) error {
}
for _, cc := range c.htmlCallbacks {
i := 0
- doc.Find(cc.Selector).Each(func(_ int, s *goquery.Selection) {
+ doc.Find(cc.Selector).EachWithBreak(func(_ int, s *goquery.Selection) bool {
for _, n := range s.Nodes {
e := NewHTMLElementFromSelectionNode(resp, s, n, i)
i++
@@ -1137,11 +1156,17 @@ func (c *Collector) handleOnHTML(resp *Response) error {
"url": resp.Request.URL.String(),
}))
}
- cc.Function(e)
+ if err = cc.Function(e); err != nil {
+ return false
+ }
}
+ return true
})
+ if err != nil {
+ return
+ }
}
- return nil
+ return
}
func (c *Collector) handleOnXML(resp *Response) error {
@@ -1180,7 +1205,9 @@ func (c *Collector) handleOnXML(resp *Response) error {
"url": resp.Request.URL.String(),
}))
}
- cc.Function(e)
+ if err := cc.Function(e); err != nil {
+ return err
+ }
}
}
} else if strings.Contains(contentType, "xml") || isXMLFile {
@@ -1198,8 +1225,11 @@ func (c *Collector) handleOnXML(resp *Response) error {
"url": resp.Request.URL.String(),
}))
}
- cc.Function(e)
+ err = cc.Function(e)
})
+ if err != nil {
+ return err
+ }
}
}
return nil
@@ -1230,10 +1260,15 @@ func (c *Collector) handleOnError(response *Response, err error, request *Reques
if response.Ctx == nil {
response.Ctx = request.Ctx
}
+ if len(c.errorCallbacks) < 1 {
+ return err
+ }
for _, f := range c.errorCallbacks {
- f(response, err)
+ if abortErr := f(response, err); abortErr != nil {
+ return abortErr
+ }
}
- return err
+ return nil
}
func (c *Collector) handleOnScraped(r *Response) {
diff --git a/colly_test.go b/colly_test.go
index 4358b63e..215a085e 100644
--- a/colly_test.go
+++ b/colly_test.go
@@ -447,7 +447,7 @@ func TestCollectorVisit(t *testing.T) {
r.Ctx.Put("x", "y")
})
- c.OnResponse(func(r *Response) {
+ c.OnResponse(func(r *Response) error {
onResponseCalled = true
if r.Ctx.Get("x") != "y" {
@@ -457,6 +457,7 @@ func TestCollectorVisit(t *testing.T) {
if !bytes.Equal(r.Body, serverIndexResponse) {
t.Error("Response body does not match with the original content")
}
+ return nil
})
c.OnScraped(func(r *Response) {
@@ -530,14 +531,16 @@ func TestCollectorVisitResponseHeaders(t *testing.T) {
var onResponseHeadersCalled bool
c := NewCollector()
- c.OnResponseHeaders(func(r *Response) {
+ c.OnResponseHeaders(func(r *Response) error {
onResponseHeadersCalled = true
if r.Headers.Get("Content-Type") == "application/octet-stream" {
r.Request.Abort()
}
+ return nil
})
- c.OnResponse(func(r *Response) {
+ c.OnResponse(func(r *Response) error {
t.Error("OnResponse was called")
+ return nil
})
c.Visit(ts.URL + "/large_binary")
if !onResponseHeadersCalled {
@@ -554,21 +557,23 @@ func TestCollectorOnHTML(t *testing.T) {
titleCallbackCalled := false
paragraphCallbackCount := 0
- c.OnHTML("title", func(e *HTMLElement) {
+ c.OnHTML("title", func(e *HTMLElement) error {
titleCallbackCalled = true
if e.Text != "Test Page" {
t.Error("Title element text does not match, got", e.Text)
}
+ return nil
})
- c.OnHTML("p", func(e *HTMLElement) {
+ c.OnHTML("p", func(e *HTMLElement) error {
paragraphCallbackCount++
if e.Attr("class") != "description" {
t.Error("Failed to get paragraph's class attribute")
}
+ return nil
})
- c.OnHTML("body", func(e *HTMLElement) {
+ c.OnHTML("body", func(e *HTMLElement) error {
if e.ChildAttr("p", "class") != "description" {
t.Error("Invalid class value")
}
@@ -576,6 +581,7 @@ func TestCollectorOnHTML(t *testing.T) {
if len(classes) != 2 {
t.Error("Invalid class values")
}
+ return nil
})
c.Visit(ts.URL + "/html")
@@ -629,11 +635,12 @@ func TestCollectorPostRevisit(t *testing.T) {
visitCount := 0
c := NewCollector()
- c.OnResponse(func(r *Response) {
+ c.OnResponse(func(r *Response) error {
if postValue != string(r.Body) {
t.Error("Failed to send data with POST")
}
visitCount++
+ return nil
})
c.Post(ts.URL+"/login", postData)
@@ -730,13 +737,14 @@ func TestSetCookieRedirect(t *testing.T) {
ts.Start()
defer ts.Close()
c := NewCollector()
- c.OnResponse(func(r *Response) {
+ c.OnResponse(func(r *Response) error {
if got, want := r.Body, serverIndexResponse; !bytes.Equal(got, want) {
t.Errorf("bad response body got=%q want=%q", got, want)
}
if got, want := r.StatusCode, http.StatusOK; got != want {
t.Errorf("bad response code got=%d want=%d", got, want)
}
+ return nil
})
if err := c.Visit(ts.URL); err != nil {
t.Fatal(err)
@@ -831,10 +839,11 @@ func TestCollectorPost(t *testing.T) {
postValue := "hello"
c := NewCollector()
- c.OnResponse(func(r *Response) {
+ c.OnResponse(func(r *Response) error {
if postValue != string(r.Body) {
t.Error("Failed to send data with POST")
}
+ return nil
})
c.Post(ts.URL+"/login", map[string]string{
@@ -849,10 +858,11 @@ func TestCollectorPostRaw(t *testing.T) {
postValue := "hello"
c := NewCollector()
- c.OnResponse(func(r *Response) {
+ c.OnResponse(func(r *Response) error {
if postValue != string(r.Body) {
t.Error("Failed to send data with POST")
}
+ return nil
})
c.PostRaw(ts.URL+"/login", []byte("name="+postValue))
@@ -867,11 +877,12 @@ func TestCollectorPostRawRevisit(t *testing.T) {
visitCount := 0
c := NewCollector()
- c.OnResponse(func(r *Response) {
+ c.OnResponse(func(r *Response) error {
if postValue != string(r.Body) {
t.Error("Failed to send data with POST RAW")
}
visitCount++
+ return nil
})
c.PostRaw(ts.URL+"/login", []byte(postData))
@@ -897,23 +908,26 @@ func TestRedirect(t *testing.T) {
defer ts.Close()
c := NewCollector()
- c.OnHTML("a[href]", func(e *HTMLElement) {
+ c.OnHTML("a[href]", func(e *HTMLElement) error {
u := e.Request.AbsoluteURL(e.Attr("href"))
if !strings.HasSuffix(u, "/redirected/test") {
t.Error("Invalid URL after redirect: " + u)
}
+ return nil
})
- c.OnResponseHeaders(func(r *Response) {
+ c.OnResponseHeaders(func(r *Response) error {
if !strings.HasSuffix(r.Request.URL.String(), "/redirected/") {
t.Error("Invalid URL in Request after redirect (OnResponseHeaders): " + r.Request.URL.String())
}
+ return nil
})
- c.OnResponse(func(r *Response) {
+ c.OnResponse(func(r *Response) error {
if !strings.HasSuffix(r.Request.URL.String(), "/redirected/") {
t.Error("Invalid URL in Request after redirect (OnResponse): " + r.Request.URL.String())
}
+ return nil
})
c.Visit(ts.URL + "/redirect")
}
@@ -937,12 +951,13 @@ func TestRedirectWithDisallowedURLs(t *testing.T) {
c := NewCollector()
c.DisallowedURLFilters = []*regexp.Regexp{regexp.MustCompile(ts.URL + "/redirected/test")}
- c.OnHTML("a[href]", func(e *HTMLElement) {
+ c.OnHTML("a[href]", func(e *HTMLElement) error {
u := e.Request.AbsoluteURL(e.Attr("href"))
err := c.Visit(u)
if !errors.Is(err, ErrForbiddenURL) {
t.Error("URL should have been forbidden: " + u)
}
+ return nil
})
c.Visit(ts.URL + "/redirect")
@@ -953,20 +968,22 @@ func TestBaseTag(t *testing.T) {
defer ts.Close()
c := NewCollector()
- c.OnHTML("a[href]", func(e *HTMLElement) {
+ c.OnHTML("a[href]", func(e *HTMLElement) error {
u := e.Request.AbsoluteURL(e.Attr("href"))
if u != "http://xy.com/z" {
t.Error("Invalid tag handling in OnHTML: expected https://xy.com/z, got " + u)
}
+ return nil
})
c.Visit(ts.URL + "/base")
c2 := NewCollector()
- c2.OnXML("//a", func(e *XMLElement) {
+ c2.OnXML("//a", func(e *XMLElement) error {
u := e.Request.AbsoluteURL(e.Attr("href"))
if u != "http://xy.com/z" {
t.Error("Invalid tag handling in OnXML: expected https://xy.com/z, got " + u)
}
+ return nil
})
c2.Visit(ts.URL + "/base")
}
@@ -976,22 +993,24 @@ func TestBaseTagRelative(t *testing.T) {
defer ts.Close()
c := NewCollector()
- c.OnHTML("a[href]", func(e *HTMLElement) {
+ c.OnHTML("a[href]", func(e *HTMLElement) error {
u := e.Request.AbsoluteURL(e.Attr("href"))
expected := ts.URL + "/foobar/z"
if u != expected {
t.Errorf("Invalid tag handling in OnHTML: expected %q, got %q", expected, u)
}
+ return nil
})
c.Visit(ts.URL + "/base_relative")
c2 := NewCollector()
- c2.OnXML("//a", func(e *XMLElement) {
+ c2.OnXML("//a", func(e *XMLElement) error {
u := e.Request.AbsoluteURL(e.Attr("href"))
expected := ts.URL + "/foobar/z"
if u != expected {
t.Errorf("Invalid tag handling in OnXML: expected %q, got %q", expected, u)
}
+ return nil
})
c2.Visit(ts.URL + "/base_relative")
}
@@ -1010,13 +1029,15 @@ func TestTabsAndNewlines(t *testing.T) {
}
c := NewCollector()
- c.OnResponse(func(res *Response) {
+ c.OnResponse(func(res *Response) error {
visited[res.Request.URL.EscapedPath()] = struct{}{}
+ return nil
})
- c.OnHTML("a[href]", func(e *HTMLElement) {
+ c.OnHTML("a[href]", func(e *HTMLElement) error {
if err := e.Request.Visit(e.Attr("href")); err != nil {
t.Errorf("visit failed: %v", err)
}
+ return nil
})
if err := c.Visit(ts.URL + "/tabs_and_newlines"); err != nil {
@@ -1035,8 +1056,9 @@ func TestLonePercent(t *testing.T) {
var visitedPath string
c := NewCollector()
- c.OnResponse(func(res *Response) {
+ c.OnResponse(func(res *Response) error {
visitedPath = res.Request.URL.RequestURI()
+ return nil
})
if err := c.Visit(ts.URL + "/100%"); err != nil {
t.Errorf("visit failed: %v", err)
@@ -1083,10 +1105,11 @@ func TestRobotsWhenAllowed(t *testing.T) {
c := NewCollector()
c.IgnoreRobotsTxt = false
- c.OnResponse(func(resp *Response) {
+ c.OnResponse(func(resp *Response) error {
if resp.StatusCode != 200 {
t.Fatalf("Wrong response code: %d", resp.StatusCode)
}
+ return nil
})
err := c.Visit(ts.URL + "/allowed")
@@ -1103,8 +1126,9 @@ func TestRobotsWhenDisallowed(t *testing.T) {
c := NewCollector()
c.IgnoreRobotsTxt = false
- c.OnResponse(func(resp *Response) {
+ c.OnResponse(func(resp *Response) error {
t.Fatalf("Received response: %d", resp.StatusCode)
+ return nil
})
err := c.Visit(ts.URL + "/disallowed")
@@ -1120,8 +1144,9 @@ func TestRobotsWhenDisallowedWithQueryParameter(t *testing.T) {
c := NewCollector()
c.IgnoreRobotsTxt = false
- c.OnResponse(func(resp *Response) {
+ c.OnResponse(func(resp *Response) error {
t.Fatalf("Received response: %d", resp.StatusCode)
+ return nil
})
err := c.Visit(ts.URL + "/allowed?q=1")
@@ -1137,10 +1162,11 @@ func TestIgnoreRobotsWhenDisallowed(t *testing.T) {
c := NewCollector()
c.IgnoreRobotsTxt = true
- c.OnResponse(func(resp *Response) {
+ c.OnResponse(func(resp *Response) error {
if resp.StatusCode != 200 {
t.Fatalf("Wrong response code: %d", resp.StatusCode)
}
+ return nil
})
err := c.Visit(ts.URL + "/disallowed")
@@ -1175,10 +1201,11 @@ func TestEnvSettings(t *testing.T) {
valid := false
- c.OnResponse(func(resp *Response) {
+ c.OnResponse(func(resp *Response) error {
if string(resp.Body) == "test" {
valid = true
}
+ return nil
})
c.Visit(ts.URL + "/user_agent")
@@ -1200,8 +1227,9 @@ func TestUserAgent(t *testing.T) {
func() {
c := NewCollector()
- c.OnResponse(func(resp *Response) {
+ c.OnResponse(func(resp *Response) error {
receivedUserAgent = string(resp.Body)
+ return nil
})
c.Visit(ts.URL + "/user_agent")
if got, want := receivedUserAgent, defaultUserAgent; got != want {
@@ -1210,8 +1238,9 @@ func TestUserAgent(t *testing.T) {
}()
func() {
c := NewCollector(UserAgent(exampleUserAgent1))
- c.OnResponse(func(resp *Response) {
+ c.OnResponse(func(resp *Response) error {
receivedUserAgent = string(resp.Body)
+ return nil
})
c.Visit(ts.URL + "/user_agent")
if got, want := receivedUserAgent, exampleUserAgent1; got != want {
@@ -1220,8 +1249,9 @@ func TestUserAgent(t *testing.T) {
}()
func() {
c := NewCollector(UserAgent(exampleUserAgent1))
- c.OnResponse(func(resp *Response) {
+ c.OnResponse(func(resp *Response) error {
receivedUserAgent = string(resp.Body)
+ return nil
})
c.Request("GET", ts.URL+"/user_agent", nil, nil, nil)
@@ -1231,8 +1261,9 @@ func TestUserAgent(t *testing.T) {
}()
func() {
c := NewCollector(UserAgent(exampleUserAgent1))
- c.OnResponse(func(resp *Response) {
+ c.OnResponse(func(resp *Response) error {
receivedUserAgent = string(resp.Body)
+ return nil
})
c.Request("GET", ts.URL+"/user_agent", nil, nil, http.Header{})
@@ -1242,8 +1273,9 @@ func TestUserAgent(t *testing.T) {
}()
func() {
c := NewCollector(UserAgent(exampleUserAgent1))
- c.OnResponse(func(resp *Response) {
+ c.OnResponse(func(resp *Response) error {
receivedUserAgent = string(resp.Body)
+ return nil
})
hdr := http.Header{}
hdr.Set("User-Agent", "")
@@ -1255,8 +1287,9 @@ func TestUserAgent(t *testing.T) {
}()
func() {
c := NewCollector(UserAgent(exampleUserAgent1))
- c.OnResponse(func(resp *Response) {
+ c.OnResponse(func(resp *Response) error {
receivedUserAgent = string(resp.Body)
+ return nil
})
hdr := http.Header{}
hdr.Set("User-Agent", exampleUserAgent2)
@@ -1281,8 +1314,9 @@ func TestHeaders(t *testing.T) {
c := NewCollector(
Headers(map[string]string{"Host": exampleHostHeader}),
)
- c.OnResponse(func(resp *Response) {
+ c.OnResponse(func(resp *Response) error {
receivedHeader = string(resp.Body)
+ return nil
})
c.Visit(ts.URL + "/host_header")
if got, want := receivedHeader, exampleHostHeader; got != want {
@@ -1293,8 +1327,9 @@ func TestHeaders(t *testing.T) {
c := NewCollector(
Headers(map[string]string{"Test": exampleTestHeader}),
)
- c.OnResponse(func(resp *Response) {
+ c.OnResponse(func(resp *Response) error {
receivedHeader = string(resp.Body)
+ return nil
})
c.Visit(ts.URL + "/custom_header")
if got, want := receivedHeader, exampleTestHeader; got != want {
@@ -1312,10 +1347,11 @@ func TestParseHTTPErrorResponse(t *testing.T) {
AllowURLRevisit(),
)
- c.OnHTML("p", func(e *HTMLElement) {
+ c.OnHTML("p", func(e *HTMLElement) error {
if e.Text == "error" {
contentCount++
}
+ return nil
})
c.Visit(ts.URL + "/500")
@@ -1382,21 +1418,23 @@ func TestCollectorOnXMLWithHtml(t *testing.T) {
titleCallbackCalled := false
paragraphCallbackCount := 0
- c.OnXML("/html/head/title", func(e *XMLElement) {
+ c.OnXML("/html/head/title", func(e *XMLElement) error {
titleCallbackCalled = true
if e.Text != "Test Page" {
t.Error("Title element text does not match, got", e.Text)
}
+ return nil
})
- c.OnXML("/html/body/p", func(e *XMLElement) {
+ c.OnXML("/html/body/p", func(e *XMLElement) error {
paragraphCallbackCount++
if e.Attr("class") != "description" {
t.Error("Failed to get paragraph's class attribute")
}
+ return nil
})
- c.OnXML("/html/body", func(e *XMLElement) {
+ c.OnXML("/html/body", func(e *XMLElement) error {
if e.ChildAttr("p", "class") != "description" {
t.Error("Invalid class value")
}
@@ -1404,6 +1442,7 @@ func TestCollectorOnXMLWithHtml(t *testing.T) {
if len(classes) != 2 {
t.Error("Invalid class values")
}
+ return nil
})
c.Visit(ts.URL + "/html")
@@ -1426,21 +1465,23 @@ func TestCollectorOnXMLWithXML(t *testing.T) {
titleCallbackCalled := false
paragraphCallbackCount := 0
- c.OnXML("//page/title", func(e *XMLElement) {
+ c.OnXML("//page/title", func(e *XMLElement) error {
titleCallbackCalled = true
if e.Text != "Test Page" {
t.Error("Title element text does not match, got", e.Text)
}
+ return nil
})
- c.OnXML("//page/paragraph", func(e *XMLElement) {
+ c.OnXML("//page/paragraph", func(e *XMLElement) error {
paragraphCallbackCount++
if e.Attr("type") != "description" {
t.Error("Failed to get paragraph's type attribute")
}
+ return nil
})
- c.OnXML("/page", func(e *XMLElement) {
+ c.OnXML("/page", func(e *XMLElement) error {
if e.ChildAttr("paragraph", "type") != "description" {
t.Error("Invalid type value")
}
@@ -1448,6 +1489,7 @@ func TestCollectorOnXMLWithXML(t *testing.T) {
if len(classes) != 2 {
t.Error("Invalid type values")
}
+ return nil
})
c.Visit(ts.URL + "/xml")
@@ -1466,10 +1508,11 @@ func TestCollectorVisitWithTrace(t *testing.T) {
defer ts.Close()
c := NewCollector(AllowedDomains("localhost", "127.0.0.1", "::1"), TraceHTTP())
- c.OnResponse(func(resp *Response) {
+ c.OnResponse(func(resp *Response) error {
if resp.Trace == nil {
t.Error("Failed to initialize trace")
}
+ return nil
})
err := c.Visit(ts.URL)
@@ -1484,8 +1527,9 @@ func TestCollectorVisitWithCheckHead(t *testing.T) {
c := NewCollector(CheckHead())
var requestMethodChain []string
- c.OnResponse(func(resp *Response) {
+ c.OnResponse(func(resp *Response) error {
requestMethodChain = append(requestMethodChain, resp.Request.Method)
+ return nil
})
err := c.Visit(ts.URL)
@@ -1506,12 +1550,13 @@ func TestCollectorDepth(t *testing.T) {
AllowURLRevisit(),
)
requestCount := 0
- c1.OnResponse(func(resp *Response) {
+ c1.OnResponse(func(resp *Response) error {
requestCount++
if requestCount >= 10 {
- return
+ return nil
}
c1.Visit(ts.URL)
+ return nil
})
c1.Visit(ts.URL)
if requestCount < 10 {
@@ -1520,9 +1565,10 @@ func TestCollectorDepth(t *testing.T) {
c2 := c1.Clone()
requestCount = 0
- c2.OnResponse(func(resp *Response) {
+ c2.OnResponse(func(resp *Response) error {
requestCount++
resp.Request.Visit(ts.URL)
+ return nil
})
c2.Visit(ts.URL)
if requestCount != 2 {
@@ -1550,9 +1596,10 @@ func TestCollectorRequests(t *testing.T) {
AllowURLRevisit(),
)
requestCount := 0
- c1.OnResponse(func(resp *Response) {
+ c1.OnResponse(func(resp *Response) error {
requestCount++
c1.Visit(ts.URL)
+ return nil
})
c1.Visit(ts.URL)
if requestCount != 5 {
@@ -1575,15 +1622,17 @@ func TestCollectorContext(t *testing.T) {
onErrorCalled := false
- c.OnResponse(func(resp *Response) {
+ c.OnResponse(func(resp *Response) error {
t.Error("OnResponse was called, expected OnError")
+ return nil
})
- c.OnError(func(resp *Response, err error) {
+ c.OnError(func(resp *Response, err error) error {
onErrorCalled = true
if err != context.DeadlineExceeded {
t.Errorf("OnError got err=%#v, expected context.DeadlineExceeded", err)
}
+ return err
})
err := c.Visit(ts.URL + "/slow")
@@ -1602,7 +1651,9 @@ func BenchmarkOnHTML(b *testing.B) {
defer ts.Close()
c := NewCollector()
- c.OnHTML("p", func(_ *HTMLElement) {})
+ c.OnHTML("p", func(_ *HTMLElement) error {
+ return nil
+ })
for n := 0; n < b.N; n++ {
c.Visit(fmt.Sprintf("%s/html?q=%d", ts.URL, n))
@@ -1614,7 +1665,9 @@ func BenchmarkOnXML(b *testing.B) {
defer ts.Close()
c := NewCollector()
- c.OnXML("//p", func(_ *XMLElement) {})
+ c.OnXML("//p", func(_ *XMLElement) error {
+ return nil
+ })
for n := 0; n < b.N; n++ {
c.Visit(fmt.Sprintf("%s/html?q=%d", ts.URL, n))
@@ -1627,7 +1680,9 @@ func BenchmarkOnResponse(b *testing.B) {
c := NewCollector()
c.AllowURLRevisit = true
- c.OnResponse(func(_ *Response) {})
+ c.OnResponse(func(_ *Response) error {
+ return nil
+ })
for n := 0; n < b.N; n++ {
c.Visit(ts.URL)
@@ -1665,3 +1720,72 @@ func requireSessionCookieAuthPage(handler http.Handler) http.Handler {
handler.ServeHTTP(w, r)
})
}
+func TestResponseContextError(t *testing.T) {
+ // Test in progress of Responsing error routing to onerror
+
+ ts := newTestServer()
+ defer ts.Close()
+
+ //test response error
+ c := NewCollector()
+ var onErrorCalled = false
+ var responseErrored = errors.New("responseErrored")
+ c.OnResponse(func(resp *Response) error {
+ return responseErrored
+ })
+ c.OnError(func(resp *Response, err error) error {
+ onErrorCalled = true
+ if err != responseErrored {
+ t.Errorf("OnError got err=%#v, expected responseErrored", err)
+ }
+ return nil
+ })
+ err := c.Visit(ts.URL + "/html")
+ if err != nil {
+ t.Errorf("Visit return err=%#v, expected nil", err)
+ }
+ if !onErrorCalled {
+ t.Error("OnError was not called")
+ }
+
+ //test onhtml error
+ c = NewCollector()
+ c.OnHTML("*", func(*HTMLElement) error {
+ return responseErrored
+ })
+ c.OnError(func(resp *Response, err error) error {
+ onErrorCalled = true
+ if err != responseErrored {
+ t.Errorf("OnError got err=%#v, expected responseErrored", err)
+ }
+ return nil
+ })
+ err = c.Visit(ts.URL + "/html")
+ if err != nil {
+ t.Errorf("Visit return err=%#v, expected nil", err)
+ }
+ if !onErrorCalled {
+ t.Error("OnError was not called")
+ }
+
+ //test onxml error
+ c = NewCollector()
+ c.OnXML("*", func(*XMLElement) error {
+ return responseErrored
+ })
+ c.OnError(func(resp *Response, err error) error {
+ onErrorCalled = true
+ if err != responseErrored {
+ t.Errorf("OnError got err=%#v, expected responseErrored", err)
+ }
+ return nil
+ })
+ err = c.Visit(ts.URL + "/xml")
+ if err != nil {
+ t.Errorf("Visit return err=%#v, expected nil", err)
+ }
+ if !onErrorCalled {
+ t.Error("OnError was not called")
+ }
+
+}
diff --git a/extensions/referer.go b/extensions/referer.go
index 32a1c69e..0bdc4949 100644
--- a/extensions/referer.go
+++ b/extensions/referer.go
@@ -8,8 +8,9 @@ import (
// Warning: this extension works only if you use Request.Visit
// from callbacks instead of Collector.Visit.
func Referer(c *colly.Collector) {
- c.OnResponse(func(r *colly.Response) {
+ c.OnResponse(func(r *colly.Response) error {
r.Ctx.Put("_referer", r.Request.URL.String())
+ return nil
})
c.OnRequest(func(r *colly.Request) {
if ref := r.Ctx.Get("_referer"); ref != "" {
diff --git a/http_backend.go b/http_backend.go
index f48df628..7d9a4b22 100644
--- a/http_backend.go
+++ b/http_backend.go
@@ -18,6 +18,7 @@ import (
"crypto/sha1"
"encoding/gob"
"encoding/hex"
+ "errors"
"io"
"io/ioutil"
"math/rand"
@@ -40,7 +41,9 @@ type httpBackend struct {
lock *sync.RWMutex
}
-type checkHeadersFunc func(req *http.Request, statusCode int, header http.Header) bool
+type checkHeadersFunc func(req *http.Request, statusCode int, header http.Header) error
+
+var CheckHeaderFuncCheckError = errors.New("header no check")
// LimitRule provides connection restrictions for domains.
// Both DomainRegexp and DomainGlob can be used to specify
@@ -141,7 +144,9 @@ func (h *httpBackend) Cache(request *http.Request, bodySize int, checkHeadersFun
resp := new(Response)
err := gob.NewDecoder(file).Decode(resp)
file.Close()
- checkHeadersFunc(request, resp.StatusCode, *resp.Headers)
+ if err = checkHeadersFunc(request, resp.StatusCode, *resp.Headers); err != nil {
+ return resp, err
+ }
if resp.StatusCode < 500 {
return resp, err
}
@@ -191,10 +196,10 @@ func (h *httpBackend) Do(request *http.Request, bodySize int, checkHeadersFunc c
if res.Request != nil {
finalRequest = res.Request
}
- if !checkHeadersFunc(finalRequest, res.StatusCode, res.Header) {
+ if err := checkHeadersFunc(finalRequest, res.StatusCode, res.Header); err != nil {
// closing res.Body (see defer above) without reading it aborts
// the download
- return nil, ErrAbortedAfterHeaders
+ return nil, err
}
var bodyReader io.Reader = res.Body
diff --git a/queue/queue_test.go b/queue/queue_test.go
index 1d10f837..352484d7 100644
--- a/queue/queue_test.go
+++ b/queue/queue_test.go
@@ -48,7 +48,7 @@ func TestQueue(t *testing.T) {
c.OnRequest(func(req *colly.Request) {
atomic.AddUint32(&requests, 1)
})
- c.OnResponse(func(resp *colly.Response) {
+ c.OnResponse(func(resp *colly.Response) error {
if resp.StatusCode == http.StatusOK {
atomic.AddUint32(&success, 1)
} else {
@@ -60,9 +60,11 @@ func TestQueue(t *testing.T) {
if toss {
put()
}
+ return nil
})
- c.OnError(func(resp *colly.Response, err error) {
+ c.OnError(func(resp *colly.Response, err error) error {
atomic.AddUint32(&failure, 1)
+ return err
})
err = q.Run(c)
if err != nil {