From 5f26fd659ca03c27f3a2875a83818eb1d5ee4fd1 Mon Sep 17 00:00:00 2001 From: gdsmith Date: Thu, 22 May 2025 17:37:28 +0100 Subject: [PATCH 1/2] Don't store plaintext passwords Rather store hashed passwords in the way mysql server does. here we: - update CredentialProvider to return new Credential struct - Credential includes the plugin that the user was created with - update InMemoryProvider to handle hashing of passwords and add default auth method to make usage backwards compatible - update server authentication to use mysql server methods of comparing hashes rather than relying on having the plaintext password available - rework the password negotiation to switch plugin type to match the stored credentials - add hashing and comparison functions for the above where missing from existing libraries --- client/auth.go | 4 +- mysql/util.go | 176 ++++++++++++++++++++++++++++-- server/auth.go | 103 ++++++++--------- server/auth_switch_response.go | 88 +++++---------- server/caching_sha2_cache_test.go | 4 +- server/conn.go | 2 +- server/credential_provider.go | 74 +++++++++++-- server/handshake_resp.go | 10 +- server/resp.go | 1 + server/server_test.go | 12 +- 10 files changed, 325 insertions(+), 149 deletions(-) diff --git a/client/auth.go b/client/auth.go index 952accec5..b7a52a0e9 100644 --- a/client/auth.go +++ b/client/auth.go @@ -154,9 +154,9 @@ func (c *Conn) genAuthResponse(authData []byte) ([]byte, bool, error) { // password hashing switch c.authPluginName { case mysql.AUTH_NATIVE_PASSWORD: - return mysql.CalcPassword(authData[:20], []byte(c.password)), false, nil + return mysql.CalcNativePassword(authData[:20], []byte(c.password)), false, nil case mysql.AUTH_CACHING_SHA2_PASSWORD: - return mysql.CalcCachingSha2Password(authData, c.password), false, nil + return mysql.CalcCachingSha2Password(authData, []byte(c.password)), false, nil case mysql.AUTH_CLEAR_PASSWORD: return []byte(c.password), true, nil case mysql.AUTH_SHA256_PASSWORD: diff --git a/mysql/util.go b/mysql/util.go index 2e426262f..ece636dcb 100644 --- a/mysql/util.go +++ b/mysql/util.go @@ -9,7 +9,9 @@ import ( "crypto/sha1" "crypto/sha256" "crypto/sha512" + "crypto/subtle" "encoding/binary" + "encoding/hex" "fmt" "io" mrand "math/rand" @@ -29,7 +31,7 @@ func Pstack() string { return string(buf[0:n]) } -func CalcPassword(scramble, password []byte) []byte { +func CalcNativePassword(scramble, password []byte) []byte { if len(password) == 0 { return nil } @@ -39,27 +41,92 @@ func CalcPassword(scramble, password []byte) []byte { crypt.Write(password) stage1 := crypt.Sum(nil) - // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) - // inner Hash + // stage2Hash = SHA1(stage1Hash) crypt.Reset() crypt.Write(stage1) - hash := crypt.Sum(nil) + stage2 := crypt.Sum(nil) - // outer Hash + // scrambleHash = SHA1(scramble + stage2Hash) crypt.Reset() crypt.Write(scramble) - crypt.Write(hash) - scramble = crypt.Sum(nil) + crypt.Write(stage2) + scrambleHash := crypt.Sum(nil) // token = scrambleHash XOR stage1Hash - for i := range scramble { - scramble[i] ^= stage1[i] + return Xor(scrambleHash, stage1) +} + +func Xor(hash1 []byte, hash2 []byte) []byte { + for i := range hash1 { + hash1[i] ^= hash2[i] } - return scramble + return hash1 +} + +// hash_stage1 = xor(reply, sha1(public_seed, hash_stage2)) +func Stage1FromReply(scramble []byte, seed []byte, stage2 []byte) []byte { + crypt := sha1.New() + crypt.Write(seed) + crypt.Write(stage2) + seededHash := crypt.Sum(nil) + + return Xor(scramble, seededHash) +} + +// FROM vitess.io/vitess/go/mysql/auth_server.go +// DecodePasswordHex decodes the standard format used by MySQL +// for 4.1 style password hashes. It drops the optionally leading * before +// decoding the rest as a hex encoded string. +func DecodePasswordHex(hexEncodedPassword string) ([]byte, error) { + if hexEncodedPassword[0] == '*' { + hexEncodedPassword = hexEncodedPassword[1:] + } + return hex.DecodeString(hexEncodedPassword) +} + +// EncodePasswordHex encodes to the standard format used by MySQL +// adds the optionally leading * to the hashed password +func EncodePasswordHex(passwordHash []byte) string { + hexstr := strings.ToUpper(hex.EncodeToString(passwordHash)) + return "*" + hexstr +} + +// NativePasswordHash = sha1(sha1(password)) +func NativePasswordHash(password []byte) []byte { + if len(password) == 0 { + return nil + } + + // stage1Hash = SHA1(password) + crypt := sha1.New() + crypt.Write(password) + stage1 := crypt.Sum(nil) + + // stage2Hash = SHA1(stage1Hash) + crypt.Reset() + crypt.Write(stage1) + return crypt.Sum(nil) +} + +func CompareNativePassword(reply []byte, stored []byte, seed []byte) bool { + if len(stored) == 0 { + return false + } + + // hash_stage1 = xor(reply, sha1(public_seed, hash_stage2)) + stage1 := Stage1FromReply(reply, seed, stored) + // andidate_hash2 = sha1(hash_stage1) + crypt := sha1.New() + crypt.Write(stage1) + stage2 := crypt.Sum(nil) + + // check(candidate_hash2 == hash_stage2) + // use ConstantTimeCompare to mitigate timing based attacks + return subtle.ConstantTimeCompare(stage2, stored) == 1 } // CalcCachingSha2Password: Hash password using MySQL 8+ method (SHA256) -func CalcCachingSha2Password(scramble []byte, password string) []byte { +func CalcCachingSha2Password(scramble []byte, password []byte) []byte { if len(password) == 0 { return nil } @@ -67,7 +134,7 @@ func CalcCachingSha2Password(scramble []byte, password string) []byte { // XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble)) crypt := sha256.New() - crypt.Write([]byte(password)) + crypt.Write(password) message1 := crypt.Sum(nil) crypt.Reset() @@ -135,6 +202,91 @@ func EncryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, return rsa.EncryptOAEP(sha1v, rand.Reader, pub, plain, nil) } +const ( + SALT_LENGTH = 16 + ITERATION_MULTIPLIER = 1000 + SHA256_PASSWORD_ITERATIONS = 5 +) + +// generateUserSalt generate salt of given length for sha256_password hash +func generateUserSalt(length int) ([]byte, error) { + // Generate a random salt of the given length + // Implement this function for your project + salt := make([]byte, length) + _, err := rand.Read(salt) + if err != nil { + return []byte(""), err + } + + // Restrict to 7-bit to avoid multi-byte UTF-8 + for i := range salt { + salt[i] = salt[i] &^ 128 + for salt[i] == 36 || salt[i] == 0 { // '$' or NUL + newval := make([]byte, 1) + _, err := rand.Read(newval) + if err != nil { + return []byte(""), err + } + salt[i] = newval[0] &^ 128 + } + } + return salt, nil +} + +// hashCrypt256 salt and hash a password the given number of iterations +func hashCrypt256(source, salt string, iterations uint64) (string, error) { + actualIterations := iterations * ITERATION_MULTIPLIER + hashInput := []byte(source + salt) + var hash [32]byte + for i := uint64(0); i < actualIterations; i++ { + h := sha256.New() + h.Write(hashInput) + hash = sha256.Sum256(h.Sum(nil)) + hashInput = hash[:] + } + + hashHex := hex.EncodeToString(hash[:]) + digest := fmt.Sprintf("$%d$%s$%s", iterations, salt, hashHex) + return digest, nil +} + +// Check256HashingPassword compares a password to a hash for sha256_password +// rather than trying to recreate just the hash we recreate the full hash +// and use that for comparison +func Check256HashingPassword(pwhash []byte, password string) (bool, error) { + pwHashParts := bytes.Split(pwhash, []byte("$")) + if len(pwHashParts) != 4 { + return false, errors.New("failed to decode hash parts") + } + + iterationsPart := pwHashParts[1] + if len(iterationsPart) == 0 { + return false, errors.New("iterations part is empty") + } + + iterations, err := strconv.ParseUint(string(iterationsPart), 10, 64) + if err != nil { + return false, errors.New("failed to decode iterations") + } + salt := pwHashParts[2][:SALT_LENGTH] + + newHash, err := hashCrypt256(password, string(salt), iterations) + if err != nil { + return false, err + } + + return bytes.Equal(pwhash, []byte(newHash)), nil +} + +// NewSha256PasswordHash creates a new password hash for sha256_password +func NewSha256PasswordHash(pwd string) (string, error) { + salt, err := generateUserSalt(SALT_LENGTH) + if err != nil { + return "", err + } + return hashCrypt256(pwd, string(salt), SHA256_PASSWORD_ITERATIONS) +} + func DecompressMariadbData(data []byte) ([]byte, error) { // algorithm always 0=zlib // algorithm := (data[pos] & 0x07) >> 4 diff --git a/server/auth.go b/server/auth.go index 9d1f8101c..77b7c569a 100644 --- a/server/auth.go +++ b/server/auth.go @@ -19,26 +19,32 @@ var ( ) func (c *Conn) compareAuthData(authPluginName string, clientAuthData []byte) error { - switch authPluginName { - case mysql.AUTH_NATIVE_PASSWORD: - if err := c.acquirePassword(); err != nil { + if authPluginName != c.credential.authPluginName { + err := c.writeAuthSwitchRequest(c.credential.authPluginName) + if err != nil { return err } - return c.compareNativePasswordAuthData(clientAuthData, c.password) + return c.handleAuthSwitchResponse() + } + + switch authPluginName { + case mysql.AUTH_NATIVE_PASSWORD: + return c.compareNativePasswordAuthData(clientAuthData, c.credential) case mysql.AUTH_CACHING_SHA2_PASSWORD: - if err := c.compareCacheSha2PasswordAuthData(clientAuthData); err != nil { - return err - } - if c.cachingSha2FullAuth { - return c.handleAuthSwitchResponse() + if !c.cachingSha2FullAuth { + if err := c.compareCacheSha2PasswordAuthData(clientAuthData); err != nil { + return err + } + if c.cachingSha2FullAuth { + return c.handleAuthSwitchResponse() + } + return nil } - return nil + // AuthMoreData packet already sent, do full auth + return c.handleCachingSha2PasswordFullAuth(clientAuthData) case mysql.AUTH_SHA256_PASSWORD: - if err := c.acquirePassword(); err != nil { - return err - } cont, err := c.handlePublicKeyRetrieval(clientAuthData) if err != nil { return err @@ -46,7 +52,7 @@ func (c *Conn) compareAuthData(authPluginName string, clientAuthData []byte) err if !cont { return nil } - return c.compareSha256PasswordAuthData(clientAuthData, c.password) + return c.compareSha256PasswordAuthData(clientAuthData, c.credential) default: return errors.Errorf("unknown authentication plugin name '%s'", authPluginName) @@ -54,19 +60,19 @@ func (c *Conn) compareAuthData(authPluginName string, clientAuthData []byte) err } func (c *Conn) acquirePassword() error { - password, found, err := c.credentialProvider.GetCredential(c.user) + credential, found, err := c.credentialProvider.GetCredential(c.user) if err != nil { return err } if !found { return mysql.NewDefaultError(mysql.ER_NO_SUCH_USER, c.user, c.RemoteAddr().String()) } - c.password = password + c.credential = credential return nil } -func errAccessDenied(password string) error { - if password == "" { +func errAccessDenied(credential Credential) error { + if credential.password == "" { return ErrAccessDeniedNoPassword } @@ -93,17 +99,21 @@ func scrambleValidation(cached, nonce, scramble []byte) bool { return bytes.Equal(m, cached) } -func (c *Conn) compareNativePasswordAuthData(clientAuthData []byte, password string) error { - if bytes.Equal(mysql.CalcPassword(c.salt, []byte(password)), clientAuthData) { +func (c *Conn) compareNativePasswordAuthData(clientAuthData []byte, credential Credential) error { + password, err := mysql.DecodePasswordHex(c.credential.password) + if err != nil { + return errAccessDenied(credential) + } + if mysql.CompareNativePassword(clientAuthData, password, c.salt) { return nil } - return errAccessDenied(password) + return errAccessDenied(credential) } -func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, password string) error { +func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, credential Credential) error { // Empty passwords are not hashed, but sent as empty string if len(clientAuthData) == 0 { - if password == "" { + if credential.password == "" { return nil } return ErrAccessDenied @@ -117,10 +127,6 @@ func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, password str if l := len(clientAuthData); l != 0 && clientAuthData[l-1] == 0x00 { clientAuthData = clientAuthData[:l-1] } - if bytes.Equal(clientAuthData, []byte(password)) { - return nil - } - return errAccessDenied(password) } else { // client should send encrypted password // decrypt @@ -128,46 +134,31 @@ func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, password str if err != nil { return err } - plain := make([]byte, len(password)+1) - copy(plain, password) - for i := range plain { - j := i % len(c.salt) - plain[i] ^= c.salt[j] - } - if bytes.Equal(plain, dbytes) { - return nil + clientAuthData = mysql.Xor(dbytes, c.salt) + if l := len(clientAuthData); l != 0 && clientAuthData[l-1] == 0x00 { + clientAuthData = clientAuthData[:l-1] } - return errAccessDenied(password) } + check, err := mysql.Check256HashingPassword([]byte(credential.password), string(clientAuthData)) + if err != nil { + return err + } + if check { + return nil + } + return ErrAccessDenied } func (c *Conn) compareCacheSha2PasswordAuthData(clientAuthData []byte) error { // Empty passwords are not hashed, but sent as empty string if len(clientAuthData) == 0 { - if err := c.acquirePassword(); err != nil { - return err - } - if c.password == "" { + if c.credential.password == "" { return nil } return ErrAccessDenied } // the caching of 'caching_sha2_password' in MySQL, see: https://dev.mysql.com/worklog/task/?id=9591 - if _, ok := c.credentialProvider.(*InMemoryProvider); ok { - // since we have already kept the password in memory and calculate the scramble is not that high of cost, we eliminate - // the caching part. So our server will never ask the client to do a full authentication via RSA key exchange and it appears - // like the auth will always hit the cache. - if err := c.acquirePassword(); err != nil { - return err - } - if bytes.Equal(mysql.CalcCachingSha2Password(c.salt, c.password), clientAuthData) { - // 'fast' auth: write "More data" packet (first byte == 0x01) with the second byte = 0x03 - return c.writeAuthMoreDataFastAuth() - } - - return errAccessDenied(c.password) - } - // other type of credential provider, we use the cache + // check if we have a cached value cached, ok := c.serverConf.cacheShaPassword.Load(fmt.Sprintf("%s@%s", c.user, c.LocalAddr())) if ok { // Scramble validation @@ -176,7 +167,7 @@ func (c *Conn) compareCacheSha2PasswordAuthData(clientAuthData []byte) error { return c.writeAuthMoreDataFastAuth() } - return errAccessDenied(c.password) + return errAccessDenied(c.credential) } // cache miss, do full auth if err := c.writeAuthMoreDataFullAuth(); err != nil { diff --git a/server/auth_switch_response.go b/server/auth_switch_response.go index 9baccc03e..003016a01 100644 --- a/server/auth_switch_response.go +++ b/server/auth_switch_response.go @@ -1,7 +1,6 @@ package server import ( - "bytes" "crypto/rand" "crypto/rsa" "crypto/sha1" @@ -9,6 +8,8 @@ import ( "crypto/tls" "fmt" + "github.com/pingcap/tidb/pkg/parser/auth" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/errors" ) @@ -19,48 +20,7 @@ func (c *Conn) handleAuthSwitchResponse() error { return err } - switch c.authPluginName { - case mysql.AUTH_NATIVE_PASSWORD: - if err := c.acquirePassword(); err != nil { - return err - } - return c.compareNativePasswordAuthData(authData, c.password) - - case mysql.AUTH_CACHING_SHA2_PASSWORD: - if !c.cachingSha2FullAuth { - // Switched auth method but no MoreData packet send yet - if err := c.compareCacheSha2PasswordAuthData(authData); err != nil { - return err - } else { - if c.cachingSha2FullAuth { - return c.handleAuthSwitchResponse() - } - return nil - } - } - // AuthMoreData packet already sent, do full auth - if err := c.handleCachingSha2PasswordFullAuth(authData); err != nil { - return err - } - c.writeCachingSha2Cache() - return nil - - case mysql.AUTH_SHA256_PASSWORD: - cont, err := c.handlePublicKeyRetrieval(authData) - if err != nil { - return err - } - if !cont { - return nil - } - if err := c.acquirePassword(); err != nil { - return err - } - return c.compareSha256PasswordAuthData(authData, c.password) - - default: - return errors.Errorf("unknown authentication plugin name '%s'", c.authPluginName) - } + return c.compareAuthData(c.authPluginName, authData) } func (c *Conn) handleCachingSha2PasswordFullAuth(authData []byte) error { @@ -76,10 +36,6 @@ func (c *Conn) handleCachingSha2PasswordFullAuth(authData []byte) error { if l := len(authData); l != 0 && authData[l-1] == 0x00 { authData = authData[:l-1] } - if bytes.Equal(authData, []byte(c.password)) { - return nil - } - return errAccessDenied(c.password) } else { // client either request for the public key or send the encrypted password if len(authData) == 1 && authData[0] == 0x02 { @@ -99,27 +55,41 @@ func (c *Conn) handleCachingSha2PasswordFullAuth(authData []byte) error { if err != nil { return err } - plain := make([]byte, len(c.password)+1) - copy(plain, c.password) - for i := range plain { - j := i % len(c.salt) - plain[i] ^= c.salt[j] - } - if bytes.Equal(plain, dbytes) { - return nil + authData = mysql.Xor(dbytes, c.salt) + if l := len(authData); l != 0 && authData[l-1] == 0x00 { + authData = authData[:l-1] } - return errAccessDenied(c.password) } + err := c.checkSha2CacheCredentials(authData, c.credential) + if err != nil { + return err + } + // write cache on successful auth - needs to be here as we have the decrypted password + // and we need to store an unsalted hashed version of the plaintext password in the cache + c.writeCachingSha2Cache(authData) + return nil +} + +func (c *Conn) checkSha2CacheCredentials(clientAuthData []byte, credential Credential) error { + match, err := auth.CheckHashingPassword([]byte(credential.password), string(clientAuthData), mysql.AUTH_CACHING_SHA2_PASSWORD) + if match && err == nil { + return nil + } + return errAccessDenied(credential) } -func (c *Conn) writeCachingSha2Cache() { +func (c *Conn) writeCachingSha2Cache(authData []byte) { // write cache - if c.password == "" { + if authData == nil { return } + + if l := len(authData); l != 0 && authData[l-1] == 0x00 { + authData = authData[:l-1] + } // SHA256(PASSWORD) crypt := sha256.New() - crypt.Write([]byte(c.password)) + crypt.Write(authData) m1 := crypt.Sum(nil) // SHA256(SHA256(PASSWORD)) crypt.Reset() diff --git a/server/caching_sha2_cache_test.go b/server/caching_sha2_cache_test.go index 9c9530811..4f018bba6 100644 --- a/server/caching_sha2_cache_test.go +++ b/server/caching_sha2_cache_test.go @@ -58,7 +58,7 @@ type RemoteThrottleProvider struct { getCredCallCount atomic.Int64 } -func (m *RemoteThrottleProvider) GetCredential(username string) (password string, found bool, err error) { +func (m *RemoteThrottleProvider) GetCredential(username string) (credential Credential, found bool, err error) { m.getCredCallCount.Add(1) return m.InMemoryProvider.GetCredential(username) } @@ -107,7 +107,7 @@ func (s *cacheTestSuite) onAccept() { func (s *cacheTestSuite) onConn(conn net.Conn) { // co, err := NewConn(conn, *testUser, *testPassword, &testHandler{s}) - co, err := NewCustomizedConn(conn, s.server, s.credProvider, &testCacheHandler{s}) + co, err := s.server.NewCustomizedConn(conn, s.credProvider, &testCacheHandler{s}) require.NoError(s.T(), err) for { err = co.HandleCommand() diff --git a/server/conn.go b/server/conn.go index ad3509aea..1f99bdfa2 100644 --- a/server/conn.go +++ b/server/conn.go @@ -26,7 +26,7 @@ type Conn struct { credentialProvider CredentialProvider user string - password string + credential Credential cachingSha2FullAuth bool h Handler diff --git a/server/credential_provider.go b/server/credential_provider.go index 11014d916..f8e0061f2 100644 --- a/server/credential_provider.go +++ b/server/credential_provider.go @@ -1,6 +1,12 @@ package server -import "sync" +import ( + "sync" + + "github.com/go-mysql-org/go-mysql/mysql" + "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/parser/auth" +) // interface for user credential provider // hint: can be extended for more functionality @@ -13,18 +19,29 @@ type CredentialProvider interface { // check if the user exists CheckUsername(username string) (bool, error) // get user credential - GetCredential(username string) (password string, found bool, err error) + GetCredential(username string) (credential Credential, found bool, err error) } -func NewInMemoryProvider() *InMemoryProvider { +func NewInMemoryProvider(defaultAuthMethod ...string) *InMemoryProvider { + d := mysql.AUTH_CACHING_SHA2_PASSWORD + if len(defaultAuthMethod) > 0 { + d = defaultAuthMethod[0] + } return &InMemoryProvider{ - userPool: sync.Map{}, + userPool: sync.Map{}, + defaultAuthMethod: d, } } -// implements a in memory credential provider +type Credential struct { + password string + authPluginName string +} + +// implements an in memory credential provider type InMemoryProvider struct { - userPool sync.Map // username -> password + userPool sync.Map // username -> password + defaultAuthMethod string } func (m *InMemoryProvider) CheckUsername(username string) (found bool, err error) { @@ -32,16 +49,51 @@ func (m *InMemoryProvider) CheckUsername(username string) (found bool, err error return ok, nil } -func (m *InMemoryProvider) GetCredential(username string) (password string, found bool, err error) { +func (m *InMemoryProvider) GetCredential(username string) (credential Credential, found bool, err error) { v, ok := m.userPool.Load(username) if !ok { - return "", false, nil + return Credential{}, false, nil } - return v.(string), true, nil + c, valid := v.(Credential) + if !valid { + return Credential{}, true, errors.Errorf("invalid credential") + } + return c, true, nil } -func (m *InMemoryProvider) AddUser(username, password string) { - m.userPool.Store(username, password) +func (m *InMemoryProvider) AddUser(username, password string, authPluginName ...string) error { + c := Credential{ + authPluginName: m.defaultAuthMethod, + } + if len(authPluginName) > 0 { + c.authPluginName = authPluginName[0] + } + + if password == "" { + c.password = "" + m.userPool.Store(username, c) + return nil + } + + switch c.authPluginName { + case mysql.AUTH_NATIVE_PASSWORD: + c.password = mysql.EncodePasswordHex(mysql.NativePasswordHash([]byte(password))) + + case mysql.AUTH_CACHING_SHA2_PASSWORD: + c.password = auth.NewHashPassword(password, mysql.AUTH_CACHING_SHA2_PASSWORD) + + case mysql.AUTH_SHA256_PASSWORD: + hash, err := mysql.NewSha256PasswordHash(password) + if err != nil { + return err + } + c.password = hash + + default: + return errors.Errorf("unknown authentication plugin name '%s'", c.authPluginName) + } + m.userPool.Store(username, c) + return nil } type Provider InMemoryProvider diff --git a/server/handshake_resp.go b/server/handshake_resp.go index 68cade122..1d3e088c5 100644 --- a/server/handshake_resp.go +++ b/server/handshake_resp.go @@ -200,12 +200,16 @@ func (c *Conn) handlePublicKeyRetrieval(authData []byte) (bool, error) { func (c *Conn) handleAuthMatch() (bool, error) { // if the client responds the handshake with a different auth method, the server will send the AuthSwitchRequest packet // to the client to ask the client to switch. + credential, _, err := c.credentialProvider.GetCredential(c.user) + if err != nil { + return false, err + } + c.credential = credential - if c.authPluginName != c.serverConf.defaultAuthMethod { - if err := c.writeAuthSwitchRequest(c.serverConf.defaultAuthMethod); err != nil { + if c.authPluginName != credential.authPluginName { + if err := c.writeAuthSwitchRequest(credential.authPluginName); err != nil { return false, err } - c.authPluginName = c.serverConf.defaultAuthMethod // handle AuthSwitchResponse return false, c.handleAuthSwitchResponse() } diff --git a/server/resp.go b/server/resp.go index c6f13fe9e..e4b21aad2 100644 --- a/server/resp.go +++ b/server/resp.go @@ -66,6 +66,7 @@ func (c *Conn) writeEOF() error { // see: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_switch_request.html func (c *Conn) writeAuthSwitchRequest(newAuthPluginName string) error { + c.authPluginName = newAuthPluginName data := make([]byte, 4) data = append(data, mysql.EOF_HEADER) data = append(data, []byte(newAuthPluginName)...) diff --git a/server/server_test.go b/server/server_test.go index 2ad8f4e7b..c1e827f05 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -61,11 +61,13 @@ func prepareServerConf() []*Server { func Test(t *testing.T) { // general tests inMemProvider := NewInMemoryProvider() - inMemProvider.AddUser(*testUser, *testPassword) - servers := prepareServerConf() // no TLS for _, svr := range servers { + inMemProvider.userPool.Clear() + err := inMemProvider.AddUser(*testUser, *testPassword, svr.defaultAuthMethod) + require.NoError(t, err) + suite.Run(t, &serverTestSuite{ server: svr, credProvider: inMemProvider, @@ -76,6 +78,10 @@ func Test(t *testing.T) { // TLS if server supports for _, svr := range servers { if svr.tlsConfig != nil { + inMemProvider.userPool.Clear() + err := inMemProvider.AddUser(*testUser, *testPassword, svr.defaultAuthMethod) + require.NoError(t, err) + suite.Run(t, &serverTestSuite{ server: svr, credProvider: inMemProvider, @@ -138,7 +144,7 @@ func (s *serverTestSuite) onAccept() { func (s *serverTestSuite) onConn(conn net.Conn) { // co, err := NewConn(conn, *testUser, *testPassword, &testHandler{s}) - co, err := NewCustomizedConn(conn, s.server, s.credProvider, &testHandler{s}) + co, err := s.server.NewCustomizedConn(conn, s.credProvider, &testHandler{s}) require.NoError(s.T(), err) // set SSL if defined for { From 43cc679080cc1d842833afefc191e1384114a748 Mon Sep 17 00:00:00 2001 From: gdsmith Date: Fri, 23 May 2025 09:25:06 +0100 Subject: [PATCH 2/2] Update to ensure we only request the password once also update the caching test to allow for the fact that we request the password once on every auth attempt --- server/auth.go | 3 +++ server/caching_sha2_cache_test.go | 2 +- server/handshake_resp.go | 8 +++----- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/server/auth.go b/server/auth.go index 77b7c569a..e3905db39 100644 --- a/server/auth.go +++ b/server/auth.go @@ -60,6 +60,9 @@ func (c *Conn) compareAuthData(authPluginName string, clientAuthData []byte) err } func (c *Conn) acquirePassword() error { + if c.credential.password != "" { + return nil + } credential, found, err := c.credentialProvider.GetCredential(c.user) if err != nil { return err diff --git a/server/caching_sha2_cache_test.go b/server/caching_sha2_cache_test.go index 4f018bba6..9d53a7105 100644 --- a/server/caching_sha2_cache_test.go +++ b/server/caching_sha2_cache_test.go @@ -147,7 +147,7 @@ func (s *cacheTestSuite) TestCache() { s.db.SetMaxIdleConns(4) s.runSelect() got = s.credProvider.(*RemoteThrottleProvider).getCredCallCount.Load() - require.Equal(s.T(), int64(1), got) + require.Equal(s.T(), int64(2), got) if s.db != nil { s.db.Close() diff --git a/server/handshake_resp.go b/server/handshake_resp.go index 1d3e088c5..05d3e6367 100644 --- a/server/handshake_resp.go +++ b/server/handshake_resp.go @@ -200,14 +200,12 @@ func (c *Conn) handlePublicKeyRetrieval(authData []byte) (bool, error) { func (c *Conn) handleAuthMatch() (bool, error) { // if the client responds the handshake with a different auth method, the server will send the AuthSwitchRequest packet // to the client to ask the client to switch. - credential, _, err := c.credentialProvider.GetCredential(c.user) - if err != nil { + if err := c.acquirePassword(); err != nil { return false, err } - c.credential = credential - if c.authPluginName != credential.authPluginName { - if err := c.writeAuthSwitchRequest(credential.authPluginName); err != nil { + if c.authPluginName != c.credential.authPluginName { + if err := c.writeAuthSwitchRequest(c.credential.authPluginName); err != nil { return false, err } // handle AuthSwitchResponse