@@ -1 +1,2 @@
|
|||||||
.envrc
|
.envrc
|
||||||
|
vendor/
|
||||||
@@ -3,9 +3,29 @@
|
|||||||
|
|
||||||
# pgconn
|
# pgconn
|
||||||
|
|
||||||
Package pgconn is a low-level PostgreSQL database driver.
|
Package pgconn is a low-level PostgreSQL database driver. It operates at nearly the same level is the C library libpq.
|
||||||
|
It is primarily intended to serve as the foundation for higher level libraries such as https://github.com/jackc/pgx.
|
||||||
|
Applications should handle normal queries with a higher level library and only use pgconn directly when required for
|
||||||
|
low-level access to PostgreSQL functionality.
|
||||||
|
|
||||||
It is intended to serve as the foundation for the next generation of https://github.com/jackc/pgx.
|
## Example Usage
|
||||||
|
|
||||||
|
```go
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL"))
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln("pgconn failed to connect:", err)
|
||||||
|
}
|
||||||
|
defer pgConn.Close()
|
||||||
|
|
||||||
|
result := pgConn.ExecParams(context.Background(), "select email from users where id=$1", [][]byte{[]byte("123")}, nil, nil, nil)
|
||||||
|
for result.NextRow() {
|
||||||
|
fmt.Println("User 123 has email:", string(result.Values()[0]))
|
||||||
|
}
|
||||||
|
_, err := result.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln("failed reading result:", err)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -249,7 +249,7 @@ func computeClientProof(saltedPassword, authMessage []byte) []byte {
|
|||||||
|
|
||||||
func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte {
|
func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte {
|
||||||
serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
|
serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
|
||||||
serverSignature := computeHMAC(serverKey[:], authMessage)
|
serverSignature := computeHMAC(serverKey, authMessage)
|
||||||
buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature)))
|
buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature)))
|
||||||
base64.StdEncoding.Encode(buf, serverSignature)
|
base64.StdEncoding.Encode(buf, serverSignature)
|
||||||
return buf
|
return buf
|
||||||
|
|||||||
+8
-7
@@ -20,6 +20,7 @@ func BenchmarkConnect(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, bm := range benchmarks {
|
for _, bm := range benchmarks {
|
||||||
|
bm := bm
|
||||||
b.Run(bm.name, func(b *testing.B) {
|
b.Run(bm.name, func(b *testing.B) {
|
||||||
connString := os.Getenv(bm.env)
|
connString := os.Getenv(bm.env)
|
||||||
if connString == "" {
|
if connString == "" {
|
||||||
@@ -54,12 +55,12 @@ func BenchmarkExec(b *testing.B) {
|
|||||||
|
|
||||||
rowCount := 0
|
rowCount := 0
|
||||||
for rr.NextRow() {
|
for rr.NextRow() {
|
||||||
rowCount += 1
|
rowCount++
|
||||||
if len(rr.Values()) != len(expectedValues) {
|
if len(rr.Values()) != len(expectedValues) {
|
||||||
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
|
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
|
||||||
}
|
}
|
||||||
for i := range rr.Values() {
|
for i := range rr.Values() {
|
||||||
if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 {
|
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
|
||||||
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
|
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -101,12 +102,12 @@ func BenchmarkExecPossibleToCancel(b *testing.B) {
|
|||||||
|
|
||||||
rowCount := 0
|
rowCount := 0
|
||||||
for rr.NextRow() {
|
for rr.NextRow() {
|
||||||
rowCount += 1
|
rowCount++
|
||||||
if len(rr.Values()) != len(expectedValues) {
|
if len(rr.Values()) != len(expectedValues) {
|
||||||
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
|
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
|
||||||
}
|
}
|
||||||
for i := range rr.Values() {
|
for i := range rr.Values() {
|
||||||
if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 {
|
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
|
||||||
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
|
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -145,12 +146,12 @@ func BenchmarkExecPrepared(b *testing.B) {
|
|||||||
|
|
||||||
rowCount := 0
|
rowCount := 0
|
||||||
for rr.NextRow() {
|
for rr.NextRow() {
|
||||||
rowCount += 1
|
rowCount++
|
||||||
if len(rr.Values()) != len(expectedValues) {
|
if len(rr.Values()) != len(expectedValues) {
|
||||||
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
|
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
|
||||||
}
|
}
|
||||||
for i := range rr.Values() {
|
for i := range rr.Values() {
|
||||||
if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 {
|
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
|
||||||
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
|
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -191,7 +192,7 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) {
|
|||||||
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
|
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
|
||||||
}
|
}
|
||||||
for i := range rr.Values() {
|
for i := range rr.Values() {
|
||||||
if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 {
|
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
|
||||||
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
|
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
||||||
|
type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
||||||
|
|
||||||
// Config is the settings used to establish a connection to a PostgreSQL server.
|
// Config is the settings used to establish a connection to a PostgreSQL server.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -36,10 +37,15 @@ type Config struct {
|
|||||||
|
|
||||||
Fallbacks []*FallbackConfig
|
Fallbacks []*FallbackConfig
|
||||||
|
|
||||||
// AfterConnectFunc is called after successful connection. It can be used to set up the connection or to validate that
|
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
|
||||||
// server is acceptable. If this returns an error the connection is closed and the next fallback config is tried. This
|
// It can be used validate that server is acceptable. If this returns an error the connection is closed and the next
|
||||||
// allows implementing high availability behavior such as libpq does with target_session_attrs.
|
// fallback config is tried. This allows implementing high availability behavior such as libpq does with
|
||||||
AfterConnectFunc AfterConnectFunc
|
// target_session_attrs.
|
||||||
|
ValidateConnect ValidateConnectFunc
|
||||||
|
|
||||||
|
// AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables
|
||||||
|
// or prepare statements). If this returns an error the connection attempt fails.
|
||||||
|
AfterConnect AfterConnectFunc
|
||||||
|
|
||||||
// OnNotice is a callback function called when a notice response is received.
|
// OnNotice is a callback function called when a notice response is received.
|
||||||
OnNotice NoticeHandler
|
OnNotice NoticeHandler
|
||||||
@@ -121,6 +127,13 @@ func NetworkAddress(host string, port uint16) (network, address string) {
|
|||||||
// security guarantees than it would with libpq. Do not rely on this behavior as it
|
// security guarantees than it would with libpq. Do not rely on this behavior as it
|
||||||
// may be possible to match libpq in the future. If you need full security use
|
// may be possible to match libpq in the future. If you need full security use
|
||||||
// "verify-full".
|
// "verify-full".
|
||||||
|
//
|
||||||
|
// Other known differences with libpq:
|
||||||
|
//
|
||||||
|
// If a host name resolves into multiple addresses, libpq will try all addresses. pgconn will only try the first.
|
||||||
|
//
|
||||||
|
// When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn
|
||||||
|
// does not.
|
||||||
func ParseConfig(connString string) (*Config, error) {
|
func ParseConfig(connString string) (*Config, error) {
|
||||||
settings := defaultSettings()
|
settings := defaultSettings()
|
||||||
addEnvSettings(settings)
|
addEnvSettings(settings)
|
||||||
@@ -238,7 +251,7 @@ func ParseConfig(connString string) (*Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if settings["target_session_attrs"] == "read-write" {
|
if settings["target_session_attrs"] == "read-write" {
|
||||||
config.AfterConnectFunc = AfterConnectTargetSessionAttrsReadWrite
|
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite
|
||||||
} else if settings["target_session_attrs"] != "any" {
|
} else if settings["target_session_attrs"] != "any" {
|
||||||
return nil, errors.Errorf("unknown target_session_attrs value: %v", settings["target_session_attrs"])
|
return nil, errors.Errorf("unknown target_session_attrs value: %v", settings["target_session_attrs"])
|
||||||
}
|
}
|
||||||
@@ -474,9 +487,9 @@ func makeConnectTimeoutDialFunc(s string) (DialFunc, error) {
|
|||||||
return d.DialContext, nil
|
return d.DialContext, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible
|
// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible
|
||||||
// target_session_attrs=read-write.
|
// target_session_attrs=read-write.
|
||||||
func AfterConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
|
func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
|
||||||
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
|
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
|
||||||
if result.Err != nil {
|
if result.Err != nil {
|
||||||
return result.Err
|
return result.Err
|
||||||
|
|||||||
+10
-9
@@ -378,14 +378,14 @@ func TestParseConfig(t *testing.T) {
|
|||||||
name: "target_session_attrs",
|
name: "target_session_attrs",
|
||||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write",
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write",
|
||||||
config: &pgconn.Config{
|
config: &pgconn.Config{
|
||||||
User: "jack",
|
User: "jack",
|
||||||
Password: "secret",
|
Password: "secret",
|
||||||
Host: "localhost",
|
Host: "localhost",
|
||||||
Port: 5432,
|
Port: 5432,
|
||||||
Database: "mydb",
|
Database: "mydb",
|
||||||
TLSConfig: nil,
|
TLSConfig: nil,
|
||||||
RuntimeParams: map[string]string{},
|
RuntimeParams: map[string]string{},
|
||||||
AfterConnectFunc: pgconn.AfterConnectTargetSessionAttrsReadWrite,
|
ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsReadWrite,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -416,7 +416,8 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName
|
|||||||
assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
|
assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
|
||||||
|
|
||||||
// Can't test function equality, so just test that they are set or not.
|
// Can't test function equality, so just test that they are set or not.
|
||||||
assert.Equalf(t, expected.AfterConnectFunc == nil, actual.AfterConnectFunc == nil, "%s - AfterConnectFunc", testName)
|
assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName)
|
||||||
|
assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName)
|
||||||
|
|
||||||
if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) {
|
if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) {
|
||||||
if expected.TLSConfig != nil {
|
if expected.TLSConfig != nil {
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ go 1.12
|
|||||||
require (
|
require (
|
||||||
github.com/jackc/pgio v1.0.0
|
github.com/jackc/pgio v1.0.0
|
||||||
github.com/jackc/pgpassfile v1.0.0
|
github.com/jackc/pgpassfile v1.0.0
|
||||||
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db
|
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711
|
||||||
github.com/stretchr/testify v1.3.0
|
github.com/stretchr/testify v1.3.0
|
||||||
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a
|
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a
|
||||||
golang.org/x/text v0.3.0
|
golang.org/x/text v0.3.0
|
||||||
|
|||||||
@@ -2,12 +2,16 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8
|
|||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0=
|
github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0=
|
||||||
github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=
|
github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=
|
||||||
|
github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs=
|
||||||
|
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
|
||||||
github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE=
|
github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE=
|
||||||
github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8=
|
github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8=
|
||||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db h1:UpaKn/gYxzH6/zWyRQH1S260zvKqwJJ4h8+Kf09ooh0=
|
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db h1:UpaKn/gYxzH6/zWyRQH1S260zvKqwJJ4h8+Kf09ooh0=
|
||||||
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
|
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
|
||||||
|
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 h1:vZp4bYotXUkFx7JUSm7U8KV/7Q0AOdrQxxBBj0ZmZsg=
|
||||||
|
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg=
|
||||||
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
||||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
|||||||
@@ -122,13 +122,25 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
|
|||||||
for _, fc := range fallbackConfigs {
|
for _, fc := range fallbackConfigs {
|
||||||
pgConn, err = connect(ctx, config, fc)
|
pgConn, err = connect(ctx, config, fc)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return pgConn, nil
|
break
|
||||||
} else if err, ok := err.(*PgError); ok {
|
} else if err, ok := err.(*PgError); ok {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, err
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.AfterConnect != nil {
|
||||||
|
err := config.AfterConnect(ctx, pgConn)
|
||||||
|
if err != nil {
|
||||||
|
pgConn.conn.Close()
|
||||||
|
return nil, errors.Errorf("AfterConnect: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return pgConn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) {
|
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) {
|
||||||
@@ -201,11 +213,11 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||||||
}
|
}
|
||||||
case *pgproto3.ReadyForQuery:
|
case *pgproto3.ReadyForQuery:
|
||||||
pgConn.status = connStatusIdle
|
pgConn.status = connStatusIdle
|
||||||
if config.AfterConnectFunc != nil {
|
if config.ValidateConnect != nil {
|
||||||
err := config.AfterConnectFunc(ctx, pgConn)
|
err := config.ValidateConnect(ctx, pgConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.conn.Close()
|
pgConn.conn.Close()
|
||||||
return nil, errors.Errorf("AfterConnectFunc: %v", err)
|
return nil, errors.Errorf("ValidateConnect: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return pgConn, nil
|
return pgConn, nil
|
||||||
@@ -241,16 +253,16 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) {
|
func (pgConn *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) {
|
||||||
switch msg.Type {
|
switch msg.Type {
|
||||||
case pgproto3.AuthTypeOk:
|
case pgproto3.AuthTypeOk:
|
||||||
case pgproto3.AuthTypeCleartextPassword:
|
case pgproto3.AuthTypeCleartextPassword:
|
||||||
err = c.txPasswordMessage(c.Config.Password)
|
err = pgConn.txPasswordMessage(pgConn.Config.Password)
|
||||||
case pgproto3.AuthTypeMD5Password:
|
case pgproto3.AuthTypeMD5Password:
|
||||||
digestedPassword := "md5" + hexMD5(hexMD5(c.Config.Password+c.Config.User)+string(msg.Salt[:]))
|
digestedPassword := "md5" + hexMD5(hexMD5(pgConn.Config.Password+pgConn.Config.User)+string(msg.Salt[:]))
|
||||||
err = c.txPasswordMessage(digestedPassword)
|
err = pgConn.txPasswordMessage(digestedPassword)
|
||||||
case pgproto3.AuthTypeSASL:
|
case pgproto3.AuthTypeSASL:
|
||||||
err = c.scramAuth(msg.SASLAuthMechanisms)
|
err = pgConn.scramAuth(msg.SASLAuthMechanisms)
|
||||||
default:
|
default:
|
||||||
err = errors.New("Received unknown authentication message")
|
err = errors.New("Received unknown authentication message")
|
||||||
}
|
}
|
||||||
@@ -390,7 +402,7 @@ func (pgConn *PgConn) hardClose() error {
|
|||||||
return pgConn.conn.Close()
|
return pgConn.conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO - rethink how to report status. At the moment this is just a temporary measure so pgx.Conn can detect deatch of
|
// TODO - rethink how to report status. At the moment this is just a temporary measure so pgx.Conn can detect death of
|
||||||
// underlying connection.
|
// underlying connection.
|
||||||
func (pgConn *PgConn) IsAlive() bool {
|
func (pgConn *PgConn) IsAlive() bool {
|
||||||
return pgConn.status >= connStatusIdle
|
return pgConn.status >= connStatusIdle
|
||||||
@@ -514,11 +526,11 @@ readloop:
|
|||||||
|
|
||||||
func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
|
func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
|
||||||
return &PgError{
|
return &PgError{
|
||||||
Severity: string(msg.Severity),
|
Severity: msg.Severity,
|
||||||
Code: string(msg.Code),
|
Code: string(msg.Code),
|
||||||
Message: string(msg.Message),
|
Message: string(msg.Message),
|
||||||
Detail: string(msg.Detail),
|
Detail: string(msg.Detail),
|
||||||
Hint: string(msg.Hint),
|
Hint: msg.Hint,
|
||||||
Position: msg.Position,
|
Position: msg.Position,
|
||||||
InternalPosition: msg.InternalPosition,
|
InternalPosition: msg.InternalPosition,
|
||||||
InternalQuery: string(msg.InternalQuery),
|
InternalQuery: string(msg.InternalQuery),
|
||||||
@@ -527,7 +539,7 @@ func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
|
|||||||
TableName: string(msg.TableName),
|
TableName: string(msg.TableName),
|
||||||
ColumnName: string(msg.ColumnName),
|
ColumnName: string(msg.ColumnName),
|
||||||
DataTypeName: string(msg.DataTypeName),
|
DataTypeName: string(msg.DataTypeName),
|
||||||
ConstraintName: string(msg.ConstraintName),
|
ConstraintName: msg.ConstraintName,
|
||||||
File: string(msg.File),
|
File: string(msg.File),
|
||||||
Line: msg.Line,
|
Line: msg.Line,
|
||||||
Routine: string(msg.Routine),
|
Routine: string(msg.Routine),
|
||||||
@@ -919,7 +931,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||||||
copyDone := &pgproto3.CopyDone{}
|
copyDone := &pgproto3.CopyDone{}
|
||||||
buf = copyDone.Encode(buf)
|
buf = copyDone.Encode(buf)
|
||||||
} else {
|
} else {
|
||||||
copyFail := &pgproto3.CopyFail{Error: readErr.Error()}
|
copyFail := &pgproto3.CopyFail{Message: readErr.Error()}
|
||||||
buf = copyFail.Encode(buf)
|
buf = copyFail.Encode(buf)
|
||||||
}
|
}
|
||||||
_, err = pgConn.conn.Write(buf)
|
_, err = pgConn.conn.Write(buf)
|
||||||
|
|||||||
+31
-9
@@ -37,6 +37,7 @@ func TestConnect(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
tt := tt
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
connString := os.Getenv(tt.env)
|
connString := os.Getenv(tt.env)
|
||||||
if connString == "" {
|
if connString == "" {
|
||||||
@@ -186,7 +187,7 @@ func TestConnectWithFallback(t *testing.T) {
|
|||||||
closeConn(t, conn)
|
closeConn(t, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnectWithAfterConnectFunc(t *testing.T) {
|
func TestConnectWithValidateConnect(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
@@ -194,13 +195,13 @@ func TestConnectWithAfterConnectFunc(t *testing.T) {
|
|||||||
|
|
||||||
dialCount := 0
|
dialCount := 0
|
||||||
config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
|
config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
dialCount += 1
|
dialCount++
|
||||||
return net.Dial(network, address)
|
return net.Dial(network, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
acceptConnCount := 0
|
acceptConnCount := 0
|
||||||
config.AfterConnectFunc = func(ctx context.Context, conn *pgconn.PgConn) error {
|
config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error {
|
||||||
acceptConnCount += 1
|
acceptConnCount++
|
||||||
if acceptConnCount < 2 {
|
if acceptConnCount < 2 {
|
||||||
return errors.New("reject first conn")
|
return errors.New("reject first conn")
|
||||||
}
|
}
|
||||||
@@ -225,13 +226,13 @@ func TestConnectWithAfterConnectFunc(t *testing.T) {
|
|||||||
assert.True(t, acceptConnCount > 1)
|
assert.True(t, acceptConnCount > 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) {
|
func TestConnectWithValidateConnectTargetSessionAttrsReadWrite(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
config.AfterConnectFunc = pgconn.AfterConnectTargetSessionAttrsReadWrite
|
config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite
|
||||||
config.RuntimeParams["default_transaction_read_only"] = "on"
|
config.RuntimeParams["default_transaction_read_only"] = "on"
|
||||||
|
|
||||||
conn, err := pgconn.ConnectConfig(context.Background(), config)
|
conn, err := pgconn.ConnectConfig(context.Background(), config)
|
||||||
@@ -240,6 +241,27 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnectWithAfterConnect(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error {
|
||||||
|
_, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := pgconn.ConnectConfig(context.Background(), config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
results, err := conn.Exec(context.Background(), "show search_path;").ReadAll()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
assert.Equal(t, []byte("foobar"), results[0].Rows[0][0])
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnPrepareSyntaxError(t *testing.T) {
|
func TestConnPrepareSyntaxError(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -302,7 +324,7 @@ func TestConnExecEmpty(t *testing.T) {
|
|||||||
|
|
||||||
resultCount := 0
|
resultCount := 0
|
||||||
for multiResult.NextResult() {
|
for multiResult.NextResult() {
|
||||||
resultCount += 1
|
resultCount++
|
||||||
multiResult.ResultReader().Close()
|
multiResult.ResultReader().Close()
|
||||||
}
|
}
|
||||||
assert.Equal(t, 0, resultCount)
|
assert.Equal(t, 0, resultCount)
|
||||||
@@ -753,12 +775,12 @@ func TestConnLocking(t *testing.T) {
|
|||||||
defer closeConn(t, pgConn)
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
mrr := pgConn.Exec(context.Background(), "select 'Hello, world'")
|
mrr := pgConn.Exec(context.Background(), "select 'Hello, world'")
|
||||||
results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll()
|
_, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll()
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.True(t, errors.Is(err, pgconn.ErrConnBusy))
|
assert.True(t, errors.Is(err, pgconn.ErrConnBusy))
|
||||||
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
|
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
|
||||||
|
|
||||||
results, err = mrr.ReadAll()
|
results, err := mrr.ReadAll()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Len(t, results, 1)
|
assert.Len(t, results, 1)
|
||||||
assert.Nil(t, results[0].Err)
|
assert.Nil(t, results[0].Err)
|
||||||
|
|||||||
Reference in New Issue
Block a user