Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ import (
const errThreshold = 3

type Conn struct {
conn net.Conn
text *textproto.Conn
server *Server
helo string
conn net.Conn
text *textproto.Conn
server *Server
helo string
rejected bool

// Number of errors witnessed on this connection
errCount int
Expand Down Expand Up @@ -105,6 +106,17 @@ func (c *Conn) handle(cmd string, arg string) {
return
}

//as per RFC5321 3.1
if c.rejected {
if cmd == "QUIT" {
c.writeResponse(221, NoEnhancedCode, "OK")
c.Close()
} else {
c.protocolError(503, NoEnhancedCode, "bad sequence of commands")
}
return
}

cmd = strings.ToUpper(cmd)
switch cmd {
case "SEND", "SOML", "SAML", "EXPN", "HELP", "TURN":
Expand Down Expand Up @@ -1276,7 +1288,13 @@ func (c *Conn) greet() {
if c.server.LMTP {
protocol = "LMTP"
}
c.writeResponse(220, NoEnhancedCode, fmt.Sprintf("%v %s Service Ready", c.server.Domain, protocol))
domain, err := c.server.GetDomain(c)
if err != nil {
c.writeResponse(554, NoEnhancedCode, "Error: "+err.Error())
c.rejected = true
return
}
c.writeResponse(220, NoEnhancedCode, fmt.Sprintf("%v %s Service Ready", domain, protocol))
}

func (c *Conn) writeResponse(code int, enhCode EnhancedCode, text ...string) {
Expand Down
14 changes: 14 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ type Server struct {
ErrorLog Logger
ReadTimeout time.Duration
WriteTimeout time.Duration
domainFn func(*Conn) (string, error)

// Advertise SMTPUTF8 (RFC 6531) capability.
// Should be used only if backend supports it.
Expand Down Expand Up @@ -148,6 +149,19 @@ func (s *Server) Serve(l net.Listener) error {
}
}

// Getter for the optional dynamic domain string generator function
func (s *Server) GetDomain(c *Conn) (string, error) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would better be implemented at the backend level.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if so, how? Initial greeting banner with the domain is presented by Conn.greet() called by the server that has just Accept()ed the connection and the backend is not even initialized before client issues an EHLO/HELO/LHLO in Conn.handleGreet(). I miss the point where the backend could have any control of this in its lifetime. Please advise.

if s.domainFn != nil {
return s.domainFn(c)
}
return s.Domain, nil
}

// Setter for the dynamic domain string generator function
func (s *Server) SetDomainFunc(fn func(*Conn) (string, error)) {
s.domainFn = fn
}

func (s *Server) handleConn(c *Conn) error {
s.locker.Lock()
s.conns[c] = struct{}{}
Expand Down
92 changes: 92 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1712,3 +1712,95 @@ func TestServerMTPRIORITY(t *testing.T) {
t.Fatal("Incorrect MtPriority parameter value:", fmt.Sprintf("expected %d, got %d", expectedPriority, *priority))
}
}

func getDynamicDomainResponse(c *smtp.Conn) (string, error) {
return "dynamichost.local", nil
}

func getNegativeDynamicDomainResponse(c *smtp.Conn) (string, error) {
return "", fmt.Errorf("no service")
}

func testServerDynamicDomain(t *testing.T, fn ...serverConfigureFunc) (be *backend, s *smtp.Server, c net.Conn, scanner *bufio.Scanner) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}

be = new(backend)
s = smtp.NewServer(be)
s.Domain = "localhost"
s.SetDomainFunc(getDynamicDomainResponse)
s.AllowInsecureAuth = true
for _, f := range fn {
f(s)
}

go s.Serve(l)

c, err = net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatal(err)
}

scanner = bufio.NewScanner(c)
return
}

func testServerNegativeDynamicDomain(t *testing.T, fn ...serverConfigureFunc) (be *backend, s *smtp.Server, c net.Conn, scanner *bufio.Scanner) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}

be = new(backend)
s = smtp.NewServer(be)
s.Domain = "localhost"
s.SetDomainFunc(getNegativeDynamicDomainResponse)
s.AllowInsecureAuth = true
for _, f := range fn {
f(s)
}

go s.Serve(l)

c, err = net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatal(err)
}

scanner = bufio.NewScanner(c)
return
}

func TestServerDynamicDomainGreeted(t *testing.T) {
_, _, _, scanner := testServerDynamicDomain(t)

scanner.Scan()
if scanner.Text() != "220 dynamichost.local ESMTP Service Ready" {
t.Fatal("Invalid greeting:", scanner.Text())
}
}

func TestServerNegativeDynamicDomainGreeted(t *testing.T) {
_, _, c, scanner := testServerNegativeDynamicDomain(t)

scanner.Scan()
if scanner.Text() != "554 Error: no service" {
t.Fatal("Invalid greeting:", scanner.Text())
}

//Now test for 503 error as per RFC521 3.1
io.WriteString(c, "HELO localhost\r\n")

scanner.Scan()
if !strings.HasPrefix(scanner.Text(), "503 bad sequence of commands") {
t.Fatal("Invalid HELO response:", scanner.Text())
}
io.WriteString(c, "QUIT\r\n")

scanner.Scan()
if !strings.HasPrefix(scanner.Text(), "221 OK") {
t.Fatal("Invalid HELO response:", scanner.Text())
}
}