Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 105 additions & 2 deletions source/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ type serviceSource struct {
endpointSlicesInformer discoveryinformers.EndpointSliceInformer
podInformer coreinformers.PodInformer
nodeInformer coreinformers.NodeInformer
nodeEventHandler func()
serviceTypeFilter *serviceTypes
exposeInternalIPv6 bool

Expand Down Expand Up @@ -845,9 +846,111 @@ func (sc *serviceSource) AddEventHandler(_ context.Context, handler func()) {
if sc.listenEndpointEvents && sc.serviceTypeFilter.isRequired(v1.ServiceTypeNodePort, v1.ServiceTypeClusterIP) {
_, _ = sc.endpointSlicesInformer.Informer().AddEventHandler(eventHandlerFunc(handler))
}
if sc.serviceTypeFilter.isRequired(v1.ServiceTypeNodePort) {
_, _ = sc.nodeInformer.Informer().AddEventHandler(eventHandlerFunc(handler))
sc.nodeEventHandler = handler
// Register node handler eagerly if node data may influence endpoints.
if sc.nodeInformer != nil && sc.nodeEventHandler != nil {
explicitNodePort := sc.serviceTypeFilter != nil && sc.serviceTypeFilter.enabled && sc.serviceTypeFilter.types[v1.ServiceTypeNodePort]
if sc.publishHostIP || explicitNodePort || sc.anyServiceRequiresNodeData() {
_, _ = sc.nodeInformer.Informer().AddEventHandler(&nodeAddressChangeHandler{source: sc, handler: sc.nodeEventHandler})
}
}
}

// anyServiceRequiresNodeData checks current services for node-related endpoint needs
func (sc *serviceSource) anyServiceRequiresNodeData() bool {
if sc.serviceInformer == nil {
return false
}
var services []*v1.Service
var err error

func() {
defer func() { _ = recover() }()
services, err = sc.serviceInformer.Lister().Services(sc.namespace).List(sc.labelSelector)
}()
if err != nil {
return false
}
services = sc.filterByServiceType(services)
services, err = sc.filterByAnnotations(services)
if err != nil {
return true
}
if len(services) == 0 {
return false
}
for _, svc := range services {
if svc.Spec.Type == v1.ServiceTypeNodePort {
return true
}
if svc.Spec.ClusterIP == v1.ClusterIPNone {
endpointsType := getEndpointsTypeFromAnnotations(svc.Annotations)
if endpointsType == EndpointsTypeNodeExternalIP || sc.publishHostIP {
return true
}
}
}
return false
}

type nodeAddressChangeHandler struct {
source *serviceSource
handler func()
}

func (h *nodeAddressChangeHandler) OnAdd(obj interface{}, isInInitialList bool) {
if isInInitialList {
return
}
h.handler()
}

func (h *nodeAddressChangeHandler) OnUpdate(oldObj, newObj interface{}) {
oldNode, okOld := oldObj.(*v1.Node)
newNode, okNew := newObj.(*v1.Node)
if !okOld || !okNew {
h.handler()
return
}
if nodeAddressSetsEqual(oldNode, newNode) {
return
}
h.handler()
}

func (h *nodeAddressChangeHandler) OnDelete(obj interface{}) {
h.handler()
}

func nodeAddressSetsEqual(a, b *v1.Node) bool {
return addressSetEquals(nodeAddressSet(a), nodeAddressSet(b))
}

func nodeAddressSet(node *v1.Node) map[string]struct{} {
if node == nil {
return nil
}
result := make(map[string]struct{})
for _, address := range node.Status.Addresses {
switch address.Type {
case v1.NodeExternalIP, v1.NodeInternalIP:
key := string(address.Type) + "|" + address.Address
result[key] = struct{}{}
}
}
return result
}

func addressSetEquals(a, b map[string]struct{}) bool {
if len(a) != len(b) {
return false
}
for key := range a {
if _, ok := b[key]; !ok {
return false
}
}
return true
}

type serviceTypes struct {
Expand Down
23 changes: 14 additions & 9 deletions source/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5018,40 +5018,41 @@ func TestServiceSource_AddEventHandler(t *testing.T) {
name string
filter []string
times int
prepare func(s *serviceSource)
asserts func(t *testing.T, s *serviceSource)
}{
{
name: "AddEventHandler should trigger all event handlers when empty filter is provided",
name: "AddEventHandler registers service + endpoint handlers only (empty filter, no node need)",
filter: []string{},
times: 3,
times: 2,
asserts: func(t *testing.T, s *serviceSource) {
fakeServiceInformer.AssertNumberOfCalls(t, "Informer", 1)
fakeServiceInformer.AssertNumberOfCalls(t, "Informer", 2)
fakeEdpInformer.AssertNumberOfCalls(t, "Informer", 1)
fakeNodeInformer.AssertNumberOfCalls(t, "Informer", 1)
fakeNodeInformer.AssertNumberOfCalls(t, "Informer", 0)
},
},
{
name: "AddEventHandler should trigger only service event handler",
name: "AddEventHandler only service handler (no endpoint slice types enabled)",
filter: []string{string(v1.ServiceTypeExternalName), string(v1.ServiceTypeLoadBalancer)},
times: 1,
asserts: func(t *testing.T, s *serviceSource) {
fakeServiceInformer.AssertNumberOfCalls(t, "Informer", 1)
fakeServiceInformer.AssertNumberOfCalls(t, "Informer", 2)
fakeEdpInformer.AssertNumberOfCalls(t, "Informer", 0)
fakeNodeInformer.AssertNumberOfCalls(t, "Informer", 0)
},
},
{
name: "AddEventHandler should configure only service event handler",
name: "AddEventHandler service + endpoint slice handlers",
filter: []string{string(v1.ServiceTypeExternalName), string(v1.ServiceTypeLoadBalancer), string(v1.ServiceTypeClusterIP)},
times: 2,
asserts: func(t *testing.T, s *serviceSource) {
fakeServiceInformer.AssertNumberOfCalls(t, "Informer", 1)
fakeServiceInformer.AssertNumberOfCalls(t, "Informer", 2)
fakeEdpInformer.AssertNumberOfCalls(t, "Informer", 1)
fakeNodeInformer.AssertNumberOfCalls(t, "Informer", 0)
},
},
{
name: "AddEventHandler should configure all service event handlers",
name: "AddEventHandler registers node handler for NodePort",
filter: []string{string(v1.ServiceTypeNodePort)},
times: 3,
asserts: func(t *testing.T, s *serviceSource) {
Expand Down Expand Up @@ -5085,6 +5086,10 @@ func TestServiceSource_AddEventHandler(t *testing.T) {
listenEndpointEvents: true,
}

if tt.prepare != nil {
tt.prepare(svcSource)
}

svcSource.AddEventHandler(t.Context(), func() {})

assert.Equal(t, tt.times, infSvc.times+infEdp.times+infNode.times)
Expand Down
Loading