Merge branch 'pgconnimport' into v5-dev
This commit is contained in:
Vendored
+81
@@ -0,0 +1,81 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ master ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ master ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
|
||||||
|
test:
|
||||||
|
name: Test
|
||||||
|
runs-on: ubuntu-18.04
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
go-version: [1.15, 1.16]
|
||||||
|
pg-version: [9.6, 10, 11, 12, 13, cockroachdb]
|
||||||
|
include:
|
||||||
|
- pg-version: 9.6
|
||||||
|
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||||
|
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
|
||||||
|
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
|
||||||
|
- pg-version: 10
|
||||||
|
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||||
|
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
|
||||||
|
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
|
||||||
|
- pg-version: 11
|
||||||
|
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||||
|
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
|
||||||
|
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
|
||||||
|
- pg-version: 12
|
||||||
|
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||||
|
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
|
||||||
|
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
|
||||||
|
- pg-version: 13
|
||||||
|
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||||
|
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
|
||||||
|
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
|
||||||
|
- pg-version: cockroachdb
|
||||||
|
pgx-test-conn-string: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on"
|
||||||
|
|
||||||
|
steps:
|
||||||
|
|
||||||
|
- name: Set up Go 1.x
|
||||||
|
uses: actions/setup-go@v2
|
||||||
|
with:
|
||||||
|
go-version: ${{ matrix.go-version }}
|
||||||
|
|
||||||
|
- name: Check out code into the Go module directory
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Setup database server for testing
|
||||||
|
run: ci/setup_test.bash
|
||||||
|
env:
|
||||||
|
PGVERSION: ${{ matrix.pg-version }}
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: go test -v -race ./...
|
||||||
|
env:
|
||||||
|
PGX_TEST_CONN_STRING: ${{ matrix.pgx-test-conn-string }}
|
||||||
|
PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }}
|
||||||
|
PGX_TEST_TCP_CONN_STRING: ${{ matrix.pgx-test-tcp-conn-string }}
|
||||||
|
PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }}
|
||||||
|
PGX_TEST_MD5_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-md5-password-conn-string }}
|
||||||
|
PGX_TEST_PLAIN_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-plain-password-conn-string }}
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
.envrc
|
||||||
|
vendor/
|
||||||
|
.vscode
|
||||||
@@ -0,0 +1,129 @@
|
|||||||
|
# 1.10.1 (November 20, 2021)
|
||||||
|
|
||||||
|
* Close without waiting for response (Kei Kamikawa)
|
||||||
|
* Save waiting for network round-trip in CopyFrom (Rueian)
|
||||||
|
* Fix concurrency issue with ContextWatcher
|
||||||
|
* LRU.Get always checks context for cancellation / expiration (Georges Varouchas)
|
||||||
|
|
||||||
|
# 1.10.0 (July 24, 2021)
|
||||||
|
|
||||||
|
* net.Timeout errors are no longer returned when a query is canceled via context. A wrapped context error is returned.
|
||||||
|
|
||||||
|
# 1.9.0 (July 10, 2021)
|
||||||
|
|
||||||
|
* pgconn.Timeout only is true for errors originating in pgconn (Michael Darr)
|
||||||
|
* Add defaults for sslcert, sslkey, and sslrootcert (Joshua Brindle)
|
||||||
|
* Solve issue with 'sslmode=verify-full' when there are multiple hosts (mgoddard)
|
||||||
|
* Fix default host when parsing URL without host but with port
|
||||||
|
* Allow dbname query parameter in URL conn string
|
||||||
|
* Update underlying dependencies
|
||||||
|
|
||||||
|
# 1.8.1 (March 25, 2021)
|
||||||
|
|
||||||
|
* Better connection string sanitization (ip.novikov)
|
||||||
|
* Use proper pgpass location on Windows (Moshe Katz)
|
||||||
|
* Use errors instead of golang.org/x/xerrors
|
||||||
|
* Resume fallback on server error in Connect (Andrey Borodin)
|
||||||
|
|
||||||
|
# 1.8.0 (December 3, 2020)
|
||||||
|
|
||||||
|
* Add StatementErrored method to stmtcache.Cache. This allows the cache to purge invalidated prepared statements. (Ethan Pailes)
|
||||||
|
|
||||||
|
# 1.7.2 (November 3, 2020)
|
||||||
|
|
||||||
|
* Fix data value slices into work buffer with capacities larger than length.
|
||||||
|
|
||||||
|
# 1.7.1 (October 31, 2020)
|
||||||
|
|
||||||
|
* Do not asyncClose after receiving FATAL error from PostgreSQL server
|
||||||
|
|
||||||
|
# 1.7.0 (September 26, 2020)
|
||||||
|
|
||||||
|
* Exec(Params|Prepared) return ResultReader with FieldDescriptions loaded
|
||||||
|
* Add ReceiveResults (Sebastiaan Mannem)
|
||||||
|
* Fix parsing DSN connection with bad backslash
|
||||||
|
* Add PgConn.CleanupDone so connection pools can determine when async close is complete
|
||||||
|
|
||||||
|
# 1.6.4 (July 29, 2020)
|
||||||
|
|
||||||
|
* Fix deadlock on error after CommandComplete but before ReadyForQuery
|
||||||
|
* Fix panic on parsing DSN with trailing '='
|
||||||
|
|
||||||
|
# 1.6.3 (July 22, 2020)
|
||||||
|
|
||||||
|
* Fix error message after AppendCertsFromPEM failure (vahid-sohrabloo)
|
||||||
|
|
||||||
|
# 1.6.2 (July 14, 2020)
|
||||||
|
|
||||||
|
* Update pgservicefile library
|
||||||
|
|
||||||
|
# 1.6.1 (June 27, 2020)
|
||||||
|
|
||||||
|
* Update golang.org/x/crypto to latest
|
||||||
|
* Update golang.org/x/text to 0.3.3
|
||||||
|
* Fix error handling for bad PGSERVICE definition
|
||||||
|
* Redact passwords in ParseConfig errors (Lukas Vogel)
|
||||||
|
|
||||||
|
# 1.6.0 (June 6, 2020)
|
||||||
|
|
||||||
|
* Fix panic when closing conn during cancellable query
|
||||||
|
* Fix behavior of sslmode=require with sslrootcert present (Petr Jediný)
|
||||||
|
* Fix field descriptions available after command concluded (Tobias Salzmann)
|
||||||
|
* Support connect_timeout (georgysavva)
|
||||||
|
* Handle IPv6 in connection URLs (Lukas Vogel)
|
||||||
|
* Fix ValidateConnect with cancelable context
|
||||||
|
* Improve CopyFrom performance
|
||||||
|
* Add Config.Copy (georgysavva)
|
||||||
|
|
||||||
|
# 1.5.0 (March 30, 2020)
|
||||||
|
|
||||||
|
* Update golang.org/x/crypto for security fix
|
||||||
|
* Implement "verify-ca" SSL mode (Greg Curtis)
|
||||||
|
|
||||||
|
# 1.4.0 (March 7, 2020)
|
||||||
|
|
||||||
|
* Fix ExecParams and ExecPrepared handling of empty query.
|
||||||
|
* Support reading config from PostgreSQL service files.
|
||||||
|
|
||||||
|
# 1.3.2 (February 14, 2020)
|
||||||
|
|
||||||
|
* Update chunkreader to v2.0.1 for optimized default buffer size.
|
||||||
|
|
||||||
|
# 1.3.1 (February 5, 2020)
|
||||||
|
|
||||||
|
* Fix CopyFrom deadlock when multiple NoticeResponse received during copy
|
||||||
|
|
||||||
|
# 1.3.0 (January 23, 2020)
|
||||||
|
|
||||||
|
* Add Hijack and Construct.
|
||||||
|
* Update pgproto3 to v2.0.1.
|
||||||
|
|
||||||
|
# 1.2.1 (January 13, 2020)
|
||||||
|
|
||||||
|
* Fix data race in context cancellation introduced in v1.2.0.
|
||||||
|
|
||||||
|
# 1.2.0 (January 11, 2020)
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
* Add Insert(), Update(), Delete(), and Select() statement type query methods to CommandTag.
|
||||||
|
* Add PgError.SQLState method. This could be used for compatibility with other drivers and databases.
|
||||||
|
|
||||||
|
## Performance
|
||||||
|
|
||||||
|
* Improve performance when context.Background() is used. (bakape)
|
||||||
|
* CommandTag.RowsAffected is faster and does not allocate.
|
||||||
|
|
||||||
|
## Fixes
|
||||||
|
|
||||||
|
* Try to cancel any in-progress query when a conn is closed by ctx cancel.
|
||||||
|
* Handle NoticeResponse during CopyFrom.
|
||||||
|
* Ignore errors sending Terminate message while closing connection. This mimics the behavior of libpq PGfinish.
|
||||||
|
|
||||||
|
# 1.1.0 (October 12, 2019)
|
||||||
|
|
||||||
|
* Add PgConn.IsBusy() method.
|
||||||
|
|
||||||
|
# 1.0.1 (September 19, 2019)
|
||||||
|
|
||||||
|
* Fix statement cache not properly cleaning discarded statements.
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
Copyright (c) 2019-2021 Jack Christensen
|
||||||
|
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
a copy of this software and associated documentation files (the
|
||||||
|
"Software"), to deal in the Software without restriction, including
|
||||||
|
without limitation the rights to use, copy, modify, merge, publish,
|
||||||
|
distribute, sublicense, and/or sell copies of the Software, and to
|
||||||
|
permit persons to whom the Software is furnished to do so, subject to
|
||||||
|
the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be
|
||||||
|
included in all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||||
|
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||||
|
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||||
|
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
[](https://godoc.org/github.com/jackc/pgconn)
|
||||||
|

|
||||||
|
|
||||||
|
# pgconn
|
||||||
|
|
||||||
|
Package pgconn is a low-level PostgreSQL database driver. It operates at nearly the same level as the C library libpq.
|
||||||
|
It is primarily intended to serve as the foundation for higher level libraries such as https://github.com/jackc/pgx.
|
||||||
|
Applications should handle normal queries with a higher level library and only use pgconn directly when required for
|
||||||
|
low-level access to PostgreSQL functionality.
|
||||||
|
|
||||||
|
## Example Usage
|
||||||
|
|
||||||
|
```go
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL"))
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln("pgconn failed to connect:", err)
|
||||||
|
}
|
||||||
|
defer pgConn.Close(context.Background())
|
||||||
|
|
||||||
|
result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil)
|
||||||
|
for result.NextRow() {
|
||||||
|
fmt.Println("User 123 has email:", string(result.Values()[0]))
|
||||||
|
}
|
||||||
|
_, err = result.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln("failed reading result:", err)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
The pgconn tests require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_CONN_STRING`
|
||||||
|
environment variable. The `PGX_TEST_CONN_STRING` environment variable can be a URL or DSN. In addition, the standard `PG*`
|
||||||
|
environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to simplify
|
||||||
|
environment variable handling.
|
||||||
|
|
||||||
|
### Example Test Environment
|
||||||
|
|
||||||
|
Connect to your PostgreSQL server and run:
|
||||||
|
|
||||||
|
```
|
||||||
|
create database pgx_test;
|
||||||
|
```
|
||||||
|
|
||||||
|
Now you can run the tests:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
PGX_TEST_CONN_STRING="host=/var/run/postgresql dbname=pgx_test" go test ./...
|
||||||
|
```
|
||||||
|
|
||||||
|
### Connection and Authentication Tests
|
||||||
|
|
||||||
|
Pgconn supports multiple connection types and means of authentication. These tests are optional. They
|
||||||
|
will only run if the appropriate environment variable is set. Run `go test -v | grep SKIP` to see if any tests are being
|
||||||
|
skipped. Most developers will not need to enable these tests. See `ci/setup_test.bash` for an example set up if you need change
|
||||||
|
authentication code.
|
||||||
@@ -0,0 +1,266 @@
|
|||||||
|
// SCRAM-SHA-256 authentication
|
||||||
|
//
|
||||||
|
// Resources:
|
||||||
|
// https://tools.ietf.org/html/rfc5802
|
||||||
|
// https://tools.ietf.org/html/rfc8265
|
||||||
|
// https://www.postgresql.org/docs/current/sasl-authentication.html
|
||||||
|
//
|
||||||
|
// Inspiration drawn from other implementations:
|
||||||
|
// https://github.com/lib/pq/pull/608
|
||||||
|
// https://github.com/lib/pq/pull/788
|
||||||
|
// https://github.com/lib/pq/pull/833
|
||||||
|
|
||||||
|
package pgconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/jackc/pgproto3/v2"
|
||||||
|
"golang.org/x/crypto/pbkdf2"
|
||||||
|
"golang.org/x/text/secure/precis"
|
||||||
|
)
|
||||||
|
|
||||||
|
const clientNonceLen = 18
|
||||||
|
|
||||||
|
// Perform SCRAM authentication.
|
||||||
|
func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
|
||||||
|
sc, err := newScramClient(serverAuthMechanisms, c.config.Password)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send client-first-message in a SASLInitialResponse
|
||||||
|
saslInitialResponse := &pgproto3.SASLInitialResponse{
|
||||||
|
AuthMechanism: "SCRAM-SHA-256",
|
||||||
|
Data: sc.clientFirstMessage(),
|
||||||
|
}
|
||||||
|
_, err = c.conn.Write(saslInitialResponse.Encode(nil))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive server-first-message payload in a AuthenticationSASLContinue.
|
||||||
|
saslContinue, err := c.rxSASLContinue()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = sc.recvServerFirstMessage(saslContinue.Data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send client-final-message in a SASLResponse
|
||||||
|
saslResponse := &pgproto3.SASLResponse{
|
||||||
|
Data: []byte(sc.clientFinalMessage()),
|
||||||
|
}
|
||||||
|
_, err = c.conn.Write(saslResponse.Encode(nil))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive server-final-message payload in a AuthenticationSASLFinal.
|
||||||
|
saslFinal, err := c.rxSASLFinal()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return sc.recvServerFinalMessage(saslFinal.Data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) {
|
||||||
|
msg, err := c.receiveMessage()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
saslContinue, ok := msg.(*pgproto3.AuthenticationSASLContinue)
|
||||||
|
if ok {
|
||||||
|
return saslContinue, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errors.New("expected AuthenticationSASLContinue message but received unexpected message")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
|
||||||
|
msg, err := c.receiveMessage()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
saslFinal, ok := msg.(*pgproto3.AuthenticationSASLFinal)
|
||||||
|
if ok {
|
||||||
|
return saslFinal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errors.New("expected AuthenticationSASLFinal message but received unexpected message")
|
||||||
|
}
|
||||||
|
|
||||||
|
type scramClient struct {
|
||||||
|
serverAuthMechanisms []string
|
||||||
|
password []byte
|
||||||
|
clientNonce []byte
|
||||||
|
|
||||||
|
clientFirstMessageBare []byte
|
||||||
|
|
||||||
|
serverFirstMessage []byte
|
||||||
|
clientAndServerNonce []byte
|
||||||
|
salt []byte
|
||||||
|
iterations int
|
||||||
|
|
||||||
|
saltedPassword []byte
|
||||||
|
authMessage []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) {
|
||||||
|
sc := &scramClient{
|
||||||
|
serverAuthMechanisms: serverAuthMechanisms,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure server supports SCRAM-SHA-256
|
||||||
|
hasScramSHA256 := false
|
||||||
|
for _, mech := range sc.serverAuthMechanisms {
|
||||||
|
if mech == "SCRAM-SHA-256" {
|
||||||
|
hasScramSHA256 = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasScramSHA256 {
|
||||||
|
return nil, errors.New("server does not support SCRAM-SHA-256")
|
||||||
|
}
|
||||||
|
|
||||||
|
// precis.OpaqueString is equivalent to SASLprep for password.
|
||||||
|
var err error
|
||||||
|
sc.password, err = precis.OpaqueString.Bytes([]byte(password))
|
||||||
|
if err != nil {
|
||||||
|
// PostgreSQL allows passwords invalid according to SCRAM / SASLprep.
|
||||||
|
sc.password = []byte(password)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, clientNonceLen)
|
||||||
|
_, err = rand.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf)))
|
||||||
|
base64.RawStdEncoding.Encode(sc.clientNonce, buf)
|
||||||
|
|
||||||
|
return sc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *scramClient) clientFirstMessage() []byte {
|
||||||
|
sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce))
|
||||||
|
return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error {
|
||||||
|
sc.serverFirstMessage = serverFirstMessage
|
||||||
|
buf := serverFirstMessage
|
||||||
|
if !bytes.HasPrefix(buf, []byte("r=")) {
|
||||||
|
return errors.New("invalid SCRAM server-first-message received from server: did not include r=")
|
||||||
|
}
|
||||||
|
buf = buf[2:]
|
||||||
|
|
||||||
|
idx := bytes.IndexByte(buf, ',')
|
||||||
|
if idx == -1 {
|
||||||
|
return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
|
||||||
|
}
|
||||||
|
sc.clientAndServerNonce = buf[:idx]
|
||||||
|
buf = buf[idx+1:]
|
||||||
|
|
||||||
|
if !bytes.HasPrefix(buf, []byte("s=")) {
|
||||||
|
return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
|
||||||
|
}
|
||||||
|
buf = buf[2:]
|
||||||
|
|
||||||
|
idx = bytes.IndexByte(buf, ',')
|
||||||
|
if idx == -1 {
|
||||||
|
return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
|
||||||
|
}
|
||||||
|
saltStr := buf[:idx]
|
||||||
|
buf = buf[idx+1:]
|
||||||
|
|
||||||
|
if !bytes.HasPrefix(buf, []byte("i=")) {
|
||||||
|
return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
|
||||||
|
}
|
||||||
|
buf = buf[2:]
|
||||||
|
iterationsStr := buf
|
||||||
|
|
||||||
|
var err error
|
||||||
|
sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid SCRAM salt received from server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sc.iterations, err = strconv.Atoi(string(iterationsStr))
|
||||||
|
if err != nil || sc.iterations <= 0 {
|
||||||
|
return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) {
|
||||||
|
return errors.New("invalid SCRAM nonce: did not start with client nonce")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sc.clientAndServerNonce) <= len(sc.clientNonce) {
|
||||||
|
return errors.New("invalid SCRAM nonce: did not include server nonce")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *scramClient) clientFinalMessage() string {
|
||||||
|
clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce))
|
||||||
|
|
||||||
|
sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New)
|
||||||
|
sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(","))
|
||||||
|
|
||||||
|
clientProof := computeClientProof(sc.saltedPassword, sc.authMessage)
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error {
|
||||||
|
if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) {
|
||||||
|
return errors.New("invalid SCRAM server-final-message received from server")
|
||||||
|
}
|
||||||
|
|
||||||
|
serverSignature := serverFinalMessage[2:]
|
||||||
|
|
||||||
|
if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) {
|
||||||
|
return errors.New("invalid SCRAM ServerSignature received from server")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func computeHMAC(key, msg []byte) []byte {
|
||||||
|
mac := hmac.New(sha256.New, key)
|
||||||
|
mac.Write(msg)
|
||||||
|
return mac.Sum(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func computeClientProof(saltedPassword, authMessage []byte) []byte {
|
||||||
|
clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
|
||||||
|
storedKey := sha256.Sum256(clientKey)
|
||||||
|
clientSignature := computeHMAC(storedKey[:], authMessage)
|
||||||
|
|
||||||
|
clientProof := make([]byte, len(clientSignature))
|
||||||
|
for i := 0; i < len(clientSignature); i++ {
|
||||||
|
clientProof[i] = clientKey[i] ^ clientSignature[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof)))
|
||||||
|
base64.StdEncoding.Encode(buf, clientProof)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte {
|
||||||
|
serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
|
||||||
|
serverSignature := computeHMAC(serverKey, authMessage)
|
||||||
|
buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature)))
|
||||||
|
base64.StdEncoding.Encode(buf, serverSignature)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
@@ -0,0 +1,322 @@
|
|||||||
|
package pgconn_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgconn"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkConnect(b *testing.B) {
|
||||||
|
benchmarks := []struct {
|
||||||
|
name string
|
||||||
|
env string
|
||||||
|
}{
|
||||||
|
{"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"},
|
||||||
|
{"TCP", "PGX_TEST_TCP_CONN_STRING"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, bm := range benchmarks {
|
||||||
|
bm := bm
|
||||||
|
b.Run(bm.name, func(b *testing.B) {
|
||||||
|
connString := os.Getenv(bm.env)
|
||||||
|
if connString == "" {
|
||||||
|
b.Skipf("Skipping due to missing environment variable %v", bm.env)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
conn, err := pgconn.Connect(context.Background(), connString)
|
||||||
|
require.Nil(b, err)
|
||||||
|
|
||||||
|
err = conn.Close(context.Background())
|
||||||
|
require.Nil(b, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkExec(b *testing.B) {
|
||||||
|
expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
|
||||||
|
benchmarks := []struct {
|
||||||
|
name string
|
||||||
|
ctx context.Context
|
||||||
|
}{
|
||||||
|
// Using an empty context other than context.Background() to compare
|
||||||
|
// performance
|
||||||
|
{"background context", context.Background()},
|
||||||
|
{"empty context", context.TODO()},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, bm := range benchmarks {
|
||||||
|
bm := bm
|
||||||
|
b.Run(bm.name, func(b *testing.B) {
|
||||||
|
conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
|
require.Nil(b, err)
|
||||||
|
defer closeConn(b, conn)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
mrr := conn.Exec(bm.ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date")
|
||||||
|
|
||||||
|
for mrr.NextResult() {
|
||||||
|
rr := mrr.ResultReader()
|
||||||
|
|
||||||
|
rowCount := 0
|
||||||
|
for rr.NextRow() {
|
||||||
|
rowCount++
|
||||||
|
if len(rr.Values()) != len(expectedValues) {
|
||||||
|
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
|
||||||
|
}
|
||||||
|
for i := range rr.Values() {
|
||||||
|
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
|
||||||
|
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, err = rr.Close()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
if rowCount != 1 {
|
||||||
|
b.Fatalf("unexpected rowCount: %d", rowCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err := mrr.Close()
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkExecPossibleToCancel(b *testing.B) {
|
||||||
|
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
|
require.Nil(b, err)
|
||||||
|
defer closeConn(b, conn)
|
||||||
|
|
||||||
|
expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
mrr := conn.Exec(ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date")
|
||||||
|
|
||||||
|
for mrr.NextResult() {
|
||||||
|
rr := mrr.ResultReader()
|
||||||
|
|
||||||
|
rowCount := 0
|
||||||
|
for rr.NextRow() {
|
||||||
|
rowCount++
|
||||||
|
if len(rr.Values()) != len(expectedValues) {
|
||||||
|
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
|
||||||
|
}
|
||||||
|
for i := range rr.Values() {
|
||||||
|
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
|
||||||
|
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, err = rr.Close()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
if rowCount != 1 {
|
||||||
|
b.Fatalf("unexpected rowCount: %d", rowCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err := mrr.Close()
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkExecPrepared(b *testing.B) {
|
||||||
|
expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
|
||||||
|
|
||||||
|
benchmarks := []struct {
|
||||||
|
name string
|
||||||
|
ctx context.Context
|
||||||
|
}{
|
||||||
|
// Using an empty context other than context.Background() to compare
|
||||||
|
// performance
|
||||||
|
{"background context", context.Background()},
|
||||||
|
{"empty context", context.TODO()},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, bm := range benchmarks {
|
||||||
|
bm := bm
|
||||||
|
b.Run(bm.name, func(b *testing.B) {
|
||||||
|
conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
|
require.Nil(b, err)
|
||||||
|
defer closeConn(b, conn)
|
||||||
|
|
||||||
|
_, err = conn.Prepare(bm.ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
|
||||||
|
require.Nil(b, err)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
rr := conn.ExecPrepared(bm.ctx, "ps1", nil, nil, nil)
|
||||||
|
|
||||||
|
rowCount := 0
|
||||||
|
for rr.NextRow() {
|
||||||
|
rowCount++
|
||||||
|
if len(rr.Values()) != len(expectedValues) {
|
||||||
|
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
|
||||||
|
}
|
||||||
|
for i := range rr.Values() {
|
||||||
|
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
|
||||||
|
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, err = rr.Close()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
if rowCount != 1 {
|
||||||
|
b.Fatalf("unexpected rowCount: %d", rowCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkExecPreparedPossibleToCancel(b *testing.B) {
|
||||||
|
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
|
require.Nil(b, err)
|
||||||
|
defer closeConn(b, conn)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, err = conn.Prepare(ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
|
||||||
|
require.Nil(b, err)
|
||||||
|
|
||||||
|
expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
rr := conn.ExecPrepared(ctx, "ps1", nil, nil, nil)
|
||||||
|
|
||||||
|
rowCount := 0
|
||||||
|
for rr.NextRow() {
|
||||||
|
rowCount += 1
|
||||||
|
if len(rr.Values()) != len(expectedValues) {
|
||||||
|
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
|
||||||
|
}
|
||||||
|
for i := range rr.Values() {
|
||||||
|
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
|
||||||
|
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, err = rr.Close()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
if rowCount != 1 {
|
||||||
|
b.Fatalf("unexpected rowCount: %d", rowCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// func BenchmarkChanToSetDeadlinePossibleToCancel(b *testing.B) {
|
||||||
|
// conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
|
// require.Nil(b, err)
|
||||||
|
// defer closeConn(b, conn)
|
||||||
|
|
||||||
|
// ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
// defer cancel()
|
||||||
|
|
||||||
|
// b.ResetTimer()
|
||||||
|
|
||||||
|
// for i := 0; i < b.N; i++ {
|
||||||
|
// conn.ChanToSetDeadline().Watch(ctx)
|
||||||
|
// conn.ChanToSetDeadline().Ignore()
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
func BenchmarkCommandTagRowsAffected(b *testing.B) {
|
||||||
|
benchmarks := []struct {
|
||||||
|
commandTag string
|
||||||
|
rowsAffected int64
|
||||||
|
}{
|
||||||
|
{"UPDATE 1", 1},
|
||||||
|
{"UPDATE 123456789", 123456789},
|
||||||
|
{"INSERT 0 1", 1},
|
||||||
|
{"INSERT 0 123456789", 123456789},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, bm := range benchmarks {
|
||||||
|
ct := pgconn.CommandTag(bm.commandTag)
|
||||||
|
b.Run(bm.commandTag, func(b *testing.B) {
|
||||||
|
var n int64
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
n = ct.RowsAffected()
|
||||||
|
}
|
||||||
|
if n != bm.rowsAffected {
|
||||||
|
b.Errorf("expected %d got %d", bm.rowsAffected, n)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkCommandTagTypeFromString(b *testing.B) {
|
||||||
|
ct := pgconn.CommandTag("UPDATE 1")
|
||||||
|
|
||||||
|
var update bool
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
update = strings.HasPrefix(ct.String(), "UPDATE")
|
||||||
|
}
|
||||||
|
if !update {
|
||||||
|
b.Error("expected update")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkCommandTagInsert(b *testing.B) {
|
||||||
|
benchmarks := []struct {
|
||||||
|
commandTag string
|
||||||
|
is bool
|
||||||
|
}{
|
||||||
|
{"INSERT 1", true},
|
||||||
|
{"INSERT 1234567890", true},
|
||||||
|
{"UPDATE 1", false},
|
||||||
|
{"UPDATE 1234567890", false},
|
||||||
|
{"DELETE 1", false},
|
||||||
|
{"DELETE 1234567890", false},
|
||||||
|
{"SELECT 1", false},
|
||||||
|
{"SELECT 1234567890", false},
|
||||||
|
{"UNKNOWN 1234567890", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, bm := range benchmarks {
|
||||||
|
ct := pgconn.CommandTag(bm.commandTag)
|
||||||
|
b.Run(bm.commandTag, func(b *testing.B) {
|
||||||
|
var is bool
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
is = ct.Insert()
|
||||||
|
}
|
||||||
|
if is != bm.is {
|
||||||
|
b.Errorf("expected %v got %v", bm.is, is)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Executable
+10
@@ -0,0 +1,10 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -eux
|
||||||
|
|
||||||
|
if [ "${PGVERSION-}" != "" ]
|
||||||
|
then
|
||||||
|
go test -v -race ./...
|
||||||
|
elif [ "${CRATEVERSION-}" != "" ]
|
||||||
|
then
|
||||||
|
go test -v -race -run 'TestCrateDBConnect'
|
||||||
|
fi
|
||||||
Executable
+59
@@ -0,0 +1,59 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -eux
|
||||||
|
|
||||||
|
if [[ "${PGVERSION-}" =~ ^[0-9.]+$ ]]
|
||||||
|
then
|
||||||
|
sudo apt-get remove -y --purge postgresql libpq-dev libpq5 postgresql-client-common postgresql-common
|
||||||
|
sudo rm -rf /var/lib/postgresql
|
||||||
|
wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add -
|
||||||
|
sudo sh -c "echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION >> /etc/apt/sources.list.d/postgresql.list"
|
||||||
|
sudo apt-get update -qq
|
||||||
|
sudo apt-get -y -o Dpkg::Options::=--force-confdef -o Dpkg::Options::="--force-confnew" install postgresql-$PGVERSION postgresql-server-dev-$PGVERSION postgresql-contrib-$PGVERSION
|
||||||
|
sudo chmod 777 /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||||
|
echo "local all postgres trust" > /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||||
|
echo "local all all trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||||
|
echo "host all pgx_md5 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||||
|
echo "host all pgx_pw 127.0.0.1/32 password" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||||
|
echo "hostssl all pgx_ssl 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||||
|
echo "host replication pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||||
|
echo "host pgx_test pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||||
|
sudo chmod 777 /etc/postgresql/$PGVERSION/main/postgresql.conf
|
||||||
|
if $(dpkg --compare-versions $PGVERSION ge 9.6) ; then
|
||||||
|
echo "wal_level='logical'" >> /etc/postgresql/$PGVERSION/main/postgresql.conf
|
||||||
|
echo "max_wal_senders=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf
|
||||||
|
echo "max_replication_slots=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf
|
||||||
|
fi
|
||||||
|
sudo /etc/init.d/postgresql restart
|
||||||
|
|
||||||
|
# The tricky test user, below, has to actually exist so that it can be used in a test
|
||||||
|
# of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles.
|
||||||
|
psql -U postgres -c 'create database pgx_test'
|
||||||
|
psql -U postgres pgx_test -c 'create extension hstore'
|
||||||
|
psql -U postgres pgx_test -c 'create domain uint64 as numeric(20,0)'
|
||||||
|
psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'"
|
||||||
|
psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'"
|
||||||
|
psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'"
|
||||||
|
psql -U postgres -c "create user `whoami`"
|
||||||
|
psql -U postgres -c "create user pgx_replication with replication password 'secret'"
|
||||||
|
psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "${PGVERSION-}" =~ ^cockroach ]]
|
||||||
|
then
|
||||||
|
wget -qO- https://binaries.cockroachdb.com/cockroach-v20.2.5.linux-amd64.tgz | tar xvz
|
||||||
|
sudo mv cockroach-v20.2.5.linux-amd64/cockroach /usr/local/bin/
|
||||||
|
cockroach start-single-node --insecure --background --listen-addr=localhost
|
||||||
|
cockroach sql --insecure -e 'create database pgx_test'
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "${CRATEVERSION-}" != "" ]
|
||||||
|
then
|
||||||
|
docker run \
|
||||||
|
-p "6543:5432" \
|
||||||
|
-d \
|
||||||
|
crate:"$CRATEVERSION" \
|
||||||
|
crate \
|
||||||
|
-Cnetwork.host=0.0.0.0 \
|
||||||
|
-Ctransport.host=localhost \
|
||||||
|
-Clicense.enterprise=false
|
||||||
|
fi
|
||||||
@@ -0,0 +1,729 @@
|
|||||||
|
package pgconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"math"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/chunkreader/v2"
|
||||||
|
"github.com/jackc/pgpassfile"
|
||||||
|
"github.com/jackc/pgproto3/v2"
|
||||||
|
"github.com/jackc/pgservicefile"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
||||||
|
type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
||||||
|
|
||||||
|
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig. A
|
||||||
|
// manually initialized Config will cause ConnectConfig to panic.
|
||||||
|
type Config struct {
|
||||||
|
Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp)
|
||||||
|
Port uint16
|
||||||
|
Database string
|
||||||
|
User string
|
||||||
|
Password string
|
||||||
|
TLSConfig *tls.Config // nil disables TLS
|
||||||
|
ConnectTimeout time.Duration
|
||||||
|
DialFunc DialFunc // e.g. net.Dialer.DialContext
|
||||||
|
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
|
||||||
|
BuildFrontend BuildFrontendFunc
|
||||||
|
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
||||||
|
|
||||||
|
Fallbacks []*FallbackConfig
|
||||||
|
|
||||||
|
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
|
||||||
|
// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
|
||||||
|
// fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs.
|
||||||
|
ValidateConnect ValidateConnectFunc
|
||||||
|
|
||||||
|
// AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables
|
||||||
|
// or prepare statements). If this returns an error the connection attempt fails.
|
||||||
|
AfterConnect AfterConnectFunc
|
||||||
|
|
||||||
|
// OnNotice is a callback function called when a notice response is received.
|
||||||
|
OnNotice NoticeHandler
|
||||||
|
|
||||||
|
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
|
||||||
|
OnNotification NotificationHandler
|
||||||
|
|
||||||
|
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy returns a deep copy of the config that is safe to use and modify.
|
||||||
|
// The only exception is the TLSConfig field:
|
||||||
|
// according to the tls.Config docs it must not be modified after creation.
|
||||||
|
func (c *Config) Copy() *Config {
|
||||||
|
newConf := new(Config)
|
||||||
|
*newConf = *c
|
||||||
|
if newConf.TLSConfig != nil {
|
||||||
|
newConf.TLSConfig = c.TLSConfig.Clone()
|
||||||
|
}
|
||||||
|
if newConf.RuntimeParams != nil {
|
||||||
|
newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams))
|
||||||
|
for k, v := range c.RuntimeParams {
|
||||||
|
newConf.RuntimeParams[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if newConf.Fallbacks != nil {
|
||||||
|
newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks))
|
||||||
|
for i, fallback := range c.Fallbacks {
|
||||||
|
newFallback := new(FallbackConfig)
|
||||||
|
*newFallback = *fallback
|
||||||
|
if newFallback.TLSConfig != nil {
|
||||||
|
newFallback.TLSConfig = fallback.TLSConfig.Clone()
|
||||||
|
}
|
||||||
|
newConf.Fallbacks[i] = newFallback
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return newConf
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a
|
||||||
|
// network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections.
|
||||||
|
type FallbackConfig struct {
|
||||||
|
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
||||||
|
Port uint16
|
||||||
|
TLSConfig *tls.Config // nil disables TLS
|
||||||
|
}
|
||||||
|
|
||||||
|
// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with
|
||||||
|
// net.Dial.
|
||||||
|
func NetworkAddress(host string, port uint16) (network, address string) {
|
||||||
|
if strings.HasPrefix(host, "/") {
|
||||||
|
network = "unix"
|
||||||
|
address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
|
||||||
|
} else {
|
||||||
|
network = "tcp"
|
||||||
|
address = net.JoinHostPort(host, strconv.Itoa(int(port)))
|
||||||
|
}
|
||||||
|
return network, address
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseConfig builds a *Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same
|
||||||
|
// defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely matches
|
||||||
|
// the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style). See
|
||||||
|
// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be
|
||||||
|
// empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file.
|
||||||
|
//
|
||||||
|
// # Example DSN
|
||||||
|
// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca
|
||||||
|
//
|
||||||
|
// # Example URL
|
||||||
|
// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca
|
||||||
|
//
|
||||||
|
// The returned *Config may be modified. However, it is strongly recommended that any configuration that can be done
|
||||||
|
// through the connection string be done there. In particular the fields Host, Port, TLSConfig, and Fallbacks can be
|
||||||
|
// interdependent (e.g. TLSConfig needs knowledge of the host to validate the server certificate). These fields should
|
||||||
|
// not be modified individually. They should all be modified or all left unchanged.
|
||||||
|
//
|
||||||
|
// ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated
|
||||||
|
// values that will be tried in order. This can be used as part of a high availability system. See
|
||||||
|
// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information.
|
||||||
|
//
|
||||||
|
// # Example URL
|
||||||
|
// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
|
||||||
|
//
|
||||||
|
// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed
|
||||||
|
// via database URL or DSN:
|
||||||
|
//
|
||||||
|
// PGHOST
|
||||||
|
// PGPORT
|
||||||
|
// PGDATABASE
|
||||||
|
// PGUSER
|
||||||
|
// PGPASSWORD
|
||||||
|
// PGPASSFILE
|
||||||
|
// PGSERVICE
|
||||||
|
// PGSERVICEFILE
|
||||||
|
// PGSSLMODE
|
||||||
|
// PGSSLCERT
|
||||||
|
// PGSSLKEY
|
||||||
|
// PGSSLROOTCERT
|
||||||
|
// PGAPPNAME
|
||||||
|
// PGCONNECT_TIMEOUT
|
||||||
|
// PGTARGETSESSIONATTRS
|
||||||
|
//
|
||||||
|
// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables.
|
||||||
|
//
|
||||||
|
// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are
|
||||||
|
// usually but not always the environment variable name downcased and without the "PG" prefix.
|
||||||
|
//
|
||||||
|
// Important Security Notes:
|
||||||
|
//
|
||||||
|
// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if
|
||||||
|
// not set.
|
||||||
|
//
|
||||||
|
// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of
|
||||||
|
// security each sslmode provides.
|
||||||
|
//
|
||||||
|
// The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of
|
||||||
|
// the Config struct. If TLSConfig is manually changed it will not affect the fallbacks. For example, in the case of
|
||||||
|
// sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try the fallback
|
||||||
|
// which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually
|
||||||
|
// changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting
|
||||||
|
// TLCConfig.
|
||||||
|
//
|
||||||
|
// Other known differences with libpq:
|
||||||
|
//
|
||||||
|
// If a host name resolves into multiple addresses, libpq will try all addresses. pgconn will only try the first.
|
||||||
|
//
|
||||||
|
// When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn
|
||||||
|
// does not.
|
||||||
|
//
|
||||||
|
// In addition, ParseConfig accepts the following options:
|
||||||
|
//
|
||||||
|
// min_read_buffer_size
|
||||||
|
// The minimum size of the internal read buffer. Default 8192.
|
||||||
|
// servicefile
|
||||||
|
// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a
|
||||||
|
// part of the connection string.
|
||||||
|
func ParseConfig(connString string) (*Config, error) {
|
||||||
|
defaultSettings := defaultSettings()
|
||||||
|
envSettings := parseEnvSettings()
|
||||||
|
|
||||||
|
connStringSettings := make(map[string]string)
|
||||||
|
if connString != "" {
|
||||||
|
var err error
|
||||||
|
// connString may be a database URL or a DSN
|
||||||
|
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
|
||||||
|
connStringSettings, err = parseURLSettings(connString)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
connStringSettings, err = parseDSNSettings(connString)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
settings := mergeSettings(defaultSettings, envSettings, connStringSettings)
|
||||||
|
if service, present := settings["service"]; present {
|
||||||
|
serviceSettings, err := parseServiceSettings(settings["servicefile"], service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings)
|
||||||
|
}
|
||||||
|
|
||||||
|
minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
createdByParseConfig: true,
|
||||||
|
Database: settings["database"],
|
||||||
|
User: settings["user"],
|
||||||
|
Password: settings["password"],
|
||||||
|
RuntimeParams: make(map[string]string),
|
||||||
|
BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)),
|
||||||
|
}
|
||||||
|
|
||||||
|
if connectTimeoutSetting, present := settings["connect_timeout"]; present {
|
||||||
|
connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
|
||||||
|
}
|
||||||
|
config.ConnectTimeout = connectTimeout
|
||||||
|
config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
|
||||||
|
} else {
|
||||||
|
defaultDialer := makeDefaultDialer()
|
||||||
|
config.DialFunc = defaultDialer.DialContext
|
||||||
|
}
|
||||||
|
|
||||||
|
config.LookupFunc = makeDefaultResolver().LookupHost
|
||||||
|
|
||||||
|
notRuntimeParams := map[string]struct{}{
|
||||||
|
"host": struct{}{},
|
||||||
|
"port": struct{}{},
|
||||||
|
"database": struct{}{},
|
||||||
|
"user": struct{}{},
|
||||||
|
"password": struct{}{},
|
||||||
|
"passfile": struct{}{},
|
||||||
|
"connect_timeout": struct{}{},
|
||||||
|
"sslmode": struct{}{},
|
||||||
|
"sslkey": struct{}{},
|
||||||
|
"sslcert": struct{}{},
|
||||||
|
"sslrootcert": struct{}{},
|
||||||
|
"target_session_attrs": struct{}{},
|
||||||
|
"min_read_buffer_size": struct{}{},
|
||||||
|
"service": struct{}{},
|
||||||
|
"servicefile": struct{}{},
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range settings {
|
||||||
|
if _, present := notRuntimeParams[k]; present {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
config.RuntimeParams[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
fallbacks := []*FallbackConfig{}
|
||||||
|
|
||||||
|
hosts := strings.Split(settings["host"], ",")
|
||||||
|
ports := strings.Split(settings["port"], ",")
|
||||||
|
|
||||||
|
for i, host := range hosts {
|
||||||
|
var portStr string
|
||||||
|
if i < len(ports) {
|
||||||
|
portStr = ports[i]
|
||||||
|
} else {
|
||||||
|
portStr = ports[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
port, err := parsePort(portStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
var tlsConfigs []*tls.Config
|
||||||
|
|
||||||
|
// Ignore TLS settings if Unix domain socket like libpq
|
||||||
|
if network, _ := NetworkAddress(host, port); network == "unix" {
|
||||||
|
tlsConfigs = append(tlsConfigs, nil)
|
||||||
|
} else {
|
||||||
|
var err error
|
||||||
|
tlsConfigs, err = configTLS(settings, host)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tlsConfig := range tlsConfigs {
|
||||||
|
fallbacks = append(fallbacks, &FallbackConfig{
|
||||||
|
Host: host,
|
||||||
|
Port: port,
|
||||||
|
TLSConfig: tlsConfig,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Host = fallbacks[0].Host
|
||||||
|
config.Port = fallbacks[0].Port
|
||||||
|
config.TLSConfig = fallbacks[0].TLSConfig
|
||||||
|
config.Fallbacks = fallbacks[1:]
|
||||||
|
|
||||||
|
passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
|
||||||
|
if err == nil {
|
||||||
|
if config.Password == "" {
|
||||||
|
host := config.Host
|
||||||
|
if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" {
|
||||||
|
host = "localhost"
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if settings["target_session_attrs"] == "read-write" {
|
||||||
|
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite
|
||||||
|
} else if settings["target_session_attrs"] != "any" {
|
||||||
|
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", settings["target_session_attrs"])}
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeSettings(settingSets ...map[string]string) map[string]string {
|
||||||
|
settings := make(map[string]string)
|
||||||
|
|
||||||
|
for _, s2 := range settingSets {
|
||||||
|
for k, v := range s2 {
|
||||||
|
settings[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return settings
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseEnvSettings() map[string]string {
|
||||||
|
settings := make(map[string]string)
|
||||||
|
|
||||||
|
nameMap := map[string]string{
|
||||||
|
"PGHOST": "host",
|
||||||
|
"PGPORT": "port",
|
||||||
|
"PGDATABASE": "database",
|
||||||
|
"PGUSER": "user",
|
||||||
|
"PGPASSWORD": "password",
|
||||||
|
"PGPASSFILE": "passfile",
|
||||||
|
"PGAPPNAME": "application_name",
|
||||||
|
"PGCONNECT_TIMEOUT": "connect_timeout",
|
||||||
|
"PGSSLMODE": "sslmode",
|
||||||
|
"PGSSLKEY": "sslkey",
|
||||||
|
"PGSSLCERT": "sslcert",
|
||||||
|
"PGSSLROOTCERT": "sslrootcert",
|
||||||
|
"PGTARGETSESSIONATTRS": "target_session_attrs",
|
||||||
|
"PGSERVICE": "service",
|
||||||
|
"PGSERVICEFILE": "servicefile",
|
||||||
|
}
|
||||||
|
|
||||||
|
for envname, realname := range nameMap {
|
||||||
|
value := os.Getenv(envname)
|
||||||
|
if value != "" {
|
||||||
|
settings[realname] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return settings
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseURLSettings(connString string) (map[string]string, error) {
|
||||||
|
settings := make(map[string]string)
|
||||||
|
|
||||||
|
url, err := url.Parse(connString)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if url.User != nil {
|
||||||
|
settings["user"] = url.User.Username()
|
||||||
|
if password, present := url.User.Password(); present {
|
||||||
|
settings["password"] = password
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port.
|
||||||
|
var hosts []string
|
||||||
|
var ports []string
|
||||||
|
for _, host := range strings.Split(url.Host, ",") {
|
||||||
|
if host == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if isIPOnly(host) {
|
||||||
|
hosts = append(hosts, strings.Trim(host, "[]"))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
h, p, err := net.SplitHostPort(host)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err)
|
||||||
|
}
|
||||||
|
if h != "" {
|
||||||
|
hosts = append(hosts, h)
|
||||||
|
}
|
||||||
|
if p != "" {
|
||||||
|
ports = append(ports, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(hosts) > 0 {
|
||||||
|
settings["host"] = strings.Join(hosts, ",")
|
||||||
|
}
|
||||||
|
if len(ports) > 0 {
|
||||||
|
settings["port"] = strings.Join(ports, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
database := strings.TrimLeft(url.Path, "/")
|
||||||
|
if database != "" {
|
||||||
|
settings["database"] = database
|
||||||
|
}
|
||||||
|
|
||||||
|
nameMap := map[string]string{
|
||||||
|
"dbname": "database",
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range url.Query() {
|
||||||
|
if k2, present := nameMap[k]; present {
|
||||||
|
k = k2
|
||||||
|
}
|
||||||
|
|
||||||
|
settings[k] = v[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
return settings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIPOnly(host string) bool {
|
||||||
|
return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":")
|
||||||
|
}
|
||||||
|
|
||||||
|
var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
|
||||||
|
|
||||||
|
func parseDSNSettings(s string) (map[string]string, error) {
|
||||||
|
settings := make(map[string]string)
|
||||||
|
|
||||||
|
nameMap := map[string]string{
|
||||||
|
"dbname": "database",
|
||||||
|
}
|
||||||
|
|
||||||
|
for len(s) > 0 {
|
||||||
|
var key, val string
|
||||||
|
eqIdx := strings.IndexRune(s, '=')
|
||||||
|
if eqIdx < 0 {
|
||||||
|
return nil, errors.New("invalid dsn")
|
||||||
|
}
|
||||||
|
|
||||||
|
key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
|
||||||
|
s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f")
|
||||||
|
if len(s) == 0 {
|
||||||
|
} else if s[0] != '\'' {
|
||||||
|
end := 0
|
||||||
|
for ; end < len(s); end++ {
|
||||||
|
if asciiSpace[s[end]] == 1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if s[end] == '\\' {
|
||||||
|
end++
|
||||||
|
if end == len(s) {
|
||||||
|
return nil, errors.New("invalid backslash")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
|
||||||
|
if end == len(s) {
|
||||||
|
s = ""
|
||||||
|
} else {
|
||||||
|
s = s[end+1:]
|
||||||
|
}
|
||||||
|
} else { // quoted string
|
||||||
|
s = s[1:]
|
||||||
|
end := 0
|
||||||
|
for ; end < len(s); end++ {
|
||||||
|
if s[end] == '\'' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if s[end] == '\\' {
|
||||||
|
end++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if end == len(s) {
|
||||||
|
return nil, errors.New("unterminated quoted string in connection info string")
|
||||||
|
}
|
||||||
|
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
|
||||||
|
if end == len(s) {
|
||||||
|
s = ""
|
||||||
|
} else {
|
||||||
|
s = s[end+1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if k, ok := nameMap[key]; ok {
|
||||||
|
key = k
|
||||||
|
}
|
||||||
|
|
||||||
|
if key == "" {
|
||||||
|
return nil, errors.New("invalid dsn")
|
||||||
|
}
|
||||||
|
|
||||||
|
settings[key] = val
|
||||||
|
}
|
||||||
|
|
||||||
|
return settings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) {
|
||||||
|
servicefile, err := pgservicefile.ReadServicefile(servicefilePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read service file: %v", servicefilePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
service, err := servicefile.GetService(serviceName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to find service: %v", serviceName)
|
||||||
|
}
|
||||||
|
|
||||||
|
nameMap := map[string]string{
|
||||||
|
"dbname": "database",
|
||||||
|
}
|
||||||
|
|
||||||
|
settings := make(map[string]string, len(service.Settings))
|
||||||
|
for k, v := range service.Settings {
|
||||||
|
if k2, present := nameMap[k]; present {
|
||||||
|
k = k2
|
||||||
|
}
|
||||||
|
settings[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
return settings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is
|
||||||
|
// necessary to allow returning multiple TLS configs as sslmode "allow" and
|
||||||
|
// "prefer" allow fallback.
|
||||||
|
func configTLS(settings map[string]string, thisHost string) ([]*tls.Config, error) {
|
||||||
|
host := thisHost
|
||||||
|
sslmode := settings["sslmode"]
|
||||||
|
sslrootcert := settings["sslrootcert"]
|
||||||
|
sslcert := settings["sslcert"]
|
||||||
|
sslkey := settings["sslkey"]
|
||||||
|
|
||||||
|
// Match libpq default behavior
|
||||||
|
if sslmode == "" {
|
||||||
|
sslmode = "prefer"
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{}
|
||||||
|
|
||||||
|
switch sslmode {
|
||||||
|
case "disable":
|
||||||
|
return []*tls.Config{nil}, nil
|
||||||
|
case "allow", "prefer":
|
||||||
|
tlsConfig.InsecureSkipVerify = true
|
||||||
|
case "require":
|
||||||
|
// According to PostgreSQL documentation, if a root CA file exists,
|
||||||
|
// the behavior of sslmode=require should be the same as that of verify-ca
|
||||||
|
//
|
||||||
|
// See https://www.postgresql.org/docs/12/libpq-ssl.html
|
||||||
|
if sslrootcert != "" {
|
||||||
|
goto nextCase
|
||||||
|
}
|
||||||
|
tlsConfig.InsecureSkipVerify = true
|
||||||
|
break
|
||||||
|
nextCase:
|
||||||
|
fallthrough
|
||||||
|
case "verify-ca":
|
||||||
|
// Don't perform the default certificate verification because it
|
||||||
|
// will verify the hostname. Instead, verify the server's
|
||||||
|
// certificate chain ourselves in VerifyPeerCertificate and
|
||||||
|
// ignore the server name. This emulates libpq's verify-ca
|
||||||
|
// behavior.
|
||||||
|
//
|
||||||
|
// See https://github.com/golang/go/issues/21971#issuecomment-332693931
|
||||||
|
// and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate
|
||||||
|
// for more info.
|
||||||
|
tlsConfig.InsecureSkipVerify = true
|
||||||
|
tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error {
|
||||||
|
certs := make([]*x509.Certificate, len(certificates))
|
||||||
|
for i, asn1Data := range certificates {
|
||||||
|
cert, err := x509.ParseCertificate(asn1Data)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("failed to parse certificate from server: " + err.Error())
|
||||||
|
}
|
||||||
|
certs[i] = cert
|
||||||
|
}
|
||||||
|
|
||||||
|
// Leave DNSName empty to skip hostname verification.
|
||||||
|
opts := x509.VerifyOptions{
|
||||||
|
Roots: tlsConfig.RootCAs,
|
||||||
|
Intermediates: x509.NewCertPool(),
|
||||||
|
}
|
||||||
|
// Skip the first cert because it's the leaf. All others
|
||||||
|
// are intermediates.
|
||||||
|
for _, cert := range certs[1:] {
|
||||||
|
opts.Intermediates.AddCert(cert)
|
||||||
|
}
|
||||||
|
_, err := certs[0].Verify(opts)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case "verify-full":
|
||||||
|
tlsConfig.ServerName = host
|
||||||
|
default:
|
||||||
|
return nil, errors.New("sslmode is invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
if sslrootcert != "" {
|
||||||
|
caCertPool := x509.NewCertPool()
|
||||||
|
|
||||||
|
caPath := sslrootcert
|
||||||
|
caCert, err := ioutil.ReadFile(caPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to read CA file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||||
|
return nil, errors.New("unable to add CA to cert pool")
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig.RootCAs = caCertPool
|
||||||
|
tlsConfig.ClientCAs = caCertPool
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
|
||||||
|
return nil, errors.New(`both "sslcert" and "sslkey" are required`)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sslcert != "" && sslkey != "" {
|
||||||
|
cert, err := tls.LoadX509KeyPair(sslcert, sslkey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to read cert: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch sslmode {
|
||||||
|
case "allow":
|
||||||
|
return []*tls.Config{nil, tlsConfig}, nil
|
||||||
|
case "prefer":
|
||||||
|
return []*tls.Config{tlsConfig, nil}, nil
|
||||||
|
case "require", "verify-ca", "verify-full":
|
||||||
|
return []*tls.Config{tlsConfig}, nil
|
||||||
|
default:
|
||||||
|
panic("BUG: bad sslmode should already have been caught")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePort(s string) (uint16, error) {
|
||||||
|
port, err := strconv.ParseUint(s, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if port < 1 || port > math.MaxUint16 {
|
||||||
|
return 0, errors.New("outside range")
|
||||||
|
}
|
||||||
|
return uint16(port), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeDefaultDialer() *net.Dialer {
|
||||||
|
return &net.Dialer{KeepAlive: 5 * time.Minute}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeDefaultResolver() *net.Resolver {
|
||||||
|
return net.DefaultResolver
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc {
|
||||||
|
return func(r io.Reader, w io.Writer) Frontend {
|
||||||
|
cr, err := chunkreader.NewConfig(r, chunkreader.Config{MinBufLen: minBufferLen})
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("BUG: chunkreader.NewConfig failed: %v", err))
|
||||||
|
}
|
||||||
|
frontend := pgproto3.NewFrontend(cr, w)
|
||||||
|
|
||||||
|
return frontend
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseConnectTimeoutSetting(s string) (time.Duration, error) {
|
||||||
|
timeout, err := strconv.ParseInt(s, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if timeout < 0 {
|
||||||
|
return 0, errors.New("negative timeout")
|
||||||
|
}
|
||||||
|
return time.Duration(timeout) * time.Second, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc {
|
||||||
|
d := makeDefaultDialer()
|
||||||
|
d.Timeout = timeout
|
||||||
|
return d.DialContext
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible
|
||||||
|
// target_session_attrs=read-write.
|
||||||
|
func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
|
||||||
|
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
|
||||||
|
if result.Err != nil {
|
||||||
|
return result.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(result.Rows[0][0]) == "on" {
|
||||||
|
return errors.New("read only connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,900 @@
|
|||||||
|
package pgconn_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"os/user"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgconn"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseConfig(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var osUserName string
|
||||||
|
osUser, err := user.Current()
|
||||||
|
if err == nil {
|
||||||
|
// Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`,
|
||||||
|
// but the libpq default is just the `user` portion, so we strip off the first part.
|
||||||
|
if runtime.GOOS == "windows" && strings.Contains(osUser.Username, "\\") {
|
||||||
|
osUserName = osUser.Username[strings.LastIndex(osUser.Username, "\\")+1:]
|
||||||
|
} else {
|
||||||
|
osUserName = osUser.Username
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := pgconn.ParseConfig("")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defaultHost := config.Host
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
connString string
|
||||||
|
config *pgconn.Config
|
||||||
|
}{
|
||||||
|
// Test all sslmodes
|
||||||
|
{
|
||||||
|
name: "sslmode not set (prefer)",
|
||||||
|
connString: "postgres://jack:secret@localhost:5432/mydb",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
Fallbacks: []*pgconn.FallbackConfig{
|
||||||
|
&pgconn.FallbackConfig{
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
TLSConfig: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sslmode disable",
|
||||||
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sslmode allow",
|
||||||
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=allow",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
Fallbacks: []*pgconn.FallbackConfig{
|
||||||
|
&pgconn.FallbackConfig{
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sslmode prefer",
|
||||||
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
Fallbacks: []*pgconn.FallbackConfig{
|
||||||
|
&pgconn.FallbackConfig{
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
TLSConfig: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sslmode require",
|
||||||
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=require",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sslmode verify-ca",
|
||||||
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-ca",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sslmode verify-full",
|
||||||
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-full",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: &tls.Config{ServerName: "localhost"},
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "database url everything",
|
||||||
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema&connect_timeout=5",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
ConnectTimeout: 5 * time.Second,
|
||||||
|
RuntimeParams: map[string]string{
|
||||||
|
"application_name": "pgxtest",
|
||||||
|
"search_path": "myschema",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "database url missing password",
|
||||||
|
connString: "postgres://jack@localhost:5432/mydb?sslmode=disable",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "database url missing user and password",
|
||||||
|
connString: "postgres://localhost:5432/mydb?sslmode=disable",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: osUserName,
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "database url missing port",
|
||||||
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "database url unix domain socket host",
|
||||||
|
connString: "postgres:///foo?host=/tmp",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: osUserName,
|
||||||
|
Host: "/tmp",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "foo",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "database url dbname",
|
||||||
|
connString: "postgres://localhost/?dbname=foo&sslmode=disable",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: osUserName,
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "foo",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "database url postgresql protocol",
|
||||||
|
connString: "postgresql://jack@localhost:5432/mydb?sslmode=disable",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "database url IPv4 with port",
|
||||||
|
connString: "postgresql://jack@127.0.0.1:5433/mydb?sslmode=disable",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Host: "127.0.0.1",
|
||||||
|
Port: 5433,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "database url IPv6 with port",
|
||||||
|
connString: "postgresql://jack@[2001:db8::1]:5433/mydb?sslmode=disable",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Host: "2001:db8::1",
|
||||||
|
Port: 5433,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "database url IPv6 no port",
|
||||||
|
connString: "postgresql://jack@[2001:db8::1]/mydb?sslmode=disable",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Host: "2001:db8::1",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "DSN everything",
|
||||||
|
connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema connect_timeout=5",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
ConnectTimeout: 5 * time.Second,
|
||||||
|
RuntimeParams: map[string]string{
|
||||||
|
"application_name": "pgxtest",
|
||||||
|
"search_path": "myschema",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "DSN with escaped single quote",
|
||||||
|
connString: "user=jack\\'s password=secret host=localhost port=5432 dbname=mydb sslmode=disable",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack's",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "DSN with escaped backslash",
|
||||||
|
connString: "user=jack password=sooper\\\\secret host=localhost port=5432 dbname=mydb sslmode=disable",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "sooper\\secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "DSN with single quoted values",
|
||||||
|
connString: "user='jack' host='localhost' dbname='mydb' sslmode='disable'",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "DSN with single quoted value with escaped single quote",
|
||||||
|
connString: "user='jack\\'s' host='localhost' dbname='mydb' sslmode='disable'",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack's",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "DSN with empty single quoted value",
|
||||||
|
connString: "user='jack' password='' host='localhost' dbname='mydb' sslmode='disable'",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "DSN with space between key and value",
|
||||||
|
connString: "user = 'jack' password = '' host = 'localhost' dbname = 'mydb' sslmode='disable'",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "URL multiple hosts",
|
||||||
|
connString: "postgres://jack:secret@foo,bar,baz/mydb?sslmode=disable",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "foo",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
Fallbacks: []*pgconn.FallbackConfig{
|
||||||
|
&pgconn.FallbackConfig{
|
||||||
|
Host: "bar",
|
||||||
|
Port: 5432,
|
||||||
|
TLSConfig: nil,
|
||||||
|
},
|
||||||
|
&pgconn.FallbackConfig{
|
||||||
|
Host: "baz",
|
||||||
|
Port: 5432,
|
||||||
|
TLSConfig: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "URL multiple hosts and ports",
|
||||||
|
connString: "postgres://jack:secret@foo:1,bar:2,baz:3/mydb?sslmode=disable",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "foo",
|
||||||
|
Port: 1,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
Fallbacks: []*pgconn.FallbackConfig{
|
||||||
|
&pgconn.FallbackConfig{
|
||||||
|
Host: "bar",
|
||||||
|
Port: 2,
|
||||||
|
TLSConfig: nil,
|
||||||
|
},
|
||||||
|
&pgconn.FallbackConfig{
|
||||||
|
Host: "baz",
|
||||||
|
Port: 3,
|
||||||
|
TLSConfig: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// https://github.com/jackc/pgconn/issues/72
|
||||||
|
{
|
||||||
|
name: "URL without host but with port still uses default host",
|
||||||
|
connString: "postgres://jack:secret@:1/mydb?sslmode=disable",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: defaultHost,
|
||||||
|
Port: 1,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "DSN multiple hosts one port",
|
||||||
|
connString: "user=jack password=secret host=foo,bar,baz port=5432 dbname=mydb sslmode=disable",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "foo",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
Fallbacks: []*pgconn.FallbackConfig{
|
||||||
|
&pgconn.FallbackConfig{
|
||||||
|
Host: "bar",
|
||||||
|
Port: 5432,
|
||||||
|
TLSConfig: nil,
|
||||||
|
},
|
||||||
|
&pgconn.FallbackConfig{
|
||||||
|
Host: "baz",
|
||||||
|
Port: 5432,
|
||||||
|
TLSConfig: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "DSN multiple hosts multiple ports",
|
||||||
|
connString: "user=jack password=secret host=foo,bar,baz port=1,2,3 dbname=mydb sslmode=disable",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "foo",
|
||||||
|
Port: 1,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
Fallbacks: []*pgconn.FallbackConfig{
|
||||||
|
&pgconn.FallbackConfig{
|
||||||
|
Host: "bar",
|
||||||
|
Port: 2,
|
||||||
|
TLSConfig: nil,
|
||||||
|
},
|
||||||
|
&pgconn.FallbackConfig{
|
||||||
|
Host: "baz",
|
||||||
|
Port: 3,
|
||||||
|
TLSConfig: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple hosts and fallback tsl",
|
||||||
|
connString: "user=jack password=secret host=foo,bar,baz dbname=mydb sslmode=prefer",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "foo",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
Fallbacks: []*pgconn.FallbackConfig{
|
||||||
|
&pgconn.FallbackConfig{
|
||||||
|
Host: "foo",
|
||||||
|
Port: 5432,
|
||||||
|
TLSConfig: nil,
|
||||||
|
},
|
||||||
|
&pgconn.FallbackConfig{
|
||||||
|
Host: "bar",
|
||||||
|
Port: 5432,
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
}},
|
||||||
|
&pgconn.FallbackConfig{
|
||||||
|
Host: "bar",
|
||||||
|
Port: 5432,
|
||||||
|
TLSConfig: nil,
|
||||||
|
},
|
||||||
|
&pgconn.FallbackConfig{
|
||||||
|
Host: "baz",
|
||||||
|
Port: 5432,
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
}},
|
||||||
|
&pgconn.FallbackConfig{
|
||||||
|
Host: "baz",
|
||||||
|
Port: 5432,
|
||||||
|
TLSConfig: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "target_session_attrs",
|
||||||
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsReadWrite,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
config, err := pgconn.ParseConfig(tt.connString)
|
||||||
|
if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/jackc/pgconn/issues/47
|
||||||
|
func TestParseConfigDSNWithTrailingEmptyEqualDoesNotPanic(t *testing.T) {
|
||||||
|
_, err := pgconn.ParseConfig("host= user= password= port= database=")
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseConfigDSNLeadingEqual(t *testing.T) {
|
||||||
|
_, err := pgconn.ParseConfig("= user=jack")
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/jackc/pgconn/issues/49
|
||||||
|
func TestParseConfigDSNTrailingBackslash(t *testing.T) {
|
||||||
|
_, err := pgconn.ParseConfig(`x=x\`)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "invalid backslash")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigCopyReturnsEqualConfig(t *testing.T) {
|
||||||
|
connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5"
|
||||||
|
original, err := pgconn.ParseConfig(connString)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
copied := original.Copy()
|
||||||
|
assertConfigsEqual(t, original, copied, "Test Config.Copy() returns equal config")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigCopyOriginalConfigDidNotChange(t *testing.T) {
|
||||||
|
connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5&sslmode=prefer"
|
||||||
|
original, err := pgconn.ParseConfig(connString)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
copied := original.Copy()
|
||||||
|
assertConfigsEqual(t, original, copied, "Test Config.Copy() returns equal config")
|
||||||
|
|
||||||
|
copied.Port = uint16(5433)
|
||||||
|
copied.RuntimeParams["foo"] = "bar"
|
||||||
|
copied.Fallbacks[0].Port = uint16(5433)
|
||||||
|
|
||||||
|
assert.Equal(t, uint16(5432), original.Port)
|
||||||
|
assert.Equal(t, "", original.RuntimeParams["foo"])
|
||||||
|
assert.Equal(t, uint16(5432), original.Fallbacks[0].Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigCopyCanBeUsedToConnect(t *testing.T) {
|
||||||
|
connString := os.Getenv("PGX_TEST_CONN_STRING")
|
||||||
|
original, err := pgconn.ParseConfig(connString)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
copied := original.Copy()
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
_, err = pgconn.ConnectConfig(context.Background(), copied)
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) {
|
||||||
|
if !assert.NotNil(t, expected) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !assert.NotNil(t, actual) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName)
|
||||||
|
assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName)
|
||||||
|
assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName)
|
||||||
|
assert.Equalf(t, expected.User, actual.User, "%s - User", testName)
|
||||||
|
assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName)
|
||||||
|
assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName)
|
||||||
|
assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
|
||||||
|
|
||||||
|
// Can't test function equality, so just test that they are set or not.
|
||||||
|
assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName)
|
||||||
|
assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName)
|
||||||
|
|
||||||
|
if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) {
|
||||||
|
if expected.TLSConfig != nil {
|
||||||
|
assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName)
|
||||||
|
assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) {
|
||||||
|
for i := range expected.Fallbacks {
|
||||||
|
assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i)
|
||||||
|
assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i)
|
||||||
|
|
||||||
|
if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) {
|
||||||
|
if expected.Fallbacks[i].TLSConfig != nil {
|
||||||
|
assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName)
|
||||||
|
assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseConfigEnvLibpq(t *testing.T) {
|
||||||
|
var osUserName string
|
||||||
|
osUser, err := user.Current()
|
||||||
|
if err == nil {
|
||||||
|
// Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`,
|
||||||
|
// but the libpq default is just the `user` portion, so we strip off the first part.
|
||||||
|
if runtime.GOOS == "windows" && strings.Contains(osUser.Username, "\\") {
|
||||||
|
osUserName = osUser.Username[strings.LastIndex(osUser.Username, "\\")+1:]
|
||||||
|
} else {
|
||||||
|
osUserName = osUser.Username
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT"}
|
||||||
|
|
||||||
|
savedEnv := make(map[string]string)
|
||||||
|
for _, n := range pgEnvvars {
|
||||||
|
savedEnv[n] = os.Getenv(n)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
for k, v := range savedEnv {
|
||||||
|
err := os.Setenv(k, v)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to restore environment: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
envvars map[string]string
|
||||||
|
config *pgconn.Config
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
// not testing no environment at all as that would use default host and that can vary.
|
||||||
|
name: "PGHOST only",
|
||||||
|
envvars: map[string]string{"PGHOST": "123.123.123.123"},
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: osUserName,
|
||||||
|
Host: "123.123.123.123",
|
||||||
|
Port: 5432,
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
Fallbacks: []*pgconn.FallbackConfig{
|
||||||
|
&pgconn.FallbackConfig{
|
||||||
|
Host: "123.123.123.123",
|
||||||
|
Port: 5432,
|
||||||
|
TLSConfig: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "All non-TLS environment",
|
||||||
|
envvars: map[string]string{
|
||||||
|
"PGHOST": "123.123.123.123",
|
||||||
|
"PGPORT": "7777",
|
||||||
|
"PGDATABASE": "foo",
|
||||||
|
"PGUSER": "bar",
|
||||||
|
"PGPASSWORD": "baz",
|
||||||
|
"PGCONNECT_TIMEOUT": "10",
|
||||||
|
"PGSSLMODE": "disable",
|
||||||
|
"PGAPPNAME": "pgxtest",
|
||||||
|
},
|
||||||
|
config: &pgconn.Config{
|
||||||
|
Host: "123.123.123.123",
|
||||||
|
Port: 7777,
|
||||||
|
Database: "foo",
|
||||||
|
User: "bar",
|
||||||
|
Password: "baz",
|
||||||
|
ConnectTimeout: 10 * time.Second,
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{"application_name": "pgxtest"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
for _, n := range pgEnvvars {
|
||||||
|
err := os.Unsetenv(n)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range tt.envvars {
|
||||||
|
err := os.Setenv(k, v)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := pgconn.ParseConfig("")
|
||||||
|
if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseConfigReadsPgPassfile(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tf, err := ioutil.TempFile("", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
defer tf.Close()
|
||||||
|
defer os.Remove(tf.Name())
|
||||||
|
|
||||||
|
_, err = tf.Write([]byte("test1:5432:curlydb:curly:nyuknyuknyuk"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tf.Name())
|
||||||
|
expected := &pgconn.Config{
|
||||||
|
User: "curly",
|
||||||
|
Password: "nyuknyuknyuk",
|
||||||
|
Host: "test1",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "curlydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
}
|
||||||
|
|
||||||
|
actual, err := pgconn.ParseConfig(connString)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assertConfigsEqual(t, expected, actual, "passfile")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseConfigReadsPgServiceFile(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tf, err := ioutil.TempFile("", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
defer tf.Close()
|
||||||
|
defer os.Remove(tf.Name())
|
||||||
|
|
||||||
|
_, err = tf.Write([]byte(`
|
||||||
|
[abc]
|
||||||
|
host=abc.example.com
|
||||||
|
port=9999
|
||||||
|
dbname=abcdb
|
||||||
|
user=abcuser
|
||||||
|
|
||||||
|
[def]
|
||||||
|
host = def.example.com
|
||||||
|
dbname = defdb
|
||||||
|
user = defuser
|
||||||
|
application_name = spaced string
|
||||||
|
`))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
connString string
|
||||||
|
config *pgconn.Config
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "abc",
|
||||||
|
connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tf.Name(), "abc"),
|
||||||
|
config: &pgconn.Config{
|
||||||
|
Host: "abc.example.com",
|
||||||
|
Database: "abcdb",
|
||||||
|
User: "abcuser",
|
||||||
|
Port: 9999,
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
Fallbacks: []*pgconn.FallbackConfig{
|
||||||
|
&pgconn.FallbackConfig{
|
||||||
|
Host: "abc.example.com",
|
||||||
|
Port: 9999,
|
||||||
|
TLSConfig: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "def",
|
||||||
|
connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tf.Name(), "def"),
|
||||||
|
config: &pgconn.Config{
|
||||||
|
Host: "def.example.com",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "defdb",
|
||||||
|
User: "defuser",
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
RuntimeParams: map[string]string{"application_name": "spaced string"},
|
||||||
|
Fallbacks: []*pgconn.FallbackConfig{
|
||||||
|
&pgconn.FallbackConfig{
|
||||||
|
Host: "def.example.com",
|
||||||
|
Port: 5432,
|
||||||
|
TLSConfig: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "conn string has precedence",
|
||||||
|
connString: fmt.Sprintf("postgres://other.example.com:7777/?servicefile=%s&service=%s&sslmode=disable", tf.Name(), "abc"),
|
||||||
|
config: &pgconn.Config{
|
||||||
|
Host: "other.example.com",
|
||||||
|
Database: "abcdb",
|
||||||
|
User: "abcuser",
|
||||||
|
Port: 7777,
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
config, err := pgconn.ParseConfig(tt.connString)
|
||||||
|
if !assert.NoErrorf(t, err, "Test %d (%s)", i, tt.name) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseConfigExtractsMinReadBufferSize(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
config, err := pgconn.ParseConfig("min_read_buffer_size=0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, present := config.RuntimeParams["min_read_buffer_size"]
|
||||||
|
require.False(t, present)
|
||||||
|
|
||||||
|
// The buffer size is internal so there isn't much that can be done to test it other than see that the runtime param
|
||||||
|
// was removed.
|
||||||
|
}
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
// +build !windows
|
||||||
|
|
||||||
|
package pgconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"os/user"
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
func defaultSettings() map[string]string {
|
||||||
|
settings := make(map[string]string)
|
||||||
|
|
||||||
|
settings["host"] = defaultHost()
|
||||||
|
settings["port"] = "5432"
|
||||||
|
|
||||||
|
// Default to the OS user name. Purposely ignoring err getting user name from
|
||||||
|
// OS. The client application will simply have to specify the user in that
|
||||||
|
// case (which they typically will be doing anyway).
|
||||||
|
user, err := user.Current()
|
||||||
|
if err == nil {
|
||||||
|
settings["user"] = user.Username
|
||||||
|
settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass")
|
||||||
|
settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf")
|
||||||
|
sslcert := filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt")
|
||||||
|
sslkey := filepath.Join(user.HomeDir, ".postgresql", "postgresql.key")
|
||||||
|
if _, err := os.Stat(sslcert); err == nil {
|
||||||
|
if _, err := os.Stat(sslkey); err == nil {
|
||||||
|
// Both the cert and key must be present to use them, or do not use either
|
||||||
|
settings["sslcert"] = sslcert
|
||||||
|
settings["sslkey"] = sslkey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sslrootcert := filepath.Join(user.HomeDir, ".postgresql", "root.crt")
|
||||||
|
if _, err := os.Stat(sslrootcert); err == nil {
|
||||||
|
settings["sslrootcert"] = sslrootcert
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
settings["target_session_attrs"] = "any"
|
||||||
|
|
||||||
|
settings["min_read_buffer_size"] = "8192"
|
||||||
|
|
||||||
|
return settings
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost
|
||||||
|
// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it
|
||||||
|
// checks the existence of common locations.
|
||||||
|
func defaultHost() string {
|
||||||
|
candidatePaths := []string{
|
||||||
|
"/var/run/postgresql", // Debian
|
||||||
|
"/private/tmp", // OSX - homebrew
|
||||||
|
"/tmp", // standard PostgreSQL
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, path := range candidatePaths {
|
||||||
|
if _, err := os.Stat(path); err == nil {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "localhost"
|
||||||
|
}
|
||||||
@@ -0,0 +1,59 @@
|
|||||||
|
package pgconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"os/user"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func defaultSettings() map[string]string {
|
||||||
|
settings := make(map[string]string)
|
||||||
|
|
||||||
|
settings["host"] = defaultHost()
|
||||||
|
settings["port"] = "5432"
|
||||||
|
|
||||||
|
// Default to the OS user name. Purposely ignoring err getting user name from
|
||||||
|
// OS. The client application will simply have to specify the user in that
|
||||||
|
// case (which they typically will be doing anyway).
|
||||||
|
user, err := user.Current()
|
||||||
|
appData := os.Getenv("APPDATA")
|
||||||
|
if err == nil {
|
||||||
|
// Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`,
|
||||||
|
// but the libpq default is just the `user` portion, so we strip off the first part.
|
||||||
|
username := user.Username
|
||||||
|
if strings.Contains(username, "\\") {
|
||||||
|
username = username[strings.LastIndex(username, "\\")+1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
settings["user"] = username
|
||||||
|
settings["passfile"] = filepath.Join(appData, "postgresql", "pgpass.conf")
|
||||||
|
settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf")
|
||||||
|
sslcert := filepath.Join(appData, "postgresql", "postgresql.crt")
|
||||||
|
sslkey := filepath.Join(appData, "postgresql", "postgresql.key")
|
||||||
|
if _, err := os.Stat(sslcert); err == nil {
|
||||||
|
if _, err := os.Stat(sslkey); err == nil {
|
||||||
|
// Both the cert and key must be present to use them, or do not use either
|
||||||
|
settings["sslcert"] = sslcert
|
||||||
|
settings["sslkey"] = sslkey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sslrootcert := filepath.Join(appData, "postgresql", "root.crt")
|
||||||
|
if _, err := os.Stat(sslrootcert); err == nil {
|
||||||
|
settings["sslrootcert"] = sslrootcert
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
settings["target_session_attrs"] = "any"
|
||||||
|
|
||||||
|
settings["min_read_buffer_size"] = "8192"
|
||||||
|
|
||||||
|
return settings
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost
|
||||||
|
// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it
|
||||||
|
// checks the existence of common locations.
|
||||||
|
func defaultHost() string {
|
||||||
|
return "localhost"
|
||||||
|
}
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
// Package pgconn is a low-level PostgreSQL database driver.
|
||||||
|
/*
|
||||||
|
pgconn provides lower level access to a PostgreSQL connection than a database/sql or pgx connection. It operates at
|
||||||
|
nearly the same level is the C library libpq.
|
||||||
|
|
||||||
|
Establishing a Connection
|
||||||
|
|
||||||
|
Use Connect to establish a connection. It accepts a connection string in URL or DSN and will read the environment for
|
||||||
|
libpq style environment variables.
|
||||||
|
|
||||||
|
Executing a Query
|
||||||
|
|
||||||
|
ExecParams and ExecPrepared execute a single query. They return readers that iterate over each row. The Read method
|
||||||
|
reads all rows into memory.
|
||||||
|
|
||||||
|
Executing Multiple Queries in a Single Round Trip
|
||||||
|
|
||||||
|
Exec and ExecBatch can execute multiple queries in a single round trip. They return readers that iterate over each query
|
||||||
|
result. The ReadAll method reads all query results into memory.
|
||||||
|
|
||||||
|
Context Support
|
||||||
|
|
||||||
|
All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the
|
||||||
|
method immediately returns. In most circumstances, this will close the underlying connection.
|
||||||
|
|
||||||
|
The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the
|
||||||
|
client to abort.
|
||||||
|
*/
|
||||||
|
package pgconn
|
||||||
@@ -0,0 +1,221 @@
|
|||||||
|
package pgconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server.
|
||||||
|
func SafeToRetry(err error) bool {
|
||||||
|
if e, ok := err.(interface{ SafeToRetry() bool }); ok {
|
||||||
|
return e.SafeToRetry()
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a
|
||||||
|
// context.Canceled, context.DeadlineExceeded or an implementer of net.Error where Timeout() is true.
|
||||||
|
func Timeout(err error) bool {
|
||||||
|
var timeoutErr *errTimeout
|
||||||
|
return errors.As(err, &timeoutErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PgError represents an error reported by the PostgreSQL server. See
|
||||||
|
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
|
||||||
|
// detailed field description.
|
||||||
|
type PgError struct {
|
||||||
|
Severity string
|
||||||
|
Code string
|
||||||
|
Message string
|
||||||
|
Detail string
|
||||||
|
Hint string
|
||||||
|
Position int32
|
||||||
|
InternalPosition int32
|
||||||
|
InternalQuery string
|
||||||
|
Where string
|
||||||
|
SchemaName string
|
||||||
|
TableName string
|
||||||
|
ColumnName string
|
||||||
|
DataTypeName string
|
||||||
|
ConstraintName string
|
||||||
|
File string
|
||||||
|
Line int32
|
||||||
|
Routine string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pe *PgError) Error() string {
|
||||||
|
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
|
||||||
|
}
|
||||||
|
|
||||||
|
// SQLState returns the SQLState of the error.
|
||||||
|
func (pe *PgError) SQLState() string {
|
||||||
|
return pe.Code
|
||||||
|
}
|
||||||
|
|
||||||
|
type connectError struct {
|
||||||
|
config *Config
|
||||||
|
msg string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *connectError) Error() string {
|
||||||
|
sb := &strings.Builder{}
|
||||||
|
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg)
|
||||||
|
if e.err != nil {
|
||||||
|
fmt.Fprintf(sb, " (%s)", e.err.Error())
|
||||||
|
}
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *connectError) Unwrap() error {
|
||||||
|
return e.err
|
||||||
|
}
|
||||||
|
|
||||||
|
type connLockError struct {
|
||||||
|
status string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *connLockError) SafeToRetry() bool {
|
||||||
|
return true // a lock failure by definition happens before the connection is used.
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *connLockError) Error() string {
|
||||||
|
return e.status
|
||||||
|
}
|
||||||
|
|
||||||
|
type parseConfigError struct {
|
||||||
|
connString string
|
||||||
|
msg string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *parseConfigError) Error() string {
|
||||||
|
connString := redactPW(e.connString)
|
||||||
|
if e.err == nil {
|
||||||
|
return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *parseConfigError) Unwrap() error {
|
||||||
|
return e.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() ==
|
||||||
|
// true. Otherwise returns err.
|
||||||
|
func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
|
||||||
|
if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil {
|
||||||
|
return &errTimeout{err: ctx.Err()}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
type pgconnError struct {
|
||||||
|
msg string
|
||||||
|
err error
|
||||||
|
safeToRetry bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *pgconnError) Error() string {
|
||||||
|
if e.msg == "" {
|
||||||
|
return e.err.Error()
|
||||||
|
}
|
||||||
|
if e.err == nil {
|
||||||
|
return e.msg
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s: %s", e.msg, e.err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *pgconnError) SafeToRetry() bool {
|
||||||
|
return e.safeToRetry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *pgconnError) Unwrap() error {
|
||||||
|
return e.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// errTimeout occurs when an error was caused by a timeout. Specifically, it wraps an error which is
|
||||||
|
// context.Canceled, context.DeadlineExceeded, or an implementer of net.Error where Timeout() is true.
|
||||||
|
type errTimeout struct {
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *errTimeout) Error() string {
|
||||||
|
return fmt.Sprintf("timeout: %s", e.err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *errTimeout) SafeToRetry() bool {
|
||||||
|
return SafeToRetry(e.err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *errTimeout) Unwrap() error {
|
||||||
|
return e.err
|
||||||
|
}
|
||||||
|
|
||||||
|
type contextAlreadyDoneError struct {
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *contextAlreadyDoneError) Error() string {
|
||||||
|
return fmt.Sprintf("context already done: %s", e.err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *contextAlreadyDoneError) SafeToRetry() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *contextAlreadyDoneError) Unwrap() error {
|
||||||
|
return e.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// newContextAlreadyDoneError double-wraps a context error in `contextAlreadyDoneError` and `errTimeout`.
|
||||||
|
func newContextAlreadyDoneError(ctx context.Context) (err error) {
|
||||||
|
return &errTimeout{&contextAlreadyDoneError{err: ctx.Err()}}
|
||||||
|
}
|
||||||
|
|
||||||
|
type writeError struct {
|
||||||
|
err error
|
||||||
|
safeToRetry bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *writeError) Error() string {
|
||||||
|
return fmt.Sprintf("write failed: %s", e.err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *writeError) SafeToRetry() bool {
|
||||||
|
return e.safeToRetry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *writeError) Unwrap() error {
|
||||||
|
return e.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func redactPW(connString string) string {
|
||||||
|
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
|
||||||
|
if u, err := url.Parse(connString); err == nil {
|
||||||
|
return redactURL(u)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
quotedDSN := regexp.MustCompile(`password='[^']*'`)
|
||||||
|
connString = quotedDSN.ReplaceAllLiteralString(connString, "password=xxxxx")
|
||||||
|
plainDSN := regexp.MustCompile(`password=[^ ]*`)
|
||||||
|
connString = plainDSN.ReplaceAllLiteralString(connString, "password=xxxxx")
|
||||||
|
brokenURL := regexp.MustCompile(`:[^:@]+?@`)
|
||||||
|
connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@")
|
||||||
|
return connString
|
||||||
|
}
|
||||||
|
|
||||||
|
func redactURL(u *url.URL) string {
|
||||||
|
if u == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if _, pwSet := u.User.Password(); pwSet {
|
||||||
|
u.User = url.UserPassword(u.User.Username(), "xxxxx")
|
||||||
|
}
|
||||||
|
return u.String()
|
||||||
|
}
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
package pgconn_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgconn"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConfigError(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
expectedMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "url with password",
|
||||||
|
err: pgconn.NewParseConfigError("postgresql://foo:password@host", "msg", nil),
|
||||||
|
expectedMsg: "cannot parse `postgresql://foo:xxxxx@host`: msg",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dsn with password unquoted",
|
||||||
|
err: pgconn.NewParseConfigError("host=host password=password user=user", "msg", nil),
|
||||||
|
expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dsn with password quoted",
|
||||||
|
err: pgconn.NewParseConfigError("host=host password='pass word' user=user", "msg", nil),
|
||||||
|
expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "weird url",
|
||||||
|
err: pgconn.NewParseConfigError("postgresql://foo::pasword@host:1:", "msg", nil),
|
||||||
|
expectedMsg: "cannot parse `postgresql://foo:xxxxx@host:1:`: msg",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "weird url with slash in password",
|
||||||
|
err: pgconn.NewParseConfigError("postgres://user:pass/word@host:5432/db_name", "msg", nil),
|
||||||
|
expectedMsg: "cannot parse `postgres://user:xxxxxx@host:5432/db_name`: msg",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "url without password",
|
||||||
|
err: pgconn.NewParseConfigError("postgresql://other@host/db", "msg", nil),
|
||||||
|
expectedMsg: "cannot parse `postgresql://other@host/db`: msg",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
assert.EqualError(t, tt.err, tt.expectedMsg)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
// File export_test exports some methods for better testing.
|
||||||
|
|
||||||
|
package pgconn
|
||||||
|
|
||||||
|
func NewParseConfigError(conn, msg string, err error) error {
|
||||||
|
return &parseConfigError{
|
||||||
|
connString: conn,
|
||||||
|
msg: msg,
|
||||||
|
err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,70 @@
|
|||||||
|
package pgconn_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgconn"
|
||||||
|
"github.com/jackc/pgproto3/v2"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// frontendWrapper allows to hijack a regular frontend, and inject a specific response
|
||||||
|
type frontendWrapper struct {
|
||||||
|
front pgconn.Frontend
|
||||||
|
|
||||||
|
msg pgproto3.BackendMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
// frontendWrapper implements the pgconn.Frontend interface
|
||||||
|
var _ pgconn.Frontend = (*frontendWrapper)(nil)
|
||||||
|
|
||||||
|
func (f *frontendWrapper) Receive() (pgproto3.BackendMessage, error) {
|
||||||
|
if f.msg != nil {
|
||||||
|
return f.msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return f.front.Receive()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFrontendFatalErrExec(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
buildFrontend := config.BuildFrontend
|
||||||
|
var front *frontendWrapper
|
||||||
|
|
||||||
|
config.BuildFrontend = func(r io.Reader, w io.Writer) pgconn.Frontend {
|
||||||
|
wrapped := buildFrontend(r, w)
|
||||||
|
front = &frontendWrapper{wrapped, nil}
|
||||||
|
|
||||||
|
return front
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := pgconn.ConnectConfig(context.Background(), config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
require.NotNil(t, front)
|
||||||
|
|
||||||
|
// set frontend to return a "FATAL" message on next call
|
||||||
|
front.msg = &pgproto3.ErrorResponse{Severity: "FATAL", Message: "unit testing fatal error"}
|
||||||
|
|
||||||
|
_, err = conn.Exec(context.Background(), "SELECT 1").ReadAll()
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
err = conn.Close(context.Background())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-conn.CleanupDone():
|
||||||
|
t.Log("ok, CleanupDone() is not blocking")
|
||||||
|
|
||||||
|
default:
|
||||||
|
assert.Fail(t, "connection closed but CleanupDone() still blocking")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
module github.com/jackc/pgconn
|
||||||
|
|
||||||
|
go 1.12
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/jackc/chunkreader/v2 v2.0.1
|
||||||
|
github.com/jackc/pgio v1.0.0
|
||||||
|
github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65
|
||||||
|
github.com/jackc/pgpassfile v1.0.0
|
||||||
|
github.com/jackc/pgproto3/v2 v2.1.1
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b
|
||||||
|
github.com/stretchr/testify v1.7.0
|
||||||
|
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97
|
||||||
|
golang.org/x/text v0.3.6
|
||||||
|
)
|
||||||
+130
@@ -0,0 +1,130 @@
|
|||||||
|
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
|
||||||
|
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
|
||||||
|
github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
|
||||||
|
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
|
||||||
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
||||||
|
github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0=
|
||||||
|
github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=
|
||||||
|
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
|
||||||
|
github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8=
|
||||||
|
github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
|
||||||
|
github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA=
|
||||||
|
github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE=
|
||||||
|
github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s=
|
||||||
|
github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o=
|
||||||
|
github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY=
|
||||||
|
github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE=
|
||||||
|
github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8=
|
||||||
|
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE=
|
||||||
|
github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c=
|
||||||
|
github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc=
|
||||||
|
github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak=
|
||||||
|
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||||
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
|
github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A=
|
||||||
|
github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78=
|
||||||
|
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
|
||||||
|
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg=
|
||||||
|
github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM=
|
||||||
|
github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM=
|
||||||
|
github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
|
||||||
|
github.com/jackc/pgproto3/v2 v2.1.1 h1:7PQ/4gLoqnl87ZxL7xjO0DR5gYuviDCZxQJsUlFW1eI=
|
||||||
|
github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg=
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E=
|
||||||
|
github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg=
|
||||||
|
github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc=
|
||||||
|
github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw=
|
||||||
|
github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y=
|
||||||
|
github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM=
|
||||||
|
github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc=
|
||||||
|
github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
|
||||||
|
github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
|
||||||
|
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||||
|
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||||
|
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||||
|
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||||
|
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||||
|
github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw=
|
||||||
|
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||||
|
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||||
|
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||||
|
github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||||
|
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||||
|
github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ=
|
||||||
|
github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
|
||||||
|
github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
|
||||||
|
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
|
||||||
|
github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
|
||||||
|
github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc=
|
||||||
|
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
|
||||||
|
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4=
|
||||||
|
github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
|
||||||
|
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
|
||||||
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
|
||||||
|
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||||
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
|
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||||
|
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
||||||
|
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
|
||||||
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
|
||||||
|
go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
|
||||||
|
go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
|
||||||
|
go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0=
|
||||||
|
go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
|
||||||
|
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
|
||||||
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
|
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE=
|
||||||
|
golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
|
golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
|
||||||
|
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
|
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI=
|
||||||
|
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
|
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||||
|
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||||
|
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
|
golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
|
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||||
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
|
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||||
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
|
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
|
golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M=
|
||||||
|
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
|
golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
||||||
|
golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||||
|
golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s=
|
||||||
|
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
|
||||||
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
package pgconn_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgconn"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func closeConn(t testing.TB, conn *pgconn.PgConn) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
require.NoError(t, conn.Close(ctx))
|
||||||
|
select {
|
||||||
|
case <-conn.CleanupDone():
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("Connection cleanup exceeded maximum time")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do a simple query to ensure the connection is still usable
|
||||||
|
func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
|
result := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil).Read()
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
require.Nil(t, result.Err)
|
||||||
|
assert.Equal(t, 3, len(result.Rows))
|
||||||
|
assert.Equal(t, "1", string(result.Rows[0][0]))
|
||||||
|
assert.Equal(t, "2", string(result.Rows[1][0]))
|
||||||
|
assert.Equal(t, "3", string(result.Rows[2][0]))
|
||||||
|
}
|
||||||
@@ -0,0 +1,73 @@
|
|||||||
|
package ctxwatch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a
|
||||||
|
// time.
|
||||||
|
type ContextWatcher struct {
|
||||||
|
onCancel func()
|
||||||
|
onUnwatchAfterCancel func()
|
||||||
|
unwatchChan chan struct{}
|
||||||
|
|
||||||
|
lock sync.Mutex
|
||||||
|
watchInProgress bool
|
||||||
|
onCancelWasCalled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled.
|
||||||
|
// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and
|
||||||
|
// onCancel called.
|
||||||
|
func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher {
|
||||||
|
cw := &ContextWatcher{
|
||||||
|
onCancel: onCancel,
|
||||||
|
onUnwatchAfterCancel: onUnwatchAfterCancel,
|
||||||
|
unwatchChan: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
return cw
|
||||||
|
}
|
||||||
|
|
||||||
|
// Watch starts watching ctx. If ctx is canceled then the onCancel function passed to NewContextWatcher will be called.
|
||||||
|
func (cw *ContextWatcher) Watch(ctx context.Context) {
|
||||||
|
cw.lock.Lock()
|
||||||
|
defer cw.lock.Unlock()
|
||||||
|
|
||||||
|
if cw.watchInProgress {
|
||||||
|
panic("Watch already in progress")
|
||||||
|
}
|
||||||
|
|
||||||
|
cw.onCancelWasCalled = false
|
||||||
|
|
||||||
|
if ctx.Done() != nil {
|
||||||
|
cw.watchInProgress = true
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
cw.onCancel()
|
||||||
|
cw.onCancelWasCalled = true
|
||||||
|
<-cw.unwatchChan
|
||||||
|
case <-cw.unwatchChan:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
} else {
|
||||||
|
cw.watchInProgress = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unwatch stops watching the previously watched context. If the onCancel function passed to NewContextWatcher was
|
||||||
|
// called then onUnwatchAfterCancel will also be called.
|
||||||
|
func (cw *ContextWatcher) Unwatch() {
|
||||||
|
cw.lock.Lock()
|
||||||
|
defer cw.lock.Unlock()
|
||||||
|
|
||||||
|
if cw.watchInProgress {
|
||||||
|
cw.unwatchChan <- struct{}{}
|
||||||
|
if cw.onCancelWasCalled {
|
||||||
|
cw.onUnwatchAfterCancel()
|
||||||
|
}
|
||||||
|
cw.watchInProgress = false
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,163 @@
|
|||||||
|
package ctxwatch_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgconn/internal/ctxwatch"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestContextWatcherContextCancelled(t *testing.T) {
|
||||||
|
canceledChan := make(chan struct{})
|
||||||
|
cleanupCalled := false
|
||||||
|
cw := ctxwatch.NewContextWatcher(func() {
|
||||||
|
canceledChan <- struct{}{}
|
||||||
|
}, func() {
|
||||||
|
cleanupCalled = true
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cw.Watch(ctx)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-canceledChan:
|
||||||
|
case <-time.NewTimer(time.Second).C:
|
||||||
|
t.Fatal("Timed out waiting for cancel func to be called")
|
||||||
|
}
|
||||||
|
|
||||||
|
cw.Unwatch()
|
||||||
|
|
||||||
|
require.True(t, cleanupCalled, "Cleanup func was not called")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) {
|
||||||
|
cw := ctxwatch.NewContextWatcher(func() {
|
||||||
|
t.Error("cancel func should not have been called")
|
||||||
|
}, func() {
|
||||||
|
t.Error("cleanup func should not have been called")
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cw.Watch(ctx)
|
||||||
|
cw.Unwatch()
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextWatcherMultipleWatchPanics(t *testing.T) {
|
||||||
|
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
cw.Watch(ctx)
|
||||||
|
|
||||||
|
ctx2, cancel2 := context.WithCancel(context.Background())
|
||||||
|
defer cancel2()
|
||||||
|
require.Panics(t, func() { cw.Watch(ctx2) }, "Expected panic when Watch called multiple times")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) {
|
||||||
|
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||||
|
cw.Unwatch() // unwatch when not / never watching
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
cw.Watch(ctx)
|
||||||
|
cw.Unwatch()
|
||||||
|
cw.Unwatch() // double unwatch
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextWatcherUnwatchIsConcurrencySafe(t *testing.T) {
|
||||||
|
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
cw.Watch(ctx)
|
||||||
|
|
||||||
|
go cw.Unwatch()
|
||||||
|
go cw.Unwatch()
|
||||||
|
|
||||||
|
<-ctx.Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextWatcherStress(t *testing.T) {
|
||||||
|
var cancelFuncCalls int64
|
||||||
|
var cleanupFuncCalls int64
|
||||||
|
|
||||||
|
cw := ctxwatch.NewContextWatcher(func() {
|
||||||
|
atomic.AddInt64(&cancelFuncCalls, 1)
|
||||||
|
}, func() {
|
||||||
|
atomic.AddInt64(&cleanupFuncCalls, 1)
|
||||||
|
})
|
||||||
|
|
||||||
|
cycleCount := 100000
|
||||||
|
|
||||||
|
for i := 0; i < cycleCount; i++ {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cw.Watch(ctx)
|
||||||
|
if i%2 == 0 {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Without time.Sleep, cw.Unwatch will almost always run before the cancel func which means cancel will never happen. This gives us a better mix.
|
||||||
|
if i%3 == 0 {
|
||||||
|
time.Sleep(time.Nanosecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
cw.Unwatch()
|
||||||
|
if i%2 == 1 {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
actualCancelFuncCalls := atomic.LoadInt64(&cancelFuncCalls)
|
||||||
|
actualCleanupFuncCalls := atomic.LoadInt64(&cleanupFuncCalls)
|
||||||
|
|
||||||
|
if actualCancelFuncCalls == 0 {
|
||||||
|
t.Fatal("actualCancelFuncCalls == 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
maxCancelFuncCalls := int64(cycleCount) / 2
|
||||||
|
if actualCancelFuncCalls > maxCancelFuncCalls {
|
||||||
|
t.Errorf("cancel func calls should be no more than %d but was %d", actualCancelFuncCalls, maxCancelFuncCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actualCancelFuncCalls != actualCleanupFuncCalls {
|
||||||
|
t.Errorf("cancel func calls (%d) should be equal to cleanup func calls (%d) but was not", actualCancelFuncCalls, actualCleanupFuncCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkContextWatcherUncancellable(b *testing.B) {
|
||||||
|
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
cw.Watch(context.Background())
|
||||||
|
cw.Unwatch()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkContextWatcherCancelled(b *testing.B) {
|
||||||
|
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cw.Watch(ctx)
|
||||||
|
cancel()
|
||||||
|
cw.Unwatch()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkContextWatcherCancellable(b *testing.B) {
|
||||||
|
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
cw.Watch(ctx)
|
||||||
|
cw.Unwatch()
|
||||||
|
}
|
||||||
|
}
|
||||||
+1703
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,90 @@
|
|||||||
|
package pgconn_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"math/rand"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgconn"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConnStress(t *testing.T) {
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
actionCount := 10000
|
||||||
|
if s := os.Getenv("PGX_TEST_STRESS_FACTOR"); s != "" {
|
||||||
|
stressFactor, err := strconv.ParseInt(s, 10, 64)
|
||||||
|
require.Nil(t, err, "Failed to parse PGX_TEST_STRESS_FACTOR")
|
||||||
|
actionCount *= int(stressFactor)
|
||||||
|
}
|
||||||
|
|
||||||
|
setupStressDB(t, pgConn)
|
||||||
|
|
||||||
|
actions := []struct {
|
||||||
|
name string
|
||||||
|
fn func(*pgconn.PgConn) error
|
||||||
|
}{
|
||||||
|
{"Exec Select", stressExecSelect},
|
||||||
|
{"ExecParams Select", stressExecParamsSelect},
|
||||||
|
{"Batch", stressBatch},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < actionCount; i++ {
|
||||||
|
action := actions[rand.Intn(len(actions))]
|
||||||
|
err := action.fn(pgConn)
|
||||||
|
require.Nilf(t, err, "%d: %s", i, action.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Each call with a context starts a goroutine. Ensure they are cleaned up when context is not canceled.
|
||||||
|
numGoroutine := runtime.NumGoroutine()
|
||||||
|
require.Truef(t, numGoroutine < 1000, "goroutines appear to be orphaned: %d in process", numGoroutine)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) {
|
||||||
|
_, err := pgConn.Exec(context.Background(), `
|
||||||
|
create temporary table widgets(
|
||||||
|
id serial primary key,
|
||||||
|
name varchar not null,
|
||||||
|
description text,
|
||||||
|
creation_time timestamptz default now()
|
||||||
|
);
|
||||||
|
|
||||||
|
insert into widgets(name, description) values
|
||||||
|
('Foo', 'bar'),
|
||||||
|
('baz', 'Something really long Something really long Something really long Something really long Something really long'),
|
||||||
|
('a', 'b')`).ReadAll()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func stressExecSelect(pgConn *pgconn.PgConn) error {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
_, err := pgConn.Exec(ctx, "select * from widgets").ReadAll()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func stressExecParamsSelect(pgConn *pgconn.PgConn) error {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
result := pgConn.ExecParams(ctx, "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read()
|
||||||
|
return result.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
func stressBatch(pgConn *pgconn.PgConn) error {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
batch := &pgconn.Batch{}
|
||||||
|
|
||||||
|
batch.ExecParams("select * from widgets", nil, nil, nil, nil)
|
||||||
|
batch.ExecParams("select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil)
|
||||||
|
_, err := pgConn.ExecBatch(ctx, batch).ReadAll()
|
||||||
|
return err
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,165 @@
|
|||||||
|
package stmtcache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"container/list"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/jackc/pgconn"
|
||||||
|
)
|
||||||
|
|
||||||
|
var lruCount uint64
|
||||||
|
|
||||||
|
// LRU implements Cache with a Least Recently Used (LRU) cache.
|
||||||
|
type LRU struct {
|
||||||
|
conn *pgconn.PgConn
|
||||||
|
mode int
|
||||||
|
cap int
|
||||||
|
prepareCount int
|
||||||
|
m map[string]*list.Element
|
||||||
|
l *list.List
|
||||||
|
psNamePrefix string
|
||||||
|
stmtsToClear []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLRU creates a new LRU. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache.
|
||||||
|
func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU {
|
||||||
|
mustBeValidMode(mode)
|
||||||
|
mustBeValidCap(cap)
|
||||||
|
|
||||||
|
n := atomic.AddUint64(&lruCount, 1)
|
||||||
|
|
||||||
|
return &LRU{
|
||||||
|
conn: conn,
|
||||||
|
mode: mode,
|
||||||
|
cap: cap,
|
||||||
|
m: make(map[string]*list.Element),
|
||||||
|
l: list.New(),
|
||||||
|
psNamePrefix: fmt.Sprintf("lrupsc_%d", n),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed.
|
||||||
|
func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
|
||||||
|
if ctx != context.Background() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// flush an outstanding bad statements
|
||||||
|
txStatus := c.conn.TxStatus()
|
||||||
|
if (txStatus == 'I' || txStatus == 'T') && len(c.stmtsToClear) > 0 {
|
||||||
|
for _, stmt := range c.stmtsToClear {
|
||||||
|
err := c.clearStmt(ctx, stmt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if el, ok := c.m[sql]; ok {
|
||||||
|
c.l.MoveToFront(el)
|
||||||
|
return el.Value.(*pgconn.StatementDescription), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.l.Len() == c.cap {
|
||||||
|
err := c.removeOldest(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
psd, err := c.prepare(ctx, sql)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
el := c.l.PushFront(psd)
|
||||||
|
c.m[sql] = el
|
||||||
|
|
||||||
|
return psd, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session.
|
||||||
|
func (c *LRU) Clear(ctx context.Context) error {
|
||||||
|
for c.l.Len() > 0 {
|
||||||
|
err := c.removeOldest(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LRU) StatementErrored(sql string, err error) {
|
||||||
|
pgErr, ok := err.(*pgconn.PgError)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
isInvalidCachedPlanError := pgErr.Severity == "ERROR" &&
|
||||||
|
pgErr.Code == "0A000" &&
|
||||||
|
pgErr.Message == "cached plan must not change result type"
|
||||||
|
if isInvalidCachedPlanError {
|
||||||
|
c.stmtsToClear = append(c.stmtsToClear, sql)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LRU) clearStmt(ctx context.Context, sql string) error {
|
||||||
|
elem, inMap := c.m[sql]
|
||||||
|
if !inMap {
|
||||||
|
// The statement probably fell off the back of the list. In that case, we've
|
||||||
|
// ensured that it isn't in the cache, so we can declare victory.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.l.Remove(elem)
|
||||||
|
|
||||||
|
psd := elem.Value.(*pgconn.StatementDescription)
|
||||||
|
delete(c.m, psd.SQL)
|
||||||
|
if c.mode == ModePrepare {
|
||||||
|
return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the number of cached prepared statement descriptions.
|
||||||
|
func (c *LRU) Len() int {
|
||||||
|
return c.l.Len()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cap returns the maximum number of cached prepared statement descriptions.
|
||||||
|
func (c *LRU) Cap() int {
|
||||||
|
return c.cap
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
|
||||||
|
func (c *LRU) Mode() int {
|
||||||
|
return c.mode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
|
||||||
|
var name string
|
||||||
|
if c.mode == ModePrepare {
|
||||||
|
name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount)
|
||||||
|
c.prepareCount += 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.conn.Prepare(ctx, name, sql, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LRU) removeOldest(ctx context.Context) error {
|
||||||
|
oldest := c.l.Back()
|
||||||
|
c.l.Remove(oldest)
|
||||||
|
psd := oldest.Value.(*pgconn.StatementDescription)
|
||||||
|
delete(c.m, psd.SQL)
|
||||||
|
if c.mode == ModePrepare {
|
||||||
|
return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,292 @@
|
|||||||
|
package stmtcache_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"os"
|
||||||
|
"regexp"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgconn"
|
||||||
|
"github.com/jackc/pgconn/stmtcache"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLRUModePrepare(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close(ctx)
|
||||||
|
|
||||||
|
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2)
|
||||||
|
require.EqualValues(t, 0, cache.Len())
|
||||||
|
require.EqualValues(t, 2, cache.Cap())
|
||||||
|
require.EqualValues(t, stmtcache.ModePrepare, cache.Mode())
|
||||||
|
|
||||||
|
psd, err := cache.Get(ctx, "select 1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, psd)
|
||||||
|
require.EqualValues(t, 1, cache.Len())
|
||||||
|
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
|
||||||
|
|
||||||
|
psd, err = cache.Get(ctx, "select 1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, psd)
|
||||||
|
require.EqualValues(t, 1, cache.Len())
|
||||||
|
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
|
||||||
|
|
||||||
|
psd, err = cache.Get(ctx, "select 2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, psd)
|
||||||
|
require.EqualValues(t, 2, cache.Len())
|
||||||
|
require.ElementsMatch(t, []string{"select 1", "select 2"}, fetchServerStatements(t, ctx, conn))
|
||||||
|
|
||||||
|
psd, err = cache.Get(ctx, "select 3")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, psd)
|
||||||
|
require.EqualValues(t, 2, cache.Len())
|
||||||
|
require.ElementsMatch(t, []string{"select 2", "select 3"}, fetchServerStatements(t, ctx, conn))
|
||||||
|
|
||||||
|
err = cache.Clear(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 0, cache.Len())
|
||||||
|
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLRUStmtInvalidation(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close(ctx)
|
||||||
|
|
||||||
|
// we construct a fake error because its not super straightforward to actually call
|
||||||
|
// a prepared statement from the LRU cache without the helper routines which live
|
||||||
|
// in pgx proper.
|
||||||
|
fakeInvalidCachePlanError := &pgconn.PgError{
|
||||||
|
Severity: "ERROR",
|
||||||
|
Code: "0A000",
|
||||||
|
Message: "cached plan must not change result type",
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2)
|
||||||
|
|
||||||
|
//
|
||||||
|
// outside of a transaction, we eagerly flush the statement
|
||||||
|
//
|
||||||
|
|
||||||
|
_, err = cache.Get(ctx, "select 1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 1, cache.Len())
|
||||||
|
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
|
||||||
|
|
||||||
|
cache.StatementErrored("select 1", fakeInvalidCachePlanError)
|
||||||
|
_, err = cache.Get(ctx, "select 2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 1, cache.Len())
|
||||||
|
require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn))
|
||||||
|
|
||||||
|
err = cache.Clear(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
//
|
||||||
|
// within an errored transaction, we defer the flush to after the first get
|
||||||
|
// that happens after the transaction is rolled back
|
||||||
|
//
|
||||||
|
|
||||||
|
_, err = cache.Get(ctx, "select 1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 1, cache.Len())
|
||||||
|
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
|
||||||
|
|
||||||
|
res := conn.Exec(ctx, "begin")
|
||||||
|
require.NoError(t, res.Close())
|
||||||
|
require.Equal(t, byte('T'), conn.TxStatus())
|
||||||
|
|
||||||
|
res = conn.Exec(ctx, "selec")
|
||||||
|
require.Error(t, res.Close())
|
||||||
|
require.Equal(t, byte('E'), conn.TxStatus())
|
||||||
|
|
||||||
|
cache.StatementErrored("select 1", fakeInvalidCachePlanError)
|
||||||
|
require.EqualValues(t, 1, cache.Len())
|
||||||
|
|
||||||
|
res = conn.Exec(ctx, "rollback")
|
||||||
|
require.NoError(t, res.Close())
|
||||||
|
|
||||||
|
_, err = cache.Get(ctx, "select 2")
|
||||||
|
require.EqualValues(t, 1, cache.Len())
|
||||||
|
require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLRUStmtInvalidationIntegration(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close(ctx)
|
||||||
|
|
||||||
|
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2)
|
||||||
|
|
||||||
|
result := conn.ExecParams(ctx, "create temporary table stmtcache_table (a text)", nil, nil, nil, nil).Read()
|
||||||
|
require.NoError(t, result.Err)
|
||||||
|
|
||||||
|
sql := "select * from stmtcache_table"
|
||||||
|
sd1, err := cache.Get(ctx, sql)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read()
|
||||||
|
require.NoError(t, result.Err)
|
||||||
|
|
||||||
|
result = conn.ExecParams(ctx, "alter table stmtcache_table add column b text", nil, nil, nil, nil).Read()
|
||||||
|
require.NoError(t, result.Err)
|
||||||
|
|
||||||
|
result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read()
|
||||||
|
require.EqualError(t, result.Err, "ERROR: cached plan must not change result type (SQLSTATE 0A000)")
|
||||||
|
|
||||||
|
cache.StatementErrored(sql, result.Err)
|
||||||
|
|
||||||
|
sd2, err := cache.Get(ctx, sql)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEqual(t, sd1.Name, sd2.Name)
|
||||||
|
|
||||||
|
result = conn.ExecPrepared(ctx, sd2.Name, nil, nil, nil).Read()
|
||||||
|
require.NoError(t, result.Err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLRUModePrepareStress(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close(ctx)
|
||||||
|
|
||||||
|
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 8)
|
||||||
|
require.EqualValues(t, 0, cache.Len())
|
||||||
|
require.EqualValues(t, 8, cache.Cap())
|
||||||
|
require.EqualValues(t, stmtcache.ModePrepare, cache.Mode())
|
||||||
|
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
psd, err := cache.Get(ctx, fmt.Sprintf("select %d", rand.Intn(50)))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, psd)
|
||||||
|
result := conn.ExecPrepared(ctx, psd.Name, nil, nil, nil).Read()
|
||||||
|
require.NoError(t, result.Err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLRUModeDescribe(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close(ctx)
|
||||||
|
|
||||||
|
cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2)
|
||||||
|
require.EqualValues(t, 0, cache.Len())
|
||||||
|
require.EqualValues(t, 2, cache.Cap())
|
||||||
|
require.EqualValues(t, stmtcache.ModeDescribe, cache.Mode())
|
||||||
|
|
||||||
|
psd, err := cache.Get(ctx, "select 1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, psd)
|
||||||
|
require.EqualValues(t, 1, cache.Len())
|
||||||
|
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
||||||
|
|
||||||
|
psd, err = cache.Get(ctx, "select 1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, psd)
|
||||||
|
require.EqualValues(t, 1, cache.Len())
|
||||||
|
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
||||||
|
|
||||||
|
psd, err = cache.Get(ctx, "select 2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, psd)
|
||||||
|
require.EqualValues(t, 2, cache.Len())
|
||||||
|
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
||||||
|
|
||||||
|
psd, err = cache.Get(ctx, "select 3")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, psd)
|
||||||
|
require.EqualValues(t, 2, cache.Len())
|
||||||
|
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
||||||
|
|
||||||
|
err = cache.Clear(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 0, cache.Len())
|
||||||
|
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLRUContext(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close(ctx)
|
||||||
|
|
||||||
|
cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2)
|
||||||
|
|
||||||
|
// test 1 : getting a value for the first time with a cancelled context returns an error
|
||||||
|
ctx1, cancel1 := context.WithCancel(ctx)
|
||||||
|
cancel1()
|
||||||
|
|
||||||
|
desc, err := cache.Get(ctx1, "SELECT 1")
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, desc)
|
||||||
|
|
||||||
|
// test 2 : when querying for the 2nd time a cached value, if the context is canceled return an error
|
||||||
|
ctx2, cancel2 := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
desc, err = cache.Get(ctx2, "SELECT 2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, desc)
|
||||||
|
|
||||||
|
cancel2()
|
||||||
|
|
||||||
|
desc, err = cache.Get(ctx2, "SELECT 2")
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, desc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchServerStatements(t testing.TB, ctx context.Context, conn *pgconn.PgConn) []string {
|
||||||
|
result := conn.ExecParams(ctx, `select statement from pg_prepared_statements`, nil, nil, nil, nil).Read()
|
||||||
|
require.NoError(t, result.Err)
|
||||||
|
var statements []string
|
||||||
|
for _, r := range result.Rows {
|
||||||
|
statement := string(r[0])
|
||||||
|
if conn.ParameterStatus("crdb_version") != "" {
|
||||||
|
if statement == "PREPARE AS select statement from pg_prepared_statements" {
|
||||||
|
// CockroachDB includes the currently running unnamed prepared statement while PostgreSQL does not. Ignore it.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// CockroachDB includes the "PREPARE ... AS" text in the statement even if it was prepared through the extended
|
||||||
|
// protocol will PostgreSQL does not. Normalize the statement.
|
||||||
|
re := regexp.MustCompile(`^PREPARE lrupsc[0-9_]+ AS `)
|
||||||
|
statement = re.ReplaceAllString(statement, "")
|
||||||
|
}
|
||||||
|
statements = append(statements, statement)
|
||||||
|
}
|
||||||
|
return statements
|
||||||
|
}
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
// Package stmtcache is a cache that can be used to implement lazy prepared statements.
|
||||||
|
package stmtcache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/jackc/pgconn"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ModePrepare = iota // Cache should prepare named statements.
|
||||||
|
ModeDescribe // Cache should prepare the anonymous prepared statement to only fetch the description of the statement.
|
||||||
|
)
|
||||||
|
|
||||||
|
// Cache prepares and caches prepared statement descriptions.
|
||||||
|
type Cache interface {
|
||||||
|
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed.
|
||||||
|
Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error)
|
||||||
|
|
||||||
|
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session.
|
||||||
|
Clear(ctx context.Context) error
|
||||||
|
|
||||||
|
// StatementErrored informs the cache that the given statement resulted in an error when it
|
||||||
|
// was last used against the database. In some cases, this will cause the cache to maer that
|
||||||
|
// statement as bad. The bad statement will instead be flushed during the next call to Get
|
||||||
|
// that occurs outside of a failed transaction.
|
||||||
|
StatementErrored(sql string, err error)
|
||||||
|
|
||||||
|
// Len returns the number of cached prepared statement descriptions.
|
||||||
|
Len() int
|
||||||
|
|
||||||
|
// Cap returns the maximum number of cached prepared statement descriptions.
|
||||||
|
Cap() int
|
||||||
|
|
||||||
|
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
|
||||||
|
Mode() int
|
||||||
|
}
|
||||||
|
|
||||||
|
// New returns the preferred cache implementation for mode and cap. mode is either ModePrepare or ModeDescribe. cap is
|
||||||
|
// the maximum size of the cache.
|
||||||
|
func New(conn *pgconn.PgConn, mode int, cap int) Cache {
|
||||||
|
mustBeValidMode(mode)
|
||||||
|
mustBeValidCap(cap)
|
||||||
|
|
||||||
|
return NewLRU(conn, mode, cap)
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustBeValidMode(mode int) {
|
||||||
|
if mode != ModePrepare && mode != ModeDescribe {
|
||||||
|
panic("mode must be ModePrepare or ModeDescribe")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustBeValidCap(cap int) {
|
||||||
|
if cap < 1 {
|
||||||
|
panic("cache must have cap of >= 1")
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user