diff --git a/README.md b/README.md index 727a436..f0a8550 100644 --- a/README.md +++ b/README.md @@ -64,8 +64,31 @@ func main() { srv = cache.Middleware(srv) // ... do more things } -```` +``` Doing it, Gqlgen write the lowest max-age defined in cacheControl extensions. For more informations, see `_example` folder. + +### Force cache control in case of error + +By default, any existing cache hints will not result in a Cache-Control header in case of an error. + +It's possible to force it by using the helper `cache.ContextWithForceCacheControl` to mark the request as cacheable even when GraphQL errors are present. + + +```go +func (h *handler) GraphqlHandler() gin.HandlerFunc { + // ... setup gqlgen/graphql/handler + // ... setup server + srv.Use(cache.Extension{}) + cachedServer := cache.Middleware(srv) + + return func(c *gin.Context) { + ctx := cache.ContextWithForceCacheControl(c.Request.Context(), true) + c.Request = c.Request.WithContext(ctx) + + cachedServer.ServeHTTP(c.Writer, c.Request) + } +} +``` diff --git a/cache/context.go b/cache/context.go new file mode 100644 index 0000000..926e41c --- /dev/null +++ b/cache/context.go @@ -0,0 +1,12 @@ +package cache + +import "context" + +type contextKey string + +const forceCacheControlKey = contextKey("forceCacheControl") + +// ContextWithForceCacheControl creates a new context with the forceCacheControl option. +func ContextWithForceCacheControl(ctx context.Context, forceCacheControl bool) context.Context { + return context.WithValue(ctx, forceCacheControlKey, forceCacheControl) +} diff --git a/cache/context_test.go b/cache/context_test.go new file mode 100644 index 0000000..573617c --- /dev/null +++ b/cache/context_test.go @@ -0,0 +1,37 @@ +package cache + +import ( + "context" + "reflect" + "testing" +) + +func TestContextWithForceCacheControl(t *testing.T) { + type args struct { + ctx context.Context + forceCacheControl bool + } + tests := []struct { + name string + args args + want context.Context + }{ + { + "forceCacheControl is false", + args{context.Background(), false}, + context.WithValue(context.Background(), forceCacheControlKey, false), + }, + { + "valid key and value forceCacheControl", + args{context.Background(), true}, + context.WithValue(context.Background(), forceCacheControlKey, true), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ContextWithForceCacheControl(tt.args.ctx, tt.args.forceCacheControl); !reflect.DeepEqual(got, tt.want) { + t.Errorf("ContextWithForceCacheControl() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/cache/middleware.go b/cache/middleware.go index 3edf167..85e0aff 100644 --- a/cache/middleware.go +++ b/cache/middleware.go @@ -40,11 +40,13 @@ func (c responseWriter) Header() http.Header { } func (c responseWriter) Write(bytes []byte) (int, error) { + forceCacheControl, _ := c.r.Context().Value(forceCacheControlKey).(bool) + if c.w.Header().Get("Cache-Control") == "" { resp := graphql.Response{} err := json.Unmarshal(bytes, &resp) if err == nil { - writeCacheControl(c.r.Context(), c.w, &resp) + writeCacheControl(c.r.Context(), c.w, &resp, forceCacheControl) } } @@ -55,8 +57,8 @@ func (c responseWriter) WriteHeader(statusCode int) { c.w.WriteHeader(statusCode) } -func writeCacheControl(ctx context.Context, w http.ResponseWriter, response *graphql.Response) { - if len(response.Errors) > 0 { +func writeCacheControl(ctx context.Context, w http.ResponseWriter, response *graphql.Response, forceCacheControl bool) { + if len(response.Errors) > 0 && !forceCacheControl { return }