From fa7e06489bda50794a89e7a6e60446c4cc1c2ba5 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Fri, 26 Jul 2019 11:14:07 +0300 Subject: [PATCH 01/27] Add MinReadBufferSize option to Config Signed-off-by: Artemiy Ryabinkov --- config.go | 3 +++ pgconn.go | 8 +++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index 9b74945e..bbd458e3 100644 --- a/config.go +++ b/config.go @@ -37,6 +37,9 @@ type Config struct { Fallbacks []*FallbackConfig + // MinReadBufferSize used to configure size of connection read buffer. + MinReadBufferSize int + // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. // It can be used validate that server is acceptable. If this returns an error the connection is closed and the next // fallback config is tried. This allows implementing high availability behavior such as libpq does with diff --git a/pgconn.go b/pgconn.go index 6e1fb7e3..5077ccae 100644 --- a/pgconn.go +++ b/pgconn.go @@ -15,6 +15,7 @@ import ( "sync" "time" + "github.com/jackc/chunkreader/v2" "github.com/jackc/pgconn/internal/ctxwatch" "github.com/jackc/pgio" "github.com/jackc/pgproto3/v2" @@ -170,7 +171,12 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig func() { pgConn.conn.SetDeadline(time.Time{}) }, ) - pgConn.Frontend, err = pgproto3.NewFrontend(pgproto3.NewChunkReader(pgConn.conn), pgConn.conn) + cr, err := chunkreader.NewConfig(pgConn.conn, chunkreader.Config{MinBufLen: config.MinReadBufferSize}) + if err != nil { + return nil, err + } + + pgConn.Frontend, err = pgproto3.NewFrontend(cr, pgConn.conn) if err != nil { return nil, err } From f0b479097a4868d74e83c938131f5a24d25c49e8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 6 Aug 2019 17:07:11 -0500 Subject: [PATCH 02/27] Fix missing deferred constraint violations in certain conditions See https://github.com/jackc/pgx/issues/570. --- pgconn.go | 5 ++- pgconn_test.go | 85 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index 6e1fb7e3..3157f17e 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1151,7 +1151,10 @@ func (rr *ResultReader) Close() (CommandTag, error) { return nil, rr.err } - switch msg.(type) { + switch msg := msg.(type) { + // Detect a deferred constraint violation where the ErrorResponse is sent after CommandComplete. + case *pgproto3.ErrorResponse: + rr.err = errorResponseToPgError(msg) case *pgproto3.ReadyForQuery: rr.pgConn.contextWatcher.Unwatch() rr.pgConn.unlock() diff --git a/pgconn_test.go b/pgconn_test.go index feb78641..1b90b9d2 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -381,6 +381,34 @@ func TestConnExecMultipleQueriesError(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnExecDeferredError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) + + _, err = pgConn.Exec(context.Background(), `update t set n=n+1 where id='b' returning *`).ReadAll() + require.NotNil(t, err) + + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) + + ensureConnValid(t, pgConn) +} + func TestConnExecContextCanceled(t *testing.T) { t.Parallel() @@ -437,6 +465,33 @@ func TestConnExecParams(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnExecParamsDeferredError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) + + result := pgConn.ExecParams(context.Background(), `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read() + require.NotNil(t, result.Err) + var pgErr *pgconn.PgError + require.True(t, errors.As(result.Err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) + + ensureConnValid(t, pgConn) +} + func TestConnExecParamsMaxNumberOfParams(t *testing.T) { t.Parallel() @@ -683,6 +738,36 @@ func TestConnExecBatch(t *testing.T) { assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) } +func TestConnExecBatchDeferredError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) + + batch := &pgconn.Batch{} + + batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil) + _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.NotNil(t, err) + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) + + ensureConnValid(t, pgConn) +} + func TestConnExecBatchPrecanceled(t *testing.T) { t.Parallel() From 0a99b543c007eab4dd3eb284e0206eb7d8144346 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Thu, 8 Aug 2019 11:46:25 +0300 Subject: [PATCH 03/27] Add BuildFrontendFunc in Config Signed-off-by: Artemiy Ryabinkov --- config.go | 30 +++++++++++++++++++----------- go.sum | 4 ---- pgconn.go | 32 +++++++++++++++++--------------- 3 files changed, 36 insertions(+), 30 deletions(-) diff --git a/config.go b/config.go index bbd458e3..be8bdab4 100644 --- a/config.go +++ b/config.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "io" "io/ioutil" "math" "net" @@ -18,6 +19,7 @@ import ( "time" "github.com/jackc/pgpassfile" + "github.com/jackc/pgproto3/v2" errors "golang.org/x/xerrors" ) @@ -26,20 +28,18 @@ type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error // Config is the settings used to establish a connection to a PostgreSQL server. type Config struct { - Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) - Port uint16 - Database string - User string - Password string - TLSConfig *tls.Config // nil disables TLS - DialFunc DialFunc // e.g. net.Dialer.DialContext - RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) + Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) + Port uint16 + Database string + User string + Password string + TLSConfig *tls.Config // nil disables TLS + DialFunc DialFunc // e.g. net.Dialer.DialContext + BuildFrontendFunc BuildFrontendFunc + RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) Fallbacks []*FallbackConfig - // MinReadBufferSize used to configure size of connection read buffer. - MinReadBufferSize int - // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. // It can be used validate that server is acceptable. If this returns an error the connection is closed and the next // fallback config is tried. This allows implementing high availability behavior such as libpq does with @@ -476,6 +476,14 @@ func makeDefaultDialer() *net.Dialer { return &net.Dialer{KeepAlive: 5 * time.Minute} } +func makeDefaultBuildFrontendFunc() BuildFrontendFunc { + return func(r io.Reader) Frontend { + frontend, _ := pgproto3.NewFrontend(pgproto3.NewChunkReader(r), nil) + + return frontend + } +} + func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { timeout, err := strconv.ParseInt(s, 10, 64) if err != nil { diff --git a/go.sum b/go.sum index 50dfc2fd..0e853203 100644 --- a/go.sum +++ b/go.sum @@ -8,8 +8,6 @@ github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db h1:UpaKn/gYxzH6/zWyRQH1S260zvKqwJJ4h8+Kf09ooh0= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 h1:vZp4bYotXUkFx7JUSm7U8KV/7Q0AOdrQxxBBj0ZmZsg= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= @@ -25,7 +23,5 @@ golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e h1:nFYrTHrdrAOpShe27kaFHjsqY golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 h1:PPwnA7z1Pjf7XYaBP9GL1VAMZmcIWyFz7QCMSIIa3Bg= -golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 h1:bhOzK9QyoD0ogCnFro1m2mz41+Ib0oOhfJnBp5MR4K4= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/pgconn.go b/pgconn.go index 5077ccae..e7833c1f 100644 --- a/pgconn.go +++ b/pgconn.go @@ -15,7 +15,6 @@ import ( "sync" "time" - "github.com/jackc/chunkreader/v2" "github.com/jackc/pgconn/internal/ctxwatch" "github.com/jackc/pgio" "github.com/jackc/pgproto3/v2" @@ -41,9 +40,12 @@ type Notification struct { Payload string } -// DialFunc is a function that can be used to connect to a PostgreSQL server +// DialFunc is a function that can be used to connect to a PostgreSQL server. type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) +// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. +type BuildFrontendFunc func(r io.Reader) Frontend + // NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at // any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin // of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY @@ -56,6 +58,11 @@ type NoticeHandler func(*PgConn, *Notice) // notice event. type NotificationHandler func(*PgConn, *Notification) +// Frontend used to receive messages from backend. +type Frontend interface { + Receive() (pgproto3.BackendMessage, error) +} + // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { conn net.Conn // the underlying TCP or unix domain socket connection @@ -63,7 +70,7 @@ type PgConn struct { secretKey uint32 // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server TxStatus byte - Frontend *pgproto3.Frontend + frontend Frontend Config *Config @@ -106,6 +113,9 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err if config.DialFunc == nil { config.DialFunc = makeDefaultDialer().DialContext } + if config.BuildFrontendFunc == nil { + config.BuildFrontendFunc = makeDefaultBuildFrontendFunc() + } if config.RuntimeParams == nil { config.RuntimeParams = make(map[string]string) } @@ -171,15 +181,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig func() { pgConn.conn.SetDeadline(time.Time{}) }, ) - cr, err := chunkreader.NewConfig(pgConn.conn, chunkreader.Config{MinBufLen: config.MinReadBufferSize}) - if err != nil { - return nil, err - } - - pgConn.Frontend, err = pgproto3.NewFrontend(cr, pgConn.conn) - if err != nil { - return nil, err - } + pgConn.frontend = config.BuildFrontendFunc(pgConn.conn) startupMsg := pgproto3.StartupMessage{ ProtocolVersion: pgproto3.ProtocolVersionNumber, @@ -298,7 +300,7 @@ func (pgConn *PgConn) signalMessage() chan struct{} { ch := make(chan struct{}) go func() { - pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.Frontend.Receive() + pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive() pgConn.bufferingReceiveMux.Unlock() close(ch) }() @@ -318,10 +320,10 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { // If a timeout error happened in the background try the read again. if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - msg, err = pgConn.Frontend.Receive() + msg, err = pgConn.frontend.Receive() } } else { - msg, err = pgConn.Frontend.Receive() + msg, err = pgConn.frontend.Receive() } if err != nil { From dbb7aa8fd51b866cf601df8daf11306a9bb7c707 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Thu, 8 Aug 2019 12:52:04 +0300 Subject: [PATCH 04/27] Add GOPROXY to travis builds to mitigate problems with github and etc Signed-off-by: Artemiy Ryabinkov --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index e5ed43a8..1687adad 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,6 +11,7 @@ before_install: env: global: - GO111MODULE=on + - GOPROXY=https://proxy.golang.org - PGX_TEST_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test - PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql database=pgx_test" - PGX_TEST_TCP_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test From c9660e30c8b4f7903eaa7814789656ea79b6d173 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Thu, 8 Aug 2019 13:12:27 +0300 Subject: [PATCH 05/27] Use go mod download to install deps on travis-ci. Add cache for travis-ci. Signed-off-by: Artemiy Ryabinkov --- .travis.yml | 12 ++++++++++-- travis/install.bash | 14 -------------- 2 files changed, 10 insertions(+), 16 deletions(-) delete mode 100755 travis/install.bash diff --git a/.travis.yml b/.travis.yml index 1687adad..2c547abf 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,6 +4,9 @@ go: - 1.x - tip +git: + depth: 1 + # Derived from https://github.com/lib/pq/blob/master/.travis.yml before_install: - ./travis/before_install.bash @@ -12,6 +15,7 @@ env: global: - GO111MODULE=on - GOPROXY=https://proxy.golang.org + - GOFLAGS=-mod=readonly - PGX_TEST_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test - PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql database=pgx_test" - PGX_TEST_TCP_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test @@ -26,11 +30,15 @@ env: - PGVERSION=9.4 - PGVERSION=9.3 +cache: + directories: + - $HOME/.cache/go-build + - $HOME/gopath/pkg/mod + before_script: - ./travis/before_script.bash -install: - - ./travis/install.bash +install: go mod download script: - ./travis/script.bash diff --git a/travis/install.bash b/travis/install.bash deleted file mode 100755 index 63ba875d..00000000 --- a/travis/install.bash +++ /dev/null @@ -1,14 +0,0 @@ -#!/usr/bin/env bash -set -eux - -go get -u github.com/cockroachdb/apd -go get -u github.com/shopspring/decimal -go get -u gopkg.in/inconshreveable/log15.v2 -go get -u github.com/jackc/fake -go get -u github.com/lib/pq -go get -u github.com/hashicorp/go-version -go get -u github.com/satori/go.uuid -go get -u github.com/sirupsen/logrus -go get -u github.com/pkg/errors -go get -u go.uber.org/zap -go get -u github.com/rs/zerolog From d364370a31359546fb19828f737073b19a56f812 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 20 Aug 2019 14:11:16 -0500 Subject: [PATCH 06/27] Add SendBytes and ReceiveMessage --- auth_scram.go | 2 +- pgconn.go | 77 +++++++++++++++++++++++++++++++++++++++++++------- pgconn_test.go | 40 ++++++++++++++++++++++++++ 3 files changed, 108 insertions(+), 11 deletions(-) diff --git a/auth_scram.go b/auth_scram.go index bdaf3e92..4409a080 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -74,7 +74,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { } func (c *PgConn) rxAuthMsg(typ uint32) (*pgproto3.Authentication, error) { - msg, err := c.ReceiveMessage() + msg, err := c.receiveMessage() if err != nil { return nil, err } diff --git a/pgconn.go b/pgconn.go index 63e19ed1..abbc2d10 100644 --- a/pgconn.go +++ b/pgconn.go @@ -204,7 +204,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } for { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.conn.Close() return nil, err @@ -308,7 +308,64 @@ func (pgConn *PgConn) signalMessage() chan struct{} { return ch } -func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { +// SendBytes sends buf to the PostgreSQL server. It must only be used when the connection is not busy. e.g. It is as +// error to call SendBytes while reading the result of a query. +// +// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. +// See https://www.postgresql.org/docs/current/protocol.html. +func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { + if err := pgConn.lock(); err != nil { + return linkErrors(err, ErrNoBytesSent) + } + defer pgConn.unlock() + + select { + case <-ctx.Done(): + return linkErrors(ctx.Err(), ErrNoBytesSent) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + + n, err := pgConn.conn.Write(buf) + if err != nil { + pgConn.hardClose() + if n == 0 { + err = linkErrors(err, ErrNoBytesSent) + } + return linkErrors(ctx.Err(), err) + } + + return nil +} + +// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the +// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages +// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger +// the OnNotification callback. +// +// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. +// See https://www.postgresql.org/docs/current/protocol.html. +func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) { + if err := pgConn.lock(); err != nil { + return nil, linkErrors(err, ErrNoBytesSent) + } + defer pgConn.unlock() + + select { + case <-ctx.Done(): + return nil, linkErrors(ctx.Err(), ErrNoBytesSent) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + + msg, err := pgConn.receiveMessage() + return msg, err +} + +// receiveMessage receives a message without setting up context cancellation +func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { var msg pgproto3.BackendMessage var err error if pgConn.bufferingReceive { @@ -506,7 +563,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ readloop: for { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() return nil, linkErrors(ctx.Err(), err) @@ -616,7 +673,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { defer pgConn.contextWatcher.Unwatch() for { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { return linkErrors(ctx.Err(), err) } @@ -821,7 +878,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm var commandTag CommandTag var pgErr error for { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() return nil, linkErrors(ctx.Err(), err) @@ -882,7 +939,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co var pgErr error pendingCopyInResponse := true for pendingCopyInResponse { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() return nil, linkErrors(ctx.Err(), err) @@ -920,7 +977,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co select { case <-signalMessageChan: - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() return nil, linkErrors(ctx.Err(), err) @@ -950,7 +1007,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co // Read results for { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() return nil, linkErrors(ctx.Err(), err) @@ -991,7 +1048,7 @@ func (mrr *MultiResultReader) ReadAll() ([]*Result, error) { } func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) { - msg, err := mrr.pgConn.ReceiveMessage() + msg, err := mrr.pgConn.receiveMessage() if err != nil { mrr.pgConn.contextWatcher.Unwatch() @@ -1176,7 +1233,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) { if rr.multiResultReader == nil { - msg, err = rr.pgConn.ReceiveMessage() + msg, err = rr.pgConn.receiveMessage() } else { msg, err = rr.multiResultReader.receiveMessage() } diff --git a/pgconn_test.go b/pgconn_test.go index 1b90b9d2..f385bc19 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -18,6 +18,7 @@ import ( "time" "github.com/jackc/pgconn" + "github.com/jackc/pgproto3/v2" errors "golang.org/x/xerrors" "github.com/stretchr/testify/assert" @@ -1416,6 +1417,45 @@ func TestConnCancelRequest(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnSendBytesAndReceiveMessage(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + queryMsg := pgproto3.Query{String: "select 42"} + buf := queryMsg.Encode(nil) + + err = pgConn.SendBytes(ctx, buf) + require.NoError(t, err) + + msg, err := pgConn.ReceiveMessage(ctx) + require.NoError(t, err) + _, ok := msg.(*pgproto3.RowDescription) + require.True(t, ok) + + msg, err = pgConn.ReceiveMessage(ctx) + require.NoError(t, err) + _, ok = msg.(*pgproto3.DataRow) + require.True(t, ok) + + msg, err = pgConn.ReceiveMessage(ctx) + require.NoError(t, err) + _, ok = msg.(*pgproto3.CommandComplete) + require.True(t, ok) + + msg, err = pgConn.ReceiveMessage(ctx) + require.NoError(t, err) + _, ok = msg.(*pgproto3.ReadyForQuery) + require.True(t, ok) + + ensureConnValid(t, pgConn) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) if err != nil { From 11255efe7af4e7c2ab77e863f245f42f4ca6b4c5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 20 Aug 2019 15:49:57 -0500 Subject: [PATCH 07/27] Make ErrorResponseToPgError public --- pgconn.go | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/pgconn.go b/pgconn.go index abbc2d10..e51d40e8 100644 --- a/pgconn.go +++ b/pgconn.go @@ -233,7 +233,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig // handled by ReceiveMessage case *pgproto3.ErrorResponse: pgConn.conn.Close() - return nil, errorResponseToPgError(msg) + return nil, ErrorResponseToPgError(msg) default: pgConn.conn.Close() return nil, errors.New("unexpected message") @@ -400,7 +400,7 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { case *pgproto3.ErrorResponse: if msg.Severity == "FATAL" { pgConn.hardClose() - return nil, errorResponseToPgError(msg) + return nil, ErrorResponseToPgError(msg) } case *pgproto3.NoticeResponse: if pgConn.Config.OnNotice != nil { @@ -577,7 +577,7 @@ readloop: psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields)) copy(psd.Fields, msg.Fields) case *pgproto3.ErrorResponse: - parseErr = errorResponseToPgError(msg) + parseErr = ErrorResponseToPgError(msg) case *pgproto3.ReadyForQuery: break readloop } @@ -589,7 +589,8 @@ readloop: return psd, nil } -func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { +// ErrorResponseToPgError converts a wire protocol error message to a *PgError. +func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { return &PgError{ Severity: msg.Severity, Code: string(msg.Code), @@ -612,7 +613,7 @@ func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { } func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice { - pgerr := errorResponseToPgError((*pgproto3.ErrorResponse)(msg)) + pgerr := ErrorResponseToPgError((*pgproto3.ErrorResponse)(msg)) return (*Notice)(pgerr) } @@ -898,7 +899,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm case *pgproto3.CommandComplete: commandTag = CommandTag(msg.CommandTag) case *pgproto3.ErrorResponse: - pgErr = errorResponseToPgError(msg) + pgErr = ErrorResponseToPgError(msg) } } } @@ -949,7 +950,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co case *pgproto3.CopyInResponse: pendingCopyInResponse = false case *pgproto3.ErrorResponse: - pgErr = errorResponseToPgError(msg) + pgErr = ErrorResponseToPgError(msg) case *pgproto3.ReadyForQuery: return commandTag, pgErr } @@ -985,7 +986,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co switch msg := msg.(type) { case *pgproto3.ErrorResponse: - pgErr = errorResponseToPgError(msg) + pgErr = ErrorResponseToPgError(msg) } default: } @@ -1019,7 +1020,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co case *pgproto3.CommandComplete: commandTag = CommandTag(msg.CommandTag) case *pgproto3.ErrorResponse: - pgErr = errorResponseToPgError(msg) + pgErr = ErrorResponseToPgError(msg) } } } @@ -1064,7 +1065,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) mrr.closed = true mrr.pgConn.unlock() case *pgproto3.ErrorResponse: - mrr.err = errorResponseToPgError(msg) + mrr.err = ErrorResponseToPgError(msg) } return msg, nil @@ -1219,7 +1220,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { switch msg := msg.(type) { // Detect a deferred constraint violation where the ErrorResponse is sent after CommandComplete. case *pgproto3.ErrorResponse: - rr.err = errorResponseToPgError(msg) + rr.err = ErrorResponseToPgError(msg) case *pgproto3.ReadyForQuery: rr.pgConn.contextWatcher.Unwatch() rr.pgConn.unlock() @@ -1255,7 +1256,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error case *pgproto3.CommandComplete: rr.concludeCommand(CommandTag(msg.CommandTag), nil) case *pgproto3.ErrorResponse: - rr.concludeCommand(nil, errorResponseToPgError(msg)) + rr.concludeCommand(nil, ErrorResponseToPgError(msg)) } return msg, nil From 1558987979c58286747e7c90ab181adc1560f027 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 22 Aug 2019 20:11:27 -0500 Subject: [PATCH 08/27] ReceiveMessage returns context error instead of io error on cancel --- pgconn.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pgconn.go b/pgconn.go index e51d40e8..5d84871b 100644 --- a/pgconn.go +++ b/pgconn.go @@ -361,6 +361,9 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa defer pgConn.contextWatcher.Unwatch() msg, err := pgConn.receiveMessage() + if err != nil { + err = linkErrors(ctx.Err(), err) + } return msg, err } From 760dd75542eb13b37333e0e134b3463efade7cb4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 09:28:44 -0500 Subject: [PATCH 09/27] Require Config to be created by ParseConfig --- config.go | 15 ++++++++++----- pgconn.go | 19 ++++++------------- pgconn_test.go | 8 ++++++++ 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/config.go b/config.go index be8bdab4..a861ff5f 100644 --- a/config.go +++ b/config.go @@ -26,7 +26,8 @@ import ( type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error -// Config is the settings used to establish a connection to a PostgreSQL server. +// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig and +// then it can be modified. A manually initialized Config will cause ConnectConfig to panic. type Config struct { Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) Port uint16 @@ -55,6 +56,8 @@ type Config struct { // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. OnNotification NotificationHandler + + createdByParseConfig bool // Used to enforce created by ParseConfig rule. } // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a @@ -157,10 +160,12 @@ func ParseConfig(connString string) (*Config, error) { } config := &Config{ - Database: settings["database"], - User: settings["user"], - Password: settings["password"], - RuntimeParams: make(map[string]string), + createdByParseConfig: true, + Database: settings["database"], + User: settings["user"], + Password: settings["password"], + RuntimeParams: make(map[string]string), + BuildFrontendFunc: makeDefaultBuildFrontendFunc(), } if connectTimeout, present := settings["connect_timeout"]; present { diff --git a/pgconn.go b/pgconn.go index 5d84871b..b0e4cfd2 100644 --- a/pgconn.go +++ b/pgconn.go @@ -99,25 +99,18 @@ func Connect(ctx context.Context, connString string) (*PgConn, error) { return ConnectConfig(ctx, config) } -// Connect establishes a connection to a PostgreSQL server using config. ctx can be used to cancel a connect attempt. +// Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with +// ParseConfig. ctx can be used to cancel a connect attempt. // // If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An // authentication error will terminate the chain of attempts (like libpq: // https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise, // if all attempts fail the last error is returned. func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err error) { - // For convenience set a few defaults if not already set. This makes it simpler to directly construct a config. - if config.Port == 0 { - config.Port = 5432 - } - if config.DialFunc == nil { - config.DialFunc = makeDefaultDialer().DialContext - } - if config.BuildFrontendFunc == nil { - config.BuildFrontendFunc = makeDefaultBuildFrontendFunc() - } - if config.RuntimeParams == nil { - config.RuntimeParams = make(map[string]string) + // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from + // zero values. + if !config.createdByParseConfig { + panic("config must be created by ParseConfig") } // Simplify usage by treating primary config and fallbacks the same. diff --git a/pgconn_test.go b/pgconn_test.go index f385bc19..1cd74024 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -263,6 +263,14 @@ func TestConnectWithAfterConnect(t *testing.T) { assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) } +func TestConnectConfigRequiresConfigFromParseConfig(t *testing.T) { + t.Parallel() + + config := &pgconn.Config{} + + require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgconn.ConnectConfig(context.Background(), config) }) +} + func TestConnPrepareSyntaxError(t *testing.T) { t.Parallel() From e540a0576006af74ed45bea905dbb4d8a5e320bc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 14:16:38 -0500 Subject: [PATCH 10/27] Fix typo in docs --- doc.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc.go b/doc.go index d36eb0fd..cde58cd8 100644 --- a/doc.go +++ b/doc.go @@ -15,7 +15,7 @@ reads all rows into memory. Executing Multiple Queries in a Single Round Trip -Exec and ExecBatch can execute multiple queries in a single round trip. The return readers that iterate over each query +Exec and ExecBatch can execute multiple queries in a single round trip. They return readers that iterate over each query result. The ReadAll method reads all query results into memory. Context Support From e6bd7390678ab23b1fded5035d8364e6fa704f28 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 16:02:27 -0500 Subject: [PATCH 11/27] Add pscache package --- pscache/lrucache.go | 111 ++++++++++++++++++++++++++++++++++++++ pscache/lrucache_test.go | 113 +++++++++++++++++++++++++++++++++++++++ pscache/pscache.go | 52 ++++++++++++++++++ 3 files changed, 276 insertions(+) create mode 100644 pscache/lrucache.go create mode 100644 pscache/lrucache_test.go create mode 100644 pscache/pscache.go diff --git a/pscache/lrucache.go b/pscache/lrucache.go new file mode 100644 index 00000000..d5d6062f --- /dev/null +++ b/pscache/lrucache.go @@ -0,0 +1,111 @@ +package pscache + +import ( + "container/list" + "context" + "fmt" + "sync/atomic" + + "github.com/jackc/pgconn" +) + +var lruCacheCount uint64 + +// LRUCache implements cache with a Least Recently Used (LRU) cache. +type LRUCache struct { + conn *pgconn.PgConn + mode int + cap int + prepareCount int + m map[string]*list.Element + l *list.List + psNamePrefix string +} + +// NewLRUCache creates a new LRUCache. mode is either PrepareMode or DescribeMode. cap is the maximum size of the cache. +func NewLRUCache(conn *pgconn.PgConn, mode int, cap int) *LRUCache { + mustBeValidMode(mode) + mustBeValidCap(cap) + + n := atomic.AddUint64(&lruCacheCount, 1) + + return &LRUCache{ + conn: conn, + mode: mode, + cap: cap, + m: make(map[string]*list.Element), + l: list.New(), + psNamePrefix: fmt.Sprintf("lrupsc_%d", n), + } +} + +// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. +func (c *LRUCache) Get(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) { + if el, ok := c.m[sql]; ok { + c.l.MoveToFront(el) + return el.Value.(*pgconn.PreparedStatementDescription), nil + } + + if c.l.Len() == c.cap { + err := c.removeOldest(ctx) + if err != nil { + return nil, err + } + } + + psd, err := c.prepare(ctx, sql) + if err != nil { + return nil, err + } + + el := c.l.PushFront(psd) + c.m[sql] = el + + return psd, nil +} + +// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. +func (c *LRUCache) Clear(ctx context.Context) error { + for c.l.Len() > 0 { + err := c.removeOldest(ctx) + if err != nil { + return err + } + } + + return nil +} + +// Len returns the number of cached prepared statement descriptions. +func (c *LRUCache) Len() int { + return c.l.Len() +} + +// Cap returns the maximum number of cached prepared statement descriptions. +func (c *LRUCache) Cap() int { + return c.cap +} + +// Mode returns the mode of the cache (PrepareMode or DescribeMode) +func (c *LRUCache) Mode() int { + return c.mode +} + +func (c *LRUCache) prepare(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) { + var name string + if c.mode == PrepareMode { + name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount) + c.prepareCount += 1 + } + + return c.conn.Prepare(ctx, name, sql, nil) +} + +func (c *LRUCache) removeOldest(ctx context.Context) error { + oldest := c.l.Back() + c.l.Remove(oldest) + if c.mode == PrepareMode { + return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", oldest.Value.(*pgconn.PreparedStatementDescription).Name)).Close() + } + return nil +} diff --git a/pscache/lrucache_test.go b/pscache/lrucache_test.go new file mode 100644 index 00000000..bf2fcbe0 --- /dev/null +++ b/pscache/lrucache_test.go @@ -0,0 +1,113 @@ +package pscache_test + +import ( + "context" + "os" + "testing" + "time" + + "github.com/jackc/pgconn" + "github.com/jackc/pgconn/pscache" + + "github.com/stretchr/testify/require" +) + +func TestLRUCachePrepareMode(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer conn.Close(ctx) + + cache := pscache.NewLRUCache(conn, pscache.PrepareMode, 2) + require.EqualValues(t, 0, cache.Len()) + require.EqualValues(t, 2, cache.Cap()) + require.EqualValues(t, pscache.PrepareMode, cache.Mode()) + + psd, err := cache.Get(ctx, "select 1") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 1, cache.Len()) + require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 1") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 1, cache.Len()) + require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 2") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 2, cache.Len()) + require.ElementsMatch(t, []string{"select 1", "select 2"}, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 3") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 2, cache.Len()) + require.ElementsMatch(t, []string{"select 2", "select 3"}, fetchServerStatements(t, ctx, conn)) + + err = cache.Clear(ctx) + require.NoError(t, err) + require.EqualValues(t, 0, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) +} + +func TestLRUCacheDescribeMode(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer conn.Close(ctx) + + cache := pscache.NewLRUCache(conn, pscache.DescribeMode, 2) + require.EqualValues(t, 0, cache.Len()) + require.EqualValues(t, 2, cache.Cap()) + require.EqualValues(t, pscache.DescribeMode, cache.Mode()) + + psd, err := cache.Get(ctx, "select 1") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 1, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 1") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 1, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 2") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 2, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 3") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 2, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) + + err = cache.Clear(ctx) + require.NoError(t, err) + require.EqualValues(t, 0, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) +} + +func fetchServerStatements(t testing.TB, ctx context.Context, conn *pgconn.PgConn) []string { + result := conn.ExecParams(ctx, `select statement from pg_prepared_statements`, nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + var statements []string + for _, r := range result.Rows { + statements = append(statements, string(r[0])) + } + return statements +} diff --git a/pscache/pscache.go b/pscache/pscache.go new file mode 100644 index 00000000..bfd51e81 --- /dev/null +++ b/pscache/pscache.go @@ -0,0 +1,52 @@ +// Package pscache is a cache that can be used to implement lazy, automatic prepared statements. +package pscache + +import ( + "context" + + "github.com/jackc/pgconn" +) + +const ( + PrepareMode = iota // Cache should prepare named statements. + DescribeMode // Cache should prepare the anonymous prepared statement to only fetch the description of the statement. +) + +// Cache prepares and caches prepared statement descriptions. +type Cache interface { + // Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. + Get(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) + + // Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. + Clear(ctx context.Context) error + + // Len returns the number of cached prepared statement descriptions. + Len() int + + // Cap returns the maximum number of cached prepared statement descriptions. + Cap() int + + // Mode returns the mode of the cache (PrepareMode or DescribeMode) + Mode() int +} + +// New returns the preferred cache implementation for mode and cap. mode is either PrepareMode or DescribeMode. cap is +// the maximum size of the cache. +func New(conn *pgconn.PgConn, mode int, cap int) Cache { + mustBeValidMode(mode) + mustBeValidCap(cap) + + return NewLRUCache(conn, mode, cap) +} + +func mustBeValidMode(mode int) { + if mode != PrepareMode && mode != DescribeMode { + panic("mode must be PrepareMode or DescribeMode") + } +} + +func mustBeValidCap(cap int) { + if cap < 1 { + panic("cache must have cap of >= 1") + } +} From 797a44bf048f27e5db5c79dcbf7e406969ca6904 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 16:18:01 -0500 Subject: [PATCH 12/27] Rename BuildFrontendFunc to BuildFrontend For consistency with other functions supplied in Config. --- config.go | 20 ++++++++++---------- pgconn.go | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/config.go b/config.go index a861ff5f..b5c119f5 100644 --- a/config.go +++ b/config.go @@ -29,15 +29,15 @@ type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error // Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig and // then it can be modified. A manually initialized Config will cause ConnectConfig to panic. type Config struct { - Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) - Port uint16 - Database string - User string - Password string - TLSConfig *tls.Config // nil disables TLS - DialFunc DialFunc // e.g. net.Dialer.DialContext - BuildFrontendFunc BuildFrontendFunc - RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) + Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) + Port uint16 + Database string + User string + Password string + TLSConfig *tls.Config // nil disables TLS + DialFunc DialFunc // e.g. net.Dialer.DialContext + BuildFrontend BuildFrontendFunc + RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) Fallbacks []*FallbackConfig @@ -165,7 +165,7 @@ func ParseConfig(connString string) (*Config, error) { User: settings["user"], Password: settings["password"], RuntimeParams: make(map[string]string), - BuildFrontendFunc: makeDefaultBuildFrontendFunc(), + BuildFrontend: makeDefaultBuildFrontendFunc(), } if connectTimeout, present := settings["connect_timeout"]; present { diff --git a/pgconn.go b/pgconn.go index b0e4cfd2..fe2f304e 100644 --- a/pgconn.go +++ b/pgconn.go @@ -174,7 +174,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig func() { pgConn.conn.SetDeadline(time.Time{}) }, ) - pgConn.frontend = config.BuildFrontendFunc(pgConn.conn) + pgConn.frontend = config.BuildFrontend(pgConn.conn) startupMsg := pgproto3.StartupMessage{ ProtocolVersion: pgproto3.ProtocolVersionNumber, From 2209d2e36aea43ee17610489a2644af2212a4bc3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 16:27:54 -0500 Subject: [PATCH 13/27] Rename mode constants --- pscache/lrucache.go | 8 ++++---- pscache/lrucache_test.go | 12 ++++++------ pscache/pscache.go | 12 ++++++------ 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/pscache/lrucache.go b/pscache/lrucache.go index d5d6062f..cdcec63c 100644 --- a/pscache/lrucache.go +++ b/pscache/lrucache.go @@ -22,7 +22,7 @@ type LRUCache struct { psNamePrefix string } -// NewLRUCache creates a new LRUCache. mode is either PrepareMode or DescribeMode. cap is the maximum size of the cache. +// NewLRUCache creates a new LRUCache. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache. func NewLRUCache(conn *pgconn.PgConn, mode int, cap int) *LRUCache { mustBeValidMode(mode) mustBeValidCap(cap) @@ -86,14 +86,14 @@ func (c *LRUCache) Cap() int { return c.cap } -// Mode returns the mode of the cache (PrepareMode or DescribeMode) +// Mode returns the mode of the cache (ModePrepare or ModeDescribe) func (c *LRUCache) Mode() int { return c.mode } func (c *LRUCache) prepare(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) { var name string - if c.mode == PrepareMode { + if c.mode == ModePrepare { name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount) c.prepareCount += 1 } @@ -104,7 +104,7 @@ func (c *LRUCache) prepare(ctx context.Context, sql string) (*pgconn.PreparedSta func (c *LRUCache) removeOldest(ctx context.Context) error { oldest := c.l.Back() c.l.Remove(oldest) - if c.mode == PrepareMode { + if c.mode == ModePrepare { return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", oldest.Value.(*pgconn.PreparedStatementDescription).Name)).Close() } return nil diff --git a/pscache/lrucache_test.go b/pscache/lrucache_test.go index bf2fcbe0..a5d413e3 100644 --- a/pscache/lrucache_test.go +++ b/pscache/lrucache_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestLRUCachePrepareMode(t *testing.T) { +func TestLRUCacheModePrepare(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) @@ -22,10 +22,10 @@ func TestLRUCachePrepareMode(t *testing.T) { require.NoError(t, err) defer conn.Close(ctx) - cache := pscache.NewLRUCache(conn, pscache.PrepareMode, 2) + cache := pscache.NewLRUCache(conn, pscache.ModePrepare, 2) require.EqualValues(t, 0, cache.Len()) require.EqualValues(t, 2, cache.Cap()) - require.EqualValues(t, pscache.PrepareMode, cache.Mode()) + require.EqualValues(t, pscache.ModePrepare, cache.Mode()) psd, err := cache.Get(ctx, "select 1") require.NoError(t, err) @@ -57,7 +57,7 @@ func TestLRUCachePrepareMode(t *testing.T) { require.Empty(t, fetchServerStatements(t, ctx, conn)) } -func TestLRUCacheDescribeMode(t *testing.T) { +func TestLRUCacheModeDescribe(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) @@ -67,10 +67,10 @@ func TestLRUCacheDescribeMode(t *testing.T) { require.NoError(t, err) defer conn.Close(ctx) - cache := pscache.NewLRUCache(conn, pscache.DescribeMode, 2) + cache := pscache.NewLRUCache(conn, pscache.ModeDescribe, 2) require.EqualValues(t, 0, cache.Len()) require.EqualValues(t, 2, cache.Cap()) - require.EqualValues(t, pscache.DescribeMode, cache.Mode()) + require.EqualValues(t, pscache.ModeDescribe, cache.Mode()) psd, err := cache.Get(ctx, "select 1") require.NoError(t, err) diff --git a/pscache/pscache.go b/pscache/pscache.go index bfd51e81..4f8cf723 100644 --- a/pscache/pscache.go +++ b/pscache/pscache.go @@ -8,8 +8,8 @@ import ( ) const ( - PrepareMode = iota // Cache should prepare named statements. - DescribeMode // Cache should prepare the anonymous prepared statement to only fetch the description of the statement. + ModePrepare = iota // Cache should prepare named statements. + ModeDescribe // Cache should prepare the anonymous prepared statement to only fetch the description of the statement. ) // Cache prepares and caches prepared statement descriptions. @@ -26,11 +26,11 @@ type Cache interface { // Cap returns the maximum number of cached prepared statement descriptions. Cap() int - // Mode returns the mode of the cache (PrepareMode or DescribeMode) + // Mode returns the mode of the cache (ModePrepare or ModeDescribe) Mode() int } -// New returns the preferred cache implementation for mode and cap. mode is either PrepareMode or DescribeMode. cap is +// New returns the preferred cache implementation for mode and cap. mode is either ModePrepare or ModeDescribe. cap is // the maximum size of the cache. func New(conn *pgconn.PgConn, mode int, cap int) Cache { mustBeValidMode(mode) @@ -40,8 +40,8 @@ func New(conn *pgconn.PgConn, mode int, cap int) Cache { } func mustBeValidMode(mode int) { - if mode != PrepareMode && mode != DescribeMode { - panic("mode must be PrepareMode or DescribeMode") + if mode != ModePrepare && mode != ModeDescribe { + panic("mode must be ModePrepare or ModeDescribe") } } From beba629bb5d526f8d7de6ec8754090d39b476757 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 17:18:29 -0500 Subject: [PATCH 14/27] Fix result reader returned by locked conn --- pgconn.go | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pgconn.go b/pgconn.go index fe2f304e..797080bd 100644 --- a/pgconn.go +++ b/pgconn.go @@ -791,19 +791,18 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa } func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader { - if err := pgConn.lock(); err != nil { - return &ResultReader{ - closed: true, - err: linkErrors(err, ErrNoBytesSent), - } - } - pgConn.resultReader = ResultReader{ pgConn: pgConn, ctx: ctx, } result := &pgConn.resultReader + if err := pgConn.lock(); err != nil { + result.concludeCommand(nil, linkErrors(err, ErrNoBytesSent)) + result.closed = true + return result + } + if len(paramValues) > math.MaxUint16 { result.concludeCommand(nil, errors.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) result.closed = true From bcd6b9244ab8fc80e85b75b604bf214f82345e59 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 19:46:14 -0500 Subject: [PATCH 15/27] Rename pscache to stmtcache --- {pscache => stmtcache}/lrucache.go | 2 +- {pscache => stmtcache}/lrucache_test.go | 12 ++++++------ pscache/pscache.go => stmtcache/stmtcache.go | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) rename {pscache => stmtcache}/lrucache.go (99%) rename {pscache => stmtcache}/lrucache_test.go (90%) rename pscache/pscache.go => stmtcache/stmtcache.go (92%) diff --git a/pscache/lrucache.go b/stmtcache/lrucache.go similarity index 99% rename from pscache/lrucache.go rename to stmtcache/lrucache.go index cdcec63c..9c4d046d 100644 --- a/pscache/lrucache.go +++ b/stmtcache/lrucache.go @@ -1,4 +1,4 @@ -package pscache +package stmtcache import ( "container/list" diff --git a/pscache/lrucache_test.go b/stmtcache/lrucache_test.go similarity index 90% rename from pscache/lrucache_test.go rename to stmtcache/lrucache_test.go index a5d413e3..ed8ebdc3 100644 --- a/pscache/lrucache_test.go +++ b/stmtcache/lrucache_test.go @@ -1,4 +1,4 @@ -package pscache_test +package stmtcache_test import ( "context" @@ -7,7 +7,7 @@ import ( "time" "github.com/jackc/pgconn" - "github.com/jackc/pgconn/pscache" + "github.com/jackc/pgconn/stmtcache" "github.com/stretchr/testify/require" ) @@ -22,10 +22,10 @@ func TestLRUCacheModePrepare(t *testing.T) { require.NoError(t, err) defer conn.Close(ctx) - cache := pscache.NewLRUCache(conn, pscache.ModePrepare, 2) + cache := stmtcache.NewLRUCache(conn, stmtcache.ModePrepare, 2) require.EqualValues(t, 0, cache.Len()) require.EqualValues(t, 2, cache.Cap()) - require.EqualValues(t, pscache.ModePrepare, cache.Mode()) + require.EqualValues(t, stmtcache.ModePrepare, cache.Mode()) psd, err := cache.Get(ctx, "select 1") require.NoError(t, err) @@ -67,10 +67,10 @@ func TestLRUCacheModeDescribe(t *testing.T) { require.NoError(t, err) defer conn.Close(ctx) - cache := pscache.NewLRUCache(conn, pscache.ModeDescribe, 2) + cache := stmtcache.NewLRUCache(conn, stmtcache.ModeDescribe, 2) require.EqualValues(t, 0, cache.Len()) require.EqualValues(t, 2, cache.Cap()) - require.EqualValues(t, pscache.ModeDescribe, cache.Mode()) + require.EqualValues(t, stmtcache.ModeDescribe, cache.Mode()) psd, err := cache.Get(ctx, "select 1") require.NoError(t, err) diff --git a/pscache/pscache.go b/stmtcache/stmtcache.go similarity index 92% rename from pscache/pscache.go rename to stmtcache/stmtcache.go index 4f8cf723..d70f277b 100644 --- a/pscache/pscache.go +++ b/stmtcache/stmtcache.go @@ -1,5 +1,5 @@ -// Package pscache is a cache that can be used to implement lazy, automatic prepared statements. -package pscache +// Package stmtcache is a cache that can be used to implement lazy prepared statements. +package stmtcache import ( "context" From 78abbdf1d7eef6b2aa78831c31141057876537f6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 19:48:43 -0500 Subject: [PATCH 16/27] Rename LRUCache to LRU --- stmtcache/{lrucache.go => lru.go} | 28 ++++++++++----------- stmtcache/{lrucache_test.go => lru_test.go} | 8 +++--- stmtcache/stmtcache.go | 2 +- 3 files changed, 19 insertions(+), 19 deletions(-) rename stmtcache/{lrucache.go => lru.go} (70%) rename stmtcache/{lrucache_test.go => lru_test.go} (93%) diff --git a/stmtcache/lrucache.go b/stmtcache/lru.go similarity index 70% rename from stmtcache/lrucache.go rename to stmtcache/lru.go index 9c4d046d..432a70b4 100644 --- a/stmtcache/lrucache.go +++ b/stmtcache/lru.go @@ -9,10 +9,10 @@ import ( "github.com/jackc/pgconn" ) -var lruCacheCount uint64 +var lruCount uint64 -// LRUCache implements cache with a Least Recently Used (LRU) cache. -type LRUCache struct { +// LRU implements Cache with a Least Recently Used (LRU) cache. +type LRU struct { conn *pgconn.PgConn mode int cap int @@ -22,14 +22,14 @@ type LRUCache struct { psNamePrefix string } -// NewLRUCache creates a new LRUCache. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache. -func NewLRUCache(conn *pgconn.PgConn, mode int, cap int) *LRUCache { +// NewLRU creates a new LRU. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache. +func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU { mustBeValidMode(mode) mustBeValidCap(cap) - n := atomic.AddUint64(&lruCacheCount, 1) + n := atomic.AddUint64(&lruCount, 1) - return &LRUCache{ + return &LRU{ conn: conn, mode: mode, cap: cap, @@ -40,7 +40,7 @@ func NewLRUCache(conn *pgconn.PgConn, mode int, cap int) *LRUCache { } // Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. -func (c *LRUCache) Get(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) { +func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) { if el, ok := c.m[sql]; ok { c.l.MoveToFront(el) return el.Value.(*pgconn.PreparedStatementDescription), nil @@ -65,7 +65,7 @@ func (c *LRUCache) Get(ctx context.Context, sql string) (*pgconn.PreparedStateme } // Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. -func (c *LRUCache) Clear(ctx context.Context) error { +func (c *LRU) Clear(ctx context.Context) error { for c.l.Len() > 0 { err := c.removeOldest(ctx) if err != nil { @@ -77,21 +77,21 @@ func (c *LRUCache) Clear(ctx context.Context) error { } // Len returns the number of cached prepared statement descriptions. -func (c *LRUCache) Len() int { +func (c *LRU) Len() int { return c.l.Len() } // Cap returns the maximum number of cached prepared statement descriptions. -func (c *LRUCache) Cap() int { +func (c *LRU) Cap() int { return c.cap } // Mode returns the mode of the cache (ModePrepare or ModeDescribe) -func (c *LRUCache) Mode() int { +func (c *LRU) Mode() int { return c.mode } -func (c *LRUCache) prepare(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) { +func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) { var name string if c.mode == ModePrepare { name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount) @@ -101,7 +101,7 @@ func (c *LRUCache) prepare(ctx context.Context, sql string) (*pgconn.PreparedSta return c.conn.Prepare(ctx, name, sql, nil) } -func (c *LRUCache) removeOldest(ctx context.Context) error { +func (c *LRU) removeOldest(ctx context.Context) error { oldest := c.l.Back() c.l.Remove(oldest) if c.mode == ModePrepare { diff --git a/stmtcache/lrucache_test.go b/stmtcache/lru_test.go similarity index 93% rename from stmtcache/lrucache_test.go rename to stmtcache/lru_test.go index ed8ebdc3..b518364e 100644 --- a/stmtcache/lrucache_test.go +++ b/stmtcache/lru_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestLRUCacheModePrepare(t *testing.T) { +func TestLRUModePrepare(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) @@ -22,7 +22,7 @@ func TestLRUCacheModePrepare(t *testing.T) { require.NoError(t, err) defer conn.Close(ctx) - cache := stmtcache.NewLRUCache(conn, stmtcache.ModePrepare, 2) + cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2) require.EqualValues(t, 0, cache.Len()) require.EqualValues(t, 2, cache.Cap()) require.EqualValues(t, stmtcache.ModePrepare, cache.Mode()) @@ -57,7 +57,7 @@ func TestLRUCacheModePrepare(t *testing.T) { require.Empty(t, fetchServerStatements(t, ctx, conn)) } -func TestLRUCacheModeDescribe(t *testing.T) { +func TestLRUModeDescribe(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) @@ -67,7 +67,7 @@ func TestLRUCacheModeDescribe(t *testing.T) { require.NoError(t, err) defer conn.Close(ctx) - cache := stmtcache.NewLRUCache(conn, stmtcache.ModeDescribe, 2) + cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2) require.EqualValues(t, 0, cache.Len()) require.EqualValues(t, 2, cache.Cap()) require.EqualValues(t, stmtcache.ModeDescribe, cache.Mode()) diff --git a/stmtcache/stmtcache.go b/stmtcache/stmtcache.go index d70f277b..9bedf549 100644 --- a/stmtcache/stmtcache.go +++ b/stmtcache/stmtcache.go @@ -36,7 +36,7 @@ func New(conn *pgconn.PgConn, mode int, cap int) Cache { mustBeValidMode(mode) mustBeValidCap(cap) - return NewLRUCache(conn, mode, cap) + return NewLRU(conn, mode, cap) } func mustBeValidMode(mode int) { From da9fc85c4404a53f910e2f8210be5add1bc50454 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 20:39:01 -0500 Subject: [PATCH 17/27] Rename PreparedStatementDescription to StatementDescription PreparedStatementDescription was too long. It also no longer entirely represents its purpose now that it is also intended for use with described statements. --- pgconn.go | 9 +++++---- stmtcache/lru.go | 8 ++++---- stmtcache/stmtcache.go | 2 +- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/pgconn.go b/pgconn.go index 797080bd..8f3291f1 100644 --- a/pgconn.go +++ b/pgconn.go @@ -517,15 +517,16 @@ func (ct CommandTag) String() string { return string(ct) } -type PreparedStatementDescription struct { +type StatementDescription struct { Name string SQL string ParamOIDs []uint32 Fields []pgproto3.FieldDescription } -// Prepare creates a prepared statement. -func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) { +// Prepare creates a prepared statement. If the name is empty, the anonymous prepared statement will be used. This +// allows Prepare to also to describe statements without creating a server-side prepared statement. +func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) { if err := pgConn.lock(); err != nil { return nil, linkErrors(err, ErrNoBytesSent) } @@ -553,7 +554,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ return nil, linkErrors(ctx.Err(), err) } - psd := &PreparedStatementDescription{Name: name, SQL: sql} + psd := &StatementDescription{Name: name, SQL: sql} var parseErr error diff --git a/stmtcache/lru.go b/stmtcache/lru.go index 432a70b4..fff4d0b7 100644 --- a/stmtcache/lru.go +++ b/stmtcache/lru.go @@ -40,10 +40,10 @@ func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU { } // Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. -func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) { +func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) { if el, ok := c.m[sql]; ok { c.l.MoveToFront(el) - return el.Value.(*pgconn.PreparedStatementDescription), nil + return el.Value.(*pgconn.StatementDescription), nil } if c.l.Len() == c.cap { @@ -91,7 +91,7 @@ func (c *LRU) Mode() int { return c.mode } -func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) { +func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.StatementDescription, error) { var name string if c.mode == ModePrepare { name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount) @@ -105,7 +105,7 @@ func (c *LRU) removeOldest(ctx context.Context) error { oldest := c.l.Back() c.l.Remove(oldest) if c.mode == ModePrepare { - return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", oldest.Value.(*pgconn.PreparedStatementDescription).Name)).Close() + return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", oldest.Value.(*pgconn.StatementDescription).Name)).Close() } return nil } diff --git a/stmtcache/stmtcache.go b/stmtcache/stmtcache.go index 9bedf549..96215799 100644 --- a/stmtcache/stmtcache.go +++ b/stmtcache/stmtcache.go @@ -15,7 +15,7 @@ const ( // Cache prepares and caches prepared statement descriptions. type Cache interface { // Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. - Get(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) + Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) // Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. Clear(ctx context.Context) error From 6feea0c1c57d8ec5ff0cd806354437ed03b415f6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 23:43:26 -0500 Subject: [PATCH 18/27] Replace IsAlive with IsClosed IsAlive is ambiguous because the connection may be dead and we do not know it. It implies the possibility of a ping. IsClosed is clearer -- it does not promise the connection is alive only that it hasn't been closed. fixes #2 --- pgconn.go | 7 +++---- pgconn_test.go | 10 +++++----- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/pgconn.go b/pgconn.go index 8f3291f1..153829ca 100644 --- a/pgconn.go +++ b/pgconn.go @@ -463,10 +463,9 @@ func (pgConn *PgConn) hardClose() error { return pgConn.conn.Close() } -// TODO - rethink how to report status. At the moment this is just a temporary measure so pgx.Conn can detect death of -// underlying connection. -func (pgConn *PgConn) IsAlive() bool { - return pgConn.status >= connStatusIdle +// IsClosed reports if the connection has been closed. +func (pgConn *PgConn) IsClosed() bool { + return pgConn.status < connStatusIdle } // lock locks the connection. It panics if the connection is already locked or is closed. diff --git a/pgconn_test.go b/pgconn_test.go index 1cd74024..64628262 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -433,7 +433,7 @@ func TestConnExecContextCanceled(t *testing.T) { } err = multiResult.Close() assert.Equal(t, context.DeadlineExceeded, err) - assert.False(t, pgConn.IsAlive()) + assert.True(t, pgConn.IsClosed()) } func TestConnExecContextPrecanceled(t *testing.T) { @@ -566,7 +566,7 @@ func TestConnExecParamsCanceled(t *testing.T) { assert.Equal(t, pgconn.CommandTag(nil), commandTag) assert.Equal(t, context.DeadlineExceeded, err) - assert.False(t, pgConn.IsAlive()) + assert.True(t, pgConn.IsClosed()) } func TestConnExecParamsPrecanceled(t *testing.T) { @@ -692,7 +692,7 @@ func TestConnExecPreparedCanceled(t *testing.T) { commandTag, err := result.Close() assert.Equal(t, pgconn.CommandTag(nil), commandTag) assert.Equal(t, context.DeadlineExceeded, err) - assert.False(t, pgConn.IsAlive()) + assert.True(t, pgConn.IsClosed()) } func TestConnExecPreparedPrecanceled(t *testing.T) { @@ -1142,7 +1142,7 @@ func TestConnCopyToCanceled(t *testing.T) { assert.True(t, errors.Is(err, context.DeadlineExceeded)) assert.Equal(t, pgconn.CommandTag(nil), res) - assert.False(t, pgConn.IsAlive()) + assert.True(t, pgConn.IsClosed()) } func TestConnCopyToPrecanceled(t *testing.T) { @@ -1233,7 +1233,7 @@ func TestConnCopyFromCanceled(t *testing.T) { assert.Equal(t, int64(0), ct.RowsAffected()) assert.True(t, errors.Is(err, context.DeadlineExceeded)) - assert.False(t, pgConn.IsAlive()) + assert.True(t, pgConn.IsClosed()) } func TestConnCopyFromPrecanceled(t *testing.T) { From 595d09d6f1bfba423db8d00f61efebf0aaa6a85a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 23:57:24 -0500 Subject: [PATCH 19/27] Build fully operational Frontend --- config.go | 4 ++-- pgconn.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/config.go b/config.go index b5c119f5..e078a061 100644 --- a/config.go +++ b/config.go @@ -482,8 +482,8 @@ func makeDefaultDialer() *net.Dialer { } func makeDefaultBuildFrontendFunc() BuildFrontendFunc { - return func(r io.Reader) Frontend { - frontend, _ := pgproto3.NewFrontend(pgproto3.NewChunkReader(r), nil) + return func(r io.Reader, w io.Writer) Frontend { + frontend, _ := pgproto3.NewFrontend(pgproto3.NewChunkReader(r), w) return frontend } diff --git a/pgconn.go b/pgconn.go index 153829ca..7d301af2 100644 --- a/pgconn.go +++ b/pgconn.go @@ -44,7 +44,7 @@ type Notification struct { type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) // BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. -type BuildFrontendFunc func(r io.Reader) Frontend +type BuildFrontendFunc func(r io.Reader, w io.Writer) Frontend // NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at // any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin @@ -174,7 +174,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig func() { pgConn.conn.SetDeadline(time.Time{}) }, ) - pgConn.frontend = config.BuildFrontend(pgConn.conn) + pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn) startupMsg := pgproto3.StartupMessage{ ProtocolVersion: pgproto3.ProtocolVersionNumber, From e6cf51b304f1d6961663ede4ba89be363fc54237 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 25 Aug 2019 00:22:32 -0500 Subject: [PATCH 20/27] Expose min_read_buffer_size config param --- config.go | 24 +++++++++++++++++++++--- config_test.go | 12 ++++++++++++ go.mod | 1 + 3 files changed, 34 insertions(+), 3 deletions(-) diff --git a/config.go b/config.go index e078a061..cb153c77 100644 --- a/config.go +++ b/config.go @@ -18,6 +18,7 @@ import ( "strings" "time" + "github.com/jackc/chunkreader/v2" "github.com/jackc/pgpassfile" "github.com/jackc/pgproto3/v2" errors "golang.org/x/xerrors" @@ -140,6 +141,11 @@ func NetworkAddress(host string, port uint16) (network, address string) { // // When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn // does not. +// +// In addition, ParseConfig accepts the following options: +// +// min_read_buffer_size +// The minimum size of the internal read buffer. Default 8192. func ParseConfig(connString string) (*Config, error) { settings := defaultSettings() addEnvSettings(settings) @@ -159,13 +165,18 @@ func ParseConfig(connString string) (*Config, error) { } } + minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32) + if err != nil { + return nil, errors.Errorf("cannot parse min_read_buffer_size: %w", err) + } + config := &Config{ createdByParseConfig: true, Database: settings["database"], User: settings["user"], Password: settings["password"], RuntimeParams: make(map[string]string), - BuildFrontend: makeDefaultBuildFrontendFunc(), + BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)), } if connectTimeout, present := settings["connect_timeout"]; present { @@ -192,6 +203,7 @@ func ParseConfig(connString string) (*Config, error) { "sslcert": struct{}{}, "sslrootcert": struct{}{}, "target_session_attrs": struct{}{}, + "min_read_buffer_size": struct{}{}, } for k, v := range settings { @@ -284,6 +296,8 @@ func defaultSettings() map[string]string { settings["target_session_attrs"] = "any" + settings["min_read_buffer_size"] = "8192" + return settings } @@ -481,9 +495,13 @@ func makeDefaultDialer() *net.Dialer { return &net.Dialer{KeepAlive: 5 * time.Minute} } -func makeDefaultBuildFrontendFunc() BuildFrontendFunc { +func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc { return func(r io.Reader, w io.Writer) Frontend { - frontend, _ := pgproto3.NewFrontend(pgproto3.NewChunkReader(r), w) + cr, err := chunkreader.NewConfig(r, chunkreader.Config{MinBufLen: minBufferLen}) + if err != nil { + panic(fmt.Sprintf("BUG: chunkreader.NewConfig failed: %v", err)) + } + frontend, _ := pgproto3.NewFrontend(cr, w) return frontend } diff --git a/config_test.go b/config_test.go index 23d86529..af42094d 100644 --- a/config_test.go +++ b/config_test.go @@ -561,3 +561,15 @@ func TestParseConfigReadsPgPassfile(t *testing.T) { assertConfigsEqual(t, expected, actual, "passfile") } + +func TestParseConfigExtractsMinReadBufferSize(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig("min_read_buffer_size=0") + require.NoError(t, err) + _, present := config.RuntimeParams["min_read_buffer_size"] + require.False(t, present) + + // The buffer size is internal so there isn't much that can be done to test it other than see that the runtime param + // was removed. +} diff --git a/go.mod b/go.mod index b1c84049..cbeef02a 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/jackc/pgconn go 1.12 require ( + github.com/jackc/chunkreader/v2 v2.0.0 github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 From 138254da5b02b80a548f7858f01636f9a426b918 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 27 Aug 2019 18:01:59 -0500 Subject: [PATCH 21/27] Refactor errors - Use strongly typed errors internally - SafeToRetry(error) streamlines retry logic over ErrNoBytesSent - Timeout(error) removes the need to choose between returning a context and an i/o error --- config.go | 14 ++--- errors.go | 156 ++++++++++++++++++++++++++++++++++++------------- pgconn.go | 125 +++++++++++++++++---------------------- pgconn_test.go | 41 ++++++------- 4 files changed, 195 insertions(+), 141 deletions(-) diff --git a/config.go b/config.go index cb153c77..d24d0202 100644 --- a/config.go +++ b/config.go @@ -155,19 +155,19 @@ func ParseConfig(connString string) (*Config, error) { if strings.HasPrefix(connString, "postgres://") { err := addURLSettings(settings, connString) if err != nil { - return nil, err + return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err} } } else { err := addDSNSettings(settings, connString) if err != nil { - return nil, err + return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err} } } } minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32) if err != nil { - return nil, errors.Errorf("cannot parse min_read_buffer_size: %w", err) + return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err} } config := &Config{ @@ -182,7 +182,7 @@ func ParseConfig(connString string) (*Config, error) { if connectTimeout, present := settings["connect_timeout"]; present { dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout) if err != nil { - return nil, err + return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err} } config.DialFunc = dialFunc } else { @@ -228,7 +228,7 @@ func ParseConfig(connString string) (*Config, error) { port, err := parsePort(portStr) if err != nil { - return nil, errors.Errorf("invalid port: %w", err) + return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err} } var tlsConfigs []*tls.Config @@ -240,7 +240,7 @@ func ParseConfig(connString string) (*Config, error) { var err error tlsConfigs, err = configTLS(settings) if err != nil { - return nil, err + return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err} } } @@ -273,7 +273,7 @@ func ParseConfig(connString string) (*Config, error) { if settings["target_session_attrs"] == "read-write" { config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite } else if settings["target_session_attrs"] != "any" { - return nil, errors.Errorf("unknown target_session_attrs value: %v", settings["target_session_attrs"]) + return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", settings["target_session_attrs"])} } return config, nil diff --git a/errors.go b/errors.go index 4f8af407..a088dcdd 100644 --- a/errors.go +++ b/errors.go @@ -2,22 +2,31 @@ package pgconn import ( "context" + "fmt" "net" + "strings" errors "golang.org/x/xerrors" ) -// ErrTLSRefused occurs when the connection attempt requires TLS and the -// PostgreSQL server refuses to use TLS -var ErrTLSRefused = errors.New("server refused TLS connection") +// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server. +func SafeToRetry(err error) bool { + if e, ok := err.(interface{ SafeToRetry() bool }); ok { + return e.SafeToRetry() + } + return false +} -// ErrConnBusy occurs when the connection is busy (for example, in the middle of reading query results) and another -// action is attempted. -var ErrConnBusy = errors.New("conn is busy") +// Timeout checks if err was was caused by a timeout. To be specific, it is true if err is or was caused by a +// context.Canceled, context.Canceled or an implementer of net.Error where Timeout() is true. +func Timeout(err error) bool { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return true + } -// ErrNoBytesSent is used to annotate an error that occurred without sending any bytes to the server. This can be used -// to implement safe retry logic. ErrNoBytesSent will never occur alone. It will always be wrapped by another error. -var ErrNoBytesSent = errors.New("no bytes sent to server") + var netErr net.Error + return errors.As(err, &netErr) && netErr.Timeout() +} // PgError represents an error reported by the PostgreSQL server. See // http://www.postgresql.org/docs/11/static/protocol-error-fields.html for @@ -46,44 +55,107 @@ func (pe *PgError) Error() string { return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" } -// linkedError connects two errors as if err wrapped next. -type linkedError struct { - err error - next error +type connectError struct { + config *Config + msg string + err error } -func (le *linkedError) Error() string { - return le.err.Error() -} - -func (le *linkedError) Is(target error) bool { - return errors.Is(le.err, target) -} - -func (le *linkedError) As(target interface{}) bool { - return errors.As(le.err, target) -} - -func (le *linkedError) Unwrap() error { - return le.next -} - -// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == -// true. Otherwise returns err. -func preferContextOverNetTimeoutError(ctx context.Context, err error) error { - if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil { - return ctx.Err() +func (e *connectError) Error() string { + sb := &strings.Builder{} + fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg) + if e.err != nil { + fmt.Fprintf(sb, " (%s)", e.err.Error()) } - return err + return sb.String() } -// linkErrors connects outer and inner as if the the fully unwrapped outer wrapped inner. If either outer or inner is nil then the other is returned. -func linkErrors(outer, inner error) error { - if outer == nil { - return inner +func (e *connectError) Unwrap() error { + return e.err +} + +type connLockError struct { + status string +} + +func (e *connLockError) SafeToRetry() bool { + return true // a lock failure by definition happens before the connection is used. +} + +func (e *connLockError) Error() string { + return e.status +} + +type parseConfigError struct { + connString string + msg string + err error +} + +func (e *parseConfigError) Error() string { + if e.err == nil { + return fmt.Sprintf("cannot parse `%s`: %s", e.connString, e.msg) } - if inner == nil { - return outer + return fmt.Sprintf("cannot parse `%s`: %s (%s)", e.connString, e.msg, e.err.Error()) +} + +func (e *parseConfigError) Unwrap() error { + return e.err +} + +type pgconnError struct { + msg string + err error + safeToRetry bool +} + +func (e *pgconnError) Error() string { + if e.msg == "" { + return e.err.Error() } - return &linkedError{err: outer, next: inner} + if e.err == nil { + return e.msg + } + return fmt.Sprintf("%s: %s", e.msg, e.err.Error()) +} + +func (e *pgconnError) SafeToRetry() bool { + return e.safeToRetry +} + +func (e *pgconnError) Unwrap() error { + return e.err +} + +type contextAlreadyDoneError struct { + err error +} + +func (e *contextAlreadyDoneError) Error() string { + return fmt.Sprintf("context already done: %s", e.err.Error()) +} + +func (e *contextAlreadyDoneError) SafeToRetry() bool { + return true +} + +func (e *contextAlreadyDoneError) Unwrap() error { + return e.err +} + +type writeError struct { + err error + safeToRetry bool +} + +func (e *writeError) Error() string { + return fmt.Sprintf("write failed: %s", e.err.Error()) +} + +func (e *writeError) SafeToRetry() bool { + return e.safeToRetry +} + +func (e *writeError) Unwrap() error { + return e.err } diff --git a/pgconn.go b/pgconn.go index 7d301af2..347acf80 100644 --- a/pgconn.go +++ b/pgconn.go @@ -128,19 +128,19 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err if err == nil { break } else if err, ok := err.(*PgError); ok { - return nil, err + return nil, &connectError{config: config, msg: "server error", err: err} } } if err != nil { - return nil, err + return nil, err // no need to wrap in connectError because it will already be wrapped in all cases except PgError } if config.AfterConnect != nil { err := config.AfterConnect(ctx, pgConn) if err != nil { pgConn.conn.Close() - return nil, errors.Errorf("AfterConnect: %v", err) + return nil, &connectError{config: config, msg: "AfterConnect error", err: err} } } @@ -156,7 +156,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) pgConn.conn, err = config.DialFunc(ctx, network, address) if err != nil { - return nil, err + return nil, &connectError{config: config, msg: "dial error", err: err} } pgConn.parameterStatuses = make(map[string]string) @@ -164,7 +164,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if fallbackConfig.TLSConfig != nil { if err := pgConn.startTLS(fallbackConfig.TLSConfig); err != nil { pgConn.conn.Close() - return nil, err + return nil, &connectError{config: config, msg: "tls error", err: err} } } @@ -193,14 +193,17 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil { pgConn.conn.Close() - return nil, err + return nil, &connectError{config: config, msg: "failed to write startup message", err: err} } for { msg, err := pgConn.receiveMessage() if err != nil { pgConn.conn.Close() - return nil, err + if err, ok := err.(*PgError); ok { + return nil, err + } + return nil, &connectError{config: config, msg: "failed to receive message", err: err} } switch msg := msg.(type) { @@ -210,7 +213,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig case *pgproto3.Authentication: if err = pgConn.rxAuthenticationX(msg); err != nil { pgConn.conn.Close() - return nil, err + return nil, &connectError{config: config, msg: "failed handle authentication message", err: err} } case *pgproto3.ReadyForQuery: pgConn.status = connStatusIdle @@ -218,7 +221,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig err := config.ValidateConnect(ctx, pgConn) if err != nil { pgConn.conn.Close() - return nil, errors.Errorf("ValidateConnect: %v", err) + return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err} } } return pgConn, nil @@ -229,7 +232,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig return nil, ErrorResponseToPgError(msg) default: pgConn.conn.Close() - return nil, errors.New("unexpected message") + return nil, &connectError{config: config, msg: "received unexpected message", err: err} } } } @@ -246,7 +249,7 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { } if response[0] != 'S' { - return ErrTLSRefused + return errors.New("server refused TLS connection") } pgConn.conn = tls.Client(pgConn.conn, tlsConfig) @@ -308,13 +311,13 @@ func (pgConn *PgConn) signalMessage() chan struct{} { // See https://www.postgresql.org/docs/current/protocol.html. func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { if err := pgConn.lock(); err != nil { - return linkErrors(err, ErrNoBytesSent) + return err } defer pgConn.unlock() select { case <-ctx.Done(): - return linkErrors(ctx.Err(), ErrNoBytesSent) + return &contextAlreadyDoneError{err: ctx.Err()} default: } pgConn.contextWatcher.Watch(ctx) @@ -323,10 +326,7 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - return linkErrors(ctx.Err(), err) + return &writeError{err: err, safeToRetry: n == 0} } return nil @@ -341,13 +341,13 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { // See https://www.postgresql.org/docs/current/protocol.html. func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) { if err := pgConn.lock(); err != nil { - return nil, linkErrors(err, ErrNoBytesSent) + return nil, err } defer pgConn.unlock() select { case <-ctx.Done(): - return nil, linkErrors(ctx.Err(), ErrNoBytesSent) + return nil, &contextAlreadyDoneError{err: ctx.Err()} default: } pgConn.contextWatcher.Watch(ctx) @@ -355,7 +355,7 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa msg, err := pgConn.receiveMessage() if err != nil { - err = linkErrors(ctx.Err(), err) + err = &pgconnError{msg: "receive message failed", err: err, safeToRetry: true} } return msg, err } @@ -442,12 +442,12 @@ func (pgConn *PgConn) Close(ctx context.Context) error { _, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) if err != nil { - return linkErrors(ctx.Err(), err) + return err } _, err = pgConn.conn.Read(make([]byte, 1)) if err != io.EOF { - return linkErrors(ctx.Err(), err) + return err } return pgConn.conn.Close() @@ -468,15 +468,15 @@ func (pgConn *PgConn) IsClosed() bool { return pgConn.status < connStatusIdle } -// lock locks the connection. It panics if the connection is already locked or is closed. +// lock locks the connection. func (pgConn *PgConn) lock() error { switch pgConn.status { case connStatusBusy: - return ErrConnBusy // This only should be possible in case of an application bug. + return &connLockError{status: "conn busy"} // This only should be possible in case of an application bug. case connStatusClosed: - return errors.New("conn closed") + return &connLockError{status: "conn closed"} case connStatusUninitialized: - return errors.New("conn uninitialized") + return &connLockError{status: "conn uninitialized"} } pgConn.status = connStatusBusy return nil @@ -527,13 +527,13 @@ type StatementDescription struct { // allows Prepare to also to describe statements without creating a server-side prepared statement. func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) { if err := pgConn.lock(); err != nil { - return nil, linkErrors(err, ErrNoBytesSent) + return nil, err } defer pgConn.unlock() select { case <-ctx.Done(): - return nil, linkErrors(ctx.Err(), ErrNoBytesSent) + return nil, &contextAlreadyDoneError{err: ctx.Err()} default: } pgConn.contextWatcher.Watch(ctx) @@ -547,10 +547,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - return nil, linkErrors(ctx.Err(), err) + return nil, &pgconnError{msg: "write failed", err: err, safeToRetry: n == 0} } psd := &StatementDescription{Name: name, SQL: sql} @@ -562,7 +559,7 @@ readloop: msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } switch msg := msg.(type) { @@ -641,12 +638,12 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) _, err = cancelConn.Write(buf) if err != nil { - return linkErrors(ctx.Err(), err) + return err } _, err = cancelConn.Read(buf) if err != io.EOF { - return errors.Errorf("Server failed to close connection after cancel query request: %w", linkErrors(ctx.Err(), err)) + return err } return nil @@ -672,7 +669,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { for { msg, err := pgConn.receiveMessage() if err != nil { - return linkErrors(ctx.Err(), err) + return err } switch msg.(type) { @@ -691,7 +688,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, - err: linkErrors(err, ErrNoBytesSent), + err: err, } } @@ -704,7 +701,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { select { case <-ctx.Done(): multiResult.closed = true - multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent) + multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} pgConn.unlock() return multiResult default: @@ -719,10 +716,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { pgConn.hardClose() pgConn.contextWatcher.Unwatch() multiResult.closed = true - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - multiResult.err = linkErrors(ctx.Err(), err) + multiResult.err = &writeError{err: err, safeToRetry: n == 0} pgConn.unlock() return multiResult } @@ -798,7 +792,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by result := &pgConn.resultReader if err := pgConn.lock(); err != nil { - result.concludeCommand(nil, linkErrors(err, ErrNoBytesSent)) + result.concludeCommand(nil, err) result.closed = true return result } @@ -812,7 +806,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by select { case <-ctx.Done(): - result.concludeCommand(nil, linkErrors(ctx.Err(), ErrNoBytesSent)) + result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) result.closed = true pgConn.unlock() return result @@ -831,10 +825,7 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - result.concludeCommand(nil, linkErrors(ctx.Err(), err)) + result.concludeCommand(nil, &writeError{err: err, safeToRetry: n == 0}) pgConn.contextWatcher.Unwatch() result.closed = true pgConn.unlock() @@ -844,13 +835,13 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result // CopyTo executes the copy command sql and copies the results to w. func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { if err := pgConn.lock(); err != nil { - return nil, linkErrors(err, ErrNoBytesSent) + return nil, err } select { case <-ctx.Done(): pgConn.unlock() - return nil, linkErrors(ctx.Err(), ErrNoBytesSent) + return nil, &contextAlreadyDoneError{err: ctx.Err()} default: } pgConn.contextWatcher.Watch(ctx) @@ -864,10 +855,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm if err != nil { pgConn.hardClose() pgConn.unlock() - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - return nil, linkErrors(ctx.Err(), err) + return nil, &writeError{err: err, safeToRetry: n == 0} } // Read results @@ -877,7 +865,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } switch msg := msg.(type) { @@ -905,13 +893,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm // could still block. func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { if err := pgConn.lock(); err != nil { - return nil, linkErrors(err, ErrNoBytesSent) + return nil, err } defer pgConn.unlock() select { case <-ctx.Done(): - return nil, linkErrors(ctx.Err(), ErrNoBytesSent) + return nil, &contextAlreadyDoneError{err: ctx.Err()} default: } pgConn.contextWatcher.Watch(ctx) @@ -924,10 +912,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - return nil, linkErrors(ctx.Err(), err) + return nil, &writeError{err: err, safeToRetry: n == 0} } // Read until copy in response or error. @@ -938,7 +923,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } switch msg := msg.(type) { @@ -967,7 +952,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } } @@ -976,7 +961,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } switch msg := msg.(type) { @@ -998,7 +983,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } // Read results @@ -1006,7 +991,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } switch msg := msg.(type) { @@ -1048,7 +1033,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) if err != nil { mrr.pgConn.contextWatcher.Unwatch() - mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) + mrr.err = err mrr.closed = true mrr.pgConn.hardClose() return nil, mrr.err @@ -1263,7 +1248,7 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { } rr.commandTag = commandTag - rr.err = preferContextOverNetTimeoutError(rr.ctx, err) + rr.err = err rr.fieldDescriptions = nil rr.rowValues = nil rr.commandConcluded = true @@ -1293,7 +1278,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, - err: linkErrors(err, ErrNoBytesSent), + err: err, } } @@ -1306,7 +1291,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR select { case <-ctx.Done(): multiResult.closed = true - multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent) + multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} pgConn.unlock() return multiResult default: diff --git a/pgconn_test.go b/pgconn_test.go index 64628262..3fbdf8df 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -86,14 +86,11 @@ func TestConnectInvalidUser(t *testing.T) { config.User = "pgxinvalidusertest" - conn, err := pgconn.ConnectConfig(context.Background(), config) - if err == nil { - conn.Close(context.Background()) - t.Fatal("expected err but got none") - } - pgErr, ok := err.(*pgconn.PgError) + _, err = pgconn.ConnectConfig(context.Background(), config) + require.Error(t, err) + pgErr, ok := errors.Unwrap(err).(*pgconn.PgError) if !ok { - t.Fatalf("Expected to receive a PgError, instead received: %v", err) + t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err) } if pgErr.Code != "28000" && pgErr.Code != "28P01" { t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) @@ -298,7 +295,7 @@ func TestConnPrepareContextPrecanceled(t *testing.T) { assert.Nil(t, psd) assert.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(err)) ensureConnValid(t, pgConn) } @@ -432,7 +429,7 @@ func TestConnExecContextCanceled(t *testing.T) { for multiResult.NextResult() { } err = multiResult.Close() - assert.Equal(t, context.DeadlineExceeded, err) + assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) } @@ -448,7 +445,7 @@ func TestConnExecContextPrecanceled(t *testing.T) { _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() assert.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(err)) ensureConnValid(t, pgConn) } @@ -564,7 +561,7 @@ func TestConnExecParamsCanceled(t *testing.T) { assert.Equal(t, 0, rowCount) commandTag, err := result.Close() assert.Equal(t, pgconn.CommandTag(nil), commandTag) - assert.Equal(t, context.DeadlineExceeded, err) + assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) } @@ -581,7 +578,7 @@ func TestConnExecParamsPrecanceled(t *testing.T) { result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read() require.Error(t, result.Err) assert.True(t, errors.Is(result.Err, context.Canceled)) - assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(result.Err)) ensureConnValid(t, pgConn) } @@ -691,7 +688,7 @@ func TestConnExecPreparedCanceled(t *testing.T) { assert.Equal(t, 0, rowCount) commandTag, err := result.Close() assert.Equal(t, pgconn.CommandTag(nil), commandTag) - assert.Equal(t, context.DeadlineExceeded, err) + assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) } @@ -710,7 +707,7 @@ func TestConnExecPreparedPrecanceled(t *testing.T) { result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() require.Error(t, result.Err) assert.True(t, errors.Is(result.Err, context.Canceled)) - assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(result.Err)) ensureConnValid(t, pgConn) } @@ -798,7 +795,7 @@ func TestConnExecBatchPrecanceled(t *testing.T) { _, err = pgConn.ExecBatch(ctx, batch).ReadAll() require.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(err)) ensureConnValid(t, pgConn) } @@ -871,8 +868,8 @@ func TestConnLocking(t *testing.T) { mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") _, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() assert.Error(t, err) - assert.True(t, errors.Is(err, pgconn.ErrConnBusy)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.Equal(t, "conn busy", err.Error()) + assert.True(t, pgconn.SafeToRetry(err)) results, err := mrr.ReadAll() assert.NoError(t, err) @@ -1029,7 +1026,7 @@ func TestConnWaitForNotificationTimeout(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) err = pgConn.WaitForNotification(ctx) cancel() - assert.True(t, errors.Is(err, context.DeadlineExceeded)) + assert.True(t, pgconn.Timeout(err)) ensureConnValid(t, pgConn) } @@ -1139,7 +1136,7 @@ func TestConnCopyToCanceled(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout") - assert.True(t, errors.Is(err, context.DeadlineExceeded)) + assert.Error(t, err) assert.Equal(t, pgconn.CommandTag(nil), res) assert.True(t, pgConn.IsClosed()) @@ -1159,7 +1156,7 @@ func TestConnCopyToPrecanceled(t *testing.T) { res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout") require.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(err)) assert.Equal(t, pgconn.CommandTag(nil), res) ensureConnValid(t, pgConn) @@ -1231,7 +1228,7 @@ func TestConnCopyFromCanceled(t *testing.T) { ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") cancel() assert.Equal(t, int64(0), ct.RowsAffected()) - assert.True(t, errors.Is(err, context.DeadlineExceeded)) + assert.Error(t, err) assert.True(t, pgConn.IsClosed()) } @@ -1267,7 +1264,7 @@ func TestConnCopyFromPrecanceled(t *testing.T) { ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") require.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(err)) assert.Equal(t, pgconn.CommandTag(nil), ct) ensureConnValid(t, pgConn) From 66aaed7c9eb0751b2936dbdbf278963dda8804fd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 27 Aug 2019 18:11:50 -0500 Subject: [PATCH 22/27] Remove public fields from PgConn - Access TxStatus via method - Make Config private fixes #7 --- auth_scram.go | 2 +- pgconn.go | 27 ++++++++++++++++----------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/auth_scram.go b/auth_scram.go index 4409a080..6d6d0651 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -31,7 +31,7 @@ const clientNonceLen = 18 // Perform SCRAM authentication. func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { - sc, err := newScramClient(serverAuthMechanisms, c.Config.Password) + sc, err := newScramClient(serverAuthMechanisms, c.config.Password) if err != nil { return err } diff --git a/pgconn.go b/pgconn.go index 347acf80..1e3f9515 100644 --- a/pgconn.go +++ b/pgconn.go @@ -69,10 +69,10 @@ type PgConn struct { pid uint32 // backend pid secretKey uint32 // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server - TxStatus byte + txStatus byte frontend Frontend - Config *Config + config *Config status byte // One of connStatus* constants @@ -149,7 +149,7 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { pgConn := new(PgConn) - pgConn.Config = config + pgConn.config = config pgConn.wbuf = make([]byte, 0, 1024) var err error @@ -261,9 +261,9 @@ func (pgConn *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error switch msg.Type { case pgproto3.AuthTypeOk: case pgproto3.AuthTypeCleartextPassword: - err = pgConn.txPasswordMessage(pgConn.Config.Password) + err = pgConn.txPasswordMessage(pgConn.config.Password) case pgproto3.AuthTypeMD5Password: - digestedPassword := "md5" + hexMD5(hexMD5(pgConn.Config.Password+pgConn.Config.User)+string(msg.Salt[:])) + digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) err = pgConn.txPasswordMessage(digestedPassword) case pgproto3.AuthTypeSASL: err = pgConn.scramAuth(msg.SASLAuthMechanisms) @@ -390,7 +390,7 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { switch msg := msg.(type) { case *pgproto3.ReadyForQuery: - pgConn.TxStatus = msg.TxStatus + pgConn.txStatus = msg.TxStatus case *pgproto3.ParameterStatus: pgConn.parameterStatuses[msg.Name] = msg.Value case *pgproto3.ErrorResponse: @@ -399,12 +399,12 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { return nil, ErrorResponseToPgError(msg) } case *pgproto3.NoticeResponse: - if pgConn.Config.OnNotice != nil { - pgConn.Config.OnNotice(pgConn, noticeResponseToNotice(msg)) + if pgConn.config.OnNotice != nil { + pgConn.config.OnNotice(pgConn, noticeResponseToNotice(msg)) } case *pgproto3.NotificationResponse: - if pgConn.Config.OnNotification != nil { - pgConn.Config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload}) + if pgConn.config.OnNotification != nil { + pgConn.config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload}) } } @@ -421,6 +421,11 @@ func (pgConn *PgConn) PID() uint32 { return pgConn.pid } +// TxStatus returns the current TxStatus as reported by the server. +func (pgConn *PgConn) TxStatus() byte { + return pgConn.txStatus +} + // SecretKey returns the backend secret key used to send a cancel query message to the server. func (pgConn *PgConn) SecretKey() uint32 { return pgConn.secretKey @@ -618,7 +623,7 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. serverAddr := pgConn.conn.RemoteAddr() - cancelConn, err := pgConn.Config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) + cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) if err != nil { return err } From 6bba3c4810ce93171830696896238f19911b7ca3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 31 Aug 2019 11:55:02 -0500 Subject: [PATCH 23/27] Update pgproto3 --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index cbeef02a..b54607b6 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/jackc/chunkreader/v2 v2.0.0 github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 + github.com/jackc/pgproto3/v2 v2.0.0-rc2 github.com/stretchr/testify v1.3.0 golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a golang.org/x/text v0.3.0 diff --git a/go.sum b/go.sum index 0e853203..d7a6d087 100644 --- a/go.sum +++ b/go.sum @@ -8,8 +8,8 @@ github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 h1:vZp4bYotXUkFx7JUSm7U8KV/7Q0AOdrQxxBBj0ZmZsg= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= +github.com/jackc/pgproto3/v2 v2.0.0-rc2 h1:u+jUsxBxiLY2C6mhr8cZhSy71n/y8Id2STOzJ7bl2Mg= +github.com/jackc/pgproto3/v2 v2.0.0-rc2/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= From 2fabfa3c18b7bcb4f204c365f2f0d2e09d4564eb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 31 Aug 2019 15:44:54 -0500 Subject: [PATCH 24/27] Update to newest pgproto3 --- auth_scram.go | 34 ++++++++++++++++++++++------------ config.go | 2 +- go.mod | 2 +- go.sum | 8 ++------ pgconn.go | 40 ++++++++++++++++++++-------------------- 5 files changed, 46 insertions(+), 40 deletions(-) diff --git a/auth_scram.go b/auth_scram.go index 6d6d0651..665fc2c2 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -47,11 +47,11 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { } // Receive server-first-message payload in a AuthenticationSASLContinue. - authMsg, err := c.rxAuthMsg(pgproto3.AuthTypeSASLContinue) + saslContinue, err := c.rxSASLContinue() if err != nil { return err } - err = sc.recvServerFirstMessage(authMsg.SASLData) + err = sc.recvServerFirstMessage(saslContinue.Data) if err != nil { return err } @@ -66,27 +66,37 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { } // Receive server-final-message payload in a AuthenticationSASLFinal. - authMsg, err = c.rxAuthMsg(pgproto3.AuthTypeSASLFinal) + saslFinal, err := c.rxSASLFinal() if err != nil { return err } - return sc.recvServerFinalMessage(authMsg.SASLData) + return sc.recvServerFinalMessage(saslFinal.Data) } -func (c *PgConn) rxAuthMsg(typ uint32) (*pgproto3.Authentication, error) { +func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) { msg, err := c.receiveMessage() if err != nil { return nil, err } - authMsg, ok := msg.(*pgproto3.Authentication) - if !ok { - return nil, errors.New("unexpected message type") - } - if authMsg.Type != typ { - return nil, errors.New("unexpected auth type") + saslContinue, ok := msg.(*pgproto3.AuthenticationSASLContinue) + if ok { + return saslContinue, nil } - return authMsg, nil + return nil, errors.New("expected AuthenticationSASLContinue message but received unexpected message") +} + +func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) { + msg, err := c.receiveMessage() + if err != nil { + return nil, err + } + saslFinal, ok := msg.(*pgproto3.AuthenticationSASLFinal) + if ok { + return saslFinal, nil + } + + return nil, errors.New("expected AuthenticationSASLFinal message but received unexpected message") } type scramClient struct { diff --git a/config.go b/config.go index d24d0202..d1267621 100644 --- a/config.go +++ b/config.go @@ -501,7 +501,7 @@ func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc { if err != nil { panic(fmt.Sprintf("BUG: chunkreader.NewConfig failed: %v", err)) } - frontend, _ := pgproto3.NewFrontend(cr, w) + frontend := pgproto3.NewFrontend(cr, w) return frontend } diff --git a/go.mod b/go.mod index b54607b6..6e270cd6 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/jackc/chunkreader/v2 v2.0.0 github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.0-rc2 + github.com/jackc/pgproto3/v2 v2.0.0-rc3 github.com/stretchr/testify v1.3.0 golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a golang.org/x/text v0.3.0 diff --git a/go.sum b/go.sum index d7a6d087..ed8eb401 100644 --- a/go.sum +++ b/go.sum @@ -1,17 +1,13 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= -github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3/v2 v2.0.0-rc2 h1:u+jUsxBxiLY2C6mhr8cZhSy71n/y8Id2STOzJ7bl2Mg= -github.com/jackc/pgproto3/v2 v2.0.0-rc2/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= -github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/jackc/pgproto3/v2 v2.0.0-rc3 h1:EHkgVE6iDyI7HZDfMPaZ2Xjdf7C29DikR6o39WVO61c= +github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/pgconn.go b/pgconn.go index 1e3f9515..d51eb76a 100644 --- a/pgconn.go +++ b/pgconn.go @@ -210,11 +210,28 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig case *pgproto3.BackendKeyData: pgConn.pid = msg.ProcessID pgConn.secretKey = msg.SecretKey - case *pgproto3.Authentication: - if err = pgConn.rxAuthenticationX(msg); err != nil { + + case *pgproto3.AuthenticationOk: + case *pgproto3.AuthenticationCleartextPassword: + err = pgConn.txPasswordMessage(pgConn.config.Password) + if err != nil { pgConn.conn.Close() - return nil, &connectError{config: config, msg: "failed handle authentication message", err: err} + return nil, &connectError{config: config, msg: "failed to write password message", err: err} } + case *pgproto3.AuthenticationMD5Password: + digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) + err = pgConn.txPasswordMessage(digestedPassword) + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "failed to write password message", err: err} + } + case *pgproto3.AuthenticationSASL: + err = pgConn.scramAuth(msg.AuthMechanisms) + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "failed SASL auth", err: err} + } + case *pgproto3.ReadyForQuery: pgConn.status = connStatusIdle if config.ValidateConnect != nil { @@ -257,23 +274,6 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { return nil } -func (pgConn *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { - switch msg.Type { - case pgproto3.AuthTypeOk: - case pgproto3.AuthTypeCleartextPassword: - err = pgConn.txPasswordMessage(pgConn.config.Password) - case pgproto3.AuthTypeMD5Password: - digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) - err = pgConn.txPasswordMessage(digestedPassword) - case pgproto3.AuthTypeSASL: - err = pgConn.scramAuth(msg.SASLAuthMechanisms) - default: - err = errors.New("Received unknown authentication message") - } - - return -} - func (pgConn *PgConn) txPasswordMessage(password string) (err error) { msg := &pgproto3.PasswordMessage{Password: password} _, err = pgConn.conn.Write(msg.Encode(pgConn.wbuf)) From 2f6b8f3f5665228c0800e66b05e797ef119f3ef2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 31 Aug 2019 17:01:54 -0500 Subject: [PATCH 25/27] Fix context timeout on connect --- go.mod | 11 ++++--- go.sum | 88 ++++++++++++++++++++++++++++++++++++++++++++++++++ pgconn.go | 3 ++ pgconn_test.go | 62 +++++++++++++++++++++++++++++++++++ 4 files changed, 159 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index 6e270cd6..11692c10 100644 --- a/go.mod +++ b/go.mod @@ -5,10 +5,11 @@ go 1.12 require ( github.com/jackc/chunkreader/v2 v2.0.0 github.com/jackc/pgio v1.0.0 + github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.0-rc3 - github.com/stretchr/testify v1.3.0 - golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a - golang.org/x/text v0.3.0 - golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 + github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 + github.com/stretchr/testify v1.4.0 + golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 + golang.org/x/text v0.3.2 + golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 ) diff --git a/go.sum b/go.sum index ed8eb401..c1b3d405 100644 --- a/go.sum +++ b/go.sum @@ -1,23 +1,111 @@ +github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= +github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= +github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= +github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye4717ITLaNwV9mWbJx0dLCpcRzdA= +github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/jackc/pgproto3/v2 v2.0.0-rc3 h1:EHkgVE6iDyI7HZDfMPaZ2Xjdf7C29DikR6o39WVO61c= github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 h1:f2HwOeI1NIJyNFVVeh1gUISyt57iw/fmI/IXJfH3ATE= +github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= +github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= +github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= +github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= +github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= +github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= +github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= +github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= +github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= +github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= +github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= +go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcNpq8q3BCACtVgNfoJxOV7g= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 h1:7KByu05hhLed2MO29w7p1XfZvZ13m8mub3shuVftRs0= +golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e h1:nFYrTHrdrAOpShe27kaFHjsqYSEQ0KWqdWLu3xuZJts= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 h1:bhOzK9QyoD0ogCnFro1m2mz41+Ib0oOhfJnBp5MR4K4= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/pgconn.go b/pgconn.go index d51eb76a..5c01d1dc 100644 --- a/pgconn.go +++ b/pgconn.go @@ -174,6 +174,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig func() { pgConn.conn.SetDeadline(time.Time{}) }, ) + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn) startupMsg := pgproto3.StartupMessage{ diff --git a/pgconn_test.go b/pgconn_test.go index 3fbdf8df..4a67a2e0 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -18,6 +18,7 @@ import ( "time" "github.com/jackc/pgconn" + "github.com/jackc/pgmock" "github.com/jackc/pgproto3/v2" errors "golang.org/x/xerrors" @@ -73,6 +74,67 @@ func TestConnectTLS(t *testing.T) { closeConn(t, conn) } +type pgmockWaitStep time.Duration + +func (s pgmockWaitStep) Step(*pgproto3.Backend) error { + time.Sleep(time.Duration(s)) + return nil +} + +func TestConnectWithContextThatTimesOut(t *testing.T) { + t.Parallel() + + script := &pgmock.Script{ + Steps: []pgmock.Step{ + pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), + pgmock.SendMessage(&pgproto3.AuthenticationOk{}), + pgmockWaitStep(time.Millisecond * 500), + pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}), + pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + }, + } + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error, 1) + go func() { + defer close(serverErrChan) + + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450)) + if err != nil { + serverErrChan <- err + return + } + + err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)) + if err != nil { + serverErrChan <- err + return + } + }() + + parts := strings.Split(ln.Addr().String(), ":") + host := parts[0] + port := parts[1] + connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) + tooLate := time.Now().Add(time.Millisecond * 500) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50) + defer cancel() + _, err = pgconn.Connect(ctx, connStr) + require.True(t, pgconn.Timeout(err), err) + require.True(t, time.Now().Before(tooLate)) +} + func TestConnectInvalidUser(t *testing.T) { t.Parallel() From a8362ef96d23eb9e53a9eb57bb12889f8cbaa1c2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 10 Sep 2019 17:14:04 -0500 Subject: [PATCH 26/27] Parse postgresql:// protocol --- config.go | 2 +- config_test.go | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index d1267621..2ec6ae3f 100644 --- a/config.go +++ b/config.go @@ -152,7 +152,7 @@ func ParseConfig(connString string) (*Config, error) { if connString != "" { // connString may be a database URL or a DSN - if strings.HasPrefix(connString, "postgres://") { + if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { err := addURLSettings(settings, connString) if err != nil { return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err} diff --git a/config_test.go b/config_test.go index af42094d..090302a2 100644 --- a/config_test.go +++ b/config_test.go @@ -214,6 +214,18 @@ func TestParseConfig(t *testing.T) { RuntimeParams: map[string]string{}, }, }, + { + name: "database url postgresql protocol", + connString: "postgresql://jack@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, { name: "DSN everything", connString: "user=jack password=secret host=localhost port=5432 database=mydb sslmode=disable application_name=pgxtest search_path=myschema", From f8be2b60ce34bf79b747009b9cc7fb718b918734 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 10 Sep 2019 17:25:25 -0500 Subject: [PATCH 27/27] go.sum changes --- go.sum | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/go.sum b/go.sum index c1b3d405..d0a917fc 100644 --- a/go.sum +++ b/go.sum @@ -2,11 +2,11 @@ github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMe github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= @@ -19,10 +19,10 @@ github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye47 github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= -github.com/jackc/pgproto3/v2 v2.0.0-rc3 h1:EHkgVE6iDyI7HZDfMPaZ2Xjdf7C29DikR6o39WVO61c= github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 h1:f2HwOeI1NIJyNFVVeh1gUISyt57iw/fmI/IXJfH3ATE= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= @@ -48,7 +48,6 @@ github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -63,7 +62,6 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= @@ -74,7 +72,6 @@ go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/ go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcNpq8q3BCACtVgNfoJxOV7g= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 h1:7KByu05hhLed2MO29w7p1XfZvZ13m8mub3shuVftRs0= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -86,12 +83,10 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e h1:nFYrTHrdrAOpShe27kaFHjsqYSEQ0KWqdWLu3xuZJts= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= @@ -99,7 +94,6 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 h1:bhOzK9QyoD0ogCnFro1m2mz41+Ib0oOhfJnBp5MR4K4= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=