Skip to content

Commit f03e9f6

Browse files
committed
fix: Ensure all connections persist until queue-proxy drain
Fixes: Websockets (and some HTTP) closing abruptly when queue-proxy undergoes drain. Due to hijacked connections in net/http not being respected when server.Shutdown is called, any active websocket connections currently end as soon as the queue-proxy calls .Shutdown. See gorilla/websocket#448 and golang/go#17721 for details. This patch fixes this issue by introducing an atomic counter of active requests, which increments as a request comes in and decrements as a request handler terminates. During drain, this counter must reach zero or adhere to the revision timeout, in order to call .Shutdown. Further, this prevents pre-mature closing of connections in the user container due to misconfigured SIGTERM handling, by delaying the SIGTERM send until the queue-proxy has verified it has fully drained.
1 parent 24ff578 commit f03e9f6

File tree

19 files changed

+945
-503
lines changed

19 files changed

+945
-503
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,6 @@
99
# Temporary output of build tools
1010
bazel-*
1111
*.out
12+
13+
# Repomix outputs
14+
repomix*.xml

pkg/activator/net/throttler.go

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,13 @@ func (p *podTracker) Capacity() int {
100100
if p.b == nil {
101101
return 1
102102
}
103-
return p.b.Capacity()
103+
capacity := p.b.Capacity()
104+
// Safe conversion: breaker capacity is always reasonable for int
105+
// Check for overflow before conversion
106+
if capacity > 0x7FFFFFFF {
107+
return 0x7FFFFFFF // Return max int32 value
108+
}
109+
return int(capacity)
104110
}
105111

106112
func (p *podTracker) UpdateConcurrency(c int) {
@@ -118,7 +124,7 @@ func (p *podTracker) Reserve(ctx context.Context) (func(), bool) {
118124
}
119125

120126
type breaker interface {
121-
Capacity() int
127+
Capacity() uint64
122128
Maybe(ctx context.Context, thunk func()) error
123129
UpdateConcurrency(int)
124130
Reserve(ctx context.Context) (func(), bool)
@@ -721,8 +727,13 @@ func newInfiniteBreaker(logger *zap.SugaredLogger) *infiniteBreaker {
721727
}
722728

723729
// Capacity returns the current capacity of the breaker
724-
func (ib *infiniteBreaker) Capacity() int {
725-
return int(ib.concurrency.Load())
730+
func (ib *infiniteBreaker) Capacity() uint64 {
731+
concurrency := ib.concurrency.Load()
732+
// Safe conversion: concurrency is int32 and we check for non-negative
733+
if concurrency >= 0 {
734+
return uint64(concurrency)
735+
}
736+
return 0
726737
}
727738

728739
func zeroOrOne(x int) int32 {

pkg/activator/net/throttler_test.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ func TestThrottlerUpdateCapacity(t *testing.T) {
226226
rt.breaker = newInfiniteBreaker(logger)
227227
}
228228
rt.updateCapacity(tt.capacity)
229-
if got := rt.breaker.Capacity(); got != tt.want {
229+
if got := rt.breaker.Capacity(); got != uint64(tt.want) {
230230
t.Errorf("Capacity = %d, want: %d", got, tt.want)
231231
}
232232
if tt.checkAssignedPod {
@@ -560,7 +560,7 @@ func TestThrottlerSuccesses(t *testing.T) {
560560
rt.mux.RLock()
561561
defer rt.mux.RUnlock()
562562
if *cc != 0 {
563-
return rt.activatorIndex.Load() != -1 && rt.breaker.Capacity() == wantCapacity &&
563+
return rt.activatorIndex.Load() != -1 && rt.breaker.Capacity() == uint64(wantCapacity) &&
564564
sortedTrackers(rt.assignedTrackers), nil
565565
}
566566
// If CC=0 then verify number of backends, rather the capacity of breaker.
@@ -638,7 +638,7 @@ func TestPodAssignmentFinite(t *testing.T) {
638638
if got, want := trackerDestSet(rt.assignedTrackers), sets.New("ip0", "ip4"); !got.Equal(want) {
639639
t.Errorf("Assigned trackers = %v, want: %v, diff: %s", got, want, cmp.Diff(want, got))
640640
}
641-
if got, want := rt.breaker.Capacity(), 2*42; got != want {
641+
if got, want := rt.breaker.Capacity(), uint64(2*42); got != want {
642642
t.Errorf("TotalCapacity = %d, want: %d", got, want)
643643
}
644644
if got, want := rt.assignedTrackers[0].Capacity(), 42; got != want {
@@ -657,7 +657,7 @@ func TestPodAssignmentFinite(t *testing.T) {
657657
if got, want := len(rt.assignedTrackers), 0; got != want {
658658
t.Errorf("NumAssignedTrackers = %d, want: %d", got, want)
659659
}
660-
if got, want := rt.breaker.Capacity(), 0; got != want {
660+
if got, want := rt.breaker.Capacity(), uint64(0); got != want {
661661
t.Errorf("TotalCapacity = %d, want: %d", got, want)
662662
}
663663
}
@@ -687,7 +687,7 @@ func TestPodAssignmentInfinite(t *testing.T) {
687687
if got, want := len(rt.assignedTrackers), 3; got != want {
688688
t.Errorf("NumAssigned trackers = %d, want: %d", got, want)
689689
}
690-
if got, want := rt.breaker.Capacity(), 1; got != want {
690+
if got, want := rt.breaker.Capacity(), uint64(1); got != want {
691691
t.Errorf("TotalCapacity = %d, want: %d", got, want)
692692
}
693693
if got, want := rt.assignedTrackers[0].Capacity(), 1; got != want {
@@ -703,7 +703,7 @@ func TestPodAssignmentInfinite(t *testing.T) {
703703
if got, want := len(rt.assignedTrackers), 0; got != want {
704704
t.Errorf("NumAssignedTrackers = %d, want: %d", got, want)
705705
}
706-
if got, want := rt.breaker.Capacity(), 0; got != want {
706+
if got, want := rt.breaker.Capacity(), uint64(0); got != want {
707707
t.Errorf("TotalCapacity = %d, want: %d", got, want)
708708
}
709709
}
@@ -935,7 +935,7 @@ func TestInfiniteBreaker(t *testing.T) {
935935
}
936936

937937
// Verify initial condition.
938-
if got, want := b.Capacity(), 0; got != want {
938+
if got, want := b.Capacity(), uint64(0); got != want {
939939
t.Errorf("Cap=%d, want: %d", got, want)
940940
}
941941
if _, ok := b.Reserve(context.Background()); ok != true {
@@ -949,7 +949,7 @@ func TestInfiniteBreaker(t *testing.T) {
949949
}
950950

951951
b.UpdateConcurrency(1)
952-
if got, want := b.Capacity(), 1; got != want {
952+
if got, want := b.Capacity(), uint64(1); got != want {
953953
t.Errorf("Cap=%d, want: %d", got, want)
954954
}
955955

@@ -976,7 +976,7 @@ func TestInfiniteBreaker(t *testing.T) {
976976
if err := b.Maybe(ctx, nil); err == nil {
977977
t.Error("Should have failed, but didn't")
978978
}
979-
if got, want := b.Capacity(), 0; got != want {
979+
if got, want := b.Capacity(), uint64(0); got != want {
980980
t.Errorf("Cap=%d, want: %d", got, want)
981981
}
982982

@@ -1212,7 +1212,7 @@ func TestAssignSlice(t *testing.T) {
12121212
t.Errorf("Got=%v, want: %v; diff: %s", got, want,
12131213
cmp.Diff(want, got, opt))
12141214
}
1215-
if got, want := got[0].b.Capacity(), 0; got != want {
1215+
if got, want := got[0].b.Capacity(), uint64(0); got != want {
12161216
t.Errorf("Capacity for the tail pod = %d, want: %d", got, want)
12171217
}
12181218
})

pkg/autoscaler/metrics/stat.pb.go

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/queue/breaker.go

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ type BreakerParams struct {
4343
// executions in excess of the concurrency limit. Function call attempts
4444
// beyond the limit of the queue are failed immediately.
4545
type Breaker struct {
46-
inFlight atomic.Int64
46+
pending atomic.Int64
4747
totalSlots int64
4848
sem *semaphore
4949

@@ -83,10 +83,10 @@ func NewBreaker(params BreakerParams) *Breaker {
8383
func (b *Breaker) tryAcquirePending() bool {
8484
// This is an atomic version of:
8585
//
86-
// if inFlight == totalSlots {
86+
// if pending == totalSlots {
8787
// return false
8888
// } else {
89-
// inFlight++
89+
// pending++
9090
// return true
9191
// }
9292
//
@@ -96,19 +96,20 @@ func (b *Breaker) tryAcquirePending() bool {
9696
// (it fails if we're raced to it) or if we don't fulfill the condition
9797
// anymore.
9898
for {
99-
cur := b.inFlight.Load()
99+
cur := b.pending.Load()
100+
// 10000 + containerConcurrency = totalSlots
100101
if cur == b.totalSlots {
101102
return false
102103
}
103-
if b.inFlight.CompareAndSwap(cur, cur+1) {
104+
if b.pending.CompareAndSwap(cur, cur+1) {
104105
return true
105106
}
106107
}
107108
}
108109

109110
// releasePending releases a slot on the pending "queue".
110111
func (b *Breaker) releasePending() {
111-
b.inFlight.Add(-1)
112+
b.pending.Add(-1)
112113
}
113114

114115
// Reserve reserves an execution slot in the breaker, to permit
@@ -154,9 +155,9 @@ func (b *Breaker) Maybe(ctx context.Context, thunk func()) error {
154155
return nil
155156
}
156157

157-
// InFlight returns the number of requests currently in flight in this breaker.
158-
func (b *Breaker) InFlight() int {
159-
return int(b.inFlight.Load())
158+
// Pending returns the number of requests currently pending to this breaker.
159+
func (b *Breaker) Pending() int {
160+
return int(b.pending.Load())
160161
}
161162

162163
// UpdateConcurrency updates the maximum number of in-flight requests.
@@ -165,10 +166,15 @@ func (b *Breaker) UpdateConcurrency(size int) {
165166
}
166167

167168
// Capacity returns the number of allowed in-flight requests on this breaker.
168-
func (b *Breaker) Capacity() int {
169+
func (b *Breaker) Capacity() uint64 {
169170
return b.sem.Capacity()
170171
}
171172

173+
// InFlight returns the number of requests currently in-flight on this breaker.
174+
func (b *Breaker) InFlight() uint64 {
175+
return b.sem.InFlight()
176+
}
177+
172178
// newSemaphore creates a semaphore with the desired initial capacity.
173179
func newSemaphore(maxCapacity, initialCapacity int) *semaphore {
174180
queue := make(chan struct{}, maxCapacity)
@@ -288,9 +294,15 @@ func (s *semaphore) updateCapacity(size int) {
288294
}
289295

290296
// Capacity is the capacity of the semaphore.
291-
func (s *semaphore) Capacity() int {
297+
func (s *semaphore) Capacity() uint64 {
292298
capacity, _ := unpack(s.state.Load())
293-
return int(capacity) //nolint:gosec // TODO(dprotaso) - capacity should be uint64
299+
return capacity
300+
}
301+
302+
// InFlight is the number of the inflight requests of the semaphore.
303+
func (s *semaphore) InFlight() uint64 {
304+
_, inFlight := unpack(s.state.Load())
305+
return inFlight
294306
}
295307

296308
// unpack takes an uint64 and returns two uint32 (as uint64) comprised of the leftmost

pkg/queue/breaker_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,12 +212,12 @@ func TestBreakerUpdateConcurrency(t *testing.T) {
212212
params := BreakerParams{QueueDepth: 1, MaxConcurrency: 1, InitialCapacity: 0}
213213
b := NewBreaker(params)
214214
b.UpdateConcurrency(1)
215-
if got, want := b.Capacity(), 1; got != want {
215+
if got, want := b.Capacity(), uint64(1); got != want {
216216
t.Errorf("Capacity() = %d, want: %d", got, want)
217217
}
218218

219219
b.UpdateConcurrency(0)
220-
if got, want := b.Capacity(), 0; got != want {
220+
if got, want := b.Capacity(), uint64(0); got != want {
221221
t.Errorf("Capacity() = %d, want: %d", got, want)
222222
}
223223
}
@@ -294,12 +294,12 @@ func TestSemaphoreRelease(t *testing.T) {
294294
func TestSemaphoreUpdateCapacity(t *testing.T) {
295295
const initialCapacity = 1
296296
sem := newSemaphore(3, initialCapacity)
297-
if got, want := sem.Capacity(), 1; got != want {
297+
if got, want := sem.Capacity(), uint64(1); got != want {
298298
t.Errorf("Capacity = %d, want: %d", got, want)
299299
}
300300
sem.acquire(context.Background())
301301
sem.updateCapacity(initialCapacity + 2)
302-
if got, want := sem.Capacity(), 3; got != want {
302+
if got, want := sem.Capacity(), uint64(3); got != want {
303303
t.Errorf("Capacity = %d, want: %d", got, want)
304304
}
305305
}

pkg/queue/request_metric.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ func (h *appRequestMetricsHandler) ServeHTTP(w http.ResponseWriter, r *http.Requ
8585
startTime := h.clock.Now()
8686

8787
if h.breaker != nil {
88-
h.queueLen.Record(r.Context(), int64(h.breaker.InFlight()))
88+
h.queueLen.Record(r.Context(), int64(h.breaker.Pending()))
8989
}
9090
defer func() {
9191
// Filter probe requests for revision metrics.

pkg/queue/sharedmain/handlers.go

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@ package sharedmain
1818

1919
import (
2020
"context"
21+
"fmt"
2122
"net"
2223
"net/http"
24+
"strings"
25+
"sync/atomic"
2326
"time"
2427

2528
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
@@ -30,6 +33,7 @@ import (
3033
netheader "knative.dev/networking/pkg/http/header"
3134
netproxy "knative.dev/networking/pkg/http/proxy"
3235
netstats "knative.dev/networking/pkg/http/stats"
36+
"knative.dev/pkg/network"
3337
pkghandler "knative.dev/pkg/network/handlers"
3438
"knative.dev/serving/pkg/activator"
3539
pkghttp "knative.dev/serving/pkg/http"
@@ -46,6 +50,7 @@ func mainHandler(
4650
logger *zap.SugaredLogger,
4751
mp metric.MeterProvider,
4852
tp trace.TracerProvider,
53+
pendingRequests *atomic.Int32,
4954
) (http.Handler, *pkghandler.Drainer) {
5055
target := net.JoinHostPort("127.0.0.1", env.UserPort)
5156
tracer := tp.Tracer("knative.dev/serving/pkg/queue")
@@ -73,6 +78,7 @@ func mainHandler(
7378

7479
composedHandler = requestAppMetricsHandler(logger, composedHandler, breaker, mp)
7580
composedHandler = queue.ProxyHandler(tracer, breaker, stats, composedHandler)
81+
7682
composedHandler = queue.ForwardedShimHandler(composedHandler)
7783
composedHandler = handler.NewTimeoutHandler(composedHandler, "request timeout", func(r *http.Request) (time.Duration, time.Duration, time.Duration) {
7884
return timeout, responseStartTimeout, idleTimeout
@@ -81,6 +87,8 @@ func mainHandler(
8187
composedHandler = queue.NewRouteTagHandler(composedHandler)
8288
composedHandler = withFullDuplex(composedHandler, env.EnableHTTPFullDuplex, logger)
8389

90+
composedHandler = withRequestCounter(composedHandler, pendingRequests)
91+
8492
drainer := &pkghandler.Drainer{
8593
QuietPeriod: drainSleepDuration,
8694
// Add Activator probe header to the drainer so it can handle probes directly from activator
@@ -105,11 +113,10 @@ func mainHandler(
105113
return !netheader.IsProbe(r)
106114
}),
107115
)
108-
109116
return composedHandler, drainer
110117
}
111118

112-
func adminHandler(ctx context.Context, logger *zap.SugaredLogger, drainer *pkghandler.Drainer) http.Handler {
119+
func adminHandler(ctx context.Context, logger *zap.SugaredLogger, drainer *pkghandler.Drainer, pendingRequests *atomic.Int32) http.Handler {
113120
mux := http.NewServeMux()
114121
mux.HandleFunc(queue.RequestQueueDrainPath, func(w http.ResponseWriter, r *http.Request) {
115122
logger.Info("Attached drain handler from user-container", r)
@@ -130,6 +137,17 @@ func adminHandler(ctx context.Context, logger *zap.SugaredLogger, drainer *pkgha
130137
w.WriteHeader(http.StatusOK)
131138
})
132139

140+
// New endpoint that returns 200 only when all requests are drained
141+
mux.HandleFunc("/drain-complete", func(w http.ResponseWriter, r *http.Request) {
142+
if pendingRequests.Load() <= 0 {
143+
w.WriteHeader(http.StatusOK)
144+
w.Write([]byte("drained"))
145+
} else {
146+
w.WriteHeader(http.StatusServiceUnavailable)
147+
fmt.Fprintf(w, "pending requests: %d", pendingRequests.Load())
148+
}
149+
})
150+
133151
return mux
134152
}
135153

@@ -145,3 +163,13 @@ func withFullDuplex(h http.Handler, enableFullDuplex bool, logger *zap.SugaredLo
145163
h.ServeHTTP(w, r)
146164
})
147165
}
166+
167+
func withRequestCounter(h http.Handler, pendingRequests *atomic.Int32) http.Handler {
168+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
169+
if r.Header.Get(network.ProbeHeaderName) != network.ProbeHeaderValue && !strings.HasPrefix(r.Header.Get("User-Agent"), "kube-probe/") {
170+
pendingRequests.Add(1)
171+
defer pendingRequests.Add(-1)
172+
}
173+
h.ServeHTTP(w, r)
174+
})
175+
}

0 commit comments

Comments
 (0)