diff --git a/README.md b/README.md index b3bd42c..dffa296 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,8 @@ func main() { * **Publish()** * **SubscribeAsync()** * **SubscribeOnceAsync()** +* **SubscribeReplyAsync()** +* **Request()** * **WaitAsync()** #### New() @@ -116,6 +118,28 @@ Transactional determines whether subsequent callbacks for a topic are run serial #### SubscribeOnceAsync(topic string, args ...interface{}) SubscribeOnceAsync works like SubscribeOnce except the callback to executed asynchronously +#### SubscribeReplyAsync(topic string, fn interface{}) +SubscribeReplyAsync works like SubscribeAsync except the callback is expected to return a value. The value is returned to the caller of Publish. + +#### Request(topic string, handler interface{}, timeoutMs time.Duration, args ...interface{}) +Request is a function that allows you to make a request to a topic and wait for a response. The response is returned to the caller of `Request` as an interface{}. + +```go +bus := EventBus.New() + +func slowCalculator(reply string, a, b int) { + time.Sleep(3 * time.Second) + bus.Publish(reply, a + b) +} + +bus.SubscribeReplyAsync("main:slow_calculator", slowCalculator) + +reply := bus.Request("main:slow_calculator", func(rslt int) { + fmt.Printf("Result: %d\n", rslt) +}, 20, 60) + +``` + #### WaitAsync() WaitAsync waits for all async callbacks to complete. diff --git a/event_bus.go b/event_bus.go index dedc7fd..a80725d 100644 --- a/event_bus.go +++ b/event_bus.go @@ -1,32 +1,43 @@ package EventBus import ( + "errors" "fmt" + "github.com/google/uuid" "reflect" "sync" + "time" ) -//BusSubscriber defines subscription-related bus behavior +const ( + ReplyTopicPrefix = "_INBOX:" +) + +type Void struct{} + +// BusSubscriber defines subscription-related bus behavior type BusSubscriber interface { Subscribe(topic string, fn interface{}) error SubscribeAsync(topic string, fn interface{}, transactional bool) error SubscribeOnce(topic string, fn interface{}) error SubscribeOnceAsync(topic string, fn interface{}) error + SubscribeReplyAsync(topic string, fn interface{}) error Unsubscribe(topic string, handler interface{}) error } -//BusPublisher defines publishing-related bus behavior +// BusPublisher defines publishing-related bus behavior type BusPublisher interface { Publish(topic string, args ...interface{}) + Request(topic string, handler interface{}, timeout time.Duration, args ...interface{}) error } -//BusController defines bus control behavior (checking handler's presence, synchronization) +// BusController defines bus control behavior (checking handler's presence, synchronization) type BusController interface { HasCallback(topic string) bool WaitAsync() } -//Bus englobes global (subscribe, publish, control) bus behavior +// Bus englobes global (subscribe, publish, control) bus behavior type Bus interface { BusController BusSubscriber @@ -35,9 +46,9 @@ type Bus interface { // EventBus - box for handlers and callbacks. type EventBus struct { - handlers map[string][]*eventHandler - lock sync.Mutex // a lock for the map - wg sync.WaitGroup + mapHandlers sync.Map + lock sync.Mutex // a lock for the map + wg sync.WaitGroup } type eventHandler struct { @@ -51,7 +62,7 @@ type eventHandler struct { // New returns new EventBus with empty handlers. func New() Bus { b := &EventBus{ - make(map[string][]*eventHandler), + sync.Map{}, sync.Mutex{}, sync.WaitGroup{}, } @@ -60,12 +71,15 @@ func New() Bus { // doSubscribe handles the subscription logic and is utilized by the public Subscribe functions func (bus *EventBus) doSubscribe(topic string, fn interface{}, handler *eventHandler) error { - bus.lock.Lock() - defer bus.lock.Unlock() if !(reflect.TypeOf(fn).Kind() == reflect.Func) { return fmt.Errorf("%s is not of type reflect.Func", reflect.TypeOf(fn).Kind()) } - bus.handlers[topic] = append(bus.handlers[topic], handler) + // rewrite in sync.Map + if _, ok := bus.mapHandlers.Load(topic); !ok { + bus.mapHandlers.Store(topic, []*eventHandler{}) + } + handlers, _ := bus.mapHandlers.Load(topic) + bus.mapHandlers.Store(topic, append(handlers.([]*eventHandler), handler)) return nil } @@ -104,13 +118,35 @@ func (bus *EventBus) SubscribeOnceAsync(topic string, fn interface{}) error { }) } +// SubcribeReplyAsync subscribes to a topic with an asynchronous callback +func (bus *EventBus) SubscribeReplyAsync(topic string, fn interface{}) error { + fnValue := reflect.ValueOf(fn) + if fnValue.Kind() != reflect.Func { + return errors.New("fn must be a function") + } + + fnType := fnValue.Type() + if fnType.NumIn() == 0 { + return errors.New("fn must have at least one input parameter") + } + + if fnType.In(0).Kind() != reflect.String { + return errors.New("fn's first parameter (reply topic) must be a string") + } + + return bus.doSubscribe(topic, fn, &eventHandler{ + reflect.ValueOf(fn), false, true, false, sync.Mutex{}, + }) +} + // HasCallback returns true if exists any callback subscribed to the topic. func (bus *EventBus) HasCallback(topic string) bool { bus.lock.Lock() - defer bus.lock.Unlock() - _, ok := bus.handlers[topic] - if ok { - return len(bus.handlers[topic]) > 0 + defer func() { + bus.lock.Unlock() + }() + if handlers, ok := bus.mapHandlers.Load(topic); ok { + return len(handlers.([]*eventHandler)) > 0 } return false } @@ -118,22 +154,70 @@ func (bus *EventBus) HasCallback(topic string) bool { // Unsubscribe removes callback defined for a topic. // Returns error if there are no callbacks subscribed to the topic. func (bus *EventBus) Unsubscribe(topic string, handler interface{}) error { - bus.lock.Lock() - defer bus.lock.Unlock() - if _, ok := bus.handlers[topic]; ok && len(bus.handlers[topic]) > 0 { - bus.removeHandler(topic, bus.findHandlerIdx(topic, reflect.ValueOf(handler))) - return nil + if iHandlers, ok := bus.mapHandlers.Load(topic); ok { + handlers := iHandlers.([]*eventHandler) + for i, h := range handlers { + if h.callBack.Type() == reflect.ValueOf(handler).Type() && + h.callBack.Pointer() == reflect.ValueOf(handler).Pointer() { + handlers = append(handlers[:i], handlers[i+1:]...) + bus.mapHandlers.Store(topic, handlers) + return nil + } + } } return fmt.Errorf("topic %s doesn't exist", topic) } +func (bus *EventBus) Request(topic string, handler interface{}, timeout time.Duration, args ...interface{}) error { + inboxStr := fmt.Sprintf("%v%v:%v", ReplyTopicPrefix, topic, uuid.NewString()) + if !bus.HasCallback(topic) { + return fmt.Errorf("no responder on topic: %v", topic) + } + chResult := make(chan Void) + + wrapperHandler := func(args ...interface{}) { + chResult <- Void{} + handlerValue := reflect.ValueOf(handler) + if handlerValue.Kind() != reflect.Func { + fmt.Printf("handler is not a function: %v\n", handler) + return + } + handlerArgs := make([]reflect.Value, len(args)) + for i, arg := range args { + handlerArgs[i] = reflect.ValueOf(arg) + } + handlerValue.Call(handlerArgs) + } + err := bus.SubscribeOnce(inboxStr, wrapperHandler) + // fmt.Printf("subscribing: %v\n", inboxStr) + if err != nil { + fmt.Println("failed to subscribe to reply topic: %w", err) + } + newArgs := append([]interface{}{inboxStr}, args...) + bus.Publish(topic, newArgs...) + + timer := time.NewTimer(timeout) + select { + case <-chResult: + return nil + case <-timer.C: + err = bus.Unsubscribe(inboxStr, wrapperHandler) + if err != nil { + err = fmt.Errorf("failed to unsubscribe: %v %w", inboxStr, err) + } + if err != nil { + err = fmt.Errorf("request timed out %w", err) + } else { + err = fmt.Errorf("request timed out") + } + return err + } +} + // Publish executes callback defined for a topic. Any additional argument will be transferred to the callback. func (bus *EventBus) Publish(topic string, args ...interface{}) { - bus.lock.Lock() // will unlock if handler is not found or always after setUpPublish - defer bus.lock.Unlock() - if handlers, ok := bus.handlers[topic]; ok && 0 < len(handlers) { - // Handlers slice may be changed by removeHandler and Unsubscribe during iteration, - // so make a copy and iterate the copied slice. + if iHandlers, ok := bus.mapHandlers.Load(topic); ok { + handlers := iHandlers.([]*eventHandler) copyHandlers := make([]*eventHandler, len(handlers)) copy(copyHandlers, handlers) for i, handler := range copyHandlers { @@ -145,9 +229,7 @@ func (bus *EventBus) Publish(topic string, args ...interface{}) { } else { bus.wg.Add(1) if handler.transactional { - bus.lock.Unlock() handler.Lock() - bus.lock.Lock() } go bus.doPublishAsync(handler, topic, args...) } @@ -169,28 +251,25 @@ func (bus *EventBus) doPublishAsync(handler *eventHandler, topic string, args .. } func (bus *EventBus) removeHandler(topic string, idx int) { - if _, ok := bus.handlers[topic]; !ok { - return - } - l := len(bus.handlers[topic]) - - if !(0 <= idx && idx < l) { - return + if iHandlers, ok := bus.mapHandlers.Load(topic); ok { + handlers := iHandlers.([]*eventHandler) + if len(handlers) > idx && idx >= 0 { + bus.mapHandlers.Store(topic, append(handlers[:idx], handlers[idx+1:]...)) + } } - - copy(bus.handlers[topic][idx:], bus.handlers[topic][idx+1:]) - bus.handlers[topic][l-1] = nil // or the zero value of T - bus.handlers[topic] = bus.handlers[topic][:l-1] } func (bus *EventBus) findHandlerIdx(topic string, callback reflect.Value) int { - if _, ok := bus.handlers[topic]; ok { - for idx, handler := range bus.handlers[topic] { + // rewrite in sync.Map + if iHandlers, ok := bus.mapHandlers.Load(topic); ok { + handlers := iHandlers.([]*eventHandler) + for idx, handler := range handlers { if handler.callBack.Type() == callback.Type() && handler.callBack.Pointer() == callback.Pointer() { return idx } } + } return -1 } diff --git a/event_bus_test.go b/event_bus_test.go index 0cdb579..970e0c0 100644 --- a/event_bus_test.go +++ b/event_bus_test.go @@ -1,6 +1,9 @@ package EventBus import ( + "fmt" + "github.com/stretchr/testify/assert" + "sync/atomic" "testing" "time" ) @@ -163,7 +166,7 @@ func TestSubscribeAsync(t *testing.T) { results := make(chan int) bus := New() - bus.SubscribeAsync("topic", func(a int, out chan<- int) { + _ = bus.SubscribeAsync("topic", func(a int, out chan<- int) { out <- a }, false) @@ -188,3 +191,129 @@ func TestSubscribeAsync(t *testing.T) { // t.Fail() //} } + +func TestRequestReply(t *testing.T) { + bus := New() + _ = bus.SubscribeReplyAsync("topic", func(replyTopic string, action string, in1 float64, in2 float64) { + var result float64 + switch action { + case "add": + result = in1 + in2 + case "sub": + result = in1 - in2 + case "mul": + result = in1 * in2 + case "div": + result = in1 / in2 + } + bus.Publish(replyTopic, result) + }) + + counter := 0 + + replyHandler := func(data float64) { + switch counter { + case 0: + assert.Equal(t, 22.0, data) + case 1: + assert.Equal(t, 2.0, data) + case 2: + assert.Equal(t, 120.0, data) + case 3: + assert.Equal(t, 1.2, data) + default: + assert.Fail(t, "unexpected response") + } + counter++ + } + + _ = bus.Request("topic", replyHandler, 10*time.Millisecond, "add", 12.0, 10.0) + _ = bus.Request("topic", replyHandler, 10*time.Millisecond, "sub", 12.0, 10.0) + _ = bus.Request("topic", replyHandler, 10*time.Millisecond, "mul", 12.0, 10.0) + _ = bus.Request("topic", replyHandler, 10*time.Millisecond, "div", 12.0, 10.0) + + time.Sleep(10 * time.Millisecond) +} + +func TestConcurrencyReply(t *testing.T) { + bus := New() + err := bus.SubscribeReplyAsync("concurrency", func(replyTopic string, in1 float64, in2 float64) { + time.Sleep(100 * time.Microsecond) + bus.Publish(replyTopic, in1, in2, in1+in2) + }) + if err != nil { + assert.Fail(t, "failed to subscribe") + } + counter := atomic.Uint64{} + replyHandler := func(in1, in2, data float64) { + assert.Equal(t, in1+in2, data, "wrong value") + counter.Add(1) + } + errCounter := atomic.Uint64{} + for i := 0; i < 10000; i++ { + go func() { + err = bus.Request("concurrency", replyHandler, 200*time.Microsecond, float64(i), float64(10*i)) + if err != nil { + errCounter.Add(1) + } + }() + } + + time.Sleep(2 * time.Second) + fmt.Printf("counter: %d error: %d\n", counter.Load(), errCounter.Load()) + assert.Equal(t, 10000, int(counter.Load()+errCounter.Load()), "wrong counter") +} + +func TestConcurrentPubSub(t *testing.T) { + bus := New() + counter := atomic.Int64{} + end := make(chan Void) + total := 10000 + err := bus.SubscribeAsync("concurrent", func() { + counter.Add(1) + if counter.Load() == int64(total) { + end <- Void{} + } + }, false) + if err != nil { + assert.Fail(t, "failed to subscribe") + } + for i := 0; i < 10000; i++ { + go bus.Publish("concurrent") + } + <-end + bus.WaitAsync() + assert.Equal(t, 10000, int(counter.Load()), "wrong counter") +} + +func TestFailedRequestReply(t *testing.T) { + bus := New() + err := bus.SubscribeReplyAsync("topic", func(replyTopic int, action string, in1 float64, in2 float64) {}) + if err != nil { + fmt.Println(err) + } else { + t.Fail() + } +} + +func TestRequestReplyTimeout(t *testing.T) { + bus := New() + slowCalculator := func(reply string, a, b int) { + time.Sleep(1 * time.Second) + bus.Publish(reply, a+b) + } + + _ = bus.SubscribeReplyAsync("main:slow_calculator", slowCalculator) + + err := bus.Request("main:slow_calculator", func(rslt int) { + fmt.Printf("Result: %d\n", rslt) + }, 10*time.Millisecond, 20, 60) + assert.NotNil(t, err, "Request should return timeout error") + + err = bus.Request("main:slow_calculator", func(rslt int) { + fmt.Printf("Result: %d\n", rslt) + }, 2*time.Second, 20, 90) + assert.Nil(t, err, "Request should not return an error") + + time.Sleep(100 * time.Millisecond) +}