diff --git a/pkg/asyncapi/v2/schema.go b/pkg/asyncapi/v2/schema.go index 104c530f..e0bd9183 100644 --- a/pkg/asyncapi/v2/schema.go +++ b/pkg/asyncapi/v2/schema.go @@ -62,7 +62,10 @@ func NewSchema() Schema { // generateMetadata generates metadata for the schema and its children. func (s *Schema) generateMetadata(name string, isRequired bool) error { - s.Name = template.Namify(name) + // Do not set name if this is a reference - it will be resolved later + if s.Reference == "" { + s.Name = template.Namify(name) + } // Generate Properties metadata if err := s.generatePropertiesMetadata(); err != nil { diff --git a/pkg/codegen/generators/v2/templates/helpers.go b/pkg/codegen/generators/v2/templates/helpers.go index 8e4d6034..19152f39 100644 --- a/pkg/codegen/generators/v2/templates/helpers.go +++ b/pkg/codegen/generators/v2/templates/helpers.go @@ -152,6 +152,22 @@ func ForcePointerOnFields() { } } +// ExtractSchemaNameFromReference extracts schema name from reference path like "#/components/schemas/TestCreatedEvent". +func ExtractSchemaNameFromReference(ref string) string { + if ref == "" { + return "" + } + + // Split by "/" and get the last part + parts := strings.Split(ref, "/") + if len(parts) == 0 { + return "" + } + + schemaName := parts[len(parts)-1] + return templateutil.Namify(schemaName) + "Schema" +} + // HelpersFunctions returns the functions that can be used as helpers // in a golang template. func HelpersFunctions() template.FuncMap { @@ -166,5 +182,6 @@ func HelpersFunctions() template.FuncMap { "referenceToTypeName": ReferenceToTypeName, "generateValidateTags": generators.GenerateValidateTags[asyncapi.Schema], "generateJSONTags": generators.GenerateJSONTags[asyncapi.Schema], + "extractSchemaNameFromReference": ExtractSchemaNameFromReference, } } diff --git a/pkg/codegen/generators/v2/templates/schema_name.tmpl b/pkg/codegen/generators/v2/templates/schema_name.tmpl index 79c6940f..400614df 100644 --- a/pkg/codegen/generators/v2/templates/schema_name.tmpl +++ b/pkg/codegen/generators/v2/templates/schema_name.tmpl @@ -1,9 +1,39 @@ {{define "schema-name" -}} - {{- /* ------------------------- Custom Go type ------------------------- */ -}} {{- if .ExtGoType -}} {{ .ExtGoType }} +{{- /* ---------------------------- Reference (highest priority) ------- */ -}} +{{- else if .ReferenceTo -}} +{{ namify .ReferenceTo.Name }} + +{{- /* ----------------------- Reference without ReferenceTo (fallback) ---- */ -}} +{{- else if .Reference -}} +{{ extractSchemaNameFromReference .Reference }} + +{{- /* ------------------------- AllOf schema ------------------------- */ -}} +{{- else if .AllOf -}} +{{ namify .Name }} + +{{- /* ------------------------- AnyOf or OneOf ------------------------- */ -}} +{{- else if or .AnyOf .OneOf -}} +{{$xxxOf := $.AnyOf}}{{- if .OneOf }}{{$xxxOf = $.OneOf}}{{end -}} + +struct { + {{- if .OneOf }} + // WARNING: only one of the following field can be used + {{ end }} + +{{- range $key, $value := $xxxOf}} + // {{ if $value.Reference}}{{ .ReferenceTo.Name }}{{else}}AnyOf{{$key}}{{end}} +{{- if $value.Description}} + // Description: {{multiLineComment $value.Description}} +{{- end}} + {{ if $value.Reference}}{{ .ReferenceTo.Name }}{{else}}AnyOf{{$key}}{{end}} *{{template "schema-name" $value}} +{{end -}} +} + +{{- /* ------------------------- Type handling ------------------------- */ -}} {{- else if .Type -}} {{- /* --------------------------- Type Object -------------------------- */ -}} @@ -49,31 +79,8 @@ float64 // WARNING: no generation occured here as it has unknown type '{{.Type}}' {{- end -}} -{{- /* ------------------------- AnyOf or OneOf ------------------------- */ -}} -{{- else if or .AnyOf .OneOf -}} -{{$xxxOf := $.AnyOf}}{{- if .OneOf }}{{$xxxOf = $.OneOf}}{{end -}} - -struct { - {{- if .OneOf }} - // WARNING: only one of the following field can be used - {{ end }} - -{{- range $key, $value := $xxxOf}} - // {{ if $value.Reference}}{{ .ReferenceTo.Name }}{{else}}AnyOf{{$key}}{{end}} -{{- if $value.Description}} - // Description: {{multiLineComment $value.Description}} -{{- end}} - {{ if $value.Reference}}{{ .ReferenceTo.Name }}{{else}}AnyOf{{$key}}{{end}} *{{template "schema" $value}} -{{end -}} -} - -{{- /* ---------------------------- Reference --------------------------- */ -}} -{{- else if .ReferenceTo -}} -{{ namify .Follow.Name }} - {{- /* ----------------------- Unsupported use case ---------------------- */ -}} {{- else -}} -interface{} // WARNING: potential error in AsyncAPI generation // Infos on type: {{ describeStruct . }} {{- end -}} diff --git a/test/v2/issues/290/asyncapi.gen.go b/test/v2/issues/290/asyncapi.gen.go new file mode 100644 index 00000000..b45d1f85 --- /dev/null +++ b/test/v2/issues/290/asyncapi.gen.go @@ -0,0 +1,515 @@ +// Package "issue290" provides primitives to interact with the AsyncAPI specification. +// +// Code generated by github.com/lerenn/asyncapi-codegen version (devel) DO NOT EDIT. +package issue290 + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/lerenn/asyncapi-codegen/pkg/extensions" +) + +// AppController is the structure that provides publishing capabilities to the +// developer and and connect the broker with the App +type AppController struct { + controller +} + +// NewAppController links the App to the broker +func NewAppController(bc extensions.BrokerController, options ...ControllerOption) (*AppController, error) { + // Check if broker controller has been provided + if bc == nil { + return nil, extensions.ErrNilBrokerController + } + + // Create default controller + controller := controller{ + broker: bc, + subscriptions: make(map[string]extensions.BrokerChannelSubscription), + logger: extensions.DummyLogger{}, + middlewares: make([]extensions.Middleware, 0), + errorHandler: extensions.DefaultErrorHandler(), + } + + // Apply options + for _, option := range options { + option(&controller) + } + + return &AppController{controller: controller}, nil +} + +func (c AppController) wrapMiddlewares( + middlewares []extensions.Middleware, + callback extensions.NextMiddleware, +) func(ctx context.Context, msg *extensions.BrokerMessage) error { + var called bool + + // If there is no more middleware + if len(middlewares) == 0 { + return func(ctx context.Context, msg *extensions.BrokerMessage) error { + // Call the callback if it exists and it has not been called already + if callback != nil && !called { + called = true + return callback(ctx) + } + + // Nil can be returned, as the callback has already been called + return nil + } + } + + // Get the next function to call from next middlewares or callback + next := c.wrapMiddlewares(middlewares[1:], callback) + + // Wrap middleware into a check function that will call execute the middleware + // and call the next wrapped middleware if the returned function has not been + // called already + return func(ctx context.Context, msg *extensions.BrokerMessage) error { + // Call the middleware and the following if it has not been done already + if !called { + // Create the next call with the context and the message + nextWithArgs := func(ctx context.Context) error { + return next(ctx, msg) + } + + // Call the middleware and register it as already called + called = true + if err := middlewares[0](ctx, msg, nextWithArgs); err != nil { + return err + } + + // If next has already been called in middleware, it should not be executed again + return nextWithArgs(ctx) + } + + // Nil can be returned, as the next middleware has already been called + return nil + } +} + +func (c AppController) executeMiddlewares(ctx context.Context, msg *extensions.BrokerMessage, callback extensions.NextMiddleware) error { + // Wrap middleware to have 'next' function when calling them + wrapped := c.wrapMiddlewares(c.middlewares, callback) + + // Execute wrapped middlewares + return wrapped(ctx, msg) +} + +func addAppContextValues(ctx context.Context, path string) context.Context { + ctx = context.WithValue(ctx, extensions.ContextKeyIsVersion, "1.0.0") + ctx = context.WithValue(ctx, extensions.ContextKeyIsProvider, "app") + return context.WithValue(ctx, extensions.ContextKeyIsChannel, path) +} + +// Close will clean up any existing resources on the controller +func (c *AppController) Close(ctx context.Context) { + // Unsubscribing remaining channels +} + +// PublishTestCreated will publish messages to 'test.created' channel +func (c *AppController) PublishTestCreated( + ctx context.Context, + msg TestCreatedMessage, +) error { + // Get channel path + path := "test.created" + + // Set context + ctx = addAppContextValues(ctx, path) + ctx = context.WithValue(ctx, extensions.ContextKeyIsDirection, "publication") + + // Convert to BrokerMessage + brokerMsg, err := msg.toBrokerMessage() + if err != nil { + return err + } + + // Set broker message to context + ctx = context.WithValue(ctx, extensions.ContextKeyIsBrokerMessage, brokerMsg.String()) + + // Publish the message on event-broker through middlewares + return c.executeMiddlewares(ctx, &brokerMsg, func(ctx context.Context) error { + return c.broker.Publish(ctx, path, brokerMsg) + }) +} + +// UserSubscriber represents all handlers that are expecting messages for User +type UserSubscriber interface { + // TestCreated subscribes to messages placed on the 'test.created' channel + TestCreated(ctx context.Context, msg TestCreatedMessage) error +} + +// UserController is the structure that provides publishing capabilities to the +// developer and and connect the broker with the User +type UserController struct { + controller +} + +// NewUserController links the User to the broker +func NewUserController(bc extensions.BrokerController, options ...ControllerOption) (*UserController, error) { + // Check if broker controller has been provided + if bc == nil { + return nil, extensions.ErrNilBrokerController + } + + // Create default controller + controller := controller{ + broker: bc, + subscriptions: make(map[string]extensions.BrokerChannelSubscription), + logger: extensions.DummyLogger{}, + middlewares: make([]extensions.Middleware, 0), + errorHandler: extensions.DefaultErrorHandler(), + } + + // Apply options + for _, option := range options { + option(&controller) + } + + return &UserController{controller: controller}, nil +} + +func (c UserController) wrapMiddlewares( + middlewares []extensions.Middleware, + callback extensions.NextMiddleware, +) func(ctx context.Context, msg *extensions.BrokerMessage) error { + var called bool + + // If there is no more middleware + if len(middlewares) == 0 { + return func(ctx context.Context, msg *extensions.BrokerMessage) error { + // Call the callback if it exists and it has not been called already + if callback != nil && !called { + called = true + return callback(ctx) + } + + // Nil can be returned, as the callback has already been called + return nil + } + } + + // Get the next function to call from next middlewares or callback + next := c.wrapMiddlewares(middlewares[1:], callback) + + // Wrap middleware into a check function that will call execute the middleware + // and call the next wrapped middleware if the returned function has not been + // called already + return func(ctx context.Context, msg *extensions.BrokerMessage) error { + // Call the middleware and the following if it has not been done already + if !called { + // Create the next call with the context and the message + nextWithArgs := func(ctx context.Context) error { + return next(ctx, msg) + } + + // Call the middleware and register it as already called + called = true + if err := middlewares[0](ctx, msg, nextWithArgs); err != nil { + return err + } + + // If next has already been called in middleware, it should not be executed again + return nextWithArgs(ctx) + } + + // Nil can be returned, as the next middleware has already been called + return nil + } +} + +func (c UserController) executeMiddlewares(ctx context.Context, msg *extensions.BrokerMessage, callback extensions.NextMiddleware) error { + // Wrap middleware to have 'next' function when calling them + wrapped := c.wrapMiddlewares(c.middlewares, callback) + + // Execute wrapped middlewares + return wrapped(ctx, msg) +} + +func addUserContextValues(ctx context.Context, path string) context.Context { + ctx = context.WithValue(ctx, extensions.ContextKeyIsVersion, "1.0.0") + ctx = context.WithValue(ctx, extensions.ContextKeyIsProvider, "user") + return context.WithValue(ctx, extensions.ContextKeyIsChannel, path) +} + +// Close will clean up any existing resources on the controller +func (c *UserController) Close(ctx context.Context) { + // Unsubscribing remaining channels + c.UnsubscribeAll(ctx) + + c.logger.Info(ctx, "Closed user controller") +} + +// SubscribeAll will subscribe to channels without parameters on which the app is expecting messages. +// For channels with parameters, they should be subscribed independently. +func (c *UserController) SubscribeAll(ctx context.Context, as UserSubscriber) error { + if as == nil { + return extensions.ErrNilUserSubscriber + } + + if err := c.SubscribeTestCreated(ctx, as.TestCreated); err != nil { + return err + } + + return nil +} + +// UnsubscribeAll will unsubscribe all remaining subscribed channels +func (c *UserController) UnsubscribeAll(ctx context.Context) { + c.UnsubscribeTestCreated(ctx) +} + +// SubscribeTestCreated will subscribe to new messages from 'test.created' channel. +// +// Callback function 'fn' will be called each time a new message is received. +func (c *UserController) SubscribeTestCreated( + ctx context.Context, + fn func(ctx context.Context, msg TestCreatedMessage) error, +) error { + // Get channel path + path := "test.created" + + // Set context + ctx = addUserContextValues(ctx, path) + ctx = context.WithValue(ctx, extensions.ContextKeyIsDirection, "reception") + + // Check if there is already a subscription + _, exists := c.subscriptions[path] + if exists { + err := fmt.Errorf("%w: %q channel is already subscribed", extensions.ErrAlreadySubscribedChannel, path) + c.logger.Error(ctx, err.Error()) + return err + } + + // Subscribe to broker channel + sub, err := c.broker.Subscribe(ctx, path) + if err != nil { + c.logger.Error(ctx, err.Error()) + return err + } + c.logger.Info(ctx, "Subscribed to channel") + + // Asynchronously listen to new messages and pass them to app subscriber + go func() { + for { + // Listen to next message + stop, err := c.listenToTestCreatedNextMessage(path, sub, fn) + if err != nil { + c.logger.Error(ctx, err.Error()) + } + + // Stop if required + if stop { + return + } + } + }() + + // Add the cancel channel to the inside map + c.subscriptions[path] = sub + + return nil +} + +func (c *UserController) listenToTestCreatedNextMessage( + path string, + sub extensions.BrokerChannelSubscription, + fn func(ctx context.Context, msg TestCreatedMessage) error, +) (stop bool, err error) { + // Create a context for the received response + msgCtx, cancel := context.WithCancel(context.Background()) + msgCtx = addUserContextValues(msgCtx, path) + msgCtx = context.WithValue(msgCtx, extensions.ContextKeyIsDirection, "reception") + defer cancel() + + // Wait for next message + acknowledgeableBrokerMessage, open := <-sub.MessagesChannel() + + // If subscription is closed and there is no more message + // (i.e. uninitialized message), then exit the function + if !open && acknowledgeableBrokerMessage.IsUninitialized() { + return true, nil + } + + // Set broker message to context + msgCtx = context.WithValue(msgCtx, extensions.ContextKeyIsBrokerMessage, acknowledgeableBrokerMessage.String()) + + // Execute middlewares before handling the message + if err := c.executeMiddlewares(msgCtx, &acknowledgeableBrokerMessage.BrokerMessage, func(middlewareCtx context.Context) error { + // Process message + msg, err := brokerMessageToTestCreatedMessage(acknowledgeableBrokerMessage.BrokerMessage) + if err != nil { + return err + } + + // Execute the subscription function + if err := fn(middlewareCtx, msg); err != nil { + return err + } + + acknowledgeableBrokerMessage.Ack() + + return nil + }); err != nil { + c.errorHandler(msgCtx, path, &acknowledgeableBrokerMessage, err) + // On error execute the acknowledgeableBrokerMessage nack() function and + // let the BrokerAcknowledgment decide what is the right nack behavior for the broker + acknowledgeableBrokerMessage.Nak() + } + + return false, nil +} + +// UnsubscribeTestCreated will unsubscribe messages from 'test.created' channel. +// A timeout can be set in context to avoid blocking operation, if needed. +func (c *UserController) UnsubscribeTestCreated(ctx context.Context) { + // Get channel path + path := "test.created" + + // Check if there subscribers for this channel + sub, exists := c.subscriptions[path] + if !exists { + return + } + + // Set context + ctx = addUserContextValues(ctx, path) + + // Stop the subscription + sub.Cancel(ctx) + + // Remove if from the subscribers + delete(c.subscriptions, path) + + c.logger.Info(ctx, "Unsubscribed from channel") +} + +// AsyncAPIVersion is the version of the used AsyncAPI document +const AsyncAPIVersion = "1.0.0" + +// controller is the controller that will be used to communicate with the broker +// It will be used internally by AppController and UserController +type controller struct { + // broker is the broker controller that will be used to communicate + broker extensions.BrokerController + // subscriptions is a map of all subscriptions + subscriptions map[string]extensions.BrokerChannelSubscription + // logger is the logger that will be used² to log operations on controller + logger extensions.Logger + // middlewares are the middlewares that will be executed when sending or + // receiving messages + middlewares []extensions.Middleware + // handler to handle errors from consumers and middlewares + errorHandler extensions.ErrorHandler +} + +// ControllerOption is the type of the options that can be passed +// when creating a new Controller +type ControllerOption func(controller *controller) + +// WithLogger attaches a logger to the controller +func WithLogger(logger extensions.Logger) ControllerOption { + return func(controller *controller) { + controller.logger = logger + } +} + +// WithMiddlewares attaches middlewares that will be executed when sending or receiving messages +func WithMiddlewares(middlewares ...extensions.Middleware) ControllerOption { + return func(controller *controller) { + controller.middlewares = middlewares + } +} + +// WithErrorHandler attaches a errorhandler to handle errors from subscriber functions +func WithErrorHandler(handler extensions.ErrorHandler) ControllerOption { + return func(controller *controller) { + controller.errorHandler = handler + } +} + +type MessageWithCorrelationID interface { + CorrelationID() string + SetCorrelationID(id string) +} + +type Error struct { + Channel string + Err error +} + +func (e *Error) Error() string { + return fmt.Sprintf("channel %q: err %v", e.Channel, e.Err) +} + +// TestCreatedMessage is the message expected for 'TestCreatedMessage' channel. +type TestCreatedMessage struct { + // Payload will be inserted in the message payload + Payload TestEventSchema +} + +func NewTestCreatedMessage() TestCreatedMessage { + var msg TestCreatedMessage + + return msg +} + +// brokerMessageToTestCreatedMessage will fill a new TestCreatedMessage with data from generic broker message +func brokerMessageToTestCreatedMessage(bMsg extensions.BrokerMessage) (TestCreatedMessage, error) { + var msg TestCreatedMessage + + // Unmarshal payload to expected message payload format + err := json.Unmarshal(bMsg.Payload, &msg.Payload) + if err != nil { + return msg, err + } + + // TODO: run checks on msg type + + return msg, nil +} + +// toBrokerMessage will generate a generic broker message from TestCreatedMessage data +func (msg TestCreatedMessage) toBrokerMessage() (extensions.BrokerMessage, error) { + // TODO: implement checks on message + + // Marshal payload to JSON + payload, err := json.Marshal(msg.Payload) + if err != nil { + return extensions.BrokerMessage{}, err + } + + // There is no headers here + headers := make(map[string][]byte, 0) + + return extensions.BrokerMessage{ + Headers: headers, + Payload: payload, + }, nil +} + +// BaseEventSchema is a schema from the AsyncAPI specification required in messages +type BaseEventSchema struct { + Id *string `json:"id,omitempty"` + Timestamp *string `json:"timestamp,omitempty"` +} + +// TestEventSchema is a schema from the AsyncAPI specification required in messages +type TestEventSchema struct { + Data *string `json:"data,omitempty"` + Id *string `json:"id,omitempty"` + Timestamp *string `json:"timestamp,omitempty"` +} + +const ( + // TestCreatedPath is the constant representing the 'TestCreated' channel path. + TestCreatedPath = "test.created" +) + +// ChannelsPaths is an array of all channels paths +var ChannelsPaths = []string{ + TestCreatedPath, +} diff --git a/test/v2/issues/290/asyncapi.yaml b/test/v2/issues/290/asyncapi.yaml new file mode 100644 index 00000000..e8f558e9 --- /dev/null +++ b/test/v2/issues/290/asyncapi.yaml @@ -0,0 +1,29 @@ +asyncapi: 2.5.0 +info: + title: Issue 290 - allOf schema references in message payloads + version: 1.0.0 + +channels: + test.created: + subscribe: + message: + payload: + $ref: '#/components/schemas/TestEvent' + +components: + schemas: + TestEvent: + allOf: + - $ref: '#/components/schemas/BaseEvent' + - type: object + properties: + data: + type: string + + BaseEvent: + type: object + properties: + id: + type: string + timestamp: + type: string diff --git a/test/v2/issues/290/suite_test.go b/test/v2/issues/290/suite_test.go new file mode 100644 index 00000000..8d81df44 --- /dev/null +++ b/test/v2/issues/290/suite_test.go @@ -0,0 +1,42 @@ +//go:generate go run ../../../../cmd/asyncapi-codegen -p issue290 -i ./asyncapi.yaml -o ./asyncapi.gen.go + +package issue290 + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +func TestSuite(t *testing.T) { + suite.Run(t, NewSuite()) +} + +type Suite struct { + suite.Suite +} + +func NewSuite() *Suite { + return &Suite{} +} + +func stringPtr(s string) *string { + return &s +} + +func (suite *Suite) TestMessagePayloadTypeGeneration() { + // Test that the message payload type is properly generated from allOf schema reference + // Before the fix, this would fail to compile because Payload had no type + msg := NewTestCreatedMessage() + + // Verify the message has a properly typed payload field from allOf composition + msg.Payload.Id = stringPtr("test-id") + msg.Payload.Timestamp = stringPtr("2023-01-01T00:00:00Z") + msg.Payload.Data = stringPtr("test-data") + + // Verify fields are accessible and properly typed + assert.Equal(suite.T(), "test-id", *msg.Payload.Id) + assert.Equal(suite.T(), "2023-01-01T00:00:00Z", *msg.Payload.Timestamp) + assert.Equal(suite.T(), "test-data", *msg.Payload.Data) +}