diff --git a/auth_scram.go b/auth_scram.go index 6d6d0651..665fc2c2 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -47,11 +47,11 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { } // Receive server-first-message payload in a AuthenticationSASLContinue. - authMsg, err := c.rxAuthMsg(pgproto3.AuthTypeSASLContinue) + saslContinue, err := c.rxSASLContinue() if err != nil { return err } - err = sc.recvServerFirstMessage(authMsg.SASLData) + err = sc.recvServerFirstMessage(saslContinue.Data) if err != nil { return err } @@ -66,27 +66,37 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { } // Receive server-final-message payload in a AuthenticationSASLFinal. - authMsg, err = c.rxAuthMsg(pgproto3.AuthTypeSASLFinal) + saslFinal, err := c.rxSASLFinal() if err != nil { return err } - return sc.recvServerFinalMessage(authMsg.SASLData) + return sc.recvServerFinalMessage(saslFinal.Data) } -func (c *PgConn) rxAuthMsg(typ uint32) (*pgproto3.Authentication, error) { +func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) { msg, err := c.receiveMessage() if err != nil { return nil, err } - authMsg, ok := msg.(*pgproto3.Authentication) - if !ok { - return nil, errors.New("unexpected message type") - } - if authMsg.Type != typ { - return nil, errors.New("unexpected auth type") + saslContinue, ok := msg.(*pgproto3.AuthenticationSASLContinue) + if ok { + return saslContinue, nil } - return authMsg, nil + return nil, errors.New("expected AuthenticationSASLContinue message but received unexpected message") +} + +func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) { + msg, err := c.receiveMessage() + if err != nil { + return nil, err + } + saslFinal, ok := msg.(*pgproto3.AuthenticationSASLFinal) + if ok { + return saslFinal, nil + } + + return nil, errors.New("expected AuthenticationSASLFinal message but received unexpected message") } type scramClient struct { diff --git a/config.go b/config.go index d24d0202..d1267621 100644 --- a/config.go +++ b/config.go @@ -501,7 +501,7 @@ func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc { if err != nil { panic(fmt.Sprintf("BUG: chunkreader.NewConfig failed: %v", err)) } - frontend, _ := pgproto3.NewFrontend(cr, w) + frontend := pgproto3.NewFrontend(cr, w) return frontend } diff --git a/go.mod b/go.mod index b54607b6..6e270cd6 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/jackc/chunkreader/v2 v2.0.0 github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.0-rc2 + github.com/jackc/pgproto3/v2 v2.0.0-rc3 github.com/stretchr/testify v1.3.0 golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a golang.org/x/text v0.3.0 diff --git a/go.sum b/go.sum index d7a6d087..ed8eb401 100644 --- a/go.sum +++ b/go.sum @@ -1,17 +1,13 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= -github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3/v2 v2.0.0-rc2 h1:u+jUsxBxiLY2C6mhr8cZhSy71n/y8Id2STOzJ7bl2Mg= -github.com/jackc/pgproto3/v2 v2.0.0-rc2/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= -github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/jackc/pgproto3/v2 v2.0.0-rc3 h1:EHkgVE6iDyI7HZDfMPaZ2Xjdf7C29DikR6o39WVO61c= +github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/pgconn.go b/pgconn.go index 1e3f9515..d51eb76a 100644 --- a/pgconn.go +++ b/pgconn.go @@ -210,11 +210,28 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig case *pgproto3.BackendKeyData: pgConn.pid = msg.ProcessID pgConn.secretKey = msg.SecretKey - case *pgproto3.Authentication: - if err = pgConn.rxAuthenticationX(msg); err != nil { + + case *pgproto3.AuthenticationOk: + case *pgproto3.AuthenticationCleartextPassword: + err = pgConn.txPasswordMessage(pgConn.config.Password) + if err != nil { pgConn.conn.Close() - return nil, &connectError{config: config, msg: "failed handle authentication message", err: err} + return nil, &connectError{config: config, msg: "failed to write password message", err: err} } + case *pgproto3.AuthenticationMD5Password: + digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) + err = pgConn.txPasswordMessage(digestedPassword) + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "failed to write password message", err: err} + } + case *pgproto3.AuthenticationSASL: + err = pgConn.scramAuth(msg.AuthMechanisms) + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "failed SASL auth", err: err} + } + case *pgproto3.ReadyForQuery: pgConn.status = connStatusIdle if config.ValidateConnect != nil { @@ -257,23 +274,6 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { return nil } -func (pgConn *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { - switch msg.Type { - case pgproto3.AuthTypeOk: - case pgproto3.AuthTypeCleartextPassword: - err = pgConn.txPasswordMessage(pgConn.config.Password) - case pgproto3.AuthTypeMD5Password: - digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) - err = pgConn.txPasswordMessage(digestedPassword) - case pgproto3.AuthTypeSASL: - err = pgConn.scramAuth(msg.SASLAuthMechanisms) - default: - err = errors.New("Received unknown authentication message") - } - - return -} - func (pgConn *PgConn) txPasswordMessage(password string) (err error) { msg := &pgproto3.PasswordMessage{Password: password} _, err = pgConn.conn.Write(msg.Encode(pgConn.wbuf))