diff --git a/.gitignore b/.gitignore index 7a6353d6..6eb9d442 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ .envrc +vendor/ \ No newline at end of file diff --git a/README.md b/README.md index 05cfedf1..9e35a0f5 100644 --- a/README.md +++ b/README.md @@ -3,9 +3,29 @@ # pgconn -Package pgconn is a low-level PostgreSQL database driver. +Package pgconn is a low-level PostgreSQL database driver. It operates at nearly the same level is the C library libpq. +It is primarily intended to serve as the foundation for higher level libraries such as https://github.com/jackc/pgx. +Applications should handle normal queries with a higher level library and only use pgconn directly when required for +low-level access to PostgreSQL functionality. -It is intended to serve as the foundation for the next generation of https://github.com/jackc/pgx. +## Example Usage + +```go +pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL")) +if err != nil { + log.Fatalln("pgconn failed to connect:", err) +} +defer pgConn.Close() + +result := pgConn.ExecParams(context.Background(), "select email from users where id=$1", [][]byte{[]byte("123")}, nil, nil, nil) +for result.NextRow() { + fmt.Println("User 123 has email:", string(result.Values()[0])) +} +_, err := result.Close() +if err != nil { + log.Fatalln("failed reading result:", err) +}) +``` ## Testing diff --git a/auth_scram.go b/auth_scram.go index d102d305..bdaf3e92 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -249,7 +249,7 @@ func computeClientProof(saltedPassword, authMessage []byte) []byte { func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte { serverKey := computeHMAC(saltedPassword, []byte("Server Key")) - serverSignature := computeHMAC(serverKey[:], authMessage) + serverSignature := computeHMAC(serverKey, authMessage) buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature))) base64.StdEncoding.Encode(buf, serverSignature) return buf diff --git a/benchmark_test.go b/benchmark_test.go index 51e11e24..8067c985 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -20,6 +20,7 @@ func BenchmarkConnect(b *testing.B) { } for _, bm := range benchmarks { + bm := bm b.Run(bm.name, func(b *testing.B) { connString := os.Getenv(bm.env) if connString == "" { @@ -54,12 +55,12 @@ func BenchmarkExec(b *testing.B) { rowCount := 0 for rr.NextRow() { - rowCount += 1 + rowCount++ if len(rr.Values()) != len(expectedValues) { b.Fatalf("unexpected number of values: %d", len(rr.Values())) } for i := range rr.Values() { - if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) } } @@ -101,12 +102,12 @@ func BenchmarkExecPossibleToCancel(b *testing.B) { rowCount := 0 for rr.NextRow() { - rowCount += 1 + rowCount++ if len(rr.Values()) != len(expectedValues) { b.Fatalf("unexpected number of values: %d", len(rr.Values())) } for i := range rr.Values() { - if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) } } @@ -145,12 +146,12 @@ func BenchmarkExecPrepared(b *testing.B) { rowCount := 0 for rr.NextRow() { - rowCount += 1 + rowCount++ if len(rr.Values()) != len(expectedValues) { b.Fatalf("unexpected number of values: %d", len(rr.Values())) } for i := range rr.Values() { - if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) } } @@ -191,7 +192,7 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { b.Fatalf("unexpected number of values: %d", len(rr.Values())) } for i := range rr.Values() { - if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) } } diff --git a/config.go b/config.go index c751cc0d..9b74945e 100644 --- a/config.go +++ b/config.go @@ -22,6 +22,7 @@ import ( ) type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error +type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error // Config is the settings used to establish a connection to a PostgreSQL server. type Config struct { @@ -36,10 +37,15 @@ type Config struct { Fallbacks []*FallbackConfig - // AfterConnectFunc is called after successful connection. It can be used to set up the connection or to validate that - // server is acceptable. If this returns an error the connection is closed and the next fallback config is tried. This - // allows implementing high availability behavior such as libpq does with target_session_attrs. - AfterConnectFunc AfterConnectFunc + // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. + // It can be used validate that server is acceptable. If this returns an error the connection is closed and the next + // fallback config is tried. This allows implementing high availability behavior such as libpq does with + // target_session_attrs. + ValidateConnect ValidateConnectFunc + + // AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables + // or prepare statements). If this returns an error the connection attempt fails. + AfterConnect AfterConnectFunc // OnNotice is a callback function called when a notice response is received. OnNotice NoticeHandler @@ -121,6 +127,13 @@ func NetworkAddress(host string, port uint16) (network, address string) { // security guarantees than it would with libpq. Do not rely on this behavior as it // may be possible to match libpq in the future. If you need full security use // "verify-full". +// +// Other known differences with libpq: +// +// If a host name resolves into multiple addresses, libpq will try all addresses. pgconn will only try the first. +// +// When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn +// does not. func ParseConfig(connString string) (*Config, error) { settings := defaultSettings() addEnvSettings(settings) @@ -238,7 +251,7 @@ func ParseConfig(connString string) (*Config, error) { } if settings["target_session_attrs"] == "read-write" { - config.AfterConnectFunc = AfterConnectTargetSessionAttrsReadWrite + config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite } else if settings["target_session_attrs"] != "any" { return nil, errors.Errorf("unknown target_session_attrs value: %v", settings["target_session_attrs"]) } @@ -474,9 +487,9 @@ func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { return d.DialContext, nil } -// AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible +// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible // target_session_attrs=read-write. -func AfterConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { +func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() if result.Err != nil { return result.Err diff --git a/config_test.go b/config_test.go index ce6f3957..23d86529 100644 --- a/config_test.go +++ b/config_test.go @@ -378,14 +378,14 @@ func TestParseConfig(t *testing.T) { name: "target_session_attrs", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write", config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, - RuntimeParams: map[string]string{}, - AfterConnectFunc: pgconn.AfterConnectTargetSessionAttrsReadWrite, + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsReadWrite, }, }, } @@ -416,7 +416,8 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) // Can't test function equality, so just test that they are set or not. - assert.Equalf(t, expected.AfterConnectFunc == nil, actual.AfterConnectFunc == nil, "%s - AfterConnectFunc", testName) + assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName) + assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { if expected.TLSConfig != nil { diff --git a/go.mod b/go.mod index 9401dce8..b1c84049 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.12 require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db + github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 github.com/stretchr/testify v1.3.0 golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a golang.org/x/text v0.3.0 diff --git a/go.sum b/go.sum index 1b6862a0..50dfc2fd 100644 --- a/go.sum +++ b/go.sum @@ -2,12 +2,16 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= +github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= +github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db h1:UpaKn/gYxzH6/zWyRQH1S260zvKqwJJ4h8+Kf09ooh0= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 h1:vZp4bYotXUkFx7JUSm7U8KV/7Q0AOdrQxxBBj0ZmZsg= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/pgconn.go b/pgconn.go index c51742ae..6e1fb7e3 100644 --- a/pgconn.go +++ b/pgconn.go @@ -122,13 +122,25 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err for _, fc := range fallbackConfigs { pgConn, err = connect(ctx, config, fc) if err == nil { - return pgConn, nil + break } else if err, ok := err.(*PgError); ok { return nil, err } } - return nil, err + if err != nil { + return nil, err + } + + if config.AfterConnect != nil { + err := config.AfterConnect(ctx, pgConn) + if err != nil { + pgConn.conn.Close() + return nil, errors.Errorf("AfterConnect: %v", err) + } + } + + return pgConn, nil } func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { @@ -201,11 +213,11 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } case *pgproto3.ReadyForQuery: pgConn.status = connStatusIdle - if config.AfterConnectFunc != nil { - err := config.AfterConnectFunc(ctx, pgConn) + if config.ValidateConnect != nil { + err := config.ValidateConnect(ctx, pgConn) if err != nil { pgConn.conn.Close() - return nil, errors.Errorf("AfterConnectFunc: %v", err) + return nil, errors.Errorf("ValidateConnect: %v", err) } } return pgConn, nil @@ -241,16 +253,16 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { return nil } -func (c *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { +func (pgConn *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { switch msg.Type { case pgproto3.AuthTypeOk: case pgproto3.AuthTypeCleartextPassword: - err = c.txPasswordMessage(c.Config.Password) + err = pgConn.txPasswordMessage(pgConn.Config.Password) case pgproto3.AuthTypeMD5Password: - digestedPassword := "md5" + hexMD5(hexMD5(c.Config.Password+c.Config.User)+string(msg.Salt[:])) - err = c.txPasswordMessage(digestedPassword) + digestedPassword := "md5" + hexMD5(hexMD5(pgConn.Config.Password+pgConn.Config.User)+string(msg.Salt[:])) + err = pgConn.txPasswordMessage(digestedPassword) case pgproto3.AuthTypeSASL: - err = c.scramAuth(msg.SASLAuthMechanisms) + err = pgConn.scramAuth(msg.SASLAuthMechanisms) default: err = errors.New("Received unknown authentication message") } @@ -390,7 +402,7 @@ func (pgConn *PgConn) hardClose() error { return pgConn.conn.Close() } -// TODO - rethink how to report status. At the moment this is just a temporary measure so pgx.Conn can detect deatch of +// TODO - rethink how to report status. At the moment this is just a temporary measure so pgx.Conn can detect death of // underlying connection. func (pgConn *PgConn) IsAlive() bool { return pgConn.status >= connStatusIdle @@ -514,11 +526,11 @@ readloop: func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { return &PgError{ - Severity: string(msg.Severity), + Severity: msg.Severity, Code: string(msg.Code), Message: string(msg.Message), Detail: string(msg.Detail), - Hint: string(msg.Hint), + Hint: msg.Hint, Position: msg.Position, InternalPosition: msg.InternalPosition, InternalQuery: string(msg.InternalQuery), @@ -527,7 +539,7 @@ func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { TableName: string(msg.TableName), ColumnName: string(msg.ColumnName), DataTypeName: string(msg.DataTypeName), - ConstraintName: string(msg.ConstraintName), + ConstraintName: msg.ConstraintName, File: string(msg.File), Line: msg.Line, Routine: string(msg.Routine), @@ -919,7 +931,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co copyDone := &pgproto3.CopyDone{} buf = copyDone.Encode(buf) } else { - copyFail := &pgproto3.CopyFail{Error: readErr.Error()} + copyFail := &pgproto3.CopyFail{Message: readErr.Error()} buf = copyFail.Encode(buf) } _, err = pgConn.conn.Write(buf) diff --git a/pgconn_test.go b/pgconn_test.go index 310b387b..feb78641 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -37,6 +37,7 @@ func TestConnect(t *testing.T) { } for _, tt := range tests { + tt := tt t.Run(tt.name, func(t *testing.T) { connString := os.Getenv(tt.env) if connString == "" { @@ -186,7 +187,7 @@ func TestConnectWithFallback(t *testing.T) { closeConn(t, conn) } -func TestConnectWithAfterConnectFunc(t *testing.T) { +func TestConnectWithValidateConnect(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) @@ -194,13 +195,13 @@ func TestConnectWithAfterConnectFunc(t *testing.T) { dialCount := 0 config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { - dialCount += 1 + dialCount++ return net.Dial(network, address) } acceptConnCount := 0 - config.AfterConnectFunc = func(ctx context.Context, conn *pgconn.PgConn) error { - acceptConnCount += 1 + config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + acceptConnCount++ if acceptConnCount < 2 { return errors.New("reject first conn") } @@ -225,13 +226,13 @@ func TestConnectWithAfterConnectFunc(t *testing.T) { assert.True(t, acceptConnCount > 1) } -func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { +func TestConnectWithValidateConnectTargetSessionAttrsReadWrite(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) - config.AfterConnectFunc = pgconn.AfterConnectTargetSessionAttrsReadWrite + config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite config.RuntimeParams["default_transaction_read_only"] = "on" conn, err := pgconn.ConnectConfig(context.Background(), config) @@ -240,6 +241,27 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { } } +func TestConnectWithAfterConnect(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + _, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll() + return err + } + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + + results, err := conn.Exec(context.Background(), "show search_path;").ReadAll() + require.NoError(t, err) + defer closeConn(t, conn) + + assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) +} + func TestConnPrepareSyntaxError(t *testing.T) { t.Parallel() @@ -302,7 +324,7 @@ func TestConnExecEmpty(t *testing.T) { resultCount := 0 for multiResult.NextResult() { - resultCount += 1 + resultCount++ multiResult.ResultReader().Close() } assert.Equal(t, 0, resultCount) @@ -753,12 +775,12 @@ func TestConnLocking(t *testing.T) { defer closeConn(t, pgConn) mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") - results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + _, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() assert.Error(t, err) assert.True(t, errors.Is(err, pgconn.ErrConnBusy)) assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) - results, err = mrr.ReadAll() + results, err := mrr.ReadAll() assert.NoError(t, err) assert.Len(t, results, 1) assert.Nil(t, results[0].Err)