2
0

Update to newest pgproto3

This commit is contained in:
Jack Christensen
2019-08-31 15:44:54 -05:00
parent 6bba3c4810
commit 2fabfa3c18
5 changed files with 46 additions and 40 deletions
+22 -12
View File
@@ -47,11 +47,11 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
}
// Receive server-first-message payload in a AuthenticationSASLContinue.
authMsg, err := c.rxAuthMsg(pgproto3.AuthTypeSASLContinue)
saslContinue, err := c.rxSASLContinue()
if err != nil {
return err
}
err = sc.recvServerFirstMessage(authMsg.SASLData)
err = sc.recvServerFirstMessage(saslContinue.Data)
if err != nil {
return err
}
@@ -66,27 +66,37 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
}
// Receive server-final-message payload in a AuthenticationSASLFinal.
authMsg, err = c.rxAuthMsg(pgproto3.AuthTypeSASLFinal)
saslFinal, err := c.rxSASLFinal()
if err != nil {
return err
}
return sc.recvServerFinalMessage(authMsg.SASLData)
return sc.recvServerFinalMessage(saslFinal.Data)
}
func (c *PgConn) rxAuthMsg(typ uint32) (*pgproto3.Authentication, error) {
func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) {
msg, err := c.receiveMessage()
if err != nil {
return nil, err
}
authMsg, ok := msg.(*pgproto3.Authentication)
if !ok {
return nil, errors.New("unexpected message type")
}
if authMsg.Type != typ {
return nil, errors.New("unexpected auth type")
saslContinue, ok := msg.(*pgproto3.AuthenticationSASLContinue)
if ok {
return saslContinue, nil
}
return authMsg, nil
return nil, errors.New("expected AuthenticationSASLContinue message but received unexpected message")
}
func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
msg, err := c.receiveMessage()
if err != nil {
return nil, err
}
saslFinal, ok := msg.(*pgproto3.AuthenticationSASLFinal)
if ok {
return saslFinal, nil
}
return nil, errors.New("expected AuthenticationSASLFinal message but received unexpected message")
}
type scramClient struct {
+1 -1
View File
@@ -501,7 +501,7 @@ func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc {
if err != nil {
panic(fmt.Sprintf("BUG: chunkreader.NewConfig failed: %v", err))
}
frontend, _ := pgproto3.NewFrontend(cr, w)
frontend := pgproto3.NewFrontend(cr, w)
return frontend
}
+1 -1
View File
@@ -6,7 +6,7 @@ require (
github.com/jackc/chunkreader/v2 v2.0.0
github.com/jackc/pgio v1.0.0
github.com/jackc/pgpassfile v1.0.0
github.com/jackc/pgproto3/v2 v2.0.0-rc2
github.com/jackc/pgproto3/v2 v2.0.0-rc3
github.com/stretchr/testify v1.3.0
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a
golang.org/x/text v0.3.0
+2 -6
View File
@@ -1,17 +1,13 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0=
github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=
github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs=
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE=
github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgproto3/v2 v2.0.0-rc2 h1:u+jUsxBxiLY2C6mhr8cZhSy71n/y8Id2STOzJ7bl2Mg=
github.com/jackc/pgproto3/v2 v2.0.0-rc2/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/jackc/pgproto3/v2 v2.0.0-rc3 h1:EHkgVE6iDyI7HZDfMPaZ2Xjdf7C29DikR6o39WVO61c=
github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+20 -20
View File
@@ -210,11 +210,28 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
case *pgproto3.BackendKeyData:
pgConn.pid = msg.ProcessID
pgConn.secretKey = msg.SecretKey
case *pgproto3.Authentication:
if err = pgConn.rxAuthenticationX(msg); err != nil {
case *pgproto3.AuthenticationOk:
case *pgproto3.AuthenticationCleartextPassword:
err = pgConn.txPasswordMessage(pgConn.config.Password)
if err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed handle authentication message", err: err}
return nil, &connectError{config: config, msg: "failed to write password message", err: err}
}
case *pgproto3.AuthenticationMD5Password:
digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:]))
err = pgConn.txPasswordMessage(digestedPassword)
if err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed to write password message", err: err}
}
case *pgproto3.AuthenticationSASL:
err = pgConn.scramAuth(msg.AuthMechanisms)
if err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed SASL auth", err: err}
}
case *pgproto3.ReadyForQuery:
pgConn.status = connStatusIdle
if config.ValidateConnect != nil {
@@ -257,23 +274,6 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) {
return nil
}
func (pgConn *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) {
switch msg.Type {
case pgproto3.AuthTypeOk:
case pgproto3.AuthTypeCleartextPassword:
err = pgConn.txPasswordMessage(pgConn.config.Password)
case pgproto3.AuthTypeMD5Password:
digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:]))
err = pgConn.txPasswordMessage(digestedPassword)
case pgproto3.AuthTypeSASL:
err = pgConn.scramAuth(msg.SASLAuthMechanisms)
default:
err = errors.New("Received unknown authentication message")
}
return
}
func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
msg := &pgproto3.PasswordMessage{Password: password}
_, err = pgConn.conn.Write(msg.Encode(pgConn.wbuf))