@@ -24,6 +24,7 @@ import (
24
24
"context"
25
25
"fmt"
26
26
"net"
27
+ "net/url"
27
28
"strconv"
28
29
"sync"
29
30
"time"
@@ -36,6 +37,7 @@ import (
36
37
"github.com/dop251/goja"
37
38
k6common "go.k6.io/k6/js/common"
38
39
k6lib "go.k6.io/k6/lib"
40
+ k6types "go.k6.io/k6/lib/types"
39
41
k6stats "go.k6.io/k6/stats"
40
42
)
41
43
@@ -48,7 +50,7 @@ type NetworkManager struct {
48
50
49
51
ctx context.Context
50
52
logger * Logger
51
- session * Session
53
+ session session
52
54
parent * NetworkManager
53
55
frameManager * FrameManager
54
56
credentials * Credentials
@@ -235,10 +237,20 @@ func (m *NetworkManager) handleRequestRedirect(req *Request, redirectResponse *n
235
237
}
236
238
237
239
func (m * NetworkManager ) initDomains () error {
238
- action := network .Enable ()
239
- if err := action .Do (cdp .WithExecutor (m .ctx , m .session )); err != nil {
240
- return fmt .Errorf ("unable to execute %T: %w" , action , err )
240
+ actions := []Action {network .Enable ()}
241
+
242
+ // Only enable the Fetch domain if necessary, as it has a performance overhead.
243
+ if m .userReqInterceptionEnabled {
244
+ actions = append (actions ,
245
+ network .SetCacheDisabled (true ),
246
+ fetch .Enable ().WithPatterns ([]* fetch.RequestPattern {{URLPattern : "*" }}))
241
247
}
248
+ for _ , action := range actions {
249
+ if err := action .Do (cdp .WithExecutor (m .ctx , m .session )); err != nil {
250
+ return fmt .Errorf ("unable to execute %T: %w" , action , err )
251
+ }
252
+ }
253
+
242
254
return nil
243
255
}
244
256
@@ -250,6 +262,7 @@ func (m *NetworkManager) initEvents() {
250
262
cdproto .EventNetworkRequestWillBeSent ,
251
263
cdproto .EventNetworkRequestServedFromCache ,
252
264
cdproto .EventNetworkResponseReceived ,
265
+ cdproto .EventFetchRequestPaused ,
253
266
}, chHandler )
254
267
255
268
go func () {
@@ -279,6 +292,8 @@ func (m *NetworkManager) handleEvents(in <-chan Event) bool {
279
292
m .onRequestServedFromCache (ev )
280
293
case * network.EventResponseReceived :
281
294
m .onResponseReceived (ev )
295
+ case * fetch.EventRequestPaused :
296
+ m .onRequestPaused (ev )
282
297
}
283
298
}
284
299
return true
@@ -360,32 +375,58 @@ func (m *NetworkManager) onRequest(event *network.EventRequestWillBeSent, interc
360
375
m .reqsMu .Unlock ()
361
376
m .emitRequestMetrics (req )
362
377
m .frameManager .requestStarted (req )
378
+ }
363
379
364
- if m .userReqInterceptionEnabled {
365
- state := k6lib .GetState (m .ctx )
366
- ip := net .ParseIP (req .url .Host )
367
- blockedHosts := state .Options .BlockedHostnames .Trie
368
- if blockedHosts != nil && ip == nil {
369
- if match , blocked := blockedHosts .Contains (req .url .Host ); blocked {
370
- // Tell browser we've blocked this request.
371
- fetch .FailRequest (fetch .RequestID (req .getID ()), network .ErrorReasonBlockedByClient )
372
-
373
- // Throw exception into JS runtime
374
- rt := k6common .GetRuntime (m .ctx )
375
- // TODO: create PR to make netext.BlockedHostError a public struct in k6 perhaps?
376
- k6common .Throw (rt , fmt .Errorf ("hostname (%s) is in a blocked pattern (%s)" , req .url .Host , match ))
377
- }
380
+ func (m * NetworkManager ) onRequestPaused (event * fetch.EventRequestPaused ) {
381
+ m .logger .Debugf ("NetworkManager:onRequestPaused" ,
382
+ "sid:%s url:%v" , m .session .ID (), event .Request .URL )
383
+ defer m .logger .Debugf ("NetworkManager:onRequestPaused:return" ,
384
+ "sid:%s url:%v" , m .session .ID (), event .Request .URL )
385
+
386
+ var (
387
+ failReason string
388
+ state = k6lib .GetState (m .ctx )
389
+ )
390
+
391
+ defer func () { m .failOrContinueRequest (event , failReason ) }()
392
+
393
+ purl , err := url .Parse (event .Request .URL )
394
+ if err != nil {
395
+ m .logger .Errorf ("NetworkManager:onRequestPaused" ,
396
+ "error parsing URL: %s" , err .Error ())
397
+ return
398
+ }
399
+
400
+ failReason = handleBlockedHosts (purl , state .Options .BlockedHostnames .Trie )
401
+ }
402
+
403
+ func (m * NetworkManager ) failOrContinueRequest (event * fetch.EventRequestPaused , failReason string ) {
404
+ if failReason != "" {
405
+ action := fetch .FailRequest (event .RequestID , network .ErrorReasonBlockedByClient )
406
+ if err := action .Do (cdp .WithExecutor (m .ctx , m .session )); err != nil {
407
+ m .logger .Errorf ("NetworkManager:onRequestPaused" ,
408
+ "error interrupting request: %s" , err .Error ())
409
+ } else {
410
+ m .logger .Warnf ("NetworkManager:onRequestPaused" ,
411
+ "request %s %s was interrupted: %s" , event .Request .Method , event .Request .URL , failReason )
412
+ return
378
413
}
414
+ }
415
+ action := fetch .ContinueRequest (event .RequestID )
416
+ if err := action .Do (cdp .WithExecutor (m .ctx , m .session )); err != nil {
417
+ m .logger .Errorf ("NetworkManager:onRequestPaused" ,
418
+ "error continuing request: %s" , err .Error ())
419
+ }
420
+ }
379
421
380
- /*
381
- TODO: is there a way to do IP filtering without requiring a lookup here?
382
- for _, ipnet := range state.Options.BlacklistIPs {
383
- if ipnet.Contains(ev.Request.URL) {
384
- return "", netext.BlackListedIPError{ip: remote.IP, net: ipnet}
385
- }
386
- }
387
- */
422
+ func handleBlockedHosts (u * url.URL , blockedHosts * k6types.HostnameTrie ) string {
423
+ ip := net .ParseIP (u .Host )
424
+ if ip == nil {
425
+ if match , blocked := blockedHosts .Contains (u .Host ); blocked {
426
+ return fmt .Sprintf ("hostname %s is in a blocked pattern (%s)" , u .Host , match )
427
+ }
388
428
}
429
+ return ""
389
430
}
390
431
391
432
func (m * NetworkManager ) onRequestServedFromCache (event * network.EventRequestServedFromCache ) {
0 commit comments