Skip to content
58 changes: 49 additions & 9 deletions controller/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"net/http"
"os"
"os/signal"
"runtime"
"syscall"
"time"

Expand Down Expand Up @@ -72,6 +73,24 @@ import (
"sigs.k8s.io/external-dns/source/wrappers"
)

// sigtermSignals is a package-level signal channel that is registered in init().
// This way, SIGTERM is captured as soon as the package is loaded, preventing
// default process termination, even if application startup is delayed.
var sigtermSignals chan os.Signal

func init() {
sigtermSignals = make(chan os.Signal, 1)
signal.Notify(sigtermSignals, terminationSignals()...)
}

func terminationSignals() []os.Signal {
signals := []os.Signal{os.Interrupt}
if runtime.GOOS != "windows" {
signals = append(signals, syscall.SIGTERM)
}
return signals
}

func Execute() {
cfg := externaldns.NewConfig()
if err := cfg.ParseFlags(os.Args[1:]); err != nil {
Expand Down Expand Up @@ -99,8 +118,14 @@ func Execute() {

ctx, cancel := context.WithCancel(context.Background())

go serveMetrics(cfg.MetricsAddress)
go handleSigterm(cancel)
// Connect global SIGTERM capture to this run's context cancellation.
go func() {
<-sigtermSignals
log.Info("Received termination signal. Terminating...")
cancel()
}()

go serveMetrics(ctx, cfg.MetricsAddress)

endpointsSource, err := buildSource(ctx, cfg)
if err != nil {
Expand Down Expand Up @@ -468,22 +493,25 @@ func createDomainFilter(cfg *externaldns.Config) *endpoint.DomainFilter {
}
}

// handleSigterm listens for a SIGTERM signal and triggers the provided cancel function
// handleSigterm listens for termination signals and triggers the provided cancel function
// to gracefully terminate the application. It logs a message when the signal is received.
func handleSigterm(cancel func()) {
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGTERM)
signal.Notify(signals, terminationSignals()...)
<-signals
log.Info("Received SIGTERM. Terminating...")
log.Info("Received termination signal. Terminating...")
cancel()
signal.Stop(signals)
}

// serveMetrics starts an HTTP server that serves health and metrics endpoints.
// The /healthz endpoint returns a 200 OK status to indicate the service is healthy.
// The /metrics endpoint serves Prometheus metrics.
// The server listens on the specified address and logs debug information about the endpoints.
func serveMetrics(address string) {
http.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) {
func serveMetrics(ctx context.Context, address string) {
mux := http.NewServeMux()

mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
})
Expand All @@ -492,7 +520,19 @@ func serveMetrics(address string) {
log.Debugf("serving 'metrics' on '%s/metrics'", address)
log.Debugf("registered '%d' metrics", len(metrics.RegisterMetric.Metrics))

http.Handle("/metrics", promhttp.Handler())
mux.Handle("/metrics", promhttp.Handler())

srv := &http.Server{Addr: address, Handler: mux}

log.Fatal(http.ListenAndServe(address, nil))
// Shutdown server on context cancellation
go func() {
<-ctx.Done()
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
_ = srv.Shutdown(shutdownCtx)
cancel()
}()

if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatal(err)
}
}
217 changes: 217 additions & 0 deletions controller/execute_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,17 @@ import (
"bytes"
"context"
"errors"
"fmt"
"net"
"net/http"
"net/http/httptest"
"os"
"os/exec"
"os/signal"
"reflect"
"regexp"
"runtime"
"syscall"
"testing"
"time"

Expand Down Expand Up @@ -375,6 +380,95 @@ func TestCreateDomainFilter(t *testing.T) {
}
}

func getRandomPort() (int, error) {
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
if err != nil {
return 0, err
}

l, err := net.ListenTCP("tcp", addr)
if err != nil {
return 0, err
}
defer l.Close()
return l.Addr().(*net.TCPAddr).Port, nil
}

func sendTerminationSignal() error {
proc, err := os.FindProcess(os.Getpid())
if err != nil {
return err
}
if runtime.GOOS == "windows" {
return proc.Signal(os.Interrupt)
}
return proc.Signal(syscall.SIGTERM)
}

func TestServeMetrics(t *testing.T) {
// Use a fresh DefaultServeMux for this test (do not restore to avoid data race with server goroutine)
http.DefaultServeMux = http.NewServeMux()

port, err := getRandomPort()
require.NoError(t, err)
address := fmt.Sprintf("localhost:%d", port)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go serveMetrics(ctx, fmt.Sprintf(":%d", port))

// Wait for the TCP socket to be ready
require.Eventually(t, func() bool {
conn, err := net.Dial("tcp", address)
if err != nil {
return false
}
_ = conn.Close()
return true
}, 2*time.Second, 10*time.Millisecond, "server not ready with port open in time")

resp, err := http.Get(fmt.Sprintf("http://%s/healthz", address))
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
_ = resp.Body.Close()

resp, err = http.Get(fmt.Sprintf("http://%s/metrics", address))
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
_ = resp.Body.Close()

// Stop the server to avoid leaking goroutines across tests
cancel()
}

func TestHandleSigterm(t *testing.T) {
cancelCalled := make(chan bool, 1)
cancel := func() { cancelCalled <- true }

var logOutput bytes.Buffer
log.SetOutput(&logOutput)
defer log.SetOutput(os.Stderr)

go handleSigterm(cancel)

// Simulate sending a termination signal
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, terminationSignals()...)
defer signal.Stop(sigChan)
err := sendTerminationSignal()
assert.NoError(t, err)

// Wait for cancel to be called
select {
case <-cancelCalled:
assert.Contains(t, logOutput.String(), "Received termination signal. Terminating...")
case sig := <-sigChan:
assert.Contains(t, terminationSignals(), sig)
case <-time.After(1 * time.Second):
t.Fatal("cancel function was not called")
}
}

func TestBuildSource(t *testing.T) {
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotImplemented)
Expand Down Expand Up @@ -629,6 +723,129 @@ func TestExecuteBuildControllerErrorExitsNonZero(t *testing.T) {
assert.NotEqual(t, 0, code)
}

// ValidateConfig triggers log.Fatalf (in-process).
func TestExecuteConfigValidationFatalInProcess(t *testing.T) {
// Prepare args to trigger validation error before any goroutines start
prevArgs := os.Args
os.Args = []string{
"external-dns",
"--source", "fake",
"--provider", "inmemory",
"--ignore-hostname-annotation", // triggers validation: FQDN template required when ignoring annotations
"--metrics-address", ":0",
}
t.Cleanup(func() { os.Args = prevArgs })

// Capture logs and replace Fatalf with Goexit to stop only the Execute goroutine
logger := log.StandardLogger()
prevExit := logger.ExitFunc
prevOut := logger.Out
buf := new(bytes.Buffer)
logger.SetOutput(buf)
logger.ExitFunc = func(int) { runtime.Goexit() }
t.Cleanup(func() { logger.ExitFunc = prevExit; logger.SetOutput(prevOut) })

done := make(chan struct{})
go func() {
defer close(done)
Execute()
}()

select {
case <-done:
// ok
case <-time.After(2 * time.Second):
t.Fatal("Execute did not exit after validation fatal")
}

// Do not assert on logger text to avoid flakiness with global logger
}

// Run path with --events; shut down via SIGTERM.
func TestExecuteDefaultRunWithEventsStopsOnSigterm(t *testing.T) {
// Use a fresh DefaultServeMux for this test (do not restore to avoid data race with server goroutine)
http.DefaultServeMux = http.NewServeMux()

// Prepare args to run Execute without --once and with --events
prevArgs := os.Args
os.Args = []string{
"external-dns",
"--source", "fake",
"--provider", "inmemory",
"--events",
"--dry-run",
"--metrics-address", ":0",
}
t.Cleanup(func() { os.Args = prevArgs })

// Prevent log.Fatal from terminating the test process
logger := log.StandardLogger()
prevExit := logger.ExitFunc
logger.ExitFunc = func(int) { runtime.Goexit() }
t.Cleanup(func() { logger.ExitFunc = prevExit })

done := make(chan struct{})
go func() {
defer close(done)
Execute()
}()

// Give goroutines time to start
time.Sleep(50 * time.Millisecond)

// Send termination signal to trigger handleSigterm(cancel)
require.NoError(t, sendTerminationSignal())

select {
case <-done:
// ok
case <-time.After(2 * time.Second):
t.Fatal("Execute did not stop after termination signal")
}
}

// Webhook server path; pre-bind 127.0.0.1:8888 to force a bind failure.
func TestExecuteWebhookServerFailsPortInUseInProcess(t *testing.T) {
// Use a fresh DefaultServeMux for this test (do not restore to avoid data race with server goroutine)
http.DefaultServeMux = http.NewServeMux()

// Pre-bind the webhook server port so it is unavailable
l, err := net.Listen("tcp", "127.0.0.1:8888")
if err != nil {
// If we cannot bind, assume something else is bound already, which is fine for this test
} else {
t.Cleanup(func() { _ = l.Close() })
}

prevArgs := os.Args
os.Args = []string{
"external-dns",
"--source", "fake",
"--provider", "inmemory",
"--webhook-server",
"--metrics-address", ":0",
}
t.Cleanup(func() { os.Args = prevArgs })

logger := log.StandardLogger()
prevExit := logger.ExitFunc
logger.ExitFunc = func(int) { runtime.Goexit() }
t.Cleanup(func() { logger.ExitFunc = prevExit })

done := make(chan struct{})
go func() {
defer close(done)
Execute()
}()

select {
case <-done:
// ok
case <-time.After(2 * time.Second):
t.Fatal("Execute did not exit after webhook server fatal")
}
}

// Controller run loop stops on context cancel.
func TestControllerRunCancelContextStopsLoop(t *testing.T) {
// Minimal controller using fake source and inmemory provider.
Expand Down
Loading