diff --git a/sse.go b/sse.go index 480261f..8b031b1 100644 --- a/sse.go +++ b/sse.go @@ -45,7 +45,7 @@ func Bind(obj interface{}, opts ...Options) flamego.Handler { } c.Set(reflect.ChanOf(reflect.SendDir, sse.sender.Type().Elem()), sse.sender) - go sse.handle(log, c.ResponseWriter()) + go sse.handle(log, c) } } @@ -59,7 +59,8 @@ func newOptions(opts []Options) Options { return opts[0] } -func (c *connection) handle(log *log.Logger, w flamego.ResponseWriter) { +func (c *connection) handle(log *log.Logger, ctx flamego.Context) { + w := ctx.ResponseWriter() ticker := time.NewTicker(c.PingInterval) defer func() { ticker.Stop() }() @@ -78,11 +79,13 @@ func (c *connection) handle(log *log.Logger, w flamego.ResponseWriter) { senderSend = iota tickerTick timeout + closed ) - cases := make([]reflect.SelectCase, 3) + cases := make([]reflect.SelectCase, 4) cases[senderSend] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: c.sender, Send: reflect.ValueOf(nil)} cases[tickerTick] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ticker.C), Send: reflect.ValueOf(nil)} cases[timeout] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(time.After(time.Hour)), Send: reflect.ValueOf(nil)} + cases[closed] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ctx.Request().Context().Done()), Send: reflect.ValueOf(nil)} loop: for { @@ -112,6 +115,9 @@ loop: write("events: stream timeout\n\n") w.Flush() break loop + + case closed: + return } } diff --git a/sse_test.go b/sse_test.go index 7c82483..717aa8c 100644 --- a/sse_test.go +++ b/sse_test.go @@ -6,6 +6,7 @@ package sse import ( "bytes" + "context" "net/http" "net/http/httptest" "sync" @@ -63,6 +64,27 @@ func TestBind(t *testing.T) { time.Sleep(1 * time.Second) }, ) + f.Get("/ticker", + Bind( + object{}, + Options{ + 100 * time.Millisecond, + }, + ), + func(ctx flamego.Context, msg chan<- *object) { + ticker := time.NewTicker(1 * time.Second) + defer func() { ticker.Stop() }() + + for { + select { + case <-ticker.C: + msg <- &object{Message: "Flamego"} + case <-ctx.Request().Context().Done(): + return + } + } + }, + ) t.Run("normal", func(t *testing.T) { resp := &mockResponseWriter{ @@ -118,4 +140,25 @@ data: {"Message":"Flamego"} ` assert.Equal(t, wantBody, resp.Body()) }) + + t.Run("close connection", func(t *testing.T) { + server := httptest.NewServer(f) + + reqContext, cancel := context.WithCancel(context.Background()) + req, err := http.NewRequestWithContext(reqContext, http.MethodGet, server.URL+"/ticker", nil) + require.NoError(t, err) + + // Close request connection after 1 second. + go func() { + time.Sleep(1 * time.Second) + cancel() + }() + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + err = resp.Body.Close() + require.NoError(t, err) + + // Sleep for 3 seconds to wait for new responses that may be in a closed request. + time.Sleep(3 * time.Second) + }) }