diff --git a/tonic/handler.go b/tonic/handler.go index bcb99e8..9af4e10 100644 --- a/tonic/handler.go +++ b/tonic/handler.go @@ -49,16 +49,16 @@ func Handler(h interface{}, status int, options ...func(*Route)) gin.HandlerFunc // Wrap Gin handler. f := func(c *gin.Context) { _, ok := c.Get(tonicWantRouteInfos) + r := &Route{} + for _, opt := range options { + opt(r) + } if ok { - r := &Route{} r.defaultStatusCode = status r.handler = hv r.handlerType = ht r.inputType = in r.outputType = out - for _, opt := range options { - opt(r) - } c.Set(tonicRoutesInfos, r) c.Abort() return @@ -71,8 +71,14 @@ func Handler(h interface{}, status int, options ...func(*Route)) gin.HandlerFunc // binding. if in != nil { input := reflect.New(in) + routeBindHook := r.GetBindHook() + if routeBindHook == nil { + // use the default bindHook if the route + // does not have a custom one + routeBindHook = bindHook + } // Bind the body with the hook. - if err := bindHook(c, input.Interface()); err != nil { + if err := routeBindHook(c, input.Interface()); err != nil { handleError(c, BindError{message: err.Error(), typ: in}) return } @@ -116,7 +122,11 @@ func Handler(h interface{}, status int, options ...func(*Route)) gin.HandlerFunc handleError(c, err.(error)) return } - renderHook(c, status, val) + routeRenderHook := r.GetRenderHook() + if routeRenderHook == nil { + routeRenderHook = renderHook + } + routeRenderHook(c, status, val) } // Register route in tonic-enabled routes map route := &Route{ diff --git a/tonic/route.go b/tonic/route.go index db6fe07..5d76229 100644 --- a/tonic/route.go +++ b/tonic/route.go @@ -32,6 +32,18 @@ type Route struct { // outputType is the type of the output object. // This can be nil if the handler use none. outputType reflect.Type + + // the bindHook for this route + bindHook BindHook + + // the renderHook for this route + renderHook RenderHook + + // the request media type for this route + requestMediaType string + + // the response media type for this route + responseMediaType string } // GetVerb returns the HTTP verb of the route. @@ -40,6 +52,30 @@ func (r *Route) GetVerb() string { return r.Method } // GetPath returns the path of the route. func (r *Route) GetPath() string { return r.Path } +// GetRequestMediaType returns the media type of the route. +func (r *Route) GetRequestMediaType() string { return r.requestMediaType } + +// GetResponseMediaType returns the media type of the route. +func (r *Route) GetResponseMediaType() string { return r.responseMediaType } + +// SetRequestMediaType returns the media type of the route. +func (r *Route) SetRequestMediaType(mt string) { r.requestMediaType = mt } + +// SetResponseMediaType returns the media type of the route. +func (r *Route) SetResponseMediaType(mt string) { r.responseMediaType = mt } + +// GetBindHook returns the bind hook of the route. +func (r *Route) GetBindHook() BindHook { return r.bindHook } + +// SetBindHook returns the bind hook of the route. +func (r *Route) SetBindHook(h BindHook) { r.bindHook = h } + +// GetRenderHook returns the bind hook of the route. +func (r *Route) GetRenderHook() RenderHook { return r.renderHook } + +// SetRenderHook returns the bind hook of the route. +func (r *Route) SetRenderHook(h RenderHook) { r.renderHook = h } + // GetDescription returns the description of the route. func (r *Route) GetDescription() string { return r.description } diff --git a/tonic/tonic_test.go b/tonic/tonic_test.go index 694c159..1e6d8c1 100644 --- a/tonic/tonic_test.go +++ b/tonic/tonic_test.go @@ -35,6 +35,18 @@ func TestMain(m *testing.M) { g.GET("/query", tonic.Handler(queryHandler, 200)) g.GET("/query-old", tonic.Handler(queryHandlerOld, 200)) g.POST("/body", tonic.Handler(bodyHandler, 200)) + g.POST("/bodyYAML", tonic.Handler(bodyHandler, 200, func(r *tonic.Route) { + r.SetBindHook(func(c *gin.Context, i interface{}) error { + if err := c.ShouldBindYAML(i); err != nil { + return fmt.Errorf("error parsing request body: %s", err.Error()) + } + return nil + }) + r.SetRenderHook(func(c *gin.Context, statusCode int, payload interface{}) { + c.YAML(statusCode, payload) + }) + r.SetResponseMediaType("text/yaml") + })) r = g @@ -130,6 +142,15 @@ func TestBody(t *testing.T) { tester.Run() } +func TestBodyYAML(t *testing.T) { + + tester := iffy.NewTester(t, r) + + tester.AddCall("body", "POST", "/bodyYAML", `param: foo`).Checkers(iffy.ExpectStatus(200), expectStringInBody("param: foo")) + + tester.Run() +} + func errorHandler(c *gin.Context) error { return errors.New("error") } @@ -190,9 +211,9 @@ func queryHandlerOld(c *gin.Context, in *queryInOld) (*queryInOld, error) { } type bodyIn struct { - Param string `json:"param" validate:"required"` - ParamOptional string `json:"param-optional"` - ValidatedParamOptional string `json:"param-optional-validated" validate:"eq=|eq=foo|gt=10"` + Param string `json:"param" validate:"required" yaml:"param"` + ParamOptional string `json:"param-optional" yaml:"param-optional"` + ValidatedParamOptional string `json:"param-optional-validated" validate:"eq=|eq=foo|gt=10" yaml:"param-optional-validated"` } func bodyHandler(c *gin.Context, in *bodyIn) (*bodyIn, error) {