Skip to content
Open
45 changes: 42 additions & 3 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ package neffos
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"sync"
"sync/atomic"
"time"

"github.com/gorilla/websocket"
)

type (
Expand Down Expand Up @@ -324,7 +327,22 @@ func (c *Conn) startReader() {
for {
b, msgTyp, err := c.socket.ReadData(c.readTimeout)
if err != nil {
if c.server.Logger != nil {
if !c.isAcknowledged() {
c.server.Logger.Error(fmt.Errorf("unacknowledged read data err. id:%s, err:%s", c.ID(), err.Error()))
} else if websocket.IsCloseError(err, websocket.CloseNoStatusReceived) {
c.server.Logger.Debug(fmt.Sprintf("read data err. id:%s, err:%s", c.ID(), err.Error()))
} else {
c.server.Logger.Error(fmt.Errorf("read data err. id:%s, err:%s", c.ID(), err.Error()))
}
}

c.readiness.unwait(err)
// if websocket.IsCloseError(err, websocket.CloseNoStatusReceived) {
// c.server.Logger.Debug(fmt.Sprintf("read data err. id:%s, err:%s", c.ID(), err.Error()))
// } else {
// c.server.Logger.Error(fmt.Errorf("read data err. id:%s, err:%s", c.ID(), err.Error()))
// }
return
}

Expand All @@ -346,13 +364,19 @@ func (c *Conn) startReader() {
}

func (c *Conn) handleACK(msgTyp MessageType, b []byte) bool {
if c.server.Logger != nil {
c.server.Logger.Debug(fmt.Sprintf("handle ACK, id:%s, msgTyp:%d, b:%s", c.ID(), msgTyp, string(b)))
}
switch typ := b[0]; typ {
case ackBinary:
// from client startup to server.
err := c.readiness.wait()
if err != nil {
// it's not Ok, send error which client's Dial should return.
c.write(append(ackNotOKBinaryB, []byte(err.Error())...), false)
if c.server.Logger != nil {
c.server.Logger.Error(fmt.Errorf("ackBinary err. id:%s, err:%s", c.ID(), err.Error()))
}
return false
}
atomic.StoreUint32(c.acknowledged, 1)
Expand Down Expand Up @@ -488,7 +512,20 @@ func (c *Conn) DeserializeMessage(msgTyp MessageType, payload []byte) Message {

// HandlePayload fires manually a local event based on the "payload".
func (c *Conn) HandlePayload(msgTyp MessageType, payload []byte) error {
return c.handleMessage(c.DeserializeMessage(msgTyp, payload))
msg := c.DeserializeMessage(msgTyp, payload)
if err := c.handleMessage(msg); err != nil {
if c.server.Logger != nil {
if err == ErrInvalidPayload {
c.server.Logger.Error(fmt.Errorf("handle payload err. id:%s, msgType:%d, payload:%s, err:%s", c.ID(), msgTyp, string(payload), err.Error()))
} else {
c.server.Logger.Error(fmt.Errorf("handle payload err. id:%s, msgType:%d, namespace:%s, room:%s, event:%s, body:%s, err:%s", c.ID(), msgTyp, msg.Namespace, msg.Room, msg.Event, string(msg.Body), err.Error()))
}
}

return err
}

return nil
}

const syncWaitDur = 15 * time.Millisecond
Expand Down Expand Up @@ -909,7 +946,7 @@ func (c *Conn) Write(msg Message) bool {
}

msg.FromExplicit = ""
return c.write(serializeMessage(msg), msg.SetBinary)
return c.write(serializeMessage(msg, false), msg.SetBinary)
}

// used when `Ask` caller cares only for successful call and not the message, for performance reasons we just use raw bytes.
Expand Down Expand Up @@ -1006,6 +1043,7 @@ func (c *Conn) ask(ctx context.Context, msg Message, mustWaitOnlyTheNextMessage
// After this method call the `Conn` is not usable anymore, a new `Dial` call is required.
func (c *Conn) Close() {
if atomic.CompareAndSwapUint32(c.closed, 0, 1) {
var disconnectNamspaces []string
if !c.shouldHandleOnlyNativeMessages {
disconnectMsg := Message{Event: OnNamespaceDisconnect, IsForced: true, IsLocal: true}
c.connectedNamespacesMutex.Lock()
Expand All @@ -1016,6 +1054,7 @@ func (c *Conn) Close() {
disconnectMsg.Namespace = ns.namespace
ns.events.fireEvent(ns, disconnectMsg)
delete(c.connectedNamespaces, namespace)
disconnectNamspaces = append(disconnectNamspaces, namespace)
}
c.connectedNamespacesMutex.Unlock()

Expand All @@ -1030,7 +1069,7 @@ func (c *Conn) Close() {

if !c.IsClient() {
go func() {
c.server.disconnect <- c
c.server.disconnect <- disconnectAction{conn: c, namespaces: disconnectNamspaces}
}()
}

Expand Down
9 changes: 5 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ module github.com/kataras/neffos
go 1.21

require (
github.com/bytedance/gopkg v0.1.2-0.20241212062526-165e60aa2d41
github.com/gobwas/ws v1.3.2
github.com/gorilla/websocket v1.5.1
github.com/iris-contrib/go.uuid v2.0.0+incompatible
github.com/mediocregopher/radix/v3 v3.8.1
github.com/nats-io/nats.go v1.31.0
golang.org/x/sync v0.6.0
golang.org/x/sync v0.8.0
)

require (
Expand All @@ -17,9 +18,9 @@ require (
github.com/klauspost/compress v1.17.0 // indirect
github.com/nats-io/nkeys v0.4.7 // indirect
github.com/nats-io/nuid v1.0.1 // indirect
golang.org/x/crypto v0.18.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.16.0 // indirect
golang.org/x/crypto v0.22.0 // indirect
golang.org/x/net v0.24.0 // indirect
golang.org/x/sys v0.19.0 // indirect
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
)
23 changes: 14 additions & 9 deletions go.sum

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions logger.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package neffos

type logger interface {
// 调试日志
Debug(msg string)

// 提示
Info(msg string)

// 警告
Warn(msg string)

// 错误日志
Error(err error)
}
99 changes: 69 additions & 30 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ func (m *Message) isRoomLeft() bool {
}

// Serialize returns this message's transport format.
func (m Message) Serialize() []byte {
return serializeMessage(m)
func (m Message) Serialize(publishToMultiplexStackExchange bool) []byte {
return serializeMessage(m, publishToMultiplexStackExchange)
}

type (
Expand Down Expand Up @@ -274,7 +274,7 @@ func unescape(s string) string {
return strings.Replace(s, messageFieldSeparatorReplacement, messageSeparatorString, -1)
}

func serializeMessage(msg Message) (out []byte) {
func serializeMessage(msg Message, publishToMultiplexStackExchange bool) (out []byte) {
if msg.IsNative && msg.wait == "" {
out = msg.Body
} else {
Expand All @@ -286,13 +286,19 @@ func serializeMessage(msg Message) (out []byte) {

msg.wait = msg.FromExplicit
}
out = serializeOutput(msg.wait, escape(msg.Namespace), escape(msg.Room), escape(msg.Event), msg.Body, msg.Err, msg.isNoOp)

var to string
if publishToMultiplexStackExchange && msg.To != "" {
to = msg.To
}

out = serializeOutput(msg.wait, escape(msg.Namespace), escape(msg.Room), escape(to), escape(msg.Event), msg.Body, msg.Err, msg.isNoOp)
}

return out
}

func serializeOutput(wait, namespace, room, event string,
func serializeOutput(wait, namespace, room, to, event string,
body []byte,
err error,
isNoOp bool,
Expand Down Expand Up @@ -321,15 +327,29 @@ func serializeOutput(wait, namespace, room, event string,
waitByte = []byte(wait)
}

msg := bytes.Join([][]byte{ // this number of fields should match the deserializer's, see `validMessageSepCount`.
waitByte,
[]byte(namespace),
[]byte(room),
[]byte(event),
isErrorByte,
isNoOpByte,
body,
}, messageSeparator)
var msg []byte
if to == "" {
msg = bytes.Join([][]byte{ // this number of fields should match the deserializer's, see `validMessageSepCount`.
waitByte,
[]byte(namespace),
[]byte(room),
[]byte(event),
isErrorByte,
isNoOpByte,
body,
}, messageSeparator)
} else {
msg = bytes.Join([][]byte{ // this number of fields should match the deserializer's, see `validMessageSepCount`.
waitByte,
[]byte(namespace),
[]byte(room),
[]byte(to),
[]byte(event),
isErrorByte,
isNoOpByte,
body,
}, messageSeparator)
}

return msg
}
Expand All @@ -338,7 +358,7 @@ func serializeOutput(wait, namespace, room, event string,
// and returns a neffos Message.
// When allowNativeMessages only Body is filled and check about message format is skipped.
func DeserializeMessage(msgTyp MessageType, b []byte, allowNativeMessages, shouldHandleOnlyNativeMessages bool) Message {
wait, namespace, room, event, body, err, isNoOp, isInvalid := deserializeInput(b, allowNativeMessages, shouldHandleOnlyNativeMessages)
wait, namespace, room, to, event, body, err, isNoOp, isInvalid := deserializeInput(b, allowNativeMessages, shouldHandleOnlyNativeMessages)

fromExplicit := ""
if isServerConnID(wait) {
Expand Down Expand Up @@ -366,7 +386,7 @@ func DeserializeMessage(msgTyp MessageType, b []byte, allowNativeMessages, shoul
from: "",
FromExplicit: fromExplicit,
FromStackExchange: fromStackExchange,
To: "",
To: unescape(to),
IsForced: false,
IsLocal: false,
IsNative: allowNativeMessages && event == OnNativeMessage,
Expand Down Expand Up @@ -419,6 +439,7 @@ func deserializeInput(b []byte, allowNativeMessages, shouldHandleOnlyNativeMessa
wait,
namespace,
room,
to,
event string,
body []byte,
err error,
Expand All @@ -438,8 +459,8 @@ func deserializeInput(b []byte, allowNativeMessages, shouldHandleOnlyNativeMessa
}

// Note: Go's SplitN returns the remainder in[6] but JavasSript's string.split behaves differently.
dts := bytes.SplitN(b, messageSeparator, validMessageSepCount)
if len(dts) != validMessageSepCount {
dts := bytes.SplitN(b, messageSeparator, validMessageSepCount+1)
if len(dts) != validMessageSepCount && len(dts) != validMessageSepCount+1 {
if !allowNativeMessages {
isInvalid = true
return
Expand All @@ -450,18 +471,36 @@ func deserializeInput(b []byte, allowNativeMessages, shouldHandleOnlyNativeMessa
return
}

wait = string(dts[0])
namespace = string(dts[1])
room = string(dts[2])
event = string(dts[3])
isError := bytes.Equal(dts[4], trueByte)
isNoOp = bytes.Equal(dts[5], trueByte)
if b := dts[6]; len(b) > 0 {
if isError {
errorText := string(b)
err = resolveError(errorText)
} else {
body = b // keep it like that.
if len(dts) == validMessageSepCount {
wait = string(dts[0])
namespace = string(dts[1])
room = string(dts[2])
event = string(dts[3])
isError := bytes.Equal(dts[4], trueByte)
isNoOp = bytes.Equal(dts[5], trueByte)
if b := dts[6]; len(b) > 0 {
if isError {
errorText := string(b)
err = resolveError(errorText)
} else {
body = b // keep it like that.
}
}
} else {
wait = string(dts[0])
namespace = string(dts[1])
room = string(dts[2])
to = string(dts[3])
event = string(dts[4])
isError := bytes.Equal(dts[5], trueByte)
isNoOp = bytes.Equal(dts[6], trueByte)
if b := dts[7]; len(b) > 0 {
if isError {
errorText := string(b)
err = resolveError(errorText)
} else {
body = b // keep it like that.
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func TestMessageSerialization(t *testing.T) {
}

for i, tt := range tests {
got := serializeMessage(tt.msg)
got := serializeMessage(tt.msg, false)
if !bytes.Equal(got, tt.serialized) {
t.Fatalf("[%d] serialize: expected %s but got %s", i, tt.serialized, got)
}
Expand Down Expand Up @@ -129,7 +129,7 @@ func TestMessageSerialization(t *testing.T) {
expectedSerialized := []byte(fmt.Sprintf(";contains%ssemi;%sthis%sfor sure%s;thatdoesnot;0;0;",
messageFieldSeparatorReplacement, messageFieldSeparatorReplacement, messageFieldSeparatorReplacement, messageFieldSeparatorReplacement))

gotSerialized := serializeMessage(msg)
gotSerialized := serializeMessage(msg, false)

if !bytes.Equal(expectedSerialized, gotSerialized) {
t.Fatalf("expected escaped serialized to be: %s but got: %s", string(expectedSerialized), string(gotSerialized))
Expand Down
Loading