diff --git a/client/iface/device.go b/client/iface/device.go index 921f0ea9848..c0c829825ee 100644 --- a/client/iface/device.go +++ b/client/iface/device.go @@ -23,4 +23,5 @@ type WGTunDevice interface { FilteredDevice() *device.FilteredDevice Device() *wgdevice.Device GetNet() *netstack.Net + GetICEBind() device.EndpointManager } diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index a731684cccd..48346fc0f06 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -150,6 +150,11 @@ func (t *WGTunDevice) GetNet() *netstack.Net { return nil } +// GetICEBind returns the ICEBind instance +func (t *WGTunDevice) GetICEBind() EndpointManager { + return t.iceBind +} + func routesToString(routes []string) string { return strings.Join(routes, ";") } diff --git a/client/iface/device/device_darwin.go b/client/iface/device/device_darwin.go index 390efe0880b..acd5f6f11f0 100644 --- a/client/iface/device/device_darwin.go +++ b/client/iface/device/device_darwin.go @@ -154,3 +154,8 @@ func (t *TunDevice) assignAddr() error { func (t *TunDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *TunDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/device_ios.go b/client/iface/device/device_ios.go index 96e4c8bcf1a..f96edf99246 100644 --- a/client/iface/device/device_ios.go +++ b/client/iface/device/device_ios.go @@ -144,3 +144,8 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice { func (t *TunDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *TunDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/device_kernel_unix.go b/client/iface/device/device_kernel_unix.go index cdac43a5314..2a836f84646 100644 --- a/client/iface/device/device_kernel_unix.go +++ b/client/iface/device/device_kernel_unix.go @@ -179,3 +179,8 @@ func (t *TunKernelDevice) assignAddr() error { func (t *TunKernelDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns nil for kernel mode devices +func (t *TunKernelDevice) GetICEBind() EndpointManager { + return nil +} diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index e37321b6876..40d8fdac869 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -21,6 +21,7 @@ type Bind interface { conn.Bind GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) ActivityRecorder() *bind.ActivityRecorder + EndpointManager } type TunNetstackDevice struct { @@ -155,3 +156,8 @@ func (t *TunNetstackDevice) Device() *device.Device { func (t *TunNetstackDevice) GetNet() *netstack.Net { return t.net } + +// GetICEBind returns the bind instance +func (t *TunNetstackDevice) GetICEBind() EndpointManager { + return t.bind +} diff --git a/client/iface/device/device_usp_unix.go b/client/iface/device/device_usp_unix.go index 4cdd70a3266..24654fc0361 100644 --- a/client/iface/device/device_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -146,3 +146,8 @@ func (t *USPDevice) assignAddr() error { func (t *USPDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *USPDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/device_windows.go b/client/iface/device/device_windows.go index f1023bc0a9e..96350df8a6d 100644 --- a/client/iface/device/device_windows.go +++ b/client/iface/device/device_windows.go @@ -185,3 +185,8 @@ func (t *TunDevice) assignAddr() error { func (t *TunDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *TunDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/endpoint_manager.go b/client/iface/device/endpoint_manager.go new file mode 100644 index 00000000000..b53888baab6 --- /dev/null +++ b/client/iface/device/endpoint_manager.go @@ -0,0 +1,13 @@ +package device + +import ( + "net" + "net/netip" +) + +// EndpointManager manages fake IP to connection mappings for userspace bind implementations. +// Implemented by bind.ICEBind and bind.RelayBindJS. +type EndpointManager interface { + SetEndpoint(fakeIP netip.Addr, conn net.Conn) + RemoveEndpoint(fakeIP netip.Addr) +} diff --git a/client/iface/device_android.go b/client/iface/device_android.go index 4649b8b97f6..cdfcea48dff 100644 --- a/client/iface/device_android.go +++ b/client/iface/device_android.go @@ -21,4 +21,5 @@ type WGTunDevice interface { FilteredDevice() *device.FilteredDevice Device() *wgdevice.Device GetNet() *netstack.Net + GetICEBind() device.EndpointManager } diff --git a/client/iface/iface.go b/client/iface/iface.go index 15867216065..07235a99598 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -80,6 +80,17 @@ func (w *WGIface) GetProxy() wgproxy.Proxy { return w.wgProxyFactory.GetProxy() } +// GetBind returns the EndpointManager userspace bind mode. +func (w *WGIface) GetBind() device.EndpointManager { + w.mu.Lock() + defer w.mu.Unlock() + + if w.tun == nil { + return nil + } + return w.tun.GetICEBind() +} + // IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind func (w *WGIface) IsUserspaceBind() bool { return w.userspaceBind diff --git a/client/internal/lazyconn/activity/lazy_conn.go b/client/internal/lazyconn/activity/lazy_conn.go new file mode 100644 index 00000000000..2564a9905cb --- /dev/null +++ b/client/internal/lazyconn/activity/lazy_conn.go @@ -0,0 +1,82 @@ +package activity + +import ( + "context" + "io" + "net" + "time" +) + +// lazyConn detects activity when WireGuard attempts to send packets. +// It does not deliver packets, only signals that activity occurred. +type lazyConn struct { + activityCh chan struct{} + ctx context.Context + cancel context.CancelFunc +} + +// newLazyConn creates a new lazyConn for activity detection. +func newLazyConn() *lazyConn { + ctx, cancel := context.WithCancel(context.Background()) + return &lazyConn{ + activityCh: make(chan struct{}, 1), + ctx: ctx, + cancel: cancel, + } +} + +// Read blocks until the connection is closed. +func (c *lazyConn) Read(_ []byte) (n int, err error) { + <-c.ctx.Done() + return 0, io.EOF +} + +// Write signals activity detection when ICEBind routes packets to this endpoint. +func (c *lazyConn) Write(b []byte) (n int, err error) { + if c.ctx.Err() != nil { + return 0, io.EOF + } + + select { + case c.activityCh <- struct{}{}: + default: + } + + return len(b), nil +} + +// ActivityChan returns the channel that signals when activity is detected. +func (c *lazyConn) ActivityChan() <-chan struct{} { + return c.activityCh +} + +// Close closes the connection. +func (c *lazyConn) Close() error { + c.cancel() + return nil +} + +// LocalAddr returns the local address. +func (c *lazyConn) LocalAddr() net.Addr { + return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort} +} + +// RemoteAddr returns the remote address. +func (c *lazyConn) RemoteAddr() net.Addr { + return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort} +} + +// SetDeadline sets the read and write deadlines. +func (c *lazyConn) SetDeadline(_ time.Time) error { + return nil +} + +// SetReadDeadline sets the deadline for future Read calls. +func (c *lazyConn) SetReadDeadline(_ time.Time) error { + return nil +} + +// SetWriteDeadline sets the deadline for future Write calls. +func (c *lazyConn) SetWriteDeadline(_ time.Time) error { + return nil +} diff --git a/client/internal/lazyconn/activity/listener_bind.go b/client/internal/lazyconn/activity/listener_bind.go new file mode 100644 index 00000000000..792d0421550 --- /dev/null +++ b/client/internal/lazyconn/activity/listener_bind.go @@ -0,0 +1,127 @@ +package activity + +import ( + "fmt" + "net" + "net/netip" + "sync" + + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/internal/lazyconn" +) + +type bindProvider interface { + GetBind() device.EndpointManager +} + +const ( + // lazyBindPort is an obscure port used for lazy peer endpoints to avoid confusion with real peers. + // The actual routing is done via fakeIP in ICEBind, not by this port. + lazyBindPort = 17473 +) + +// BindListener uses lazyConn with bind implementations for direct data passing in userspace bind mode. +type BindListener struct { + wgIface WgInterface + peerCfg lazyconn.PeerConfig + done sync.WaitGroup + + lazyConn *lazyConn + bind device.EndpointManager + fakeIP netip.Addr +} + +// NewBindListener creates a listener that passes data directly through bind using LazyConn. +// It automatically derives a unique fake IP from the peer's NetBird IP in the 127.2.x.x range. +func NewBindListener(wgIface WgInterface, bind device.EndpointManager, cfg lazyconn.PeerConfig) (*BindListener, error) { + fakeIP, err := deriveFakeIP(wgIface, cfg.AllowedIPs) + if err != nil { + return nil, fmt.Errorf("derive fake IP: %w", err) + } + + d := &BindListener{ + wgIface: wgIface, + peerCfg: cfg, + bind: bind, + fakeIP: fakeIP, + } + + if err := d.setupLazyConn(); err != nil { + return nil, fmt.Errorf("setup lazy connection: %v", err) + } + + d.done.Add(1) + return d, nil +} + +// deriveFakeIP creates a deterministic fake IP for bind mode based on peer's NetBird IP. +// Maps peer IP 100.64.x.y to fake IP 127.2.x.y (similar to relay proxy using 127.1.x.y). +// It finds the peer's actual NetBird IP by checking which allowedIP is in the same subnet as our WG interface. +func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, error) { + if len(allowedIPs) == 0 { + return netip.Addr{}, fmt.Errorf("no allowed IPs for peer") + } + + ourNetwork := wgIface.Address().Network + + var peerIP netip.Addr + for _, allowedIP := range allowedIPs { + ip := allowedIP.Addr() + if !ip.Is4() { + continue + } + if ourNetwork.Contains(ip) { + peerIP = ip + break + } + } + + if !peerIP.IsValid() { + return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs") + } + + octets := peerIP.As4() + fakeIP := netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]}) + return fakeIP, nil +} + +func (d *BindListener) setupLazyConn() error { + d.lazyConn = newLazyConn() + d.bind.SetEndpoint(d.fakeIP, d.lazyConn) + + endpoint := &net.UDPAddr{ + IP: d.fakeIP.AsSlice(), + Port: lazyBindPort, + } + return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, endpoint, nil) +} + +// ReadPackets blocks until activity is detected on the LazyConn or the listener is closed. +func (d *BindListener) ReadPackets() { + select { + case <-d.lazyConn.ActivityChan(): + d.peerCfg.Log.Infof("activity detected via LazyConn") + case <-d.lazyConn.ctx.Done(): + d.peerCfg.Log.Infof("exit from activity listener") + } + + d.peerCfg.Log.Debugf("removing lazy endpoint for peer %s", d.peerCfg.PublicKey) + if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil { + d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err) + } + + _ = d.lazyConn.Close() + d.bind.RemoveEndpoint(d.fakeIP) + d.done.Done() +} + +// Close stops the listener and cleans up resources. +func (d *BindListener) Close() { + d.peerCfg.Log.Infof("closing activity listener (LazyConn)") + + if err := d.lazyConn.Close(); err != nil { + d.peerCfg.Log.Errorf("failed to close LazyConn: %s", err) + } + + d.done.Wait() +} diff --git a/client/internal/lazyconn/activity/listener_bind_test.go b/client/internal/lazyconn/activity/listener_bind_test.go new file mode 100644 index 00000000000..f86dd3877fc --- /dev/null +++ b/client/internal/lazyconn/activity/listener_bind_test.go @@ -0,0 +1,291 @@ +package activity + +import ( + "net" + "net/netip" + "runtime" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/internal/lazyconn" + peerid "github.com/netbirdio/netbird/client/internal/peer/id" +) + +func isBindListenerPlatform() bool { + return runtime.GOOS == "windows" || runtime.GOOS == "js" +} + +// mockEndpointManager implements device.EndpointManager for testing +type mockEndpointManager struct { + endpoints map[netip.Addr]net.Conn +} + +func newMockEndpointManager() *mockEndpointManager { + return &mockEndpointManager{ + endpoints: make(map[netip.Addr]net.Conn), + } +} + +func (m *mockEndpointManager) SetEndpoint(fakeIP netip.Addr, conn net.Conn) { + m.endpoints[fakeIP] = conn +} + +func (m *mockEndpointManager) RemoveEndpoint(fakeIP netip.Addr) { + delete(m.endpoints, fakeIP) +} + +func (m *mockEndpointManager) GetEndpoint(fakeIP netip.Addr) net.Conn { + return m.endpoints[fakeIP] +} + +// MockWGIfaceBind mocks WgInterface with bind support +type MockWGIfaceBind struct { + endpointMgr *mockEndpointManager +} + +func (m *MockWGIfaceBind) RemovePeer(string) error { + return nil +} + +func (m *MockWGIfaceBind) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error { + return nil +} + +func (m *MockWGIfaceBind) IsUserspaceBind() bool { + return true +} + +func (m *MockWGIfaceBind) Address() wgaddr.Address { + return wgaddr.Address{ + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/16"), + } +} + +func (m *MockWGIfaceBind) GetBind() device.EndpointManager { + return m.endpointMgr +} + +func TestBindListener_Creation(t *testing.T) { + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg) + require.NoError(t, err) + + expectedFakeIP := netip.MustParseAddr("127.2.0.2") + conn := mockEndpointMgr.GetEndpoint(expectedFakeIP) + require.NotNil(t, conn, "Endpoint should be registered in mock endpoint manager") + + _, ok := conn.(*lazyConn) + assert.True(t, ok, "Registered endpoint should be a lazyConn") + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } +} + +func TestBindListener_ActivityDetection(t *testing.T) { + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg) + require.NoError(t, err) + + activityDetected := make(chan struct{}) + go func() { + listener.ReadPackets() + close(activityDetected) + }() + + fakeIP := listener.fakeIP + conn := mockEndpointMgr.GetEndpoint(fakeIP) + require.NotNil(t, conn, "Endpoint should be registered") + + _, err = conn.Write([]byte{0x01, 0x02, 0x03}) + require.NoError(t, err) + + select { + case <-activityDetected: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity detection") + } + + assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity detection") +} + +func TestBindListener_Close(t *testing.T) { + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg) + require.NoError(t, err) + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + fakeIP := listener.fakeIP + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } + + assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after Close") +} + +func TestManager_BindMode(t *testing.T) { + if !isBindListenerPlatform() { + t.Skip("BindListener only used on Windows/JS platforms") + } + + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + mgr := NewManager(mockIface) + defer mgr.Close() + + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + err := mgr.MonitorPeerActivity(cfg) + require.NoError(t, err) + + listener, exists := mgr.GetPeerListener(cfg.PeerConnID) + require.True(t, exists, "Peer listener should be found") + + bindListener, ok := listener.(*BindListener) + require.True(t, ok, "Listener should be BindListener, got %T", listener) + + fakeIP := bindListener.fakeIP + conn := mockEndpointMgr.GetEndpoint(fakeIP) + require.NotNil(t, conn, "Endpoint should be registered") + + _, err = conn.Write([]byte{0x01, 0x02, 0x03}) + require.NoError(t, err) + + select { + case peerConnID := <-mgr.OnActivityChan: + assert.Equal(t, cfg.PeerConnID, peerConnID, "Received peer connection ID should match") + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity notification") + } + + assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity") +} + +func TestManager_BindMode_MultiplePeers(t *testing.T) { + if !isBindListenerPlatform() { + t.Skip("BindListener only used on Windows/JS platforms") + } + + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer1 := &MocPeer{PeerID: "testPeer1"} + peer2 := &MocPeer{PeerID: "testPeer2"} + mgr := NewManager(mockIface) + defer mgr.Close() + + cfg1 := lazyconn.PeerConfig{ + PublicKey: peer1.PeerID, + PeerConnID: peer1.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + cfg2 := lazyconn.PeerConfig{ + PublicKey: peer2.PeerID, + PeerConnID: peer2.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.3/32")}, + Log: log.WithField("peer", "testPeer2"), + } + + err := mgr.MonitorPeerActivity(cfg1) + require.NoError(t, err) + + err = mgr.MonitorPeerActivity(cfg2) + require.NoError(t, err) + + listener1, exists := mgr.GetPeerListener(cfg1.PeerConnID) + require.True(t, exists, "Peer1 listener should be found") + bindListener1 := listener1.(*BindListener) + + listener2, exists := mgr.GetPeerListener(cfg2.PeerConnID) + require.True(t, exists, "Peer2 listener should be found") + bindListener2 := listener2.(*BindListener) + + conn1 := mockEndpointMgr.GetEndpoint(bindListener1.fakeIP) + require.NotNil(t, conn1, "Peer1 endpoint should be registered") + _, err = conn1.Write([]byte{0x01}) + require.NoError(t, err) + + conn2 := mockEndpointMgr.GetEndpoint(bindListener2.fakeIP) + require.NotNil(t, conn2, "Peer2 endpoint should be registered") + _, err = conn2.Write([]byte{0x02}) + require.NoError(t, err) + + receivedPeers := make(map[peerid.ConnID]bool) + for i := 0; i < 2; i++ { + select { + case peerConnID := <-mgr.OnActivityChan: + receivedPeers[peerConnID] = true + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity notifications") + } + } + + assert.True(t, receivedPeers[cfg1.PeerConnID], "Peer1 activity should be received") + assert.True(t, receivedPeers[cfg2.PeerConnID], "Peer2 activity should be received") +} diff --git a/client/internal/lazyconn/activity/listener_test.go b/client/internal/lazyconn/activity/listener_test.go deleted file mode 100644 index 98d7838d264..00000000000 --- a/client/internal/lazyconn/activity/listener_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package activity - -import ( - "testing" - "time" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/lazyconn" -) - -func TestNewListener(t *testing.T) { - peer := &MocPeer{ - PeerID: "examplePublicKey1", - } - - cfg := lazyconn.PeerConfig{ - PublicKey: peer.PeerID, - PeerConnID: peer.ConnID(), - Log: log.WithField("peer", "examplePublicKey1"), - } - - l, err := NewListener(MocWGIface{}, cfg) - if err != nil { - t.Fatalf("failed to create listener: %v", err) - } - - chanClosed := make(chan struct{}) - go func() { - defer close(chanClosed) - l.ReadPackets() - }() - - time.Sleep(1 * time.Second) - l.Close() - - select { - case <-chanClosed: - case <-time.After(time.Second): - } -} diff --git a/client/internal/lazyconn/activity/listener.go b/client/internal/lazyconn/activity/listener_udp.go similarity index 64% rename from client/internal/lazyconn/activity/listener.go rename to client/internal/lazyconn/activity/listener_udp.go index 817ff00c38c..e0b09be6c5b 100644 --- a/client/internal/lazyconn/activity/listener.go +++ b/client/internal/lazyconn/activity/listener_udp.go @@ -11,26 +11,27 @@ import ( "github.com/netbirdio/netbird/client/internal/lazyconn" ) -// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking -type Listener struct { +// UDPListener uses UDP sockets for activity detection in kernel mode. +type UDPListener struct { wgIface WgInterface peerCfg lazyconn.PeerConfig conn *net.UDPConn endpoint *net.UDPAddr done sync.Mutex - isClosed atomic.Bool // use to avoid error log when closing the listener + isClosed atomic.Bool } -func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error) { - d := &Listener{ +// NewUDPListener creates a listener that detects activity via UDP socket reads. +func NewUDPListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*UDPListener, error) { + d := &UDPListener{ wgIface: wgIface, peerCfg: cfg, } conn, err := d.newConn() if err != nil { - return nil, fmt.Errorf("failed to creating activity listener: %v", err) + return nil, fmt.Errorf("create UDP connection: %v", err) } d.conn = conn d.endpoint = conn.LocalAddr().(*net.UDPAddr) @@ -38,12 +39,14 @@ func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error if err := d.createEndpoint(); err != nil { return nil, err } + d.done.Lock() - cfg.Log.Infof("created activity listener: %s", conn.LocalAddr().(*net.UDPAddr).String()) + cfg.Log.Infof("created activity listener: %s", d.conn.LocalAddr().(*net.UDPAddr).String()) return d, nil } -func (d *Listener) ReadPackets() { +// ReadPackets blocks reading from the UDP socket until activity is detected or the listener is closed. +func (d *UDPListener) ReadPackets() { for { n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1)) if err != nil { @@ -64,15 +67,17 @@ func (d *Listener) ReadPackets() { } d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String()) - if err := d.removeEndpoint(); err != nil { + if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil { d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err) } - _ = d.conn.Close() // do not care err because some cases it will return "use of closed network connection" + // Ignore close error as it may return "use of closed network connection" if already closed. + _ = d.conn.Close() d.done.Unlock() } -func (d *Listener) Close() { +// Close stops the listener and cleans up resources. +func (d *UDPListener) Close() { d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String()) d.isClosed.Store(true) @@ -82,16 +87,12 @@ func (d *Listener) Close() { d.done.Lock() } -func (d *Listener) removeEndpoint() error { - return d.wgIface.RemovePeer(d.peerCfg.PublicKey) -} - -func (d *Listener) createEndpoint() error { +func (d *UDPListener) createEndpoint() error { d.peerCfg.Log.Debugf("creating lazy endpoint: %s", d.endpoint.String()) return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, d.endpoint, nil) } -func (d *Listener) newConn() (*net.UDPConn, error) { +func (d *UDPListener) newConn() (*net.UDPConn, error) { addr := &net.UDPAddr{ Port: 0, IP: listenIP, diff --git a/client/internal/lazyconn/activity/listener_udp_test.go b/client/internal/lazyconn/activity/listener_udp_test.go new file mode 100644 index 00000000000..d2adb9bf491 --- /dev/null +++ b/client/internal/lazyconn/activity/listener_udp_test.go @@ -0,0 +1,110 @@ +package activity + +import ( + "net" + "net/netip" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/lazyconn" +) + +func TestUDPListener_Creation(t *testing.T) { + mockIface := &MocWGIface{} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewUDPListener(mockIface, cfg) + require.NoError(t, err) + require.NotNil(t, listener.conn) + require.NotNil(t, listener.endpoint) + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } +} + +func TestUDPListener_ActivityDetection(t *testing.T) { + mockIface := &MocWGIface{} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewUDPListener(mockIface, cfg) + require.NoError(t, err) + + activityDetected := make(chan struct{}) + go func() { + listener.ReadPackets() + close(activityDetected) + }() + + conn, err := net.Dial("udp", listener.conn.LocalAddr().String()) + require.NoError(t, err) + defer conn.Close() + + _, err = conn.Write([]byte{0x01, 0x02, 0x03}) + require.NoError(t, err) + + select { + case <-activityDetected: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity detection") + } +} + +func TestUDPListener_Close(t *testing.T) { + mockIface := &MocWGIface{} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewUDPListener(mockIface, cfg) + require.NoError(t, err) + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } + + assert.True(t, listener.isClosed.Load(), "Listener should be marked as closed") +} diff --git a/client/internal/lazyconn/activity/manager.go b/client/internal/lazyconn/activity/manager.go index 915fb9cb89c..db283ec9afe 100644 --- a/client/internal/lazyconn/activity/manager.go +++ b/client/internal/lazyconn/activity/manager.go @@ -1,21 +1,32 @@ package activity import ( + "errors" "net" "net/netip" + "runtime" "sync" "time" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/lazyconn" peerid "github.com/netbirdio/netbird/client/internal/peer/id" ) +// listener defines the contract for activity detection listeners. +type listener interface { + ReadPackets() + Close() +} + type WgInterface interface { RemovePeer(peerKey string) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + IsUserspaceBind() bool + Address() wgaddr.Address } type Manager struct { @@ -23,7 +34,7 @@ type Manager struct { wgIface WgInterface - peers map[peerid.ConnID]*Listener + peers map[peerid.ConnID]listener done chan struct{} mu sync.Mutex @@ -33,7 +44,7 @@ func NewManager(wgIface WgInterface) *Manager { m := &Manager{ OnActivityChan: make(chan peerid.ConnID, 1), wgIface: wgIface, - peers: make(map[peerid.ConnID]*Listener), + peers: make(map[peerid.ConnID]listener), done: make(chan struct{}), } return m @@ -48,16 +59,38 @@ func (m *Manager) MonitorPeerActivity(peerCfg lazyconn.PeerConfig) error { return nil } - listener, err := NewListener(m.wgIface, peerCfg) + listener, err := m.createListener(peerCfg) if err != nil { return err } - m.peers[peerCfg.PeerConnID] = listener + m.peers[peerCfg.PeerConnID] = listener go m.waitForTraffic(listener, peerCfg.PeerConnID) return nil } +func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error) { + if !m.wgIface.IsUserspaceBind() { + return NewUDPListener(m.wgIface, peerCfg) + } + + // BindListener is only used on Windows and JS platforms: + // - JS: Cannot listen to UDP sockets + // - Windows: IP_UNICAST_IF socket option forces packets out the interface the default + // gateway points to, preventing them from reaching the loopback interface. + // BindListener bypasses this by passing data directly through the bind. + if runtime.GOOS != "windows" && runtime.GOOS != "js" { + return NewUDPListener(m.wgIface, peerCfg) + } + + provider, ok := m.wgIface.(bindProvider) + if !ok { + return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider") + } + + return NewBindListener(m.wgIface, provider.GetBind(), peerCfg) +} + func (m *Manager) RemovePeer(log *log.Entry, peerConnID peerid.ConnID) { m.mu.Lock() defer m.mu.Unlock() @@ -82,8 +115,8 @@ func (m *Manager) Close() { } } -func (m *Manager) waitForTraffic(listener *Listener, peerConnID peerid.ConnID) { - listener.ReadPackets() +func (m *Manager) waitForTraffic(l listener, peerConnID peerid.ConnID) { + l.ReadPackets() m.mu.Lock() if _, ok := m.peers[peerConnID]; !ok { diff --git a/client/internal/lazyconn/activity/manager_test.go b/client/internal/lazyconn/activity/manager_test.go index ae6c31da4bb..0768d9219fa 100644 --- a/client/internal/lazyconn/activity/manager_test.go +++ b/client/internal/lazyconn/activity/manager_test.go @@ -9,6 +9,7 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/lazyconn" peerid "github.com/netbirdio/netbird/client/internal/peer/id" ) @@ -30,16 +31,26 @@ func (m MocWGIface) RemovePeer(string) error { func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error { return nil +} +func (m MocWGIface) IsUserspaceBind() bool { + return false } -// Add this method to the Manager struct -func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (*Listener, bool) { +func (m MocWGIface) Address() wgaddr.Address { + return wgaddr.Address{ + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/16"), + } +} + +// GetPeerListener is a test helper to access listeners +func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (listener, bool) { m.mu.Lock() defer m.mu.Unlock() - listener, exists := m.peers[peerConnID] - return listener, exists + l, exists := m.peers[peerConnID] + return l, exists } func TestManager_MonitorPeerActivity(t *testing.T) { @@ -65,7 +76,12 @@ func TestManager_MonitorPeerActivity(t *testing.T) { t.Fatalf("peer listener not found") } - if err := trigger(listener.conn.LocalAddr().String()); err != nil { + // Get the UDP listener's address for triggering + udpListener, ok := listener.(*UDPListener) + if !ok { + t.Fatalf("expected UDPListener") + } + if err := trigger(udpListener.conn.LocalAddr().String()); err != nil { t.Fatalf("failed to trigger activity: %v", err) } @@ -97,7 +113,9 @@ func TestManager_RemovePeerActivity(t *testing.T) { t.Fatalf("failed to monitor peer activity: %v", err) } - addr := mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String() + listener, _ := mgr.GetPeerListener(peerCfg1.PeerConnID) + udpListener, _ := listener.(*UDPListener) + addr := udpListener.conn.LocalAddr().String() mgr.RemovePeer(peerCfg1.Log, peerCfg1.PeerConnID) @@ -147,7 +165,8 @@ func TestManager_MultiPeerActivity(t *testing.T) { t.Fatalf("peer listener for peer1 not found") } - if err := trigger(listener.conn.LocalAddr().String()); err != nil { + udpListener1, _ := listener.(*UDPListener) + if err := trigger(udpListener1.conn.LocalAddr().String()); err != nil { t.Fatalf("failed to trigger activity: %v", err) } @@ -156,7 +175,8 @@ func TestManager_MultiPeerActivity(t *testing.T) { t.Fatalf("peer listener for peer2 not found") } - if err := trigger(listener.conn.LocalAddr().String()); err != nil { + udpListener2, _ := listener.(*UDPListener) + if err := trigger(udpListener2.conn.LocalAddr().String()); err != nil { t.Fatalf("failed to trigger activity: %v", err) } diff --git a/client/internal/lazyconn/wgiface.go b/client/internal/lazyconn/wgiface.go index 0351904f7f9..0626c18154b 100644 --- a/client/internal/lazyconn/wgiface.go +++ b/client/internal/lazyconn/wgiface.go @@ -7,6 +7,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/monotime" ) @@ -14,5 +15,6 @@ type WGIface interface { RemovePeer(peerKey string) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error IsUserspaceBind() bool + Address() wgaddr.Address LastActivities() map[string]monotime.Time }