Update to newest pgproto3
This commit is contained in:
+22
-12
@@ -47,11 +47,11 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
|
||||
}
|
||||
|
||||
// Receive server-first-message payload in a AuthenticationSASLContinue.
|
||||
authMsg, err := c.rxAuthMsg(pgproto3.AuthTypeSASLContinue)
|
||||
saslContinue, err := c.rxSASLContinue()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = sc.recvServerFirstMessage(authMsg.SASLData)
|
||||
err = sc.recvServerFirstMessage(saslContinue.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -66,27 +66,37 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
|
||||
}
|
||||
|
||||
// Receive server-final-message payload in a AuthenticationSASLFinal.
|
||||
authMsg, err = c.rxAuthMsg(pgproto3.AuthTypeSASLFinal)
|
||||
saslFinal, err := c.rxSASLFinal()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sc.recvServerFinalMessage(authMsg.SASLData)
|
||||
return sc.recvServerFinalMessage(saslFinal.Data)
|
||||
}
|
||||
|
||||
func (c *PgConn) rxAuthMsg(typ uint32) (*pgproto3.Authentication, error) {
|
||||
func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) {
|
||||
msg, err := c.receiveMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authMsg, ok := msg.(*pgproto3.Authentication)
|
||||
if !ok {
|
||||
return nil, errors.New("unexpected message type")
|
||||
}
|
||||
if authMsg.Type != typ {
|
||||
return nil, errors.New("unexpected auth type")
|
||||
saslContinue, ok := msg.(*pgproto3.AuthenticationSASLContinue)
|
||||
if ok {
|
||||
return saslContinue, nil
|
||||
}
|
||||
|
||||
return authMsg, nil
|
||||
return nil, errors.New("expected AuthenticationSASLContinue message but received unexpected message")
|
||||
}
|
||||
|
||||
func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
|
||||
msg, err := c.receiveMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
saslFinal, ok := msg.(*pgproto3.AuthenticationSASLFinal)
|
||||
if ok {
|
||||
return saslFinal, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("expected AuthenticationSASLFinal message but received unexpected message")
|
||||
}
|
||||
|
||||
type scramClient struct {
|
||||
|
||||
@@ -501,7 +501,7 @@ func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc {
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("BUG: chunkreader.NewConfig failed: %v", err))
|
||||
}
|
||||
frontend, _ := pgproto3.NewFrontend(cr, w)
|
||||
frontend := pgproto3.NewFrontend(cr, w)
|
||||
|
||||
return frontend
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ require (
|
||||
github.com/jackc/chunkreader/v2 v2.0.0
|
||||
github.com/jackc/pgio v1.0.0
|
||||
github.com/jackc/pgpassfile v1.0.0
|
||||
github.com/jackc/pgproto3/v2 v2.0.0-rc2
|
||||
github.com/jackc/pgproto3/v2 v2.0.0-rc3
|
||||
github.com/stretchr/testify v1.3.0
|
||||
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a
|
||||
golang.org/x/text v0.3.0
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0=
|
||||
github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=
|
||||
github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs=
|
||||
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
|
||||
github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE=
|
||||
github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgproto3/v2 v2.0.0-rc2 h1:u+jUsxBxiLY2C6mhr8cZhSy71n/y8Id2STOzJ7bl2Mg=
|
||||
github.com/jackc/pgproto3/v2 v2.0.0-rc2/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg=
|
||||
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/jackc/pgproto3/v2 v2.0.0-rc3 h1:EHkgVE6iDyI7HZDfMPaZ2Xjdf7C29DikR6o39WVO61c=
|
||||
github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
|
||||
@@ -210,11 +210,28 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
||||
case *pgproto3.BackendKeyData:
|
||||
pgConn.pid = msg.ProcessID
|
||||
pgConn.secretKey = msg.SecretKey
|
||||
case *pgproto3.Authentication:
|
||||
if err = pgConn.rxAuthenticationX(msg); err != nil {
|
||||
|
||||
case *pgproto3.AuthenticationOk:
|
||||
case *pgproto3.AuthenticationCleartextPassword:
|
||||
err = pgConn.txPasswordMessage(pgConn.config.Password)
|
||||
if err != nil {
|
||||
pgConn.conn.Close()
|
||||
return nil, &connectError{config: config, msg: "failed handle authentication message", err: err}
|
||||
return nil, &connectError{config: config, msg: "failed to write password message", err: err}
|
||||
}
|
||||
case *pgproto3.AuthenticationMD5Password:
|
||||
digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:]))
|
||||
err = pgConn.txPasswordMessage(digestedPassword)
|
||||
if err != nil {
|
||||
pgConn.conn.Close()
|
||||
return nil, &connectError{config: config, msg: "failed to write password message", err: err}
|
||||
}
|
||||
case *pgproto3.AuthenticationSASL:
|
||||
err = pgConn.scramAuth(msg.AuthMechanisms)
|
||||
if err != nil {
|
||||
pgConn.conn.Close()
|
||||
return nil, &connectError{config: config, msg: "failed SASL auth", err: err}
|
||||
}
|
||||
|
||||
case *pgproto3.ReadyForQuery:
|
||||
pgConn.status = connStatusIdle
|
||||
if config.ValidateConnect != nil {
|
||||
@@ -257,23 +274,6 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pgConn *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) {
|
||||
switch msg.Type {
|
||||
case pgproto3.AuthTypeOk:
|
||||
case pgproto3.AuthTypeCleartextPassword:
|
||||
err = pgConn.txPasswordMessage(pgConn.config.Password)
|
||||
case pgproto3.AuthTypeMD5Password:
|
||||
digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:]))
|
||||
err = pgConn.txPasswordMessage(digestedPassword)
|
||||
case pgproto3.AuthTypeSASL:
|
||||
err = pgConn.scramAuth(msg.SASLAuthMechanisms)
|
||||
default:
|
||||
err = errors.New("Received unknown authentication message")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
|
||||
msg := &pgproto3.PasswordMessage{Password: password}
|
||||
_, err = pgConn.conn.Write(msg.Encode(pgConn.wbuf))
|
||||
|
||||
Reference in New Issue
Block a user