diff --git a/pgconn/.github/workflows/ci.yml b/pgconn/.github/workflows/ci.yml new file mode 100644 index 00000000..d84462da --- /dev/null +++ b/pgconn/.github/workflows/ci.yml @@ -0,0 +1,81 @@ +name: CI + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +jobs: + + test: + name: Test + runs-on: ubuntu-18.04 + + strategy: + matrix: + go-version: [1.15, 1.16] + pg-version: [9.6, 10, 11, 12, 13, cockroachdb] + include: + - pg-version: 9.6 + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test + - pg-version: 10 + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test + - pg-version: 11 + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test + - pg-version: 12 + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test + - pg-version: 13 + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test + - pg-version: cockroachdb + pgx-test-conn-string: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on" + + steps: + + - name: Set up Go 1.x + uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go-version }} + + - name: Check out code into the Go module directory + uses: actions/checkout@v2 + + - name: Setup database server for testing + run: ci/setup_test.bash + env: + PGVERSION: ${{ matrix.pg-version }} + + - name: Test + run: go test -v -race ./... + env: + PGX_TEST_CONN_STRING: ${{ matrix.pgx-test-conn-string }} + PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }} + PGX_TEST_TCP_CONN_STRING: ${{ matrix.pgx-test-tcp-conn-string }} + PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }} + PGX_TEST_MD5_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-md5-password-conn-string }} + PGX_TEST_PLAIN_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-plain-password-conn-string }} diff --git a/pgconn/.gitignore b/pgconn/.gitignore new file mode 100644 index 00000000..e980f555 --- /dev/null +++ b/pgconn/.gitignore @@ -0,0 +1,3 @@ +.envrc +vendor/ +.vscode diff --git a/pgconn/CHANGELOG.md b/pgconn/CHANGELOG.md new file mode 100644 index 00000000..63933a3a --- /dev/null +++ b/pgconn/CHANGELOG.md @@ -0,0 +1,129 @@ +# 1.10.1 (November 20, 2021) + +* Close without waiting for response (Kei Kamikawa) +* Save waiting for network round-trip in CopyFrom (Rueian) +* Fix concurrency issue with ContextWatcher +* LRU.Get always checks context for cancellation / expiration (Georges Varouchas) + +# 1.10.0 (July 24, 2021) + +* net.Timeout errors are no longer returned when a query is canceled via context. A wrapped context error is returned. + +# 1.9.0 (July 10, 2021) + +* pgconn.Timeout only is true for errors originating in pgconn (Michael Darr) +* Add defaults for sslcert, sslkey, and sslrootcert (Joshua Brindle) +* Solve issue with 'sslmode=verify-full' when there are multiple hosts (mgoddard) +* Fix default host when parsing URL without host but with port +* Allow dbname query parameter in URL conn string +* Update underlying dependencies + +# 1.8.1 (March 25, 2021) + +* Better connection string sanitization (ip.novikov) +* Use proper pgpass location on Windows (Moshe Katz) +* Use errors instead of golang.org/x/xerrors +* Resume fallback on server error in Connect (Andrey Borodin) + +# 1.8.0 (December 3, 2020) + +* Add StatementErrored method to stmtcache.Cache. This allows the cache to purge invalidated prepared statements. (Ethan Pailes) + +# 1.7.2 (November 3, 2020) + +* Fix data value slices into work buffer with capacities larger than length. + +# 1.7.1 (October 31, 2020) + +* Do not asyncClose after receiving FATAL error from PostgreSQL server + +# 1.7.0 (September 26, 2020) + +* Exec(Params|Prepared) return ResultReader with FieldDescriptions loaded +* Add ReceiveResults (Sebastiaan Mannem) +* Fix parsing DSN connection with bad backslash +* Add PgConn.CleanupDone so connection pools can determine when async close is complete + +# 1.6.4 (July 29, 2020) + +* Fix deadlock on error after CommandComplete but before ReadyForQuery +* Fix panic on parsing DSN with trailing '=' + +# 1.6.3 (July 22, 2020) + +* Fix error message after AppendCertsFromPEM failure (vahid-sohrabloo) + +# 1.6.2 (July 14, 2020) + +* Update pgservicefile library + +# 1.6.1 (June 27, 2020) + +* Update golang.org/x/crypto to latest +* Update golang.org/x/text to 0.3.3 +* Fix error handling for bad PGSERVICE definition +* Redact passwords in ParseConfig errors (Lukas Vogel) + +# 1.6.0 (June 6, 2020) + +* Fix panic when closing conn during cancellable query +* Fix behavior of sslmode=require with sslrootcert present (Petr Jediný) +* Fix field descriptions available after command concluded (Tobias Salzmann) +* Support connect_timeout (georgysavva) +* Handle IPv6 in connection URLs (Lukas Vogel) +* Fix ValidateConnect with cancelable context +* Improve CopyFrom performance +* Add Config.Copy (georgysavva) + +# 1.5.0 (March 30, 2020) + +* Update golang.org/x/crypto for security fix +* Implement "verify-ca" SSL mode (Greg Curtis) + +# 1.4.0 (March 7, 2020) + +* Fix ExecParams and ExecPrepared handling of empty query. +* Support reading config from PostgreSQL service files. + +# 1.3.2 (February 14, 2020) + +* Update chunkreader to v2.0.1 for optimized default buffer size. + +# 1.3.1 (February 5, 2020) + +* Fix CopyFrom deadlock when multiple NoticeResponse received during copy + +# 1.3.0 (January 23, 2020) + +* Add Hijack and Construct. +* Update pgproto3 to v2.0.1. + +# 1.2.1 (January 13, 2020) + +* Fix data race in context cancellation introduced in v1.2.0. + +# 1.2.0 (January 11, 2020) + +## Features + +* Add Insert(), Update(), Delete(), and Select() statement type query methods to CommandTag. +* Add PgError.SQLState method. This could be used for compatibility with other drivers and databases. + +## Performance + +* Improve performance when context.Background() is used. (bakape) +* CommandTag.RowsAffected is faster and does not allocate. + +## Fixes + +* Try to cancel any in-progress query when a conn is closed by ctx cancel. +* Handle NoticeResponse during CopyFrom. +* Ignore errors sending Terminate message while closing connection. This mimics the behavior of libpq PGfinish. + +# 1.1.0 (October 12, 2019) + +* Add PgConn.IsBusy() method. + +# 1.0.1 (September 19, 2019) + +* Fix statement cache not properly cleaning discarded statements. diff --git a/pgconn/LICENSE b/pgconn/LICENSE new file mode 100644 index 00000000..aebadd6c --- /dev/null +++ b/pgconn/LICENSE @@ -0,0 +1,22 @@ +Copyright (c) 2019-2021 Jack Christensen + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/pgconn/README.md b/pgconn/README.md new file mode 100644 index 00000000..1c698a11 --- /dev/null +++ b/pgconn/README.md @@ -0,0 +1,56 @@ +[![](https://godoc.org/github.com/jackc/pgconn?status.svg)](https://godoc.org/github.com/jackc/pgconn) +![CI](https://github.com/jackc/pgconn/workflows/CI/badge.svg) + +# pgconn + +Package pgconn is a low-level PostgreSQL database driver. It operates at nearly the same level as the C library libpq. +It is primarily intended to serve as the foundation for higher level libraries such as https://github.com/jackc/pgx. +Applications should handle normal queries with a higher level library and only use pgconn directly when required for +low-level access to PostgreSQL functionality. + +## Example Usage + +```go +pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL")) +if err != nil { + log.Fatalln("pgconn failed to connect:", err) +} +defer pgConn.Close(context.Background()) + +result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil) +for result.NextRow() { + fmt.Println("User 123 has email:", string(result.Values()[0])) +} +_, err = result.Close() +if err != nil { + log.Fatalln("failed reading result:", err) +} +``` + +## Testing + +The pgconn tests require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_CONN_STRING` +environment variable. The `PGX_TEST_CONN_STRING` environment variable can be a URL or DSN. In addition, the standard `PG*` +environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to simplify +environment variable handling. + +### Example Test Environment + +Connect to your PostgreSQL server and run: + +``` +create database pgx_test; +``` + +Now you can run the tests: + +```bash +PGX_TEST_CONN_STRING="host=/var/run/postgresql dbname=pgx_test" go test ./... +``` + +### Connection and Authentication Tests + +Pgconn supports multiple connection types and means of authentication. These tests are optional. They +will only run if the appropriate environment variable is set. Run `go test -v | grep SKIP` to see if any tests are being +skipped. Most developers will not need to enable these tests. See `ci/setup_test.bash` for an example set up if you need change +authentication code. diff --git a/pgconn/auth_scram.go b/pgconn/auth_scram.go new file mode 100644 index 00000000..6a143fcd --- /dev/null +++ b/pgconn/auth_scram.go @@ -0,0 +1,266 @@ +// SCRAM-SHA-256 authentication +// +// Resources: +// https://tools.ietf.org/html/rfc5802 +// https://tools.ietf.org/html/rfc8265 +// https://www.postgresql.org/docs/current/sasl-authentication.html +// +// Inspiration drawn from other implementations: +// https://github.com/lib/pq/pull/608 +// https://github.com/lib/pq/pull/788 +// https://github.com/lib/pq/pull/833 + +package pgconn + +import ( + "bytes" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "errors" + "fmt" + "strconv" + + "github.com/jackc/pgproto3/v2" + "golang.org/x/crypto/pbkdf2" + "golang.org/x/text/secure/precis" +) + +const clientNonceLen = 18 + +// Perform SCRAM authentication. +func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { + sc, err := newScramClient(serverAuthMechanisms, c.config.Password) + if err != nil { + return err + } + + // Send client-first-message in a SASLInitialResponse + saslInitialResponse := &pgproto3.SASLInitialResponse{ + AuthMechanism: "SCRAM-SHA-256", + Data: sc.clientFirstMessage(), + } + _, err = c.conn.Write(saslInitialResponse.Encode(nil)) + if err != nil { + return err + } + + // Receive server-first-message payload in a AuthenticationSASLContinue. + saslContinue, err := c.rxSASLContinue() + if err != nil { + return err + } + err = sc.recvServerFirstMessage(saslContinue.Data) + if err != nil { + return err + } + + // Send client-final-message in a SASLResponse + saslResponse := &pgproto3.SASLResponse{ + Data: []byte(sc.clientFinalMessage()), + } + _, err = c.conn.Write(saslResponse.Encode(nil)) + if err != nil { + return err + } + + // Receive server-final-message payload in a AuthenticationSASLFinal. + saslFinal, err := c.rxSASLFinal() + if err != nil { + return err + } + return sc.recvServerFinalMessage(saslFinal.Data) +} + +func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) { + msg, err := c.receiveMessage() + if err != nil { + return nil, err + } + saslContinue, ok := msg.(*pgproto3.AuthenticationSASLContinue) + if ok { + return saslContinue, nil + } + + return nil, errors.New("expected AuthenticationSASLContinue message but received unexpected message") +} + +func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) { + msg, err := c.receiveMessage() + if err != nil { + return nil, err + } + saslFinal, ok := msg.(*pgproto3.AuthenticationSASLFinal) + if ok { + return saslFinal, nil + } + + return nil, errors.New("expected AuthenticationSASLFinal message but received unexpected message") +} + +type scramClient struct { + serverAuthMechanisms []string + password []byte + clientNonce []byte + + clientFirstMessageBare []byte + + serverFirstMessage []byte + clientAndServerNonce []byte + salt []byte + iterations int + + saltedPassword []byte + authMessage []byte +} + +func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) { + sc := &scramClient{ + serverAuthMechanisms: serverAuthMechanisms, + } + + // Ensure server supports SCRAM-SHA-256 + hasScramSHA256 := false + for _, mech := range sc.serverAuthMechanisms { + if mech == "SCRAM-SHA-256" { + hasScramSHA256 = true + break + } + } + if !hasScramSHA256 { + return nil, errors.New("server does not support SCRAM-SHA-256") + } + + // precis.OpaqueString is equivalent to SASLprep for password. + var err error + sc.password, err = precis.OpaqueString.Bytes([]byte(password)) + if err != nil { + // PostgreSQL allows passwords invalid according to SCRAM / SASLprep. + sc.password = []byte(password) + } + + buf := make([]byte, clientNonceLen) + _, err = rand.Read(buf) + if err != nil { + return nil, err + } + sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf))) + base64.RawStdEncoding.Encode(sc.clientNonce, buf) + + return sc, nil +} + +func (sc *scramClient) clientFirstMessage() []byte { + sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce)) + return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare)) +} + +func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error { + sc.serverFirstMessage = serverFirstMessage + buf := serverFirstMessage + if !bytes.HasPrefix(buf, []byte("r=")) { + return errors.New("invalid SCRAM server-first-message received from server: did not include r=") + } + buf = buf[2:] + + idx := bytes.IndexByte(buf, ',') + if idx == -1 { + return errors.New("invalid SCRAM server-first-message received from server: did not include s=") + } + sc.clientAndServerNonce = buf[:idx] + buf = buf[idx+1:] + + if !bytes.HasPrefix(buf, []byte("s=")) { + return errors.New("invalid SCRAM server-first-message received from server: did not include s=") + } + buf = buf[2:] + + idx = bytes.IndexByte(buf, ',') + if idx == -1 { + return errors.New("invalid SCRAM server-first-message received from server: did not include i=") + } + saltStr := buf[:idx] + buf = buf[idx+1:] + + if !bytes.HasPrefix(buf, []byte("i=")) { + return errors.New("invalid SCRAM server-first-message received from server: did not include i=") + } + buf = buf[2:] + iterationsStr := buf + + var err error + sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr)) + if err != nil { + return fmt.Errorf("invalid SCRAM salt received from server: %w", err) + } + + sc.iterations, err = strconv.Atoi(string(iterationsStr)) + if err != nil || sc.iterations <= 0 { + return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err) + } + + if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) { + return errors.New("invalid SCRAM nonce: did not start with client nonce") + } + + if len(sc.clientAndServerNonce) <= len(sc.clientNonce) { + return errors.New("invalid SCRAM nonce: did not include server nonce") + } + + return nil +} + +func (sc *scramClient) clientFinalMessage() string { + clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce)) + + sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New) + sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(",")) + + clientProof := computeClientProof(sc.saltedPassword, sc.authMessage) + + return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof) +} + +func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error { + if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) { + return errors.New("invalid SCRAM server-final-message received from server") + } + + serverSignature := serverFinalMessage[2:] + + if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) { + return errors.New("invalid SCRAM ServerSignature received from server") + } + + return nil +} + +func computeHMAC(key, msg []byte) []byte { + mac := hmac.New(sha256.New, key) + mac.Write(msg) + return mac.Sum(nil) +} + +func computeClientProof(saltedPassword, authMessage []byte) []byte { + clientKey := computeHMAC(saltedPassword, []byte("Client Key")) + storedKey := sha256.Sum256(clientKey) + clientSignature := computeHMAC(storedKey[:], authMessage) + + clientProof := make([]byte, len(clientSignature)) + for i := 0; i < len(clientSignature); i++ { + clientProof[i] = clientKey[i] ^ clientSignature[i] + } + + buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof))) + base64.StdEncoding.Encode(buf, clientProof) + return buf +} + +func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte { + serverKey := computeHMAC(saltedPassword, []byte("Server Key")) + serverSignature := computeHMAC(serverKey, authMessage) + buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature))) + base64.StdEncoding.Encode(buf, serverSignature) + return buf +} diff --git a/pgconn/benchmark_test.go b/pgconn/benchmark_test.go new file mode 100644 index 00000000..ced785b6 --- /dev/null +++ b/pgconn/benchmark_test.go @@ -0,0 +1,322 @@ +package pgconn_test + +import ( + "bytes" + "context" + "os" + "strings" + "testing" + + "github.com/jackc/pgconn" + "github.com/stretchr/testify/require" +) + +func BenchmarkConnect(b *testing.B) { + benchmarks := []struct { + name string + env string + }{ + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, + } + + for _, bm := range benchmarks { + bm := bm + b.Run(bm.name, func(b *testing.B) { + connString := os.Getenv(bm.env) + if connString == "" { + b.Skipf("Skipping due to missing environment variable %v", bm.env) + } + + for i := 0; i < b.N; i++ { + conn, err := pgconn.Connect(context.Background(), connString) + require.Nil(b, err) + + err = conn.Close(context.Background()) + require.Nil(b, err) + } + }) + } +} + +func BenchmarkExec(b *testing.B) { + expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} + benchmarks := []struct { + name string + ctx context.Context + }{ + // Using an empty context other than context.Background() to compare + // performance + {"background context", context.Background()}, + {"empty context", context.TODO()}, + } + + for _, bm := range benchmarks { + bm := bm + b.Run(bm.name, func(b *testing.B) { + conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.Nil(b, err) + defer closeConn(b, conn) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + mrr := conn.Exec(bm.ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") + + for mrr.NextResult() { + rr := mrr.ResultReader() + + rowCount := 0 + for rr.NextRow() { + rowCount++ + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) + } + } + + err := mrr.Close() + if err != nil { + b.Fatal(err) + } + } + }) + } +} + +func BenchmarkExecPossibleToCancel(b *testing.B) { + conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.Nil(b, err) + defer closeConn(b, conn) + + expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} + + b.ResetTimer() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for i := 0; i < b.N; i++ { + mrr := conn.Exec(ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") + + for mrr.NextResult() { + rr := mrr.ResultReader() + + rowCount := 0 + for rr.NextRow() { + rowCount++ + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) + } + } + + err := mrr.Close() + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkExecPrepared(b *testing.B) { + expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} + + benchmarks := []struct { + name string + ctx context.Context + }{ + // Using an empty context other than context.Background() to compare + // performance + {"background context", context.Background()}, + {"empty context", context.TODO()}, + } + + for _, bm := range benchmarks { + bm := bm + b.Run(bm.name, func(b *testing.B) { + conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.Nil(b, err) + defer closeConn(b, conn) + + _, err = conn.Prepare(bm.ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) + require.Nil(b, err) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rr := conn.ExecPrepared(bm.ctx, "ps1", nil, nil, nil) + + rowCount := 0 + for rr.NextRow() { + rowCount++ + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) + } + } + }) + } +} + +func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { + conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.Nil(b, err) + defer closeConn(b, conn) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, err = conn.Prepare(ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) + require.Nil(b, err) + + expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rr := conn.ExecPrepared(ctx, "ps1", nil, nil, nil) + + rowCount := 0 + for rr.NextRow() { + rowCount += 1 + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) + } + } +} + +// func BenchmarkChanToSetDeadlinePossibleToCancel(b *testing.B) { +// conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) +// require.Nil(b, err) +// defer closeConn(b, conn) + +// ctx, cancel := context.WithCancel(context.Background()) +// defer cancel() + +// b.ResetTimer() + +// for i := 0; i < b.N; i++ { +// conn.ChanToSetDeadline().Watch(ctx) +// conn.ChanToSetDeadline().Ignore() +// } +// } + +func BenchmarkCommandTagRowsAffected(b *testing.B) { + benchmarks := []struct { + commandTag string + rowsAffected int64 + }{ + {"UPDATE 1", 1}, + {"UPDATE 123456789", 123456789}, + {"INSERT 0 1", 1}, + {"INSERT 0 123456789", 123456789}, + } + + for _, bm := range benchmarks { + ct := pgconn.CommandTag(bm.commandTag) + b.Run(bm.commandTag, func(b *testing.B) { + var n int64 + for i := 0; i < b.N; i++ { + n = ct.RowsAffected() + } + if n != bm.rowsAffected { + b.Errorf("expected %d got %d", bm.rowsAffected, n) + } + }) + } +} + +func BenchmarkCommandTagTypeFromString(b *testing.B) { + ct := pgconn.CommandTag("UPDATE 1") + + var update bool + for i := 0; i < b.N; i++ { + update = strings.HasPrefix(ct.String(), "UPDATE") + } + if !update { + b.Error("expected update") + } +} + +func BenchmarkCommandTagInsert(b *testing.B) { + benchmarks := []struct { + commandTag string + is bool + }{ + {"INSERT 1", true}, + {"INSERT 1234567890", true}, + {"UPDATE 1", false}, + {"UPDATE 1234567890", false}, + {"DELETE 1", false}, + {"DELETE 1234567890", false}, + {"SELECT 1", false}, + {"SELECT 1234567890", false}, + {"UNKNOWN 1234567890", false}, + } + + for _, bm := range benchmarks { + ct := pgconn.CommandTag(bm.commandTag) + b.Run(bm.commandTag, func(b *testing.B) { + var is bool + for i := 0; i < b.N; i++ { + is = ct.Insert() + } + if is != bm.is { + b.Errorf("expected %v got %v", bm.is, is) + } + }) + } +} diff --git a/pgconn/ci/script.bash b/pgconn/ci/script.bash new file mode 100755 index 00000000..5bf1b77e --- /dev/null +++ b/pgconn/ci/script.bash @@ -0,0 +1,10 @@ +#!/usr/bin/env bash +set -eux + +if [ "${PGVERSION-}" != "" ] +then + go test -v -race ./... +elif [ "${CRATEVERSION-}" != "" ] +then + go test -v -race -run 'TestCrateDBConnect' +fi diff --git a/pgconn/ci/setup_test.bash b/pgconn/ci/setup_test.bash new file mode 100755 index 00000000..f71bd98c --- /dev/null +++ b/pgconn/ci/setup_test.bash @@ -0,0 +1,59 @@ +#!/usr/bin/env bash +set -eux + +if [[ "${PGVERSION-}" =~ ^[0-9.]+$ ]] +then + sudo apt-get remove -y --purge postgresql libpq-dev libpq5 postgresql-client-common postgresql-common + sudo rm -rf /var/lib/postgresql + wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - + sudo sh -c "echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION >> /etc/apt/sources.list.d/postgresql.list" + sudo apt-get update -qq + sudo apt-get -y -o Dpkg::Options::=--force-confdef -o Dpkg::Options::="--force-confnew" install postgresql-$PGVERSION postgresql-server-dev-$PGVERSION postgresql-contrib-$PGVERSION + sudo chmod 777 /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "local all postgres trust" > /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "local all all trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "host all pgx_md5 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "host all pgx_pw 127.0.0.1/32 password" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "hostssl all pgx_ssl 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "host replication pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "host pgx_test pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + sudo chmod 777 /etc/postgresql/$PGVERSION/main/postgresql.conf + if $(dpkg --compare-versions $PGVERSION ge 9.6) ; then + echo "wal_level='logical'" >> /etc/postgresql/$PGVERSION/main/postgresql.conf + echo "max_wal_senders=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf + echo "max_replication_slots=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf + fi + sudo /etc/init.d/postgresql restart + + # The tricky test user, below, has to actually exist so that it can be used in a test + # of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles. + psql -U postgres -c 'create database pgx_test' + psql -U postgres pgx_test -c 'create extension hstore' + psql -U postgres pgx_test -c 'create domain uint64 as numeric(20,0)' + psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'" + psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'" + psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'" + psql -U postgres -c "create user `whoami`" + psql -U postgres -c "create user pgx_replication with replication password 'secret'" + psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'" +fi + +if [[ "${PGVERSION-}" =~ ^cockroach ]] +then + wget -qO- https://binaries.cockroachdb.com/cockroach-v20.2.5.linux-amd64.tgz | tar xvz + sudo mv cockroach-v20.2.5.linux-amd64/cockroach /usr/local/bin/ + cockroach start-single-node --insecure --background --listen-addr=localhost + cockroach sql --insecure -e 'create database pgx_test' +fi + +if [ "${CRATEVERSION-}" != "" ] +then + docker run \ + -p "6543:5432" \ + -d \ + crate:"$CRATEVERSION" \ + crate \ + -Cnetwork.host=0.0.0.0 \ + -Ctransport.host=localhost \ + -Clicense.enterprise=false +fi diff --git a/pgconn/config.go b/pgconn/config.go new file mode 100644 index 00000000..172e7478 --- /dev/null +++ b/pgconn/config.go @@ -0,0 +1,729 @@ +package pgconn + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io" + "io/ioutil" + "math" + "net" + "net/url" + "os" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/jackc/chunkreader/v2" + "github.com/jackc/pgpassfile" + "github.com/jackc/pgproto3/v2" + "github.com/jackc/pgservicefile" +) + +type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error +type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error + +// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig. A +// manually initialized Config will cause ConnectConfig to panic. +type Config struct { + Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp) + Port uint16 + Database string + User string + Password string + TLSConfig *tls.Config // nil disables TLS + ConnectTimeout time.Duration + DialFunc DialFunc // e.g. net.Dialer.DialContext + LookupFunc LookupFunc // e.g. net.Resolver.LookupHost + BuildFrontend BuildFrontendFunc + RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) + + Fallbacks []*FallbackConfig + + // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. + // It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next + // fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs. + ValidateConnect ValidateConnectFunc + + // AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables + // or prepare statements). If this returns an error the connection attempt fails. + AfterConnect AfterConnectFunc + + // OnNotice is a callback function called when a notice response is received. + OnNotice NoticeHandler + + // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. + OnNotification NotificationHandler + + createdByParseConfig bool // Used to enforce created by ParseConfig rule. +} + +// Copy returns a deep copy of the config that is safe to use and modify. +// The only exception is the TLSConfig field: +// according to the tls.Config docs it must not be modified after creation. +func (c *Config) Copy() *Config { + newConf := new(Config) + *newConf = *c + if newConf.TLSConfig != nil { + newConf.TLSConfig = c.TLSConfig.Clone() + } + if newConf.RuntimeParams != nil { + newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams)) + for k, v := range c.RuntimeParams { + newConf.RuntimeParams[k] = v + } + } + if newConf.Fallbacks != nil { + newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks)) + for i, fallback := range c.Fallbacks { + newFallback := new(FallbackConfig) + *newFallback = *fallback + if newFallback.TLSConfig != nil { + newFallback.TLSConfig = fallback.TLSConfig.Clone() + } + newConf.Fallbacks[i] = newFallback + } + } + return newConf +} + +// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a +// network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections. +type FallbackConfig struct { + Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) + Port uint16 + TLSConfig *tls.Config // nil disables TLS +} + +// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with +// net.Dial. +func NetworkAddress(host string, port uint16) (network, address string) { + if strings.HasPrefix(host, "/") { + network = "unix" + address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10) + } else { + network = "tcp" + address = net.JoinHostPort(host, strconv.Itoa(int(port))) + } + return network, address +} + +// ParseConfig builds a *Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same +// defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely matches +// the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style). See +// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be +// empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file. +// +// # Example DSN +// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca +// +// # Example URL +// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca +// +// The returned *Config may be modified. However, it is strongly recommended that any configuration that can be done +// through the connection string be done there. In particular the fields Host, Port, TLSConfig, and Fallbacks can be +// interdependent (e.g. TLSConfig needs knowledge of the host to validate the server certificate). These fields should +// not be modified individually. They should all be modified or all left unchanged. +// +// ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated +// values that will be tried in order. This can be used as part of a high availability system. See +// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. +// +// # Example URL +// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb +// +// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed +// via database URL or DSN: +// +// PGHOST +// PGPORT +// PGDATABASE +// PGUSER +// PGPASSWORD +// PGPASSFILE +// PGSERVICE +// PGSERVICEFILE +// PGSSLMODE +// PGSSLCERT +// PGSSLKEY +// PGSSLROOTCERT +// PGAPPNAME +// PGCONNECT_TIMEOUT +// PGTARGETSESSIONATTRS +// +// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables. +// +// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are +// usually but not always the environment variable name downcased and without the "PG" prefix. +// +// Important Security Notes: +// +// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if +// not set. +// +// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of +// security each sslmode provides. +// +// The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of +// the Config struct. If TLSConfig is manually changed it will not affect the fallbacks. For example, in the case of +// sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try the fallback +// which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually +// changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting +// TLCConfig. +// +// Other known differences with libpq: +// +// If a host name resolves into multiple addresses, libpq will try all addresses. pgconn will only try the first. +// +// When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn +// does not. +// +// In addition, ParseConfig accepts the following options: +// +// min_read_buffer_size +// The minimum size of the internal read buffer. Default 8192. +// servicefile +// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a +// part of the connection string. +func ParseConfig(connString string) (*Config, error) { + defaultSettings := defaultSettings() + envSettings := parseEnvSettings() + + connStringSettings := make(map[string]string) + if connString != "" { + var err error + // connString may be a database URL or a DSN + if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { + connStringSettings, err = parseURLSettings(connString) + if err != nil { + return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err} + } + } else { + connStringSettings, err = parseDSNSettings(connString) + if err != nil { + return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err} + } + } + } + + settings := mergeSettings(defaultSettings, envSettings, connStringSettings) + if service, present := settings["service"]; present { + serviceSettings, err := parseServiceSettings(settings["servicefile"], service) + if err != nil { + return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err} + } + + settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings) + } + + minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32) + if err != nil { + return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err} + } + + config := &Config{ + createdByParseConfig: true, + Database: settings["database"], + User: settings["user"], + Password: settings["password"], + RuntimeParams: make(map[string]string), + BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)), + } + + if connectTimeoutSetting, present := settings["connect_timeout"]; present { + connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting) + if err != nil { + return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err} + } + config.ConnectTimeout = connectTimeout + config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout) + } else { + defaultDialer := makeDefaultDialer() + config.DialFunc = defaultDialer.DialContext + } + + config.LookupFunc = makeDefaultResolver().LookupHost + + notRuntimeParams := map[string]struct{}{ + "host": struct{}{}, + "port": struct{}{}, + "database": struct{}{}, + "user": struct{}{}, + "password": struct{}{}, + "passfile": struct{}{}, + "connect_timeout": struct{}{}, + "sslmode": struct{}{}, + "sslkey": struct{}{}, + "sslcert": struct{}{}, + "sslrootcert": struct{}{}, + "target_session_attrs": struct{}{}, + "min_read_buffer_size": struct{}{}, + "service": struct{}{}, + "servicefile": struct{}{}, + } + + for k, v := range settings { + if _, present := notRuntimeParams[k]; present { + continue + } + config.RuntimeParams[k] = v + } + + fallbacks := []*FallbackConfig{} + + hosts := strings.Split(settings["host"], ",") + ports := strings.Split(settings["port"], ",") + + for i, host := range hosts { + var portStr string + if i < len(ports) { + portStr = ports[i] + } else { + portStr = ports[0] + } + + port, err := parsePort(portStr) + if err != nil { + return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err} + } + + var tlsConfigs []*tls.Config + + // Ignore TLS settings if Unix domain socket like libpq + if network, _ := NetworkAddress(host, port); network == "unix" { + tlsConfigs = append(tlsConfigs, nil) + } else { + var err error + tlsConfigs, err = configTLS(settings, host) + if err != nil { + return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err} + } + } + + for _, tlsConfig := range tlsConfigs { + fallbacks = append(fallbacks, &FallbackConfig{ + Host: host, + Port: port, + TLSConfig: tlsConfig, + }) + } + } + + config.Host = fallbacks[0].Host + config.Port = fallbacks[0].Port + config.TLSConfig = fallbacks[0].TLSConfig + config.Fallbacks = fallbacks[1:] + + passfile, err := pgpassfile.ReadPassfile(settings["passfile"]) + if err == nil { + if config.Password == "" { + host := config.Host + if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" { + host = "localhost" + } + + config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User) + } + } + + if settings["target_session_attrs"] == "read-write" { + config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite + } else if settings["target_session_attrs"] != "any" { + return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", settings["target_session_attrs"])} + } + + return config, nil +} + +func mergeSettings(settingSets ...map[string]string) map[string]string { + settings := make(map[string]string) + + for _, s2 := range settingSets { + for k, v := range s2 { + settings[k] = v + } + } + + return settings +} + +func parseEnvSettings() map[string]string { + settings := make(map[string]string) + + nameMap := map[string]string{ + "PGHOST": "host", + "PGPORT": "port", + "PGDATABASE": "database", + "PGUSER": "user", + "PGPASSWORD": "password", + "PGPASSFILE": "passfile", + "PGAPPNAME": "application_name", + "PGCONNECT_TIMEOUT": "connect_timeout", + "PGSSLMODE": "sslmode", + "PGSSLKEY": "sslkey", + "PGSSLCERT": "sslcert", + "PGSSLROOTCERT": "sslrootcert", + "PGTARGETSESSIONATTRS": "target_session_attrs", + "PGSERVICE": "service", + "PGSERVICEFILE": "servicefile", + } + + for envname, realname := range nameMap { + value := os.Getenv(envname) + if value != "" { + settings[realname] = value + } + } + + return settings +} + +func parseURLSettings(connString string) (map[string]string, error) { + settings := make(map[string]string) + + url, err := url.Parse(connString) + if err != nil { + return nil, err + } + + if url.User != nil { + settings["user"] = url.User.Username() + if password, present := url.User.Password(); present { + settings["password"] = password + } + } + + // Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port. + var hosts []string + var ports []string + for _, host := range strings.Split(url.Host, ",") { + if host == "" { + continue + } + if isIPOnly(host) { + hosts = append(hosts, strings.Trim(host, "[]")) + continue + } + h, p, err := net.SplitHostPort(host) + if err != nil { + return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err) + } + if h != "" { + hosts = append(hosts, h) + } + if p != "" { + ports = append(ports, p) + } + } + if len(hosts) > 0 { + settings["host"] = strings.Join(hosts, ",") + } + if len(ports) > 0 { + settings["port"] = strings.Join(ports, ",") + } + + database := strings.TrimLeft(url.Path, "/") + if database != "" { + settings["database"] = database + } + + nameMap := map[string]string{ + "dbname": "database", + } + + for k, v := range url.Query() { + if k2, present := nameMap[k]; present { + k = k2 + } + + settings[k] = v[0] + } + + return settings, nil +} + +func isIPOnly(host string) bool { + return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":") +} + +var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1} + +func parseDSNSettings(s string) (map[string]string, error) { + settings := make(map[string]string) + + nameMap := map[string]string{ + "dbname": "database", + } + + for len(s) > 0 { + var key, val string + eqIdx := strings.IndexRune(s, '=') + if eqIdx < 0 { + return nil, errors.New("invalid dsn") + } + + key = strings.Trim(s[:eqIdx], " \t\n\r\v\f") + s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f") + if len(s) == 0 { + } else if s[0] != '\'' { + end := 0 + for ; end < len(s); end++ { + if asciiSpace[s[end]] == 1 { + break + } + if s[end] == '\\' { + end++ + if end == len(s) { + return nil, errors.New("invalid backslash") + } + } + } + val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) + if end == len(s) { + s = "" + } else { + s = s[end+1:] + } + } else { // quoted string + s = s[1:] + end := 0 + for ; end < len(s); end++ { + if s[end] == '\'' { + break + } + if s[end] == '\\' { + end++ + } + } + if end == len(s) { + return nil, errors.New("unterminated quoted string in connection info string") + } + val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) + if end == len(s) { + s = "" + } else { + s = s[end+1:] + } + } + + if k, ok := nameMap[key]; ok { + key = k + } + + if key == "" { + return nil, errors.New("invalid dsn") + } + + settings[key] = val + } + + return settings, nil +} + +func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) { + servicefile, err := pgservicefile.ReadServicefile(servicefilePath) + if err != nil { + return nil, fmt.Errorf("failed to read service file: %v", servicefilePath) + } + + service, err := servicefile.GetService(serviceName) + if err != nil { + return nil, fmt.Errorf("unable to find service: %v", serviceName) + } + + nameMap := map[string]string{ + "dbname": "database", + } + + settings := make(map[string]string, len(service.Settings)) + for k, v := range service.Settings { + if k2, present := nameMap[k]; present { + k = k2 + } + settings[k] = v + } + + return settings, nil +} + +// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is +// necessary to allow returning multiple TLS configs as sslmode "allow" and +// "prefer" allow fallback. +func configTLS(settings map[string]string, thisHost string) ([]*tls.Config, error) { + host := thisHost + sslmode := settings["sslmode"] + sslrootcert := settings["sslrootcert"] + sslcert := settings["sslcert"] + sslkey := settings["sslkey"] + + // Match libpq default behavior + if sslmode == "" { + sslmode = "prefer" + } + + tlsConfig := &tls.Config{} + + switch sslmode { + case "disable": + return []*tls.Config{nil}, nil + case "allow", "prefer": + tlsConfig.InsecureSkipVerify = true + case "require": + // According to PostgreSQL documentation, if a root CA file exists, + // the behavior of sslmode=require should be the same as that of verify-ca + // + // See https://www.postgresql.org/docs/12/libpq-ssl.html + if sslrootcert != "" { + goto nextCase + } + tlsConfig.InsecureSkipVerify = true + break + nextCase: + fallthrough + case "verify-ca": + // Don't perform the default certificate verification because it + // will verify the hostname. Instead, verify the server's + // certificate chain ourselves in VerifyPeerCertificate and + // ignore the server name. This emulates libpq's verify-ca + // behavior. + // + // See https://github.com/golang/go/issues/21971#issuecomment-332693931 + // and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate + // for more info. + tlsConfig.InsecureSkipVerify = true + tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error { + certs := make([]*x509.Certificate, len(certificates)) + for i, asn1Data := range certificates { + cert, err := x509.ParseCertificate(asn1Data) + if err != nil { + return errors.New("failed to parse certificate from server: " + err.Error()) + } + certs[i] = cert + } + + // Leave DNSName empty to skip hostname verification. + opts := x509.VerifyOptions{ + Roots: tlsConfig.RootCAs, + Intermediates: x509.NewCertPool(), + } + // Skip the first cert because it's the leaf. All others + // are intermediates. + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) + } + _, err := certs[0].Verify(opts) + return err + } + case "verify-full": + tlsConfig.ServerName = host + default: + return nil, errors.New("sslmode is invalid") + } + + if sslrootcert != "" { + caCertPool := x509.NewCertPool() + + caPath := sslrootcert + caCert, err := ioutil.ReadFile(caPath) + if err != nil { + return nil, fmt.Errorf("unable to read CA file: %w", err) + } + + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, errors.New("unable to add CA to cert pool") + } + + tlsConfig.RootCAs = caCertPool + tlsConfig.ClientCAs = caCertPool + } + + if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { + return nil, errors.New(`both "sslcert" and "sslkey" are required`) + } + + if sslcert != "" && sslkey != "" { + cert, err := tls.LoadX509KeyPair(sslcert, sslkey) + if err != nil { + return nil, fmt.Errorf("unable to read cert: %w", err) + } + + tlsConfig.Certificates = []tls.Certificate{cert} + } + + switch sslmode { + case "allow": + return []*tls.Config{nil, tlsConfig}, nil + case "prefer": + return []*tls.Config{tlsConfig, nil}, nil + case "require", "verify-ca", "verify-full": + return []*tls.Config{tlsConfig}, nil + default: + panic("BUG: bad sslmode should already have been caught") + } +} + +func parsePort(s string) (uint16, error) { + port, err := strconv.ParseUint(s, 10, 16) + if err != nil { + return 0, err + } + if port < 1 || port > math.MaxUint16 { + return 0, errors.New("outside range") + } + return uint16(port), nil +} + +func makeDefaultDialer() *net.Dialer { + return &net.Dialer{KeepAlive: 5 * time.Minute} +} + +func makeDefaultResolver() *net.Resolver { + return net.DefaultResolver +} + +func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc { + return func(r io.Reader, w io.Writer) Frontend { + cr, err := chunkreader.NewConfig(r, chunkreader.Config{MinBufLen: minBufferLen}) + if err != nil { + panic(fmt.Sprintf("BUG: chunkreader.NewConfig failed: %v", err)) + } + frontend := pgproto3.NewFrontend(cr, w) + + return frontend + } +} + +func parseConnectTimeoutSetting(s string) (time.Duration, error) { + timeout, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return 0, err + } + if timeout < 0 { + return 0, errors.New("negative timeout") + } + return time.Duration(timeout) * time.Second, nil +} + +func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc { + d := makeDefaultDialer() + d.Timeout = timeout + return d.DialContext +} + +// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible +// target_session_attrs=read-write. +func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { + result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() + if result.Err != nil { + return result.Err + } + + if string(result.Rows[0][0]) == "on" { + return errors.New("read only connection") + } + + return nil +} diff --git a/pgconn/config_test.go b/pgconn/config_test.go new file mode 100644 index 00000000..d29173d1 --- /dev/null +++ b/pgconn/config_test.go @@ -0,0 +1,900 @@ +package pgconn_test + +import ( + "context" + "crypto/tls" + "fmt" + "io/ioutil" + "os" + "os/user" + "runtime" + "strings" + "testing" + "time" + + "github.com/jackc/pgconn" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseConfig(t *testing.T) { + t.Parallel() + + var osUserName string + osUser, err := user.Current() + if err == nil { + // Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`, + // but the libpq default is just the `user` portion, so we strip off the first part. + if runtime.GOOS == "windows" && strings.Contains(osUser.Username, "\\") { + osUserName = osUser.Username[strings.LastIndex(osUser.Username, "\\")+1:] + } else { + osUserName = osUser.Username + } + } + + config, err := pgconn.ParseConfig("") + require.NoError(t, err) + defaultHost := config.Host + + tests := []struct { + name string + connString string + config *pgconn.Config + }{ + // Test all sslmodes + { + name: "sslmode not set (prefer)", + connString: "postgres://jack:secret@localhost:5432/mydb", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "sslmode disable", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "sslmode allow", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=allow", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 5432, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + }, + }, + }, + { + name: "sslmode prefer", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer", + config: &pgconn.Config{ + + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "sslmode require", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=require", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "sslmode verify-ca", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-ca", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "sslmode verify-full", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-full", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ServerName: "localhost"}, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url everything", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema&connect_timeout=5", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + ConnectTimeout: 5 * time.Second, + RuntimeParams: map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + }, + }, + }, + { + name: "database url missing password", + connString: "postgres://jack@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url missing user and password", + connString: "postgres://localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: osUserName, + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url missing port", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url unix domain socket host", + connString: "postgres:///foo?host=/tmp", + config: &pgconn.Config{ + User: osUserName, + Host: "/tmp", + Port: 5432, + Database: "foo", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url dbname", + connString: "postgres://localhost/?dbname=foo&sslmode=disable", + config: &pgconn.Config{ + User: osUserName, + Host: "localhost", + Port: 5432, + Database: "foo", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url postgresql protocol", + connString: "postgresql://jack@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url IPv4 with port", + connString: "postgresql://jack@127.0.0.1:5433/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "127.0.0.1", + Port: 5433, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url IPv6 with port", + connString: "postgresql://jack@[2001:db8::1]:5433/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "2001:db8::1", + Port: 5433, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url IPv6 no port", + connString: "postgresql://jack@[2001:db8::1]/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "2001:db8::1", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN everything", + connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema connect_timeout=5", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + ConnectTimeout: 5 * time.Second, + RuntimeParams: map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + }, + }, + }, + { + name: "DSN with escaped single quote", + connString: "user=jack\\'s password=secret host=localhost port=5432 dbname=mydb sslmode=disable", + config: &pgconn.Config{ + User: "jack's", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN with escaped backslash", + connString: "user=jack password=sooper\\\\secret host=localhost port=5432 dbname=mydb sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "sooper\\secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN with single quoted values", + connString: "user='jack' host='localhost' dbname='mydb' sslmode='disable'", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN with single quoted value with escaped single quote", + connString: "user='jack\\'s' host='localhost' dbname='mydb' sslmode='disable'", + config: &pgconn.Config{ + User: "jack's", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN with empty single quoted value", + connString: "user='jack' password='' host='localhost' dbname='mydb' sslmode='disable'", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN with space between key and value", + connString: "user = 'jack' password = '' host = 'localhost' dbname = 'mydb' sslmode='disable'", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "URL multiple hosts", + connString: "postgres://jack:secret@foo,bar,baz/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "bar", + Port: 5432, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "URL multiple hosts and ports", + connString: "postgres://jack:secret@foo:1,bar:2,baz:3/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 1, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "bar", + Port: 2, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 3, + TLSConfig: nil, + }, + }, + }, + }, + // https://github.com/jackc/pgconn/issues/72 + { + name: "URL without host but with port still uses default host", + connString: "postgres://jack:secret@:1/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: defaultHost, + Port: 1, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN multiple hosts one port", + connString: "user=jack password=secret host=foo,bar,baz port=5432 dbname=mydb sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "bar", + Port: 5432, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "DSN multiple hosts multiple ports", + connString: "user=jack password=secret host=foo,bar,baz port=1,2,3 dbname=mydb sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 1, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "bar", + Port: 2, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 3, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "multiple hosts and fallback tsl", + connString: "user=jack password=secret host=foo,bar,baz dbname=mydb sslmode=prefer", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "foo", + Port: 5432, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "bar", + Port: 5432, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }}, + &pgconn.FallbackConfig{ + Host: "bar", + Port: 5432, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 5432, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }}, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "target_session_attrs", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsReadWrite, + }, + }, + } + + for i, tt := range tests { + config, err := pgconn.ParseConfig(tt.connString) + if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) { + continue + } + + assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name)) + } +} + +// https://github.com/jackc/pgconn/issues/47 +func TestParseConfigDSNWithTrailingEmptyEqualDoesNotPanic(t *testing.T) { + _, err := pgconn.ParseConfig("host= user= password= port= database=") + require.NoError(t, err) +} + +func TestParseConfigDSNLeadingEqual(t *testing.T) { + _, err := pgconn.ParseConfig("= user=jack") + require.Error(t, err) +} + +// https://github.com/jackc/pgconn/issues/49 +func TestParseConfigDSNTrailingBackslash(t *testing.T) { + _, err := pgconn.ParseConfig(`x=x\`) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid backslash") +} + +func TestConfigCopyReturnsEqualConfig(t *testing.T) { + connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5" + original, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + copied := original.Copy() + assertConfigsEqual(t, original, copied, "Test Config.Copy() returns equal config") +} + +func TestConfigCopyOriginalConfigDidNotChange(t *testing.T) { + connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5&sslmode=prefer" + original, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + copied := original.Copy() + assertConfigsEqual(t, original, copied, "Test Config.Copy() returns equal config") + + copied.Port = uint16(5433) + copied.RuntimeParams["foo"] = "bar" + copied.Fallbacks[0].Port = uint16(5433) + + assert.Equal(t, uint16(5432), original.Port) + assert.Equal(t, "", original.RuntimeParams["foo"]) + assert.Equal(t, uint16(5432), original.Fallbacks[0].Port) +} + +func TestConfigCopyCanBeUsedToConnect(t *testing.T) { + connString := os.Getenv("PGX_TEST_CONN_STRING") + original, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + copied := original.Copy() + assert.NotPanics(t, func() { + _, err = pgconn.ConnectConfig(context.Background(), copied) + }) + assert.NoError(t, err) +} + +func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) { + if !assert.NotNil(t, expected) { + return + } + if !assert.NotNil(t, actual) { + return + } + + assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) + assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) + assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) + assert.Equalf(t, expected.User, actual.User, "%s - User", testName) + assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName) + assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName) + assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) + + // Can't test function equality, so just test that they are set or not. + assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName) + assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) + + if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { + if expected.TLSConfig != nil { + assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName) + assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName) + } + } + + if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) { + for i := range expected.Fallbacks { + assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i) + assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i) + + if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) { + if expected.Fallbacks[i].TLSConfig != nil { + assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName) + assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName) + } + } + } + } +} + +func TestParseConfigEnvLibpq(t *testing.T) { + var osUserName string + osUser, err := user.Current() + if err == nil { + // Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`, + // but the libpq default is just the `user` portion, so we strip off the first part. + if runtime.GOOS == "windows" && strings.Contains(osUser.Username, "\\") { + osUserName = osUser.Username[strings.LastIndex(osUser.Username, "\\")+1:] + } else { + osUserName = osUser.Username + } + } + + pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT"} + + savedEnv := make(map[string]string) + for _, n := range pgEnvvars { + savedEnv[n] = os.Getenv(n) + } + defer func() { + for k, v := range savedEnv { + err := os.Setenv(k, v) + if err != nil { + t.Fatalf("Unable to restore environment: %v", err) + } + } + }() + + tests := []struct { + name string + envvars map[string]string + config *pgconn.Config + }{ + { + // not testing no environment at all as that would use default host and that can vary. + name: "PGHOST only", + envvars: map[string]string{"PGHOST": "123.123.123.123"}, + config: &pgconn.Config{ + User: osUserName, + Host: "123.123.123.123", + Port: 5432, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "123.123.123.123", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "All non-TLS environment", + envvars: map[string]string{ + "PGHOST": "123.123.123.123", + "PGPORT": "7777", + "PGDATABASE": "foo", + "PGUSER": "bar", + "PGPASSWORD": "baz", + "PGCONNECT_TIMEOUT": "10", + "PGSSLMODE": "disable", + "PGAPPNAME": "pgxtest", + }, + config: &pgconn.Config{ + Host: "123.123.123.123", + Port: 7777, + Database: "foo", + User: "bar", + Password: "baz", + ConnectTimeout: 10 * time.Second, + TLSConfig: nil, + RuntimeParams: map[string]string{"application_name": "pgxtest"}, + }, + }, + } + + for i, tt := range tests { + for _, n := range pgEnvvars { + err := os.Unsetenv(n) + require.NoError(t, err) + } + + for k, v := range tt.envvars { + err := os.Setenv(k, v) + require.NoError(t, err) + } + + config, err := pgconn.ParseConfig("") + if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) { + continue + } + + assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name)) + } +} + +func TestParseConfigReadsPgPassfile(t *testing.T) { + t.Parallel() + + tf, err := ioutil.TempFile("", "") + require.NoError(t, err) + + defer tf.Close() + defer os.Remove(tf.Name()) + + _, err = tf.Write([]byte("test1:5432:curlydb:curly:nyuknyuknyuk")) + require.NoError(t, err) + + connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tf.Name()) + expected := &pgconn.Config{ + User: "curly", + Password: "nyuknyuknyuk", + Host: "test1", + Port: 5432, + Database: "curlydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + } + + actual, err := pgconn.ParseConfig(connString) + assert.NoError(t, err) + + assertConfigsEqual(t, expected, actual, "passfile") +} + +func TestParseConfigReadsPgServiceFile(t *testing.T) { + t.Parallel() + + tf, err := ioutil.TempFile("", "") + require.NoError(t, err) + + defer tf.Close() + defer os.Remove(tf.Name()) + + _, err = tf.Write([]byte(` +[abc] +host=abc.example.com +port=9999 +dbname=abcdb +user=abcuser + +[def] +host = def.example.com +dbname = defdb +user = defuser +application_name = spaced string +`)) + require.NoError(t, err) + + tests := []struct { + name string + connString string + config *pgconn.Config + }{ + { + name: "abc", + connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tf.Name(), "abc"), + config: &pgconn.Config{ + Host: "abc.example.com", + Database: "abcdb", + User: "abcuser", + Port: 9999, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "abc.example.com", + Port: 9999, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "def", + connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tf.Name(), "def"), + config: &pgconn.Config{ + Host: "def.example.com", + Port: 5432, + Database: "defdb", + User: "defuser", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{"application_name": "spaced string"}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "def.example.com", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "conn string has precedence", + connString: fmt.Sprintf("postgres://other.example.com:7777/?servicefile=%s&service=%s&sslmode=disable", tf.Name(), "abc"), + config: &pgconn.Config{ + Host: "other.example.com", + Database: "abcdb", + User: "abcuser", + Port: 7777, + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + } + + for i, tt := range tests { + config, err := pgconn.ParseConfig(tt.connString) + if !assert.NoErrorf(t, err, "Test %d (%s)", i, tt.name) { + continue + } + + assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name)) + } +} + +func TestParseConfigExtractsMinReadBufferSize(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig("min_read_buffer_size=0") + require.NoError(t, err) + _, present := config.RuntimeParams["min_read_buffer_size"] + require.False(t, present) + + // The buffer size is internal so there isn't much that can be done to test it other than see that the runtime param + // was removed. +} diff --git a/pgconn/defaults.go b/pgconn/defaults.go new file mode 100644 index 00000000..f69cad31 --- /dev/null +++ b/pgconn/defaults.go @@ -0,0 +1,64 @@ +// +build !windows + +package pgconn + +import ( + "os" + "os/user" + "path/filepath" +) + +func defaultSettings() map[string]string { + settings := make(map[string]string) + + settings["host"] = defaultHost() + settings["port"] = "5432" + + // Default to the OS user name. Purposely ignoring err getting user name from + // OS. The client application will simply have to specify the user in that + // case (which they typically will be doing anyway). + user, err := user.Current() + if err == nil { + settings["user"] = user.Username + settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass") + settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf") + sslcert := filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") + sslkey := filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") + if _, err := os.Stat(sslcert); err == nil { + if _, err := os.Stat(sslkey); err == nil { + // Both the cert and key must be present to use them, or do not use either + settings["sslcert"] = sslcert + settings["sslkey"] = sslkey + } + } + sslrootcert := filepath.Join(user.HomeDir, ".postgresql", "root.crt") + if _, err := os.Stat(sslrootcert); err == nil { + settings["sslrootcert"] = sslrootcert + } + } + + settings["target_session_attrs"] = "any" + + settings["min_read_buffer_size"] = "8192" + + return settings +} + +// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost +// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it +// checks the existence of common locations. +func defaultHost() string { + candidatePaths := []string{ + "/var/run/postgresql", // Debian + "/private/tmp", // OSX - homebrew + "/tmp", // standard PostgreSQL + } + + for _, path := range candidatePaths { + if _, err := os.Stat(path); err == nil { + return path + } + } + + return "localhost" +} diff --git a/pgconn/defaults_windows.go b/pgconn/defaults_windows.go new file mode 100644 index 00000000..71eb77db --- /dev/null +++ b/pgconn/defaults_windows.go @@ -0,0 +1,59 @@ +package pgconn + +import ( + "os" + "os/user" + "path/filepath" + "strings" +) + +func defaultSettings() map[string]string { + settings := make(map[string]string) + + settings["host"] = defaultHost() + settings["port"] = "5432" + + // Default to the OS user name. Purposely ignoring err getting user name from + // OS. The client application will simply have to specify the user in that + // case (which they typically will be doing anyway). + user, err := user.Current() + appData := os.Getenv("APPDATA") + if err == nil { + // Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`, + // but the libpq default is just the `user` portion, so we strip off the first part. + username := user.Username + if strings.Contains(username, "\\") { + username = username[strings.LastIndex(username, "\\")+1:] + } + + settings["user"] = username + settings["passfile"] = filepath.Join(appData, "postgresql", "pgpass.conf") + settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf") + sslcert := filepath.Join(appData, "postgresql", "postgresql.crt") + sslkey := filepath.Join(appData, "postgresql", "postgresql.key") + if _, err := os.Stat(sslcert); err == nil { + if _, err := os.Stat(sslkey); err == nil { + // Both the cert and key must be present to use them, or do not use either + settings["sslcert"] = sslcert + settings["sslkey"] = sslkey + } + } + sslrootcert := filepath.Join(appData, "postgresql", "root.crt") + if _, err := os.Stat(sslrootcert); err == nil { + settings["sslrootcert"] = sslrootcert + } + } + + settings["target_session_attrs"] = "any" + + settings["min_read_buffer_size"] = "8192" + + return settings +} + +// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost +// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it +// checks the existence of common locations. +func defaultHost() string { + return "localhost" +} diff --git a/pgconn/doc.go b/pgconn/doc.go new file mode 100644 index 00000000..cde58cd8 --- /dev/null +++ b/pgconn/doc.go @@ -0,0 +1,29 @@ +// Package pgconn is a low-level PostgreSQL database driver. +/* +pgconn provides lower level access to a PostgreSQL connection than a database/sql or pgx connection. It operates at +nearly the same level is the C library libpq. + +Establishing a Connection + +Use Connect to establish a connection. It accepts a connection string in URL or DSN and will read the environment for +libpq style environment variables. + +Executing a Query + +ExecParams and ExecPrepared execute a single query. They return readers that iterate over each row. The Read method +reads all rows into memory. + +Executing Multiple Queries in a Single Round Trip + +Exec and ExecBatch can execute multiple queries in a single round trip. They return readers that iterate over each query +result. The ReadAll method reads all query results into memory. + +Context Support + +All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the +method immediately returns. In most circumstances, this will close the underlying connection. + +The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the +client to abort. +*/ +package pgconn diff --git a/pgconn/errors.go b/pgconn/errors.go new file mode 100644 index 00000000..a32b29c9 --- /dev/null +++ b/pgconn/errors.go @@ -0,0 +1,221 @@ +package pgconn + +import ( + "context" + "errors" + "fmt" + "net" + "net/url" + "regexp" + "strings" +) + +// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server. +func SafeToRetry(err error) bool { + if e, ok := err.(interface{ SafeToRetry() bool }); ok { + return e.SafeToRetry() + } + return false +} + +// Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a +// context.Canceled, context.DeadlineExceeded or an implementer of net.Error where Timeout() is true. +func Timeout(err error) bool { + var timeoutErr *errTimeout + return errors.As(err, &timeoutErr) +} + +// PgError represents an error reported by the PostgreSQL server. See +// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for +// detailed field description. +type PgError struct { + Severity string + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string +} + +func (pe *PgError) Error() string { + return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" +} + +// SQLState returns the SQLState of the error. +func (pe *PgError) SQLState() string { + return pe.Code +} + +type connectError struct { + config *Config + msg string + err error +} + +func (e *connectError) Error() string { + sb := &strings.Builder{} + fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg) + if e.err != nil { + fmt.Fprintf(sb, " (%s)", e.err.Error()) + } + return sb.String() +} + +func (e *connectError) Unwrap() error { + return e.err +} + +type connLockError struct { + status string +} + +func (e *connLockError) SafeToRetry() bool { + return true // a lock failure by definition happens before the connection is used. +} + +func (e *connLockError) Error() string { + return e.status +} + +type parseConfigError struct { + connString string + msg string + err error +} + +func (e *parseConfigError) Error() string { + connString := redactPW(e.connString) + if e.err == nil { + return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg) + } + return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error()) +} + +func (e *parseConfigError) Unwrap() error { + return e.err +} + +// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == +// true. Otherwise returns err. +func preferContextOverNetTimeoutError(ctx context.Context, err error) error { + if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil { + return &errTimeout{err: ctx.Err()} + } + return err +} + +type pgconnError struct { + msg string + err error + safeToRetry bool +} + +func (e *pgconnError) Error() string { + if e.msg == "" { + return e.err.Error() + } + if e.err == nil { + return e.msg + } + return fmt.Sprintf("%s: %s", e.msg, e.err.Error()) +} + +func (e *pgconnError) SafeToRetry() bool { + return e.safeToRetry +} + +func (e *pgconnError) Unwrap() error { + return e.err +} + +// errTimeout occurs when an error was caused by a timeout. Specifically, it wraps an error which is +// context.Canceled, context.DeadlineExceeded, or an implementer of net.Error where Timeout() is true. +type errTimeout struct { + err error +} + +func (e *errTimeout) Error() string { + return fmt.Sprintf("timeout: %s", e.err.Error()) +} + +func (e *errTimeout) SafeToRetry() bool { + return SafeToRetry(e.err) +} + +func (e *errTimeout) Unwrap() error { + return e.err +} + +type contextAlreadyDoneError struct { + err error +} + +func (e *contextAlreadyDoneError) Error() string { + return fmt.Sprintf("context already done: %s", e.err.Error()) +} + +func (e *contextAlreadyDoneError) SafeToRetry() bool { + return true +} + +func (e *contextAlreadyDoneError) Unwrap() error { + return e.err +} + +// newContextAlreadyDoneError double-wraps a context error in `contextAlreadyDoneError` and `errTimeout`. +func newContextAlreadyDoneError(ctx context.Context) (err error) { + return &errTimeout{&contextAlreadyDoneError{err: ctx.Err()}} +} + +type writeError struct { + err error + safeToRetry bool +} + +func (e *writeError) Error() string { + return fmt.Sprintf("write failed: %s", e.err.Error()) +} + +func (e *writeError) SafeToRetry() bool { + return e.safeToRetry +} + +func (e *writeError) Unwrap() error { + return e.err +} + +func redactPW(connString string) string { + if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { + if u, err := url.Parse(connString); err == nil { + return redactURL(u) + } + } + quotedDSN := regexp.MustCompile(`password='[^']*'`) + connString = quotedDSN.ReplaceAllLiteralString(connString, "password=xxxxx") + plainDSN := regexp.MustCompile(`password=[^ ]*`) + connString = plainDSN.ReplaceAllLiteralString(connString, "password=xxxxx") + brokenURL := regexp.MustCompile(`:[^:@]+?@`) + connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@") + return connString +} + +func redactURL(u *url.URL) string { + if u == nil { + return "" + } + if _, pwSet := u.User.Password(); pwSet { + u.User = url.UserPassword(u.User.Username(), "xxxxx") + } + return u.String() +} diff --git a/pgconn/errors_test.go b/pgconn/errors_test.go new file mode 100644 index 00000000..1bff3656 --- /dev/null +++ b/pgconn/errors_test.go @@ -0,0 +1,54 @@ +package pgconn_test + +import ( + "testing" + + "github.com/jackc/pgconn" + "github.com/stretchr/testify/assert" +) + +func TestConfigError(t *testing.T) { + tests := []struct { + name string + err error + expectedMsg string + }{ + { + name: "url with password", + err: pgconn.NewParseConfigError("postgresql://foo:password@host", "msg", nil), + expectedMsg: "cannot parse `postgresql://foo:xxxxx@host`: msg", + }, + { + name: "dsn with password unquoted", + err: pgconn.NewParseConfigError("host=host password=password user=user", "msg", nil), + expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg", + }, + { + name: "dsn with password quoted", + err: pgconn.NewParseConfigError("host=host password='pass word' user=user", "msg", nil), + expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg", + }, + { + name: "weird url", + err: pgconn.NewParseConfigError("postgresql://foo::pasword@host:1:", "msg", nil), + expectedMsg: "cannot parse `postgresql://foo:xxxxx@host:1:`: msg", + }, + { + name: "weird url with slash in password", + err: pgconn.NewParseConfigError("postgres://user:pass/word@host:5432/db_name", "msg", nil), + expectedMsg: "cannot parse `postgres://user:xxxxxx@host:5432/db_name`: msg", + }, + { + name: "url without password", + err: pgconn.NewParseConfigError("postgresql://other@host/db", "msg", nil), + expectedMsg: "cannot parse `postgresql://other@host/db`: msg", + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.EqualError(t, tt.err, tt.expectedMsg) + }) + } +} diff --git a/pgconn/export_test.go b/pgconn/export_test.go new file mode 100644 index 00000000..2a0bad8b --- /dev/null +++ b/pgconn/export_test.go @@ -0,0 +1,11 @@ +// File export_test exports some methods for better testing. + +package pgconn + +func NewParseConfigError(conn, msg string, err error) error { + return &parseConfigError{ + connString: conn, + msg: msg, + err: err, + } +} diff --git a/pgconn/frontend_test.go b/pgconn/frontend_test.go new file mode 100644 index 00000000..b82552bf --- /dev/null +++ b/pgconn/frontend_test.go @@ -0,0 +1,70 @@ +package pgconn_test + +import ( + "context" + "io" + "os" + "testing" + + "github.com/jackc/pgconn" + "github.com/jackc/pgproto3/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// frontendWrapper allows to hijack a regular frontend, and inject a specific response +type frontendWrapper struct { + front pgconn.Frontend + + msg pgproto3.BackendMessage +} + +// frontendWrapper implements the pgconn.Frontend interface +var _ pgconn.Frontend = (*frontendWrapper)(nil) + +func (f *frontendWrapper) Receive() (pgproto3.BackendMessage, error) { + if f.msg != nil { + return f.msg, nil + } + + return f.front.Receive() +} + +func TestFrontendFatalErrExec(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + buildFrontend := config.BuildFrontend + var front *frontendWrapper + + config.BuildFrontend = func(r io.Reader, w io.Writer) pgconn.Frontend { + wrapped := buildFrontend(r, w) + front = &frontendWrapper{wrapped, nil} + + return front + } + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + require.NotNil(t, conn) + require.NotNil(t, front) + + // set frontend to return a "FATAL" message on next call + front.msg = &pgproto3.ErrorResponse{Severity: "FATAL", Message: "unit testing fatal error"} + + _, err = conn.Exec(context.Background(), "SELECT 1").ReadAll() + assert.Error(t, err) + + err = conn.Close(context.Background()) + assert.NoError(t, err) + + select { + case <-conn.CleanupDone(): + t.Log("ok, CleanupDone() is not blocking") + + default: + assert.Fail(t, "connection closed but CleanupDone() still blocking") + } +} diff --git a/pgconn/go.mod b/pgconn/go.mod new file mode 100644 index 00000000..6fdd0e97 --- /dev/null +++ b/pgconn/go.mod @@ -0,0 +1,15 @@ +module github.com/jackc/pgconn + +go 1.12 + +require ( + github.com/jackc/chunkreader/v2 v2.0.1 + github.com/jackc/pgio v1.0.0 + github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 + github.com/jackc/pgpassfile v1.0.0 + github.com/jackc/pgproto3/v2 v2.1.1 + github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b + github.com/stretchr/testify v1.7.0 + golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 + golang.org/x/text v0.3.6 +) diff --git a/pgconn/go.sum b/pgconn/go.sum new file mode 100644 index 00000000..3c77ee21 --- /dev/null +++ b/pgconn/go.sum @@ -0,0 +1,130 @@ +github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= +github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= +github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= +github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= +github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= +github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= +github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= +github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= +github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= +github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= +github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= +github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= +github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= +github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.1.1 h1:7PQ/4gLoqnl87ZxL7xjO0DR5gYuviDCZxQJsUlFW1eI= +github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= +github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= +github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= +github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= +github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= +github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= +github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= +github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= +github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= +github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= +github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= +github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= +go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pgconn/helper_test.go b/pgconn/helper_test.go new file mode 100644 index 00000000..87613dc9 --- /dev/null +++ b/pgconn/helper_test.go @@ -0,0 +1,36 @@ +package pgconn_test + +import ( + "context" + "testing" + "time" + + "github.com/jackc/pgconn" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func closeConn(t testing.TB, conn *pgconn.PgConn) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + require.NoError(t, conn.Close(ctx)) + select { + case <-conn.CleanupDone(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } +} + +// Do a simple query to ensure the connection is still usable +func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + result := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil).Read() + cancel() + + require.Nil(t, result.Err) + assert.Equal(t, 3, len(result.Rows)) + assert.Equal(t, "1", string(result.Rows[0][0])) + assert.Equal(t, "2", string(result.Rows[1][0])) + assert.Equal(t, "3", string(result.Rows[2][0])) +} diff --git a/pgconn/internal/ctxwatch/context_watcher.go b/pgconn/internal/ctxwatch/context_watcher.go new file mode 100644 index 00000000..b39cb3ee --- /dev/null +++ b/pgconn/internal/ctxwatch/context_watcher.go @@ -0,0 +1,73 @@ +package ctxwatch + +import ( + "context" + "sync" +) + +// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a +// time. +type ContextWatcher struct { + onCancel func() + onUnwatchAfterCancel func() + unwatchChan chan struct{} + + lock sync.Mutex + watchInProgress bool + onCancelWasCalled bool +} + +// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled. +// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and +// onCancel called. +func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher { + cw := &ContextWatcher{ + onCancel: onCancel, + onUnwatchAfterCancel: onUnwatchAfterCancel, + unwatchChan: make(chan struct{}), + } + + return cw +} + +// Watch starts watching ctx. If ctx is canceled then the onCancel function passed to NewContextWatcher will be called. +func (cw *ContextWatcher) Watch(ctx context.Context) { + cw.lock.Lock() + defer cw.lock.Unlock() + + if cw.watchInProgress { + panic("Watch already in progress") + } + + cw.onCancelWasCalled = false + + if ctx.Done() != nil { + cw.watchInProgress = true + go func() { + select { + case <-ctx.Done(): + cw.onCancel() + cw.onCancelWasCalled = true + <-cw.unwatchChan + case <-cw.unwatchChan: + } + }() + } else { + cw.watchInProgress = false + } +} + +// Unwatch stops watching the previously watched context. If the onCancel function passed to NewContextWatcher was +// called then onUnwatchAfterCancel will also be called. +func (cw *ContextWatcher) Unwatch() { + cw.lock.Lock() + defer cw.lock.Unlock() + + if cw.watchInProgress { + cw.unwatchChan <- struct{}{} + if cw.onCancelWasCalled { + cw.onUnwatchAfterCancel() + } + cw.watchInProgress = false + } +} diff --git a/pgconn/internal/ctxwatch/context_watcher_test.go b/pgconn/internal/ctxwatch/context_watcher_test.go new file mode 100644 index 00000000..289606c3 --- /dev/null +++ b/pgconn/internal/ctxwatch/context_watcher_test.go @@ -0,0 +1,163 @@ +package ctxwatch_test + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/jackc/pgconn/internal/ctxwatch" + "github.com/stretchr/testify/require" +) + +func TestContextWatcherContextCancelled(t *testing.T) { + canceledChan := make(chan struct{}) + cleanupCalled := false + cw := ctxwatch.NewContextWatcher(func() { + canceledChan <- struct{}{} + }, func() { + cleanupCalled = true + }) + + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + cancel() + + select { + case <-canceledChan: + case <-time.NewTimer(time.Second).C: + t.Fatal("Timed out waiting for cancel func to be called") + } + + cw.Unwatch() + + require.True(t, cleanupCalled, "Cleanup func was not called") +} + +func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) { + cw := ctxwatch.NewContextWatcher(func() { + t.Error("cancel func should not have been called") + }, func() { + t.Error("cleanup func should not have been called") + }) + + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + cw.Unwatch() + cancel() +} + +func TestContextWatcherMultipleWatchPanics(t *testing.T) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cw.Watch(ctx) + + ctx2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + require.Panics(t, func() { cw.Watch(ctx2) }, "Expected panic when Watch called multiple times") +} + +func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + cw.Unwatch() // unwatch when not / never watching + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cw.Watch(ctx) + cw.Unwatch() + cw.Unwatch() // double unwatch +} + +func TestContextWatcherUnwatchIsConcurrencySafe(t *testing.T) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + cw.Watch(ctx) + + go cw.Unwatch() + go cw.Unwatch() + + <-ctx.Done() +} + +func TestContextWatcherStress(t *testing.T) { + var cancelFuncCalls int64 + var cleanupFuncCalls int64 + + cw := ctxwatch.NewContextWatcher(func() { + atomic.AddInt64(&cancelFuncCalls, 1) + }, func() { + atomic.AddInt64(&cleanupFuncCalls, 1) + }) + + cycleCount := 100000 + + for i := 0; i < cycleCount; i++ { + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + if i%2 == 0 { + cancel() + } + + // Without time.Sleep, cw.Unwatch will almost always run before the cancel func which means cancel will never happen. This gives us a better mix. + if i%3 == 0 { + time.Sleep(time.Nanosecond) + } + + cw.Unwatch() + if i%2 == 1 { + cancel() + } + } + + actualCancelFuncCalls := atomic.LoadInt64(&cancelFuncCalls) + actualCleanupFuncCalls := atomic.LoadInt64(&cleanupFuncCalls) + + if actualCancelFuncCalls == 0 { + t.Fatal("actualCancelFuncCalls == 0") + } + + maxCancelFuncCalls := int64(cycleCount) / 2 + if actualCancelFuncCalls > maxCancelFuncCalls { + t.Errorf("cancel func calls should be no more than %d but was %d", actualCancelFuncCalls, maxCancelFuncCalls) + } + + if actualCancelFuncCalls != actualCleanupFuncCalls { + t.Errorf("cancel func calls (%d) should be equal to cleanup func calls (%d) but was not", actualCancelFuncCalls, actualCleanupFuncCalls) + } +} + +func BenchmarkContextWatcherUncancellable(b *testing.B) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + + for i := 0; i < b.N; i++ { + cw.Watch(context.Background()) + cw.Unwatch() + } +} + +func BenchmarkContextWatcherCancelled(b *testing.B) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + + for i := 0; i < b.N; i++ { + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + cancel() + cw.Unwatch() + } +} + +func BenchmarkContextWatcherCancellable(b *testing.B) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for i := 0; i < b.N; i++ { + cw.Watch(ctx) + cw.Unwatch() + } +} diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go new file mode 100644 index 00000000..382ad33c --- /dev/null +++ b/pgconn/pgconn.go @@ -0,0 +1,1703 @@ +package pgconn + +import ( + "context" + "crypto/md5" + "crypto/tls" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "io" + "math" + "net" + "strings" + "sync" + "time" + + "github.com/jackc/pgconn/internal/ctxwatch" + "github.com/jackc/pgio" + "github.com/jackc/pgproto3/v2" +) + +const ( + connStatusUninitialized = iota + connStatusConnecting + connStatusClosed + connStatusIdle + connStatusBusy +) + +const wbufLen = 1024 + +// Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from +// LISTEN/NOTIFY notification. +type Notice PgError + +// Notification is a message received from the PostgreSQL LISTEN/NOTIFY system +type Notification struct { + PID uint32 // backend pid that sent the notification + Channel string // channel from which notification was received + Payload string +} + +// DialFunc is a function that can be used to connect to a PostgreSQL server. +type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) + +// LookupFunc is a function that can be used to lookup IPs addrs from host. +type LookupFunc func(ctx context.Context, host string) (addrs []string, err error) + +// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. +type BuildFrontendFunc func(r io.Reader, w io.Writer) Frontend + +// NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at +// any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin +// of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY +// notification. +type NoticeHandler func(*PgConn, *Notice) + +// NotificationHandler is a function that can handle notifications received from the PostgreSQL server. Notifications +// can be received at any time, usually during handling of a query response. The *PgConn is provided so the handler is +// aware of the origin of the notice, but it must not invoke any query method. Be aware that this is distinct from a +// notice event. +type NotificationHandler func(*PgConn, *Notification) + +// Frontend used to receive messages from backend. +type Frontend interface { + Receive() (pgproto3.BackendMessage, error) +} + +// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. +type PgConn struct { + conn net.Conn // the underlying TCP or unix domain socket connection + pid uint32 // backend pid + secretKey uint32 // key to use to send a cancel query message to the server + parameterStatuses map[string]string // parameters that have been reported by the server + txStatus byte + frontend Frontend + + config *Config + + status byte // One of connStatus* constants + + bufferingReceive bool + bufferingReceiveMux sync.Mutex + bufferingReceiveMsg pgproto3.BackendMessage + bufferingReceiveErr error + + peekedMsg pgproto3.BackendMessage + + // Reusable / preallocated resources + wbuf []byte // write buffer + resultReader ResultReader + multiResultReader MultiResultReader + contextWatcher *ctxwatch.ContextWatcher + + cleanupDone chan struct{} +} + +// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) +// to provide configuration. See documention for ParseConfig for details. ctx can be used to cancel a connect attempt. +func Connect(ctx context.Context, connString string) (*PgConn, error) { + config, err := ParseConfig(connString) + if err != nil { + return nil, err + } + + return ConnectConfig(ctx, config) +} + +// Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with +// ParseConfig. ctx can be used to cancel a connect attempt. +// +// If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An +// authentication error will terminate the chain of attempts (like libpq: +// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise, +// if all attempts fail the last error is returned. +func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err error) { + // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from + // zero values. + if !config.createdByParseConfig { + panic("config must be created by ParseConfig") + } + + // ConnectTimeout restricts the whole connection process. + if config.ConnectTimeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, config.ConnectTimeout) + defer cancel() + } + // Simplify usage by treating primary config and fallbacks the same. + fallbackConfigs := []*FallbackConfig{ + { + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }, + } + fallbackConfigs = append(fallbackConfigs, config.Fallbacks...) + + fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs) + if err != nil { + return nil, &connectError{config: config, msg: "hostname resolving error", err: err} + } + + if len(fallbackConfigs) == 0 { + return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")} + } + + for _, fc := range fallbackConfigs { + pgConn, err = connect(ctx, config, fc) + if err == nil { + break + } else if pgerr, ok := err.(*PgError); ok { + err = &connectError{config: config, msg: "server error", err: pgerr} + ERRCODE_INVALID_PASSWORD := "28P01" // worng password + ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION := "28000" // db does not exist + if pgerr.Code == ERRCODE_INVALID_PASSWORD || pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION { + break + } + } + } + + if err != nil { + return nil, err // no need to wrap in connectError because it will already be wrapped in all cases except PgError + } + + if config.AfterConnect != nil { + err := config.AfterConnect(ctx, pgConn) + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "AfterConnect error", err: err} + } + } + + return pgConn, nil +} + +func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*FallbackConfig) ([]*FallbackConfig, error) { + var configs []*FallbackConfig + + for _, fb := range fallbacks { + // skip resolve for unix sockets + if strings.HasPrefix(fb.Host, "/") { + configs = append(configs, &FallbackConfig{ + Host: fb.Host, + Port: fb.Port, + TLSConfig: fb.TLSConfig, + }) + + continue + } + + ips, err := lookupFn(ctx, fb.Host) + if err != nil { + return nil, err + } + + for _, ip := range ips { + configs = append(configs, &FallbackConfig{ + Host: ip, + Port: fb.Port, + TLSConfig: fb.TLSConfig, + }) + } + } + + return configs, nil +} + +func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { + pgConn := new(PgConn) + pgConn.config = config + pgConn.wbuf = make([]byte, 0, wbufLen) + pgConn.cleanupDone = make(chan struct{}) + + var err error + network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) + pgConn.conn, err = config.DialFunc(ctx, network, address) + if err != nil { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + err = &errTimeout{err: err} + } + return nil, &connectError{config: config, msg: "dial error", err: err} + } + + pgConn.parameterStatuses = make(map[string]string) + + if fallbackConfig.TLSConfig != nil { + if err := pgConn.startTLS(fallbackConfig.TLSConfig); err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "tls error", err: err} + } + } + + pgConn.status = connStatusConnecting + pgConn.contextWatcher = ctxwatch.NewContextWatcher( + func() { pgConn.conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, + func() { pgConn.conn.SetDeadline(time.Time{}) }, + ) + + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + + pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn) + + startupMsg := pgproto3.StartupMessage{ + ProtocolVersion: pgproto3.ProtocolVersionNumber, + Parameters: make(map[string]string), + } + + // Copy default run-time params + for k, v := range config.RuntimeParams { + startupMsg.Parameters[k] = v + } + + startupMsg.Parameters["user"] = config.User + if config.Database != "" { + startupMsg.Parameters["database"] = config.Database + } + + if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "failed to write startup message", err: err} + } + + for { + msg, err := pgConn.receiveMessage() + if err != nil { + pgConn.conn.Close() + if err, ok := err.(*PgError); ok { + return nil, err + } + return nil, &connectError{config: config, msg: "failed to receive message", err: preferContextOverNetTimeoutError(ctx, err)} + } + + switch msg := msg.(type) { + case *pgproto3.BackendKeyData: + pgConn.pid = msg.ProcessID + pgConn.secretKey = msg.SecretKey + + case *pgproto3.AuthenticationOk: + case *pgproto3.AuthenticationCleartextPassword: + err = pgConn.txPasswordMessage(pgConn.config.Password) + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "failed to write password message", err: err} + } + case *pgproto3.AuthenticationMD5Password: + digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) + err = pgConn.txPasswordMessage(digestedPassword) + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "failed to write password message", err: err} + } + case *pgproto3.AuthenticationSASL: + err = pgConn.scramAuth(msg.AuthMechanisms) + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "failed SASL auth", err: err} + } + + case *pgproto3.ReadyForQuery: + pgConn.status = connStatusIdle + if config.ValidateConnect != nil { + // ValidateConnect may execute commands that cause the context to be watched again. Unwatch first to avoid + // the watch already in progress panic. This is that last thing done by this method so there is no need to + // restart the watch after ValidateConnect returns. + // + // See https://github.com/jackc/pgconn/issues/40. + pgConn.contextWatcher.Unwatch() + + err := config.ValidateConnect(ctx, pgConn) + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err} + } + } + return pgConn, nil + case *pgproto3.ParameterStatus: + // handled by ReceiveMessage + case *pgproto3.ErrorResponse: + pgConn.conn.Close() + return nil, ErrorResponseToPgError(msg) + default: + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "received unexpected message", err: err} + } + } +} + +func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { + err = binary.Write(pgConn.conn, binary.BigEndian, []int32{8, 80877103}) + if err != nil { + return + } + + response := make([]byte, 1) + if _, err = io.ReadFull(pgConn.conn, response); err != nil { + return + } + + if response[0] != 'S' { + return errors.New("server refused TLS connection") + } + + pgConn.conn = tls.Client(pgConn.conn, tlsConfig) + + return nil +} + +func (pgConn *PgConn) txPasswordMessage(password string) (err error) { + msg := &pgproto3.PasswordMessage{Password: password} + _, err = pgConn.conn.Write(msg.Encode(pgConn.wbuf)) + return err +} + +func hexMD5(s string) string { + hash := md5.New() + io.WriteString(hash, s) + return hex.EncodeToString(hash.Sum(nil)) +} + +func (pgConn *PgConn) signalMessage() chan struct{} { + if pgConn.bufferingReceive { + panic("BUG: signalMessage when already in progress") + } + + pgConn.bufferingReceive = true + pgConn.bufferingReceiveMux.Lock() + + ch := make(chan struct{}) + go func() { + pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive() + pgConn.bufferingReceiveMux.Unlock() + close(ch) + }() + + return ch +} + +// SendBytes sends buf to the PostgreSQL server. It must only be used when the connection is not busy. e.g. It is as +// error to call SendBytes while reading the result of a query. +// +// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. +// See https://www.postgresql.org/docs/current/protocol.html. +func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { + if err := pgConn.lock(); err != nil { + return err + } + defer pgConn.unlock() + + if ctx != context.Background() { + select { + case <-ctx.Done(): + return newContextAlreadyDoneError(ctx) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + n, err := pgConn.conn.Write(buf) + if err != nil { + pgConn.asyncClose() + return &writeError{err: err, safeToRetry: n == 0} + } + + return nil +} + +// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the +// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages +// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger +// the OnNotification callback. +// +// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. +// See https://www.postgresql.org/docs/current/protocol.html. +func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) { + if err := pgConn.lock(); err != nil { + return nil, err + } + defer pgConn.unlock() + + if ctx != context.Background() { + select { + case <-ctx.Done(): + return nil, newContextAlreadyDoneError(ctx) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + msg, err := pgConn.receiveMessage() + if err != nil { + err = &pgconnError{ + msg: "receive message failed", + err: preferContextOverNetTimeoutError(ctx, err), + safeToRetry: true} + } + return msg, err +} + +// peekMessage peeks at the next message without setting up context cancellation. +func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { + if pgConn.peekedMsg != nil { + return pgConn.peekedMsg, nil + } + + var msg pgproto3.BackendMessage + var err error + if pgConn.bufferingReceive { + pgConn.bufferingReceiveMux.Lock() + msg = pgConn.bufferingReceiveMsg + err = pgConn.bufferingReceiveErr + pgConn.bufferingReceiveMux.Unlock() + pgConn.bufferingReceive = false + + // If a timeout error happened in the background try the read again. + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + msg, err = pgConn.frontend.Receive() + } + } else { + msg, err = pgConn.frontend.Receive() + } + + if err != nil { + // Close on anything other than timeout error - everything else is fatal + var netErr net.Error + isNetErr := errors.As(err, &netErr) + if !(isNetErr && netErr.Timeout()) { + pgConn.asyncClose() + } + + return nil, err + } + + pgConn.peekedMsg = msg + return msg, nil +} + +// receiveMessage receives a message without setting up context cancellation +func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { + msg, err := pgConn.peekMessage() + if err != nil { + // Close on anything other than timeout error - everything else is fatal + var netErr net.Error + isNetErr := errors.As(err, &netErr) + if !(isNetErr && netErr.Timeout()) { + pgConn.asyncClose() + } + + return nil, err + } + pgConn.peekedMsg = nil + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + pgConn.txStatus = msg.TxStatus + case *pgproto3.ParameterStatus: + pgConn.parameterStatuses[msg.Name] = msg.Value + case *pgproto3.ErrorResponse: + if msg.Severity == "FATAL" { + pgConn.status = connStatusClosed + pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return. + close(pgConn.cleanupDone) + return nil, ErrorResponseToPgError(msg) + } + case *pgproto3.NoticeResponse: + if pgConn.config.OnNotice != nil { + pgConn.config.OnNotice(pgConn, noticeResponseToNotice(msg)) + } + case *pgproto3.NotificationResponse: + if pgConn.config.OnNotification != nil { + pgConn.config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload}) + } + } + + return msg, nil +} + +// Conn returns the underlying net.Conn. +func (pgConn *PgConn) Conn() net.Conn { + return pgConn.conn +} + +// PID returns the backend PID. +func (pgConn *PgConn) PID() uint32 { + return pgConn.pid +} + +// TxStatus returns the current TxStatus as reported by the server in the ReadyForQuery message. +// +// Possible return values: +// 'I' - idle / not in transaction +// 'T' - in a transaction +// 'E' - in a failed transaction +// +// See https://www.postgresql.org/docs/current/protocol-message-formats.html. +func (pgConn *PgConn) TxStatus() byte { + return pgConn.txStatus +} + +// SecretKey returns the backend secret key used to send a cancel query message to the server. +func (pgConn *PgConn) SecretKey() uint32 { + return pgConn.secretKey +} + +// Close closes a connection. It is safe to call Close on a already closed connection. Close attempts a clean close by +// sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The +// underlying net.Conn.Close() will always be called regardless of any other errors. +func (pgConn *PgConn) Close(ctx context.Context) error { + if pgConn.status == connStatusClosed { + return nil + } + pgConn.status = connStatusClosed + + defer close(pgConn.cleanupDone) + defer pgConn.conn.Close() + + if ctx != context.Background() { + // Close may be called while a cancellable query is in progress. This will most often be triggered by panic when + // a defer closes the connection (possibly indirectly via a transaction or a connection pool). Unwatch to end any + // previous watch. It is safe to Unwatch regardless of whether a watch is already is progress. + // + // See https://github.com/jackc/pgconn/issues/29 + pgConn.contextWatcher.Unwatch() + + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + // Ignore any errors sending Terminate message and waiting for server to close connection. + // This mimics the behavior of libpq PQfinish. It calls closePGconn which calls sendTerminateConn which purposefully + // ignores errors. + // + // See https://github.com/jackc/pgx/issues/637 + pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) + + return pgConn.conn.Close() +} + +// asyncClose marks the connection as closed and asynchronously sends a cancel query message and closes the underlying +// connection. +func (pgConn *PgConn) asyncClose() { + if pgConn.status == connStatusClosed { + return + } + pgConn.status = connStatusClosed + + go func() { + defer close(pgConn.cleanupDone) + defer pgConn.conn.Close() + + deadline := time.Now().Add(time.Second * 15) + + ctx, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() + + pgConn.CancelRequest(ctx) + + pgConn.conn.SetDeadline(deadline) + + pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) + }() +} + +// CleanupDone returns a channel that will be closed after all underlying resources have been cleaned up. A closed +// connection is no longer usable, but underlying resources, in particular the net.Conn, may not have finished closing +// yet. This is because certain errors such as a context cancellation require that the interrupted function call return +// immediately, but the error may also cause the connection to be closed. In these cases the underlying resources are +// closed asynchronously. +// +// This is only likely to be useful to connection pools. It gives them a way avoid establishing a new connection while +// an old connection is still being cleaned up and thereby exceeding the maximum pool size. +func (pgConn *PgConn) CleanupDone() chan (struct{}) { + return pgConn.cleanupDone +} + +// IsClosed reports if the connection has been closed. +// +// CleanupDone() can be used to determine if all cleanup has been completed. +func (pgConn *PgConn) IsClosed() bool { + return pgConn.status < connStatusIdle +} + +// IsBusy reports if the connection is busy. +func (pgConn *PgConn) IsBusy() bool { + return pgConn.status == connStatusBusy +} + +// lock locks the connection. +func (pgConn *PgConn) lock() error { + switch pgConn.status { + case connStatusBusy: + return &connLockError{status: "conn busy"} // This only should be possible in case of an application bug. + case connStatusClosed: + return &connLockError{status: "conn closed"} + case connStatusUninitialized: + return &connLockError{status: "conn uninitialized"} + } + pgConn.status = connStatusBusy + return nil +} + +func (pgConn *PgConn) unlock() { + switch pgConn.status { + case connStatusBusy: + pgConn.status = connStatusIdle + case connStatusClosed: + default: + panic("BUG: cannot unlock unlocked connection") // This should only be possible if there is a bug in this package. + } +} + +// ParameterStatus returns the value of a parameter reported by the server (e.g. +// server_version). Returns an empty string for unknown parameters. +func (pgConn *PgConn) ParameterStatus(key string) string { + return pgConn.parameterStatuses[key] +} + +// CommandTag is the result of an Exec function +type CommandTag []byte + +// RowsAffected returns the number of rows affected. If the CommandTag was not +// for a row affecting command (e.g. "CREATE TABLE") then it returns 0. +func (ct CommandTag) RowsAffected() int64 { + // Find last non-digit + idx := -1 + for i := len(ct) - 1; i >= 0; i-- { + if ct[i] >= '0' && ct[i] <= '9' { + idx = i + } else { + break + } + } + + if idx == -1 { + return 0 + } + + var n int64 + for _, b := range ct[idx:] { + n = n*10 + int64(b-'0') + } + + return n +} + +func (ct CommandTag) String() string { + return string(ct) +} + +// Insert is true if the command tag starts with "INSERT". +func (ct CommandTag) Insert() bool { + return len(ct) >= 6 && + ct[0] == 'I' && + ct[1] == 'N' && + ct[2] == 'S' && + ct[3] == 'E' && + ct[4] == 'R' && + ct[5] == 'T' +} + +// Update is true if the command tag starts with "UPDATE". +func (ct CommandTag) Update() bool { + return len(ct) >= 6 && + ct[0] == 'U' && + ct[1] == 'P' && + ct[2] == 'D' && + ct[3] == 'A' && + ct[4] == 'T' && + ct[5] == 'E' +} + +// Delete is true if the command tag starts with "DELETE". +func (ct CommandTag) Delete() bool { + return len(ct) >= 6 && + ct[0] == 'D' && + ct[1] == 'E' && + ct[2] == 'L' && + ct[3] == 'E' && + ct[4] == 'T' && + ct[5] == 'E' +} + +// Select is true if the command tag starts with "SELECT". +func (ct CommandTag) Select() bool { + return len(ct) >= 6 && + ct[0] == 'S' && + ct[1] == 'E' && + ct[2] == 'L' && + ct[3] == 'E' && + ct[4] == 'C' && + ct[5] == 'T' +} + +type StatementDescription struct { + Name string + SQL string + ParamOIDs []uint32 + Fields []pgproto3.FieldDescription +} + +// Prepare creates a prepared statement. If the name is empty, the anonymous prepared statement will be used. This +// allows Prepare to also to describe statements without creating a server-side prepared statement. +func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) { + if err := pgConn.lock(); err != nil { + return nil, err + } + defer pgConn.unlock() + + if ctx != context.Background() { + select { + case <-ctx.Done(): + return nil, newContextAlreadyDoneError(ctx) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + buf := pgConn.wbuf + buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) + buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf) + buf = (&pgproto3.Sync{}).Encode(buf) + + n, err := pgConn.conn.Write(buf) + if err != nil { + pgConn.asyncClose() + return nil, &writeError{err: err, safeToRetry: n == 0} + } + + psd := &StatementDescription{Name: name, SQL: sql} + + var parseErr error + +readloop: + for { + msg, err := pgConn.receiveMessage() + if err != nil { + pgConn.asyncClose() + return nil, preferContextOverNetTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ParameterDescription: + psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) + copy(psd.ParamOIDs, msg.ParameterOIDs) + case *pgproto3.RowDescription: + psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields)) + copy(psd.Fields, msg.Fields) + case *pgproto3.ErrorResponse: + parseErr = ErrorResponseToPgError(msg) + case *pgproto3.ReadyForQuery: + break readloop + } + } + + if parseErr != nil { + return nil, parseErr + } + return psd, nil +} + +// ErrorResponseToPgError converts a wire protocol error message to a *PgError. +func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { + return &PgError{ + Severity: msg.Severity, + Code: string(msg.Code), + Message: string(msg.Message), + Detail: string(msg.Detail), + Hint: msg.Hint, + Position: msg.Position, + InternalPosition: msg.InternalPosition, + InternalQuery: string(msg.InternalQuery), + Where: string(msg.Where), + SchemaName: string(msg.SchemaName), + TableName: string(msg.TableName), + ColumnName: string(msg.ColumnName), + DataTypeName: string(msg.DataTypeName), + ConstraintName: msg.ConstraintName, + File: string(msg.File), + Line: msg.Line, + Routine: string(msg.Routine), + } +} + +func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice { + pgerr := ErrorResponseToPgError((*pgproto3.ErrorResponse)(msg)) + return (*Notice)(pgerr) +} + +// CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel +// request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there +// is no way to be sure a query was canceled. See https://www.postgresql.org/docs/11/protocol-flow.html#id-1.10.5.7.9 +func (pgConn *PgConn) CancelRequest(ctx context.Context) error { + // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing + // the connection config. This is important in high availability configurations where fallback connections may be + // specified or DNS may be used to load balance. + serverAddr := pgConn.conn.RemoteAddr() + cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) + if err != nil { + return err + } + defer cancelConn.Close() + + if ctx != context.Background() { + contextWatcher := ctxwatch.NewContextWatcher( + func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, + func() { cancelConn.SetDeadline(time.Time{}) }, + ) + contextWatcher.Watch(ctx) + defer contextWatcher.Unwatch() + } + + buf := make([]byte, 16) + binary.BigEndian.PutUint32(buf[0:4], 16) + binary.BigEndian.PutUint32(buf[4:8], 80877102) + binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.pid)) + binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) + _, err = cancelConn.Write(buf) + if err != nil { + return err + } + + _, err = cancelConn.Read(buf) + if err != io.EOF { + return err + } + + return nil +} + +// WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not +// received. +func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { + if err := pgConn.lock(); err != nil { + return err + } + defer pgConn.unlock() + + if ctx != context.Background() { + select { + case <-ctx.Done(): + return newContextAlreadyDoneError(ctx) + default: + } + + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + for { + msg, err := pgConn.receiveMessage() + if err != nil { + return preferContextOverNetTimeoutError(ctx, err) + } + + switch msg.(type) { + case *pgproto3.NotificationResponse: + return nil + } + } +} + +// Exec executes SQL via the PostgreSQL simple query protocol. SQL may contain multiple queries. Execution is +// implicitly wrapped in a transaction unless a transaction is already in progress or SQL contains transaction control +// statements. +// +// Prefer ExecParams unless executing arbitrary SQL that may contain multiple queries. +func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { + if err := pgConn.lock(); err != nil { + return &MultiResultReader{ + closed: true, + err: err, + } + } + + pgConn.multiResultReader = MultiResultReader{ + pgConn: pgConn, + ctx: ctx, + } + multiResult := &pgConn.multiResultReader + if ctx != context.Background() { + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = newContextAlreadyDoneError(ctx) + pgConn.unlock() + return multiResult + default: + } + pgConn.contextWatcher.Watch(ctx) + } + + buf := pgConn.wbuf + buf = (&pgproto3.Query{String: sql}).Encode(buf) + + n, err := pgConn.conn.Write(buf) + if err != nil { + pgConn.asyncClose() + pgConn.contextWatcher.Unwatch() + multiResult.closed = true + multiResult.err = &writeError{err: err, safeToRetry: n == 0} + pgConn.unlock() + return multiResult + } + + return multiResult +} + +// ReceiveResults reads the result that might be returned by Postgres after a SendBytes +// (e.a. after sending a CopyDone in a copy-both situation). +// +// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. +// See https://www.postgresql.org/docs/current/protocol.html. +func (pgConn *PgConn) ReceiveResults(ctx context.Context) *MultiResultReader { + if err := pgConn.lock(); err != nil { + return &MultiResultReader{ + closed: true, + err: err, + } + } + + pgConn.multiResultReader = MultiResultReader{ + pgConn: pgConn, + ctx: ctx, + } + multiResult := &pgConn.multiResultReader + if ctx != context.Background() { + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = newContextAlreadyDoneError(ctx) + pgConn.unlock() + return multiResult + default: + } + pgConn.contextWatcher.Watch(ctx) + } + + return multiResult +} + +// ExecParams executes a command via the PostgreSQL extended query protocol. +// +// sql is a SQL command string. It may only contain one query. Parameter substitution is positional using $1, $2, $3, +// etc. +// +// paramValues are the parameter values. It must be encoded in the format given by paramFormats. +// +// paramOIDs is a slice of data type OIDs for paramValues. If paramOIDs is nil, the server will infer the data type for +// all parameters. Any paramOID element that is 0 that will cause the server to infer the data type for that parameter. +// ExecParams will panic if len(paramOIDs) is not 0, 1, or len(paramValues). +// +// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or +// binary format. If paramFormats is nil all params are text format. ExecParams will panic if +// len(paramFormats) is not 0, 1, or len(paramValues). +// +// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or +// binary format. If resultFormats is nil all results will be in text format. +// +// ResultReader must be closed before PgConn can be used again. +func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *ResultReader { + result := pgConn.execExtendedPrefix(ctx, paramValues) + if result.closed { + return result + } + + buf := pgConn.wbuf + buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) + buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) + + pgConn.execExtendedSuffix(buf, result) + + return result +} + +// ExecPrepared enqueues the execution of a prepared statement via the PostgreSQL extended query protocol. +// +// paramValues are the parameter values. It must be encoded in the format given by paramFormats. +// +// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or +// binary format. If paramFormats is nil all params are text format. ExecPrepared will panic if +// len(paramFormats) is not 0, 1, or len(paramValues). +// +// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or +// binary format. If resultFormats is nil all results will be in text format. +// +// ResultReader must be closed before PgConn can be used again. +func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *ResultReader { + result := pgConn.execExtendedPrefix(ctx, paramValues) + if result.closed { + return result + } + + buf := pgConn.wbuf + buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) + + pgConn.execExtendedSuffix(buf, result) + + return result +} + +func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader { + pgConn.resultReader = ResultReader{ + pgConn: pgConn, + ctx: ctx, + } + result := &pgConn.resultReader + + if err := pgConn.lock(); err != nil { + result.concludeCommand(nil, err) + result.closed = true + return result + } + + if len(paramValues) > math.MaxUint16 { + result.concludeCommand(nil, fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) + result.closed = true + pgConn.unlock() + return result + } + + if ctx != context.Background() { + select { + case <-ctx.Done(): + result.concludeCommand(nil, newContextAlreadyDoneError(ctx)) + result.closed = true + pgConn.unlock() + return result + default: + } + pgConn.contextWatcher.Watch(ctx) + } + + return result +} + +func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { + buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) + buf = (&pgproto3.Execute{}).Encode(buf) + buf = (&pgproto3.Sync{}).Encode(buf) + + n, err := pgConn.conn.Write(buf) + if err != nil { + pgConn.asyncClose() + result.concludeCommand(nil, &writeError{err: err, safeToRetry: n == 0}) + pgConn.contextWatcher.Unwatch() + result.closed = true + pgConn.unlock() + return + } + + result.readUntilRowDescription() +} + +// CopyTo executes the copy command sql and copies the results to w. +func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { + if err := pgConn.lock(); err != nil { + return nil, err + } + + if ctx != context.Background() { + select { + case <-ctx.Done(): + pgConn.unlock() + return nil, newContextAlreadyDoneError(ctx) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + // Send copy to command + buf := pgConn.wbuf + buf = (&pgproto3.Query{String: sql}).Encode(buf) + + n, err := pgConn.conn.Write(buf) + if err != nil { + pgConn.asyncClose() + pgConn.unlock() + return nil, &writeError{err: err, safeToRetry: n == 0} + } + + // Read results + var commandTag CommandTag + var pgErr error + for { + msg, err := pgConn.receiveMessage() + if err != nil { + pgConn.asyncClose() + return nil, preferContextOverNetTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.CopyDone: + case *pgproto3.CopyData: + _, err := w.Write(msg.Data) + if err != nil { + pgConn.asyncClose() + return nil, err + } + case *pgproto3.ReadyForQuery: + pgConn.unlock() + return commandTag, pgErr + case *pgproto3.CommandComplete: + commandTag = CommandTag(msg.CommandTag) + case *pgproto3.ErrorResponse: + pgErr = ErrorResponseToPgError(msg) + } + } +} + +// CopyFrom executes the copy command sql and copies all of r to the PostgreSQL server. +// +// Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r +// could still block. +func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { + if err := pgConn.lock(); err != nil { + return nil, err + } + defer pgConn.unlock() + + if ctx != context.Background() { + select { + case <-ctx.Done(): + return nil, newContextAlreadyDoneError(ctx) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + // Send copy to command + buf := pgConn.wbuf + buf = (&pgproto3.Query{String: sql}).Encode(buf) + + n, err := pgConn.conn.Write(buf) + if err != nil { + pgConn.asyncClose() + return nil, &writeError{err: err, safeToRetry: n == 0} + } + + // Send copy data + abortCopyChan := make(chan struct{}) + copyErrChan := make(chan error, 1) + signalMessageChan := pgConn.signalMessage() + + go func() { + buf := make([]byte, 0, 65536) + buf = append(buf, 'd') + sp := len(buf) + + for { + n, readErr := r.Read(buf[5:cap(buf)]) + if n > 0 { + buf = buf[0 : n+5] + pgio.SetInt32(buf[sp:], int32(n+4)) + + _, writeErr := pgConn.conn.Write(buf) + if writeErr != nil { + // Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine. + pgConn.conn.Close() + + copyErrChan <- writeErr + return + } + } + if readErr != nil { + copyErrChan <- readErr + return + } + + select { + case <-abortCopyChan: + return + default: + } + } + }() + + var pgErr error + var copyErr error + for copyErr == nil && pgErr == nil { + select { + case copyErr = <-copyErrChan: + case <-signalMessageChan: + msg, err := pgConn.receiveMessage() + if err != nil { + pgConn.asyncClose() + return nil, preferContextOverNetTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ErrorResponse: + pgErr = ErrorResponseToPgError(msg) + default: + signalMessageChan = pgConn.signalMessage() + } + } + } + close(abortCopyChan) + + buf = buf[:0] + if copyErr == io.EOF || pgErr != nil { + copyDone := &pgproto3.CopyDone{} + buf = copyDone.Encode(buf) + } else { + copyFail := &pgproto3.CopyFail{Message: copyErr.Error()} + buf = copyFail.Encode(buf) + } + _, err = pgConn.conn.Write(buf) + if err != nil { + pgConn.asyncClose() + return nil, err + } + + // Read results + var commandTag CommandTag + for { + msg, err := pgConn.receiveMessage() + if err != nil { + pgConn.asyncClose() + return nil, preferContextOverNetTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + return commandTag, pgErr + case *pgproto3.CommandComplete: + commandTag = CommandTag(msg.CommandTag) + case *pgproto3.ErrorResponse: + pgErr = ErrorResponseToPgError(msg) + } + } +} + +// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. +type MultiResultReader struct { + pgConn *PgConn + ctx context.Context + + rr *ResultReader + + closed bool + err error +} + +// ReadAll reads all available results. Calling ReadAll is mutually exclusive with all other MultiResultReader methods. +func (mrr *MultiResultReader) ReadAll() ([]*Result, error) { + var results []*Result + + for mrr.NextResult() { + results = append(results, mrr.ResultReader().Read()) + } + err := mrr.Close() + + return results, err +} + +func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) { + msg, err := mrr.pgConn.receiveMessage() + + if err != nil { + mrr.pgConn.contextWatcher.Unwatch() + mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) + mrr.closed = true + mrr.pgConn.asyncClose() + return nil, mrr.err + } + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + mrr.pgConn.contextWatcher.Unwatch() + mrr.closed = true + mrr.pgConn.unlock() + case *pgproto3.ErrorResponse: + mrr.err = ErrorResponseToPgError(msg) + } + + return msg, nil +} + +// NextResult returns advances the MultiResultReader to the next result and returns true if a result is available. +func (mrr *MultiResultReader) NextResult() bool { + for !mrr.closed && mrr.err == nil { + msg, err := mrr.receiveMessage() + if err != nil { + return false + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + mrr.pgConn.resultReader = ResultReader{ + pgConn: mrr.pgConn, + multiResultReader: mrr, + ctx: mrr.ctx, + fieldDescriptions: msg.Fields, + } + mrr.rr = &mrr.pgConn.resultReader + return true + case *pgproto3.CommandComplete: + mrr.pgConn.resultReader = ResultReader{ + commandTag: CommandTag(msg.CommandTag), + commandConcluded: true, + closed: true, + } + mrr.rr = &mrr.pgConn.resultReader + return true + case *pgproto3.EmptyQueryResponse: + return false + } + } + + return false +} + +// ResultReader returns the current ResultReader. +func (mrr *MultiResultReader) ResultReader() *ResultReader { + return mrr.rr +} + +// Close closes the MultiResultReader and returns the first error that occurred during the MultiResultReader's use. +func (mrr *MultiResultReader) Close() error { + for !mrr.closed { + _, err := mrr.receiveMessage() + if err != nil { + return mrr.err + } + } + + return mrr.err +} + +// ResultReader is a reader for the result of a single query. +type ResultReader struct { + pgConn *PgConn + multiResultReader *MultiResultReader + ctx context.Context + + fieldDescriptions []pgproto3.FieldDescription + rowValues [][]byte + commandTag CommandTag + commandConcluded bool + closed bool + err error +} + +// Result is the saved query response that is returned by calling Read on a ResultReader. +type Result struct { + FieldDescriptions []pgproto3.FieldDescription + Rows [][][]byte + CommandTag CommandTag + Err error +} + +// Read saves the query response to a Result. +func (rr *ResultReader) Read() *Result { + br := &Result{} + + for rr.NextRow() { + if br.FieldDescriptions == nil { + br.FieldDescriptions = make([]pgproto3.FieldDescription, len(rr.FieldDescriptions())) + copy(br.FieldDescriptions, rr.FieldDescriptions()) + } + + row := make([][]byte, len(rr.Values())) + copy(row, rr.Values()) + br.Rows = append(br.Rows, row) + } + + br.CommandTag, br.Err = rr.Close() + + return br +} + +// NextRow advances the ResultReader to the next row and returns true if a row is available. +func (rr *ResultReader) NextRow() bool { + for !rr.commandConcluded { + msg, err := rr.receiveMessage() + if err != nil { + return false + } + + switch msg := msg.(type) { + case *pgproto3.DataRow: + rr.rowValues = msg.Values + return true + } + } + + return false +} + +// FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until +// the ResultReader is closed. +func (rr *ResultReader) FieldDescriptions() []pgproto3.FieldDescription { + return rr.fieldDescriptions +} + +// Values returns the current row data. NextRow must have been previously been called. The returned [][]byte is only +// valid until the next NextRow call or the ResultReader is closed. However, the underlying byte data is safe to +// retain a reference to and mutate. +func (rr *ResultReader) Values() [][]byte { + return rr.rowValues +} + +// Close consumes any remaining result data and returns the command tag or +// error. +func (rr *ResultReader) Close() (CommandTag, error) { + if rr.closed { + return rr.commandTag, rr.err + } + rr.closed = true + + for !rr.commandConcluded { + _, err := rr.receiveMessage() + if err != nil { + return nil, rr.err + } + } + + if rr.multiResultReader == nil { + for { + msg, err := rr.receiveMessage() + if err != nil { + return nil, rr.err + } + + switch msg := msg.(type) { + // Detect a deferred constraint violation where the ErrorResponse is sent after CommandComplete. + case *pgproto3.ErrorResponse: + rr.err = ErrorResponseToPgError(msg) + case *pgproto3.ReadyForQuery: + rr.pgConn.contextWatcher.Unwatch() + rr.pgConn.unlock() + return rr.commandTag, rr.err + } + } + } + + return rr.commandTag, rr.err +} + +// readUntilRowDescription ensures the ResultReader's fieldDescriptions are loaded. It does not return an error as any +// error will be stored in the ResultReader. +func (rr *ResultReader) readUntilRowDescription() { + for !rr.commandConcluded { + // Peek before receive to avoid consuming a DataRow if the result set does not include a RowDescription method. + // This should never happen under normal pgconn usage, but it is possible if SendBytes and ReceiveResults are + // manually used to construct a query that does not issue a describe statement. + msg, _ := rr.pgConn.peekMessage() + if _, ok := msg.(*pgproto3.DataRow); ok { + return + } + + // Consume the message + msg, _ = rr.receiveMessage() + if _, ok := msg.(*pgproto3.RowDescription); ok { + return + } + } +} + +func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) { + if rr.multiResultReader == nil { + msg, err = rr.pgConn.receiveMessage() + } else { + msg, err = rr.multiResultReader.receiveMessage() + } + + if err != nil { + err = preferContextOverNetTimeoutError(rr.ctx, err) + rr.concludeCommand(nil, err) + rr.pgConn.contextWatcher.Unwatch() + rr.closed = true + if rr.multiResultReader == nil { + rr.pgConn.asyncClose() + } + + return nil, rr.err + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + rr.fieldDescriptions = msg.Fields + case *pgproto3.CommandComplete: + rr.concludeCommand(CommandTag(msg.CommandTag), nil) + case *pgproto3.EmptyQueryResponse: + rr.concludeCommand(nil, nil) + case *pgproto3.ErrorResponse: + rr.concludeCommand(nil, ErrorResponseToPgError(msg)) + } + + return msg, nil +} + +func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { + // Keep the first error that is recorded. Store the error before checking if the command is already concluded to + // allow for receiving an error after CommandComplete but before ReadyForQuery. + if err != nil && rr.err == nil { + rr.err = err + } + + if rr.commandConcluded { + return + } + + rr.commandTag = commandTag + rr.rowValues = nil + rr.commandConcluded = true +} + +// Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip. +type Batch struct { + buf []byte +} + +// ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions. +func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { + batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf) + batch.ExecPrepared("", paramValues, paramFormats, resultFormats) +} + +// ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions. +func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { + batch.buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf) + batch.buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf) + batch.buf = (&pgproto3.Execute{}).Encode(batch.buf) +} + +// ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a +// transaction is already in progress or SQL contains transaction control statements. +func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader { + if err := pgConn.lock(); err != nil { + return &MultiResultReader{ + closed: true, + err: err, + } + } + + pgConn.multiResultReader = MultiResultReader{ + pgConn: pgConn, + ctx: ctx, + } + multiResult := &pgConn.multiResultReader + + if ctx != context.Background() { + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = newContextAlreadyDoneError(ctx) + pgConn.unlock() + return multiResult + default: + } + pgConn.contextWatcher.Watch(ctx) + } + + batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) + + // A large batch can deadlock without concurrent reading and writing. If the Write fails the underlying net.Conn is + // closed. This is all that can be done without introducing a race condition or adding a concurrent safe communication + // channel to relay the error back. The practical effect of this is that the underlying Write error is not reported. + // The error the code reading the batch results receives will be a closed connection error. + // + // See https://github.com/jackc/pgx/issues/374. + go func() { + _, err := pgConn.conn.Write(batch.buf) + if err != nil { + pgConn.conn.Close() + } + }() + + return multiResult +} + +// EscapeString escapes a string such that it can safely be interpolated into a SQL command string. It does not include +// the surrounding single quotes. +// +// The current implementation requires that standard_conforming_strings=on and client_encoding="UTF8". If these +// conditions are not met an error will be returned. It is possible these restrictions will be lifted in the future. +func (pgConn *PgConn) EscapeString(s string) (string, error) { + if pgConn.ParameterStatus("standard_conforming_strings") != "on" { + return "", errors.New("EscapeString must be run with standard_conforming_strings=on") + } + + if pgConn.ParameterStatus("client_encoding") != "UTF8" { + return "", errors.New("EscapeString must be run with client_encoding=UTF8") + } + + return strings.Replace(s, "'", "''", -1), nil +} + +// HijackedConn is the result of hijacking a connection. +// +// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning +// compatibility. +type HijackedConn struct { + Conn net.Conn // the underlying TCP or unix domain socket connection + PID uint32 // backend pid + SecretKey uint32 // key to use to send a cancel query message to the server + ParameterStatuses map[string]string // parameters that have been reported by the server + TxStatus byte + Frontend Frontend + Config *Config +} + +// Hijack extracts the internal connection data. pgConn must be in an idle state. pgConn is unusable after hijacking. +// Hijacking is typically only useful when using pgconn to establish a connection, but taking complete control of the +// raw connection after that (e.g. a load balancer or proxy). +// +// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning +// compatibility. +func (pgConn *PgConn) Hijack() (*HijackedConn, error) { + if err := pgConn.lock(); err != nil { + return nil, err + } + pgConn.status = connStatusClosed + + return &HijackedConn{ + Conn: pgConn.conn, + PID: pgConn.pid, + SecretKey: pgConn.secretKey, + ParameterStatuses: pgConn.parameterStatuses, + TxStatus: pgConn.txStatus, + Frontend: pgConn.frontend, + Config: pgConn.config, + }, nil +} + +// Construct created a PgConn from an already established connection to a PostgreSQL server. This is the inverse of +// PgConn.Hijack. The connection must be in an idle state. +// +// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning +// compatibility. +func Construct(hc *HijackedConn) (*PgConn, error) { + pgConn := &PgConn{ + conn: hc.Conn, + pid: hc.PID, + secretKey: hc.SecretKey, + parameterStatuses: hc.ParameterStatuses, + txStatus: hc.TxStatus, + frontend: hc.Frontend, + config: hc.Config, + + status: connStatusIdle, + + wbuf: make([]byte, 0, wbufLen), + cleanupDone: make(chan struct{}), + } + + pgConn.contextWatcher = ctxwatch.NewContextWatcher( + func() { pgConn.conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, + func() { pgConn.conn.SetDeadline(time.Time{}) }, + ) + + return pgConn, nil +} diff --git a/pgconn/pgconn_stress_test.go b/pgconn/pgconn_stress_test.go new file mode 100644 index 00000000..356b529a --- /dev/null +++ b/pgconn/pgconn_stress_test.go @@ -0,0 +1,90 @@ +package pgconn_test + +import ( + "context" + "math/rand" + "os" + "runtime" + "strconv" + "testing" + + "github.com/jackc/pgconn" + + "github.com/stretchr/testify/require" +) + +func TestConnStress(t *testing.T) { + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + actionCount := 10000 + if s := os.Getenv("PGX_TEST_STRESS_FACTOR"); s != "" { + stressFactor, err := strconv.ParseInt(s, 10, 64) + require.Nil(t, err, "Failed to parse PGX_TEST_STRESS_FACTOR") + actionCount *= int(stressFactor) + } + + setupStressDB(t, pgConn) + + actions := []struct { + name string + fn func(*pgconn.PgConn) error + }{ + {"Exec Select", stressExecSelect}, + {"ExecParams Select", stressExecParamsSelect}, + {"Batch", stressBatch}, + } + + for i := 0; i < actionCount; i++ { + action := actions[rand.Intn(len(actions))] + err := action.fn(pgConn) + require.Nilf(t, err, "%d: %s", i, action.name) + } + + // Each call with a context starts a goroutine. Ensure they are cleaned up when context is not canceled. + numGoroutine := runtime.NumGoroutine() + require.Truef(t, numGoroutine < 1000, "goroutines appear to be orphaned: %d in process", numGoroutine) +} + +func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) { + _, err := pgConn.Exec(context.Background(), ` + create temporary table widgets( + id serial primary key, + name varchar not null, + description text, + creation_time timestamptz default now() + ); + + insert into widgets(name, description) values + ('Foo', 'bar'), + ('baz', 'Something really long Something really long Something really long Something really long Something really long'), + ('a', 'b')`).ReadAll() + require.NoError(t, err) +} + +func stressExecSelect(pgConn *pgconn.PgConn) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, err := pgConn.Exec(ctx, "select * from widgets").ReadAll() + return err +} + +func stressExecParamsSelect(pgConn *pgconn.PgConn) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + result := pgConn.ExecParams(ctx, "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read() + return result.Err +} + +func stressBatch(pgConn *pgconn.PgConn) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + batch := &pgconn.Batch{} + + batch.ExecParams("select * from widgets", nil, nil, nil, nil) + batch.ExecParams("select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) + _, err := pgConn.ExecBatch(ctx, batch).ReadAll() + return err +} diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go new file mode 100644 index 00000000..c20b7425 --- /dev/null +++ b/pgconn/pgconn_test.go @@ -0,0 +1,2012 @@ +package pgconn_test + +import ( + "bytes" + "compress/gzip" + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "math" + "net" + "os" + "strconv" + "strings" + "testing" + "time" + + "github.com/jackc/pgconn" + "github.com/jackc/pgmock" + "github.com/jackc/pgproto3/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConnect(t *testing.T) { + tests := []struct { + name string + env string + }{ + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, + {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, + {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, + {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + connString := os.Getenv(tt.env) + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", tt.env) + } + + conn, err := pgconn.Connect(context.Background(), connString) + require.NoError(t, err) + + closeConn(t, conn) + }) + } +} + +// TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure +// connection. +func TestConnectTLS(t *testing.T) { + t.Parallel() + + connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") + } + + conn, err := pgconn.Connect(context.Background(), connString) + require.NoError(t, err) + + if _, ok := conn.Conn().(*tls.Conn); !ok { + t.Error("not a TLS connection") + } + + closeConn(t, conn) +} + +type pgmockWaitStep time.Duration + +func (s pgmockWaitStep) Step(*pgproto3.Backend) error { + time.Sleep(time.Duration(s)) + return nil +} + +func TestConnectTimeout(t *testing.T) { + t.Parallel() + tests := []struct { + name string + connect func(connStr string) error + }{ + { + name: "via context that times out", + connect: func(connStr string) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50) + defer cancel() + _, err := pgconn.Connect(ctx, connStr) + return err + }, + }, + { + name: "via config ConnectTimeout", + connect: func(connStr string) error { + conf, err := pgconn.ParseConfig(connStr) + require.NoError(t, err) + conf.ConnectTimeout = time.Microsecond * 50 + _, err = pgconn.ConnectConfig(context.Background(), conf) + return err + }, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + script := &pgmock.Script{ + Steps: []pgmock.Step{ + pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), + pgmock.SendMessage(&pgproto3.AuthenticationOk{}), + pgmockWaitStep(time.Millisecond * 500), + pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}), + pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + }, + } + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error, 1) + go func() { + defer close(serverErrChan) + + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450)) + if err != nil { + serverErrChan <- err + return + } + + err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)) + if err != nil { + serverErrChan <- err + return + } + }() + + parts := strings.Split(ln.Addr().String(), ":") + host := parts[0] + port := parts[1] + connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) + tooLate := time.Now().Add(time.Millisecond * 500) + + err = tt.connect(connStr) + require.True(t, pgconn.Timeout(err), err) + require.True(t, time.Now().Before(tooLate)) + }) + } +} + +func TestConnectInvalidUser(t *testing.T) { + t.Parallel() + + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } + + config, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + config.User = "pgxinvalidusertest" + + _, err = pgconn.ConnectConfig(context.Background(), config) + require.Error(t, err) + pgErr, ok := errors.Unwrap(err).(*pgconn.PgError) + if !ok { + t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err) + } + if pgErr.Code != "28000" && pgErr.Code != "28P01" { + t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) + } +} + +func TestConnectWithConnectionRefused(t *testing.T) { + t.Parallel() + + // Presumably nothing is listening on 127.0.0.1:1 + conn, err := pgconn.Connect(context.Background(), "host=127.0.0.1 port=1") + if err == nil { + conn.Close(context.Background()) + t.Fatal("Expected error establishing connection to bad port") + } +} + +func TestConnectCustomDialer(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + dialed := false + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialed = true + return net.Dial(network, address) + } + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + require.True(t, dialed) + closeConn(t, conn) +} + +func TestConnectCustomLookup(t *testing.T) { + t.Parallel() + + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } + + config, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + looked := false + config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { + looked = true + return net.LookupHost(host) + } + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + require.True(t, looked) + closeConn(t, conn) +} + +func TestConnectWithRuntimeParams(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + config.RuntimeParams = map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + } + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, conn) + + result := conn.ExecParams(context.Background(), "show application_name", nil, nil, nil, nil).Read() + require.Nil(t, result.Err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "pgxtest", string(result.Rows[0][0])) + + result = conn.ExecParams(context.Background(), "show search_path", nil, nil, nil, nil).Read() + require.Nil(t, result.Err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "myschema", string(result.Rows[0][0])) +} + +func TestConnectWithFallback(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + // Prepend current primary config to fallbacks + config.Fallbacks = append([]*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) + + // Make primary config bad + config.Host = "localhost" + config.Port = 1 // presumably nothing listening here + + // Prepend bad first fallback + config.Fallbacks = append([]*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 1, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + closeConn(t, conn) +} + +func TestConnectWithValidateConnect(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + dialCount := 0 + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialCount++ + return net.Dial(network, address) + } + + acceptConnCount := 0 + config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + acceptConnCount++ + if acceptConnCount < 2 { + return errors.New("reject first conn") + } + return nil + } + + // Append current primary config to fallbacks + config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }) + + // Repeat fallbacks + config.Fallbacks = append(config.Fallbacks, config.Fallbacks...) + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + closeConn(t, conn) + + assert.True(t, dialCount > 1) + assert.True(t, acceptConnCount > 1) +} + +func TestConnectWithValidateConnectTargetSessionAttrsReadWrite(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite + config.RuntimeParams["default_transaction_read_only"] = "on" + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + conn, err := pgconn.ConnectConfig(ctx, config) + if !assert.NotNil(t, err) { + conn.Close(ctx) + } +} + +func TestConnectWithAfterConnect(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + _, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll() + return err + } + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + + results, err := conn.Exec(context.Background(), "show search_path;").ReadAll() + require.NoError(t, err) + defer closeConn(t, conn) + + assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) +} + +func TestConnectConfigRequiresConfigFromParseConfig(t *testing.T) { + t.Parallel() + + config := &pgconn.Config{} + + require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgconn.ConnectConfig(context.Background(), config) }) +} + +func TestConnPrepareSyntaxError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + psd, err := pgConn.Prepare(context.Background(), "ps1", "SYNTAX ERROR", nil) + require.Nil(t, psd) + require.NotNil(t, err) + + ensureConnValid(t, pgConn) +} + +func TestConnPrepareContextPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + psd, err := pgConn.Prepare(ctx, "ps1", "select 1", nil) + assert.Nil(t, psd) + assert.Error(t, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, pgconn.SafeToRetry(err)) + + ensureConnValid(t, pgConn) +} + +func TestConnExec(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + assert.NoError(t, err) + + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + + ensureConnValid(t, pgConn) +} + +func TestConnExecEmpty(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + multiResult := pgConn.Exec(context.Background(), ";") + + resultCount := 0 + for multiResult.NextResult() { + resultCount++ + multiResult.ResultReader().Close() + } + assert.Equal(t, 0, resultCount) + err = multiResult.Close() + assert.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestConnExecMultipleQueries(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + results, err := pgConn.Exec(context.Background(), "select 'Hello, world'; select 1").ReadAll() + assert.NoError(t, err) + + assert.Len(t, results, 2) + + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + + assert.Nil(t, results[1].Err) + assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) + assert.Len(t, results[1].Rows, 1) + assert.Equal(t, "1", string(results[1].Rows[0][0])) + + ensureConnValid(t, pgConn) +} + +func TestConnExecMultipleQueriesEagerFieldDescriptions(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + mrr := pgConn.Exec(context.Background(), "select 'Hello, world' as msg; select 1 as num") + + require.True(t, mrr.NextResult()) + require.Len(t, mrr.ResultReader().FieldDescriptions(), 1) + assert.Equal(t, []byte("msg"), mrr.ResultReader().FieldDescriptions()[0].Name) + _, err = mrr.ResultReader().Close() + require.NoError(t, err) + + require.True(t, mrr.NextResult()) + require.Len(t, mrr.ResultReader().FieldDescriptions(), 1) + assert.Equal(t, []byte("num"), mrr.ResultReader().FieldDescriptions()[0].Name) + _, err = mrr.ResultReader().Close() + require.NoError(t, err) + + require.False(t, mrr.NextResult()) + + require.NoError(t, mrr.Close()) + + ensureConnValid(t, pgConn) +} + +func TestConnExecMultipleQueriesError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + results, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1").ReadAll() + require.NotNil(t, err) + if pgErr, ok := err.(*pgconn.PgError); ok { + assert.Equal(t, "22012", pgErr.Code) + } else { + t.Errorf("unexpected error: %v", err) + } + + if pgConn.ParameterStatus("crdb_version") != "" { + // CockroachDB starts the second query result set and then sends the divide by zero error. + require.Len(t, results, 2) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "1", string(results[0].Rows[0][0])) + assert.Len(t, results[1].Rows, 0) + } else { + // PostgreSQL sends the divide by zero and never sends the second query result set. + require.Len(t, results, 1) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "1", string(results[0].Rows[0][0])) + } + + ensureConnValid(t, pgConn) +} + +func TestConnExecDeferredError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") + } + + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) + + _, err = pgConn.Exec(context.Background(), `update t set n=n+1 where id='b' returning *`).ReadAll() + require.NotNil(t, err) + + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) + + ensureConnValid(t, pgConn) +} + +func TestConnExecContextCanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(1)") + + for multiResult.NextResult() { + } + err = multiResult.Close() + assert.True(t, pgconn.Timeout(err)) + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupDone(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } +} + +func TestConnExecContextPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() + assert.Error(t, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, pgconn.SafeToRetry(err)) + + ensureConnValid(t, pgConn) +} + +func TestConnExecParams(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + result := pgConn.ExecParams(context.Background(), "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + require.Len(t, result.FieldDescriptions(), 1) + assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name) + + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestConnExecParamsDeferredError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") + } + + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) + + result := pgConn.ExecParams(context.Background(), `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read() + require.NotNil(t, result.Err) + var pgErr *pgconn.PgError + require.True(t, errors.As(result.Err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) + + ensureConnValid(t, pgConn) +} + +func TestConnExecParamsMaxNumberOfParams(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) + + ensureConnValid(t, pgConn) +} + +func TestConnExecParamsTooManyParams(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) + + ensureConnValid(t, pgConn) +} + +func TestConnExecParamsCanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + result := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + } + assert.Equal(t, 0, rowCount) + commandTag, err := result.Close() + assert.Equal(t, pgconn.CommandTag(nil), commandTag) + assert.True(t, pgconn.Timeout(err)) + assert.ErrorIs(t, err, context.DeadlineExceeded) + + assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupDone(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } +} + +func TestConnExecParamsPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read() + require.Error(t, result.Err) + assert.True(t, errors.Is(result.Err, context.Canceled)) + assert.True(t, pgconn.SafeToRetry(result.Err)) + + ensureConnValid(t, pgConn) +} + +func TestConnExecParamsEmptySQL(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + result := pgConn.ExecParams(ctx, "", nil, nil, nil, nil).Read() + assert.Nil(t, result.CommandTag) + assert.Len(t, result.Rows, 0) + assert.NoError(t, result.Err) + + ensureConnValid(t, pgConn) +} + +// https://github.com/jackc/pgx/issues/859 +func TestResultReaderValuesHaveSameCapacityAsLength(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + result := pgConn.ExecParams(context.Background(), "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + require.Len(t, result.FieldDescriptions(), 1) + assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name) + + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + assert.Equal(t, len(result.Values()[0]), cap(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestConnExecPrepared(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text as msg", nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, 1) + assert.Len(t, psd.Fields, 1) + + result := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) + require.Len(t, result.FieldDescriptions(), 1) + assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name) + + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestConnExecPreparedMaxNumberOfParams(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) + + result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) + + ensureConnValid(t, pgConn) +} + +func TestConnExecPreparedTooManyParams(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) + if pgConn.ParameterStatus("crdb_version") != "" { + // CockroachDB rejects preparing a statement with more than 65535 parameters. + require.EqualError(t, err, "ERROR: more than 65535 arguments to prepared statement: 65536 (SQLSTATE 08P01)") + } else { + // PostgreSQL accepts preparing a statement with more than 65535 parameters and only fails when executing it through the extended protocol. + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) + + result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() + require.EqualError(t, result.Err, "extended protocol limited to 65535 parameters") + } + + ensureConnValid(t, pgConn) +} + +func TestConnExecPreparedCanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + } + assert.Equal(t, 0, rowCount) + commandTag, err := result.Close() + assert.Equal(t, pgconn.CommandTag(nil), commandTag) + assert.True(t, pgconn.Timeout(err)) + assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupDone(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } +} + +func TestConnExecPreparedPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() + require.Error(t, result.Err) + assert.True(t, errors.Is(result.Err, context.Canceled)) + assert.True(t, pgconn.SafeToRetry(result.Err)) + + ensureConnValid(t, pgConn) +} + +func TestConnExecPreparedEmptySQL(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Prepare(ctx, "ps1", "", nil) + require.NoError(t, err) + + result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() + assert.Nil(t, result.CommandTag) + assert.Len(t, result.Rows, 0) + assert.NoError(t, result.Err) + + ensureConnValid(t, pgConn) +} + +func TestConnExecBatch(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + require.NoError(t, err) + + batch := &pgconn.Batch{} + + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) + batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) + results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.NoError(t, err) + require.Len(t, results, 3) + + require.Len(t, results[0].Rows, 1) + require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + + require.Len(t, results[1].Rows, 1) + require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) + + require.Len(t, results[2].Rows, 1) + require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) +} + +func TestConnExecBatchDeferredError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") + } + + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + require.NoError(t, err) + + batch := &pgconn.Batch{} + + batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil) + _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.NotNil(t, err) + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) + + ensureConnValid(t, pgConn) +} + +func TestConnExecBatchPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + require.NoError(t, err) + + batch := &pgconn.Batch{} + + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) + batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err = pgConn.ExecBatch(ctx, batch).ReadAll() + require.Error(t, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, pgconn.SafeToRetry(err)) + + ensureConnValid(t, pgConn) +} + +// Without concurrent reading and writing large batches can deadlock. +// +// See https://github.com/jackc/pgx/issues/374. +func TestConnExecBatchHuge(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode.") + } + + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + batch := &pgconn.Batch{} + + queryCount := 100000 + args := make([]string, queryCount) + + for i := range args { + args[i] = strconv.Itoa(i) + batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil) + } + + results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.NoError(t, err) + require.Len(t, results, queryCount) + + for i := range args { + require.Len(t, results[i].Rows, 1) + require.Equal(t, args[i], string(results[i].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[i].CommandTag)) + } +} + +func TestConnExecBatchImplicitTransaction(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/44803)") + } + + _, err = pgConn.Exec(context.Background(), "create temporary table t(id int)").ReadAll() + require.NoError(t, err) + + batch := &pgconn.Batch{} + + batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil) + batch.ExecParams("select 1/0", nil, nil, nil, nil) + _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.Error(t, err) + + result := pgConn.ExecParams(context.Background(), "select count(*) from t", nil, nil, nil, nil).Read() + require.Equal(t, "0", string(result.Rows[0][0])) +} + +func TestConnLocking(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") + _, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + assert.Error(t, err) + assert.Equal(t, "conn busy", err.Error()) + assert.True(t, pgconn.SafeToRetry(err)) + + results, err := mrr.ReadAll() + assert.NoError(t, err) + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + + ensureConnValid(t, pgConn) +} + +func TestCommandTag(t *testing.T) { + t.Parallel() + + var tests = []struct { + commandTag pgconn.CommandTag + rowsAffected int64 + isInsert bool + isUpdate bool + isDelete bool + isSelect bool + }{ + {commandTag: pgconn.CommandTag("INSERT 0 5"), rowsAffected: 5, isInsert: true}, + {commandTag: pgconn.CommandTag("UPDATE 0"), rowsAffected: 0, isUpdate: true}, + {commandTag: pgconn.CommandTag("UPDATE 1"), rowsAffected: 1, isUpdate: true}, + {commandTag: pgconn.CommandTag("DELETE 0"), rowsAffected: 0, isDelete: true}, + {commandTag: pgconn.CommandTag("DELETE 1"), rowsAffected: 1, isDelete: true}, + {commandTag: pgconn.CommandTag("DELETE 1234567890"), rowsAffected: 1234567890, isDelete: true}, + {commandTag: pgconn.CommandTag("SELECT 1"), rowsAffected: 1, isSelect: true}, + {commandTag: pgconn.CommandTag("SELECT 99999999999"), rowsAffected: 99999999999, isSelect: true}, + {commandTag: pgconn.CommandTag("CREATE TABLE"), rowsAffected: 0}, + {commandTag: pgconn.CommandTag("ALTER TABLE"), rowsAffected: 0}, + {commandTag: pgconn.CommandTag("DROP TABLE"), rowsAffected: 0}, + } + + for i, tt := range tests { + ct := tt.commandTag + assert.Equalf(t, tt.rowsAffected, ct.RowsAffected(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isInsert, ct.Insert(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isUpdate, ct.Update(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isDelete, ct.Delete(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isSelect, ct.Select(), "%d. %v", i, tt.commandTag) + } +} + +func TestConnOnNotice(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + var msg string + config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) { + msg = notice.Message + } + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support PL/PGSQL (https://github.com/cockroachdb/cockroach/issues/17511)") + } + + multiResult := pgConn.Exec(context.Background(), `do $$ +begin + raise notice 'hello, world'; +end$$;`) + err = multiResult.Close() + require.NoError(t, err) + assert.Equal(t, "hello, world", msg) + + ensureConnValid(t, pgConn) +} + +func TestConnOnNotification(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") + } + + _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() + require.NoError(t, err) + + notifier, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() + require.NoError(t, err) + + _, err = pgConn.Exec(context.Background(), "select 1").ReadAll() + require.NoError(t, err) + + assert.Equal(t, "bar", msg) + + ensureConnValid(t, pgConn) +} + +func TestConnWaitForNotification(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") + } + + _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() + require.NoError(t, err) + + notifier, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() + require.NoError(t, err) + + err = pgConn.WaitForNotification(context.Background()) + require.NoError(t, err) + + assert.Equal(t, "bar", msg) + + ensureConnValid(t, pgConn) +} + +func TestConnWaitForNotificationPrecanceled(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err = pgConn.WaitForNotification(ctx) + require.ErrorIs(t, err, context.Canceled) + + ensureConnValid(t, pgConn) +} + +func TestConnWaitForNotificationTimeout(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + err = pgConn.WaitForNotification(ctx) + cancel() + assert.True(t, pgconn.Timeout(err)) + assert.ErrorIs(t, err, context.DeadlineExceeded) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyToSmall(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does support COPY TO") + } + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json + )`).ReadAll() + require.NoError(t, err) + + _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll() + require.NoError(t, err) + + _, err = pgConn.Exec(context.Background(), `insert into foo values (null, null, null, null, null, null, null)`).ReadAll() + require.NoError(t, err) + + inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" + + "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n") + + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + + res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") + require.NoError(t, err) + + assert.Equal(t, int64(2), res.RowsAffected()) + assert.Equal(t, inputBytes, outputWriter.Bytes()) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyToLarge(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does support COPY TO") + } + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json, + h bytea + )`).ReadAll() + require.NoError(t, err) + + inputBytes := make([]byte, 0) + + for i := 0; i < 1000; i++ { + _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll() + require.NoError(t, err) + inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...) + } + + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + + res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") + require.NoError(t, err) + + assert.Equal(t, int64(1000), res.RowsAffected()) + assert.Equal(t, inputBytes, outputWriter.Bytes()) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyToQueryError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + outputWriter := bytes.NewBuffer(make([]byte, 0)) + + res, err := pgConn.CopyTo(context.Background(), outputWriter, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyToCanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)") + } + + outputWriter := &bytes.Buffer{} + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout") + assert.Error(t, err) + assert.Equal(t, pgconn.CommandTag(nil), res) + + assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupDone(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } +} + +func TestConnCopyToPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + outputWriter := &bytes.Buffer{} + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout") + require.Error(t, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, pgconn.SafeToRetry(err)) + assert.Equal(t, pgconn.CommandTag(nil), res) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyFrom(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not fully support COPY FROM (https://www.cockroachlabs.com/docs/v20.2/copy-from.html)") + } + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + srcBuf := &bytes.Buffer{} + + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + require.NoError(t, err) + } + + ct, err := pgConn.CopyFrom(context.Background(), srcBuf, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + + result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + assert.Equal(t, inputRows, result.Rows) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyFromCanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)") + } + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + r, w := io.Pipe() + go func() { + for i := 0; i < 1000000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + _, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + if err != nil { + return + } + time.Sleep(time.Microsecond) + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") + cancel() + assert.Equal(t, int64(0), ct.RowsAffected()) + assert.Error(t, err) + + assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupDone(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } +} + +func TestConnCopyFromPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + r, w := io.Pipe() + go func() { + for i := 0; i < 1000000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + _, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + if err != nil { + return + } + time.Sleep(time.Microsecond) + } + }() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.Error(t, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, pgconn.SafeToRetry(err)) + assert.Equal(t, pgconn.CommandTag(nil), ct) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyFromGzipReader(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not fully support COPY FROM (https://www.cockroachlabs.com/docs/v20.2/copy-from.html)") + } + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + f, err := ioutil.TempFile("", "*") + require.NoError(t, err) + + gw := gzip.NewWriter(f) + + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + require.NoError(t, err) + } + + err = gw.Close() + require.NoError(t, err) + + _, err = f.Seek(0, 0) + require.NoError(t, err) + + gr, err := gzip.NewReader(f) + require.NoError(t, err) + + ct, err := pgConn.CopyFrom(context.Background(), gr, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + + err = gr.Close() + require.NoError(t, err) + + err = f.Close() + require.NoError(t, err) + + err = os.Remove(f.Name()) + require.NoError(t, err) + + result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + assert.Equal(t, inputRows, result.Rows) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyFromQuerySyntaxError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + srcBuf := &bytes.Buffer{} + + res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyFromQueryNoTableError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + srcBuf := &bytes.Buffer{} + + res, err := pgConn.CopyFrom(context.Background(), srcBuf, "copy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) + + ensureConnValid(t, pgConn) +} + +// https://github.com/jackc/pgconn/issues/21 +func TestConnCopyFromNoticeResponseReceivedMidStream(t *testing.T) { + t.Parallel() + + ctx := context.Background() + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support triggers (https://github.com/cockroachdb/cockroach/issues/28296)") + } + + _, err = pgConn.Exec(ctx, `create temporary table sentences( + t text, + ts tsvector + )`).ReadAll() + require.NoError(t, err) + + _, err = pgConn.Exec(ctx, `create function pg_temp.sentences_trigger() returns trigger as $$ + begin + new.ts := to_tsvector(new.t); + return new; + end + $$ language plpgsql;`).ReadAll() + require.NoError(t, err) + + _, err = pgConn.Exec(ctx, `create trigger sentences_update before insert on sentences for each row execute procedure pg_temp.sentences_trigger();`).ReadAll() + require.NoError(t, err) + + longString := make([]byte, 10001) + for i := range longString { + longString[i] = 'x' + } + + buf := &bytes.Buffer{} + for i := 0; i < 1000; i++ { + buf.Write([]byte(fmt.Sprintf("%s\n", string(longString)))) + } + + _, err = pgConn.CopyFrom(ctx, buf, "COPY sentences(t) FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) +} + +func TestConnEscapeString(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + tests := []struct { + in string + out string + }{ + {in: "", out: ""}, + {in: "42", out: "42"}, + {in: "'", out: "''"}, + {in: "hi'there", out: "hi''there"}, + {in: "'hi there'", out: "''hi there''"}, + } + + for i, tt := range tests { + value, err := pgConn.EscapeString(tt.in) + if assert.NoErrorf(t, err, "%d.", i) { + assert.Equalf(t, tt.out, value, "%d.", i) + } + } + + ensureConnValid(t, pgConn) +} + +func TestConnCancelRequest(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)") + } + + multiResult := pgConn.Exec(context.Background(), "select 'Hello, world', pg_sleep(2)") + + // This test flickers without the Sleep. It appears that since Exec only sends the query and returns without awaiting a + // response that the CancelRequest can race it and be received before the query is running and cancellable. So wait a + // few milliseconds. + time.Sleep(50 * time.Millisecond) + + err = pgConn.CancelRequest(context.Background()) + require.NoError(t, err) + + for multiResult.NextResult() { + } + err = multiResult.Close() + + require.IsType(t, &pgconn.PgError{}, err) + require.Equal(t, "57014", err.(*pgconn.PgError).Code) + + ensureConnValid(t, pgConn) +} + +// https://github.com/jackc/pgx/issues/659 +func TestConnContextCanceledCancelsRunningQueryOnServer(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pid := pgConn.PID() + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(30)") + + for multiResult.NextResult() { + } + err = multiResult.Close() + assert.True(t, pgconn.Timeout(err)) + assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupDone(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } + + otherConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, otherConn) + + ctx, cancel = context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + for { + result := otherConn.ExecParams(ctx, + `select 1 from pg_stat_activity where pid=$1`, + [][]byte{[]byte(strconv.FormatInt(int64(pid), 10))}, + nil, + nil, + nil, + ).Read() + require.NoError(t, result.Err) + + if len(result.Rows) == 0 { + break + } + } +} + +func TestConnSendBytesAndReceiveMessage(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + queryMsg := pgproto3.Query{String: "select 42"} + buf := queryMsg.Encode(nil) + + err = pgConn.SendBytes(ctx, buf) + require.NoError(t, err) + + msg, err := pgConn.ReceiveMessage(ctx) + require.NoError(t, err) + _, ok := msg.(*pgproto3.RowDescription) + require.True(t, ok) + + msg, err = pgConn.ReceiveMessage(ctx) + require.NoError(t, err) + _, ok = msg.(*pgproto3.DataRow) + require.True(t, ok) + + msg, err = pgConn.ReceiveMessage(ctx) + require.NoError(t, err) + _, ok = msg.(*pgproto3.CommandComplete) + require.True(t, ok) + + msg, err = pgConn.ReceiveMessage(ctx) + require.NoError(t, err) + _, ok = msg.(*pgproto3.ReadyForQuery) + require.True(t, ok) + + ensureConnValid(t, pgConn) +} + +func TestHijackAndConstruct(t *testing.T) { + t.Parallel() + + origConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + hc, err := origConn.Hijack() + require.NoError(t, err) + + _, err = origConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + require.Error(t, err) + + newConn, err := pgconn.Construct(hc) + require.NoError(t, err) + + defer closeConn(t, newConn) + + results, err := newConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + assert.NoError(t, err) + + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + + ensureConnValid(t, newConn) +} + +func TestConnCloseWhileCancellableQueryInProgress(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + ctx, _ := context.WithCancel(context.Background()) + pgConn.Exec(ctx, "select n from generate_series(1,10) n") + + closeCtx, _ := context.WithCancel(context.Background()) + pgConn.Close(closeCtx) + select { + case <-pgConn.CleanupDone(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } +} + +// https://github.com/jackc/pgx/issues/800 +func TestFatalErrorReceivedAfterCommandComplete(t *testing.T) { + t.Parallel() + + steps := pgmock.AcceptUnauthenticatedConnRequestSteps() + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Bind{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Execute{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Sync{})) + steps = append(steps, pgmock.SendMessage(&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{ + {Name: []byte("mock")}, + }})) + steps = append(steps, pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 0")})) + steps = append(steps, pgmock.SendMessage(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "57P01"})) + + script := &pgmock.Script{Steps: steps} + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error, 1) + go func() { + defer close(serverErrChan) + + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + serverErrChan <- err + return + } + + err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)) + if err != nil { + serverErrChan <- err + return + } + }() + + parts := strings.Split(ln.Addr().String(), ":") + host := parts[0] + port := parts[1] + connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + conn, err := pgconn.Connect(ctx, connStr) + require.NoError(t, err) + + rr := conn.ExecParams(ctx, "mocked...", nil, nil, nil, nil) + + for rr.NextRow() { + } + + _, err = rr.Close() + require.Error(t, err) +} + +func Example() { + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + if err != nil { + log.Fatalln(err) + } + defer pgConn.Close(context.Background()) + + result := pgConn.ExecParams(context.Background(), "select generate_series(1,3)", nil, nil, nil, nil).Read() + if result.Err != nil { + log.Fatalln(result.Err) + } + + for _, row := range result.Rows { + fmt.Println(string(row[0])) + } + + fmt.Println(result.CommandTag) + // Output: + // 1 + // 2 + // 3 + // SELECT 3 +} diff --git a/pgconn/stmtcache/lru.go b/pgconn/stmtcache/lru.go new file mode 100644 index 00000000..90fb76c2 --- /dev/null +++ b/pgconn/stmtcache/lru.go @@ -0,0 +1,165 @@ +package stmtcache + +import ( + "container/list" + "context" + "fmt" + "sync/atomic" + + "github.com/jackc/pgconn" +) + +var lruCount uint64 + +// LRU implements Cache with a Least Recently Used (LRU) cache. +type LRU struct { + conn *pgconn.PgConn + mode int + cap int + prepareCount int + m map[string]*list.Element + l *list.List + psNamePrefix string + stmtsToClear []string +} + +// NewLRU creates a new LRU. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache. +func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU { + mustBeValidMode(mode) + mustBeValidCap(cap) + + n := atomic.AddUint64(&lruCount, 1) + + return &LRU{ + conn: conn, + mode: mode, + cap: cap, + m: make(map[string]*list.Element), + l: list.New(), + psNamePrefix: fmt.Sprintf("lrupsc_%d", n), + } +} + +// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. +func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) { + if ctx != context.Background() { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + } + + // flush an outstanding bad statements + txStatus := c.conn.TxStatus() + if (txStatus == 'I' || txStatus == 'T') && len(c.stmtsToClear) > 0 { + for _, stmt := range c.stmtsToClear { + err := c.clearStmt(ctx, stmt) + if err != nil { + return nil, err + } + } + } + + if el, ok := c.m[sql]; ok { + c.l.MoveToFront(el) + return el.Value.(*pgconn.StatementDescription), nil + } + + if c.l.Len() == c.cap { + err := c.removeOldest(ctx) + if err != nil { + return nil, err + } + } + + psd, err := c.prepare(ctx, sql) + if err != nil { + return nil, err + } + + el := c.l.PushFront(psd) + c.m[sql] = el + + return psd, nil +} + +// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. +func (c *LRU) Clear(ctx context.Context) error { + for c.l.Len() > 0 { + err := c.removeOldest(ctx) + if err != nil { + return err + } + } + + return nil +} + +func (c *LRU) StatementErrored(sql string, err error) { + pgErr, ok := err.(*pgconn.PgError) + if !ok { + return + } + + isInvalidCachedPlanError := pgErr.Severity == "ERROR" && + pgErr.Code == "0A000" && + pgErr.Message == "cached plan must not change result type" + if isInvalidCachedPlanError { + c.stmtsToClear = append(c.stmtsToClear, sql) + } +} + +func (c *LRU) clearStmt(ctx context.Context, sql string) error { + elem, inMap := c.m[sql] + if !inMap { + // The statement probably fell off the back of the list. In that case, we've + // ensured that it isn't in the cache, so we can declare victory. + return nil + } + + c.l.Remove(elem) + + psd := elem.Value.(*pgconn.StatementDescription) + delete(c.m, psd.SQL) + if c.mode == ModePrepare { + return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close() + } + return nil +} + +// Len returns the number of cached prepared statement descriptions. +func (c *LRU) Len() int { + return c.l.Len() +} + +// Cap returns the maximum number of cached prepared statement descriptions. +func (c *LRU) Cap() int { + return c.cap +} + +// Mode returns the mode of the cache (ModePrepare or ModeDescribe) +func (c *LRU) Mode() int { + return c.mode +} + +func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.StatementDescription, error) { + var name string + if c.mode == ModePrepare { + name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount) + c.prepareCount += 1 + } + + return c.conn.Prepare(ctx, name, sql, nil) +} + +func (c *LRU) removeOldest(ctx context.Context) error { + oldest := c.l.Back() + c.l.Remove(oldest) + psd := oldest.Value.(*pgconn.StatementDescription) + delete(c.m, psd.SQL) + if c.mode == ModePrepare { + return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close() + } + return nil +} diff --git a/pgconn/stmtcache/lru_test.go b/pgconn/stmtcache/lru_test.go new file mode 100644 index 00000000..f594ceac --- /dev/null +++ b/pgconn/stmtcache/lru_test.go @@ -0,0 +1,292 @@ +package stmtcache_test + +import ( + "context" + "fmt" + "math/rand" + "os" + "regexp" + "testing" + "time" + + "github.com/jackc/pgconn" + "github.com/jackc/pgconn/stmtcache" + + "github.com/stretchr/testify/require" +) + +func TestLRUModePrepare(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer conn.Close(ctx) + + cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2) + require.EqualValues(t, 0, cache.Len()) + require.EqualValues(t, 2, cache.Cap()) + require.EqualValues(t, stmtcache.ModePrepare, cache.Mode()) + + psd, err := cache.Get(ctx, "select 1") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 1, cache.Len()) + require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 1") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 1, cache.Len()) + require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 2") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 2, cache.Len()) + require.ElementsMatch(t, []string{"select 1", "select 2"}, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 3") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 2, cache.Len()) + require.ElementsMatch(t, []string{"select 2", "select 3"}, fetchServerStatements(t, ctx, conn)) + + err = cache.Clear(ctx) + require.NoError(t, err) + require.EqualValues(t, 0, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) +} + +func TestLRUStmtInvalidation(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer conn.Close(ctx) + + // we construct a fake error because its not super straightforward to actually call + // a prepared statement from the LRU cache without the helper routines which live + // in pgx proper. + fakeInvalidCachePlanError := &pgconn.PgError{ + Severity: "ERROR", + Code: "0A000", + Message: "cached plan must not change result type", + } + + cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2) + + // + // outside of a transaction, we eagerly flush the statement + // + + _, err = cache.Get(ctx, "select 1") + require.NoError(t, err) + require.EqualValues(t, 1, cache.Len()) + require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) + + cache.StatementErrored("select 1", fakeInvalidCachePlanError) + _, err = cache.Get(ctx, "select 2") + require.NoError(t, err) + require.EqualValues(t, 1, cache.Len()) + require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn)) + + err = cache.Clear(ctx) + require.NoError(t, err) + + // + // within an errored transaction, we defer the flush to after the first get + // that happens after the transaction is rolled back + // + + _, err = cache.Get(ctx, "select 1") + require.NoError(t, err) + require.EqualValues(t, 1, cache.Len()) + require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) + + res := conn.Exec(ctx, "begin") + require.NoError(t, res.Close()) + require.Equal(t, byte('T'), conn.TxStatus()) + + res = conn.Exec(ctx, "selec") + require.Error(t, res.Close()) + require.Equal(t, byte('E'), conn.TxStatus()) + + cache.StatementErrored("select 1", fakeInvalidCachePlanError) + require.EqualValues(t, 1, cache.Len()) + + res = conn.Exec(ctx, "rollback") + require.NoError(t, res.Close()) + + _, err = cache.Get(ctx, "select 2") + require.EqualValues(t, 1, cache.Len()) + require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn)) +} + +func TestLRUStmtInvalidationIntegration(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer conn.Close(ctx) + + cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2) + + result := conn.ExecParams(ctx, "create temporary table stmtcache_table (a text)", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + sql := "select * from stmtcache_table" + sd1, err := cache.Get(ctx, sql) + require.NoError(t, err) + + result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read() + require.NoError(t, result.Err) + + result = conn.ExecParams(ctx, "alter table stmtcache_table add column b text", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read() + require.EqualError(t, result.Err, "ERROR: cached plan must not change result type (SQLSTATE 0A000)") + + cache.StatementErrored(sql, result.Err) + + sd2, err := cache.Get(ctx, sql) + require.NoError(t, err) + require.NotEqual(t, sd1.Name, sd2.Name) + + result = conn.ExecPrepared(ctx, sd2.Name, nil, nil, nil).Read() + require.NoError(t, result.Err) +} + +func TestLRUModePrepareStress(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer conn.Close(ctx) + + cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 8) + require.EqualValues(t, 0, cache.Len()) + require.EqualValues(t, 8, cache.Cap()) + require.EqualValues(t, stmtcache.ModePrepare, cache.Mode()) + + for i := 0; i < 1000; i++ { + psd, err := cache.Get(ctx, fmt.Sprintf("select %d", rand.Intn(50))) + require.NoError(t, err) + require.NotNil(t, psd) + result := conn.ExecPrepared(ctx, psd.Name, nil, nil, nil).Read() + require.NoError(t, result.Err) + } +} + +func TestLRUModeDescribe(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer conn.Close(ctx) + + cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2) + require.EqualValues(t, 0, cache.Len()) + require.EqualValues(t, 2, cache.Cap()) + require.EqualValues(t, stmtcache.ModeDescribe, cache.Mode()) + + psd, err := cache.Get(ctx, "select 1") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 1, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 1") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 1, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 2") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 2, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 3") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 2, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) + + err = cache.Clear(ctx) + require.NoError(t, err) + require.EqualValues(t, 0, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) +} + +func TestLRUContext(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer conn.Close(ctx) + + cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2) + + // test 1 : getting a value for the first time with a cancelled context returns an error + ctx1, cancel1 := context.WithCancel(ctx) + cancel1() + + desc, err := cache.Get(ctx1, "SELECT 1") + require.Error(t, err) + require.Nil(t, desc) + + // test 2 : when querying for the 2nd time a cached value, if the context is canceled return an error + ctx2, cancel2 := context.WithCancel(ctx) + + desc, err = cache.Get(ctx2, "SELECT 2") + require.NoError(t, err) + require.NotNil(t, desc) + + cancel2() + + desc, err = cache.Get(ctx2, "SELECT 2") + require.Error(t, err) + require.Nil(t, desc) +} + +func fetchServerStatements(t testing.TB, ctx context.Context, conn *pgconn.PgConn) []string { + result := conn.ExecParams(ctx, `select statement from pg_prepared_statements`, nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + var statements []string + for _, r := range result.Rows { + statement := string(r[0]) + if conn.ParameterStatus("crdb_version") != "" { + if statement == "PREPARE AS select statement from pg_prepared_statements" { + // CockroachDB includes the currently running unnamed prepared statement while PostgreSQL does not. Ignore it. + continue + } + + // CockroachDB includes the "PREPARE ... AS" text in the statement even if it was prepared through the extended + // protocol will PostgreSQL does not. Normalize the statement. + re := regexp.MustCompile(`^PREPARE lrupsc[0-9_]+ AS `) + statement = re.ReplaceAllString(statement, "") + } + statements = append(statements, statement) + } + return statements +} diff --git a/pgconn/stmtcache/stmtcache.go b/pgconn/stmtcache/stmtcache.go new file mode 100644 index 00000000..d083e1b4 --- /dev/null +++ b/pgconn/stmtcache/stmtcache.go @@ -0,0 +1,58 @@ +// Package stmtcache is a cache that can be used to implement lazy prepared statements. +package stmtcache + +import ( + "context" + + "github.com/jackc/pgconn" +) + +const ( + ModePrepare = iota // Cache should prepare named statements. + ModeDescribe // Cache should prepare the anonymous prepared statement to only fetch the description of the statement. +) + +// Cache prepares and caches prepared statement descriptions. +type Cache interface { + // Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. + Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) + + // Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. + Clear(ctx context.Context) error + + // StatementErrored informs the cache that the given statement resulted in an error when it + // was last used against the database. In some cases, this will cause the cache to maer that + // statement as bad. The bad statement will instead be flushed during the next call to Get + // that occurs outside of a failed transaction. + StatementErrored(sql string, err error) + + // Len returns the number of cached prepared statement descriptions. + Len() int + + // Cap returns the maximum number of cached prepared statement descriptions. + Cap() int + + // Mode returns the mode of the cache (ModePrepare or ModeDescribe) + Mode() int +} + +// New returns the preferred cache implementation for mode and cap. mode is either ModePrepare or ModeDescribe. cap is +// the maximum size of the cache. +func New(conn *pgconn.PgConn, mode int, cap int) Cache { + mustBeValidMode(mode) + mustBeValidCap(cap) + + return NewLRU(conn, mode, cap) +} + +func mustBeValidMode(mode int) { + if mode != ModePrepare && mode != ModeDescribe { + panic("mode must be ModePrepare or ModeDescribe") + } +} + +func mustBeValidCap(cap int) { + if cap < 1 { + panic("cache must have cap of >= 1") + } +}