diff --git a/go.mod b/go.mod index ecfb024612..1ab5f7babc 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ retract v0.26.1 // Tag was applied incorrectly due to a bug in the release workf retract v0.36.0 // Accidentally modified the tag. require ( + git.sr.ht/~marcopolo/di v0.0.4 github.com/benbjohnson/clock v1.3.5 github.com/davidlazar/go-crypto v0.0.0-20200604182044-b73af7476f6c github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 diff --git a/go.sum b/go.sum index 41f001038d..734f50de6d 100644 --- a/go.sum +++ b/go.sum @@ -7,6 +7,8 @@ dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0/go.mod h1:JLBr dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1:a1inKt/atXimZ4Mv927x+r7UpyzRUf4emIoiiSC2TN4= dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D6DFvNNtx+9ybjezNCa8XF0xaYcETyp6rHWU= git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= +git.sr.ht/~marcopolo/di v0.0.4 h1:BJOTly/cSV7lXvSurTg+derTF+gS2qgz31aN1OUxxtQ= +git.sr.ht/~marcopolo/di v0.0.4/go.mod h1:lLURtWN1LBR3r9P+VA9O3SCJ7hBxYDv55YMLP997M7Q= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o= diff --git a/internal/limits/limits.go b/internal/limits/limits.go new file mode 100644 index 0000000000..637152c752 --- /dev/null +++ b/internal/limits/limits.go @@ -0,0 +1,113 @@ +package limits + +import ( + "github.com/libp2p/go-libp2p/core/protocol" + "github.com/libp2p/go-libp2p/p2p/host/autonat" + rcmgr "github.com/libp2p/go-libp2p/p2p/host/resource-manager" + circuit "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/proto" + relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" + "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" + "github.com/libp2p/go-libp2p/p2p/protocol/identify" + "github.com/libp2p/go-libp2p/p2p/protocol/ping" +) + +// SetDefaultServiceLimits sets the default limits for bundled libp2p services +func SetDefaultServiceLimits(config *rcmgr.ScalingLimitConfig) { + // identify + config.AddServiceLimit( + identify.ServiceName, + rcmgr.BaseLimit{StreamsInbound: 64, StreamsOutbound: 64, Streams: 128, Memory: 4 << 20}, + rcmgr.BaseLimitIncrease{StreamsInbound: 64, StreamsOutbound: 64, Streams: 128, Memory: 4 << 20}, + ) + config.AddServicePeerLimit( + identify.ServiceName, + rcmgr.BaseLimit{StreamsInbound: 16, StreamsOutbound: 16, Streams: 32, Memory: 1 << 20}, + rcmgr.BaseLimitIncrease{}, + ) + for _, id := range [...]protocol.ID{identify.ID, identify.IDPush} { + config.AddProtocolLimit( + id, + rcmgr.BaseLimit{StreamsInbound: 64, StreamsOutbound: 64, Streams: 128, Memory: 4 << 20}, + rcmgr.BaseLimitIncrease{StreamsInbound: 64, StreamsOutbound: 64, Streams: 128, Memory: 4 << 20}, + ) + config.AddProtocolPeerLimit( + id, + rcmgr.BaseLimit{StreamsInbound: 16, StreamsOutbound: 16, Streams: 32, Memory: 32 * (256<<20 + 16<<10)}, + rcmgr.BaseLimitIncrease{}, + ) + } + + // ping + addServiceAndProtocolLimit(config, + ping.ServiceName, ping.ID, + rcmgr.BaseLimit{StreamsInbound: 64, StreamsOutbound: 64, Streams: 64, Memory: 4 << 20}, + rcmgr.BaseLimitIncrease{StreamsInbound: 64, StreamsOutbound: 64, Streams: 64, Memory: 4 << 20}, + ) + addServicePeerAndProtocolPeerLimit( + config, + ping.ServiceName, ping.ID, + rcmgr.BaseLimit{StreamsInbound: 2, StreamsOutbound: 3, Streams: 4, Memory: 32 * (256<<20 + 16<<10)}, + rcmgr.BaseLimitIncrease{}, + ) + + // autonat + addServiceAndProtocolLimit(config, + autonat.ServiceName, autonat.AutoNATProto, + rcmgr.BaseLimit{StreamsInbound: 64, StreamsOutbound: 64, Streams: 64, Memory: 4 << 20}, + rcmgr.BaseLimitIncrease{StreamsInbound: 4, StreamsOutbound: 4, Streams: 4, Memory: 2 << 20}, + ) + addServicePeerAndProtocolPeerLimit( + config, + autonat.ServiceName, autonat.AutoNATProto, + rcmgr.BaseLimit{StreamsInbound: 2, StreamsOutbound: 2, Streams: 2, Memory: 1 << 20}, + rcmgr.BaseLimitIncrease{}, + ) + + // holepunch + addServiceAndProtocolLimit(config, + holepunch.ServiceName, holepunch.Protocol, + rcmgr.BaseLimit{StreamsInbound: 32, StreamsOutbound: 32, Streams: 64, Memory: 4 << 20}, + rcmgr.BaseLimitIncrease{StreamsInbound: 8, StreamsOutbound: 8, Streams: 16, Memory: 4 << 20}, + ) + addServicePeerAndProtocolPeerLimit(config, + holepunch.ServiceName, holepunch.Protocol, + rcmgr.BaseLimit{StreamsInbound: 2, StreamsOutbound: 2, Streams: 2, Memory: 1 << 20}, + rcmgr.BaseLimitIncrease{}, + ) + + // relay/v2 + config.AddServiceLimit( + relayv2.ServiceName, + rcmgr.BaseLimit{StreamsInbound: 256, StreamsOutbound: 256, Streams: 256, Memory: 16 << 20}, + rcmgr.BaseLimitIncrease{StreamsInbound: 256, StreamsOutbound: 256, Streams: 256, Memory: 16 << 20}, + ) + config.AddServicePeerLimit( + relayv2.ServiceName, + rcmgr.BaseLimit{StreamsInbound: 64, StreamsOutbound: 64, Streams: 64, Memory: 1 << 20}, + rcmgr.BaseLimitIncrease{}, + ) + + // circuit protocols, both client and service + for _, proto := range [...]protocol.ID{circuit.ProtoIDv2Hop, circuit.ProtoIDv2Stop} { + config.AddProtocolLimit( + proto, + rcmgr.BaseLimit{StreamsInbound: 640, StreamsOutbound: 640, Streams: 640, Memory: 16 << 20}, + rcmgr.BaseLimitIncrease{StreamsInbound: 640, StreamsOutbound: 640, Streams: 640, Memory: 16 << 20}, + ) + config.AddProtocolPeerLimit( + proto, + rcmgr.BaseLimit{StreamsInbound: 128, StreamsOutbound: 128, Streams: 128, Memory: 32 << 20}, + rcmgr.BaseLimitIncrease{}, + ) + } +} + +func addServiceAndProtocolLimit(config *rcmgr.ScalingLimitConfig, service string, proto protocol.ID, limit rcmgr.BaseLimit, increase rcmgr.BaseLimitIncrease) { + config.AddServiceLimit(service, limit, increase) + config.AddProtocolLimit(proto, limit, increase) +} + +func addServicePeerAndProtocolPeerLimit(config *rcmgr.ScalingLimitConfig, service string, proto protocol.ID, limit rcmgr.BaseLimit, increase rcmgr.BaseLimitIncrease) { + config.AddServicePeerLimit(service, limit, increase) + config.AddProtocolPeerLimit(proto, limit, increase) +} diff --git a/limits.go b/limits.go index 5871577e51..6e09e2d628 100644 --- a/limits.go +++ b/limits.go @@ -1,113 +1,11 @@ package libp2p import ( - "github.com/libp2p/go-libp2p/core/protocol" - "github.com/libp2p/go-libp2p/p2p/host/autonat" + "github.com/libp2p/go-libp2p/internal/limits" rcmgr "github.com/libp2p/go-libp2p/p2p/host/resource-manager" - circuit "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/proto" - relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" - "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" - "github.com/libp2p/go-libp2p/p2p/protocol/identify" - "github.com/libp2p/go-libp2p/p2p/protocol/ping" ) // SetDefaultServiceLimits sets the default limits for bundled libp2p services func SetDefaultServiceLimits(config *rcmgr.ScalingLimitConfig) { - // identify - config.AddServiceLimit( - identify.ServiceName, - rcmgr.BaseLimit{StreamsInbound: 64, StreamsOutbound: 64, Streams: 128, Memory: 4 << 20}, - rcmgr.BaseLimitIncrease{StreamsInbound: 64, StreamsOutbound: 64, Streams: 128, Memory: 4 << 20}, - ) - config.AddServicePeerLimit( - identify.ServiceName, - rcmgr.BaseLimit{StreamsInbound: 16, StreamsOutbound: 16, Streams: 32, Memory: 1 << 20}, - rcmgr.BaseLimitIncrease{}, - ) - for _, id := range [...]protocol.ID{identify.ID, identify.IDPush} { - config.AddProtocolLimit( - id, - rcmgr.BaseLimit{StreamsInbound: 64, StreamsOutbound: 64, Streams: 128, Memory: 4 << 20}, - rcmgr.BaseLimitIncrease{StreamsInbound: 64, StreamsOutbound: 64, Streams: 128, Memory: 4 << 20}, - ) - config.AddProtocolPeerLimit( - id, - rcmgr.BaseLimit{StreamsInbound: 16, StreamsOutbound: 16, Streams: 32, Memory: 32 * (256<<20 + 16<<10)}, - rcmgr.BaseLimitIncrease{}, - ) - } - - // ping - addServiceAndProtocolLimit(config, - ping.ServiceName, ping.ID, - rcmgr.BaseLimit{StreamsInbound: 64, StreamsOutbound: 64, Streams: 64, Memory: 4 << 20}, - rcmgr.BaseLimitIncrease{StreamsInbound: 64, StreamsOutbound: 64, Streams: 64, Memory: 4 << 20}, - ) - addServicePeerAndProtocolPeerLimit( - config, - ping.ServiceName, ping.ID, - rcmgr.BaseLimit{StreamsInbound: 2, StreamsOutbound: 3, Streams: 4, Memory: 32 * (256<<20 + 16<<10)}, - rcmgr.BaseLimitIncrease{}, - ) - - // autonat - addServiceAndProtocolLimit(config, - autonat.ServiceName, autonat.AutoNATProto, - rcmgr.BaseLimit{StreamsInbound: 64, StreamsOutbound: 64, Streams: 64, Memory: 4 << 20}, - rcmgr.BaseLimitIncrease{StreamsInbound: 4, StreamsOutbound: 4, Streams: 4, Memory: 2 << 20}, - ) - addServicePeerAndProtocolPeerLimit( - config, - autonat.ServiceName, autonat.AutoNATProto, - rcmgr.BaseLimit{StreamsInbound: 2, StreamsOutbound: 2, Streams: 2, Memory: 1 << 20}, - rcmgr.BaseLimitIncrease{}, - ) - - // holepunch - addServiceAndProtocolLimit(config, - holepunch.ServiceName, holepunch.Protocol, - rcmgr.BaseLimit{StreamsInbound: 32, StreamsOutbound: 32, Streams: 64, Memory: 4 << 20}, - rcmgr.BaseLimitIncrease{StreamsInbound: 8, StreamsOutbound: 8, Streams: 16, Memory: 4 << 20}, - ) - addServicePeerAndProtocolPeerLimit(config, - holepunch.ServiceName, holepunch.Protocol, - rcmgr.BaseLimit{StreamsInbound: 2, StreamsOutbound: 2, Streams: 2, Memory: 1 << 20}, - rcmgr.BaseLimitIncrease{}, - ) - - // relay/v2 - config.AddServiceLimit( - relayv2.ServiceName, - rcmgr.BaseLimit{StreamsInbound: 256, StreamsOutbound: 256, Streams: 256, Memory: 16 << 20}, - rcmgr.BaseLimitIncrease{StreamsInbound: 256, StreamsOutbound: 256, Streams: 256, Memory: 16 << 20}, - ) - config.AddServicePeerLimit( - relayv2.ServiceName, - rcmgr.BaseLimit{StreamsInbound: 64, StreamsOutbound: 64, Streams: 64, Memory: 1 << 20}, - rcmgr.BaseLimitIncrease{}, - ) - - // circuit protocols, both client and service - for _, proto := range [...]protocol.ID{circuit.ProtoIDv2Hop, circuit.ProtoIDv2Stop} { - config.AddProtocolLimit( - proto, - rcmgr.BaseLimit{StreamsInbound: 640, StreamsOutbound: 640, Streams: 640, Memory: 16 << 20}, - rcmgr.BaseLimitIncrease{StreamsInbound: 640, StreamsOutbound: 640, Streams: 640, Memory: 16 << 20}, - ) - config.AddProtocolPeerLimit( - proto, - rcmgr.BaseLimit{StreamsInbound: 128, StreamsOutbound: 128, Streams: 128, Memory: 32 << 20}, - rcmgr.BaseLimitIncrease{}, - ) - } -} - -func addServiceAndProtocolLimit(config *rcmgr.ScalingLimitConfig, service string, proto protocol.ID, limit rcmgr.BaseLimit, increase rcmgr.BaseLimitIncrease) { - config.AddServiceLimit(service, limit, increase) - config.AddProtocolLimit(proto, limit, increase) -} - -func addServicePeerAndProtocolPeerLimit(config *rcmgr.ScalingLimitConfig, service string, proto protocol.ID, limit rcmgr.BaseLimit, increase rcmgr.BaseLimitIncrease) { - config.AddServicePeerLimit(service, limit, increase) - config.AddProtocolPeerLimit(proto, limit, increase) + limits.SetDefaultServiceLimits(config) } diff --git a/p2p/host/blank/blank.go b/p2p/host/blank/blank.go index 2e233c7173..180182c019 100644 --- a/p2p/host/blank/blank.go +++ b/p2p/host/blank/blank.go @@ -14,9 +14,8 @@ import ( "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/record" - "github.com/libp2p/go-libp2p/p2p/host/eventbus" - logging "github.com/libp2p/go-libp2p/gologshim" + "github.com/libp2p/go-libp2p/p2p/host/eventbus" ma "github.com/multiformats/go-multiaddr" mstream "github.com/multiformats/go-multistream" @@ -24,20 +23,24 @@ import ( var log = logging.Logger("blankhost") -// BlankHost is the thinnest implementation of the host.Host interface +// BlankHost is a thin implementation of the host.Host interface type BlankHost struct { - n network.Network - mux *mstream.MultistreamMuxer[protocol.ID] - cmgr connmgr.ConnManager - eventbus event.Bus - emitters struct { + N network.Network + M *mstream.MultistreamMuxer[protocol.ID] + E event.Bus + ConnMgr connmgr.ConnManager + // SkipInitSignedRecord is a flag to skip the initialization of a signed record for the host + SkipInitSignedRecord bool + emitters struct { evtLocalProtocolsUpdated event.Emitter } + onStop []func() error } type config struct { - cmgr connmgr.ConnManager - eventBus event.Bus + cmgr connmgr.ConnManager + eventBus event.Bus + skipInitSignedRecord bool } type Option = func(cfg *config) @@ -54,6 +57,12 @@ func WithEventBus(eventBus event.Bus) Option { } } +func SkipInitSignedRecord() Option { + return func(cfg *config) { + cfg.skipInitSignedRecord = true + } +} + func NewBlankHost(n network.Network, options ...Option) *BlankHost { cfg := config{ cmgr: &connmgr.NullConnMgr{}, @@ -63,36 +72,72 @@ func NewBlankHost(n network.Network, options ...Option) *BlankHost { } bh := &BlankHost{ - n: n, - cmgr: cfg.cmgr, - mux: mstream.NewMultistreamMuxer[protocol.ID](), - eventbus: cfg.eventBus, + N: n, + ConnMgr: cfg.cmgr, + M: mstream.NewMultistreamMuxer[protocol.ID](), + E: cfg.eventBus, + + SkipInitSignedRecord: cfg.skipInitSignedRecord, } - if bh.eventbus == nil { - bh.eventbus = eventbus.NewBus(eventbus.WithMetricsTracer(eventbus.NewMetricsTracer())) + + if err := bh.Start(); err != nil { + log.Error("error creating blank host", "err", err) + return nil + } + + return bh +} + +func (bh *BlankHost) Start() error { + if bh.E == nil { + bh.E = eventbus.NewBus(eventbus.WithMetricsTracer(eventbus.NewMetricsTracer())) } // subscribe the connection manager to network notifications (has no effect with NullConnMgr) - n.Notify(bh.cmgr.Notifee()) + notifee := bh.ConnMgr.Notifee() + bh.N.Notify(notifee) + bh.onStop = append(bh.onStop, func() error { + bh.N.StopNotify(notifee) + return nil + }) var err error - if bh.emitters.evtLocalProtocolsUpdated, err = bh.eventbus.Emitter(&event.EvtLocalProtocolsUpdated{}); err != nil { - return nil + if bh.emitters.evtLocalProtocolsUpdated, err = bh.E.Emitter(&event.EvtLocalProtocolsUpdated{}); err != nil { + return err } + bh.onStop = append(bh.onStop, func() error { + bh.emitters.evtLocalProtocolsUpdated.Close() + return nil + }) - n.SetStreamHandler(bh.newStreamHandler) + bh.N.SetStreamHandler(bh.newStreamHandler) + bh.onStop = append(bh.onStop, func() error { + bh.N.SetStreamHandler(func(s network.Stream) { s.Reset() }) + return nil + }) // persist a signed peer record for self to the peerstore. - if err := bh.initSignedRecord(); err != nil { - log.Error("error creating blank host", "err", err) - return nil + if !bh.SkipInitSignedRecord { + if err := bh.initSignedRecord(); err != nil { + log.Error("error creating blank host", "err", err) + return err + } } - return bh + return nil +} + +func (bh *BlankHost) Stop() error { + var err error + for _, f := range bh.onStop { + err = errors.Join(err, f()) + } + bh.onStop = nil + return err } func (bh *BlankHost) initSignedRecord() error { - cab, ok := peerstore.GetCertifiedAddrBook(bh.n.Peerstore()) + cab, ok := peerstore.GetCertifiedAddrBook(bh.N.Peerstore()) if !ok { log.Error("peerstore does not support signed records") return errors.New("peerstore does not support signed records") @@ -114,7 +159,7 @@ func (bh *BlankHost) initSignedRecord() error { var _ host.Host = (*BlankHost)(nil) func (bh *BlankHost) Addrs() []ma.Multiaddr { - addrs, err := bh.n.InterfaceListenAddresses() + addrs, err := bh.N.InterfaceListenAddresses() if err != nil { log.Debug("error retrieving network interface addrs", "err", err) return nil @@ -124,14 +169,18 @@ func (bh *BlankHost) Addrs() []ma.Multiaddr { } func (bh *BlankHost) Close() error { - return bh.n.Close() + var err error + if bh.onStop != nil { + err = bh.Stop() + } + return errors.Join(err, bh.N.Close()) } func (bh *BlankHost) Connect(ctx context.Context, ai peer.AddrInfo) error { // absorb addresses into peerstore bh.Peerstore().AddAddrs(ai.ID, ai.Addrs, peerstore.TempAddrTTL) - cs := bh.n.ConnsToPeer(ai.ID) + cs := bh.N.ConnsToPeer(ai.ID) if len(cs) > 0 { return nil } @@ -144,15 +193,15 @@ func (bh *BlankHost) Connect(ctx context.Context, ai peer.AddrInfo) error { } func (bh *BlankHost) Peerstore() peerstore.Peerstore { - return bh.n.Peerstore() + return bh.N.Peerstore() } func (bh *BlankHost) ID() peer.ID { - return bh.n.LocalPeer() + return bh.N.LocalPeer() } func (bh *BlankHost) NewStream(ctx context.Context, p peer.ID, protos ...protocol.ID) (network.Stream, error) { - s, err := bh.n.NewStream(ctx, p) + s, err := bh.N.NewStream(ctx, p) if err != nil { return nil, fmt.Errorf("failed to open stream: %w", err) } @@ -204,7 +253,7 @@ func (bh *BlankHost) SetStreamHandlerMatch(pid protocol.ID, m func(protocol.ID) func (bh *BlankHost) newStreamHandler(s network.Stream) { protoID, handle, err := bh.Mux().Negotiate(s) if err != nil { - log.Info("protocol negotiation failed", "err", err) + log.Error("protocol negotiation failed", "err", err) s.Reset() return } @@ -216,18 +265,18 @@ func (bh *BlankHost) newStreamHandler(s network.Stream) { // TODO: i'm not sure this really needs to be here func (bh *BlankHost) Mux() protocol.Switch { - return bh.mux + return bh.M } // TODO: also not sure this fits... Might be better ways around this (leaky abstractions) func (bh *BlankHost) Network() network.Network { - return bh.n + return bh.N } func (bh *BlankHost) ConnManager() connmgr.ConnManager { - return bh.cmgr + return bh.ConnMgr } func (bh *BlankHost) EventBus() event.Bus { - return bh.eventbus + return bh.E } diff --git a/p2p/transport/quic/conn.go b/p2p/transport/quic/conn.go index 0bd5e5ab31..5217d5419f 100644 --- a/p2p/transport/quic/conn.go +++ b/p2p/transport/quic/conn.go @@ -14,7 +14,7 @@ import ( type conn struct { quicConn *quic.Conn - transport *transport + transport *Transport scope network.ConnManagementScope localPeer peer.ID diff --git a/p2p/transport/quic/conn_test.go b/p2p/transport/quic/conn_test.go index 703255a8b4..6c5c1c2cb8 100644 --- a/p2p/transport/quic/conn_test.go +++ b/p2p/transport/quic/conn_test.go @@ -88,12 +88,12 @@ func testHandshake(t *testing.T, tc *connTestCase) { serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) - defer serverTransport.(io.Closer).Close() + defer serverTransport.Close() handshake := func(t *testing.T, ln tpt.Listener) { clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) - defer clientTransport.(io.Closer).Close() + defer clientTransport.Close() conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) require.NoError(t, err) defer conn.Close() @@ -141,7 +141,7 @@ func testResourceManagerSuccess(t *testing.T, tc *connTestCase) { serverRcmgr := mocknetwork.NewMockResourceManager(ctrl) serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, serverRcmgr) require.NoError(t, err) - defer serverTransport.(io.Closer).Close() + defer serverTransport.Close() ln, err := serverTransport.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1")) require.NoError(t, err) defer ln.Close() @@ -149,7 +149,7 @@ func testResourceManagerSuccess(t *testing.T, tc *connTestCase) { clientRcmgr := mocknetwork.NewMockResourceManager(ctrl) clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, clientRcmgr) require.NoError(t, err) - defer clientTransport.(io.Closer).Close() + defer clientTransport.Close() connChan := make(chan tpt.CapableConn) serverConnScope := mocknetwork.NewMockConnManagementScope(ctrl) @@ -190,7 +190,7 @@ func testResourceManagerDialDenied(t *testing.T, tc *connTestCase) { rcmgr := mocknetwork.NewMockResourceManager(ctrl) clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, rcmgr) require.NoError(t, err) - defer clientTransport.(io.Closer).Close() + defer clientTransport.Close() connScope := mocknetwork.NewMockConnManagementScope(ctrl) target := ma.StringCast("/ip4/127.0.0.1/udp/1234/quic-v1") @@ -223,7 +223,7 @@ func testResourceManagerAcceptDenied(t *testing.T, tc *connTestCase) { clientRcmgr := mocknetwork.NewMockResourceManager(ctrl) clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, clientRcmgr) require.NoError(t, err) - defer clientTransport.(io.Closer).Close() + defer clientTransport.Close() serverRcmgr := mocknetwork.NewMockResourceManager(ctrl) serverConnScope := mocknetwork.NewMockConnManagementScope(ctrl) @@ -235,7 +235,7 @@ func testResourceManagerAcceptDenied(t *testing.T, tc *connTestCase) { ) serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, serverRcmgr) require.NoError(t, err) - defer serverTransport.(io.Closer).Close() + defer serverTransport.Close() ln, err := serverTransport.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1")) require.NoError(t, err) defer ln.Close() @@ -281,13 +281,13 @@ func testStreams(t *testing.T, tc *connTestCase) { serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) - defer serverTransport.(io.Closer).Close() + defer serverTransport.Close() ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1") defer ln.Close() clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) - defer clientTransport.(io.Closer).Close() + defer clientTransport.Close() conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) require.NoError(t, err) defer conn.Close() @@ -313,13 +313,13 @@ func testStreamsErrorCode(t *testing.T, tc *connTestCase) { serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) - defer serverTransport.(io.Closer).Close() + defer serverTransport.Close() ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1") defer ln.Close() clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) - defer clientTransport.(io.Closer).Close() + defer clientTransport.Close() conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) require.NoError(t, err) defer conn.Close() @@ -361,7 +361,7 @@ func testHandshakeFailPeerIDMismatch(t *testing.T, tc *connTestCase) { serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) - defer serverTransport.(io.Closer).Close() + defer serverTransport.Close() ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1") clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil) @@ -370,7 +370,7 @@ func testHandshakeFailPeerIDMismatch(t *testing.T, tc *connTestCase) { _, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), thirdPartyID) require.Error(t, err) require.Contains(t, err.Error(), "CRYPTO_ERROR") - defer clientTransport.(io.Closer).Close() + defer clientTransport.Close() acceptErr := make(chan error) go func() { @@ -406,7 +406,9 @@ func testConnectionGating(t *testing.T, tc *connTestCase) { t.Run("accepted connections", func(t *testing.T) { serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, cg, nil) - defer serverTransport.(io.Closer).Close() + defer func() { + _ = serverTransport.Close() + }() require.NoError(t, err) ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1") defer ln.Close() @@ -422,7 +424,7 @@ func testConnectionGating(t *testing.T, tc *connTestCase) { clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) - defer clientTransport.(io.Closer).Close() + defer clientTransport.Close() // make sure that connection attempts fails conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) // In rare instances, the connection gating error will already occur on Dial. @@ -451,7 +453,7 @@ func testConnectionGating(t *testing.T, tc *connTestCase) { t.Run("secured connections", func(t *testing.T) { serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) - defer serverTransport.(io.Closer).Close() + defer serverTransport.Close() ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1") defer ln.Close() @@ -460,7 +462,7 @@ func testConnectionGating(t *testing.T, tc *connTestCase) { clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, cg, nil) require.NoError(t, err) - defer clientTransport.(io.Closer).Close() + defer clientTransport.Close() // make sure that connection attempts fails _, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) @@ -490,12 +492,12 @@ func testDialTwo(t *testing.T, tc *connTestCase) { serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) - defer serverTransport.(io.Closer).Close() + defer serverTransport.Close() ln1 := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1") defer ln1.Close() serverTransport2, err := NewTransport(serverKey2, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) - defer serverTransport2.(io.Closer).Close() + defer serverTransport2.Close() ln2 := runServer(t, serverTransport2, "/ip4/127.0.0.1/udp/0/quic-v1") defer ln2.Close() @@ -521,7 +523,7 @@ func testDialTwo(t *testing.T, tc *connTestCase) { clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) - defer clientTransport.(io.Closer).Close() + defer clientTransport.Close() c1, err := clientTransport.Dial(context.Background(), ln1.Multiaddr(), serverID) require.NoError(t, err) defer c1.Close() @@ -576,7 +578,7 @@ func testStatelessReset(t *testing.T, tc *connTestCase) { serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) - defer serverTransport.(io.Closer).Close() + defer serverTransport.Close() ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1") var drop uint32 @@ -593,7 +595,7 @@ func testStatelessReset(t *testing.T, tc *connTestCase) { // establish a connection clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) - defer clientTransport.(io.Closer).Close() + defer clientTransport.Close() proxyAddr, err := quicreuse.ToQuicMultiaddr(proxy.LocalAddr(), quic.Version1) require.NoError(t, err) conn, err := clientTransport.Dial(context.Background(), proxyAddr, serverID) @@ -663,7 +665,7 @@ func TestHolePunching(t *testing.T) { t1, err := NewTransport(serverKey, newConnManager(t), nil, nil, nil) require.NoError(t, err) - defer t1.(io.Closer).Close() + defer t1.Close() laddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic-v1") require.NoError(t, err) ln1, err := t1.Listen(laddr) @@ -677,7 +679,7 @@ func TestHolePunching(t *testing.T) { t2, err := NewTransport(clientKey, newConnManager(t), nil, nil, nil) require.NoError(t, err) - defer t2.(io.Closer).Close() + defer t2.Close() ln2, err := t2.Listen(laddr) require.NoError(t, err) done2 := make(chan struct{}) @@ -700,7 +702,7 @@ func TestHolePunching(t *testing.T) { // If it hasn't created the hole punch map entry, the connection will be accepted as a regular connection, // which would make this test fail. require.Eventually(t, func() bool { - tr := t2.(*transport) + tr := t2 tr.holePunchingMx.Lock() defer tr.holePunchingMx.Unlock() return len(tr.holePunching) > 0 diff --git a/p2p/transport/quic/listener.go b/p2p/transport/quic/listener.go index 009d33cfd3..ef0b987f66 100644 --- a/p2p/transport/quic/listener.go +++ b/p2p/transport/quic/listener.go @@ -18,14 +18,14 @@ import ( // A listener listens for QUIC connections. type listener struct { reuseListener quicreuse.Listener - transport *transport + transport *Transport rcmgr network.ResourceManager privKey ic.PrivKey localPeer peer.ID localMultiaddrs map[quic.Version]ma.Multiaddr } -func newListener(ln quicreuse.Listener, t *transport, localPeer peer.ID, key ic.PrivKey, rcmgr network.ResourceManager) (listener, error) { +func newListener(ln quicreuse.Listener, t *Transport, localPeer peer.ID, key ic.PrivKey, rcmgr network.ResourceManager) (listener, error) { localMultiaddrs := make(map[quic.Version]ma.Multiaddr) for _, addr := range ln.Multiaddrs() { if _, err := addr.ValueForProtocol(ma.P_QUIC_V1); err == nil { diff --git a/p2p/transport/quic/listener_test.go b/p2p/transport/quic/listener_test.go index 53d6001d35..68ef49c609 100644 --- a/p2p/transport/quic/listener_test.go +++ b/p2p/transport/quic/listener_test.go @@ -104,7 +104,7 @@ func TestAcceptAfterClose(t *testing.T) { func TestCorrectNumberOfVirtualListeners(t *testing.T) { tr := newTransport(t, nil) - tpt := tr.(*transport) + tpt := tr.(*Transport) defer tr.(io.Closer).Close() localAddrV1 := ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1") @@ -129,7 +129,7 @@ func TestCleanupConnWhenBlocked(t *testing.T) { }) server := newTransport(t, mockRcmgr) - serverTpt := server.(*transport) + serverTpt := server.(*Transport) defer server.(io.Closer).Close() localAddrV1 := ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1") diff --git a/p2p/transport/quic/transport.go b/p2p/transport/quic/transport.go index 0176409e48..ebbfa827f3 100644 --- a/p2p/transport/quic/transport.go +++ b/p2p/transport/quic/transport.go @@ -35,7 +35,7 @@ var ErrHolePunching = errors.New("hole punching attempted; no active dial") var HolePunchTimeout = 5 * time.Second // The Transport implements the tpt.Transport interface for QUIC connections. -type transport struct { +type Transport struct { privKey ic.PrivKey localPeer peer.ID identity *p2ptls.Identity @@ -57,7 +57,7 @@ type transport struct { listeners map[string][]*virtualListener } -var _ tpt.Transport = &transport{} +var _ tpt.Transport = &Transport{} type holePunchKey struct { addr string @@ -70,7 +70,7 @@ type activeHolePunch struct { } // NewTransport creates a new QUIC transport -func NewTransport(key ic.PrivKey, connManager *quicreuse.ConnManager, psk pnet.PSK, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) (tpt.Transport, error) { +func NewTransport(key ic.PrivKey, connManager *quicreuse.ConnManager, psk pnet.PSK, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) (*Transport, error) { if len(psk) > 0 { log.Error("QUIC doesn't support private networks yet.") return nil, errors.New("QUIC doesn't support private networks yet") @@ -88,7 +88,7 @@ func NewTransport(key ic.PrivKey, connManager *quicreuse.ConnManager, psk pnet.P rcmgr = &network.NullResourceManager{} } - return &transport{ + return &Transport{ privKey: key, localPeer: localPeer, identity: identity, @@ -103,12 +103,12 @@ func NewTransport(key ic.PrivKey, connManager *quicreuse.ConnManager, psk pnet.P }, nil } -func (t *transport) ListenOrder() int { +func (t *Transport) ListenOrder() int { return ListenOrder } // Dial dials a new QUIC connection -func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (_c tpt.CapableConn, _err error) { +func (t *Transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (_c tpt.CapableConn, _err error) { if ok, isClient, _ := network.GetSimultaneousConnect(ctx); ok && !isClient { return t.holePunch(ctx, raddr, p) } @@ -127,7 +127,7 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (_c return c, nil } -func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) { +func (t *Transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) { if err := scope.SetPeer(p); err != nil { log.Debug("resource manager blocked outgoing connection for peer", "peer", p, "addr", raddr, "err", err) return nil, err @@ -174,19 +174,19 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p pee return c, nil } -func (t *transport) addConn(conn *quic.Conn, c *conn) { +func (t *Transport) addConn(conn *quic.Conn, c *conn) { t.connMx.Lock() t.conns[conn] = c t.connMx.Unlock() } -func (t *transport) removeConn(conn *quic.Conn) { +func (t *Transport) removeConn(conn *quic.Conn) { t.connMx.Lock() delete(t.conns, conn) t.connMx.Unlock() } -func (t *transport) holePunch(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { +func (t *Transport) holePunch(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { network, saddr, err := manet.DialArgs(raddr) if err != nil { return nil, err @@ -277,12 +277,12 @@ loop: var dialMatcher = mafmt.And(mafmt.IP, mafmt.Base(ma.P_UDP), mafmt.Base(ma.P_QUIC_V1)) // CanDial determines if we can dial to an address -func (t *transport) CanDial(addr ma.Multiaddr) bool { +func (t *Transport) CanDial(addr ma.Multiaddr) bool { return dialMatcher.Matches(addr) } // Listen listens for new QUIC connections on the passed multiaddr. -func (t *transport) Listen(addr ma.Multiaddr) (tpt.Listener, error) { +func (t *Transport) Listen(addr ma.Multiaddr) (tpt.Listener, error) { var tlsConf tls.Config tlsConf.GetConfigForClient = func(_ *tls.ClientHelloInfo) (*tls.Config, error) { // return a tls.Config that verifies the peer's certificate chain. @@ -344,7 +344,7 @@ func (t *transport) Listen(addr ma.Multiaddr) (tpt.Listener, error) { return l, nil } -func (t *transport) allowWindowIncrease(conn *quic.Conn, size uint64) bool { +func (t *Transport) allowWindowIncrease(conn *quic.Conn, size uint64) bool { // If the QUIC connection tries to increase the window before we've inserted it // into our connections map (which we do right after dialing / accepting it), // we have no way to account for that memory. This should be very rare. @@ -359,24 +359,24 @@ func (t *transport) allowWindowIncrease(conn *quic.Conn, size uint64) bool { } // Proxy returns true if this transport proxies. -func (t *transport) Proxy() bool { +func (t *Transport) Proxy() bool { return false } // Protocols returns the set of protocols handled by this transport. -func (t *transport) Protocols() []int { +func (t *Transport) Protocols() []int { return t.connManager.Protocols() } -func (t *transport) String() string { +func (t *Transport) String() string { return "QUIC" } -func (t *transport) Close() error { +func (t *Transport) Close() error { return nil } -func (t *transport) CloseVirtualListener(l *virtualListener) error { +func (t *Transport) CloseVirtualListener(l *virtualListener) error { t.listenersMu.Lock() defer t.listenersMu.Unlock() diff --git a/p2p/transport/quic/virtuallistener.go b/p2p/transport/quic/virtuallistener.go index 5b23e4c507..1c70a91407 100644 --- a/p2p/transport/quic/virtuallistener.go +++ b/p2p/transport/quic/virtuallistener.go @@ -18,7 +18,7 @@ type virtualListener struct { *listener udpAddr string version quic.Version - t *transport + t *Transport acceptRunnner *acceptLoopRunner acceptChan chan acceptVal } diff --git a/p2p/transport/webtransport/conn.go b/p2p/transport/webtransport/conn.go index fd4b1187b2..99007d6cbc 100644 --- a/p2p/transport/webtransport/conn.go +++ b/p2p/transport/webtransport/conn.go @@ -28,7 +28,7 @@ func (c *connMultiaddrs) RemoteMultiaddr() ma.Multiaddr { return c.remote } type conn struct { *connSecurityMultiaddrs - transport *transport + transport *Transport session *webtransport.Session scope network.ConnManagementScope @@ -37,7 +37,7 @@ type conn struct { var _ tpt.CapableConn = &conn{} -func newConn(tr *transport, sess *webtransport.Session, sconn *connSecurityMultiaddrs, scope network.ConnManagementScope, qconn *quic.Conn) *conn { +func newConn(tr *Transport, sess *webtransport.Session, sconn *connSecurityMultiaddrs, scope network.ConnManagementScope, qconn *quic.Conn) *conn { return &conn{ connSecurityMultiaddrs: sconn, transport: tr, diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index 7cd647f72b..dba55795fe 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -27,7 +27,7 @@ const handshakeTimeout = 10 * time.Second type connKey struct{} type listener struct { - transport *transport + transport *Transport isStaticTLSConf bool reuseListener quicreuse.Listener @@ -49,7 +49,7 @@ type listener struct { var _ tpt.Listener = &listener{} -func newListener(reuseListener quicreuse.Listener, t *transport, isStaticTLSConf bool) (tpt.Listener, error) { +func newListener(reuseListener quicreuse.Listener, t *Transport, isStaticTLSConf bool) (tpt.Listener, error) { localMultiaddr, err := toWebtransportMultiaddr(reuseListener.Addr()) if err != nil { return nil, err diff --git a/p2p/transport/webtransport/multiaddr_test.go b/p2p/transport/webtransport/multiaddr_test.go index 3f0a3ec0bf..3559497499 100644 --- a/p2p/transport/webtransport/multiaddr_test.go +++ b/p2p/transport/webtransport/multiaddr_test.go @@ -83,7 +83,7 @@ func TestWebtransportResolve(t *testing.T) { "/ip4/127.0.0.1/udp/1337/quic-v1/sni/example.com/webtransport", } - tpt := &transport{} + tpt := &Transport{} ctx := context.Background() for _, tc := range testCases { diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index c18ace1606..965fb49331 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -40,10 +40,10 @@ const errorCodeConnectionGating = 0x47415445 // GATE in ASCII const certValidity = 14 * 24 * time.Hour -type Option func(*transport) error +type Option func(*Transport) error func WithClock(cl clock.Clock) Option { - return func(t *transport) error { + return func(t *Transport) error { t.clock = cl return nil } @@ -54,20 +54,20 @@ func WithClock(cl clock.Clock) Option { // When dialing a multiaddr that contains a /certhash component, this library will set InsecureSkipVerify and // overwrite the VerifyPeerCertificate callback. func WithTLSClientConfig(c *tls.Config) Option { - return func(t *transport) error { + return func(t *Transport) error { t.tlsClientConf = c return nil } } func WithHandshakeTimeout(d time.Duration) Option { - return func(t *transport) error { + return func(t *Transport) error { t.handshakeTimeout = d return nil } } -type transport struct { +type Transport struct { privKey ic.PrivKey pid peer.ID clock clock.Clock @@ -90,11 +90,11 @@ type transport struct { handshakeTimeout time.Duration } -var _ tpt.Transport = &transport{} -var _ tpt.Resolver = &transport{} -var _ io.Closer = &transport{} +var _ tpt.Transport = &Transport{} +var _ tpt.Resolver = &Transport{} +var _ io.Closer = &Transport{} -func New(key ic.PrivKey, psk pnet.PSK, connManager *quicreuse.ConnManager, gater connmgr.ConnectionGater, rcmgr network.ResourceManager, opts ...Option) (tpt.Transport, error) { +func New(key ic.PrivKey, psk pnet.PSK, connManager *quicreuse.ConnManager, gater connmgr.ConnectionGater, rcmgr network.ResourceManager, opts ...Option) (*Transport, error) { if len(psk) > 0 { log.Error("WebTransport doesn't support private networks yet.") return nil, errors.New("WebTransport doesn't support private networks yet") @@ -106,7 +106,7 @@ func New(key ic.PrivKey, psk pnet.PSK, connManager *quicreuse.ConnManager, gater if err != nil { return nil, err } - t := &transport{ + t := &Transport{ pid: id, privKey: key, rcmgr: rcmgr, @@ -129,7 +129,7 @@ func New(key ic.PrivKey, psk pnet.PSK, connManager *quicreuse.ConnManager, gater return t, nil } -func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { +func (t *Transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { scope, err := t.rcmgr.OpenConnection(network.DirOutbound, false, raddr) if err != nil { log.Debug("resource manager blocked outgoing connection", "peer", p, "addr", raddr, "error", err) @@ -145,7 +145,7 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp return c, nil } -func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) { +func (t *Transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) { _, addr, err := manet.DialArgs(raddr) if err != nil { return nil, err @@ -188,7 +188,7 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p pee return conn, nil } -func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string, certHashes []multihash.DecodedMultihash) (*webtransport.Session, *quic.Conn, error) { +func (t *Transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string, certHashes []multihash.DecodedMultihash) (*webtransport.Session, *quic.Conn, error) { var tlsConf *tls.Config if t.tlsClientConf != nil { tlsConf = t.tlsClientConf.Clone() @@ -232,7 +232,7 @@ func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string return sess, conn, err } -func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p peer.ID, certHashes []multihash.DecodedMultihash) (*connSecurityMultiaddrs, error) { +func (t *Transport) upgrade(ctx context.Context, sess *webtransport.Session, p peer.ID, certHashes []multihash.DecodedMultihash) (*connSecurityMultiaddrs, error) { local, err := toWebtransportMultiaddr(sess.LocalAddr()) if err != nil { return nil, fmt.Errorf("error determining local addr: %w", err) @@ -302,12 +302,12 @@ func decodeCertHashesFromProtobuf(b [][]byte) ([]multihash.DecodedMultihash, err return hashes, nil } -func (t *transport) CanDial(addr ma.Multiaddr) bool { +func (t *Transport) CanDial(addr ma.Multiaddr) bool { ok, _ := IsWebtransportMultiaddr(addr) return ok } -func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { +func (t *Transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { isWebTransport, certhashCount := IsWebtransportMultiaddr(laddr) if !isWebTransport { return nil, fmt.Errorf("cannot listen on non-WebTransport addr: %s", laddr) @@ -341,15 +341,15 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { return newListener(ln, t, t.staticTLSConf != nil) } -func (t *transport) Protocols() []int { +func (t *Transport) Protocols() []int { return []int{ma.P_WEBTRANSPORT} } -func (t *transport) Proxy() bool { +func (t *Transport) Proxy() bool { return false } -func (t *transport) Close() error { +func (t *Transport) Close() error { t.listenOnce.Do(func() {}) if t.certManager != nil { return t.certManager.Close() @@ -357,7 +357,7 @@ func (t *transport) Close() error { return nil } -func (t *transport) allowWindowIncrease(conn *quic.Conn, size uint64) bool { +func (t *Transport) allowWindowIncrease(conn *quic.Conn, size uint64) bool { t.connMx.Lock() defer t.connMx.Unlock() @@ -368,13 +368,13 @@ func (t *transport) allowWindowIncrease(conn *quic.Conn, size uint64) bool { return c.allowWindowIncrease(size) } -func (t *transport) addConn(conn *quic.Conn, c *conn) { +func (t *Transport) addConn(conn *quic.Conn, c *conn) { t.connMx.Lock() t.conns[conn] = c t.connMx.Unlock() } -func (t *transport) removeConn(conn *quic.Conn) { +func (t *Transport) removeConn(conn *quic.Conn) { t.connMx.Lock() delete(t.conns, conn) t.connMx.Unlock() @@ -403,7 +403,7 @@ func extractSNI(maddr ma.Multiaddr) (sni string, foundSniComponent bool) { } // Resolve implements transport.Resolver -func (t *transport) Resolve(_ context.Context, maddr ma.Multiaddr) ([]ma.Multiaddr, error) { +func (t *Transport) Resolve(_ context.Context, maddr ma.Multiaddr) ([]ma.Multiaddr, error) { sni, foundSniComponent := extractSNI(maddr) if foundSniComponent || sni == "" { @@ -433,7 +433,7 @@ func (t *transport) Resolve(_ context.Context, maddr ma.Multiaddr) ([]ma.Multiad // AddCertHashes adds the current certificate hashes to a multiaddress. // If called before Listen, it's a no-op. -func (t *transport) AddCertHashes(m ma.Multiaddr) (ma.Multiaddr, bool) { +func (t *Transport) AddCertHashes(m ma.Multiaddr) (ma.Multiaddr, bool) { if !t.hasCertManager.Load() { return m, false } diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index 266e01d18f..5ee26508b2 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -108,7 +108,7 @@ func TestTransport(t *testing.T) { serverID, serverKey := newIdentity(t) tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, nil) require.NoError(t, err) - defer tr.(io.Closer).Close() + defer tr.Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) require.NoError(t, err) defer ln.Close() @@ -118,7 +118,7 @@ func TestTransport(t *testing.T) { _, clientKey := newIdentity(t) tr2, err := libp2pwebtransport.New(clientKey, nil, newConnManager(t), nil, nil) require.NoError(t, err) - defer tr2.(io.Closer).Close() + defer tr2.Close() conn, err := tr2.Dial(context.Background(), ln.Multiaddr(), serverID) require.NoError(t, err) @@ -154,7 +154,7 @@ func TestHashVerification(t *testing.T) { serverID, serverKey := newIdentity(t) tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) - defer tr.(io.Closer).Close() + defer tr.Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) require.NoError(t, err) done := make(chan struct{}) @@ -167,7 +167,7 @@ func TestHashVerification(t *testing.T) { _, clientKey := newIdentity(t) tr2, err := libp2pwebtransport.New(clientKey, nil, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) - defer tr2.(io.Closer).Close() + defer tr2.Close() foobarHash := getCerthashComponent(t, []byte("foobar")) @@ -212,7 +212,7 @@ func TestCanDial(t *testing.T) { _, key := newIdentity(t) tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) - defer tr.(io.Closer).Close() + defer tr.Close() for _, addr := range valid { require.Truef(t, tr.CanDial(addr), "expected to be able to dial %s", addr) @@ -237,7 +237,7 @@ func TestListenAddrValidity(t *testing.T) { _, key := newIdentity(t) tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) - defer tr.(io.Closer).Close() + defer tr.Close() for _, addr := range valid { ln, err := tr.Listen(addr) @@ -254,7 +254,7 @@ func TestListenerAddrs(t *testing.T) { _, key := newIdentity(t) tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) - defer tr.(io.Closer).Close() + defer tr.Close() ln1, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) require.NoError(t, err) @@ -277,7 +277,7 @@ func TestResourceManagerDialing(t *testing.T) { _, key := newIdentity(t) tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, rcmgr) require.NoError(t, err) - defer tr.(io.Closer).Close() + defer tr.Close() l, err := tr.Listen(addr) require.NoError(t, err) @@ -296,7 +296,7 @@ func TestResourceManagerListening(t *testing.T) { clientID, key := newIdentity(t) cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) - defer cl.(io.Closer).Close() + defer cl.Close() t.Run("blocking the connection", func(t *testing.T) { serverID, key := newIdentity(t) @@ -375,7 +375,7 @@ func TestConnectionGaterDialing(t *testing.T) { serverID, serverKey := newIdentity(t) tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) - defer tr.(io.Closer).Close() + defer tr.Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) require.NoError(t, err) defer ln.Close() @@ -386,7 +386,7 @@ func TestConnectionGaterDialing(t *testing.T) { _, key := newIdentity(t) cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), connGater, &network.NullResourceManager{}) require.NoError(t, err) - defer cl.(io.Closer).Close() + defer cl.Close() _, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID) require.EqualError(t, err, "secured connection gated") } @@ -399,7 +399,7 @@ func TestConnectionGaterInterceptAccept(t *testing.T) { serverID, serverKey := newIdentity(t) tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), connGater, &network.NullResourceManager{}) require.NoError(t, err) - defer tr.(io.Closer).Close() + defer tr.Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) require.NoError(t, err) defer ln.Close() @@ -412,7 +412,7 @@ func TestConnectionGaterInterceptAccept(t *testing.T) { _, key := newIdentity(t) cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) - defer cl.(io.Closer).Close() + defer cl.Close() _, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID) require.EqualError(t, err, "received status 403") } @@ -425,7 +425,7 @@ func TestConnectionGaterInterceptSecured(t *testing.T) { serverID, serverKey := newIdentity(t) tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), connGater, &network.NullResourceManager{}) require.NoError(t, err) - defer tr.(io.Closer).Close() + defer tr.Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) require.NoError(t, err) defer ln.Close() @@ -433,7 +433,7 @@ func TestConnectionGaterInterceptSecured(t *testing.T) { clientID, key := newIdentity(t) cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) - defer cl.(io.Closer).Close() + defer cl.Close() connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true) connGater.EXPECT().InterceptSecured(network.DirInbound, clientID, gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) { @@ -461,7 +461,7 @@ func TestAcceptQueueFilledUp(t *testing.T) { serverID, serverKey := newIdentity(t) tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) - defer tr.(io.Closer).Close() + defer tr.Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) require.NoError(t, err) defer ln.Close() @@ -471,7 +471,7 @@ func TestAcceptQueueFilledUp(t *testing.T) { _, key := newIdentity(t) cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) - defer cl.(io.Closer).Close() + defer cl.Close() return cl.Dial(context.Background(), ln.Multiaddr(), serverID) } @@ -566,7 +566,7 @@ func TestFlowControlWindowIncrease(t *testing.T) { serverRcmgr := &reportingRcmgr{report: serverWindowIncreases} tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, serverRcmgr) require.NoError(t, err) - defer tr.(io.Closer).Close() + defer tr.Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) require.NoError(t, err) defer ln.Close() @@ -594,7 +594,7 @@ func TestFlowControlWindowIncrease(t *testing.T) { clientRcmgr := &reportingRcmgr{report: clientWindowIncreases} tr2, err := libp2pwebtransport.New(clientKey, nil, newConnManager(t), nil, clientRcmgr) require.NoError(t, err) - defer tr2.(io.Closer).Close() + defer tr2.Close() var addr ma.Multiaddr for _, comp := range ln.Multiaddr() { @@ -842,7 +842,7 @@ func TestH3ConnClosed(t *testing.T) { _, serverKey := newIdentity(t) tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, nil, libp2pwebtransport.WithHandshakeTimeout(1*time.Second)) require.NoError(t, err) - defer tr.(io.Closer).Close() + defer tr.Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) require.NoError(t, err) defer ln.Close() diff --git a/x/builder/builder.go b/x/builder/builder.go new file mode 100644 index 0000000000..cc0591de85 --- /dev/null +++ b/x/builder/builder.go @@ -0,0 +1,806 @@ +// package builder is an alternative and experimental way to build a go-libp2p +// node and services. +// See the commit details for history and motivation +package builder + +import ( + "crypto/rand" + "crypto/sha256" + "errors" + "fmt" + "io" + "log/slog" + "net" + "slices" + "time" + + "github.com/libp2p/go-libp2p/core/connmgr" + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/event" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/metrics" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/peerstore" + "github.com/libp2p/go-libp2p/core/pnet" + "github.com/libp2p/go-libp2p/core/protocol" + "github.com/libp2p/go-libp2p/core/routing" + "github.com/libp2p/go-libp2p/core/sec" + "github.com/libp2p/go-libp2p/core/transport" + "github.com/libp2p/go-libp2p/internal/limits" + "github.com/libp2p/go-libp2p/p2p/host/autorelay" + bhost "github.com/libp2p/go-libp2p/p2p/host/basic" + blankhost "github.com/libp2p/go-libp2p/p2p/host/blank" + "github.com/libp2p/go-libp2p/p2p/host/eventbus" + "github.com/libp2p/go-libp2p/p2p/host/observedaddrs" + "github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoremem" + rcmgr "github.com/libp2p/go-libp2p/p2p/host/resource-manager" + routed "github.com/libp2p/go-libp2p/p2p/host/routed" + "github.com/libp2p/go-libp2p/p2p/muxer/yamux" + netconnmgr "github.com/libp2p/go-libp2p/p2p/net/connmgr" + "github.com/libp2p/go-libp2p/p2p/net/swarm" + tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2" + relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" + "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" + "github.com/libp2p/go-libp2p/p2p/security/noise" + libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls" + libp2pquic "github.com/libp2p/go-libp2p/p2p/transport/quic" + "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" + "github.com/libp2p/go-libp2p/p2p/transport/tcp" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse" + libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" + ws "github.com/libp2p/go-libp2p/p2p/transport/websocket" + libp2pwebtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport" + mstream "github.com/multiformats/go-multistream" + "github.com/prometheus/client_golang/prometheus" + "golang.org/x/crypto/hkdf" + + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" + "github.com/quic-go/quic-go" + + "git.sr.ht/~marcopolo/di" +) + +// SetDefaultServiceLimits sets the default limits for bundled libp2p services +func SetDefaultServiceLimits(config *rcmgr.ScalingLimitConfig) { + limits.SetDefaultServiceLimits(config) +} + +// Lifecycle can be used to register start functions and closers. Services +// should not cause any side-effects on instantiation. Instead, services should +// have a start function that executes side effects (such as spawning a worker +// goroutine). +type Lifecycle struct { + startFns []func() error + closers []io.Closer +} + +func (l *Lifecycle) OnStart(fn func() error) { + l.startFns = append(l.startFns, fn) +} + +func (l *Lifecycle) OnClose(c io.Closer) { + l.closers = append(l.closers, c) +} + +func (l *Lifecycle) Start() error { + for _, fn := range l.startFns { + if err := fn(); err != nil { + return err + } + } + return nil +} + +func (l *Lifecycle) Close() error { + var errs []error + for _, c := range l.closers { + if err := c.Close(); err != nil { + errs = append(errs, err) + } + } + return errors.Join(errs...) +} + +type ListenAddrs []ma.Multiaddr + +type SwarmConfig struct { + ListenAddrs ListenAddrs + ReadOnlyBlackHoleDetector bool + UDPBlackHoleSuccessCounter *swarm.BlackHoleSuccessCounter + IPv6BlackHoleSuccessCounter *swarm.BlackHoleSuccessCounter + MultiaddrDNSResolver di.Optional[network.MultiaddrDNSResolver] + + // Opts can be used by the user to set swarm options + Opts []swarm.Option + Swarm di.Provide[*swarm.Swarm] +} + +// IdentifyConfig specifies the configuration of the Identify Service +type IdentifyConfig struct { + // UserAgent is the identifier this node will send to other peers when + // identifying itself, e.g. via the identify protocol. + // + // Set it via the UserAgent option function. + UserAgent string + + // ProtocolVersion is the protocol version that identifies the family + // of protocols used by the peer in the Identify protocol. It is set + // using the [ProtocolVersion] option. + ProtocolVersion string +} + +type UpgraderConfig struct { + // Upgrader is the upgrader used to upgrade connections to the libp2p + // protocol. + Upgrader di.Provide[transport.Upgrader] + + UpgraderOptions []di.Provide[tptu.Option] + + Muxers []tptu.StreamMuxer + Security []di.Provide[sec.SecureTransport] +} + +type DialConfig struct { + DialTimeout time.Duration + DialRanker di.Optional[network.DialRanker] +} + +type MetricsConfig struct { + BandwidthReporter metrics.Reporter + PrometheusRegisterer prometheus.Registerer +} + +// UDPTransportsConfig specifies the concrete transports that run on top of TCP +type TCPTransportsConfig struct { + SharedTCPConnMuxer di.Provide[*tcpreuse.ConnMgr] + + TCPOpts []tcp.Option + TcpTransport di.Provide[*tcp.TcpTransport] + + WsOpts []ws.Option + WSTransport di.Provide[*ws.WebsocketTransport] +} + +// UDPTransportsConfig specifies the concrete transports that run on top of UDP +type UDPTransportsConfig struct { + QUICConfig + + ListenUDPFn di.Provide[libp2pwebrtc.ListenUDPFn] + + QUICTransport di.Provide[*libp2pquic.Transport] + + WebTransportOpts []libp2pwebtransport.Option + WebTransportTransport di.Provide[*libp2pwebtransport.Transport] + + WebRTCOpts []libp2pwebrtc.Option + WebRTCTransport di.Provide[*libp2pwebrtc.WebRTCTransport] +} + +type TransportsConfig struct { + TCPTransportsConfig + UDPTransportsConfig + + PSK di.Provide[pnet.PSK] + + // Transports controls what transports are actually instantiated and used. + // While {TCP,UDP}TransportsConfig defines the concrete transports, they + // will not be instantiated unless they are provided here. + // + // This is because `di` is lazy, and won't call a constructor unless that + // constructor's return value is used. + Transports di.Provide[[]transport.Transport] +} + +type AutoRelayConfig struct { + Enabled bool + Opts []autorelay.Option +} + +type QUICConfig struct { + QUICReuse di.Provide[*quicreuse.ConnManager] + QUICReuseOpts []quicreuse.Option + + StatelessResetKey func(crypto.PrivKey) (quic.StatelessResetKey, error) + TokenKey func(crypto.PrivKey) (quic.TokenGeneratorKey, error) +} + +// use new types for AutoNAT specific properties when the type overlap with the +// host type. + +type AutoNatPrivKey crypto.PrivKey +type AutoNatPeerStore peerstore.Peerstore +type AutoNATHost host.Host +type AutoNatConfig struct { + PrivateKey func() (AutoNatPrivKey, error) + Peerstore di.Provide[AutoNatPeerStore] + + Host di.Provide[AutoNATHost] + Opts []autonatv2.AutoNATOption + + AutoNAT di.Provide[*autonatv2.AutoNAT] +} + +type RoutingC func(host.Host) (routing.PeerRouting, error) +type BasicHostConfig struct { + AddrsFactory di.Optional[bhost.AddrsFactory] + NATManager func(network.Network) bhost.NATManager + ObservedAddrsManager di.Provide[bhost.ObservedAddrsManager] + + EnablePing bool + EnableHolePunching bool + HolePunchingOptions []holepunch.Option + + Routing di.Optional[RoutingC] + + EnableRelay bool + RelayServiceOpts []relayv2.Option + + EnableMetrics bool +} + +type Config struct { + Logger *slog.Logger + Lifecycle func() *Lifecycle + + IdentifyConfig + + ResourceManager di.Provide[network.ResourceManager] + + EventBus func() event.Bus + Peerstore func() (peerstore.Peerstore, error) + ConnManager func(l *Lifecycle) (connmgr.ConnManager, error) + + PrivateKey func() (crypto.PrivKey, error) + PeerID func(crypto.PrivKey) (peer.ID, error) + + TransportsConfig + SwarmConfig + + ConnGater di.Provide[connmgr.ConnectionGater] + + UpgraderConfig + DialConfig + MetricsConfig + + AutoRelayConfig + BasicHostConfig + AutoNatConfig + + // SideEffects is used to do some extra processing on instantiated objects + // without producing loops. + // For example if A needs to be linked to B, but B requires C on + // instantiation, and C requires A. We can't link A to B in A's constructor + // (A would depend on B would depend on C would depend on A). + // + // With SideEffects we can construct A and B separately, and then introduce them in another step. + // - Build A + // - Build B + // - Link A to B in a side effect. + // + // The di.SideEffect is not a special type. Any type would work here. + // di.SideEffect is chosen for convention. + SideEffects []di.Provide[di.SideEffect] + + Host di.Provide[host.Host] +} + +func PrivKeyToStatelessResetKey(key crypto.PrivKey) (quic.StatelessResetKey, error) { + const statelessResetKeyInfo = "libp2p quic stateless reset key" + var statelessResetKey quic.StatelessResetKey + keyBytes, err := key.Raw() + if err != nil { + return statelessResetKey, err + } + keyReader := hkdf.New(sha256.New, keyBytes, nil, []byte(statelessResetKeyInfo)) + if _, err := io.ReadFull(keyReader, statelessResetKey[:]); err != nil { + return statelessResetKey, err + } + return statelessResetKey, nil +} + +func PrivKeyToTokenGeneratorKey(key crypto.PrivKey) (quic.TokenGeneratorKey, error) { + const tokenGeneratorKeyInfo = "libp2p quic token generator key" + var tokenKey quic.TokenGeneratorKey + keyBytes, err := key.Raw() + if err != nil { + return tokenKey, err + } + keyReader := hkdf.New(sha256.New, keyBytes, nil, []byte(tokenGeneratorKeyInfo)) + if _, err := io.ReadFull(keyReader, tokenKey[:]); err != nil { + return tokenKey, err + } + return tokenKey, nil +} + +var DefaultTransports = TransportsConfig{ + TCPTransportsConfig: TCPTransportsConfig{ + SharedTCPConnMuxer: di.MustProvide[*tcpreuse.ConnMgr]( + func(cfg TCPTransportsConfig, upgrader transport.Upgrader) *tcpreuse.ConnMgr { + if cfg.TcpTransport.Nil || cfg.WSTransport.Nil { + return nil + } + return tcpreuse.NewConnMgr(tcpreuse.EnvReuseportVal, upgrader) + }, + ), + + TCPOpts: []tcp.Option{}, + TcpTransport: di.MustProvide[*tcp.TcpTransport](tcp.NewTCPTransport), + + WsOpts: []ws.Option{}, + WSTransport: di.MustProvide[*ws.WebsocketTransport](ws.New), + }, + + UDPTransportsConfig: UDPTransportsConfig{ + QUICConfig: QUICConfig{ + QUICReuse: di.MustProvide[*quicreuse.ConnManager]( + func( + l *Lifecycle, + statelessResetKey quic.StatelessResetKey, + tokenKey quic.TokenGeneratorKey, + opts []quicreuse.Option, + ) (*quicreuse.ConnManager, error) { + cm, err := quicreuse.NewConnManager(statelessResetKey, tokenKey, opts...) + if err != nil { + return nil, err + } + l.OnClose(cm) + return cm, nil + }), + StatelessResetKey: PrivKeyToStatelessResetKey, + TokenKey: PrivKeyToTokenGeneratorKey, + }, + ListenUDPFn: di.MustProvide[libp2pwebrtc.ListenUDPFn](func( + cm *quicreuse.ConnManager, + sw *swarm.Swarm, + ) libp2pwebrtc.ListenUDPFn { + hasQuicAddrPortFor := func(network string, laddr *net.UDPAddr) bool { + quicAddrPorts := map[string]struct{}{} + for _, addr := range sw.ListenAddresses() { + if _, err := addr.ValueForProtocol(ma.P_QUIC_V1); err == nil { + netw, addr, err := manet.DialArgs(addr) + if err != nil { + return false + } + quicAddrPorts[netw+"_"+addr] = struct{}{} + } + } + _, ok := quicAddrPorts[network+"_"+laddr.String()] + return ok + } + + return func(network string, laddr *net.UDPAddr) (net.PacketConn, error) { + if hasQuicAddrPortFor(network, laddr) { + return cm.SharedNonQUICPacketConn(network, laddr) + } + return net.ListenUDP(network, laddr) + } + }), + QUICTransport: di.MustProvide[*libp2pquic.Transport](libp2pquic.NewTransport), + + WebTransportOpts: []libp2pwebtransport.Option{}, + WebTransportTransport: di.MustProvide[*libp2pwebtransport.Transport](libp2pwebtransport.New), + + WebRTCOpts: []libp2pwebrtc.Option{}, + WebRTCTransport: di.MustProvide[*libp2pwebrtc.WebRTCTransport](libp2pwebrtc.New), + }, + + // We don't support PSKs by default + PSK: di.MustProvide[pnet.PSK](nil), + + Transports: di.MustProvide[[]transport.Transport]( + func( + tcp *tcp.TcpTransport, + ws *ws.WebsocketTransport, + quic *libp2pquic.Transport, + wt *libp2pwebtransport.Transport, + webrtc *libp2pwebrtc.WebRTCTransport, + ) (tpts []transport.Transport, err error) { + if tcp != nil { + tpts = append(tpts, tcp) + } + if ws != nil { + tpts = append(tpts, ws) + } + if quic != nil { + tpts = append(tpts, quic) + } + if wt != nil { + tpts = append(tpts, wt) + } + if webrtc != nil { + tpts = append(tpts, webrtc) + } + return tpts, nil + }, + ), +} + +var DefaultConfig = Config{ + Logger: slog.Default(), + Lifecycle: func() *Lifecycle { return &Lifecycle{} }, + + IdentifyConfig: IdentifyConfig{}, + + ResourceManager: di.MustProvide[network.ResourceManager](func(l *Lifecycle) (network.ResourceManager, error) { + // Default memory limit: 1/8th of total memory, minimum 128MB, maximum 1GB + limits := rcmgr.DefaultLimits + SetDefaultServiceLimits(&limits) + r, err := rcmgr.NewResourceManager(rcmgr.NewFixedLimiter(limits.AutoScale())) + if err != nil { + return nil, err + } + l.OnClose(r) + + return r, nil + }), + + EventBus: func() event.Bus { return eventbus.NewBus() }, + Peerstore: func() (peerstore.Peerstore, error) { + return pstoremem.NewPeerstore() + }, + ConnManager: func(l *Lifecycle) (connmgr.ConnManager, error) { + + cm, err := netconnmgr.NewConnManager(160, 192) + if err != nil { + return nil, err + } + l.OnClose(cm) + return cm, nil + }, + + PrivateKey: func() (crypto.PrivKey, error) { + priv, _, err := crypto.GenerateEd25519Key(rand.Reader) + return priv, err + }, + PeerID: peer.IDFromPrivateKey, + + TransportsConfig: DefaultTransports, + SwarmConfig: SwarmConfig{ + UDPBlackHoleSuccessCounter: &swarm.BlackHoleSuccessCounter{N: 100, MinSuccesses: 5, Name: "UDP"}, + IPv6BlackHoleSuccessCounter: &swarm.BlackHoleSuccessCounter{N: 100, MinSuccesses: 5, Name: "IPv6"}, + ListenAddrs: []ma.Multiaddr{ + di.Must(ma.NewMultiaddr("/ip4/0.0.0.0/tcp/0")), + di.Must(ma.NewMultiaddr("/ip4/0.0.0.0/udp/0/quic-v1")), + di.Must(ma.NewMultiaddr("/ip4/0.0.0.0/udp/0/quic-v1/webtransport")), + di.Must(ma.NewMultiaddr("/ip4/0.0.0.0/udp/0/webrtc-direct")), + di.Must(ma.NewMultiaddr("/ip6/::/tcp/0")), + di.Must(ma.NewMultiaddr("/ip6/::/udp/0/quic-v1")), + di.Must(ma.NewMultiaddr("/ip6/::/udp/0/quic-v1/webtransport")), + di.Must(ma.NewMultiaddr("/ip6/::/udp/0/webrtc-direct")), + }, + Swarm: di.MustProvide[*swarm.Swarm](func( + l *Lifecycle, + p peer.ID, + k crypto.PrivKey, + ps peerstore.Peerstore, + b event.Bus, + listenAddrs ListenAddrs, + + swarmConfig SwarmConfig, + dialConfig DialConfig, + UDPBlackHoleSuccessCounter *swarm.BlackHoleSuccessCounter, + IPv6BlackHoleSuccessCounter *swarm.BlackHoleSuccessCounter, + rcmgr network.ResourceManager, + multiaddrResolver di.Optional[network.MultiaddrDNSResolver], + dialRanker di.Optional[network.DialRanker], + metricsCfg MetricsConfig, + gater connmgr.ConnectionGater, + ) (*swarm.Swarm, error) { + if ps == nil { + return nil, fmt.Errorf("no peerstore specified") + } + if err := ps.AddPrivKey(p, k); err != nil { + return nil, err + } + if err := ps.AddPubKey(p, k.GetPublic()); err != nil { + return nil, err + } + + opts := slices.Clone(swarmConfig.Opts) + opts = append(opts, + swarm.WithUDPBlackHoleSuccessCounter(UDPBlackHoleSuccessCounter), + swarm.WithIPv6BlackHoleSuccessCounter(IPv6BlackHoleSuccessCounter), + swarm.WithDialTimeout(dialConfig.DialTimeout), + swarm.WithResourceManager(rcmgr), + ) + if multiaddrResolver.IsSome { + opts = append(opts, swarm.WithMultiaddrResolver(multiaddrResolver.Unwrap())) + } + if dialRanker.IsSome { + opts = append(opts, swarm.WithDialRanker(dialRanker.Unwrap())) + } + if gater != nil { + opts = append(opts, swarm.WithConnectionGater(gater)) + } + if metricsCfg.PrometheusRegisterer != nil { + opts = append(opts, + swarm.WithMetricsTracer(swarm.NewMetricsTracer(swarm.WithRegisterer(metricsCfg.PrometheusRegisterer)))) + } + if metricsCfg.BandwidthReporter != nil { + opts = append(opts, swarm.WithMetrics(metricsCfg.BandwidthReporter)) + } + if swarmConfig.ReadOnlyBlackHoleDetector { + opts = append(opts, swarm.WithReadOnlyBlackHoleDetector()) + } + + s, err := swarm.NewSwarm(p, ps, b, opts...) + if err != nil { + return nil, err + } + + l.OnStart(func() error { + return s.Listen(slices.Clone(listenAddrs)...) + }) + l.OnClose(s) + + return s, nil + }), + }, + + ConnGater: di.MustProvide[connmgr.ConnectionGater](nil), + UpgraderConfig: UpgraderConfig{ + Muxers: []tptu.StreamMuxer{ + {ID: yamux.ID, Muxer: yamux.DefaultTransport}, + }, + Security: []di.Provide[sec.SecureTransport]{ + di.MustProvide[sec.SecureTransport](func( + privkey crypto.PrivKey, muxers []tptu.StreamMuxer) (sec.SecureTransport, error) { + return libp2ptls.New(libp2ptls.ID, privkey, muxers) + }), + di.MustProvide[sec.SecureTransport](func( + privkey crypto.PrivKey, muxers []tptu.StreamMuxer) (sec.SecureTransport, error) { + return noise.New(noise.ID, privkey, muxers) + }), + }, + UpgraderOptions: []di.Provide[tptu.Option]{}, + Upgrader: di.MustProvide[transport.Upgrader](func( + security []sec.SecureTransport, + muxers []tptu.StreamMuxer, + rcmgr network.ResourceManager, + connGater connmgr.ConnectionGater, + upgraderOpts []tptu.Option, + ) (transport.Upgrader, error) { + // No PSK. Use a different config for PSK + return tptu.New( + security, muxers, nil, rcmgr, connGater, upgraderOpts..., + ) + }), + }, + DialConfig: DialConfig{ + DialTimeout: 10 * time.Second, + }, + MetricsConfig: MetricsConfig{ + PrometheusRegisterer: prometheus.DefaultRegisterer, + }, + AutoRelayConfig: AutoRelayConfig{ + Enabled: false, + }, + BasicHostConfig: BasicHostConfig{ + ObservedAddrsManager: di.MustProvide[bhost.ObservedAddrsManager]( + func(l *Lifecycle, eventBus event.Bus, s *swarm.Swarm) (bhost.ObservedAddrsManager, error) { + o, err := observedaddrs.NewManager(eventBus, s) + if err != nil { + return nil, err + } + l.OnStart(func() error { + o.Start(s) + return nil + }) + l.OnClose(o) + return o, nil + }), + }, + + AutoNatConfig: AutoNatConfig{ + PrivateKey: func() (AutoNatPrivKey, error) { + priv, _, err := crypto.GenerateEd25519Key(rand.Reader) + return AutoNatPrivKey(priv), err + }, + Peerstore: di.MustProvide[AutoNatPeerStore]( + func(l *Lifecycle) (AutoNatPeerStore, error) { + ps, err := pstoremem.NewPeerstore() + l.OnClose(ps) + return AutoNatPeerStore(ps), err + }, + ), + Opts: []autonatv2.AutoNATOption{}, + + AutoNAT: di.MustProvide[*autonatv2.AutoNAT]( + func( + prometheusRegisterer prometheus.Registerer, + autonatHost AutoNATHost, + autonatOptions []autonatv2.AutoNATOption, + ) (*autonatv2.AutoNAT, error) { + if prometheusRegisterer != nil { + mt := autonatv2.NewMetricsTracer(prometheusRegisterer) + autonatOptions = append( + []autonatv2.AutoNATOption{autonatv2.WithMetricsTracer(mt)}, + autonatOptions..., + ) + } + autoNATv2, err := autonatv2.New(autonatHost, autonatOptions...) + if err != nil { + return nil, fmt.Errorf("failed to create autonatv2: %w", err) + } + return autoNATv2, nil + }, + ), + Host: di.MustProvide[AutoNATHost]( + func( + config Config, + k AutoNatPrivKey, + ps AutoNatPeerStore, + l *Lifecycle, + ) (AutoNATHost, error) { + // Use the same provided config, but override some + autonatCfg := config + autonatCfg.ListenAddrs = nil + autonatCfg.Peerstore = func() (peerstore.Peerstore, error) { + return ps, nil + } + autonatCfg.PrivateKey = func() (crypto.PrivKey, error) { + return k, nil + } + autonatCfg.Lifecycle = func() *Lifecycle { + // Use the same lifecycle as our parent config + return l + } + autonatCfg.ReadOnlyBlackHoleDetector = true + autonatCfg.Host = di.MustProvide[host.Host](func( + l *Lifecycle, + swarm *swarm.Swarm, + ) host.Host { + mux := mstream.NewMultistreamMuxer[protocol.ID]() + h := &blankhost.BlankHost{ + N: swarm, + M: mux, + ConnMgr: connmgr.NullConnMgr{}, + E: nil, + // Don't need this for autonat + SkipInitSignedRecord: true, + } + l.OnStart(func() error { + return h.Start() + }) + l.OnClose(h) + return h + }) + type Result struct { + Host host.Host + _ []di.SideEffect + } + + res, err := di.New[Result](autonatCfg) + return res.Host, err + }, + ), + }, + + SideEffects: []di.Provide[di.SideEffect]{ + di.MustProvide[di.SideEffect](func(metricsRegisterer prometheus.Registerer) (di.SideEffect, error) { + rcmgr.MustRegisterWith(metricsRegisterer) + return di.SideEffect{}, nil + }), + di.MustProvide[di.SideEffect](func(logger *slog.Logger, rcmgr network.ResourceManager, cmgr connmgr.ConnManager) (di.SideEffect, error) { + if l, ok := rcmgr.(connmgr.GetConnLimiter); ok { + err := cmgr.CheckLimit(l) + if err != nil { + logger.Warn("rcmgr limit conflicts with connmgr limit", "err", err) + } + } + return di.SideEffect{}, nil + }), + di.MustProvide[di.SideEffect](func(s *swarm.Swarm, tpts []transport.Transport) (di.SideEffect, error) { + for _, t := range tpts { + err := s.AddTransport(t) + if err != nil { + return di.SideEffect{}, err + } + } + return di.SideEffect{}, nil + }), + }, + + Host: di.MustProvide[host.Host](func( + l *Lifecycle, + identifyConfig IdentifyConfig, + basicHostConfig BasicHostConfig, + observedAddrManager bhost.ObservedAddrsManager, + metricsConfig MetricsConfig, + autoRelayConfig AutoRelayConfig, + network *swarm.Swarm, + connmgr connmgr.ConnManager, + autonat *autonatv2.AutoNAT, + routingC di.Optional[RoutingC], + eventBus event.Bus, + ) (h host.Host, err error) { + bh, err := bhost.NewHost(network, &bhost.HostOpts{ + EventBus: eventBus, + ConnManager: connmgr, + AddrsFactory: basicHostConfig.AddrsFactory.Val, + NATManager: basicHostConfig.NATManager, + EnablePing: basicHostConfig.EnablePing, + UserAgent: identifyConfig.UserAgent, + ProtocolVersion: identifyConfig.ProtocolVersion, + EnableHolePunching: basicHostConfig.EnableHolePunching, + HolePunchingOptions: basicHostConfig.HolePunchingOptions, + EnableRelayService: basicHostConfig.EnableRelay, + RelayServiceOpts: basicHostConfig.RelayServiceOpts, + EnableMetrics: metricsConfig.PrometheusRegisterer != nil, + PrometheusRegisterer: metricsConfig.PrometheusRegisterer, + ObservedAddrsManager: observedAddrManager, + AutoNATv2: autonat, + }) + if err != nil { + return nil, err + } + l.OnStart(func() error { + bh.Start() + return nil + }) + l.OnClose(bh) + h = bh + + if routingC.IsSome { + router, err := routingC.Val(h) + if err != nil { + return nil, err + } + + h = routed.Wrap(bh, router) + } + + if autoRelayConfig.Enabled { + autorelayOpts := autoRelayConfig.Opts + if metricsConfig.PrometheusRegisterer != nil { + mt := autorelay.WithMetricsTracer( + autorelay.NewMetricsTracer(autorelay.WithRegisterer(metricsConfig.PrometheusRegisterer))) + mtOpts := []autorelay.Option{mt} + autorelayOpts = append(mtOpts, autoRelayConfig.Opts...) + } + ar, err := autorelay.NewAutoRelay(h, autorelayOpts...) + if err != nil { + return nil, err + } + l.OnStart(func() error { + ar.Start() + return nil + }) + l.OnClose(ar) + } + + return h, nil + }), +} + +type hostWithLifecycle struct { + host.Host + lifecycle io.Closer +} + +func (h *hostWithLifecycle) Close() error { + return h.lifecycle.Close() +} + +// NewHost is a helper function for the common case of constructing host that +// will clean up the lifecycle of all instantiated objects when closed. +func NewHost(config Config) (host.Host, error) { + type Result struct { + L *Lifecycle + _ []di.SideEffect + Host host.Host + } + r, err := di.New[Result](config) + if err != nil { + return nil, err + } + + if err := r.L.Start(); err != nil { + return nil, err + } + + return &hostWithLifecycle{r.Host, r.L}, nil +} diff --git a/x/builder/builder_test.go b/x/builder/builder_test.go new file mode 100644 index 0000000000..54c28d01aa --- /dev/null +++ b/x/builder/builder_test.go @@ -0,0 +1,76 @@ +package builder + +import ( + "context" + "fmt" + "io" + "testing" + + "git.sr.ht/~marcopolo/di" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" +) + +func newHost(t *testing.T) host.Host { + type Result struct { + Host host.Host + L *Lifecycle + _ []di.SideEffect + } + var r Result + if err := di.Build(DefaultConfig, &r); err != nil { + t.Fatal(err) + } + + if r.Host == nil { + t.Fatal("host is nil") + } + + if err := r.L.Start(); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + if err := r.Host.Close(); err != nil { + t.Fatal(err) + } + }) + return r.Host +} + +func TestEcho(t *testing.T) { + a := newHost(t) + b := newHost(t) + + b.SetStreamHandler("/echo/1", func(s network.Stream) { + io.Copy(s, s) + s.Close() + }) + fmt.Println("B addrs", b.Addrs()) + a.Connect(context.Background(), peer.AddrInfo{ + ID: b.ID(), + Addrs: b.Addrs(), + }) + + s, err := a.NewStream(context.Background(), b.ID(), "/echo/1") + if err != nil { + t.Fatal(err) + } + _, err = s.Write([]byte("hello")) + if err != nil { + t.Fatal(err) + } + if err := s.CloseWrite(); err != nil { + t.Fatal(err) + } + msgBack, err := io.ReadAll(s) + if err != nil { + t.Fatal(err) + } + if string(msgBack) != "hello" { + t.Fatalf("expected 'hello', got '%s'", string(msgBack)) + } + + t.Logf("A Peer ID: %s\n", a.ID()) + t.Logf("B Peer ID: %s\n", b.ID()) +} diff --git a/x/builder/examples_test.go b/x/builder/examples_test.go new file mode 100644 index 0000000000..161abb98cd --- /dev/null +++ b/x/builder/examples_test.go @@ -0,0 +1,115 @@ +package builder_test + +import ( + "errors" + "fmt" + "slices" + "strings" + + "git.sr.ht/~marcopolo/di" + "github.com/libp2p/go-libp2p/core/event" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/transport" + libp2pquic "github.com/libp2p/go-libp2p/p2p/transport/quic" + "github.com/libp2p/go-libp2p/x/builder" + "github.com/multiformats/go-multiaddr" +) + +func ExampleDefaultConfig() { + config := builder.DefaultConfig + host, err := builder.NewHost(config) + if err != nil { + panic(err) + } + if len(host.Addrs()) > 0 { + fmt.Println("listening on some transports") + } + // Output: listening on some transports +} + +func ExampleDefaultConfig_extractEventBus() { + config := builder.DefaultConfig + type Result struct { + Host host.Host + Bus event.Bus + Lifecycle *builder.Lifecycle + _ []di.SideEffect + } + res, err := di.New[Result](config) + if err != nil { + panic(err) + } + + // We must call Lifecycle.Start to start all the services. + if err := res.Lifecycle.Start(); err != nil { + panic(err) + } + + // And we must remember to end the lifecycle of these services by calling + // Lifecycle.Close() + defer func() { + if err := res.Lifecycle.Close(); err != nil { + panic(err) + } + }() + + if len(res.Host.Addrs()) > 0 { + fmt.Println("listening on some transports") + } + if res.Bus != nil { + fmt.Println("and I have a reference to an event bus") + } + // Output: + // listening on some transports + // and I have a reference to an event bus +} + +// ExampleDefaultConfig_onlyQUIC shows how to customize the default config to +// build a host with only a quic transport +func ExampleDefaultConfig_onlyQUIC() { + config := builder.DefaultConfig + config.Transports = di.MustProvide[[]transport.Transport]( + func( + quic *libp2pquic.Transport, + ) (tpts []transport.Transport, err error) { + if quic == nil { + return nil, errors.New("quic transport is required") + } + return append(tpts, quic), nil + }, + ) + type Result struct { + Host host.Host + Lifecycle *builder.Lifecycle + _ []di.SideEffect + } + res, err := di.New[Result](config) + if err != nil { + panic(err) + } + + // We must call Lifecycle.Start to start all the services. + if err := res.Lifecycle.Start(); err != nil { + panic(err) + } + + // And we must remember to end the lifecycle of these services by calling + // Lifecycle.Close() + defer func() { + if err := res.Lifecycle.Close(); err != nil { + panic(err) + } + }() + + addrs := res.Host.Addrs() + onlyQuicAddrs := slices.DeleteFunc(addrs, func(m multiaddr.Multiaddr) bool { + return !strings.Contains(m.String(), "quic-v1") + }) + + if len(onlyQuicAddrs) != len(addrs) { + panic("should only be listening on QUIC addresses") + } + fmt.Println("I have a QUIC listener") + // Output: + // I have a QUIC listener +}