Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 76 additions & 35 deletions p2p/host/basic/addrs_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,26 +226,8 @@ func (a *addrsManager) startBackgroundWorker() (retErr error) {
)
}

var relayAddrs []ma.Multiaddr
// update relay addrs in case we're private
select {
case e := <-autoRelayAddrsSub.Out():
if evt, ok := e.(event.EvtAutoRelayAddrsUpdated); ok {
relayAddrs = slices.Clone(evt.RelayAddrs)
}
default:
}

select {
case e := <-autonatReachabilitySub.Out():
if evt, ok := e.(event.EvtLocalReachabilityChanged); ok {
a.hostReachability.Store(&evt.Reachability)
}
default:
}

a.wg.Add(1)
go a.background(autoRelayAddrsSub, autonatReachabilitySub, emitter, localAddrsEmitter, relayAddrs)
go a.background(autoRelayAddrsSub, autonatReachabilitySub, emitter, localAddrsEmitter)
return nil
}

Expand All @@ -254,7 +236,6 @@ func (a *addrsManager) background(
autonatReachabilitySub event.Subscription,
emitter event.Emitter,
localAddrsEmitter event.Emitter,
relayAddrs []ma.Multiaddr,
) {
defer a.wg.Done()
defer func() {
Expand All @@ -276,6 +257,24 @@ func (a *addrsManager) background(
}
}()

var relayAddrs []ma.Multiaddr
// update relay addrs in case we're private
select {
case e := <-autoRelayAddrsSub.Out():
if evt, ok := e.(event.EvtAutoRelayAddrsUpdated); ok {
relayAddrs = slices.Clone(evt.RelayAddrs)
}
default:
}

select {
case e := <-autonatReachabilitySub.Out():
if evt, ok := e.(event.EvtLocalReachabilityChanged); ok {
a.hostReachability.Store(&evt.Reachability)
}
default:
}

ticker := time.NewTicker(addrChangeTickrInterval)
defer ticker.Stop()
var previousAddrs hostAddrs
Expand Down Expand Up @@ -317,13 +316,13 @@ func (a *addrsManager) updateAddrs(prevHostAddrs hostAddrs, relayAddrs []ma.Mult
currReachableAddrs, currUnreachableAddrs, currUnknownAddrs = a.getConfirmedAddrs(localAddrs)
}
relayAddrs = slices.Clone(relayAddrs)
currAddrs := a.getAddrs(slices.Clone(localAddrs), relayAddrs)
currAddrs := a.getDialableAddrs(localAddrs, currReachableAddrs, currUnreachableAddrs, relayAddrs)
currAddrs = a.applyAddrsFactory(currAddrs)

if areAddrsDifferent(prevHostAddrs.addrs, currAddrs) {
_, _, removed := diffAddrs(prevHostAddrs.addrs, currAddrs)
a.updatePeerStore(currAddrs, removed)
}

a.addrsMx.Lock()
a.currentAddrs = hostAddrs{
addrs: append(a.currentAddrs.addrs[:0], currAddrs...),
Expand Down Expand Up @@ -412,25 +411,36 @@ func (a *addrsManager) notifyAddrsUpdated(emitter event.Emitter, localAddrsEmitt
// the node's relay addresses and private network addresses.
func (a *addrsManager) Addrs() []ma.Multiaddr {
a.addrsMx.RLock()
directAddrs := slices.Clone(a.currentAddrs.localAddrs)
relayAddrs := slices.Clone(a.currentAddrs.relayAddrs)
addrs := a.getDialableAddrs(a.currentAddrs.localAddrs, a.currentAddrs.reachableAddrs, a.currentAddrs.unreachableAddrs, a.currentAddrs.relayAddrs)
a.addrsMx.RUnlock()
return a.getAddrs(directAddrs, relayAddrs)
// don't hold the lock while applying addrs factory
return a.applyAddrsFactory(addrs)
}

// getAddrs returns the node's dialable addresses. Mutates localAddrs
func (a *addrsManager) getAddrs(localAddrs []ma.Multiaddr, relayAddrs []ma.Multiaddr) []ma.Multiaddr {
addrs := localAddrs
rch := a.hostReachability.Load()
if rch != nil && *rch == network.ReachabilityPrivate {
// Delete public addresses if the node's reachability is private, and we have relay addresses
if len(relayAddrs) > 0 {
// getDialableAddrs returns the node's dialable addrs. Doesn't mutate any argument.
func (a *addrsManager) getDialableAddrs(localAddrs, reachableAddrs, unreachableAddrs, relayAddrs []ma.Multiaddr) []ma.Multiaddr {
// remove known unreachable addrs
addrs := removeInSource(slices.Clone(localAddrs), unreachableAddrs)
// If we have no confirmed reachable addresses, add the relay addresses
if a.addrsReachabilityTracker != nil {
if len(reachableAddrs) == 0 {
addrs = append(addrs, relayAddrs...)
}
} else {
rch := a.hostReachability.Load()
// If we're only using autonatv1, remove public addrs and add relay addrs
if len(relayAddrs) > 0 && rch != nil && *rch == network.ReachabilityPrivate {
addrs = slices.DeleteFunc(addrs, manet.IsPublicAddr)
addrs = append(addrs, relayAddrs...)
}
}
// Make a copy. Consumers can modify the slice elements
addrs = slices.Clone(a.addrsFactory(addrs))
return addrs
}

func (a *addrsManager) applyAddrsFactory(addrs []ma.Multiaddr) []ma.Multiaddr {
af := a.addrsFactory(addrs)
// Copy to our slice in case addrsFactory returns its own same slice always.
addrs = append(addrs[:0], af...)
// Add certhashes for the addresses provided by the user via address factory.
addrs = a.addCertHashes(ma.Unique(addrs))
slices.SortFunc(addrs, func(a, b ma.Multiaddr) int { return a.Compare(b) })
Expand Down Expand Up @@ -871,7 +881,38 @@ func removeNotInSource(addrs, source []ma.Multiaddr) []ma.Multiaddr {
}
// a is in source, nothing to do
}
// j is the current element, i is the lowest index nil element
// Move all the nils to the end.
// j is the current element, i is lowest index of a nil element.
// At the end of every iteration all elements from i to j are nil.
i := 0
for j := range len(addrs) {
if addrs[j] != nil {
addrs[i], addrs[j] = addrs[j], addrs[i]
i++
}
}
return addrs[:i]
}

// removeInSource removes items from addrs that are present in source.
// Modifies the addrs slice in place
// addrs and source must be sorted using multiaddr.Compare.
func removeInSource(addrs, source []ma.Multiaddr) []ma.Multiaddr {
j := 0
// mark entries in source as nil
for i, a := range addrs {
// move right in source as long as a > source[j]
for j < len(source) && a.Compare(source[j]) > 0 {
j++
}
// a is in source, mark nil
if j < len(source) && a.Compare(source[j]) == 0 {
addrs[i] = nil
}
}
// Move all the nils to the end.
// j is the current element, i is lowest index of a nil element.
// At the end of every iteration all elements from i to j are nil.
i := 0
for j := range len(addrs) {
if addrs[j] != nil {
Expand Down
22 changes: 12 additions & 10 deletions p2p/host/basic/addrs_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package basichost
import (
"context"
"crypto/rand"
"errors"
"fmt"
"slices"
"sync/atomic"
Expand Down Expand Up @@ -406,10 +405,10 @@ func TestAddrsManagerReachabilityEvent(t *testing.T) {
F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
if reqs[0].Addr.Equal(publicQUIC) {
return autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil
} else if reqs[0].Addr.Equal(publicTCP) || reqs[0].Addr.Equal(publicQUIC2) {
} else if reqs[0].Addr.Equal(publicQUIC2) {
return autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPrivate}, nil
}
return autonatv2.Result{}, errors.New("invalid")
return autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityUnknown, AllAddrsRefused: true}, nil
},
},
})
Expand All @@ -429,17 +428,20 @@ func TestAddrsManagerReachabilityEvent(t *testing.T) {

// Wait for probes to complete and addresses to be classified
reachableAddrs := []ma.Multiaddr{publicQUIC}
unreachableAddrs := []ma.Multiaddr{publicTCP, publicQUIC2}
unreachableAddrs := []ma.Multiaddr{publicQUIC2}
unknownAddrs := []ma.Multiaddr{publicTCP}
select {
case e := <-sub.Out():
evt := e.(event.EvtHostReachableAddrsChanged)
require.ElementsMatch(t, reachableAddrs, evt.Reachable)
require.ElementsMatch(t, unreachableAddrs, evt.Unreachable)
require.Empty(t, evt.Unknown)
matest.AssertMultiaddrsMatch(t, reachableAddrs, evt.Reachable)
matest.AssertMultiaddrsMatch(t, unreachableAddrs, evt.Unreachable)
matest.AssertMultiaddrsMatch(t, unknownAddrs, evt.Unknown)
reachable, unreachable, unknown := am.ConfirmedAddrs()
require.ElementsMatch(t, reachable, reachableAddrs)
require.ElementsMatch(t, unreachable, unreachableAddrs)
require.Empty(t, unknown)
matest.AssertMultiaddrsMatch(t, reachableAddrs, reachable)
matest.AssertMultiaddrsMatch(t, unreachableAddrs, unreachable)
matest.AssertMultiaddrsMatch(t, unknownAddrs, unknown)
// unreachable addrs should be removed
matest.AssertMultiaddrsMatch(t, []ma.Multiaddr{publicQUIC, publicTCP}, am.Addrs())
case <-time.After(5 * time.Second):
t.Fatal("expected final event for reachability change after probing")
}
Expand Down
Loading