Merge branch 'v5-dev'
This commit is contained in:
@@ -2,7 +2,7 @@ name: CI
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ master ]
|
branches: [ master, v5-dev ]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ master ]
|
branches: [ master ]
|
||||||
|
|
||||||
@@ -14,21 +14,52 @@ jobs:
|
|||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
go-version: [1.16, 1.17]
|
go-version: [1.18]
|
||||||
pg-version: [10, 11, 12, 13, 14, cockroachdb]
|
pg-version: [10, 11, 12, 13, 14, cockroachdb]
|
||||||
include:
|
include:
|
||||||
- pg-version: 10
|
- pg-version: 10
|
||||||
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||||
|
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
|
||||||
|
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
|
||||||
- pg-version: 11
|
- pg-version: 11
|
||||||
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||||
|
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
|
||||||
|
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
|
||||||
- pg-version: 12
|
- pg-version: 12
|
||||||
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||||
|
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
|
||||||
|
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
|
||||||
- pg-version: 13
|
- pg-version: 13
|
||||||
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||||
|
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
|
||||||
|
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
|
||||||
- pg-version: 14
|
- pg-version: 14
|
||||||
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||||
|
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
|
||||||
|
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||||
|
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
|
||||||
- pg-version: cockroachdb
|
- pg-version: cockroachdb
|
||||||
pgx-test-database: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on"
|
pgx-test-database: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on"
|
||||||
|
pgx-test-conn-string: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on"
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
@@ -49,3 +80,9 @@ jobs:
|
|||||||
run: go test -race ./...
|
run: go test -race ./...
|
||||||
env:
|
env:
|
||||||
PGX_TEST_DATABASE: ${{ matrix.pgx-test-database }}
|
PGX_TEST_DATABASE: ${{ matrix.pgx-test-database }}
|
||||||
|
PGX_TEST_CONN_STRING: ${{ matrix.pgx-test-conn-string }}
|
||||||
|
PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }}
|
||||||
|
PGX_TEST_TCP_CONN_STRING: ${{ matrix.pgx-test-tcp-conn-string }}
|
||||||
|
PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }}
|
||||||
|
PGX_TEST_MD5_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-md5-password-conn-string }}
|
||||||
|
PGX_TEST_PLAIN_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-plain-password-conn-string }}
|
||||||
|
|||||||
@@ -0,0 +1,9 @@
|
|||||||
|
language: go
|
||||||
|
|
||||||
|
go:
|
||||||
|
- 1.x
|
||||||
|
- tip
|
||||||
|
|
||||||
|
matrix:
|
||||||
|
allow_failures:
|
||||||
|
- go: tip
|
||||||
+115
-206
@@ -1,268 +1,177 @@
|
|||||||
# 4.17.2 (September 3, 2022)
|
# v5.0.0
|
||||||
|
|
||||||
* Fix panic when logging batch error (Tom Möller)
|
## Merged Packages
|
||||||
|
|
||||||
# 4.17.1 (August 27, 2022)
|
`github.com/jackc/pgtype`, `github.com/jackc/pgconn`, and `github.com/jackc/pgproto3` are now included in the main
|
||||||
|
`github.com/jackc/pgx` repository. Previously there was confusion as to where issues should be reported, additional
|
||||||
|
release work due to releasing multiple packages, and less clear changelogs.
|
||||||
|
|
||||||
* Upgrade puddle to v1.3.0 - fixes context failing to cancel Acquire when acquire is creating resource which was introduced in v4.17.0 (James Hartig)
|
## pgconn
|
||||||
* Fix atomic alignment on 32-bit platforms
|
|
||||||
|
|
||||||
# 4.17.0 (August 6, 2022)
|
`CommandTag` is now an opaque type instead of directly exposing an underlying `[]byte`.
|
||||||
|
|
||||||
* Upgrade pgconn to v1.13.0
|
The return value `ResultReader.Values()` is no longer safe to retain a reference to after a subsequent call to `NextRow()` or `Close()`.
|
||||||
* Upgrade pgproto3 to v2.3.1
|
|
||||||
* Upgrade pgtype to v1.12.0
|
|
||||||
* Allow background pool connections to continue even if cause is canceled (James Hartig)
|
|
||||||
* Add LoggerFunc (Gabor Szabad)
|
|
||||||
* pgxpool: health check should avoid going below minConns (James Hartig)
|
|
||||||
* Add pgxpool.Conn.Hijack()
|
|
||||||
* Logging improvements (Stepan Rabotkin)
|
|
||||||
|
|
||||||
# 4.16.1 (May 7, 2022)
|
`Trace()` method adds low level message tracing similar to the `PQtrace` function in `libpq`.
|
||||||
|
|
||||||
* Upgrade pgconn to v1.12.1
|
pgconn now uses non-blocking IO. This is a significant internal restructuring, but it should not cause any visible changes on its own. However, it is important in implementing other new features.
|
||||||
* Fix explicitly prepared statements with describe statement cache mode
|
|
||||||
|
|
||||||
# 4.16.0 (April 21, 2022)
|
`CheckConn()` checks a connection's liveness by doing a non-blocking read. This can be used to detect database restarts or network interruptions without executing a query or a ping.
|
||||||
|
|
||||||
* Upgrade pgconn to v1.12.0
|
pgconn now supports pipeline mode.
|
||||||
* Upgrade pgproto3 to v2.3.0
|
|
||||||
* Upgrade pgtype to v1.11.0
|
|
||||||
* Fix: Do not panic when context cancelled while getting statement from cache.
|
|
||||||
* Fix: Less memory pinning from old Rows.
|
|
||||||
* Fix: Support '\r' line ending when sanitizing SQL comment.
|
|
||||||
* Add pluggable GSSAPI support (Oliver Tan)
|
|
||||||
|
|
||||||
# 4.15.0 (February 7, 2022)
|
`*PgConn.ReceiveResults` removed. Use pipeline mode instead.
|
||||||
|
|
||||||
* Upgrade to pgconn v1.11.0
|
`Timeout()` no longer considers `context.Canceled` as a timeout error. `context.DeadlineExceeded` still is considered a timeout error.
|
||||||
* Upgrade to pgtype v1.10.0
|
|
||||||
* Upgrade puddle to v1.2.1
|
|
||||||
* Make BatchResults.Close safe to be called multiple times
|
|
||||||
|
|
||||||
# 4.14.1 (November 28, 2021)
|
## pgxpool
|
||||||
|
|
||||||
* Upgrade pgtype to v1.9.1 (fixes unintentional change to timestamp binary decoding)
|
`Connect` and `ConnectConfig` have been renamed to `New` and `NewWithConfig` respectively. The `LazyConnect` option has been removed. Pools always lazily connect.
|
||||||
* Start pgxpool background health check after initial connections
|
|
||||||
|
|
||||||
# 4.14.0 (November 20, 2021)
|
## pgtype
|
||||||
|
|
||||||
* Upgrade pgconn to v1.10.1
|
The `pgtype` package has been significantly changed.
|
||||||
* Upgrade pgproto3 to v2.2.0
|
|
||||||
* Upgrade pgtype to v1.9.0
|
|
||||||
* Upgrade puddle to v1.2.0
|
|
||||||
* Add QueryFunc to BatchResults
|
|
||||||
* Add context options to zerologadapter (Thomas Frössman)
|
|
||||||
* Add zerologadapter.NewContextLogger (urso)
|
|
||||||
* Eager initialize minpoolsize on connect (Daniel)
|
|
||||||
* Unpin memory used by large queries immediately after use
|
|
||||||
|
|
||||||
# 4.13.0 (July 24, 2021)
|
### NULL Representation
|
||||||
|
|
||||||
* Trimmed pseudo-dependencies in Go modules from other packages tests
|
Previously, types had a `Status` field that could be `Undefined`, `Null`, or `Present`. This has been changed to a
|
||||||
* Upgrade pgconn -- context cancellation no longer will return a net.Error
|
`Valid` `bool` field to harmonize with how `database/sql` represents `NULL` and to make the zero value useable.
|
||||||
* Support time durations for simple protocol (Michael Darr)
|
|
||||||
|
|
||||||
# 4.12.0 (July 10, 2021)
|
### Codec and Value Split
|
||||||
|
|
||||||
* ResetSession hook is called before a connection is reused from pool for another query (Dmytro Haranzha)
|
Previously, the type system combined decoding and encoding values with the value types. e.g. Type `Int8` both handled
|
||||||
* stdlib: Add RandomizeHostOrderFunc (dkinder)
|
encoding and decoding the PostgreSQL representation and acted as a value object. This caused some difficulties when
|
||||||
* stdlib: add OptionBeforeConnect (dkinder)
|
there was not an exact 1 to 1 relationship between the Go types and the PostgreSQL types For example, scanning a
|
||||||
* stdlib: Do not reuse ConnConfig strings (Andrew Kimball)
|
PostgreSQL binary `numeric` into a Go `float64` was awkward (see https://github.com/jackc/pgtype/issues/147). This
|
||||||
* stdlib: implement Conn.ResetSession (Jonathan Amsterdam)
|
concepts have been separated. A `Codec` only has responsibility for encoding and decoding values. Value types are
|
||||||
* Upgrade pgconn to v1.9.0
|
generally defined by implementing an interface that a particular `Codec` understands (e.g. `PointScanner` and
|
||||||
* Upgrade pgtype to v1.8.0
|
`PointValuer` for the PostgreSQL `point` type).
|
||||||
|
|
||||||
# 4.11.0 (March 25, 2021)
|
### Array Types
|
||||||
|
|
||||||
* Add BeforeConnect callback to pgxpool.Config (Robert Froehlich)
|
All array types are now handled by `ArrayCodec` instead of using code generation for each new array type. This also
|
||||||
* Add Ping method to pgxpool.Conn (davidsbond)
|
means that less common array types such as `point[]` are now supported. `Array[T]` supports PostgreSQL multi-dimensional
|
||||||
* Added a kitlog level log adapter (Fabrice Aneche)
|
arrays.
|
||||||
* Make ScanArgError public to allow identification of offending column (Pau Sanchez)
|
|
||||||
* Add *pgxpool.AcquireFunc
|
|
||||||
* Add BeginFunc and BeginTxFunc
|
|
||||||
* Add prefer_simple_protocol to connection string
|
|
||||||
* Add logging on CopyFrom (Patrick Hemmer)
|
|
||||||
* Add comment support when sanitizing SQL queries (Rusakow Andrew)
|
|
||||||
* Do not panic on double close of pgxpool.Pool (Matt Schultz)
|
|
||||||
* Avoid panic on SendBatch on closed Tx (Matt Schultz)
|
|
||||||
* Update pgconn to v1.8.1
|
|
||||||
* Update pgtype to v1.7.0
|
|
||||||
|
|
||||||
# 4.10.1 (December 19, 2020)
|
### Composite Types
|
||||||
|
|
||||||
* Fix panic on Query error with nil stmtcache.
|
Composite types must be registered before use. `CompositeFields` may still be used to construct and destruct composite
|
||||||
|
values, but any type may now implement `CompositeIndexGetter` and `CompositeIndexScanner` to be used as a composite.
|
||||||
|
|
||||||
# 4.10.0 (December 3, 2020)
|
### Range Types
|
||||||
|
|
||||||
* Add CopyFromSlice to simplify CopyFrom usage (Egon Elbre)
|
Range types are now handled with types `RangeCodec` and `Range[T]`. This allows additional user defined range types to
|
||||||
* Remove broken prepared statements from stmtcache (Ethan Pailes)
|
easily be handled. Multirange types are handled similarly with `MultirangeCodec` and `Multirange[T]`.
|
||||||
* stdlib: consider any Ping error as fatal
|
|
||||||
* Update puddle to v1.1.3 - this fixes an issue where concurrent Acquires can hang when a connection cannot be established
|
|
||||||
* Update pgtype to v1.6.2
|
|
||||||
|
|
||||||
# 4.9.2 (November 3, 2020)
|
### pgxtype
|
||||||
|
|
||||||
The underlying library updates fix an issue where appending to a scanned slice could corrupt other data.
|
`LoadDataType` moved to `*Conn` as `LoadType`.
|
||||||
|
|
||||||
* Update pgconn to v1.7.2
|
### Bytea
|
||||||
* Update pgproto3 to v2.0.6
|
|
||||||
|
|
||||||
# 4.9.1 (October 31, 2020)
|
The `Bytea` and `GenericBinary` types have been replaced. Use the following instead:
|
||||||
|
|
||||||
* Update pgconn to v1.7.1
|
* `[]byte` - For normal usage directly use `[]byte`.
|
||||||
* Update pgtype to v1.6.1
|
* `DriverBytes` - Uses driver memory only available until next database method call. Avoids a copy and an allocation.
|
||||||
* Fix SendBatch of all prepared statements with statement cache disabled
|
* `PreallocBytes` - Uses preallocated byte slice to avoid an allocation.
|
||||||
|
* `UndecodedBytes` - Avoids any decoding. Allows working with raw bytes.
|
||||||
|
|
||||||
# 4.9.0 (September 26, 2020)
|
### Dropped lib/pq Support
|
||||||
|
|
||||||
* pgxpool now waits for connection cleanup to finish before making room in pool for another connection. This prevents temporarily exceeding max pool size.
|
`pgtype` previously supported and was tested against [lib/pq](https://github.com/lib/pq). While it will continue to work
|
||||||
* Fix when scanning a column to nil to skip it on the first row but scanning it to a real value on a subsequent row.
|
in most cases this is no longer supported.
|
||||||
* Fix prefer simple protocol with prepared statements. (Jinzhu)
|
|
||||||
* Fix FieldDescriptions not being available on Rows before calling Next the first time.
|
|
||||||
* Various minor fixes in updated versions of pgconn, pgtype, and puddle.
|
|
||||||
|
|
||||||
# 4.8.1 (July 29, 2020)
|
### database/sql Scan
|
||||||
|
|
||||||
* Update pgconn to v1.6.4
|
Previously, most `Scan` implementations would convert `[]byte` to `string` automatically to decode a text value. Now
|
||||||
* Fix deadlock on error after CommandComplete but before ReadyForQuery
|
only `string` is handled. This is to allow the possibility of future binary support in `database/sql` mode by
|
||||||
* Fix panic on parsing DSN with trailing '='
|
considering `[]byte` to be binary format and `string` text format. This change should have no effect for any use with
|
||||||
|
`pgx`. The previous behavior was only necessary for `lib/pq` compatibility.
|
||||||
|
|
||||||
# 4.8.0 (July 22, 2020)
|
Added `*Map.SQLScanner` to create a `sql.Scanner` for types such as `[]int32` and `Range[T]` that do not implement
|
||||||
|
`sql.Scanner` directly.
|
||||||
|
|
||||||
* All argument types supported by native pgx should now also work through database/sql
|
### Number Type Fields Include Bit size
|
||||||
* Update pgconn to v1.6.3
|
|
||||||
* Update pgtype to v1.4.2
|
|
||||||
|
|
||||||
# 4.7.2 (July 14, 2020)
|
`Int2`, `Int4`, `Int8`, `Float4`, `Float8`, and `Uint32` fields now include bit size. e.g. `Int` is renamed to `Int64`.
|
||||||
|
This matches the convention set by `database/sql`. In addition, for comparable types like `pgtype.Int8` and
|
||||||
|
`sql.NullInt64` the structures are identical. This means they can be directly converted one to another.
|
||||||
|
|
||||||
* Improve performance of Columns() (zikaeroh)
|
### 3rd Party Type Integrations
|
||||||
* Fix fatal Commit() failure not being considered fatal
|
|
||||||
* Update pgconn to v1.6.2
|
|
||||||
* Update pgtype to v1.4.1
|
|
||||||
|
|
||||||
# 4.7.1 (June 29, 2020)
|
* Extracted integrations with https://github.com/shopspring/decimal and https://github.com/gofrs/uuid to
|
||||||
|
https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid respectively. This trims
|
||||||
|
the pgx dependency tree.
|
||||||
|
|
||||||
* Fix stdlib decoding error with certain order and combination of fields
|
### Other Changes
|
||||||
|
|
||||||
# 4.7.0 (June 27, 2020)
|
* `Bit` and `Varbit` are both replaced by the `Bits` type.
|
||||||
|
* `CID`, `OID`, `OIDValue`, and `XID` are replaced by the `Uint32` type.
|
||||||
|
* `Hstore` is now defined as `map[string]*string`.
|
||||||
|
* `JSON` and `JSONB` types removed. Use `[]byte` or `string` directly.
|
||||||
|
* `QChar` type removed. Use `rune` or `byte` directly.
|
||||||
|
* `Inet` and `Cidr` types removed. Use `netip.Addr` and `netip.Prefix` directly. These types are more memory efficient than the previous `net.IPNet`.
|
||||||
|
* `Macaddr` type removed. Use `net.HardwareAddr` directly.
|
||||||
|
* Renamed `pgtype.ConnInfo` to `pgtype.Map`.
|
||||||
|
* Renamed `pgtype.DataType` to `pgtype.Type`.
|
||||||
|
* Renamed `pgtype.None` to `pgtype.Finite`.
|
||||||
|
* `RegisterType` now accepts a `*Type` instead of `Type`.
|
||||||
|
* Assorted array helper methods and types made private.
|
||||||
|
|
||||||
* Update pgtype to v1.4.0
|
## stdlib
|
||||||
* Update pgconn to v1.6.1
|
|
||||||
* Update puddle to v1.1.1
|
|
||||||
* Fix context propagation with Tx commit and Rollback (georgysavva)
|
|
||||||
* Add lazy connect option to pgxpool (georgysavva)
|
|
||||||
* Fix connection leak if pgxpool.BeginTx() fail (Jean-Baptiste Bronisz)
|
|
||||||
* Add native Go slice support for strings and numbers to simple protocol
|
|
||||||
* stdlib add default timeouts for Conn.Close() and Stmt.Close() (georgysavva)
|
|
||||||
* Assorted performance improvements especially with large result sets
|
|
||||||
* Fix close pool on not lazy connect failure (Yegor Myskin)
|
|
||||||
* Add Config copy (georgysavva)
|
|
||||||
* Support SendBatch with Simple Protocol (Jordan Lewis)
|
|
||||||
* Better error logging on rows close (Igor V. Kozinov)
|
|
||||||
* Expose stdlib.Conn.Conn() to enable database/sql.Conn.Raw()
|
|
||||||
* Improve unknown type support for database/sql
|
|
||||||
* Fix transaction commit failure closing connection
|
|
||||||
|
|
||||||
# 4.6.0 (March 30, 2020)
|
* Removed `AcquireConn` and `ReleaseConn` as that functionality has been built in since Go 1.13.
|
||||||
|
|
||||||
* stdlib: Bail early if preloading rows.Next() results in rows.Err() (Bas van Beek)
|
## Reduced Memory Usage by Reusing Read Buffers
|
||||||
* Sanitize time to microsecond accuracy (Andrew Nicoll)
|
|
||||||
* Update pgtype to v1.3.0
|
|
||||||
* Update pgconn to v1.5.0
|
|
||||||
* Update golang.org/x/crypto for security fix
|
|
||||||
* Implement "verify-ca" SSL mode
|
|
||||||
|
|
||||||
# 4.5.0 (March 7, 2020)
|
Previously, the connection read buffer would allocate large chunks of memory and never reuse them. This allowed
|
||||||
|
transferring ownership to anything such as scanned values without incurring an additional allocation and memory copy.
|
||||||
|
However, this came at the cost of overall increased memory allocation size. But worse it was also possible to pin large
|
||||||
|
chunks of memory by retaining a reference to a small value that originally came directly from the read buffer. Now
|
||||||
|
ownership remains with the read buffer and anything needing to retain a value must make a copy.
|
||||||
|
|
||||||
* Update to pgconn v1.4.0
|
## Query Execution Modes
|
||||||
* Fixes QueryRow with empty SQL
|
|
||||||
* Adds PostgreSQL service file support
|
|
||||||
* Add Len() to *pgx.Batch (WGH)
|
|
||||||
* Better logging for individual batch items (Ben Bader)
|
|
||||||
|
|
||||||
# 4.4.1 (February 14, 2020)
|
Control over automatic prepared statement caching and simple protocol use are now combined into query execution mode.
|
||||||
|
See documentation for `QueryExecMode`.
|
||||||
|
|
||||||
* Update pgconn to v1.3.2 - better default read buffer size
|
## QueryRewriter Interface and NamedArgs
|
||||||
* Fix race in CopyFrom
|
|
||||||
|
|
||||||
# 4.4.0 (February 5, 2020)
|
pgx now supports named arguments with the `NamedArgs` type. This is implemented via the new `QueryRewriter` interface which
|
||||||
|
allows arbitrary rewriting of query SQL and arguments.
|
||||||
|
|
||||||
* Update puddle to v1.1.0 - fixes possible deadlock when acquire is cancelled
|
## RowScanner Interface
|
||||||
* Update pgconn to v1.3.1 - fixes CopyFrom deadlock when multiple NoticeResponse received during copy
|
|
||||||
* Update pgtype to v1.2.0
|
|
||||||
* Add MaxConnIdleTime to pgxpool (Patrick Ellul)
|
|
||||||
* Add MinConns to pgxpool (Patrick Ellul)
|
|
||||||
* Fix: stdlib.ReleaseConn closes connections left in invalid state
|
|
||||||
|
|
||||||
# 4.3.0 (January 23, 2020)
|
The `RowScanner` interface allows a single argument to Rows.Scan to scan the entire row.
|
||||||
|
|
||||||
* Fix Rows.Values panic when unable to decode
|
## Rows Result Helpers
|
||||||
* Add Rows.Values support for unknown types
|
|
||||||
* Add DriverContext support for stdlib (Alex Gaynor)
|
|
||||||
* Update pgproto3 to v2.0.1 to never return an io.EOF as it would be misinterpreted by database/sql. Instead return io.UnexpectedEOF.
|
|
||||||
|
|
||||||
# 4.2.1 (January 13, 2020)
|
* `CollectRows` and `RowTo*` functions simplify collecting results into a slice.
|
||||||
|
* `CollectOneRow` collects one row using `RowTo*` functions.
|
||||||
|
* `ForEachRow` simplifies scanning each row and executing code using the scanned values. `ForEachRow` replaces `QueryFunc`.
|
||||||
|
|
||||||
* Update pgconn to v1.2.1 (fixes context cancellation data race introduced in v1.2.0))
|
## Tx Helpers
|
||||||
|
|
||||||
# 4.2.0 (January 11, 2020)
|
Rather than every type that implemented `Begin` or `BeginTx` methods also needing to implement `BeginFunc` and
|
||||||
|
`BeginTxFunc` these methods have been converted to functions that take a db that implements `Begin` or `BeginTx`.
|
||||||
|
|
||||||
* Update pgconn to v1.2.0.
|
## Improved Batch Query Ergonomics
|
||||||
* Update pgtype to v1.1.0.
|
|
||||||
* Return error instead of panic when wrong number of arguments passed to Exec. (malstoun)
|
|
||||||
* Fix large objects functionality when PreferSimpleProtocol = true.
|
|
||||||
* Restore GetDefaultDriver which existed in v3. (Johan Brandhorst)
|
|
||||||
* Add RegisterConnConfig to stdlib which replaces the removed RegisterDriverConfig from v3.
|
|
||||||
|
|
||||||
# 4.1.2 (October 22, 2019)
|
Previously, the code for building a batch went in one place before the call to `SendBatch`, and the code for reading the
|
||||||
|
results went in one place after the call to `SendBatch`. This could make it difficult to match up the query and the code
|
||||||
|
to handle the results. Now `Queue` returns a `QueuedQuery` which has methods `Query`, `QueryRow`, and `Exec` which can
|
||||||
|
be used to register a callback function that will handle the result. Callback functions are called automatically when
|
||||||
|
`BatchResults.Close` is called.
|
||||||
|
|
||||||
* Fix dbSavepoint.Begin recursive self call
|
## SendBatch Uses Pipeline Mode When Appropriate
|
||||||
* Upgrade pgtype to v1.0.2 - fix scan pointer to pointer
|
|
||||||
|
|
||||||
# 4.1.1 (October 21, 2019)
|
Previously, a batch with 10 unique parameterized statements executed 100 times would entail 11 network round trips. 1
|
||||||
|
for each prepare / describe and 1 for executing them all. Now pipeline mode is used to prepare / describe all statements
|
||||||
|
in a single network round trip. So it would only take 2 round trips.
|
||||||
|
|
||||||
* Fix pgxpool Rows.CommandTag() infinite loop / typo
|
## Tracing and Logging
|
||||||
|
|
||||||
# 4.1.0 (October 12, 2019)
|
Internal logging support has been replaced with tracing hooks. This allows custom tracing integration with tools like OpenTelemetry. Package tracelog provides an adapter for pgx v4 loggers to act as a tracer.
|
||||||
|
|
||||||
## Potentially Breaking Changes
|
All integrations with 3rd party loggers have been extracted to separate repositories. This trims the pgx dependency
|
||||||
|
tree.
|
||||||
Technically, two changes are breaking changes, but in practice these are extremely unlikely to break existing code.
|
|
||||||
|
|
||||||
* Conn.Begin and Conn.BeginTx return a Tx interface instead of the internal dbTx struct. This is necessary for the Conn.Begin method to signature as other methods that begin a transaction.
|
|
||||||
* Add Conn() to Tx interface. This is necessary to allow code using a Tx to access the *Conn (and pgconn.PgConn) on which the Tx is executing.
|
|
||||||
|
|
||||||
## Fixes
|
|
||||||
|
|
||||||
* Releasing a busy connection closes the connection instead of returning an unusable connection to the pool
|
|
||||||
* Do not mutate config.Config.OnNotification in connect
|
|
||||||
|
|
||||||
# 4.0.1 (September 19, 2019)
|
|
||||||
|
|
||||||
* Fix statement cache cleanup.
|
|
||||||
* Corrected daterange OID.
|
|
||||||
* Fix Tx when committing or rolling back multiple times in certain cases.
|
|
||||||
* Improve documentation.
|
|
||||||
|
|
||||||
# 4.0.0 (September 14, 2019)
|
|
||||||
|
|
||||||
v4 is a major release with many significant changes some of which are breaking changes. The most significant are
|
|
||||||
included below.
|
|
||||||
|
|
||||||
* Simplified establishing a connection with a connection string.
|
|
||||||
* All potentially blocking operations now require a context.Context. The non-context aware functions have been removed.
|
|
||||||
* OIDs are hard-coded for known types. This saves the query on connection.
|
|
||||||
* Context cancellations while network activity is in progress is now always fatal. Previously, it was sometimes recoverable. This led to increased complexity in pgx itself and in application code.
|
|
||||||
* Go modules are required.
|
|
||||||
* Errors are now implemented in the Go 1.13 style.
|
|
||||||
* `Rows` and `Tx` are now interfaces.
|
|
||||||
* The connection pool as been decoupled from pgx and is now a separate, included package (github.com/jackc/pgx/v4/pgxpool).
|
|
||||||
* pgtype has been spun off to a separate package (github.com/jackc/pgtype).
|
|
||||||
* pgproto3 has been spun off to a separate package (github.com/jackc/pgproto3/v2).
|
|
||||||
* Logical replication support has been spun off to a separate package (github.com/jackc/pglogrepl).
|
|
||||||
* Lower level PostgreSQL functionality is now implemented in a separate package (github.com/jackc/pgconn).
|
|
||||||
* Tests are now configured with environment variables.
|
|
||||||
* Conn has an automatic statement cache by default.
|
|
||||||
* Batch interface has been simplified.
|
|
||||||
* QueryArgs has been removed.
|
|
||||||
|
|||||||
@@ -1,25 +1,17 @@
|
|||||||
[](https://pkg.go.dev/github.com/jackc/pgx/v4)
|
[](https://pkg.go.dev/github.com/jackc/pgx/v5)
|
||||||
[](https://travis-ci.org/jackc/pgx)
|

|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
This is the stable `v4` release. `v5` is now in beta testing with final release expected in September. See https://github.com/jackc/pgx/issues/1273 for more information. Please consider testing `v5`.
|
|
||||||
|
|
||||||
---
|
|
||||||
# pgx - PostgreSQL Driver and Toolkit
|
# pgx - PostgreSQL Driver and Toolkit
|
||||||
|
|
||||||
pgx is a pure Go driver and toolkit for PostgreSQL.
|
pgx is a pure Go driver and toolkit for PostgreSQL.
|
||||||
|
|
||||||
pgx aims to be low-level, fast, and performant, while also enabling PostgreSQL-specific features that the standard `database/sql` package does not allow for.
|
The pgx driver is a low-level, high performance interface that exposes PostgreSQL-specific features such as `LISTEN` /
|
||||||
|
`NOTIFY` and `COPY`. It also includes an adapter for the standard `database/sql` interface.
|
||||||
The driver component of pgx can be used alongside the standard `database/sql` package.
|
|
||||||
|
|
||||||
The toolkit component is a related set of packages that implement PostgreSQL functionality such as parsing the wire protocol
|
The toolkit component is a related set of packages that implement PostgreSQL functionality such as parsing the wire protocol
|
||||||
and type mapping between PostgreSQL and Go. These underlying packages can be used to implement alternative drivers,
|
and type mapping between PostgreSQL and Go. These underlying packages can be used to implement alternative drivers,
|
||||||
proxies, load balancers, logical replication clients, etc.
|
proxies, load balancers, logical replication clients, etc.
|
||||||
|
|
||||||
The current release of `pgx v4` requires Go modules. To use the previous version, checkout and vendor the `v3` branch.
|
|
||||||
|
|
||||||
## Example Usage
|
## Example Usage
|
||||||
|
|
||||||
```go
|
```go
|
||||||
@@ -30,7 +22,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/jackc/pgx/v4"
|
"github.com/jackc/pgx/v5"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -56,52 +48,39 @@ func main() {
|
|||||||
|
|
||||||
See the [getting started guide](https://github.com/jackc/pgx/wiki/Getting-started-with-pgx) for more information.
|
See the [getting started guide](https://github.com/jackc/pgx/wiki/Getting-started-with-pgx) for more information.
|
||||||
|
|
||||||
## Choosing Between the pgx and database/sql Interfaces
|
|
||||||
|
|
||||||
It is recommended to use the pgx interface if:
|
|
||||||
1. The application only targets PostgreSQL.
|
|
||||||
2. No other libraries that require `database/sql` are in use.
|
|
||||||
|
|
||||||
The pgx interface is faster and exposes more features.
|
|
||||||
|
|
||||||
The `database/sql` interface only allows the underlying driver to return or receive the following types: `int64`,
|
|
||||||
`float64`, `bool`, `[]byte`, `string`, `time.Time`, or `nil`. Handling other types requires implementing the
|
|
||||||
`database/sql.Scanner` and the `database/sql/driver/driver.Valuer` interfaces which require transmission of values in text format. The binary format can be substantially faster, which is what the pgx interface uses.
|
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
pgx supports many features beyond what is available through `database/sql`:
|
|
||||||
|
|
||||||
* Support for approximately 70 different PostgreSQL types
|
* Support for approximately 70 different PostgreSQL types
|
||||||
* Automatic statement preparation and caching
|
* Automatic statement preparation and caching
|
||||||
* Batch queries
|
* Batch queries
|
||||||
* Single-round trip query mode
|
* Single-round trip query mode
|
||||||
* Full TLS connection control
|
* Full TLS connection control
|
||||||
* Binary format support for custom types (allows for much quicker encoding/decoding)
|
* Binary format support for custom types (allows for much quicker encoding/decoding)
|
||||||
* COPY protocol support for faster bulk data loads
|
* `COPY` protocol support for faster bulk data loads
|
||||||
* Extendable logging support including built-in support for `log15adapter`, [`logrus`](https://github.com/sirupsen/logrus), [`zap`](https://github.com/uber-go/zap), and [`zerolog`](https://github.com/rs/zerolog)
|
* Tracing and logging support
|
||||||
* Connection pool with after-connect hook for arbitrary connection setup
|
* Connection pool with after-connect hook for arbitrary connection setup
|
||||||
* Listen / notify
|
* `LISTEN` / `NOTIFY`
|
||||||
* Conversion of PostgreSQL arrays to Go slice mappings for integers, floats, and strings
|
* Conversion of PostgreSQL arrays to Go slice mappings for integers, floats, and strings
|
||||||
* Hstore support
|
* `hstore` support
|
||||||
* JSON and JSONB support
|
* `json` and `jsonb` support
|
||||||
* Maps `inet` and `cidr` PostgreSQL types to `net.IPNet` and `net.IP`
|
* Maps `inet` and `cidr` PostgreSQL types to `netip.Addr` and `netip.Prefix`
|
||||||
* Large object support
|
* Large object support
|
||||||
* NULL mapping to Null* struct or pointer to pointer
|
* NULL mapping to pointer to pointer
|
||||||
* Supports `database/sql.Scanner` and `database/sql/driver.Valuer` interfaces for custom types
|
* Supports `database/sql.Scanner` and `database/sql/driver.Valuer` interfaces for custom types
|
||||||
* Notice response handling
|
* Notice response handling
|
||||||
* Simulated nested transactions with savepoints
|
* Simulated nested transactions with savepoints
|
||||||
|
|
||||||
## Performance
|
## Choosing Between the pgx and database/sql Interfaces
|
||||||
|
|
||||||
There are three areas in particular where pgx can provide a significant performance advantage over the standard
|
The pgx interface is faster. Many PostgreSQL specific features such as `LISTEN` / `NOTIFY` and `COPY` are not available
|
||||||
`database/sql` interface and other drivers:
|
through the `database/sql` interface.
|
||||||
|
|
||||||
1. PostgreSQL specific types - Types such as arrays can be parsed much quicker because pgx uses the binary format.
|
The pgx interface is recommended when:
|
||||||
2. Automatic statement preparation and caching - pgx will prepare and cache statements by default. This can provide an
|
|
||||||
significant free improvement to code that does not explicitly use prepared statements. Under certain workloads, it can
|
1. The application only targets PostgreSQL.
|
||||||
perform nearly 3x the number of queries per second.
|
2. No other libraries that require `database/sql` are in use.
|
||||||
3. Batched queries - Multiple queries can be batched together to minimize network round trips.
|
|
||||||
|
It is also possible to use the `database/sql` interface and convert a connection to the lower-level pgx interface as needed.
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
@@ -134,37 +113,14 @@ In addition, there are tests specific for PgBouncer that will be executed if `PG
|
|||||||
|
|
||||||
## Supported Go and PostgreSQL Versions
|
## Supported Go and PostgreSQL Versions
|
||||||
|
|
||||||
pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.16 and higher and PostgreSQL 10 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
|
pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.18 and higher and PostgreSQL 10 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
|
||||||
|
|
||||||
## Version Policy
|
## Version Policy
|
||||||
|
|
||||||
pgx follows semantic versioning for the documented public API on stable releases. `v4` is the latest stable major version.
|
pgx follows semantic versioning for the documented public API on stable releases. `v5` is the latest stable major version.
|
||||||
|
|
||||||
## PGX Family Libraries
|
## PGX Family Libraries
|
||||||
|
|
||||||
pgx is the head of a family of PostgreSQL libraries. Many of these can be used independently. Many can also be accessed
|
|
||||||
from pgx for lower-level control.
|
|
||||||
|
|
||||||
### [github.com/jackc/pgconn](https://github.com/jackc/pgconn)
|
|
||||||
|
|
||||||
`pgconn` is a lower-level PostgreSQL database driver that operates at nearly the same level as the C library `libpq`.
|
|
||||||
|
|
||||||
### [github.com/jackc/pgx/v4/pgxpool](https://github.com/jackc/pgx/tree/master/pgxpool)
|
|
||||||
|
|
||||||
`pgxpool` is a connection pool for pgx. pgx is entirely decoupled from its default pool implementation. This means that pgx can be used with a different pool or without any pool at all.
|
|
||||||
|
|
||||||
### [github.com/jackc/pgx/v4/stdlib](https://github.com/jackc/pgx/tree/master/stdlib)
|
|
||||||
|
|
||||||
This is a `database/sql` compatibility layer for pgx. pgx can be used as a normal `database/sql` driver, but at any time, the native interface can be acquired for more performance or PostgreSQL specific functionality.
|
|
||||||
|
|
||||||
### [github.com/jackc/pgtype](https://github.com/jackc/pgtype)
|
|
||||||
|
|
||||||
Over 70 PostgreSQL types are supported including `uuid`, `hstore`, `json`, `bytea`, `numeric`, `interval`, `inet`, and arrays. These types support `database/sql` interfaces and are usable outside of pgx. They are fully tested in pgx and pq. They also support a higher performance interface when used with the pgx driver.
|
|
||||||
|
|
||||||
### [github.com/jackc/pgproto3](https://github.com/jackc/pgproto3)
|
|
||||||
|
|
||||||
pgproto3 provides standalone encoding and decoding of the PostgreSQL v3 wire protocol. This is useful for implementing very low level PostgreSQL tooling.
|
|
||||||
|
|
||||||
### [github.com/jackc/pglogrepl](https://github.com/jackc/pglogrepl)
|
### [github.com/jackc/pglogrepl](https://github.com/jackc/pglogrepl)
|
||||||
|
|
||||||
pglogrepl provides functionality to act as a client for PostgreSQL logical replication.
|
pglogrepl provides functionality to act as a client for PostgreSQL logical replication.
|
||||||
@@ -181,6 +137,22 @@ tern is a stand-alone SQL migration system.
|
|||||||
|
|
||||||
pgerrcode contains constants for the PostgreSQL error codes.
|
pgerrcode contains constants for the PostgreSQL error codes.
|
||||||
|
|
||||||
|
## Adapters for 3rd Party Types
|
||||||
|
|
||||||
|
* [github.com/jackc/pgx-gofrs-uuid](https://github.com/jackc/pgx-gofrs-uuid)
|
||||||
|
* [github.com/jackc/pgx-shopspring-decimal](https://github.com/jackc/pgx-shopspring-decimal)
|
||||||
|
* [github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid)
|
||||||
|
|
||||||
|
## Adapters for 3rd Party Loggers
|
||||||
|
|
||||||
|
These adapters can be used with the tracelog package.
|
||||||
|
|
||||||
|
* [github.com/jackc/pgx-go-kit-log](https://github.com/jackc/pgx-go-kit-log)
|
||||||
|
* [github.com/jackc/pgx-log15](https://github.com/jackc/pgx-log15)
|
||||||
|
* [github.com/jackc/pgx-logrus](https://github.com/jackc/pgx-logrus)
|
||||||
|
* [github.com/jackc/pgx-zap](https://github.com/jackc/pgx-zap)
|
||||||
|
* [github.com/jackc/pgx-zerolog](https://github.com/jackc/pgx-zerolog)
|
||||||
|
|
||||||
## 3rd Party Libraries with PGX Support
|
## 3rd Party Libraries with PGX Support
|
||||||
|
|
||||||
### [github.com/georgysavva/scany](https://github.com/georgysavva/scany)
|
### [github.com/georgysavva/scany](https://github.com/georgysavva/scany)
|
||||||
@@ -190,7 +162,3 @@ Library for scanning data from a database into Go structs and more.
|
|||||||
### [https://github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5)
|
### [https://github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5)
|
||||||
|
|
||||||
Adds GSSAPI / Kerberos authentication support.
|
Adds GSSAPI / Kerberos authentication support.
|
||||||
|
|
||||||
### [https://github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid)
|
|
||||||
|
|
||||||
Adds support for [`github.com/google/uuid`](https://github.com/google/uuid).
|
|
||||||
|
|||||||
@@ -0,0 +1,18 @@
|
|||||||
|
require "erb"
|
||||||
|
|
||||||
|
rule '.go' => '.go.erb' do |task|
|
||||||
|
erb = ERB.new(File.read(task.source))
|
||||||
|
File.write(task.name, "// Do not edit. Generated from #{task.source}\n" + erb.result(binding))
|
||||||
|
sh "goimports", "-w", task.name
|
||||||
|
end
|
||||||
|
|
||||||
|
generated_code_files = [
|
||||||
|
"pgtype/int.go",
|
||||||
|
"pgtype/int_test.go",
|
||||||
|
"pgtype/integration_benchmark_test.go",
|
||||||
|
"pgtype/zeronull/int.go",
|
||||||
|
"pgtype/zeronull/int_test.go"
|
||||||
|
]
|
||||||
|
|
||||||
|
desc "Generate code"
|
||||||
|
task generate: generated_code_files
|
||||||
@@ -5,69 +5,123 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/jackc/pgconn"
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
)
|
)
|
||||||
|
|
||||||
type batchItem struct {
|
// QueuedQuery is a query that has been queued for execution via a Batch.
|
||||||
|
type QueuedQuery struct {
|
||||||
query string
|
query string
|
||||||
arguments []interface{}
|
arguments []any
|
||||||
|
fn batchItemFunc
|
||||||
|
sd *pgconn.StatementDescription
|
||||||
|
}
|
||||||
|
|
||||||
|
type batchItemFunc func(br BatchResults) error
|
||||||
|
|
||||||
|
// Query sets fn to be called when the response to qq is received.
|
||||||
|
func (qq *QueuedQuery) Query(fn func(rows Rows) error) {
|
||||||
|
qq.fn = func(br BatchResults) error {
|
||||||
|
rows, err := br.Query()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
err = fn(rows)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
rows.Close()
|
||||||
|
|
||||||
|
return rows.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query sets fn to be called when the response to qq is received.
|
||||||
|
func (qq *QueuedQuery) QueryRow(fn func(row Row) error) {
|
||||||
|
qq.fn = func(br BatchResults) error {
|
||||||
|
row := br.QueryRow()
|
||||||
|
return fn(row)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec sets fn to be called when the response to qq is received.
|
||||||
|
func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) {
|
||||||
|
qq.fn = func(br BatchResults) error {
|
||||||
|
ct, err := br.Exec()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return fn(ct)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Batch queries are a way of bundling multiple queries together to avoid
|
// Batch queries are a way of bundling multiple queries together to avoid
|
||||||
// unnecessary network round trips.
|
// unnecessary network round trips. A Batch must only be sent once.
|
||||||
type Batch struct {
|
type Batch struct {
|
||||||
items []*batchItem
|
queuedQueries []*QueuedQuery
|
||||||
}
|
}
|
||||||
|
|
||||||
// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement.
|
// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement.
|
||||||
func (b *Batch) Queue(query string, arguments ...interface{}) {
|
func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery {
|
||||||
b.items = append(b.items, &batchItem{
|
qq := &QueuedQuery{
|
||||||
query: query,
|
query: query,
|
||||||
arguments: arguments,
|
arguments: arguments,
|
||||||
})
|
}
|
||||||
|
b.queuedQueries = append(b.queuedQueries, qq)
|
||||||
|
return qq
|
||||||
}
|
}
|
||||||
|
|
||||||
// Len returns number of queries that have been queued so far.
|
// Len returns number of queries that have been queued so far.
|
||||||
func (b *Batch) Len() int {
|
func (b *Batch) Len() int {
|
||||||
return len(b.items)
|
return len(b.queuedQueries)
|
||||||
}
|
}
|
||||||
|
|
||||||
type BatchResults interface {
|
type BatchResults interface {
|
||||||
// Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec.
|
// Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec. Prefer
|
||||||
|
// calling Exec on the QueuedQuery.
|
||||||
Exec() (pgconn.CommandTag, error)
|
Exec() (pgconn.CommandTag, error)
|
||||||
|
|
||||||
// Query reads the results from the next query in the batch as if the query has been sent with Conn.Query.
|
// Query reads the results from the next query in the batch as if the query has been sent with Conn.Query. Prefer
|
||||||
|
// calling Query on the QueuedQuery.
|
||||||
Query() (Rows, error)
|
Query() (Rows, error)
|
||||||
|
|
||||||
// QueryRow reads the results from the next query in the batch as if the query has been sent with Conn.QueryRow.
|
// QueryRow reads the results from the next query in the batch as if the query has been sent with Conn.QueryRow.
|
||||||
|
// Prefer calling QueryRow on the QueuedQuery.
|
||||||
QueryRow() Row
|
QueryRow() Row
|
||||||
|
|
||||||
// QueryFunc reads the results from the next query in the batch as if the query has been sent with Conn.QueryFunc.
|
// Close closes the batch operation. All unread results are read and any callback functions registered with
|
||||||
QueryFunc(scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error)
|
// QueuedQuery.Query, QueuedQuery.QueryRow, or QueuedQuery.Exec will be called. If a callback function returns an
|
||||||
|
// error or the batch encounters an error subsequent callback functions will not be called.
|
||||||
// Close closes the batch operation. This must be called before the underlying connection can be used again. Any error
|
//
|
||||||
// that occurred during a batch operation may have made it impossible to resyncronize the connection with the server.
|
// Close must be called before the underlying connection can be used again. Any error that occurred during a batch
|
||||||
// In this case the underlying connection will have been closed. Close is safe to call multiple times.
|
// operation may have made it impossible to resyncronize the connection with the server. In this case the underlying
|
||||||
|
// connection will have been closed.
|
||||||
|
//
|
||||||
|
// Close is safe to call multiple times. If it returns an error subsequent calls will return the same error. Callback
|
||||||
|
// functions will not be rerun.
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
type batchResults struct {
|
type batchResults struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
conn *Conn
|
conn *Conn
|
||||||
mrr *pgconn.MultiResultReader
|
mrr *pgconn.MultiResultReader
|
||||||
err error
|
err error
|
||||||
b *Batch
|
b *Batch
|
||||||
ix int
|
qqIdx int
|
||||||
closed bool
|
closed bool
|
||||||
|
endTraced bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exec reads the results from the next query in the batch as if the query has been sent with Exec.
|
// Exec reads the results from the next query in the batch as if the query has been sent with Exec.
|
||||||
func (br *batchResults) Exec() (pgconn.CommandTag, error) {
|
func (br *batchResults) Exec() (pgconn.CommandTag, error) {
|
||||||
if br.err != nil {
|
if br.err != nil {
|
||||||
return nil, br.err
|
return pgconn.CommandTag{}, br.err
|
||||||
}
|
}
|
||||||
if br.closed {
|
if br.closed {
|
||||||
return nil, fmt.Errorf("batch already closed")
|
return pgconn.CommandTag{}, fmt.Errorf("batch already closed")
|
||||||
}
|
}
|
||||||
|
|
||||||
query, arguments, _ := br.nextQueryAndArgs()
|
query, arguments, _ := br.nextQueryAndArgs()
|
||||||
@@ -77,35 +131,29 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
err = errors.New("no result")
|
err = errors.New("no result")
|
||||||
}
|
}
|
||||||
if br.conn.shouldLog(LogLevelError) {
|
if br.conn.batchTracer != nil {
|
||||||
br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]interface{}{
|
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
|
||||||
"sql": query,
|
SQL: query,
|
||||||
"args": logQueryArgs(arguments),
|
Args: arguments,
|
||||||
"err": err,
|
Err: err,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return nil, err
|
return pgconn.CommandTag{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
commandTag, err := br.mrr.ResultReader().Close()
|
commandTag, err := br.mrr.ResultReader().Close()
|
||||||
|
br.err = err
|
||||||
|
|
||||||
if err != nil {
|
if br.conn.batchTracer != nil {
|
||||||
if br.conn.shouldLog(LogLevelError) {
|
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
|
||||||
br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]interface{}{
|
SQL: query,
|
||||||
"sql": query,
|
Args: arguments,
|
||||||
"args": logQueryArgs(arguments),
|
CommandTag: commandTag,
|
||||||
"err": err,
|
Err: br.err,
|
||||||
})
|
|
||||||
}
|
|
||||||
} else if br.conn.shouldLog(LogLevelInfo) {
|
|
||||||
br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Exec", map[string]interface{}{
|
|
||||||
"sql": query,
|
|
||||||
"args": logQueryArgs(arguments),
|
|
||||||
"commandTag": commandTag,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return commandTag, err
|
return commandTag, br.err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query reads the results from the next query in the batch as if the query has been sent with Query.
|
// Query reads the results from the next query in the batch as if the query has been sent with Query.
|
||||||
@@ -116,15 +164,16 @@ func (br *batchResults) Query() (Rows, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if br.err != nil {
|
if br.err != nil {
|
||||||
return &connRows{err: br.err, closed: true}, br.err
|
return &baseRows{err: br.err, closed: true}, br.err
|
||||||
}
|
}
|
||||||
|
|
||||||
if br.closed {
|
if br.closed {
|
||||||
alreadyClosedErr := fmt.Errorf("batch already closed")
|
alreadyClosedErr := fmt.Errorf("batch already closed")
|
||||||
return &connRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr
|
return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr
|
||||||
}
|
}
|
||||||
|
|
||||||
rows := br.conn.getRows(br.ctx, query, arguments)
|
rows := br.conn.getRows(br.ctx, query, arguments)
|
||||||
|
rows.batchTracer = br.conn.batchTracer
|
||||||
|
|
||||||
if !br.mrr.NextResult() {
|
if !br.mrr.NextResult() {
|
||||||
rows.err = br.mrr.Close()
|
rows.err = br.mrr.Close()
|
||||||
@@ -133,11 +182,11 @@ func (br *batchResults) Query() (Rows, error) {
|
|||||||
}
|
}
|
||||||
rows.closed = true
|
rows.closed = true
|
||||||
|
|
||||||
if br.conn.shouldLog(LogLevelError) {
|
if br.conn.batchTracer != nil {
|
||||||
br.conn.log(br.ctx, LogLevelError, "BatchResult.Query", map[string]interface{}{
|
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
|
||||||
"sql": query,
|
SQL: query,
|
||||||
"args": logQueryArgs(arguments),
|
Args: arguments,
|
||||||
"err": rows.err,
|
Err: rows.err,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -148,47 +197,25 @@ func (br *batchResults) Query() (Rows, error) {
|
|||||||
return rows, nil
|
return rows, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryFunc reads the results from the next query in the batch as if the query has been sent with Conn.QueryFunc.
|
|
||||||
func (br *batchResults) QueryFunc(scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) {
|
|
||||||
if br.closed {
|
|
||||||
return nil, fmt.Errorf("batch already closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := br.Query()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
err = rows.Scan(scans...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = f(rows)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := rows.Err(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return rows.CommandTag(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow.
|
// QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow.
|
||||||
func (br *batchResults) QueryRow() Row {
|
func (br *batchResults) QueryRow() Row {
|
||||||
rows, _ := br.Query()
|
rows, _ := br.Query()
|
||||||
return (*connRow)(rows.(*connRows))
|
return (*connRow)(rows.(*baseRows))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to
|
// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to
|
||||||
// resyncronize the connection with the server. In this case the underlying connection will have been closed.
|
// resyncronize the connection with the server. In this case the underlying connection will have been closed.
|
||||||
func (br *batchResults) Close() error {
|
func (br *batchResults) Close() error {
|
||||||
|
defer func() {
|
||||||
|
if !br.endTraced {
|
||||||
|
if br.conn != nil && br.conn.batchTracer != nil {
|
||||||
|
br.conn.batchTracer.TraceBatchEnd(br.ctx, br.conn, TraceBatchEndData{Err: br.err})
|
||||||
|
}
|
||||||
|
br.endTraced = true
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if br.err != nil {
|
if br.err != nil {
|
||||||
return br.err
|
return br.err
|
||||||
}
|
}
|
||||||
@@ -196,33 +223,213 @@ func (br *batchResults) Close() error {
|
|||||||
if br.closed {
|
if br.closed {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
br.closed = true
|
|
||||||
|
|
||||||
// log any queries that haven't yet been logged by Exec or Query
|
// Read and run fn for all remaining items
|
||||||
for {
|
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
|
||||||
query, args, ok := br.nextQueryAndArgs()
|
if br.b.queuedQueries[br.qqIdx].fn != nil {
|
||||||
if !ok {
|
err := br.b.queuedQueries[br.qqIdx].fn(br)
|
||||||
break
|
if err != nil && br.err == nil {
|
||||||
}
|
br.err = err
|
||||||
|
}
|
||||||
if br.conn.shouldLog(LogLevelInfo) {
|
} else {
|
||||||
br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Close", map[string]interface{}{
|
br.Exec()
|
||||||
"sql": query,
|
|
||||||
"args": logQueryArgs(args),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return br.mrr.Close()
|
br.closed = true
|
||||||
|
|
||||||
|
err := br.mrr.Close()
|
||||||
|
if br.err == nil {
|
||||||
|
br.err = err
|
||||||
|
}
|
||||||
|
|
||||||
|
return br.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (br *batchResults) nextQueryAndArgs() (query string, args []interface{}, ok bool) {
|
func (br *batchResults) earlyError() error {
|
||||||
if br.b != nil && br.ix < len(br.b.items) {
|
return br.err
|
||||||
bi := br.b.items[br.ix]
|
}
|
||||||
|
|
||||||
|
func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
|
||||||
|
if br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
|
||||||
|
bi := br.b.queuedQueries[br.qqIdx]
|
||||||
query = bi.query
|
query = bi.query
|
||||||
args = bi.arguments
|
args = bi.arguments
|
||||||
ok = true
|
ok = true
|
||||||
br.ix++
|
br.qqIdx++
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type pipelineBatchResults struct {
|
||||||
|
ctx context.Context
|
||||||
|
conn *Conn
|
||||||
|
pipeline *pgconn.Pipeline
|
||||||
|
lastRows *baseRows
|
||||||
|
err error
|
||||||
|
b *Batch
|
||||||
|
qqIdx int
|
||||||
|
closed bool
|
||||||
|
endTraced bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec reads the results from the next query in the batch as if the query has been sent with Exec.
|
||||||
|
func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) {
|
||||||
|
if br.err != nil {
|
||||||
|
return pgconn.CommandTag{}, br.err
|
||||||
|
}
|
||||||
|
if br.closed {
|
||||||
|
return pgconn.CommandTag{}, fmt.Errorf("batch already closed")
|
||||||
|
}
|
||||||
|
if br.lastRows != nil && br.lastRows.err != nil {
|
||||||
|
return pgconn.CommandTag{}, br.err
|
||||||
|
}
|
||||||
|
|
||||||
|
query, arguments, _ := br.nextQueryAndArgs()
|
||||||
|
|
||||||
|
results, err := br.pipeline.GetResults()
|
||||||
|
if err != nil {
|
||||||
|
br.err = err
|
||||||
|
return pgconn.CommandTag{}, err
|
||||||
|
}
|
||||||
|
var commandTag pgconn.CommandTag
|
||||||
|
switch results := results.(type) {
|
||||||
|
case *pgconn.ResultReader:
|
||||||
|
commandTag, br.err = results.Close()
|
||||||
|
default:
|
||||||
|
return pgconn.CommandTag{}, fmt.Errorf("unexpected pipeline result: %T", results)
|
||||||
|
}
|
||||||
|
|
||||||
|
if br.conn.batchTracer != nil {
|
||||||
|
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
|
||||||
|
SQL: query,
|
||||||
|
Args: arguments,
|
||||||
|
CommandTag: commandTag,
|
||||||
|
Err: br.err,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return commandTag, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query reads the results from the next query in the batch as if the query has been sent with Query.
|
||||||
|
func (br *pipelineBatchResults) Query() (Rows, error) {
|
||||||
|
if br.err != nil {
|
||||||
|
return &baseRows{err: br.err, closed: true}, br.err
|
||||||
|
}
|
||||||
|
|
||||||
|
if br.closed {
|
||||||
|
alreadyClosedErr := fmt.Errorf("batch already closed")
|
||||||
|
return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr
|
||||||
|
}
|
||||||
|
|
||||||
|
if br.lastRows != nil && br.lastRows.err != nil {
|
||||||
|
br.err = br.lastRows.err
|
||||||
|
return &baseRows{err: br.err, closed: true}, br.err
|
||||||
|
}
|
||||||
|
|
||||||
|
query, arguments, ok := br.nextQueryAndArgs()
|
||||||
|
if !ok {
|
||||||
|
query = "batch query"
|
||||||
|
}
|
||||||
|
|
||||||
|
rows := br.conn.getRows(br.ctx, query, arguments)
|
||||||
|
rows.batchTracer = br.conn.batchTracer
|
||||||
|
br.lastRows = rows
|
||||||
|
|
||||||
|
results, err := br.pipeline.GetResults()
|
||||||
|
if err != nil {
|
||||||
|
br.err = err
|
||||||
|
rows.err = err
|
||||||
|
rows.closed = true
|
||||||
|
|
||||||
|
if br.conn.batchTracer != nil {
|
||||||
|
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
|
||||||
|
SQL: query,
|
||||||
|
Args: arguments,
|
||||||
|
Err: err,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
switch results := results.(type) {
|
||||||
|
case *pgconn.ResultReader:
|
||||||
|
rows.resultReader = results
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("unexpected pipeline result: %T", results)
|
||||||
|
br.err = err
|
||||||
|
rows.err = err
|
||||||
|
rows.closed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return rows, rows.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow.
|
||||||
|
func (br *pipelineBatchResults) QueryRow() Row {
|
||||||
|
rows, _ := br.Query()
|
||||||
|
return (*connRow)(rows.(*baseRows))
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to
|
||||||
|
// resyncronize the connection with the server. In this case the underlying connection will have been closed.
|
||||||
|
func (br *pipelineBatchResults) Close() error {
|
||||||
|
defer func() {
|
||||||
|
if !br.endTraced {
|
||||||
|
if br.conn.batchTracer != nil {
|
||||||
|
br.conn.batchTracer.TraceBatchEnd(br.ctx, br.conn, TraceBatchEndData{Err: br.err})
|
||||||
|
}
|
||||||
|
br.endTraced = true
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if br.err != nil {
|
||||||
|
return br.err
|
||||||
|
}
|
||||||
|
|
||||||
|
if br.lastRows != nil && br.lastRows.err != nil {
|
||||||
|
br.err = br.lastRows.err
|
||||||
|
return br.err
|
||||||
|
}
|
||||||
|
|
||||||
|
if br.closed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read and run fn for all remaining items
|
||||||
|
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
|
||||||
|
if br.b.queuedQueries[br.qqIdx].fn != nil {
|
||||||
|
err := br.b.queuedQueries[br.qqIdx].fn(br)
|
||||||
|
if err != nil && br.err == nil {
|
||||||
|
br.err = err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
br.Exec()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
br.closed = true
|
||||||
|
|
||||||
|
err := br.pipeline.Close()
|
||||||
|
if br.err == nil {
|
||||||
|
br.err = err
|
||||||
|
}
|
||||||
|
|
||||||
|
return br.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (br *pipelineBatchResults) earlyError() error {
|
||||||
|
return br.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
|
||||||
|
if br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
|
||||||
|
bi := br.b.queuedQueries[br.qqIdx]
|
||||||
|
query = bi.query
|
||||||
|
args = bi.arguments
|
||||||
|
ok = true
|
||||||
|
br.qqIdx++
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
+599
-523
File diff suppressed because it is too large
Load Diff
+63
-220
@@ -12,16 +12,32 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jackc/pgconn"
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/jackc/pgconn/stmtcache"
|
"github.com/jackc/pgx/v5/internal/nbconn"
|
||||||
"github.com/jackc/pgtype"
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
"github.com/jackc/pgx/v4"
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func BenchmarkConnectClose(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = conn.Close(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkMinimalUnpreparedSelectWithoutStatementCache(b *testing.B) {
|
func BenchmarkMinimalUnpreparedSelectWithoutStatementCache(b *testing.B) {
|
||||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
config.BuildStatementCache = nil
|
config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec
|
||||||
|
config.StatementCacheCapacity = 0
|
||||||
|
config.DescriptionCacheCapacity = 0
|
||||||
|
|
||||||
conn := mustConnect(b, config)
|
conn := mustConnect(b, config)
|
||||||
defer closeConn(b, conn)
|
defer closeConn(b, conn)
|
||||||
@@ -43,9 +59,9 @@ func BenchmarkMinimalUnpreparedSelectWithoutStatementCache(b *testing.B) {
|
|||||||
|
|
||||||
func BenchmarkMinimalUnpreparedSelectWithStatementCacheModeDescribe(b *testing.B) {
|
func BenchmarkMinimalUnpreparedSelectWithStatementCacheModeDescribe(b *testing.B) {
|
||||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache {
|
config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe
|
||||||
return stmtcache.New(conn, stmtcache.ModeDescribe, 32)
|
config.StatementCacheCapacity = 0
|
||||||
}
|
config.DescriptionCacheCapacity = 32
|
||||||
|
|
||||||
conn := mustConnect(b, config)
|
conn := mustConnect(b, config)
|
||||||
defer closeConn(b, conn)
|
defer closeConn(b, conn)
|
||||||
@@ -67,9 +83,9 @@ func BenchmarkMinimalUnpreparedSelectWithStatementCacheModeDescribe(b *testing.B
|
|||||||
|
|
||||||
func BenchmarkMinimalUnpreparedSelectWithStatementCacheModePrepare(b *testing.B) {
|
func BenchmarkMinimalUnpreparedSelectWithStatementCacheModePrepare(b *testing.B) {
|
||||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache {
|
config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement
|
||||||
return stmtcache.New(conn, stmtcache.ModePrepare, 32)
|
config.StatementCacheCapacity = 32
|
||||||
}
|
config.DescriptionCacheCapacity = 0
|
||||||
|
|
||||||
conn := mustConnect(b, config)
|
conn := mustConnect(b, config)
|
||||||
defer closeConn(b, conn)
|
defer closeConn(b, conn)
|
||||||
@@ -268,123 +284,6 @@ func BenchmarkPointerPointerWithPresentValues(b *testing.B) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkSelectWithoutLogging(b *testing.B) {
|
|
||||||
conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")))
|
|
||||||
defer closeConn(b, conn)
|
|
||||||
|
|
||||||
benchmarkSelectWithLog(b, conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
type discardLogger struct{}
|
|
||||||
|
|
||||||
func (dl discardLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkSelectWithLoggingTraceDiscard(b *testing.B) {
|
|
||||||
var logger discardLogger
|
|
||||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
|
||||||
config.Logger = logger
|
|
||||||
config.LogLevel = pgx.LogLevelTrace
|
|
||||||
|
|
||||||
conn := mustConnect(b, config)
|
|
||||||
defer closeConn(b, conn)
|
|
||||||
|
|
||||||
benchmarkSelectWithLog(b, conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkSelectWithLoggingDebugWithDiscard(b *testing.B) {
|
|
||||||
var logger discardLogger
|
|
||||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
|
||||||
config.Logger = logger
|
|
||||||
config.LogLevel = pgx.LogLevelDebug
|
|
||||||
|
|
||||||
conn := mustConnect(b, config)
|
|
||||||
defer closeConn(b, conn)
|
|
||||||
|
|
||||||
benchmarkSelectWithLog(b, conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkSelectWithLoggingInfoWithDiscard(b *testing.B) {
|
|
||||||
var logger discardLogger
|
|
||||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
|
||||||
config.Logger = logger
|
|
||||||
config.LogLevel = pgx.LogLevelInfo
|
|
||||||
|
|
||||||
conn := mustConnect(b, config)
|
|
||||||
defer closeConn(b, conn)
|
|
||||||
|
|
||||||
benchmarkSelectWithLog(b, conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkSelectWithLoggingErrorWithDiscard(b *testing.B) {
|
|
||||||
var logger discardLogger
|
|
||||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
|
||||||
config.Logger = logger
|
|
||||||
config.LogLevel = pgx.LogLevelError
|
|
||||||
|
|
||||||
conn := mustConnect(b, config)
|
|
||||||
defer closeConn(b, conn)
|
|
||||||
|
|
||||||
benchmarkSelectWithLog(b, conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func benchmarkSelectWithLog(b *testing.B, conn *pgx.Conn) {
|
|
||||||
_, err := conn.Prepare(context.Background(), "test", "select 1::int4, 'johnsmith', 'johnsmith@example.com', 'John Smith', 'male', '1970-01-01'::date, '2015-01-01 00:00:00'::timestamptz")
|
|
||||||
if err != nil {
|
|
||||||
b.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
var record struct {
|
|
||||||
id int32
|
|
||||||
userName string
|
|
||||||
email string
|
|
||||||
name string
|
|
||||||
sex string
|
|
||||||
birthDate time.Time
|
|
||||||
lastLoginTime time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
err = conn.QueryRow(context.Background(), "test").Scan(
|
|
||||||
&record.id,
|
|
||||||
&record.userName,
|
|
||||||
&record.email,
|
|
||||||
&record.name,
|
|
||||||
&record.sex,
|
|
||||||
&record.birthDate,
|
|
||||||
&record.lastLoginTime,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
b.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// These checks both ensure that the correct data was returned
|
|
||||||
// and provide a benchmark of accessing the returned values.
|
|
||||||
if record.id != 1 {
|
|
||||||
b.Fatalf("bad value for id: %v", record.id)
|
|
||||||
}
|
|
||||||
if record.userName != "johnsmith" {
|
|
||||||
b.Fatalf("bad value for userName: %v", record.userName)
|
|
||||||
}
|
|
||||||
if record.email != "johnsmith@example.com" {
|
|
||||||
b.Fatalf("bad value for email: %v", record.email)
|
|
||||||
}
|
|
||||||
if record.name != "John Smith" {
|
|
||||||
b.Fatalf("bad value for name: %v", record.name)
|
|
||||||
}
|
|
||||||
if record.sex != "male" {
|
|
||||||
b.Fatalf("bad value for sex: %v", record.sex)
|
|
||||||
}
|
|
||||||
if record.birthDate != time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC) {
|
|
||||||
b.Fatalf("bad value for birthDate: %v", record.birthDate)
|
|
||||||
}
|
|
||||||
if record.lastLoginTime != time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local) {
|
|
||||||
b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const benchmarkWriteTableCreateSQL = `drop table if exists t;
|
const benchmarkWriteTableCreateSQL = `drop table if exists t;
|
||||||
|
|
||||||
create table t(
|
create table t(
|
||||||
@@ -437,7 +336,7 @@ const benchmarkWriteTableInsertSQL = `insert into t(
|
|||||||
type benchmarkWriteTableCopyFromSrc struct {
|
type benchmarkWriteTableCopyFromSrc struct {
|
||||||
count int
|
count int
|
||||||
idx int
|
idx int
|
||||||
row []interface{}
|
row []any
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *benchmarkWriteTableCopyFromSrc) Next() bool {
|
func (s *benchmarkWriteTableCopyFromSrc) Next() bool {
|
||||||
@@ -445,7 +344,7 @@ func (s *benchmarkWriteTableCopyFromSrc) Next() bool {
|
|||||||
return s.idx < s.count
|
return s.idx < s.count
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *benchmarkWriteTableCopyFromSrc) Values() ([]interface{}, error) {
|
func (s *benchmarkWriteTableCopyFromSrc) Values() ([]any, error) {
|
||||||
return s.row, nil
|
return s.row, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -456,15 +355,15 @@ func (s *benchmarkWriteTableCopyFromSrc) Err() error {
|
|||||||
func newBenchmarkWriteTableCopyFromSrc(count int) pgx.CopyFromSource {
|
func newBenchmarkWriteTableCopyFromSrc(count int) pgx.CopyFromSource {
|
||||||
return &benchmarkWriteTableCopyFromSrc{
|
return &benchmarkWriteTableCopyFromSrc{
|
||||||
count: count,
|
count: count,
|
||||||
row: []interface{}{
|
row: []any{
|
||||||
"varchar_1",
|
"varchar_1",
|
||||||
"varchar_2",
|
"varchar_2",
|
||||||
&pgtype.Text{Status: pgtype.Null},
|
&pgtype.Text{},
|
||||||
time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local),
|
time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local),
|
||||||
&pgtype.Date{Status: pgtype.Null},
|
&pgtype.Date{},
|
||||||
1,
|
1,
|
||||||
2,
|
2,
|
||||||
&pgtype.Int4{Status: pgtype.Null},
|
&pgtype.Int4{},
|
||||||
time.Date(2001, 1, 1, 0, 0, 0, 0, time.Local),
|
time.Date(2001, 1, 1, 0, 0, 0, 0, time.Local),
|
||||||
time.Date(2002, 1, 1, 0, 0, 0, 0, time.Local),
|
time.Date(2002, 1, 1, 0, 0, 0, 0, time.Local),
|
||||||
true,
|
true,
|
||||||
@@ -508,9 +407,9 @@ func benchmarkWriteNRowsViaInsert(b *testing.B, n int) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type queryArgs []interface{}
|
type queryArgs []any
|
||||||
|
|
||||||
func (qa *queryArgs) Append(v interface{}) string {
|
func (qa *queryArgs) Append(v any) string {
|
||||||
*qa = append(*qa, v)
|
*qa = append(*qa, v)
|
||||||
return "$" + strconv.Itoa(len(*qa))
|
return "$" + strconv.Itoa(len(*qa))
|
||||||
}
|
}
|
||||||
@@ -723,7 +622,9 @@ func BenchmarkWrite10000RowsViaCopy(b *testing.B) {
|
|||||||
|
|
||||||
func BenchmarkMultipleQueriesNonBatchNoStatementCache(b *testing.B) {
|
func BenchmarkMultipleQueriesNonBatchNoStatementCache(b *testing.B) {
|
||||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
config.BuildStatementCache = nil
|
config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec
|
||||||
|
config.StatementCacheCapacity = 0
|
||||||
|
config.DescriptionCacheCapacity = 0
|
||||||
|
|
||||||
conn := mustConnect(b, config)
|
conn := mustConnect(b, config)
|
||||||
defer closeConn(b, conn)
|
defer closeConn(b, conn)
|
||||||
@@ -733,9 +634,9 @@ func BenchmarkMultipleQueriesNonBatchNoStatementCache(b *testing.B) {
|
|||||||
|
|
||||||
func BenchmarkMultipleQueriesNonBatchPrepareStatementCache(b *testing.B) {
|
func BenchmarkMultipleQueriesNonBatchPrepareStatementCache(b *testing.B) {
|
||||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache {
|
config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement
|
||||||
return stmtcache.New(conn, stmtcache.ModePrepare, 32)
|
config.StatementCacheCapacity = 32
|
||||||
}
|
config.DescriptionCacheCapacity = 0
|
||||||
|
|
||||||
conn := mustConnect(b, config)
|
conn := mustConnect(b, config)
|
||||||
defer closeConn(b, conn)
|
defer closeConn(b, conn)
|
||||||
@@ -745,9 +646,9 @@ func BenchmarkMultipleQueriesNonBatchPrepareStatementCache(b *testing.B) {
|
|||||||
|
|
||||||
func BenchmarkMultipleQueriesNonBatchDescribeStatementCache(b *testing.B) {
|
func BenchmarkMultipleQueriesNonBatchDescribeStatementCache(b *testing.B) {
|
||||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache {
|
config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe
|
||||||
return stmtcache.New(conn, stmtcache.ModeDescribe, 32)
|
config.StatementCacheCapacity = 0
|
||||||
}
|
config.DescriptionCacheCapacity = 32
|
||||||
|
|
||||||
conn := mustConnect(b, config)
|
conn := mustConnect(b, config)
|
||||||
defer closeConn(b, conn)
|
defer closeConn(b, conn)
|
||||||
@@ -783,7 +684,9 @@ func benchmarkMultipleQueriesNonBatch(b *testing.B, conn *pgx.Conn, queryCount i
|
|||||||
|
|
||||||
func BenchmarkMultipleQueriesBatchNoStatementCache(b *testing.B) {
|
func BenchmarkMultipleQueriesBatchNoStatementCache(b *testing.B) {
|
||||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
config.BuildStatementCache = nil
|
config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec
|
||||||
|
config.StatementCacheCapacity = 0
|
||||||
|
config.DescriptionCacheCapacity = 0
|
||||||
|
|
||||||
conn := mustConnect(b, config)
|
conn := mustConnect(b, config)
|
||||||
defer closeConn(b, conn)
|
defer closeConn(b, conn)
|
||||||
@@ -793,9 +696,9 @@ func BenchmarkMultipleQueriesBatchNoStatementCache(b *testing.B) {
|
|||||||
|
|
||||||
func BenchmarkMultipleQueriesBatchPrepareStatementCache(b *testing.B) {
|
func BenchmarkMultipleQueriesBatchPrepareStatementCache(b *testing.B) {
|
||||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache {
|
config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement
|
||||||
return stmtcache.New(conn, stmtcache.ModePrepare, 32)
|
config.StatementCacheCapacity = 32
|
||||||
}
|
config.DescriptionCacheCapacity = 0
|
||||||
|
|
||||||
conn := mustConnect(b, config)
|
conn := mustConnect(b, config)
|
||||||
defer closeConn(b, conn)
|
defer closeConn(b, conn)
|
||||||
@@ -805,9 +708,9 @@ func BenchmarkMultipleQueriesBatchPrepareStatementCache(b *testing.B) {
|
|||||||
|
|
||||||
func BenchmarkMultipleQueriesBatchDescribeStatementCache(b *testing.B) {
|
func BenchmarkMultipleQueriesBatchDescribeStatementCache(b *testing.B) {
|
||||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache {
|
config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe
|
||||||
return stmtcache.New(conn, stmtcache.ModeDescribe, 32)
|
config.StatementCacheCapacity = 0
|
||||||
}
|
config.DescriptionCacheCapacity = 32
|
||||||
|
|
||||||
conn := mustConnect(b, config)
|
conn := mustConnect(b, config)
|
||||||
defer closeConn(b, conn)
|
defer closeConn(b, conn)
|
||||||
@@ -918,8 +821,7 @@ func BenchmarkSelectManyRegisteredEnum(b *testing.B) {
|
|||||||
err = conn.QueryRow(context.Background(), "select oid from pg_type where typname=$1;", "color").Scan(&oid)
|
err = conn.QueryRow(context.Background(), "select oid from pg_type where typname=$1;", "color").Scan(&oid)
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
|
|
||||||
et := pgtype.NewEnumType("color", []string{"blue", "green", "orange"})
|
conn.TypeMap().RegisterType(&pgtype.Type{Name: "color", OID: oid, Codec: &pgtype.EnumCodec{}})
|
||||||
conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: et, Name: "color", OID: oid})
|
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
var x, y, z string
|
var x, y, z string
|
||||||
@@ -1105,73 +1007,6 @@ func BenchmarkSelectRowsScanDecoder(b *testing.B) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkSelectRowsExplicitDecoding(b *testing.B) {
|
|
||||||
conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
|
|
||||||
defer closeConn(b, conn)
|
|
||||||
|
|
||||||
rowCounts := getSelectRowsCounts(b)
|
|
||||||
|
|
||||||
for _, rowCount := range rowCounts {
|
|
||||||
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
|
|
||||||
br := &BenchRowDecoder{}
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
|
|
||||||
if err != nil {
|
|
||||||
b.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
rawValues := rows.RawValues()
|
|
||||||
|
|
||||||
err = br.ID.DecodeBinary(conn.ConnInfo(), rawValues[0])
|
|
||||||
if err != nil {
|
|
||||||
b.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = br.FirstName.DecodeText(conn.ConnInfo(), rawValues[1])
|
|
||||||
if err != nil {
|
|
||||||
b.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = br.LastName.DecodeText(conn.ConnInfo(), rawValues[2])
|
|
||||||
if err != nil {
|
|
||||||
b.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = br.Sex.DecodeText(conn.ConnInfo(), rawValues[3])
|
|
||||||
if err != nil {
|
|
||||||
b.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = br.BirthDate.DecodeBinary(conn.ConnInfo(), rawValues[4])
|
|
||||||
if err != nil {
|
|
||||||
b.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = br.Weight.DecodeBinary(conn.ConnInfo(), rawValues[5])
|
|
||||||
if err != nil {
|
|
||||||
b.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = br.Height.DecodeBinary(conn.ConnInfo(), rawValues[6])
|
|
||||||
if err != nil {
|
|
||||||
b.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = br.UpdateTime.DecodeBinary(conn.ConnInfo(), rawValues[7])
|
|
||||||
if err != nil {
|
|
||||||
b.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if rows.Err() != nil {
|
|
||||||
b.Fatal(rows.Err())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkSelectRowsPgConnExecText(b *testing.B) {
|
func BenchmarkSelectRowsPgConnExecText(b *testing.B) {
|
||||||
conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
|
conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
defer closeConn(b, conn)
|
defer closeConn(b, conn)
|
||||||
@@ -1285,7 +1120,7 @@ func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type queryRecorder struct {
|
type queryRecorder struct {
|
||||||
conn net.Conn
|
conn nbconn.Conn
|
||||||
writeBuf []byte
|
writeBuf []byte
|
||||||
readCount int
|
readCount int
|
||||||
}
|
}
|
||||||
@@ -1301,6 +1136,14 @@ func (qr *queryRecorder) Write(b []byte) (n int, err error) {
|
|||||||
return qr.conn.Write(b)
|
return qr.conn.Write(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (qr *queryRecorder) BufferReadUntilBlock() error {
|
||||||
|
return qr.conn.BufferReadUntilBlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (qr *queryRecorder) Flush() error {
|
||||||
|
return qr.conn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
func (qr *queryRecorder) Close() error {
|
func (qr *queryRecorder) Close() error {
|
||||||
return qr.conn.Close()
|
return qr.conn.Close()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,10 @@ then
|
|||||||
echo "local all postgres trust" > /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
echo "local all postgres trust" > /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||||
echo "local all all trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
echo "local all all trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||||
echo "host all pgx_md5 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
echo "host all pgx_md5 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||||
|
echo "host all pgx_pw 127.0.0.1/32 password" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||||
|
echo "hostssl all pgx_ssl 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||||
|
echo "host replication pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||||
|
echo "host pgx_test pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||||
sudo chmod 777 /etc/postgresql/$PGVERSION/main/postgresql.conf
|
sudo chmod 777 /etc/postgresql/$PGVERSION/main/postgresql.conf
|
||||||
if $(dpkg --compare-versions $PGVERSION ge 9.6) ; then
|
if $(dpkg --compare-versions $PGVERSION ge 9.6) ; then
|
||||||
echo "wal_level='logical'" >> /etc/postgresql/$PGVERSION/main/postgresql.conf
|
echo "wal_level='logical'" >> /etc/postgresql/$PGVERSION/main/postgresql.conf
|
||||||
@@ -24,8 +28,15 @@ then
|
|||||||
psql -U postgres -c 'create database pgx_test'
|
psql -U postgres -c 'create database pgx_test'
|
||||||
psql -U postgres pgx_test -c 'create extension hstore'
|
psql -U postgres pgx_test -c 'create extension hstore'
|
||||||
psql -U postgres pgx_test -c 'create domain uint64 as numeric(20,0)'
|
psql -U postgres pgx_test -c 'create domain uint64 as numeric(20,0)'
|
||||||
|
psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'"
|
||||||
psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'"
|
psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'"
|
||||||
|
psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'"
|
||||||
psql -U postgres -c "create user `whoami`"
|
psql -U postgres -c "create user `whoami`"
|
||||||
|
psql -U postgres -c "create user pgx_replication with replication password 'secret'"
|
||||||
|
|
||||||
|
# The tricky test user, below, has to actually exist so that it can be used in a test
|
||||||
|
# of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles.
|
||||||
|
psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [[ "${PGVERSION-}" =~ ^cockroach ]]
|
if [[ "${PGVERSION-}" =~ ^cockroach ]]
|
||||||
|
|||||||
+202
-263
@@ -3,17 +3,16 @@ package pgx_test
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"log"
|
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jackc/pgconn"
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/jackc/pgconn/stmtcache"
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
"github.com/jackc/pgtype"
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
"github.com/jackc/pgx/v4"
|
"github.com/jackc/pgx/v5/pgxtest"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@@ -83,7 +82,7 @@ func TestConnectWithPreferSimpleProtocol(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
connConfig := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
|
connConfig := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
connConfig.PreferSimpleProtocol = true
|
connConfig.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
|
||||||
|
|
||||||
conn := mustConnect(t, connConfig)
|
conn := mustConnect(t, connConfig)
|
||||||
defer closeConn(t, conn)
|
defer closeConn(t, conn)
|
||||||
@@ -93,13 +92,8 @@ func TestConnectWithPreferSimpleProtocol(t *testing.T) {
|
|||||||
|
|
||||||
var s pgtype.Text
|
var s pgtype.Text
|
||||||
err := conn.QueryRow(context.Background(), "select $1::int4", 42).Scan(&s)
|
err := conn.QueryRow(context.Background(), "select $1::int4", 42).Scan(&s)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatal(err)
|
require.Equal(t, pgtype.Text{String: "42", Valid: true}, s)
|
||||||
}
|
|
||||||
|
|
||||||
if s.Get() != "42" {
|
|
||||||
t.Fatalf(`expected "42", got %v`, s)
|
|
||||||
}
|
|
||||||
|
|
||||||
ensureConnValid(t, conn)
|
ensureConnValid(t, conn)
|
||||||
}
|
}
|
||||||
@@ -144,91 +138,122 @@ func TestParseConfigExtractsStatementCacheOptions(t *testing.T) {
|
|||||||
|
|
||||||
config, err := pgx.ParseConfig("statement_cache_capacity=0")
|
config, err := pgx.ParseConfig("statement_cache_capacity=0")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Nil(t, config.BuildStatementCache)
|
require.EqualValues(t, 0, config.StatementCacheCapacity)
|
||||||
|
|
||||||
config, err = pgx.ParseConfig("statement_cache_capacity=42")
|
config, err = pgx.ParseConfig("statement_cache_capacity=42")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, config.BuildStatementCache)
|
require.EqualValues(t, 42, config.StatementCacheCapacity)
|
||||||
c := config.BuildStatementCache(nil)
|
|
||||||
require.NotNil(t, c)
|
|
||||||
require.Equal(t, 42, c.Cap())
|
|
||||||
require.Equal(t, stmtcache.ModePrepare, c.Mode())
|
|
||||||
|
|
||||||
config, err = pgx.ParseConfig("statement_cache_capacity=42 statement_cache_mode=prepare")
|
config, err = pgx.ParseConfig("description_cache_capacity=0")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, config.BuildStatementCache)
|
require.EqualValues(t, 0, config.DescriptionCacheCapacity)
|
||||||
c = config.BuildStatementCache(nil)
|
|
||||||
require.NotNil(t, c)
|
|
||||||
require.Equal(t, 42, c.Cap())
|
|
||||||
require.Equal(t, stmtcache.ModePrepare, c.Mode())
|
|
||||||
|
|
||||||
config, err = pgx.ParseConfig("statement_cache_capacity=42 statement_cache_mode=describe")
|
config, err = pgx.ParseConfig("description_cache_capacity=42")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, config.BuildStatementCache)
|
require.EqualValues(t, 42, config.DescriptionCacheCapacity)
|
||||||
c = config.BuildStatementCache(nil)
|
|
||||||
require.NotNil(t, c)
|
// default_query_exec_mode
|
||||||
require.Equal(t, 42, c.Cap())
|
// Possible values: "cache_statement", "cache_describe", "describe_exec", "exec", and "simple_protocol". See
|
||||||
require.Equal(t, stmtcache.ModeDescribe, c.Mode())
|
|
||||||
|
config, err = pgx.ParseConfig("default_query_exec_mode=cache_statement")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, pgx.QueryExecModeCacheStatement, config.DefaultQueryExecMode)
|
||||||
|
|
||||||
|
config, err = pgx.ParseConfig("default_query_exec_mode=cache_describe")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, pgx.QueryExecModeCacheDescribe, config.DefaultQueryExecMode)
|
||||||
|
|
||||||
|
config, err = pgx.ParseConfig("default_query_exec_mode=describe_exec")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, pgx.QueryExecModeDescribeExec, config.DefaultQueryExecMode)
|
||||||
|
|
||||||
|
config, err = pgx.ParseConfig("default_query_exec_mode=exec")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, pgx.QueryExecModeExec, config.DefaultQueryExecMode)
|
||||||
|
|
||||||
|
config, err = pgx.ParseConfig("default_query_exec_mode=simple_protocol")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, pgx.QueryExecModeSimpleProtocol, config.DefaultQueryExecMode)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseConfigExtractsPreferSimpleProtocol(t *testing.T) {
|
func TestParseConfigExtractsDefaultQueryExecMode(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
for _, tt := range []struct {
|
for _, tt := range []struct {
|
||||||
connString string
|
connString string
|
||||||
preferSimpleProtocol bool
|
defaultQueryExecMode pgx.QueryExecMode
|
||||||
}{
|
}{
|
||||||
{"", false},
|
{"", pgx.QueryExecModeCacheStatement},
|
||||||
{"prefer_simple_protocol=false", false},
|
{"default_query_exec_mode=cache_statement", pgx.QueryExecModeCacheStatement},
|
||||||
{"prefer_simple_protocol=0", false},
|
{"default_query_exec_mode=cache_describe", pgx.QueryExecModeCacheDescribe},
|
||||||
{"prefer_simple_protocol=true", true},
|
{"default_query_exec_mode=describe_exec", pgx.QueryExecModeDescribeExec},
|
||||||
{"prefer_simple_protocol=1", true},
|
{"default_query_exec_mode=exec", pgx.QueryExecModeExec},
|
||||||
|
{"default_query_exec_mode=simple_protocol", pgx.QueryExecModeSimpleProtocol},
|
||||||
} {
|
} {
|
||||||
config, err := pgx.ParseConfig(tt.connString)
|
config, err := pgx.ParseConfig(tt.connString)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equalf(t, tt.preferSimpleProtocol, config.PreferSimpleProtocol, "connString: `%s`", tt.connString)
|
require.Equalf(t, tt.defaultQueryExecMode, config.DefaultQueryExecMode, "connString: `%s`", tt.connString)
|
||||||
require.Empty(t, config.RuntimeParams["prefer_simple_protocol"])
|
require.Empty(t, config.RuntimeParams["default_query_exec_mode"])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExec(t *testing.T) {
|
func TestExec(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
|
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); string(results) != "CREATE TABLE" {
|
if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); results.String() != "CREATE TABLE" {
|
||||||
t.Error("Unexpected results from Exec")
|
t.Error("Unexpected results from Exec")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accept parameters
|
// Accept parameters
|
||||||
if results := mustExec(t, conn, "insert into foo(id) values($1)", 1); string(results) != "INSERT 0 1" {
|
if results := mustExec(t, conn, "insert into foo(id) values($1)", 1); results.String() != "INSERT 0 1" {
|
||||||
t.Errorf("Unexpected results from Exec: %v", results)
|
t.Errorf("Unexpected results from Exec: %v", results)
|
||||||
}
|
}
|
||||||
|
|
||||||
if results := mustExec(t, conn, "drop table foo;"); string(results) != "DROP TABLE" {
|
if results := mustExec(t, conn, "drop table foo;"); results.String() != "DROP TABLE" {
|
||||||
t.Error("Unexpected results from Exec")
|
t.Error("Unexpected results from Exec")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Multiple statements can be executed -- last command tag is returned
|
// Multiple statements can be executed -- last command tag is returned
|
||||||
if results := mustExec(t, conn, "create temporary table foo(id serial primary key); drop table foo;"); string(results) != "DROP TABLE" {
|
if results := mustExec(t, conn, "create temporary table foo(id serial primary key); drop table foo;"); results.String() != "DROP TABLE" {
|
||||||
t.Error("Unexpected results from Exec")
|
t.Error("Unexpected results from Exec")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Can execute longer SQL strings than sharedBufferSize
|
// Can execute longer SQL strings than sharedBufferSize
|
||||||
if results := mustExec(t, conn, strings.Repeat("select 42; ", 1000)); string(results) != "SELECT 1" {
|
if results := mustExec(t, conn, strings.Repeat("select 42; ", 1000)); results.String() != "SELECT 1" {
|
||||||
t.Errorf("Unexpected results from Exec: %v", results)
|
t.Errorf("Unexpected results from Exec: %v", results)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exec no-op which does not return a command tag
|
// Exec no-op which does not return a command tag
|
||||||
if results := mustExec(t, conn, "--;"); string(results) != "" {
|
if results := mustExec(t, conn, "--;"); results.String() != "" {
|
||||||
t.Errorf("Unexpected results from Exec: %v", results)
|
t.Errorf("Unexpected results from Exec: %v", results)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type testQueryRewriter struct {
|
||||||
|
sql string
|
||||||
|
args []any
|
||||||
|
}
|
||||||
|
|
||||||
|
func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args []any) (newSQL string, newArgs []any) {
|
||||||
|
return qr.sql, qr.args
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecWithQueryRewriter(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
qr := testQueryRewriter{sql: "select $1::int", args: []any{42}}
|
||||||
|
_, err := conn.Exec(ctx, "should be replaced", &qr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestExecFailure(t *testing.T) {
|
func TestExecFailure(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
|
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
if _, err := conn.Exec(context.Background(), "selct;"); err == nil {
|
if _, err := conn.Exec(context.Background(), "selct;"); err == nil {
|
||||||
t.Fatal("Expected SQL syntax error")
|
t.Fatal("Expected SQL syntax error")
|
||||||
}
|
}
|
||||||
@@ -244,7 +269,7 @@ func TestExecFailure(t *testing.T) {
|
|||||||
func TestExecFailureWithArguments(t *testing.T) {
|
func TestExecFailureWithArguments(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
|
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
_, err := conn.Exec(context.Background(), "selct $1;", 1)
|
_, err := conn.Exec(context.Background(), "selct $1;", 1)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("Expected SQL syntax error")
|
t.Fatal("Expected SQL syntax error")
|
||||||
@@ -259,7 +284,7 @@ func TestExecFailureWithArguments(t *testing.T) {
|
|||||||
func TestExecContextWithoutCancelation(t *testing.T) {
|
func TestExecContextWithoutCancelation(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
|
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
defer cancelFunc()
|
defer cancelFunc()
|
||||||
|
|
||||||
@@ -267,7 +292,7 @@ func TestExecContextWithoutCancelation(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if string(commandTag) != "CREATE TABLE" {
|
if commandTag.String() != "CREATE TABLE" {
|
||||||
t.Fatalf("Unexpected results from Exec: %v", commandTag)
|
t.Fatalf("Unexpected results from Exec: %v", commandTag)
|
||||||
}
|
}
|
||||||
assert.False(t, pgconn.SafeToRetry(err))
|
assert.False(t, pgconn.SafeToRetry(err))
|
||||||
@@ -277,7 +302,7 @@ func TestExecContextWithoutCancelation(t *testing.T) {
|
|||||||
func TestExecContextFailureWithoutCancelation(t *testing.T) {
|
func TestExecContextFailureWithoutCancelation(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
|
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
defer cancelFunc()
|
defer cancelFunc()
|
||||||
|
|
||||||
@@ -299,7 +324,7 @@ func TestExecContextFailureWithoutCancelation(t *testing.T) {
|
|||||||
func TestExecContextFailureWithoutCancelationWithArguments(t *testing.T) {
|
func TestExecContextFailureWithoutCancelationWithArguments(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
|
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
defer cancelFunc()
|
defer cancelFunc()
|
||||||
|
|
||||||
@@ -322,56 +347,6 @@ func TestExecFailureCloseBefore(t *testing.T) {
|
|||||||
assert.True(t, pgconn.SafeToRetry(err))
|
assert.True(t, pgconn.SafeToRetry(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExecStatementCacheModes(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
buildStatementCache pgx.BuildStatementCacheFunc
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "disabled",
|
|
||||||
buildStatementCache: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "prepare",
|
|
||||||
buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache {
|
|
||||||
return stmtcache.New(conn, stmtcache.ModePrepare, 32)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "describe",
|
|
||||||
buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache {
|
|
||||||
return stmtcache.New(conn, stmtcache.ModeDescribe, 32)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
func() {
|
|
||||||
config.BuildStatementCache = tt.buildStatementCache
|
|
||||||
conn := mustConnect(t, config)
|
|
||||||
defer closeConn(t, conn)
|
|
||||||
|
|
||||||
commandTag, err := conn.Exec(context.Background(), "select 1")
|
|
||||||
assert.NoError(t, err, tt.name)
|
|
||||||
assert.Equal(t, "SELECT 1", string(commandTag), tt.name)
|
|
||||||
|
|
||||||
commandTag, err = conn.Exec(context.Background(), "select 1 union all select 1")
|
|
||||||
assert.NoError(t, err, tt.name)
|
|
||||||
assert.Equal(t, "SELECT 2", string(commandTag), tt.name)
|
|
||||||
|
|
||||||
commandTag, err = conn.Exec(context.Background(), "select 1")
|
|
||||||
assert.NoError(t, err, tt.name)
|
|
||||||
assert.Equal(t, "SELECT 1", string(commandTag), tt.name)
|
|
||||||
|
|
||||||
ensureConnValid(t, conn)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExecPerQuerySimpleProtocol(t *testing.T) {
|
func TestExecPerQuerySimpleProtocol(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -385,19 +360,19 @@ func TestExecPerQuerySimpleProtocol(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if string(commandTag) != "CREATE TABLE" {
|
if commandTag.String() != "CREATE TABLE" {
|
||||||
t.Fatalf("Unexpected results from Exec: %v", commandTag)
|
t.Fatalf("Unexpected results from Exec: %v", commandTag)
|
||||||
}
|
}
|
||||||
|
|
||||||
commandTag, err = conn.Exec(ctx,
|
commandTag, err = conn.Exec(ctx,
|
||||||
"insert into foo(name) values($1);",
|
"insert into foo(name) values($1);",
|
||||||
pgx.QuerySimpleProtocol(true),
|
pgx.QueryExecModeSimpleProtocol,
|
||||||
"bar'; drop table foo;--",
|
"bar'; drop table foo;--",
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if string(commandTag) != "INSERT 0 1" {
|
if commandTag.String() != "INSERT 0 1" {
|
||||||
t.Fatalf("Unexpected results from Exec: %v", commandTag)
|
t.Fatalf("Unexpected results from Exec: %v", commandTag)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -501,45 +476,15 @@ func TestPrepareIdempotency(t *testing.T) {
|
|||||||
func TestPrepareStatementCacheModes(t *testing.T) {
|
func TestPrepareStatementCacheModes(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
|
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
_, err := conn.Prepare(context.Background(), "test", "select $1::text")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
tests := []struct {
|
var s string
|
||||||
name string
|
err = conn.QueryRow(context.Background(), "test", "hello").Scan(&s)
|
||||||
buildStatementCache pgx.BuildStatementCacheFunc
|
require.NoError(t, err)
|
||||||
}{
|
require.Equal(t, "hello", s)
|
||||||
{
|
})
|
||||||
name: "disabled",
|
|
||||||
buildStatementCache: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "prepare",
|
|
||||||
buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache {
|
|
||||||
return stmtcache.New(conn, stmtcache.ModePrepare, 32)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "describe",
|
|
||||||
buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache {
|
|
||||||
return stmtcache.New(conn, stmtcache.ModeDescribe, 32)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
config.BuildStatementCache = tt.buildStatementCache
|
|
||||||
conn := mustConnect(t, config)
|
|
||||||
defer closeConn(t, conn)
|
|
||||||
|
|
||||||
_, err := conn.Prepare(context.Background(), "test", "select $1::text")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
var s string
|
|
||||||
err = conn.QueryRow(context.Background(), "test", "hello").Scan(&s)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, "hello", s)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestListenNotify(t *testing.T) {
|
func TestListenNotify(t *testing.T) {
|
||||||
@@ -595,7 +540,7 @@ func TestListenNotifyWhileBusyIsSafe(t *testing.T) {
|
|||||||
func() {
|
func() {
|
||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
defer closeConn(t, conn)
|
defer closeConn(t, conn)
|
||||||
skipCockroachDB(t, conn, "Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)")
|
pgxtest.SkipCockroachDB(t, conn, "Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
listenerDone := make(chan bool)
|
listenerDone := make(chan bool)
|
||||||
@@ -671,7 +616,7 @@ func TestListenNotifySelfNotification(t *testing.T) {
|
|||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
defer closeConn(t, conn)
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
skipCockroachDB(t, conn, "Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)")
|
pgxtest.SkipCockroachDB(t, conn, "Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)")
|
||||||
|
|
||||||
mustExec(t, conn, "listen self")
|
mustExec(t, conn, "listen self")
|
||||||
|
|
||||||
@@ -706,7 +651,7 @@ func TestFatalRxError(t *testing.T) {
|
|||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
defer closeConn(t, conn)
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
skipCockroachDB(t, conn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)")
|
pgxtest.SkipCockroachDB(t, conn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)")
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
@@ -745,7 +690,7 @@ func TestFatalTxError(t *testing.T) {
|
|||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
defer closeConn(t, conn)
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
skipCockroachDB(t, conn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)")
|
pgxtest.SkipCockroachDB(t, conn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)")
|
||||||
|
|
||||||
otherConn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
otherConn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
defer otherConn.Close(context.Background())
|
defer otherConn.Close(context.Background())
|
||||||
@@ -770,13 +715,13 @@ func TestFatalTxError(t *testing.T) {
|
|||||||
func TestInsertBoolArray(t *testing.T) {
|
func TestInsertBoolArray(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
|
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); string(results) != "CREATE TABLE" {
|
if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); results.String() != "CREATE TABLE" {
|
||||||
t.Error("Unexpected results from Exec")
|
t.Error("Unexpected results from Exec")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accept parameters
|
// Accept parameters
|
||||||
if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); string(results) != "INSERT 0 1" {
|
if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); results.String() != "INSERT 0 1" {
|
||||||
t.Errorf("Unexpected results from Exec: %v", results)
|
t.Errorf("Unexpected results from Exec: %v", results)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -785,91 +730,18 @@ func TestInsertBoolArray(t *testing.T) {
|
|||||||
func TestInsertTimestampArray(t *testing.T) {
|
func TestInsertTimestampArray(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
|
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); string(results) != "CREATE TABLE" {
|
if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); results.String() != "CREATE TABLE" {
|
||||||
t.Error("Unexpected results from Exec")
|
t.Error("Unexpected results from Exec")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accept parameters
|
// Accept parameters
|
||||||
if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); string(results) != "INSERT 0 1" {
|
if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); results.String() != "INSERT 0 1" {
|
||||||
t.Errorf("Unexpected results from Exec: %v", results)
|
t.Errorf("Unexpected results from Exec: %v", results)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
type testLog struct {
|
|
||||||
lvl pgx.LogLevel
|
|
||||||
msg string
|
|
||||||
data map[string]interface{}
|
|
||||||
}
|
|
||||||
|
|
||||||
type testLogger struct {
|
|
||||||
logs []testLog
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *testLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
|
|
||||||
data["ctxdata"] = ctx.Value("ctxdata")
|
|
||||||
l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLogPassesContext(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
l1 := &testLogger{}
|
|
||||||
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
|
|
||||||
config.Logger = l1
|
|
||||||
|
|
||||||
conn := mustConnect(t, config)
|
|
||||||
defer closeConn(t, conn)
|
|
||||||
|
|
||||||
l1.logs = l1.logs[0:0] // Clear logs written when establishing connection
|
|
||||||
|
|
||||||
ctx := context.WithValue(context.Background(), "ctxdata", "foo")
|
|
||||||
|
|
||||||
if _, err := conn.Exec(ctx, ";"); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(l1.logs) != 1 {
|
|
||||||
t.Fatal("Expected logger to be called once, but it wasn't")
|
|
||||||
}
|
|
||||||
|
|
||||||
if l1.logs[0].data["ctxdata"] != "foo" {
|
|
||||||
t.Fatal("Expected context data to be passed to logger, but it wasn't")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoggerFunc(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
const testMsg = "foo"
|
|
||||||
|
|
||||||
buf := bytes.Buffer{}
|
|
||||||
logger := log.New(&buf, "", 0)
|
|
||||||
|
|
||||||
createAdapterFn := func(logger *log.Logger) pgx.LoggerFunc {
|
|
||||||
return func(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
|
|
||||||
logger.Printf("%s", testMsg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
|
|
||||||
config.Logger = createAdapterFn(logger)
|
|
||||||
|
|
||||||
conn := mustConnect(t, config)
|
|
||||||
defer closeConn(t, conn)
|
|
||||||
|
|
||||||
buf.Reset() // Clear logs written when establishing connection
|
|
||||||
|
|
||||||
if _, err := conn.Exec(context.TODO(), ";"); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.TrimSpace(buf.String()) != testMsg {
|
|
||||||
t.Errorf("Expected logger function to return '%s', but it was '%s'", testMsg, buf.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIdentifierSanitize(t *testing.T) {
|
func TestIdentifierSanitize(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -911,7 +783,7 @@ func TestIdentifierSanitize(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnInitConnInfo(t *testing.T) {
|
func TestConnInitTypeMap(t *testing.T) {
|
||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
defer closeConn(t, conn)
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
@@ -923,11 +795,11 @@ func TestConnInitConnInfo(t *testing.T) {
|
|||||||
"text": pgtype.TextOID,
|
"text": pgtype.TextOID,
|
||||||
}
|
}
|
||||||
for name, oid := range nameOIDs {
|
for name, oid := range nameOIDs {
|
||||||
dtByName, ok := conn.ConnInfo().DataTypeForName(name)
|
dtByName, ok := conn.TypeMap().TypeForName(name)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("Expected type named %v to be present", name)
|
t.Fatalf("Expected type named %v to be present", name)
|
||||||
}
|
}
|
||||||
dtByOID, ok := conn.ConnInfo().DataTypeForOID(oid)
|
dtByOID, ok := conn.TypeMap().TypeForOID(oid)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("Expected type OID %v to be present", oid)
|
t.Fatalf("Expected type OID %v to be present", oid)
|
||||||
}
|
}
|
||||||
@@ -940,8 +812,8 @@ func TestConnInitConnInfo(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnregisteredTypeUsableAsStringArgumentAndBaseResult(t *testing.T) {
|
func TestUnregisteredTypeUsableAsStringArgumentAndBaseResult(t *testing.T) {
|
||||||
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
|
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
skipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)")
|
pgxtest.SkipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)")
|
||||||
|
|
||||||
var n uint64
|
var n uint64
|
||||||
err := conn.QueryRow(context.Background(), "select $1::uint64", "42").Scan(&n)
|
err := conn.QueryRow(context.Background(), "select $1::uint64", "42").Scan(&n)
|
||||||
@@ -956,32 +828,30 @@ func TestUnregisteredTypeUsableAsStringArgumentAndBaseResult(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDomainType(t *testing.T) {
|
func TestDomainType(t *testing.T) {
|
||||||
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
|
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
skipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)")
|
pgxtest.SkipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)")
|
||||||
|
|
||||||
var n uint64
|
|
||||||
|
|
||||||
// Domain type uint64 is a PostgreSQL domain of underlying type numeric.
|
// Domain type uint64 is a PostgreSQL domain of underlying type numeric.
|
||||||
|
|
||||||
err := conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n)
|
// In the extended protocol preparing "select $1::uint64" appears to create a statement that expects a param OID of
|
||||||
|
// uint64 but a result OID of the underlying numeric.
|
||||||
|
|
||||||
|
var s string
|
||||||
|
err := conn.QueryRow(context.Background(), "select $1::uint64", "24").Scan(&s)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "24", s)
|
||||||
|
|
||||||
// A string can be used. But a string cannot be the result because the describe result from the PostgreSQL server gives
|
// Register type
|
||||||
// the underlying type of numeric.
|
|
||||||
err = conn.QueryRow(context.Background(), "select $1::uint64", "42").Scan(&n)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if n != 42 {
|
|
||||||
t.Fatalf("Expected n to be 42, but was %v", n)
|
|
||||||
}
|
|
||||||
|
|
||||||
var uint64OID uint32
|
var uint64OID uint32
|
||||||
err = conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='uint64';").Scan(&uint64OID)
|
err = conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='uint64';").Scan(&uint64OID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("did not find uint64 OID, %v", err)
|
t.Fatalf("did not find uint64 OID, %v", err)
|
||||||
}
|
}
|
||||||
conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: &pgtype.Numeric{}, Name: "uint64", OID: uint64OID})
|
conn.TypeMap().RegisterType(&pgtype.Type{Name: "uint64", OID: uint64OID, Codec: pgtype.NumericCodec{}})
|
||||||
|
|
||||||
|
var n uint64
|
||||||
|
err = conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
// String is still an acceptable argument after registration
|
// String is still an acceptable argument after registration
|
||||||
err = conn.QueryRow(context.Background(), "select $1::uint64", "7").Scan(&n)
|
err = conn.QueryRow(context.Background(), "select $1::uint64", "7").Scan(&n)
|
||||||
@@ -991,15 +861,48 @@ func TestDomainType(t *testing.T) {
|
|||||||
if n != 7 {
|
if n != 7 {
|
||||||
t.Fatalf("Expected n to be 7, but was %v", n)
|
t.Fatalf("Expected n to be 7, but was %v", n)
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// But a uint64 is acceptable
|
func TestLoadTypeSameNameInDifferentSchemas(t *testing.T) {
|
||||||
err = conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n)
|
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
if err != nil {
|
pgxtest.SkipCockroachDB(t, conn, "Server does support composite types (https://github.com/cockroachdb/cockroach/issues/27792)")
|
||||||
t.Fatal(err)
|
|
||||||
|
tx, err := conn.Begin(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer tx.Rollback(ctx)
|
||||||
|
|
||||||
|
_, err = tx.Exec(ctx, `create schema pgx_a;
|
||||||
|
create type pgx_a.point as (a text, b text);
|
||||||
|
create schema pgx_b;
|
||||||
|
create type pgx_b.point as (c text);
|
||||||
|
`)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Register types
|
||||||
|
for _, typename := range []string{"pgx_a.point", "pgx_b.point"} {
|
||||||
|
// Obviously using conn while a tx is in use and registering a type after the connection has been established are
|
||||||
|
// really bad practices, but for the sake of convenience we do it in the test here.
|
||||||
|
dt, err := conn.LoadType(ctx, typename)
|
||||||
|
require.NoError(t, err)
|
||||||
|
conn.TypeMap().RegisterType(dt)
|
||||||
}
|
}
|
||||||
if n != 24 {
|
|
||||||
t.Fatalf("Expected n to be 24, but was %v", n)
|
type aPoint struct {
|
||||||
|
A string
|
||||||
|
B string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type bPoint struct {
|
||||||
|
C string
|
||||||
|
}
|
||||||
|
|
||||||
|
var a aPoint
|
||||||
|
var b bPoint
|
||||||
|
err = tx.QueryRow(ctx, `select '(foo,bar)'::pgx_a.point, '(baz)'::pgx_b.point`).Scan(&a, &b)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, aPoint{"foo", "bar"}, a)
|
||||||
|
require.Equal(t, bPoint{"baz"}, b)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1028,6 +931,7 @@ func TestStmtCacheInvalidationConn(t *testing.T) {
|
|||||||
rows, err := conn.Query(ctx, getSQL, 1)
|
rows, err := conn.Query(ctx, getSQL, 1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
rows.Close()
|
rows.Close()
|
||||||
|
require.NoError(t, rows.Err())
|
||||||
|
|
||||||
// Now, change the schema of the table out from under the statement, making it invalid.
|
// Now, change the schema of the table out from under the statement, making it invalid.
|
||||||
_, err = conn.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1")
|
_, err = conn.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1")
|
||||||
@@ -1045,10 +949,10 @@ func TestStmtCacheInvalidationConn(t *testing.T) {
|
|||||||
rows.Close()
|
rows.Close()
|
||||||
for _, err := range []error{nextErr, rows.Err()} {
|
for _, err := range []error{nextErr, rows.Err()} {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected InvalidCachedStatementPlanError: no error")
|
t.Fatal(`expected "cached plan must not change result type": no error`)
|
||||||
}
|
}
|
||||||
if !strings.Contains(err.Error(), "cached plan must not change result type") {
|
if !strings.Contains(err.Error(), "cached plan must not change result type") {
|
||||||
t.Fatalf("expected InvalidCachedStatementPlanError, got: %s", err.Error())
|
t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1070,6 +974,10 @@ func TestStmtCacheInvalidationTx(t *testing.T) {
|
|||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
defer closeConn(t, conn)
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
if conn.PgConn().ParameterStatus("crdb_version") != "" {
|
||||||
|
t.Skip("Server has non-standard prepare in errored transaction behavior (https://github.com/cockroachdb/cockroach/issues/84140)")
|
||||||
|
}
|
||||||
|
|
||||||
// create a table and fill it with some data
|
// create a table and fill it with some data
|
||||||
_, err := conn.Exec(ctx, `
|
_, err := conn.Exec(ctx, `
|
||||||
DROP TABLE IF EXISTS drop_cols;
|
DROP TABLE IF EXISTS drop_cols;
|
||||||
@@ -1092,6 +1000,7 @@ func TestStmtCacheInvalidationTx(t *testing.T) {
|
|||||||
rows, err := tx.Query(ctx, getSQL, 1)
|
rows, err := tx.Query(ctx, getSQL, 1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
rows.Close()
|
rows.Close()
|
||||||
|
require.NoError(t, rows.Err())
|
||||||
|
|
||||||
// Now, change the schema of the table out from under the statement, making it invalid.
|
// Now, change the schema of the table out from under the statement, making it invalid.
|
||||||
_, err = tx.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1")
|
_, err = tx.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1")
|
||||||
@@ -1109,18 +1018,17 @@ func TestStmtCacheInvalidationTx(t *testing.T) {
|
|||||||
rows.Close()
|
rows.Close()
|
||||||
for _, err := range []error{nextErr, rows.Err()} {
|
for _, err := range []error{nextErr, rows.Err()} {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected InvalidCachedStatementPlanError: no error")
|
t.Fatal(`expected "cached plan must not change result type": no error`)
|
||||||
}
|
}
|
||||||
if !strings.Contains(err.Error(), "cached plan must not change result type") {
|
if !strings.Contains(err.Error(), "cached plan must not change result type") {
|
||||||
t.Fatalf("expected InvalidCachedStatementPlanError, got: %s", err.Error())
|
t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err = tx.Query(ctx, getSQL, 1)
|
rows, _ = tx.Query(ctx, getSQL, 1)
|
||||||
require.NoError(t, err) // error does not pop up immediately
|
rows.Close()
|
||||||
rows.Next()
|
|
||||||
err = rows.Err()
|
err = rows.Err()
|
||||||
// Retries within the same transaction are errors (really anything except a rollbakc
|
// Retries within the same transaction are errors (really anything except a rollback
|
||||||
// will be an error in this transaction).
|
// will be an error in this transaction).
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
rows.Close()
|
rows.Close()
|
||||||
@@ -1140,7 +1048,7 @@ func TestStmtCacheInvalidationTx(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestInsertDurationInterval(t *testing.T) {
|
func TestInsertDurationInterval(t *testing.T) {
|
||||||
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
|
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
_, err := conn.Exec(context.Background(), "create temporary table t(duration INTERVAL(0) NOT NULL)")
|
_, err := conn.Exec(context.Background(), "create temporary table t(duration INTERVAL(0) NOT NULL)")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -1151,3 +1059,34 @@ func TestInsertDurationInterval(t *testing.T) {
|
|||||||
require.EqualValues(t, 1, n)
|
require.EqualValues(t, 1, n)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRawValuesUnderlyingMemoryReused(t *testing.T) {
|
||||||
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
var buf []byte
|
||||||
|
|
||||||
|
rows, err := conn.Query(ctx, `select 1::int`)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
buf = rows.RawValues()[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, rows.Err())
|
||||||
|
|
||||||
|
original := make([]byte, len(buf))
|
||||||
|
copy(original, buf)
|
||||||
|
|
||||||
|
for i := 0; i < 1_000_000; i++ {
|
||||||
|
rows, err := conn.Query(ctx, `select $1::int`, i)
|
||||||
|
require.NoError(t, err)
|
||||||
|
rows.Close()
|
||||||
|
require.NoError(t, rows.Err())
|
||||||
|
|
||||||
|
if bytes.Compare(original, buf) != 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Fatal("expected buffer from RawValues to be overwritten by subsequent queries but it was not")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
+23
-22
@@ -5,20 +5,19 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/jackc/pgconn"
|
"github.com/jackc/pgx/v5/internal/pgio"
|
||||||
"github.com/jackc/pgio"
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CopyFromRows returns a CopyFromSource interface over the provided rows slice
|
// CopyFromRows returns a CopyFromSource interface over the provided rows slice
|
||||||
// making it usable by *Conn.CopyFrom.
|
// making it usable by *Conn.CopyFrom.
|
||||||
func CopyFromRows(rows [][]interface{}) CopyFromSource {
|
func CopyFromRows(rows [][]any) CopyFromSource {
|
||||||
return ©FromRows{rows: rows, idx: -1}
|
return ©FromRows{rows: rows, idx: -1}
|
||||||
}
|
}
|
||||||
|
|
||||||
type copyFromRows struct {
|
type copyFromRows struct {
|
||||||
rows [][]interface{}
|
rows [][]any
|
||||||
idx int
|
idx int
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -27,7 +26,7 @@ func (ctr *copyFromRows) Next() bool {
|
|||||||
return ctr.idx < len(ctr.rows)
|
return ctr.idx < len(ctr.rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ctr *copyFromRows) Values() ([]interface{}, error) {
|
func (ctr *copyFromRows) Values() ([]any, error) {
|
||||||
return ctr.rows[ctr.idx], nil
|
return ctr.rows[ctr.idx], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -37,12 +36,12 @@ func (ctr *copyFromRows) Err() error {
|
|||||||
|
|
||||||
// CopyFromSlice returns a CopyFromSource interface over a dynamic func
|
// CopyFromSlice returns a CopyFromSource interface over a dynamic func
|
||||||
// making it usable by *Conn.CopyFrom.
|
// making it usable by *Conn.CopyFrom.
|
||||||
func CopyFromSlice(length int, next func(int) ([]interface{}, error)) CopyFromSource {
|
func CopyFromSlice(length int, next func(int) ([]any, error)) CopyFromSource {
|
||||||
return ©FromSlice{next: next, idx: -1, len: length}
|
return ©FromSlice{next: next, idx: -1, len: length}
|
||||||
}
|
}
|
||||||
|
|
||||||
type copyFromSlice struct {
|
type copyFromSlice struct {
|
||||||
next func(int) ([]interface{}, error)
|
next func(int) ([]any, error)
|
||||||
idx int
|
idx int
|
||||||
len int
|
len int
|
||||||
err error
|
err error
|
||||||
@@ -53,7 +52,7 @@ func (cts *copyFromSlice) Next() bool {
|
|||||||
return cts.idx < cts.len
|
return cts.idx < cts.len
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cts *copyFromSlice) Values() ([]interface{}, error) {
|
func (cts *copyFromSlice) Values() ([]any, error) {
|
||||||
values, err := cts.next(cts.idx)
|
values, err := cts.next(cts.idx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cts.err = err
|
cts.err = err
|
||||||
@@ -73,7 +72,7 @@ type CopyFromSource interface {
|
|||||||
Next() bool
|
Next() bool
|
||||||
|
|
||||||
// Values returns the values for the current row.
|
// Values returns the values for the current row.
|
||||||
Values() ([]interface{}, error)
|
Values() ([]any, error)
|
||||||
|
|
||||||
// Err returns any error that has been encountered by the CopyFromSource. If
|
// Err returns any error that has been encountered by the CopyFromSource. If
|
||||||
// this is not nil *Conn.CopyFrom will abort the copy.
|
// this is not nil *Conn.CopyFrom will abort the copy.
|
||||||
@@ -89,6 +88,13 @@ type copyFrom struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ct *copyFrom) run(ctx context.Context) (int64, error) {
|
func (ct *copyFrom) run(ctx context.Context) (int64, error) {
|
||||||
|
if ct.conn.copyFromTracer != nil {
|
||||||
|
ctx = ct.conn.copyFromTracer.TraceCopyFromStart(ctx, ct.conn, TraceCopyFromStartData{
|
||||||
|
TableName: ct.tableName,
|
||||||
|
ColumnNames: ct.columnNames,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
quotedTableName := ct.tableName.Sanitize()
|
quotedTableName := ct.tableName.Sanitize()
|
||||||
cbuf := &bytes.Buffer{}
|
cbuf := &bytes.Buffer{}
|
||||||
for i, cn := range ct.columnNames {
|
for i, cn := range ct.columnNames {
|
||||||
@@ -145,24 +151,19 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) {
|
|||||||
w.Close()
|
w.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
startTime := time.Now()
|
|
||||||
|
|
||||||
commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
|
commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
|
||||||
|
|
||||||
r.Close()
|
r.Close()
|
||||||
<-doneChan
|
<-doneChan
|
||||||
|
|
||||||
rowsAffected := commandTag.RowsAffected()
|
if ct.conn.copyFromTracer != nil {
|
||||||
endTime := time.Now()
|
ct.conn.copyFromTracer.TraceCopyFromEnd(ctx, ct.conn, TraceCopyFromEndData{
|
||||||
if err == nil {
|
CommandTag: commandTag,
|
||||||
if ct.conn.shouldLog(LogLevelInfo) {
|
Err: err,
|
||||||
ct.conn.log(ctx, LogLevelInfo, "CopyFrom", map[string]interface{}{"tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime), "rowCount": rowsAffected})
|
})
|
||||||
}
|
|
||||||
} else if ct.conn.shouldLog(LogLevelError) {
|
|
||||||
ct.conn.log(ctx, LogLevelError, "CopyFrom", map[string]interface{}{"err": err, "tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime)})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return rowsAffected, err
|
return commandTag.RowsAffected(), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) {
|
func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) {
|
||||||
@@ -178,7 +179,7 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (b
|
|||||||
|
|
||||||
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
|
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
|
||||||
for i, val := range values {
|
for i, val := range values {
|
||||||
buf, err = encodePreparedStatementArgument(ct.conn.connInfo, buf, sd.Fields[i].DataTypeOID, val)
|
buf, err = encodeCopyValue(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, nil, err
|
return false, nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
+72
-32
@@ -8,8 +8,9 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jackc/pgconn"
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/jackc/pgx/v4"
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
"github.com/jackc/pgx/v5/pgxtest"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -31,7 +32,7 @@ func TestConnCopyFromSmall(t *testing.T) {
|
|||||||
|
|
||||||
tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
|
tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
|
||||||
|
|
||||||
inputRows := [][]interface{}{
|
inputRows := [][]any{
|
||||||
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime},
|
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime},
|
||||||
{nil, nil, nil, nil, nil, nil, nil},
|
{nil, nil, nil, nil, nil, nil, nil},
|
||||||
}
|
}
|
||||||
@@ -49,7 +50,7 @@ func TestConnCopyFromSmall(t *testing.T) {
|
|||||||
t.Errorf("Unexpected error for Query: %v", err)
|
t.Errorf("Unexpected error for Query: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var outputRows [][]interface{}
|
var outputRows [][]any
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
row, err := rows.Values()
|
row, err := rows.Values()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -87,13 +88,13 @@ func TestConnCopyFromSliceSmall(t *testing.T) {
|
|||||||
|
|
||||||
tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
|
tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
|
||||||
|
|
||||||
inputRows := [][]interface{}{
|
inputRows := [][]any{
|
||||||
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime},
|
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime},
|
||||||
{nil, nil, nil, nil, nil, nil, nil},
|
{nil, nil, nil, nil, nil, nil, nil},
|
||||||
}
|
}
|
||||||
|
|
||||||
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"},
|
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"},
|
||||||
pgx.CopyFromSlice(len(inputRows), func(i int) ([]interface{}, error) {
|
pgx.CopyFromSlice(len(inputRows), func(i int) ([]any, error) {
|
||||||
return inputRows[i], nil
|
return inputRows[i], nil
|
||||||
}))
|
}))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -108,7 +109,7 @@ func TestConnCopyFromSliceSmall(t *testing.T) {
|
|||||||
t.Errorf("Unexpected error for Query: %v", err)
|
t.Errorf("Unexpected error for Query: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var outputRows [][]interface{}
|
var outputRows [][]any
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
row, err := rows.Values()
|
row, err := rows.Values()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -134,7 +135,7 @@ func TestConnCopyFromLarge(t *testing.T) {
|
|||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
defer closeConn(t, conn)
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
skipCockroachDB(t, conn, "Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/52722)")
|
pgxtest.SkipCockroachDB(t, conn, "Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/52722)")
|
||||||
|
|
||||||
mustExec(t, conn, `create temporary table foo(
|
mustExec(t, conn, `create temporary table foo(
|
||||||
a int2,
|
a int2,
|
||||||
@@ -149,10 +150,10 @@ func TestConnCopyFromLarge(t *testing.T) {
|
|||||||
|
|
||||||
tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
|
tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
|
||||||
|
|
||||||
inputRows := [][]interface{}{}
|
inputRows := [][]any{}
|
||||||
|
|
||||||
for i := 0; i < 10000; i++ {
|
for i := 0; i < 10000; i++ {
|
||||||
inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime, []byte{111, 111, 111, 111}})
|
inputRows = append(inputRows, []any{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime, []byte{111, 111, 111, 111}})
|
||||||
}
|
}
|
||||||
|
|
||||||
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows))
|
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows))
|
||||||
@@ -168,7 +169,7 @@ func TestConnCopyFromLarge(t *testing.T) {
|
|||||||
t.Errorf("Unexpected error for Query: %v", err)
|
t.Errorf("Unexpected error for Query: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var outputRows [][]interface{}
|
var outputRows [][]any
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
row, err := rows.Values()
|
row, err := rows.Values()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -211,6 +212,14 @@ func TestConnCopyFromEnum(t *testing.T) {
|
|||||||
_, err = tx.Exec(ctx, `create type fruit as enum ('apple', 'orange', 'grape')`)
|
_, err = tx.Exec(ctx, `create type fruit as enum ('apple', 'orange', 'grape')`)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Obviously using conn while a tx is in use and registering a type after the connection has been established are
|
||||||
|
// really bad practices, but for the sake of convenience we do it in the test here.
|
||||||
|
for _, name := range []string{"fruit", "color"} {
|
||||||
|
typ, err := conn.LoadType(ctx, name)
|
||||||
|
require.NoError(t, err)
|
||||||
|
conn.TypeMap().RegisterType(typ)
|
||||||
|
}
|
||||||
|
|
||||||
_, err = tx.Exec(ctx, `create table foo(
|
_, err = tx.Exec(ctx, `create table foo(
|
||||||
a text,
|
a text,
|
||||||
b color,
|
b color,
|
||||||
@@ -221,7 +230,7 @@ func TestConnCopyFromEnum(t *testing.T) {
|
|||||||
)`)
|
)`)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
inputRows := [][]interface{}{
|
inputRows := [][]any{
|
||||||
{"abc", "blue", "grape", "orange", "orange", "def"},
|
{"abc", "blue", "grape", "orange", "orange", "def"},
|
||||||
{nil, nil, nil, nil, nil, nil},
|
{nil, nil, nil, nil, nil, nil},
|
||||||
}
|
}
|
||||||
@@ -233,7 +242,7 @@ func TestConnCopyFromEnum(t *testing.T) {
|
|||||||
rows, err := conn.Query(ctx, "select * from foo")
|
rows, err := conn.Query(ctx, "select * from foo")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var outputRows [][]interface{}
|
var outputRows [][]any
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
row, err := rows.Values()
|
row, err := rows.Values()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -256,7 +265,7 @@ func TestConnCopyFromJSON(t *testing.T) {
|
|||||||
defer closeConn(t, conn)
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
for _, typeName := range []string{"json", "jsonb"} {
|
for _, typeName := range []string{"json", "jsonb"} {
|
||||||
if _, ok := conn.ConnInfo().DataTypeForName(typeName); !ok {
|
if _, ok := conn.TypeMap().TypeForName(typeName); !ok {
|
||||||
return // No JSON/JSONB type -- must be running against old PostgreSQL
|
return // No JSON/JSONB type -- must be running against old PostgreSQL
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -266,8 +275,8 @@ func TestConnCopyFromJSON(t *testing.T) {
|
|||||||
b jsonb
|
b jsonb
|
||||||
)`)
|
)`)
|
||||||
|
|
||||||
inputRows := [][]interface{}{
|
inputRows := [][]any{
|
||||||
{map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}},
|
{map[string]any{"foo": "bar"}, map[string]any{"bar": "quz"}},
|
||||||
{nil, nil},
|
{nil, nil},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -284,7 +293,7 @@ func TestConnCopyFromJSON(t *testing.T) {
|
|||||||
t.Errorf("Unexpected error for Query: %v", err)
|
t.Errorf("Unexpected error for Query: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var outputRows [][]interface{}
|
var outputRows [][]any
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
row, err := rows.Values()
|
row, err := rows.Values()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -314,12 +323,12 @@ func (cfs *clientFailSource) Next() bool {
|
|||||||
return cfs.count < 100
|
return cfs.count < 100
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cfs *clientFailSource) Values() ([]interface{}, error) {
|
func (cfs *clientFailSource) Values() ([]any, error) {
|
||||||
if cfs.count == 3 {
|
if cfs.count == 3 {
|
||||||
cfs.err = fmt.Errorf("client error")
|
cfs.err = fmt.Errorf("client error")
|
||||||
return nil, cfs.err
|
return nil, cfs.err
|
||||||
}
|
}
|
||||||
return []interface{}{make([]byte, 100000)}, nil
|
return []any{make([]byte, 100000)}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cfs *clientFailSource) Err() error {
|
func (cfs *clientFailSource) Err() error {
|
||||||
@@ -337,7 +346,7 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) {
|
|||||||
b varchar not null
|
b varchar not null
|
||||||
)`)
|
)`)
|
||||||
|
|
||||||
inputRows := [][]interface{}{
|
inputRows := [][]any{
|
||||||
{int32(1), "abc"},
|
{int32(1), "abc"},
|
||||||
{int32(2), nil}, // this row should trigger a failure
|
{int32(2), nil}, // this row should trigger a failure
|
||||||
{int32(3), "def"},
|
{int32(3), "def"},
|
||||||
@@ -359,7 +368,7 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) {
|
|||||||
t.Errorf("Unexpected error for Query: %v", err)
|
t.Errorf("Unexpected error for Query: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var outputRows [][]interface{}
|
var outputRows [][]any
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
row, err := rows.Values()
|
row, err := rows.Values()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -391,11 +400,11 @@ func (fs *failSource) Next() bool {
|
|||||||
return fs.count < 100
|
return fs.count < 100
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fs *failSource) Values() ([]interface{}, error) {
|
func (fs *failSource) Values() ([]any, error) {
|
||||||
if fs.count == 3 {
|
if fs.count == 3 {
|
||||||
return []interface{}{nil}, nil
|
return []any{nil}, nil
|
||||||
}
|
}
|
||||||
return []interface{}{make([]byte, 100000)}, nil
|
return []any{make([]byte, 100000)}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fs *failSource) Err() error {
|
func (fs *failSource) Err() error {
|
||||||
@@ -408,6 +417,8 @@ func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
|
|||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
defer closeConn(t, conn)
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
pgxtest.SkipCockroachDB(t, conn, "Server copy error does not fail fast")
|
||||||
|
|
||||||
mustExec(t, conn, `create temporary table foo(
|
mustExec(t, conn, `create temporary table foo(
|
||||||
a bytea not null
|
a bytea not null
|
||||||
)`)
|
)`)
|
||||||
@@ -436,7 +447,7 @@ func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
|
|||||||
t.Errorf("Unexpected error for Query: %v", err)
|
t.Errorf("Unexpected error for Query: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var outputRows [][]interface{}
|
var outputRows [][]any
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
row, err := rows.Values()
|
row, err := rows.Values()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -466,11 +477,11 @@ func (fs *slowFailRaceSource) Next() bool {
|
|||||||
return fs.count < 1000
|
return fs.count < 1000
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fs *slowFailRaceSource) Values() ([]interface{}, error) {
|
func (fs *slowFailRaceSource) Values() ([]any, error) {
|
||||||
if fs.count == 500 {
|
if fs.count == 500 {
|
||||||
return []interface{}{nil, nil}, nil
|
return []any{nil, nil}, nil
|
||||||
}
|
}
|
||||||
return []interface{}{1, make([]byte, 1000)}, nil
|
return []any{1, make([]byte, 1000)}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fs *slowFailRaceSource) Err() error {
|
func (fs *slowFailRaceSource) Err() error {
|
||||||
@@ -525,7 +536,7 @@ func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) {
|
|||||||
t.Errorf("Unexpected error for Query: %v", err)
|
t.Errorf("Unexpected error for Query: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var outputRows [][]interface{}
|
var outputRows [][]any
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
row, err := rows.Values()
|
row, err := rows.Values()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -554,8 +565,8 @@ func (cfs *clientFinalErrSource) Next() bool {
|
|||||||
return cfs.count < 5
|
return cfs.count < 5
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cfs *clientFinalErrSource) Values() ([]interface{}, error) {
|
func (cfs *clientFinalErrSource) Values() ([]any, error) {
|
||||||
return []interface{}{make([]byte, 100000)}, nil
|
return []any{make([]byte, 100000)}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cfs *clientFinalErrSource) Err() error {
|
func (cfs *clientFinalErrSource) Err() error {
|
||||||
@@ -585,7 +596,7 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {
|
|||||||
t.Errorf("Unexpected error for Query: %v", err)
|
t.Errorf("Unexpected error for Query: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var outputRows [][]interface{}
|
var outputRows [][]any
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
row, err := rows.Values()
|
row, err := rows.Values()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -604,3 +615,32 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {
|
|||||||
|
|
||||||
ensureConnValid(t, conn)
|
ensureConnValid(t, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnCopyFromAutomaticStringConversion(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
mustExec(t, conn, `create temporary table foo(
|
||||||
|
a int8
|
||||||
|
)`)
|
||||||
|
|
||||||
|
inputRows := [][]interface{}{
|
||||||
|
{"42"},
|
||||||
|
{"7"},
|
||||||
|
{8},
|
||||||
|
}
|
||||||
|
|
||||||
|
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, len(inputRows), copyCount)
|
||||||
|
|
||||||
|
rows, _ := conn.Query(context.Background(), "select * from foo")
|
||||||
|
nums, err := pgx.CollectRows(rows, pgx.RowTo[int64])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, []int64{42, 7, 8}, nums)
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
// Package pgx is a PostgreSQL database driver.
|
// Package pgx is a PostgreSQL database driver.
|
||||||
/*
|
/*
|
||||||
pgx provides lower level access to PostgreSQL than the standard database/sql. It remains as similar to the database/sql
|
pgx provides a native PostgreSQL driver and can act as a database/sql driver. The native PostgreSQL interface is similar
|
||||||
interface as possible while providing better speed and access to PostgreSQL specific features. Import
|
to the database/sql interface while providing better speed and access to PostgreSQL specific features. Use
|
||||||
github.com/jackc/pgx/v4/stdlib to use pgx as a database/sql compatible driver.
|
github.com/jackc/pgx/v5/stdlib to use pgx as a database/sql compatible driver. See that package's documentation for
|
||||||
|
details.
|
||||||
|
|
||||||
Establishing a Connection
|
Establishing a Connection
|
||||||
|
|
||||||
@@ -12,57 +13,41 @@ The primary way of establishing a connection is with `pgx.Connect`.
|
|||||||
|
|
||||||
The database connection string can be in URL or DSN format. Both PostgreSQL settings and pgx settings can be specified
|
The database connection string can be in URL or DSN format. Both PostgreSQL settings and pgx settings can be specified
|
||||||
here. In addition, a config struct can be created by `ParseConfig` and modified before establishing the connection with
|
here. In addition, a config struct can be created by `ParseConfig` and modified before establishing the connection with
|
||||||
`ConnectConfig`.
|
`ConnectConfig` to configure settings such as tracing that cannot be configured with a connection string.
|
||||||
|
|
||||||
config, err := pgx.ParseConfig(os.Getenv("DATABASE_URL"))
|
|
||||||
if err != nil {
|
|
||||||
// ...
|
|
||||||
}
|
|
||||||
config.Logger = log15adapter.NewLogger(log.New("module", "pgx"))
|
|
||||||
|
|
||||||
conn, err := pgx.ConnectConfig(context.Background(), config)
|
|
||||||
|
|
||||||
Connection Pool
|
Connection Pool
|
||||||
|
|
||||||
`*pgx.Conn` represents a single connection to the database and is not concurrency safe. Use sub-package pgxpool for a
|
`*pgx.Conn` represents a single connection to the database and is not concurrency safe. Use package
|
||||||
concurrency safe connection pool.
|
github.com/jackc/pgx/v5/pgxpool for a concurrency safe connection pool.
|
||||||
|
|
||||||
Query Interface
|
Query Interface
|
||||||
|
|
||||||
pgx implements Query and Scan in the familiar database/sql style.
|
pgx implements Query in the familiar database/sql style. However, pgx provides generic functions such as CollectRows and
|
||||||
|
ForEachRow that are a simpler and safer way of processing rows than manually calling rows.Next(), rows.Scan, and
|
||||||
|
rows.Err().
|
||||||
|
|
||||||
var sum int32
|
CollectRows can be used collect all returned rows into a slice.
|
||||||
|
|
||||||
// Send the query to the server. The returned rows MUST be closed
|
rows, _ := conn.Query(context.Background(), "select generate_series(1,$1)", 5)
|
||||||
// before conn can be used again.
|
numbers, err := pgx.CollectRows(rows, pgx.RowTo[int32])
|
||||||
rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
// numbers => [1 2 3 4 5]
|
||||||
|
|
||||||
// rows.Close is called by rows.Next when all rows are read
|
ForEachRow can be used to execute a callback function for every row. This is often easier than iterating over rows
|
||||||
// or an error occurs in Next or Scan. So it may optionally be
|
directly.
|
||||||
// omitted if nothing in the rows.Next loop can panic. It is
|
|
||||||
// safe to close rows multiple times.
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
// Iterate through the result set
|
var sum, n int32
|
||||||
for rows.Next() {
|
rows, _ := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
|
||||||
var n int32
|
_, err := pgx.ForEachRow(rows, []any{&n}, func(pgx.QueryFuncRow) error {
|
||||||
err = rows.Scan(&n)
|
sum += n
|
||||||
if err != nil {
|
return nil
|
||||||
return err
|
})
|
||||||
}
|
if err != nil {
|
||||||
sum += n
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Any errors encountered by rows.Next or rows.Scan will be returned here
|
|
||||||
if rows.Err() != nil {
|
|
||||||
return rows.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
// No errors found - do something with sum
|
|
||||||
|
|
||||||
pgx also implements QueryRow in the same style as database/sql.
|
pgx also implements QueryRow in the same style as database/sql.
|
||||||
|
|
||||||
var name string
|
var name string
|
||||||
@@ -82,148 +67,10 @@ Use Exec to execute a query that does not return a result set.
|
|||||||
return errors.New("No row found to delete")
|
return errors.New("No row found to delete")
|
||||||
}
|
}
|
||||||
|
|
||||||
QueryFunc can be used to execute a callback function for every row. This is often easier to use than Query.
|
PostgreSQL Data Types
|
||||||
|
|
||||||
var sum, n int32
|
The package pgtype provides extensive and customizable support for converting Go values to and from PostgreSQL values
|
||||||
_, err = conn.QueryFunc(
|
including array and composite types. See that package's documentation for details.
|
||||||
context.Background(),
|
|
||||||
"select generate_series(1,$1)",
|
|
||||||
[]interface{}{10},
|
|
||||||
[]interface{}{&n},
|
|
||||||
func(pgx.QueryFuncRow) error {
|
|
||||||
sum += n
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
Base Type Mapping
|
|
||||||
|
|
||||||
pgx maps between all common base types directly between Go and PostgreSQL. In particular:
|
|
||||||
|
|
||||||
Go PostgreSQL
|
|
||||||
-----------------------
|
|
||||||
string varchar
|
|
||||||
text
|
|
||||||
|
|
||||||
// Integers are automatically be converted to any other integer type if
|
|
||||||
// it can be done without overflow or underflow.
|
|
||||||
int8
|
|
||||||
int16 smallint
|
|
||||||
int32 int
|
|
||||||
int64 bigint
|
|
||||||
int
|
|
||||||
uint8
|
|
||||||
uint16
|
|
||||||
uint32
|
|
||||||
uint64
|
|
||||||
uint
|
|
||||||
|
|
||||||
// Floats are strict and do not automatically convert like integers.
|
|
||||||
float32 float4
|
|
||||||
float64 float8
|
|
||||||
|
|
||||||
time.Time date
|
|
||||||
timestamp
|
|
||||||
timestamptz
|
|
||||||
|
|
||||||
[]byte bytea
|
|
||||||
|
|
||||||
|
|
||||||
Null Mapping
|
|
||||||
|
|
||||||
pgx can map nulls in two ways. The first is package pgtype provides types that have a data field and a status field.
|
|
||||||
They work in a similar fashion to database/sql. The second is to use a pointer to a pointer.
|
|
||||||
|
|
||||||
var foo pgtype.Varchar
|
|
||||||
var bar *string
|
|
||||||
err := conn.QueryRow("select foo, bar from widgets where id=$1", 42).Scan(&foo, &bar)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
Array Mapping
|
|
||||||
|
|
||||||
pgx maps between int16, int32, int64, float32, float64, and string Go slices and the equivalent PostgreSQL array type.
|
|
||||||
Go slices of native types do not support nulls, so if a PostgreSQL array that contains a null is read into a native Go
|
|
||||||
slice an error will occur. The pgtype package includes many more array types for PostgreSQL types that do not directly
|
|
||||||
map to native Go types.
|
|
||||||
|
|
||||||
JSON and JSONB Mapping
|
|
||||||
|
|
||||||
pgx includes built-in support to marshal and unmarshal between Go types and the PostgreSQL JSON and JSONB.
|
|
||||||
|
|
||||||
Inet and CIDR Mapping
|
|
||||||
|
|
||||||
pgx encodes from net.IPNet to and from inet and cidr PostgreSQL types. In addition, as a convenience pgx will encode
|
|
||||||
from a net.IP; it will assume a /32 netmask for IPv4 and a /128 for IPv6.
|
|
||||||
|
|
||||||
Custom Type Support
|
|
||||||
|
|
||||||
pgx includes support for the common data types like integers, floats, strings, dates, and times that have direct
|
|
||||||
mappings between Go and SQL. In addition, pgx uses the github.com/jackc/pgtype library to support more types. See
|
|
||||||
documention for that library for instructions on how to implement custom types.
|
|
||||||
|
|
||||||
See example_custom_type_test.go for an example of a custom type for the PostgreSQL point type.
|
|
||||||
|
|
||||||
pgx also includes support for custom types implementing the database/sql.Scanner and database/sql/driver.Valuer
|
|
||||||
interfaces.
|
|
||||||
|
|
||||||
If pgx does cannot natively encode a type and that type is a renamed type (e.g. type MyTime time.Time) pgx will attempt
|
|
||||||
to encode the underlying type. While this is usually desired behavior it can produce surprising behavior if one the
|
|
||||||
underlying type and the renamed type each implement database/sql interfaces and the other implements pgx interfaces. It
|
|
||||||
is recommended that this situation be avoided by implementing pgx interfaces on the renamed type.
|
|
||||||
|
|
||||||
Composite types and row values
|
|
||||||
|
|
||||||
Row values and composite types are represented as pgtype.Record (https://pkg.go.dev/github.com/jackc/pgtype?tab=doc#Record).
|
|
||||||
It is possible to get values of your custom type by implementing DecodeBinary interface. Decoding into
|
|
||||||
pgtype.Record first can simplify process by avoiding dealing with raw protocol directly.
|
|
||||||
|
|
||||||
For example:
|
|
||||||
|
|
||||||
type MyType struct {
|
|
||||||
a int // NULL will cause decoding error
|
|
||||||
b *string // there can be NULL in this position in SQL
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *MyType) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error {
|
|
||||||
r := pgtype.Record{
|
|
||||||
Fields: []pgtype.Value{&pgtype.Int4{}, &pgtype.Text{}},
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.DecodeBinary(ci, src); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.Status != pgtype.Present {
|
|
||||||
return errors.New("BUG: decoding should not be called on NULL value")
|
|
||||||
}
|
|
||||||
|
|
||||||
a := r.Fields[0].(*pgtype.Int4)
|
|
||||||
b := r.Fields[1].(*pgtype.Text)
|
|
||||||
|
|
||||||
// type compatibility is checked by AssignTo
|
|
||||||
// only lossless assignments will succeed
|
|
||||||
if err := a.AssignTo(&t.a); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// AssignTo also deals with null value handling
|
|
||||||
if err := b.AssignTo(&t.b); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
result := MyType{}
|
|
||||||
err := conn.QueryRow(context.Background(), "select row(1, 'foo'::text)", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&r)
|
|
||||||
|
|
||||||
Raw Bytes Mapping
|
|
||||||
|
|
||||||
[]byte passed as arguments to Query, QueryRow, and Exec are passed unmodified to PostgreSQL.
|
|
||||||
|
|
||||||
Transactions
|
Transactions
|
||||||
|
|
||||||
@@ -250,12 +97,13 @@ Transactions are started by calling Begin.
|
|||||||
The Tx returned from Begin also implements the Begin method. This can be used to implement pseudo nested transactions.
|
The Tx returned from Begin also implements the Begin method. This can be used to implement pseudo nested transactions.
|
||||||
These are internally implemented with savepoints.
|
These are internally implemented with savepoints.
|
||||||
|
|
||||||
Use BeginTx to control the transaction mode.
|
Use BeginTx to control the transaction mode. BeginTx also can be used to ensure a new transaction is created instead of
|
||||||
|
a pseudo nested transaction.
|
||||||
|
|
||||||
BeginFunc and BeginTxFunc are variants that begin a transaction, execute a function, and commit or rollback the
|
BeginFunc and BeginTxFunc are functions that begin a transaction, execute a function, and commit or rollback the
|
||||||
transaction depending on the return value of the function. These can be simpler and less error prone to use.
|
transaction depending on the return value of the function. These can be simpler and less error prone to use.
|
||||||
|
|
||||||
err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error {
|
err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error {
|
||||||
_, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
|
_, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
@@ -273,10 +121,10 @@ for information on how to customize or disable the statement cache.
|
|||||||
Copy Protocol
|
Copy Protocol
|
||||||
|
|
||||||
Use CopyFrom to efficiently insert multiple rows at a time using the PostgreSQL copy protocol. CopyFrom accepts a
|
Use CopyFrom to efficiently insert multiple rows at a time using the PostgreSQL copy protocol. CopyFrom accepts a
|
||||||
CopyFromSource interface. If the data is already in a [][]interface{} use CopyFromRows to wrap it in a CopyFromSource
|
CopyFromSource interface. If the data is already in a [][]any use CopyFromRows to wrap it in a CopyFromSource interface.
|
||||||
interface. Or implement CopyFromSource to avoid buffering the entire data set in memory.
|
Or implement CopyFromSource to avoid buffering the entire data set in memory.
|
||||||
|
|
||||||
rows := [][]interface{}{
|
rows := [][]any{
|
||||||
{"John", "Smith", int32(36)},
|
{"John", "Smith", int32(36)},
|
||||||
{"Jane", "Doe", int32(29)},
|
{"Jane", "Doe", int32(29)},
|
||||||
}
|
}
|
||||||
@@ -299,8 +147,8 @@ When you already have a typed array using CopyFromSlice can be more convenient.
|
|||||||
context.Background(),
|
context.Background(),
|
||||||
pgx.Identifier{"people"},
|
pgx.Identifier{"people"},
|
||||||
[]string{"first_name", "last_name", "age"},
|
[]string{"first_name", "last_name", "age"},
|
||||||
pgx.CopyFromSlice(len(rows), func(i int) ([]interface{}, error) {
|
pgx.CopyFromSlice(len(rows), func(i int) ([]any, error) {
|
||||||
return []interface{}{rows[i].FirstName, rows[i].LastName, rows[i].Age}, nil
|
return []any{rows[i].FirstName, rows[i].LastName, rows[i].Age}, nil
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -321,20 +169,22 @@ notification is received or the context is canceled.
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
Logging
|
Tracing and Logging
|
||||||
|
|
||||||
pgx defines a simple logger interface. Connections optionally accept a logger that satisfies this interface. Set
|
pgx supports tracing by setting ConnConfig.Tracer.
|
||||||
LogLevel to control logging verbosity. Adapters for github.com/inconshreveable/log15, github.com/sirupsen/logrus,
|
|
||||||
go.uber.org/zap, github.com/rs/zerolog, and the testing log are provided in the log directory.
|
In addition, the tracelog package provides the TraceLog type which lets a traditional logger act as a Tracer.
|
||||||
|
|
||||||
|
For debug tracing of the actual PostgreSQL wire protocol messages see github.com/jackc/pgx/v5/pgproto3.
|
||||||
|
|
||||||
Lower Level PostgreSQL Functionality
|
Lower Level PostgreSQL Functionality
|
||||||
|
|
||||||
pgx is implemented on top of github.com/jackc/pgconn a lower level PostgreSQL driver. The Conn.PgConn() method can be
|
github.com/jackc/pgx/v5/pgconn contains a lower level PostgreSQL driver roughly at the level of libpq. pgx.Conn in
|
||||||
used to access this lower layer.
|
implemented on top of pgconn. The Conn.PgConn() method can be used to access this lower layer.
|
||||||
|
|
||||||
PgBouncer
|
PgBouncer
|
||||||
|
|
||||||
pgx is compatible with PgBouncer in two modes. One is when the connection has a statement cache in "describe" mode. The
|
By default pgx automatically uses prepared statements. Prepared statements are incompaptible with PgBouncer. This can be
|
||||||
other is when the connection is using the simple protocol. This can be set with the PreferSimpleProtocol config option.
|
disabled by setting a different QueryExecMode in ConnConfig.DefaultQueryExecMode.
|
||||||
*/
|
*/
|
||||||
package pgx
|
package pgx
|
||||||
|
|||||||
@@ -1,115 +0,0 @@
|
|||||||
package pgx_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"regexp"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/jackc/pgtype"
|
|
||||||
"github.com/jackc/pgx/v4"
|
|
||||||
)
|
|
||||||
|
|
||||||
var pointRegexp *regexp.Regexp = regexp.MustCompile(`^\((.*),(.*)\)$`)
|
|
||||||
|
|
||||||
// Point represents a point that may be null.
|
|
||||||
type Point struct {
|
|
||||||
X, Y float64 // Coordinates of point
|
|
||||||
Status pgtype.Status
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dst *Point) Set(src interface{}) error {
|
|
||||||
return fmt.Errorf("cannot convert %v to Point", src)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dst *Point) Get() interface{} {
|
|
||||||
switch dst.Status {
|
|
||||||
case pgtype.Present:
|
|
||||||
return dst
|
|
||||||
case pgtype.Null:
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
return dst.Status
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *Point) AssignTo(dst interface{}) error {
|
|
||||||
return fmt.Errorf("cannot assign %v to %T", src, dst)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error {
|
|
||||||
if src == nil {
|
|
||||||
*dst = Point{Status: pgtype.Null}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
s := string(src)
|
|
||||||
match := pointRegexp.FindStringSubmatch(s)
|
|
||||||
if match == nil {
|
|
||||||
return fmt.Errorf("Received invalid point: %v", s)
|
|
||||||
}
|
|
||||||
|
|
||||||
x, err := strconv.ParseFloat(match[1], 64)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Received invalid point: %v", s)
|
|
||||||
}
|
|
||||||
y, err := strconv.ParseFloat(match[2], 64)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Received invalid point: %v", s)
|
|
||||||
}
|
|
||||||
|
|
||||||
*dst = Point{X: x, Y: y, Status: pgtype.Present}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *Point) String() string {
|
|
||||||
if src.Status == pgtype.Null {
|
|
||||||
return "null point"
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Sprintf("%.1f, %.1f", src.X, src.Y)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Example_CustomType() {
|
|
||||||
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("Unable to establish connection: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer conn.Close(context.Background())
|
|
||||||
|
|
||||||
if conn.PgConn().ParameterStatus("crdb_version") != "" {
|
|
||||||
// Skip test / example when running on CockroachDB which doesn't support the point type. Since an example can't be
|
|
||||||
// skipped fake success instead.
|
|
||||||
fmt.Println("null point")
|
|
||||||
fmt.Println("1.5, 2.5")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Override registered handler for point
|
|
||||||
conn.ConnInfo().RegisterDataType(pgtype.DataType{
|
|
||||||
Value: &Point{},
|
|
||||||
Name: "point",
|
|
||||||
OID: 600,
|
|
||||||
})
|
|
||||||
|
|
||||||
p := &Point{}
|
|
||||||
err = conn.QueryRow(context.Background(), "select null::point").Scan(p)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
fmt.Println(p)
|
|
||||||
|
|
||||||
err = conn.QueryRow(context.Background(), "select point(1.5,2.5)").Scan(p)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
fmt.Println(p)
|
|
||||||
// Output:
|
|
||||||
// null point
|
|
||||||
// 1.5, 2.5
|
|
||||||
}
|
|
||||||
@@ -6,14 +6,14 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/jackc/pgx/v4/pgxpool"
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
)
|
)
|
||||||
|
|
||||||
var pool *pgxpool.Pool
|
var pool *pgxpool.Pool
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var err error
|
var err error
|
||||||
pool, err = pgxpool.Connect(context.Background(), os.Getenv("DATABASE_URL"))
|
pool, err = pgxpool.New(context.Background(), os.Getenv("DATABASE_URL"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintln(os.Stderr, "Unable to connect to database:", err)
|
fmt.Fprintln(os.Stderr, "Unable to connect to database:", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/jackc/pgx/v4"
|
"github.com/jackc/pgx/v5"
|
||||||
)
|
)
|
||||||
|
|
||||||
var conn *pgx.Conn
|
var conn *pgx.Conn
|
||||||
|
|||||||
@@ -3,13 +3,12 @@ package main
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/jackc/pgx/v4"
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/jackc/pgx/v4/log/log15adapter"
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
"github.com/jackc/pgx/v4/pgxpool"
|
|
||||||
log "gopkg.in/inconshreveable/log15.v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var db *pgxpool.Pool
|
var db *pgxpool.Pool
|
||||||
@@ -71,28 +70,21 @@ func urlHandler(w http.ResponseWriter, req *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
logger := log15adapter.NewLogger(log.New("module", "pgx"))
|
|
||||||
|
|
||||||
poolConfig, err := pgxpool.ParseConfig(os.Getenv("DATABASE_URL"))
|
poolConfig, err := pgxpool.ParseConfig(os.Getenv("DATABASE_URL"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Crit("Unable to parse DATABASE_URL", "error", err)
|
log.Fatalln("Unable to parse DATABASE_URL:", err)
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
poolConfig.ConnConfig.Logger = logger
|
db, err = pgxpool.NewWithConfig(context.Background(), poolConfig)
|
||||||
|
|
||||||
db, err = pgxpool.ConnectConfig(context.Background(), poolConfig)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Crit("Unable to create connection pool", "error", err)
|
log.Fatalln("Unable to create connection pool:", err)
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
http.HandleFunc("/", urlHandler)
|
http.HandleFunc("/", urlHandler)
|
||||||
|
|
||||||
log.Info("Starting URL shortener on localhost:8080")
|
log.Println("Starting URL shortener on localhost:8080")
|
||||||
err = http.ListenAndServe("localhost:8080", nil)
|
err = http.ListenAndServe("localhost:8080", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Crit("Unable to start web server", "error", err)
|
log.Fatalln("Unable to start web server:", err)
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+152
-109
@@ -1,69 +1,118 @@
|
|||||||
package pgx
|
package pgx
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql/driver"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
|
||||||
|
|
||||||
"github.com/jackc/pgtype"
|
"github.com/jackc/pgx/v5/internal/anynil"
|
||||||
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
)
|
)
|
||||||
|
|
||||||
type extendedQueryBuilder struct {
|
// ExtendedQueryBuilder is used to choose the parameter formats, to format the parameters and to choose the result
|
||||||
paramValues [][]byte
|
// formats for an extended query.
|
||||||
|
type ExtendedQueryBuilder struct {
|
||||||
|
ParamValues [][]byte
|
||||||
paramValueBytes []byte
|
paramValueBytes []byte
|
||||||
paramFormats []int16
|
ParamFormats []int16
|
||||||
resultFormats []int16
|
ResultFormats []int16
|
||||||
}
|
}
|
||||||
|
|
||||||
func (eqb *extendedQueryBuilder) AppendParam(ci *pgtype.ConnInfo, oid uint32, arg interface{}) error {
|
// Build sets ParamValues, ParamFormats, and ResultFormats for use with *PgConn.ExecParams or *PgConn.ExecPrepared. If
|
||||||
f := chooseParameterFormatCode(ci, oid, arg)
|
// sd is nil then QueryExecModeExec behavior will be used.
|
||||||
eqb.paramFormats = append(eqb.paramFormats, f)
|
func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescription, args []any) error {
|
||||||
|
eqb.reset()
|
||||||
|
|
||||||
v, err := eqb.encodeExtendedParamValue(ci, oid, f, arg)
|
anynil.NormalizeSlice(args)
|
||||||
if err != nil {
|
|
||||||
return err
|
if sd == nil {
|
||||||
|
return eqb.appendParamsForQueryExecModeExec(m, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sd.ParamOIDs) != len(args) {
|
||||||
|
return fmt.Errorf("mismatched param and argument count")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range args {
|
||||||
|
err := eqb.appendParam(m, sd.ParamOIDs[i], -1, args[i])
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to encode args[%d]: %v", i, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range sd.Fields {
|
||||||
|
eqb.appendResultFormat(m.FormatCodeForOID(sd.Fields[i].DataTypeOID))
|
||||||
}
|
}
|
||||||
eqb.paramValues = append(eqb.paramValues, v)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (eqb *extendedQueryBuilder) AppendResultFormat(f int16) {
|
// appendParam appends a parameter to the query. format may be -1 to automatically choose the format. If arg is nil it
|
||||||
eqb.resultFormats = append(eqb.resultFormats, f)
|
// must be an untyped nil.
|
||||||
|
func (eqb *ExtendedQueryBuilder) appendParam(m *pgtype.Map, oid uint32, format int16, arg any) error {
|
||||||
|
if format == -1 {
|
||||||
|
preferredFormat := eqb.chooseParameterFormatCode(m, oid, arg)
|
||||||
|
preferredErr := eqb.appendParam(m, oid, preferredFormat, arg)
|
||||||
|
if preferredErr == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var otherFormat int16
|
||||||
|
if preferredFormat == TextFormatCode {
|
||||||
|
otherFormat = BinaryFormatCode
|
||||||
|
} else {
|
||||||
|
otherFormat = TextFormatCode
|
||||||
|
}
|
||||||
|
|
||||||
|
otherErr := eqb.appendParam(m, oid, otherFormat, arg)
|
||||||
|
if otherErr == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return preferredErr // return the error from the preferred format
|
||||||
|
}
|
||||||
|
|
||||||
|
v, err := eqb.encodeExtendedParamValue(m, oid, format, arg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
eqb.ParamFormats = append(eqb.ParamFormats, format)
|
||||||
|
eqb.ParamValues = append(eqb.ParamValues, v)
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset readies eqb to build another query.
|
// appendResultFormat appends a result format to the query.
|
||||||
func (eqb *extendedQueryBuilder) Reset() {
|
func (eqb *ExtendedQueryBuilder) appendResultFormat(format int16) {
|
||||||
eqb.paramValues = eqb.paramValues[0:0]
|
eqb.ResultFormats = append(eqb.ResultFormats, format)
|
||||||
eqb.paramValueBytes = eqb.paramValueBytes[0:0]
|
}
|
||||||
eqb.paramFormats = eqb.paramFormats[0:0]
|
|
||||||
eqb.resultFormats = eqb.resultFormats[0:0]
|
|
||||||
|
|
||||||
if cap(eqb.paramValues) > 64 {
|
// reset readies eqb to build another query.
|
||||||
eqb.paramValues = make([][]byte, 0, 64)
|
func (eqb *ExtendedQueryBuilder) reset() {
|
||||||
|
eqb.ParamValues = eqb.ParamValues[0:0]
|
||||||
|
eqb.paramValueBytes = eqb.paramValueBytes[0:0]
|
||||||
|
eqb.ParamFormats = eqb.ParamFormats[0:0]
|
||||||
|
eqb.ResultFormats = eqb.ResultFormats[0:0]
|
||||||
|
|
||||||
|
if cap(eqb.ParamValues) > 64 {
|
||||||
|
eqb.ParamValues = make([][]byte, 0, 64)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cap(eqb.paramValueBytes) > 256 {
|
if cap(eqb.paramValueBytes) > 256 {
|
||||||
eqb.paramValueBytes = make([]byte, 0, 256)
|
eqb.paramValueBytes = make([]byte, 0, 256)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cap(eqb.paramFormats) > 64 {
|
if cap(eqb.ParamFormats) > 64 {
|
||||||
eqb.paramFormats = make([]int16, 0, 64)
|
eqb.ParamFormats = make([]int16, 0, 64)
|
||||||
}
|
}
|
||||||
if cap(eqb.resultFormats) > 64 {
|
if cap(eqb.ResultFormats) > 64 {
|
||||||
eqb.resultFormats = make([]int16, 0, 64)
|
eqb.ResultFormats = make([]int16, 0, 64)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, oid uint32, formatCode int16, arg interface{}) ([]byte, error) {
|
func (eqb *ExtendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) {
|
||||||
if arg == nil {
|
if anynil.Is(arg) {
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
refVal := reflect.ValueOf(arg)
|
|
||||||
argIsPtr := refVal.Kind() == reflect.Ptr
|
|
||||||
|
|
||||||
if argIsPtr && refVal.IsNil() {
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -71,91 +120,85 @@ func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, o
|
|||||||
eqb.paramValueBytes = make([]byte, 0, 128)
|
eqb.paramValueBytes = make([]byte, 0, 128)
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
|
||||||
var buf []byte
|
|
||||||
pos := len(eqb.paramValueBytes)
|
pos := len(eqb.paramValueBytes)
|
||||||
|
|
||||||
if arg, ok := arg.(string); ok {
|
buf, err := m.Encode(oid, formatCode, arg, eqb.paramValueBytes)
|
||||||
return []byte(arg), nil
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if buf == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
eqb.paramValueBytes = buf
|
||||||
|
return eqb.paramValueBytes[pos:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// chooseParameterFormatCode determines the correct format code for an
|
||||||
|
// argument to a prepared statement. It defaults to TextFormatCode if no
|
||||||
|
// determination can be made.
|
||||||
|
func (eqb *ExtendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid uint32, arg any) int16 {
|
||||||
|
switch arg.(type) {
|
||||||
|
case string, *string:
|
||||||
|
return TextFormatCode
|
||||||
}
|
}
|
||||||
|
|
||||||
if formatCode == TextFormatCode {
|
return m.FormatCodeForOID(oid)
|
||||||
if arg, ok := arg.(pgtype.TextEncoder); ok {
|
}
|
||||||
buf, err = arg.EncodeText(ci, eqb.paramValueBytes)
|
|
||||||
|
// appendParamsForQueryExecModeExec appends the args to eqb.
|
||||||
|
//
|
||||||
|
// Parameters must be encoded in the text format because of differences in type conversion between timestamps and
|
||||||
|
// dates. In QueryExecModeExec we don't know what the actual PostgreSQL type is. To determine the type we use the
|
||||||
|
// Go type to OID type mapping registered by RegisterDefaultPgType. However, the Go time.Time represents both
|
||||||
|
// PostgreSQL timestamp[tz] and date. To use the binary format we would need to also specify what the PostgreSQL
|
||||||
|
// type OID is. But that would mean telling PostgreSQL that we have sent a timestamp[tz] when what is needed is a date.
|
||||||
|
// This means that the value is converted from text to timestamp[tz] to date. This means it does a time zone conversion
|
||||||
|
// before converting it to date. This means that dates can be shifted by one day. In text format without that double
|
||||||
|
// type conversion it takes the date directly and ignores time zone (i.e. it works).
|
||||||
|
//
|
||||||
|
// Given that the whole point of QueryExecModeExec is to operate without having to know the PostgreSQL types there is
|
||||||
|
// no way to safely use binary or to specify the parameter OIDs.
|
||||||
|
func (eqb *ExtendedQueryBuilder) appendParamsForQueryExecModeExec(m *pgtype.Map, args []any) error {
|
||||||
|
for _, arg := range args {
|
||||||
|
if arg == nil {
|
||||||
|
err := eqb.appendParam(m, 0, TextFormatCode, arg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
if buf == nil {
|
} else {
|
||||||
return nil, nil
|
dt, ok := m.TypeForValue(arg)
|
||||||
}
|
if !ok {
|
||||||
eqb.paramValueBytes = buf
|
var tv pgtype.TextValuer
|
||||||
return eqb.paramValueBytes[pos:], nil
|
if tv, ok = arg.(pgtype.TextValuer); ok {
|
||||||
}
|
t, err := tv.TextValue()
|
||||||
} else if formatCode == BinaryFormatCode {
|
|
||||||
if arg, ok := arg.(pgtype.BinaryEncoder); ok {
|
|
||||||
buf, err = arg.EncodeBinary(ci, eqb.paramValueBytes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if buf == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
eqb.paramValueBytes = buf
|
|
||||||
return eqb.paramValueBytes[pos:], nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if argIsPtr {
|
|
||||||
// We have already checked that arg is not pointing to nil,
|
|
||||||
// so it is safe to dereference here.
|
|
||||||
arg = refVal.Elem().Interface()
|
|
||||||
return eqb.encodeExtendedParamValue(ci, oid, formatCode, arg)
|
|
||||||
}
|
|
||||||
|
|
||||||
if dt, ok := ci.DataTypeForOID(oid); ok {
|
|
||||||
value := dt.Value
|
|
||||||
err := value.Set(arg)
|
|
||||||
if err != nil {
|
|
||||||
{
|
|
||||||
if arg, ok := arg.(driver.Valuer); ok {
|
|
||||||
v, err := callValuerValue(arg)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dt, ok = m.TypeForOID(pgtype.TextOID)
|
||||||
|
if ok {
|
||||||
|
arg = t
|
||||||
}
|
}
|
||||||
return eqb.encodeExtendedParamValue(ci, oid, formatCode, v)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if !ok {
|
||||||
return nil, err
|
var str fmt.Stringer
|
||||||
}
|
if str, ok = arg.(fmt.Stringer); ok {
|
||||||
|
dt, ok = m.TypeForOID(pgtype.TextOID)
|
||||||
return eqb.encodeExtendedParamValue(ci, oid, formatCode, value)
|
if ok {
|
||||||
}
|
arg = str.String()
|
||||||
|
}
|
||||||
// There is no data type registered for the destination OID, but maybe there is data type registered for the arg
|
}
|
||||||
// type. If so use it's text encoder (if available).
|
}
|
||||||
if dt, ok := ci.DataTypeForValue(arg); ok {
|
if !ok {
|
||||||
value := dt.Value
|
return &unknownArgumentTypeQueryExecModeExecError{arg: arg}
|
||||||
if textEncoder, ok := value.(pgtype.TextEncoder); ok {
|
}
|
||||||
err := value.Set(arg)
|
err := eqb.appendParam(m, dt.OID, TextFormatCode, arg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
buf, err = textEncoder.EncodeText(ci, eqb.paramValueBytes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if buf == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
eqb.paramValueBytes = buf
|
|
||||||
return eqb.paramValueBytes[pos:], nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if strippedArg, ok := stripNamedType(&refVal); ok {
|
return nil
|
||||||
return eqb.encodeExtendedParamValue(ci, oid, formatCode, strippedArg)
|
|
||||||
}
|
|
||||||
return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,21 +1,20 @@
|
|||||||
module github.com/jackc/pgx/v4
|
module github.com/jackc/pgx/v5
|
||||||
|
|
||||||
go 1.13
|
go 1.18
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/Masterminds/semver/v3 v3.1.1
|
github.com/jackc/pgpassfile v1.0.0
|
||||||
github.com/cockroachdb/apd v1.1.0
|
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b
|
||||||
github.com/go-kit/log v0.1.0
|
github.com/jackc/puddle/v2 v2.0.0
|
||||||
github.com/gofrs/uuid v4.0.0+incompatible
|
|
||||||
github.com/jackc/pgconn v1.13.0
|
|
||||||
github.com/jackc/pgio v1.0.0
|
|
||||||
github.com/jackc/pgproto3/v2 v2.3.1
|
|
||||||
github.com/jackc/pgtype v1.12.0
|
|
||||||
github.com/jackc/puddle v1.3.0
|
|
||||||
github.com/rs/zerolog v1.15.0
|
|
||||||
github.com/shopspring/decimal v1.2.0
|
|
||||||
github.com/sirupsen/logrus v1.4.2
|
|
||||||
github.com/stretchr/testify v1.8.0
|
github.com/stretchr/testify v1.8.0
|
||||||
go.uber.org/zap v1.13.0
|
golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90
|
||||||
gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec
|
golang.org/x/text v0.3.7
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/kr/pretty v0.3.0 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,208 +1,42 @@
|
|||||||
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
|
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
|
||||||
github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc=
|
|
||||||
github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs=
|
|
||||||
github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I=
|
|
||||||
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
|
|
||||||
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
|
|
||||||
github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
|
|
||||||
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
|
|
||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/go-kit/log v0.1.0 h1:DGJh0Sm43HbOeYDNnVZFl8BvcYVvjD5bqYJvp0REbwQ=
|
|
||||||
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
|
|
||||||
github.com/go-logfmt/logfmt v0.5.0 h1:TrB8swr/68K7m9CcGut2g3UOihhbcbiMAYiuTXdEih4=
|
|
||||||
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
|
|
||||||
github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk=
|
|
||||||
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
|
||||||
github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw=
|
|
||||||
github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
|
|
||||||
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
|
|
||||||
github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0=
|
|
||||||
github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=
|
|
||||||
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
|
|
||||||
github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8=
|
|
||||||
github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
|
|
||||||
github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA=
|
|
||||||
github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE=
|
|
||||||
github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s=
|
|
||||||
github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o=
|
|
||||||
github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY=
|
|
||||||
github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI=
|
|
||||||
github.com/jackc/pgconn v1.13.0 h1:3L1XMNV2Zvca/8BYhzcRFS70Lr0WlDg16Di6SFGAbys=
|
|
||||||
github.com/jackc/pgconn v1.13.0/go.mod h1:AnowpAqO4CMIIJNZl2VJp+KrkAZciAkhEl0W0JIobpI=
|
|
||||||
github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE=
|
|
||||||
github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8=
|
|
||||||
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE=
|
|
||||||
github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c=
|
|
||||||
github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc=
|
|
||||||
github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak=
|
|
||||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A=
|
|
||||||
github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78=
|
|
||||||
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
|
|
||||||
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg=
|
|
||||||
github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM=
|
|
||||||
github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM=
|
|
||||||
github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
|
|
||||||
github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
|
|
||||||
github.com/jackc/pgproto3/v2 v2.3.1 h1:nwj7qwf0S+Q7ISFfBndqeLwSwxs+4DPsbRFjECT1Y4Y=
|
|
||||||
github.com/jackc/pgproto3/v2 v2.3.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
|
|
||||||
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg=
|
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg=
|
||||||
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E=
|
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E=
|
||||||
github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg=
|
github.com/jackc/puddle/v2 v2.0.0 h1:Kwk/AlLigcnZsDssc3Zun1dk1tAtQNPaBBxBHWn0Mjc=
|
||||||
github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc=
|
github.com/jackc/puddle/v2 v2.0.0/go.mod h1:itE7ZJY8xnoo0JqJEpSMprN0f+NQkMCuEV/N9j8h0oc=
|
||||||
github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw=
|
|
||||||
github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM=
|
|
||||||
github.com/jackc/pgtype v1.12.0 h1:Dlq8Qvcch7kiehm8wPGIW0W3KsCCHJnRacKW0UM8n5w=
|
|
||||||
github.com/jackc/pgtype v1.12.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4=
|
|
||||||
github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y=
|
|
||||||
github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM=
|
|
||||||
github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc=
|
|
||||||
github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs=
|
|
||||||
github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
|
|
||||||
github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
|
|
||||||
github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
|
|
||||||
github.com/jackc/puddle v1.2.1 h1:gI8os0wpRXFd4FiAY2dWiqRK037tjj3t7rKFeO4X5iw=
|
|
||||||
github.com/jackc/puddle v1.2.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
|
|
||||||
github.com/jackc/puddle v1.3.0 h1:eHK/5clGOatcjX3oWGBO/MpxpbHzSwud5EWTSCI+MX0=
|
|
||||||
github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
|
|
||||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
|
||||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
|
||||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=
|
|
||||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
|
||||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
|
||||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||||
|
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||||
|
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||||
|
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||||
github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw=
|
|
||||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
|
||||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||||
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||||
github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
|
||||||
github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8=
|
|
||||||
github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
|
||||||
github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ=
|
|
||||||
github.com/mattn/go-colorable v0.1.6 h1:6Su7aK7lXmJ/U79bYtBjLNaha4Fs1Rg9plHpcH+vvnE=
|
|
||||||
github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
|
|
||||||
github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
|
|
||||||
github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
|
|
||||||
github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY=
|
|
||||||
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
|
|
||||||
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
|
||||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k=
|
||||||
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
|
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
||||||
github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
|
|
||||||
github.com/rs/zerolog v1.15.0 h1:uPRuwkWF4J6fGsJ2R0Gn2jB1EQiav9k3S6CSdygQJXY=
|
|
||||||
github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc=
|
|
||||||
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
|
|
||||||
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4=
|
|
||||||
github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ=
|
|
||||||
github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
|
|
||||||
github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
|
|
||||||
github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4=
|
|
||||||
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
|
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
|
||||||
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
|
|
||||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
|
||||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
|
||||||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
||||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
|
||||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
|
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
|
||||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||||
github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
|
golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90 h1:Y/gsMcFOcR+6S6f3YeMKl5g+dZMEWqcz5Czj/GWYbkM=
|
||||||
go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
|
golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||||
go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
|
|
||||||
go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
|
|
||||||
go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk=
|
|
||||||
go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
|
|
||||||
go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0=
|
|
||||||
go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4=
|
|
||||||
go.uber.org/multierr v1.5.0 h1:KCa4XfM8CWFCpxXRGok+Q0SS/0XBhMDbHHGABQLvD2A=
|
|
||||||
go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU=
|
|
||||||
go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4=
|
|
||||||
go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA=
|
|
||||||
go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
|
|
||||||
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
|
|
||||||
go.uber.org/zap v1.13.0 h1:nR6NoDBgAf67s68NhaXbsojM+2gxp3S1hWkHDl27pVU=
|
|
||||||
go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM=
|
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
|
||||||
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE=
|
|
||||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
|
||||||
golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
|
||||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
|
||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
|
||||||
golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
|
|
||||||
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
|
||||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
|
||||||
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c=
|
|
||||||
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
|
||||||
golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs=
|
|
||||||
golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
|
|
||||||
golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc=
|
|
||||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
|
||||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
|
||||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
|
||||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
|
||||||
golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
|
||||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
|
||||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
|
||||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
|
||||||
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
|
||||||
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
|
||||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
|
||||||
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
|
||||||
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
|
||||||
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
|
||||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
|
||||||
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
|
||||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
|
||||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
|
||||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4=
|
|
||||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
|
||||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
|
||||||
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
|
||||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
|
||||||
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
|
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
|
||||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
|
||||||
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
|
||||||
golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
|
||||||
golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
|
|
||||||
golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
|
||||||
golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
|
||||||
golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
|
||||||
golang.org/x/tools v0.0.0-20200103221440-774c71fcf114 h1:DnSr2mCsxyCE6ZgIkmcWUQY2R5cH/6wL7eIxEmQOMSE=
|
|
||||||
golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
|
||||||
golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
|
||||||
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
|
||||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
|
||||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
|
||||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
|
||||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||||
gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec h1:RlWgLqCMMIYYEVcAR5MDsuHlVkaIPDAF+5Dehzg8L5A=
|
|
||||||
gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s=
|
|
||||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM=
|
|
||||||
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
|
|
||||||
|
|||||||
@@ -1,61 +0,0 @@
|
|||||||
package pgx
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql/driver"
|
|
||||||
"reflect"
|
|
||||||
)
|
|
||||||
|
|
||||||
// This file contains code copied from the Go standard library due to the
|
|
||||||
// required function not being public.
|
|
||||||
|
|
||||||
// Copyright (c) 2009 The Go Authors. All rights reserved.
|
|
||||||
|
|
||||||
// Redistribution and use in source and binary forms, with or without
|
|
||||||
// modification, are permitted provided that the following conditions are
|
|
||||||
// met:
|
|
||||||
|
|
||||||
// * Redistributions of source code must retain the above copyright
|
|
||||||
// notice, this list of conditions and the following disclaimer.
|
|
||||||
// * Redistributions in binary form must reproduce the above
|
|
||||||
// copyright notice, this list of conditions and the following disclaimer
|
|
||||||
// in the documentation and/or other materials provided with the
|
|
||||||
// distribution.
|
|
||||||
// * Neither the name of Google Inc. nor the names of its
|
|
||||||
// contributors may be used to endorse or promote products derived from
|
|
||||||
// this software without specific prior written permission.
|
|
||||||
|
|
||||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
||||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
||||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
||||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
|
||||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
|
||||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
|
||||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
|
||||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
|
||||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
||||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
||||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
||||||
|
|
||||||
// From database/sql/convert.go
|
|
||||||
|
|
||||||
var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
|
|
||||||
|
|
||||||
// callValuerValue returns vr.Value(), with one exception:
|
|
||||||
// If vr.Value is an auto-generated method on a pointer type and the
|
|
||||||
// pointer is nil, it would panic at runtime in the panicwrap
|
|
||||||
// method. Treat it like nil instead.
|
|
||||||
// Issue 8415.
|
|
||||||
//
|
|
||||||
// This is so people can implement driver.Value on value types and
|
|
||||||
// still use nil pointers to those types to mean nil/NULL, just like
|
|
||||||
// string/*string.
|
|
||||||
//
|
|
||||||
// This function is mirrored in the database/sql/driver package.
|
|
||||||
func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
|
|
||||||
if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr &&
|
|
||||||
rv.IsNil() &&
|
|
||||||
rv.Type().Elem().Implements(valuerReflectType) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return vr.Value()
|
|
||||||
}
|
|
||||||
+17
-52
@@ -7,48 +7,21 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/jackc/pgconn"
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/jackc/pgx/v4"
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
"github.com/jackc/pgx/v5/pgxtest"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func testWithAndWithoutPreferSimpleProtocol(t *testing.T, f func(t *testing.T, conn *pgx.Conn)) {
|
var defaultConnTestRunner pgxtest.ConnTestRunner
|
||||||
t.Run("SimpleProto",
|
|
||||||
func(t *testing.T) {
|
|
||||||
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
config.PreferSimpleProtocol = true
|
func init() {
|
||||||
conn, err := pgx.ConnectConfig(context.Background(), config)
|
defaultConnTestRunner = pgxtest.DefaultConnTestRunner()
|
||||||
require.NoError(t, err)
|
defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
|
||||||
defer func() {
|
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||||
err := conn.Close(context.Background())
|
require.NoError(t, err)
|
||||||
require.NoError(t, err)
|
return config
|
||||||
}()
|
}
|
||||||
|
|
||||||
f(t, conn)
|
|
||||||
|
|
||||||
ensureConnValid(t, conn)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
t.Run("DefaultProto",
|
|
||||||
func(t *testing.T) {
|
|
||||||
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
conn, err := pgx.ConnectConfig(context.Background(), config)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
|
||||||
err := conn.Close(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
}()
|
|
||||||
|
|
||||||
f(t, conn)
|
|
||||||
|
|
||||||
ensureConnValid(t, conn)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func mustConnectString(t testing.TB, connString string) *pgx.Conn {
|
func mustConnectString(t testing.TB, connString string) *pgx.Conn {
|
||||||
@@ -80,7 +53,7 @@ func closeConn(t testing.TB, conn *pgx.Conn) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag) {
|
func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...any) (commandTag pgconn.CommandTag) {
|
||||||
var err error
|
var err error
|
||||||
if commandTag, err = conn.Exec(context.Background(), sql, arguments...); err != nil {
|
if commandTag, err = conn.Exec(context.Background(), sql, arguments...); err != nil {
|
||||||
t.Fatalf("Exec unexpectedly failed with %v: %v", sql, err)
|
t.Fatalf("Exec unexpectedly failed with %v: %v", sql, err)
|
||||||
@@ -89,7 +62,7 @@ func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Do a simple query to ensure the connection is still usable
|
// Do a simple query to ensure the connection is still usable
|
||||||
func ensureConnValid(t *testing.T, conn *pgx.Conn) {
|
func ensureConnValid(t testing.TB, conn *pgx.Conn) {
|
||||||
var sum, rowCount int32
|
var sum, rowCount int32
|
||||||
|
|
||||||
rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
|
rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
|
||||||
@@ -125,13 +98,11 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equalf(t, expected.Logger, actual.Logger, "%s - Logger", testName)
|
assert.Equalf(t, expected.Tracer, actual.Tracer, "%s - Tracer", testName)
|
||||||
assert.Equalf(t, expected.LogLevel, actual.LogLevel, "%s - LogLevel", testName)
|
|
||||||
assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName)
|
assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName)
|
||||||
// Can't test function equality, so just test that they are set or not.
|
assert.Equalf(t, expected.StatementCacheCapacity, actual.StatementCacheCapacity, "%s - StatementCacheCapacity", testName)
|
||||||
assert.Equalf(t, expected.BuildStatementCache == nil, actual.BuildStatementCache == nil, "%s - BuildStatementCache", testName)
|
assert.Equalf(t, expected.DescriptionCacheCapacity, actual.DescriptionCacheCapacity, "%s - DescriptionCacheCapacity", testName)
|
||||||
assert.Equalf(t, expected.PreferSimpleProtocol, actual.PreferSimpleProtocol, "%s - PreferSimpleProtocol", testName)
|
assert.Equalf(t, expected.DefaultQueryExecMode, actual.DefaultQueryExecMode, "%s - DefaultQueryExecMode", testName)
|
||||||
|
|
||||||
assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName)
|
assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName)
|
||||||
assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName)
|
assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName)
|
||||||
assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName)
|
assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName)
|
||||||
@@ -165,9 +136,3 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func skipCockroachDB(t testing.TB, conn *pgx.Conn, msg string) {
|
|
||||||
if conn.PgConn().ParameterStatus("crdb_version") != "" {
|
|
||||||
t.Skip(msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -0,0 +1,36 @@
|
|||||||
|
package anynil
|
||||||
|
|
||||||
|
import "reflect"
|
||||||
|
|
||||||
|
// Is returns true if value is any type of nil. e.g. nil or []byte(nil).
|
||||||
|
func Is(value any) bool {
|
||||||
|
if value == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
refVal := reflect.ValueOf(value)
|
||||||
|
switch refVal.Kind() {
|
||||||
|
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice:
|
||||||
|
return refVal.IsNil()
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize converts typed nils (e.g. []byte(nil)) into untyped nil. Other values are returned unmodified.
|
||||||
|
func Normalize(v any) any {
|
||||||
|
if Is(v) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
// NormalizeSlice converts all typed nils (e.g. []byte(nil)) in s into untyped nils. Other values are unmodified. s is
|
||||||
|
// mutated in place.
|
||||||
|
func NormalizeSlice(s []any) {
|
||||||
|
for i := range s {
|
||||||
|
if Is(s[i]) {
|
||||||
|
s[i] = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
// Package iobufpool implements a global segregated-fit pool of buffers for IO.
|
||||||
|
package iobufpool
|
||||||
|
|
||||||
|
import "sync"
|
||||||
|
|
||||||
|
const minPoolExpOf2 = 8
|
||||||
|
|
||||||
|
var pools [18]*sync.Pool
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
for i := range pools {
|
||||||
|
bufLen := 1 << (minPoolExpOf2 + i)
|
||||||
|
pools[i] = &sync.Pool{New: func() any { return make([]byte, bufLen) }}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get gets a []byte of len size with cap <= size*2.
|
||||||
|
func Get(size int) []byte {
|
||||||
|
i := getPoolIdx(size)
|
||||||
|
if i >= len(pools) {
|
||||||
|
return make([]byte, size)
|
||||||
|
}
|
||||||
|
return pools[i].Get().([]byte)[:size]
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPoolIdx(size int) int {
|
||||||
|
size--
|
||||||
|
size >>= minPoolExpOf2
|
||||||
|
i := 0
|
||||||
|
for size > 0 {
|
||||||
|
size >>= 1
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put returns buf to the pool.
|
||||||
|
func Put(buf []byte) {
|
||||||
|
i := putPoolIdx(cap(buf))
|
||||||
|
if i < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pools[i].Put(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func putPoolIdx(size int) int {
|
||||||
|
minPoolSize := 1 << minPoolExpOf2
|
||||||
|
for i := range pools {
|
||||||
|
if size == minPoolSize<<i {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return -1
|
||||||
|
}
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
package iobufpool
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPoolIdx(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
size int
|
||||||
|
expected int
|
||||||
|
}{
|
||||||
|
{size: 0, expected: 0},
|
||||||
|
{size: 1, expected: 0},
|
||||||
|
{size: 255, expected: 0},
|
||||||
|
{size: 256, expected: 0},
|
||||||
|
{size: 257, expected: 1},
|
||||||
|
{size: 511, expected: 1},
|
||||||
|
{size: 512, expected: 1},
|
||||||
|
{size: 513, expected: 2},
|
||||||
|
{size: 1023, expected: 2},
|
||||||
|
{size: 1024, expected: 2},
|
||||||
|
{size: 1025, expected: 3},
|
||||||
|
{size: 2047, expected: 3},
|
||||||
|
{size: 2048, expected: 3},
|
||||||
|
{size: 2049, expected: 4},
|
||||||
|
{size: 8388607, expected: 15},
|
||||||
|
{size: 8388608, expected: 15},
|
||||||
|
{size: 8388609, expected: 16},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
idx := getPoolIdx(tt.size)
|
||||||
|
assert.Equalf(t, tt.expected, idx, "size: %d", tt.size)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,85 @@
|
|||||||
|
package iobufpool_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/iobufpool"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetCap(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
requestedLen int
|
||||||
|
expectedCap int
|
||||||
|
}{
|
||||||
|
{requestedLen: 0, expectedCap: 256},
|
||||||
|
{requestedLen: 128, expectedCap: 256},
|
||||||
|
{requestedLen: 255, expectedCap: 256},
|
||||||
|
{requestedLen: 256, expectedCap: 256},
|
||||||
|
{requestedLen: 257, expectedCap: 512},
|
||||||
|
{requestedLen: 511, expectedCap: 512},
|
||||||
|
{requestedLen: 512, expectedCap: 512},
|
||||||
|
{requestedLen: 513, expectedCap: 1024},
|
||||||
|
{requestedLen: 1023, expectedCap: 1024},
|
||||||
|
{requestedLen: 1024, expectedCap: 1024},
|
||||||
|
{requestedLen: 33554431, expectedCap: 33554432},
|
||||||
|
{requestedLen: 33554432, expectedCap: 33554432},
|
||||||
|
|
||||||
|
// Above 32 MiB skip the pool and allocate exactly the requested size.
|
||||||
|
{requestedLen: 33554433, expectedCap: 33554433},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
buf := iobufpool.Get(tt.requestedLen)
|
||||||
|
assert.Equalf(t, tt.requestedLen, len(buf), "bad len for requestedLen: %d", len(buf), tt.requestedLen)
|
||||||
|
assert.Equalf(t, tt.expectedCap, cap(buf), "bad cap for requestedLen: %d", tt.requestedLen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPutHandlesWrongSizedBuffers(t *testing.T) {
|
||||||
|
for putBufSize := range []int{0, 1, 128, 250, 256, 257, 1023, 1024, 1025, 1 << 28} {
|
||||||
|
putBuf := make([]byte, putBufSize)
|
||||||
|
iobufpool.Put(putBuf)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
requestedLen int
|
||||||
|
expectedCap int
|
||||||
|
}{
|
||||||
|
{requestedLen: 0, expectedCap: 256},
|
||||||
|
{requestedLen: 128, expectedCap: 256},
|
||||||
|
{requestedLen: 255, expectedCap: 256},
|
||||||
|
{requestedLen: 256, expectedCap: 256},
|
||||||
|
{requestedLen: 257, expectedCap: 512},
|
||||||
|
{requestedLen: 511, expectedCap: 512},
|
||||||
|
{requestedLen: 512, expectedCap: 512},
|
||||||
|
{requestedLen: 513, expectedCap: 1024},
|
||||||
|
{requestedLen: 1023, expectedCap: 1024},
|
||||||
|
{requestedLen: 1024, expectedCap: 1024},
|
||||||
|
{requestedLen: 33554431, expectedCap: 33554432},
|
||||||
|
{requestedLen: 33554432, expectedCap: 33554432},
|
||||||
|
|
||||||
|
// Above 32 MiB skip the pool and allocate exactly the requested size.
|
||||||
|
{requestedLen: 33554433, expectedCap: 33554433},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
getBuf := iobufpool.Get(tt.requestedLen)
|
||||||
|
assert.Equalf(t, tt.requestedLen, len(getBuf), "len(putBuf): %d, requestedLen: %d", len(putBuf), tt.requestedLen)
|
||||||
|
assert.Equalf(t, tt.expectedCap, cap(getBuf), "cap(putBuf): %d, requestedLen: %d", cap(putBuf), tt.requestedLen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPutGetBufferReuse(t *testing.T) {
|
||||||
|
// There is no way to guarantee a buffer will be reused. It should be, but a GC between the Put and the Get will cause
|
||||||
|
// it not to be. So try many times.
|
||||||
|
for i := 0; i < 100000; i++ {
|
||||||
|
buf := iobufpool.Get(4)
|
||||||
|
buf[0] = 1
|
||||||
|
iobufpool.Put(buf)
|
||||||
|
buf = iobufpool.Get(4)
|
||||||
|
if buf[0] == 1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Error("buffer was never reused")
|
||||||
|
}
|
||||||
@@ -0,0 +1,70 @@
|
|||||||
|
package nbconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
const minBufferQueueLen = 8
|
||||||
|
|
||||||
|
type bufferQueue struct {
|
||||||
|
lock sync.Mutex
|
||||||
|
queue [][]byte
|
||||||
|
r, w int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bq *bufferQueue) pushBack(buf []byte) {
|
||||||
|
bq.lock.Lock()
|
||||||
|
defer bq.lock.Unlock()
|
||||||
|
|
||||||
|
if bq.w >= len(bq.queue) {
|
||||||
|
bq.growQueue()
|
||||||
|
}
|
||||||
|
bq.queue[bq.w] = buf
|
||||||
|
bq.w++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bq *bufferQueue) pushFront(buf []byte) {
|
||||||
|
bq.lock.Lock()
|
||||||
|
defer bq.lock.Unlock()
|
||||||
|
|
||||||
|
if bq.w >= len(bq.queue) {
|
||||||
|
bq.growQueue()
|
||||||
|
}
|
||||||
|
copy(bq.queue[bq.r+1:bq.w+1], bq.queue[bq.r:bq.w])
|
||||||
|
bq.queue[bq.r] = buf
|
||||||
|
bq.w++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bq *bufferQueue) popFront() []byte {
|
||||||
|
bq.lock.Lock()
|
||||||
|
defer bq.lock.Unlock()
|
||||||
|
|
||||||
|
if bq.r == bq.w {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := bq.queue[bq.r]
|
||||||
|
bq.queue[bq.r] = nil // Clear reference so it can be garbage collected.
|
||||||
|
bq.r++
|
||||||
|
|
||||||
|
if bq.r == bq.w {
|
||||||
|
bq.r = 0
|
||||||
|
bq.w = 0
|
||||||
|
if len(bq.queue) > minBufferQueueLen {
|
||||||
|
bq.queue = make([][]byte, minBufferQueueLen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bq *bufferQueue) growQueue() {
|
||||||
|
desiredLen := (len(bq.queue) + 1) * 3 / 2
|
||||||
|
if desiredLen < minBufferQueueLen {
|
||||||
|
desiredLen = minBufferQueueLen
|
||||||
|
}
|
||||||
|
|
||||||
|
newQueue := make([][]byte, desiredLen)
|
||||||
|
copy(newQueue, bq.queue)
|
||||||
|
bq.queue = newQueue
|
||||||
|
}
|
||||||
@@ -0,0 +1,476 @@
|
|||||||
|
// Package nbconn implements a non-blocking net.Conn wrapper.
|
||||||
|
//
|
||||||
|
// It is designed to solve three problems.
|
||||||
|
//
|
||||||
|
// The first is resolving the deadlock that can occur when both sides of a connection are blocked writing because all
|
||||||
|
// buffers between are full. See https://github.com/jackc/pgconn/issues/27 for discussion.
|
||||||
|
//
|
||||||
|
// The second is the inability to use a write deadline with a TLS.Conn without killing the connection.
|
||||||
|
//
|
||||||
|
// The third is to efficiently check if a connection has been closed via a non-blocking read.
|
||||||
|
package nbconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/iobufpool"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errClosed = errors.New("closed")
|
||||||
|
var ErrWouldBlock = new(wouldBlockError)
|
||||||
|
|
||||||
|
const fakeNonblockingWaitDuration = 100 * time.Millisecond
|
||||||
|
|
||||||
|
// NonBlockingDeadline is a magic value that when passed to Set[Read]Deadline places the connection in non-blocking read
|
||||||
|
// mode.
|
||||||
|
var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 608536336, time.UTC)
|
||||||
|
|
||||||
|
// disableSetDeadlineDeadline is a magic value that when passed to Set[Read|Write]Deadline causes those methods to
|
||||||
|
// ignore all future calls.
|
||||||
|
var disableSetDeadlineDeadline = time.Date(1900, 1, 1, 0, 0, 0, 968549727, time.UTC)
|
||||||
|
|
||||||
|
// wouldBlockError implements net.Error so tls.Conn will recognize ErrWouldBlock as a temporary error.
|
||||||
|
type wouldBlockError struct{}
|
||||||
|
|
||||||
|
func (*wouldBlockError) Error() string {
|
||||||
|
return "would block"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*wouldBlockError) Timeout() bool { return true }
|
||||||
|
func (*wouldBlockError) Temporary() bool { return true }
|
||||||
|
|
||||||
|
// Conn is a net.Conn where Write never blocks and always succeeds. Flush or Read must be called to actually write to
|
||||||
|
// the underlying connection.
|
||||||
|
type Conn interface {
|
||||||
|
net.Conn
|
||||||
|
|
||||||
|
// Flush flushes any buffered writes.
|
||||||
|
Flush() error
|
||||||
|
|
||||||
|
// BufferReadUntilBlock reads and buffers any sucessfully read bytes until the read would block.
|
||||||
|
BufferReadUntilBlock() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// NetConn is a non-blocking net.Conn wrapper. It implements net.Conn.
|
||||||
|
type NetConn struct {
|
||||||
|
conn net.Conn
|
||||||
|
rawConn syscall.RawConn
|
||||||
|
|
||||||
|
readQueue bufferQueue
|
||||||
|
writeQueue bufferQueue
|
||||||
|
|
||||||
|
readFlushLock sync.Mutex
|
||||||
|
// non-blocking writes with syscall.RawConn are done with a callback function. By using these fields instead of the
|
||||||
|
// callback functions closure to pass the buf argument and receive the n and err results we avoid some allocations.
|
||||||
|
nonblockWriteBuf []byte
|
||||||
|
nonblockWriteErr error
|
||||||
|
nonblockWriteN int
|
||||||
|
|
||||||
|
readDeadlineLock sync.Mutex
|
||||||
|
readDeadline time.Time
|
||||||
|
readNonblocking bool
|
||||||
|
|
||||||
|
writeDeadlineLock sync.Mutex
|
||||||
|
writeDeadline time.Time
|
||||||
|
|
||||||
|
// Only access with atomics
|
||||||
|
closed int64 // 0 = not closed, 1 = closed
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn {
|
||||||
|
nc := &NetConn{
|
||||||
|
conn: conn,
|
||||||
|
}
|
||||||
|
|
||||||
|
if !fakeNonBlockingIO {
|
||||||
|
if sc, ok := conn.(syscall.Conn); ok {
|
||||||
|
if rawConn, err := sc.SyscallConn(); err == nil {
|
||||||
|
nc.rawConn = rawConn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nc
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read implements io.Reader.
|
||||||
|
func (c *NetConn) Read(b []byte) (n int, err error) {
|
||||||
|
if c.isClosed() {
|
||||||
|
return 0, errClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
c.readFlushLock.Lock()
|
||||||
|
defer c.readFlushLock.Unlock()
|
||||||
|
|
||||||
|
err = c.flush()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for n < len(b) {
|
||||||
|
buf := c.readQueue.popFront()
|
||||||
|
if buf == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
copiedN := copy(b[n:], buf)
|
||||||
|
if copiedN < len(buf) {
|
||||||
|
buf = buf[copiedN:]
|
||||||
|
c.readQueue.pushFront(buf)
|
||||||
|
} else {
|
||||||
|
iobufpool.Put(buf)
|
||||||
|
}
|
||||||
|
n += copiedN
|
||||||
|
}
|
||||||
|
|
||||||
|
// If any bytes were already buffered return them without trying to do a Read. Otherwise, when the caller is trying to
|
||||||
|
// Read up to len(b) bytes but all available bytes have already been buffered the underlying Read would block.
|
||||||
|
if n > 0 {
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var readNonblocking bool
|
||||||
|
c.readDeadlineLock.Lock()
|
||||||
|
readNonblocking = c.readNonblocking
|
||||||
|
c.readDeadlineLock.Unlock()
|
||||||
|
|
||||||
|
var readN int
|
||||||
|
if readNonblocking {
|
||||||
|
readN, err = c.nonblockingRead(b[n:])
|
||||||
|
} else {
|
||||||
|
readN, err = c.conn.Read(b[n:])
|
||||||
|
}
|
||||||
|
n += readN
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is
|
||||||
|
// closed. Call Flush to actually write to the underlying connection.
|
||||||
|
func (c *NetConn) Write(b []byte) (n int, err error) {
|
||||||
|
if c.isClosed() {
|
||||||
|
return 0, errClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := iobufpool.Get(len(b))
|
||||||
|
copy(buf, b)
|
||||||
|
c.writeQueue.pushBack(buf)
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetConn) Close() (err error) {
|
||||||
|
swapped := atomic.CompareAndSwapInt64(&c.closed, 0, 1)
|
||||||
|
if !swapped {
|
||||||
|
return errClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
closeErr := c.conn.Close()
|
||||||
|
if err == nil {
|
||||||
|
err = closeErr
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
c.readFlushLock.Lock()
|
||||||
|
defer c.readFlushLock.Unlock()
|
||||||
|
err = c.flush()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetConn) LocalAddr() net.Addr {
|
||||||
|
return c.conn.LocalAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetConn) RemoteAddr() net.Addr {
|
||||||
|
return c.conn.RemoteAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t).
|
||||||
|
func (c *NetConn) SetDeadline(t time.Time) error {
|
||||||
|
err := c.SetReadDeadline(t)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.SetWriteDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReadDeadline sets the read deadline as t. If t == NonBlockingDeadline then future reads will be non-blocking.
|
||||||
|
func (c *NetConn) SetReadDeadline(t time.Time) error {
|
||||||
|
if c.isClosed() {
|
||||||
|
return errClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
c.readDeadlineLock.Lock()
|
||||||
|
defer c.readDeadlineLock.Unlock()
|
||||||
|
if c.readDeadline == disableSetDeadlineDeadline {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if t == disableSetDeadlineDeadline {
|
||||||
|
c.readDeadline = t
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if t == NonBlockingDeadline {
|
||||||
|
c.readNonblocking = true
|
||||||
|
t = time.Time{}
|
||||||
|
} else {
|
||||||
|
c.readNonblocking = false
|
||||||
|
}
|
||||||
|
|
||||||
|
c.readDeadline = t
|
||||||
|
|
||||||
|
return c.conn.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetConn) SetWriteDeadline(t time.Time) error {
|
||||||
|
if c.isClosed() {
|
||||||
|
return errClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
c.writeDeadlineLock.Lock()
|
||||||
|
defer c.writeDeadlineLock.Unlock()
|
||||||
|
if c.writeDeadline == disableSetDeadlineDeadline {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if t == disableSetDeadlineDeadline {
|
||||||
|
c.writeDeadline = t
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.writeDeadline = t
|
||||||
|
|
||||||
|
return c.conn.SetWriteDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetConn) Flush() error {
|
||||||
|
if c.isClosed() {
|
||||||
|
return errClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
c.readFlushLock.Lock()
|
||||||
|
defer c.readFlushLock.Unlock()
|
||||||
|
return c.flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// flush does the actual work of flushing the writeQueue. readFlushLock must already be held.
|
||||||
|
func (c *NetConn) flush() error {
|
||||||
|
var stopChan chan struct{}
|
||||||
|
var errChan chan error
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if stopChan != nil {
|
||||||
|
select {
|
||||||
|
case stopChan <- struct{}{}:
|
||||||
|
case <-errChan:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for buf := c.writeQueue.popFront(); buf != nil; buf = c.writeQueue.popFront() {
|
||||||
|
remainingBuf := buf
|
||||||
|
for len(remainingBuf) > 0 {
|
||||||
|
n, err := c.nonblockingWrite(remainingBuf)
|
||||||
|
remainingBuf = remainingBuf[n:]
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, ErrWouldBlock) {
|
||||||
|
buf = buf[:len(remainingBuf)]
|
||||||
|
copy(buf, remainingBuf)
|
||||||
|
c.writeQueue.pushFront(buf)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Writing was blocked. Reading might unblock it.
|
||||||
|
if stopChan == nil {
|
||||||
|
stopChan, errChan = c.bufferNonblockingRead()
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errChan:
|
||||||
|
stopChan = nil
|
||||||
|
return err
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
iobufpool.Put(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetConn) BufferReadUntilBlock() error {
|
||||||
|
for {
|
||||||
|
buf := iobufpool.Get(8 * 1024)
|
||||||
|
n, err := c.nonblockingRead(buf)
|
||||||
|
if n > 0 {
|
||||||
|
buf = buf[:n]
|
||||||
|
c.readQueue.pushBack(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrWouldBlock) {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetConn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) {
|
||||||
|
stopChan = make(chan struct{})
|
||||||
|
errChan = make(chan error, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
err := c.BufferReadUntilBlock()
|
||||||
|
if err != nil {
|
||||||
|
errChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-stopChan:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return stopChan, errChan
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetConn) isClosed() bool {
|
||||||
|
closed := atomic.LoadInt64(&c.closed)
|
||||||
|
return closed == 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetConn) nonblockingWrite(b []byte) (n int, err error) {
|
||||||
|
if c.rawConn == nil {
|
||||||
|
return c.fakeNonblockingWrite(b)
|
||||||
|
} else {
|
||||||
|
return c.realNonblockingWrite(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) {
|
||||||
|
c.writeDeadlineLock.Lock()
|
||||||
|
defer c.writeDeadlineLock.Unlock()
|
||||||
|
|
||||||
|
deadline := time.Now().Add(fakeNonblockingWaitDuration)
|
||||||
|
if c.writeDeadline.IsZero() || deadline.Before(c.writeDeadline) {
|
||||||
|
err = c.conn.SetWriteDeadline(deadline)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
// Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails.
|
||||||
|
c.conn.SetWriteDeadline(c.writeDeadline)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||||
|
err = ErrWouldBlock
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.conn.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetConn) nonblockingRead(b []byte) (n int, err error) {
|
||||||
|
if c.rawConn == nil {
|
||||||
|
return c.fakeNonblockingRead(b)
|
||||||
|
} else {
|
||||||
|
return c.realNonblockingRead(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) {
|
||||||
|
c.readDeadlineLock.Lock()
|
||||||
|
defer c.readDeadlineLock.Unlock()
|
||||||
|
|
||||||
|
deadline := time.Now().Add(fakeNonblockingWaitDuration)
|
||||||
|
if c.readDeadline.IsZero() || deadline.Before(c.readDeadline) {
|
||||||
|
err = c.conn.SetReadDeadline(deadline)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
// Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails.
|
||||||
|
c.conn.SetReadDeadline(c.readDeadline)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||||
|
err = ErrWouldBlock
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.conn.Read(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// syscall.Conn is interface
|
||||||
|
|
||||||
|
// TLSClient establishes a TLS connection as a client over conn using config.
|
||||||
|
//
|
||||||
|
// To avoid the first Read on the returned *TLSConn also triggering a Write due to the TLS handshake and thereby
|
||||||
|
// potentially causing a read and write deadlines to behave unexpectedly, Handshake is called explicitly before the
|
||||||
|
// *TLSConn is returned.
|
||||||
|
func TLSClient(conn *NetConn, config *tls.Config) (*TLSConn, error) {
|
||||||
|
tc := tls.Client(conn, config)
|
||||||
|
err := tc.Handshake()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure last written part of Handshake is actually sent.
|
||||||
|
err = conn.Flush()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &TLSConn{
|
||||||
|
tlsConn: tc,
|
||||||
|
nbConn: conn,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLSConn is a TLS wrapper around a *Conn. It works around a temporary write error (such as a timeout) being fatal to a
|
||||||
|
// tls.Conn.
|
||||||
|
type TLSConn struct {
|
||||||
|
tlsConn *tls.Conn
|
||||||
|
nbConn *NetConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc *TLSConn) Read(b []byte) (n int, err error) { return tc.tlsConn.Read(b) }
|
||||||
|
func (tc *TLSConn) Write(b []byte) (n int, err error) { return tc.tlsConn.Write(b) }
|
||||||
|
func (tc *TLSConn) BufferReadUntilBlock() error { return tc.nbConn.BufferReadUntilBlock() }
|
||||||
|
func (tc *TLSConn) Flush() error { return tc.nbConn.Flush() }
|
||||||
|
func (tc *TLSConn) LocalAddr() net.Addr { return tc.tlsConn.LocalAddr() }
|
||||||
|
func (tc *TLSConn) RemoteAddr() net.Addr { return tc.tlsConn.RemoteAddr() }
|
||||||
|
|
||||||
|
func (tc *TLSConn) Close() error {
|
||||||
|
// tls.Conn.closeNotify() sets a 5 second deadline to avoid blocking, sends a TLS alert close notification, and then
|
||||||
|
// sets the deadline to now. This causes NetConn's Close not to be able to flush the write buffer. Instead we set our
|
||||||
|
// own 5 second deadline then make all set deadlines no-op.
|
||||||
|
tc.tlsConn.SetDeadline(time.Now().Add(time.Second * 5))
|
||||||
|
tc.tlsConn.SetDeadline(disableSetDeadlineDeadline)
|
||||||
|
|
||||||
|
return tc.tlsConn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc *TLSConn) SetDeadline(t time.Time) error { return tc.tlsConn.SetDeadline(t) }
|
||||||
|
func (tc *TLSConn) SetReadDeadline(t time.Time) error { return tc.tlsConn.SetReadDeadline(t) }
|
||||||
|
func (tc *TLSConn) SetWriteDeadline(t time.Time) error { return tc.tlsConn.SetWriteDeadline(t) }
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
//go:build !(aix || android || darwin || dragonfly || freebsd || hurd || illumos || ios || linux || netbsd || openbsd || solaris)
|
||||||
|
|
||||||
|
package nbconn
|
||||||
|
|
||||||
|
// Not using unix build tag for support on Go 1.18.
|
||||||
|
|
||||||
|
func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
|
||||||
|
return c.fakeNonblockingWrite(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
|
||||||
|
return c.fakeNonblockingRead(b)
|
||||||
|
}
|
||||||
@@ -0,0 +1,70 @@
|
|||||||
|
//go:build aix || android || darwin || dragonfly || freebsd || hurd || illumos || ios || linux || netbsd || openbsd || solaris
|
||||||
|
|
||||||
|
package nbconn
|
||||||
|
|
||||||
|
// Not using unix build tag for support on Go 1.18.
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
// realNonblockingWrite does a non-blocking write. readFlushLock must already be held.
|
||||||
|
func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
|
||||||
|
c.nonblockWriteBuf = b
|
||||||
|
c.nonblockWriteN = 0
|
||||||
|
c.nonblockWriteErr = nil
|
||||||
|
err = c.rawConn.Write(func(fd uintptr) (done bool) {
|
||||||
|
c.nonblockWriteN, c.nonblockWriteErr = syscall.Write(int(fd), c.nonblockWriteBuf)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
n = c.nonblockWriteN
|
||||||
|
if err == nil && c.nonblockWriteErr != nil {
|
||||||
|
if errors.Is(c.nonblockWriteErr, syscall.EWOULDBLOCK) {
|
||||||
|
err = ErrWouldBlock
|
||||||
|
} else {
|
||||||
|
err = c.nonblockWriteErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
// n may be -1 when an error occurs.
|
||||||
|
if n < 0 {
|
||||||
|
n = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
|
||||||
|
var funcErr error
|
||||||
|
err = c.rawConn.Read(func(fd uintptr) (done bool) {
|
||||||
|
n, funcErr = syscall.Read(int(fd), b)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
if err == nil && funcErr != nil {
|
||||||
|
if errors.Is(funcErr, syscall.EWOULDBLOCK) {
|
||||||
|
err = ErrWouldBlock
|
||||||
|
} else {
|
||||||
|
err = funcErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
// n may be -1 when an error occurs.
|
||||||
|
if n < 0 {
|
||||||
|
n = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// syscall read did not return an error and 0 bytes were read means EOF.
|
||||||
|
if n == 0 {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,584 @@
|
|||||||
|
package nbconn_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/nbconn"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test keys generated with:
|
||||||
|
//
|
||||||
|
// $ openssl req -x509 -newkey rsa:2048 -keyout key.pem -out cert.pem -sha256 -nodes -days 20000 -subj '/CN=localhost'
|
||||||
|
|
||||||
|
var testTLSPublicKey = []byte(`-----BEGIN CERTIFICATE-----
|
||||||
|
MIICpjCCAY4CCQCjQKYdUDQzKDANBgkqhkiG9w0BAQsFADAUMRIwEAYDVQQDDAls
|
||||||
|
b2NhbGhvc3QwIBcNMjIwNjA0MTY1MzE2WhgPMjA3NzAzMDcxNjUzMTZaMBQxEjAQ
|
||||||
|
BgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
|
||||||
|
ALHbOu80cfSPufKTZsKf3E5rCXHeIHjaIbgHEXA2SW/n77U8oZX518s+27FO0sK5
|
||||||
|
yA0WnEIwY34PU359sNR5KelARGnaeh3HdaGm1nuyyxBtwwAqIuM0UxGAMF/mQ4lT
|
||||||
|
caZPxG+7WlYDqnE3eVXUtG4c+T7t5qKAB3MtfbzKFSjczkWkroi6cTypmHArGghT
|
||||||
|
0VWWVu0s9oNp5q8iWchY2o9f0aIjmKv6FgtilO+geev+4U+QvtvrziR5BO3/3EgW
|
||||||
|
c5TUVcf+lwkvp8ziXvargmjjnNTyeF37y4KpFcex0v7z7hSrUK4zU0+xRn7Bp17v
|
||||||
|
7gzj0xN+HCsUW1cjPFNezX0CAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAbEBzewzg
|
||||||
|
Z5F+BqMSxP3HkMCkLLH0N9q0/DkZaVyZ38vrjcjaDYuabq28kA2d5dc5jxsQpvTw
|
||||||
|
HTGqSv1ZxJP3pBFv6jLSh8xaM6tUkk482Q6DnZGh97CD4yup/yJzkn5nv9OHtZ9g
|
||||||
|
TnaQeeXgOz0o5Zq9IpzHJb19ysya3UCIK8oKXbSO4Qd168seCq75V2BFHDpmejjk
|
||||||
|
D92eT6WODlzzvZbhzA1F3/cUilZdhbQtJMqdecKvD+yrBpzGVqzhWQsXwsRAU1fB
|
||||||
|
hShx+D14zUGM2l4wlVzOAuGh4ZL7x3AjJsc86TsCavTspS0Xl51j+mRbiULq7G7Y
|
||||||
|
E7ZYmaKTMOhvkg==
|
||||||
|
-----END CERTIFICATE-----`)
|
||||||
|
|
||||||
|
// The strings.ReplaceAll is used to placate any secret scanners that would squawk if they saw a private key embedded in
|
||||||
|
// source code.
|
||||||
|
var testTLSPrivateKey = []byte(strings.ReplaceAll(`-----BEGIN TESTING KEY-----
|
||||||
|
MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQCx2zrvNHH0j7ny
|
||||||
|
k2bCn9xOawlx3iB42iG4BxFwNklv5++1PKGV+dfLPtuxTtLCucgNFpxCMGN+D1N+
|
||||||
|
fbDUeSnpQERp2nodx3WhptZ7sssQbcMAKiLjNFMRgDBf5kOJU3GmT8Rvu1pWA6px
|
||||||
|
N3lV1LRuHPk+7eaigAdzLX28yhUo3M5FpK6IunE8qZhwKxoIU9FVllbtLPaDaeav
|
||||||
|
IlnIWNqPX9GiI5ir+hYLYpTvoHnr/uFPkL7b684keQTt/9xIFnOU1FXH/pcJL6fM
|
||||||
|
4l72q4Jo45zU8nhd+8uCqRXHsdL+8+4Uq1CuM1NPsUZ+wade7+4M49MTfhwrFFtX
|
||||||
|
IzxTXs19AgMBAAECggEBAJcHt5ARVQN8WUbobMawwX/F3QtYuPJnKWMAfYpwTwQ8
|
||||||
|
TI32orCcrObmxeBXMxowcPTMUnzSYmpV0W0EhvimuzRbYr0Qzcoj6nwPFOuN9GpL
|
||||||
|
CuBE58NQV4nw9SM6gfdHaKb17bWDvz5zdnUVym9cZKts5yrNEqDDX5Aq/S8n27gJ
|
||||||
|
/qheXwSxwETVO6kMEW1ndNIWDP8DPQ0E4O//RuMZwxpnZdnjGKkdVNy8I1BpgDgn
|
||||||
|
lwgkE3H3IciASki1GYXoyvrIiRwMQVzvYD2zcgwK9OZSjZe0TGwAGa+eQdbs3A9I
|
||||||
|
Ir1kYn6ZMGMRFJA2XHJW3hMZdWB/t2xMBGy75Uv9sAECgYEA1o+oRUYwwQ1MwBo9
|
||||||
|
YA6c00KjhFgrjdzyKPQrN14Q0dw5ErqRkhp2cs7BRdCDTDrjAegPc3Otg7uMa1vp
|
||||||
|
RgU/C72jwzFLYATvn+RLGRYRyqIE+bQ22/lLnXTrp4DCfdMrqWuQbIYouGHqfQrq
|
||||||
|
MfdtSUpQ6VZCi9zHehXOYwBMvQECgYEA1DTQFpe+tndIFmguxxaBwDltoPh5omzd
|
||||||
|
3vA7iFct2+UYk5W9shfAekAaZk2WufKmmC3OfBWYyIaJ7QwQpuGDS3zwjy6WFMTE
|
||||||
|
Otp2CypFCVahwHcvn2jYHmDMT0k0Pt6X2S3GAyWTyEPv7mAfKR1OWUYi7ZgdXpt0
|
||||||
|
TtL3Z3JyhH0CgYEAwveHUGuXodUUCPvPCZo9pzrGm1wDN8WtxskY/Bbd8dTLh9lA
|
||||||
|
riKdv3Vg6q+un3ZjETht0dsrsKib0HKUZqwdve11AcmpVHcnx4MLOqBzSk4vdzfr
|
||||||
|
IbhGna3A9VRrZyqcYjb75aGDHwjaqwVgCkdrZ03AeEeJ8M2N9cIa6Js9IAECgYBu
|
||||||
|
nlU24cVdspJWc9qml3ntrUITnlMxs1R5KXuvF9rk/OixzmYDV1RTpeTdHWcL6Yyk
|
||||||
|
WYSAtHVfWpq9ggOQKpBZonh3+w3rJ6MvFsBgE5nHQ2ywOrENhQbb1xPJ5NwiRcCc
|
||||||
|
Srsk2srNo3SIK30y3n8AFIqSljABKEIZ8Olc+JDvtQKBgQCiKz43zI6a0HscgZ77
|
||||||
|
DCBduWP4nk8BM7QTFxs9VypjrylMDGGtTKHc5BLA5fNZw97Hb7pcicN7/IbUnQUD
|
||||||
|
pz01y53wMSTJs0ocAxkYvUc5laF+vMsLpG2vp8f35w8uKuO7+vm5LAjUsPd099jG
|
||||||
|
2qWm8jTPeDC3sq+67s2oojHf+Q==
|
||||||
|
-----END TESTING KEY-----`, "TESTING KEY", "PRIVATE KEY"))
|
||||||
|
|
||||||
|
func testVariants(t *testing.T, f func(t *testing.T, local nbconn.Conn, remote net.Conn)) {
|
||||||
|
for _, tt := range []struct {
|
||||||
|
name string
|
||||||
|
makeConns func(t *testing.T) (local, remote net.Conn)
|
||||||
|
useTLS bool
|
||||||
|
fakeNonBlockingIO bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Pipe",
|
||||||
|
makeConns: makePipeConns,
|
||||||
|
useTLS: false,
|
||||||
|
fakeNonBlockingIO: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TCP with Fake Non-blocking IO",
|
||||||
|
makeConns: makeTCPConns,
|
||||||
|
useTLS: false,
|
||||||
|
fakeNonBlockingIO: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TLS over TCP with Fake Non-blocking IO",
|
||||||
|
makeConns: makeTCPConns,
|
||||||
|
useTLS: true,
|
||||||
|
fakeNonBlockingIO: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TCP with Real Non-blocking IO",
|
||||||
|
makeConns: makeTCPConns,
|
||||||
|
useTLS: false,
|
||||||
|
fakeNonBlockingIO: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TLS over TCP with Real Non-blocking IO",
|
||||||
|
makeConns: makeTCPConns,
|
||||||
|
useTLS: true,
|
||||||
|
fakeNonBlockingIO: false,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
local, remote := tt.makeConns(t)
|
||||||
|
|
||||||
|
// Just to be sure both ends get closed. Also, it retains a reference so one side of the connection doesn't get
|
||||||
|
// garbage collected. This could happen when a test is testing against a non-responsive remote. Since it never
|
||||||
|
// uses remote it may be garbage collected leading to the connection being closed.
|
||||||
|
defer local.Close()
|
||||||
|
defer remote.Close()
|
||||||
|
|
||||||
|
var conn nbconn.Conn
|
||||||
|
netConn := nbconn.NewNetConn(local, tt.fakeNonBlockingIO)
|
||||||
|
|
||||||
|
if tt.useTLS {
|
||||||
|
cert, err := tls.X509KeyPair(testTLSPublicKey, testTLSPrivateKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tlsServer := tls.Server(remote, &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
})
|
||||||
|
serverTLSHandshakeChan := make(chan error)
|
||||||
|
go func() {
|
||||||
|
err := tlsServer.Handshake()
|
||||||
|
serverTLSHandshakeChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
tlsConn, err := nbconn.TLSClient(netConn, &tls.Config{InsecureSkipVerify: true})
|
||||||
|
require.NoError(t, err)
|
||||||
|
conn = tlsConn
|
||||||
|
|
||||||
|
err = <-serverTLSHandshakeChan
|
||||||
|
require.NoError(t, err)
|
||||||
|
remote = tlsServer
|
||||||
|
} else {
|
||||||
|
conn = netConn
|
||||||
|
}
|
||||||
|
|
||||||
|
f(t, conn, remote)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// makePipeConns returns a connected pair of net.Conns created with net.Pipe(). It is entirely synchronous so it is
|
||||||
|
// useful for testing an exact sequence of reads and writes with the underlying connection blocking.
|
||||||
|
func makePipeConns(t *testing.T) (local, remote net.Conn) {
|
||||||
|
local, remote = net.Pipe()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
local.Close()
|
||||||
|
remote.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
return local, remote
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeTCPConns returns a connected pair of net.Conns running over TCP on localhost.
|
||||||
|
func makeTCPConns(t *testing.T) (local, remote net.Conn) {
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
type acceptResultT struct {
|
||||||
|
conn net.Conn
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
acceptChan := make(chan acceptResultT)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
acceptChan <- acceptResultT{conn: conn, err: err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
local, err = net.Dial("tcp", ln.Addr().String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
acceptResult := <-acceptChan
|
||||||
|
require.NoError(t, acceptResult.err)
|
||||||
|
|
||||||
|
remote = acceptResult.conn
|
||||||
|
|
||||||
|
return local, remote
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteIsBuffered(t *testing.T) {
|
||||||
|
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||||
|
// net.Pipe is synchronous so the Write would block if not buffered.
|
||||||
|
writeBuf := []byte("test")
|
||||||
|
n, err := conn.Write(writeBuf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 4, n)
|
||||||
|
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
err := conn.Flush()
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
readBuf := make([]byte, len(writeBuf))
|
||||||
|
_, err = remote.Read(readBuf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, <-errChan)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetWriteDeadlineDoesNotBlockWrite(t *testing.T) {
|
||||||
|
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||||
|
err := conn.SetWriteDeadline(time.Now())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
writeBuf := []byte("test")
|
||||||
|
n, err := conn.Write(writeBuf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 4, n)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadFlushesWriteBuffer(t *testing.T) {
|
||||||
|
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||||
|
writeBuf := []byte("test")
|
||||||
|
n, err := conn.Write(writeBuf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 4, n)
|
||||||
|
|
||||||
|
errChan := make(chan error, 2)
|
||||||
|
go func() {
|
||||||
|
readBuf := make([]byte, len(writeBuf))
|
||||||
|
_, err := remote.Read(readBuf)
|
||||||
|
errChan <- err
|
||||||
|
|
||||||
|
_, err = remote.Write([]byte("okay"))
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
readBuf := make([]byte, 4)
|
||||||
|
_, err = conn.Read(readBuf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []byte("okay"), readBuf)
|
||||||
|
|
||||||
|
require.NoError(t, <-errChan)
|
||||||
|
require.NoError(t, <-errChan)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloseFlushesWriteBuffer(t *testing.T) {
|
||||||
|
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||||
|
writeBuf := []byte("test")
|
||||||
|
n, err := conn.Write(writeBuf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 4, n)
|
||||||
|
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
readBuf := make([]byte, len(writeBuf))
|
||||||
|
_, err := remote.Read(readBuf)
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = conn.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, <-errChan)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// This test exercises the non-blocking write path. Because writes are buffered it is difficult trigger this with
|
||||||
|
// certainty and visibility. So this test tries to trigger what would otherwise be a deadlock by both sides writing
|
||||||
|
// large values.
|
||||||
|
func TestInternalNonBlockingWrite(t *testing.T) {
|
||||||
|
const deadlockSize = 4 * 1024 * 1024
|
||||||
|
|
||||||
|
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||||
|
writeBuf := make([]byte, deadlockSize)
|
||||||
|
n, err := conn.Write(writeBuf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, deadlockSize, n)
|
||||||
|
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
remoteWriteBuf := make([]byte, deadlockSize)
|
||||||
|
_, err := remote.Write(remoteWriteBuf)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
readBuf := make([]byte, deadlockSize)
|
||||||
|
_, err = io.ReadFull(remote, readBuf)
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
readBuf := make([]byte, deadlockSize)
|
||||||
|
_, err = io.ReadFull(conn, readBuf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = conn.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, <-errChan)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInternalNonBlockingWriteWithDeadline(t *testing.T) {
|
||||||
|
const deadlockSize = 4 * 1024 * 1024
|
||||||
|
|
||||||
|
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||||
|
writeBuf := make([]byte, deadlockSize)
|
||||||
|
n, err := conn.Write(writeBuf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, deadlockSize, n)
|
||||||
|
|
||||||
|
err = conn.SetDeadline(time.Now().Add(100 * time.Millisecond))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = conn.Flush()
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "i/o timeout")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNonBlockingRead(t *testing.T) {
|
||||||
|
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||||
|
err := conn.SetReadDeadline(nbconn.NonBlockingDeadline)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
buf := make([]byte, 4)
|
||||||
|
n, err := conn.Read(buf)
|
||||||
|
require.ErrorIs(t, err, nbconn.ErrWouldBlock)
|
||||||
|
require.EqualValues(t, 0, n)
|
||||||
|
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, err := remote.Write([]byte("okay"))
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = conn.SetReadDeadline(time.Time{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
n, err = conn.Read(buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 4, n)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBufferNonBlockingRead(t *testing.T) {
|
||||||
|
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||||
|
err := conn.BufferReadUntilBlock()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, err := remote.Write([]byte("okay"))
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
err = conn.BufferReadUntilBlock()
|
||||||
|
if !errors.Is(err, nbconn.ErrWouldBlock) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
buf := make([]byte, 4)
|
||||||
|
n, err := conn.Read(buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 4, n)
|
||||||
|
require.Equal(t, []byte("okay"), buf)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadPreviouslyBuffered(t *testing.T) {
|
||||||
|
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||||
|
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
err := func() error {
|
||||||
|
_, err := remote.Write([]byte("alpha"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
readBuf := make([]byte, 4)
|
||||||
|
_, err = remote.Read(readBuf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}()
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := conn.Write([]byte("test"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
|
||||||
|
err = conn.Flush()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
readBuf := make([]byte, 5)
|
||||||
|
n, err := conn.Read(readBuf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 5, n)
|
||||||
|
require.Equal(t, []byte("alpha"), readBuf)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadMoreThanPreviouslyBufferedDoesNotBlock(t *testing.T) {
|
||||||
|
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
err := func() error {
|
||||||
|
_, err := remote.Write([]byte("alpha"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
readBuf := make([]byte, 4)
|
||||||
|
_, err = remote.Read(readBuf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}()
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := conn.Write([]byte("test"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
|
||||||
|
err = conn.Flush()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
readBuf := make([]byte, 10)
|
||||||
|
n, err := conn.Read(readBuf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 5, n)
|
||||||
|
require.Equal(t, []byte("alpha"), readBuf[:n])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadPreviouslyBufferedPartialRead(t *testing.T) {
|
||||||
|
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||||
|
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
err := func() error {
|
||||||
|
_, err := remote.Write([]byte("alpha"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
readBuf := make([]byte, 4)
|
||||||
|
_, err = remote.Read(readBuf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}()
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := conn.Write([]byte("test"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
|
||||||
|
err = conn.Flush()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
readBuf := make([]byte, 2)
|
||||||
|
n, err := conn.Read(readBuf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 2, n)
|
||||||
|
require.Equal(t, []byte("al"), readBuf)
|
||||||
|
|
||||||
|
readBuf = make([]byte, 3)
|
||||||
|
n, err = conn.Read(readBuf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 3, n)
|
||||||
|
require.Equal(t, []byte("pha"), readBuf)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadMultiplePreviouslyBuffered(t *testing.T) {
|
||||||
|
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
err := func() error {
|
||||||
|
_, err := remote.Write([]byte("alpha"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = remote.Write([]byte("beta"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
readBuf := make([]byte, 4)
|
||||||
|
_, err = remote.Read(readBuf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}()
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := conn.Write([]byte("test"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
|
||||||
|
err = conn.Flush()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
readBuf := make([]byte, 9)
|
||||||
|
n, err := io.ReadFull(conn, readBuf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 9, n)
|
||||||
|
require.Equal(t, []byte("alphabeta"), readBuf)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadPreviouslyBufferedAndReadMore(t *testing.T) {
|
||||||
|
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||||
|
|
||||||
|
flushCompleteChan := make(chan struct{})
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
err := func() error {
|
||||||
|
_, err := remote.Write([]byte("alpha"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
readBuf := make([]byte, 4)
|
||||||
|
_, err = remote.Read(readBuf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
<-flushCompleteChan
|
||||||
|
|
||||||
|
_, err = remote.Write([]byte("beta"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}()
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := conn.Write([]byte("test"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
|
||||||
|
err = conn.Flush()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
close(flushCompleteChan)
|
||||||
|
|
||||||
|
readBuf := make([]byte, 9)
|
||||||
|
|
||||||
|
n, err := io.ReadFull(conn, readBuf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 9, n)
|
||||||
|
require.Equal(t, []byte("alphabeta"), readBuf)
|
||||||
|
|
||||||
|
err = <-errChan
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
# pgio
|
||||||
|
|
||||||
|
Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol.
|
||||||
|
|
||||||
|
pgio provides functions for appending integers to a []byte while doing byte
|
||||||
|
order conversion.
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
// Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol.
|
||||||
|
/*
|
||||||
|
pgio provides functions for appending integers to a []byte while doing byte
|
||||||
|
order conversion.
|
||||||
|
*/
|
||||||
|
package pgio
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
package pgio
|
||||||
|
|
||||||
|
import "encoding/binary"
|
||||||
|
|
||||||
|
func AppendUint16(buf []byte, n uint16) []byte {
|
||||||
|
wp := len(buf)
|
||||||
|
buf = append(buf, 0, 0)
|
||||||
|
binary.BigEndian.PutUint16(buf[wp:], n)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func AppendUint32(buf []byte, n uint32) []byte {
|
||||||
|
wp := len(buf)
|
||||||
|
buf = append(buf, 0, 0, 0, 0)
|
||||||
|
binary.BigEndian.PutUint32(buf[wp:], n)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func AppendUint64(buf []byte, n uint64) []byte {
|
||||||
|
wp := len(buf)
|
||||||
|
buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0)
|
||||||
|
binary.BigEndian.PutUint64(buf[wp:], n)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func AppendInt16(buf []byte, n int16) []byte {
|
||||||
|
return AppendUint16(buf, uint16(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
func AppendInt32(buf []byte, n int32) []byte {
|
||||||
|
return AppendUint32(buf, uint32(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
func AppendInt64(buf []byte, n int64) []byte {
|
||||||
|
return AppendUint64(buf, uint64(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetInt32(buf []byte, n int32) {
|
||||||
|
binary.BigEndian.PutUint32(buf, uint32(n))
|
||||||
|
}
|
||||||
@@ -0,0 +1,78 @@
|
|||||||
|
package pgio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAppendUint16NilBuf(t *testing.T) {
|
||||||
|
buf := AppendUint16(nil, 1)
|
||||||
|
if !reflect.DeepEqual(buf, []byte{0, 1}) {
|
||||||
|
t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppendUint16EmptyBuf(t *testing.T) {
|
||||||
|
buf := []byte{}
|
||||||
|
buf = AppendUint16(buf, 1)
|
||||||
|
if !reflect.DeepEqual(buf, []byte{0, 1}) {
|
||||||
|
t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppendUint16BufWithCapacityDoesNotAllocate(t *testing.T) {
|
||||||
|
buf := make([]byte, 0, 4)
|
||||||
|
AppendUint16(buf, 1)
|
||||||
|
buf = buf[0:2]
|
||||||
|
if !reflect.DeepEqual(buf, []byte{0, 1}) {
|
||||||
|
t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppendUint32NilBuf(t *testing.T) {
|
||||||
|
buf := AppendUint32(nil, 1)
|
||||||
|
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) {
|
||||||
|
t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppendUint32EmptyBuf(t *testing.T) {
|
||||||
|
buf := []byte{}
|
||||||
|
buf = AppendUint32(buf, 1)
|
||||||
|
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) {
|
||||||
|
t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppendUint32BufWithCapacityDoesNotAllocate(t *testing.T) {
|
||||||
|
buf := make([]byte, 0, 4)
|
||||||
|
AppendUint32(buf, 1)
|
||||||
|
buf = buf[0:4]
|
||||||
|
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) {
|
||||||
|
t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppendUint64NilBuf(t *testing.T) {
|
||||||
|
buf := AppendUint64(nil, 1)
|
||||||
|
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) {
|
||||||
|
t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppendUint64EmptyBuf(t *testing.T) {
|
||||||
|
buf := []byte{}
|
||||||
|
buf = AppendUint64(buf, 1)
|
||||||
|
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) {
|
||||||
|
t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppendUint64BufWithCapacityDoesNotAllocate(t *testing.T) {
|
||||||
|
buf := make([]byte, 0, 8)
|
||||||
|
AppendUint64(buf, 1)
|
||||||
|
buf = buf[0:8]
|
||||||
|
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) {
|
||||||
|
t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,136 @@
|
|||||||
|
// Package pgmock provides the ability to mock a PostgreSQL server.
|
||||||
|
package pgmock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgproto3"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Step interface {
|
||||||
|
Step(*pgproto3.Backend) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Script struct {
|
||||||
|
Steps []Step
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Script) Run(backend *pgproto3.Backend) error {
|
||||||
|
for _, step := range s.Steps {
|
||||||
|
err := step.Step(backend)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Script) Step(backend *pgproto3.Backend) error {
|
||||||
|
return s.Run(backend)
|
||||||
|
}
|
||||||
|
|
||||||
|
type expectMessageStep struct {
|
||||||
|
want pgproto3.FrontendMessage
|
||||||
|
any bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *expectMessageStep) Step(backend *pgproto3.Backend) error {
|
||||||
|
msg, err := backend.Receive()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.any && reflect.TypeOf(msg) == reflect.TypeOf(e.want) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(msg, e.want) {
|
||||||
|
return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type expectStartupMessageStep struct {
|
||||||
|
want *pgproto3.StartupMessage
|
||||||
|
any bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *expectStartupMessageStep) Step(backend *pgproto3.Backend) error {
|
||||||
|
msg, err := backend.ReceiveStartupMessage()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.any {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(msg, e.want) {
|
||||||
|
return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExpectMessage(want pgproto3.FrontendMessage) Step {
|
||||||
|
return expectMessage(want, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExpectAnyMessage(want pgproto3.FrontendMessage) Step {
|
||||||
|
return expectMessage(want, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func expectMessage(want pgproto3.FrontendMessage, any bool) Step {
|
||||||
|
if want, ok := want.(*pgproto3.StartupMessage); ok {
|
||||||
|
return &expectStartupMessageStep{want: want, any: any}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &expectMessageStep{want: want, any: any}
|
||||||
|
}
|
||||||
|
|
||||||
|
type sendMessageStep struct {
|
||||||
|
msg pgproto3.BackendMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *sendMessageStep) Step(backend *pgproto3.Backend) error {
|
||||||
|
backend.Send(e.msg)
|
||||||
|
return backend.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func SendMessage(msg pgproto3.BackendMessage) Step {
|
||||||
|
return &sendMessageStep{msg: msg}
|
||||||
|
}
|
||||||
|
|
||||||
|
type waitForCloseMessageStep struct{}
|
||||||
|
|
||||||
|
func (e *waitForCloseMessageStep) Step(backend *pgproto3.Backend) error {
|
||||||
|
for {
|
||||||
|
msg, err := backend.Receive()
|
||||||
|
if err == io.EOF {
|
||||||
|
return nil
|
||||||
|
} else if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := msg.(*pgproto3.Terminate); ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WaitForClose() Step {
|
||||||
|
return &waitForCloseMessageStep{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func AcceptUnauthenticatedConnRequestSteps() []Step {
|
||||||
|
return []Step{
|
||||||
|
ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
|
||||||
|
SendMessage(&pgproto3.AuthenticationOk{}),
|
||||||
|
SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
|
||||||
|
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,93 @@
|
|||||||
|
package pgmock_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/pgmock"
|
||||||
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
"github.com/jackc/pgx/v5/pgproto3"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestScript(t *testing.T) {
|
||||||
|
script := &pgmock.Script{
|
||||||
|
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
|
||||||
|
}
|
||||||
|
script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Query{String: "select 42"}))
|
||||||
|
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.RowDescription{
|
||||||
|
Fields: []pgproto3.FieldDescription{
|
||||||
|
pgproto3.FieldDescription{
|
||||||
|
Name: []byte("?column?"),
|
||||||
|
TableOID: 0,
|
||||||
|
TableAttributeNumber: 0,
|
||||||
|
DataTypeOID: 23,
|
||||||
|
DataTypeSize: 4,
|
||||||
|
TypeModifier: -1,
|
||||||
|
Format: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.DataRow{
|
||||||
|
Values: [][]byte{[]byte("42")},
|
||||||
|
}))
|
||||||
|
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}))
|
||||||
|
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}))
|
||||||
|
script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Terminate{}))
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
serverErrChan := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
defer close(serverErrChan)
|
||||||
|
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
serverErrChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
err = conn.SetDeadline(time.Now().Add(time.Second))
|
||||||
|
if err != nil {
|
||||||
|
serverErrChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = script.Run(pgproto3.NewBackend(conn, conn))
|
||||||
|
if err != nil {
|
||||||
|
serverErrChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
parts := strings.Split(ln.Addr().String(), ":")
|
||||||
|
host := parts[0]
|
||||||
|
port := parts[1]
|
||||||
|
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
|
defer cancel()
|
||||||
|
pgConn, err := pgconn.Connect(ctx, connStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
results, err := pgConn.Exec(ctx, "select 42").ReadAll()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, results, 1)
|
||||||
|
assert.Nil(t, results[0].Err)
|
||||||
|
assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
|
||||||
|
assert.Len(t, results[0].Rows, 1)
|
||||||
|
assert.Equal(t, "42", string(results[0].Rows[0][0]))
|
||||||
|
|
||||||
|
pgConn.Close(ctx)
|
||||||
|
|
||||||
|
assert.NoError(t, <-serverErrChan)
|
||||||
|
}
|
||||||
@@ -12,13 +12,13 @@ import (
|
|||||||
|
|
||||||
// Part is either a string or an int. A string is raw SQL. An int is a
|
// Part is either a string or an int. A string is raw SQL. An int is a
|
||||||
// argument placeholder.
|
// argument placeholder.
|
||||||
type Part interface{}
|
type Part any
|
||||||
|
|
||||||
type Query struct {
|
type Query struct {
|
||||||
Parts []Part
|
Parts []Part
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Query) Sanitize(args ...interface{}) (string, error) {
|
func (q *Query) Sanitize(args ...any) (string, error) {
|
||||||
argUse := make([]bool, len(args))
|
argUse := make([]bool, len(args))
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
|
|
||||||
@@ -295,7 +295,7 @@ func multilineCommentState(l *sqlLexer) stateFn {
|
|||||||
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
|
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
|
||||||
// as necessary. This function is only safe when standard_conforming_strings is
|
// as necessary. This function is only safe when standard_conforming_strings is
|
||||||
// on.
|
// on.
|
||||||
func SanitizeSQL(sql string, args ...interface{}) (string, error) {
|
func SanitizeSQL(sql string, args ...any) (string, error) {
|
||||||
query, err := NewQuery(sql)
|
query, err := NewQuery(sql)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jackc/pgx/v4/internal/sanitize"
|
"github.com/jackc/pgx/v5/internal/sanitize"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewQuery(t *testing.T) {
|
func TestNewQuery(t *testing.T) {
|
||||||
@@ -111,57 +111,57 @@ func TestNewQuery(t *testing.T) {
|
|||||||
func TestQuerySanitize(t *testing.T) {
|
func TestQuerySanitize(t *testing.T) {
|
||||||
successfulTests := []struct {
|
successfulTests := []struct {
|
||||||
query sanitize.Query
|
query sanitize.Query
|
||||||
args []interface{}
|
args []any
|
||||||
expected string
|
expected string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
query: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
|
query: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
|
||||||
args: []interface{}{},
|
args: []any{},
|
||||||
expected: `select 42`,
|
expected: `select 42`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||||
args: []interface{}{int64(42)},
|
args: []any{int64(42)},
|
||||||
expected: `select 42`,
|
expected: `select 42`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||||
args: []interface{}{float64(1.23)},
|
args: []any{float64(1.23)},
|
||||||
expected: `select 1.23`,
|
expected: `select 1.23`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||||
args: []interface{}{true},
|
args: []any{true},
|
||||||
expected: `select true`,
|
expected: `select true`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||||
args: []interface{}{[]byte{0, 1, 2, 3, 255}},
|
args: []any{[]byte{0, 1, 2, 3, 255}},
|
||||||
expected: `select '\x00010203ff'`,
|
expected: `select '\x00010203ff'`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||||
args: []interface{}{nil},
|
args: []any{nil},
|
||||||
expected: `select null`,
|
expected: `select null`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||||
args: []interface{}{"foobar"},
|
args: []any{"foobar"},
|
||||||
expected: `select 'foobar'`,
|
expected: `select 'foobar'`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||||
args: []interface{}{"foo'bar"},
|
args: []any{"foo'bar"},
|
||||||
expected: `select 'foo''bar'`,
|
expected: `select 'foo''bar'`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||||
args: []interface{}{`foo\'bar`},
|
args: []any{`foo\'bar`},
|
||||||
expected: `select 'foo\''bar'`,
|
expected: `select 'foo\''bar'`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
query: sanitize.Query{Parts: []sanitize.Part{"insert ", 1}},
|
query: sanitize.Query{Parts: []sanitize.Part{"insert ", 1}},
|
||||||
args: []interface{}{time.Date(2020, time.March, 1, 23, 59, 59, 999999999, time.UTC)},
|
args: []any{time.Date(2020, time.March, 1, 23, 59, 59, 999999999, time.UTC)},
|
||||||
expected: `insert '2020-03-01 23:59:59.999999Z'`,
|
expected: `insert '2020-03-01 23:59:59.999999Z'`,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -180,22 +180,22 @@ func TestQuerySanitize(t *testing.T) {
|
|||||||
|
|
||||||
errorTests := []struct {
|
errorTests := []struct {
|
||||||
query sanitize.Query
|
query sanitize.Query
|
||||||
args []interface{}
|
args []any
|
||||||
expected string
|
expected string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2}},
|
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2}},
|
||||||
args: []interface{}{int64(42)},
|
args: []any{int64(42)},
|
||||||
expected: `insufficient arguments`,
|
expected: `insufficient arguments`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
query: sanitize.Query{Parts: []sanitize.Part{"select 'foo'"}},
|
query: sanitize.Query{Parts: []sanitize.Part{"select 'foo'"}},
|
||||||
args: []interface{}{int64(42)},
|
args: []any{int64(42)},
|
||||||
expected: `unused argument: 0`,
|
expected: `unused argument: 0`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||||
args: []interface{}{42},
|
args: []any{42},
|
||||||
expected: `invalid arg type: int`,
|
expected: `invalid arg type: int`,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,98 @@
|
|||||||
|
package stmtcache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"container/list"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LRUCache implements Cache with a Least Recently Used (LRU) cache.
|
||||||
|
type LRUCache struct {
|
||||||
|
cap int
|
||||||
|
m map[string]*list.Element
|
||||||
|
l *list.List
|
||||||
|
invalidStmts []*pgconn.StatementDescription
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLRUCache creates a new LRUCache. cap is the maximum size of the cache.
|
||||||
|
func NewLRUCache(cap int) *LRUCache {
|
||||||
|
return &LRUCache{
|
||||||
|
cap: cap,
|
||||||
|
m: make(map[string]*list.Element),
|
||||||
|
l: list.New(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the statement description for sql. Returns nil if not found.
|
||||||
|
func (c *LRUCache) Get(key string) *pgconn.StatementDescription {
|
||||||
|
if el, ok := c.m[key]; ok {
|
||||||
|
c.l.MoveToFront(el)
|
||||||
|
return el.Value.(*pgconn.StatementDescription)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
|
||||||
|
func (c *LRUCache) Put(sd *pgconn.StatementDescription) {
|
||||||
|
if sd.SQL == "" {
|
||||||
|
panic("cannot store statement description with empty SQL")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, present := c.m[sd.SQL]; present {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.l.Len() == c.cap {
|
||||||
|
c.invalidateOldest()
|
||||||
|
}
|
||||||
|
|
||||||
|
el := c.l.PushFront(sd)
|
||||||
|
c.m[sd.SQL] = el
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
|
||||||
|
func (c *LRUCache) Invalidate(sql string) {
|
||||||
|
if el, ok := c.m[sql]; ok {
|
||||||
|
delete(c.m, sql)
|
||||||
|
c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription))
|
||||||
|
c.l.Remove(el)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateAll invalidates all statement descriptions.
|
||||||
|
func (c *LRUCache) InvalidateAll() {
|
||||||
|
el := c.l.Front()
|
||||||
|
for el != nil {
|
||||||
|
c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription))
|
||||||
|
el = el.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
c.m = make(map[string]*list.Element)
|
||||||
|
c.l = list.New()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LRUCache) HandleInvalidated() []*pgconn.StatementDescription {
|
||||||
|
invalidStmts := c.invalidStmts
|
||||||
|
c.invalidStmts = nil
|
||||||
|
return invalidStmts
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the number of cached prepared statement descriptions.
|
||||||
|
func (c *LRUCache) Len() int {
|
||||||
|
return c.l.Len()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cap returns the maximum number of cached prepared statement descriptions.
|
||||||
|
func (c *LRUCache) Cap() int {
|
||||||
|
return c.cap
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LRUCache) invalidateOldest() {
|
||||||
|
oldest := c.l.Back()
|
||||||
|
sd := oldest.Value.(*pgconn.StatementDescription)
|
||||||
|
c.invalidStmts = append(c.invalidStmts, sd)
|
||||||
|
delete(c.m, sd.SQL)
|
||||||
|
c.l.Remove(oldest)
|
||||||
|
}
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
// Package stmtcache is a cache for statement descriptions.
|
||||||
|
package stmtcache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
)
|
||||||
|
|
||||||
|
var stmtCounter int64
|
||||||
|
|
||||||
|
// NextStatementName returns a statement name that will be unique for the lifetime of the program.
|
||||||
|
func NextStatementName() string {
|
||||||
|
n := atomic.AddInt64(&stmtCounter, 1)
|
||||||
|
return "stmtcache_" + strconv.FormatInt(n, 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cache caches statement descriptions.
|
||||||
|
type Cache interface {
|
||||||
|
// Get returns the statement description for sql. Returns nil if not found.
|
||||||
|
Get(sql string) *pgconn.StatementDescription
|
||||||
|
|
||||||
|
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
|
||||||
|
Put(sd *pgconn.StatementDescription)
|
||||||
|
|
||||||
|
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
|
||||||
|
Invalidate(sql string)
|
||||||
|
|
||||||
|
// InvalidateAll invalidates all statement descriptions.
|
||||||
|
InvalidateAll()
|
||||||
|
|
||||||
|
// HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated.
|
||||||
|
HandleInvalidated() []*pgconn.StatementDescription
|
||||||
|
|
||||||
|
// Len returns the number of cached prepared statement descriptions.
|
||||||
|
Len() int
|
||||||
|
|
||||||
|
// Cap returns the maximum number of cached prepared statement descriptions.
|
||||||
|
Cap() int
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsStatementInvalid(err error) bool {
|
||||||
|
pgErr, ok := err.(*pgconn.PgError)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/jackc/pgx/issues/1162
|
||||||
|
//
|
||||||
|
// We used to look for the message "cached plan must not change result type". However, that message can be localized.
|
||||||
|
// Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to
|
||||||
|
// tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't
|
||||||
|
// have so it should be safe.
|
||||||
|
possibleInvalidCachedPlanError := pgErr.Code == "0A000"
|
||||||
|
return possibleInvalidCachedPlanError
|
||||||
|
}
|
||||||
@@ -0,0 +1,71 @@
|
|||||||
|
package stmtcache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UnlimitedCache implements Cache with no capacity limit.
|
||||||
|
type UnlimitedCache struct {
|
||||||
|
m map[string]*pgconn.StatementDescription
|
||||||
|
invalidStmts []*pgconn.StatementDescription
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUnlimitedCache creates a new UnlimitedCache.
|
||||||
|
func NewUnlimitedCache() *UnlimitedCache {
|
||||||
|
return &UnlimitedCache{
|
||||||
|
m: make(map[string]*pgconn.StatementDescription),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the statement description for sql. Returns nil if not found.
|
||||||
|
func (c *UnlimitedCache) Get(sql string) *pgconn.StatementDescription {
|
||||||
|
return c.m[sql]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
|
||||||
|
func (c *UnlimitedCache) Put(sd *pgconn.StatementDescription) {
|
||||||
|
if sd.SQL == "" {
|
||||||
|
panic("cannot store statement description with empty SQL")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, present := c.m[sd.SQL]; present {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.m[sd.SQL] = sd
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
|
||||||
|
func (c *UnlimitedCache) Invalidate(sql string) {
|
||||||
|
if sd, ok := c.m[sql]; ok {
|
||||||
|
delete(c.m, sql)
|
||||||
|
c.invalidStmts = append(c.invalidStmts, sd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateAll invalidates all statement descriptions.
|
||||||
|
func (c *UnlimitedCache) InvalidateAll() {
|
||||||
|
for _, sd := range c.m {
|
||||||
|
c.invalidStmts = append(c.invalidStmts, sd)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.m = make(map[string]*pgconn.StatementDescription)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UnlimitedCache) HandleInvalidated() []*pgconn.StatementDescription {
|
||||||
|
invalidStmts := c.invalidStmts
|
||||||
|
c.invalidStmts = nil
|
||||||
|
return invalidStmts
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the number of cached prepared statement descriptions.
|
||||||
|
func (c *UnlimitedCache) Len() int {
|
||||||
|
return len(c.m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cap returns the maximum number of cached prepared statement descriptions.
|
||||||
|
func (c *UnlimitedCache) Cap() int {
|
||||||
|
return math.MaxInt
|
||||||
|
}
|
||||||
@@ -7,8 +7,9 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jackc/pgconn"
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/jackc/pgx/v4"
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
"github.com/jackc/pgx/v5/pgxtest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestLargeObjects(t *testing.T) {
|
func TestLargeObjects(t *testing.T) {
|
||||||
@@ -22,7 +23,7 @@ func TestLargeObjects(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
skipCockroachDB(t, conn, "Server does support large objects")
|
pgxtest.SkipCockroachDB(t, conn, "Server does support large objects")
|
||||||
|
|
||||||
tx, err := conn.Begin(ctx)
|
tx, err := conn.Begin(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -32,7 +33,7 @@ func TestLargeObjects(t *testing.T) {
|
|||||||
testLargeObjects(t, ctx, tx)
|
testLargeObjects(t, ctx, tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLargeObjectsPreferSimpleProtocol(t *testing.T) {
|
func TestLargeObjectsSimpleProtocol(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
@@ -43,14 +44,14 @@ func TestLargeObjectsPreferSimpleProtocol(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config.PreferSimpleProtocol = true
|
config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
|
||||||
|
|
||||||
conn, err := pgx.ConnectConfig(ctx, config)
|
conn, err := pgx.ConnectConfig(ctx, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
skipCockroachDB(t, conn, "Server does support large objects")
|
pgxtest.SkipCockroachDB(t, conn, "Server does support large objects")
|
||||||
|
|
||||||
tx, err := conn.Begin(ctx)
|
tx, err := conn.Begin(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -169,7 +170,7 @@ func TestLargeObjectsMultipleTransactions(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
skipCockroachDB(t, conn, "Server does support large objects")
|
pgxtest.SkipCockroachDB(t, conn, "Server does support large objects")
|
||||||
|
|
||||||
tx, err := conn.Begin(ctx)
|
tx, err := conn.Begin(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,39 +0,0 @@
|
|||||||
package kitlogadapter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/go-kit/log"
|
|
||||||
kitlevel "github.com/go-kit/log/level"
|
|
||||||
"github.com/jackc/pgx/v4"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Logger struct {
|
|
||||||
l log.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewLogger(l log.Logger) *Logger {
|
|
||||||
return &Logger{l: l}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
|
|
||||||
logger := l.l
|
|
||||||
for k, v := range data {
|
|
||||||
logger = log.With(logger, k, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch level {
|
|
||||||
case pgx.LogLevelTrace:
|
|
||||||
logger.Log("PGX_LOG_LEVEL", level, "msg", msg)
|
|
||||||
case pgx.LogLevelDebug:
|
|
||||||
kitlevel.Debug(logger).Log("msg", msg)
|
|
||||||
case pgx.LogLevelInfo:
|
|
||||||
kitlevel.Info(logger).Log("msg", msg)
|
|
||||||
case pgx.LogLevelWarn:
|
|
||||||
kitlevel.Warn(logger).Log("msg", msg)
|
|
||||||
case pgx.LogLevelError:
|
|
||||||
kitlevel.Error(logger).Log("msg", msg)
|
|
||||||
default:
|
|
||||||
logger.Log("INVALID_PGX_LOG_LEVEL", level, "error", msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
// Package log15adapter provides a logger that writes to a github.com/inconshreveable/log15.Logger
|
|
||||||
// log.
|
|
||||||
package log15adapter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/v4"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Log15Logger interface defines the subset of
|
|
||||||
// github.com/inconshreveable/log15.Logger that this adapter uses.
|
|
||||||
type Log15Logger interface {
|
|
||||||
Debug(msg string, ctx ...interface{})
|
|
||||||
Info(msg string, ctx ...interface{})
|
|
||||||
Warn(msg string, ctx ...interface{})
|
|
||||||
Error(msg string, ctx ...interface{})
|
|
||||||
Crit(msg string, ctx ...interface{})
|
|
||||||
}
|
|
||||||
|
|
||||||
type Logger struct {
|
|
||||||
l Log15Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewLogger(l Log15Logger) *Logger {
|
|
||||||
return &Logger{l: l}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
|
|
||||||
logArgs := make([]interface{}, 0, len(data))
|
|
||||||
for k, v := range data {
|
|
||||||
logArgs = append(logArgs, k, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch level {
|
|
||||||
case pgx.LogLevelTrace:
|
|
||||||
l.l.Debug(msg, append(logArgs, "PGX_LOG_LEVEL", level)...)
|
|
||||||
case pgx.LogLevelDebug:
|
|
||||||
l.l.Debug(msg, logArgs...)
|
|
||||||
case pgx.LogLevelInfo:
|
|
||||||
l.l.Info(msg, logArgs...)
|
|
||||||
case pgx.LogLevelWarn:
|
|
||||||
l.l.Warn(msg, logArgs...)
|
|
||||||
case pgx.LogLevelError:
|
|
||||||
l.l.Error(msg, logArgs...)
|
|
||||||
default:
|
|
||||||
l.l.Error(msg, append(logArgs, "INVALID_PGX_LOG_LEVEL", level)...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
// Package logrusadapter provides a logger that writes to a github.com/sirupsen/logrus.Logger
|
|
||||||
// log.
|
|
||||||
package logrusadapter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/v4"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Logger struct {
|
|
||||||
l logrus.FieldLogger
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewLogger(l logrus.FieldLogger) *Logger {
|
|
||||||
return &Logger{l: l}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
|
|
||||||
var logger logrus.FieldLogger
|
|
||||||
if data != nil {
|
|
||||||
logger = l.l.WithFields(data)
|
|
||||||
} else {
|
|
||||||
logger = l.l
|
|
||||||
}
|
|
||||||
|
|
||||||
switch level {
|
|
||||||
case pgx.LogLevelTrace:
|
|
||||||
logger.WithField("PGX_LOG_LEVEL", level).Debug(msg)
|
|
||||||
case pgx.LogLevelDebug:
|
|
||||||
logger.Debug(msg)
|
|
||||||
case pgx.LogLevelInfo:
|
|
||||||
logger.Info(msg)
|
|
||||||
case pgx.LogLevelWarn:
|
|
||||||
logger.Warn(msg)
|
|
||||||
case pgx.LogLevelError:
|
|
||||||
logger.Error(msg)
|
|
||||||
default:
|
|
||||||
logger.WithField("INVALID_PGX_LOG_LEVEL", level).Error(msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -6,13 +6,13 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/jackc/pgx/v4"
|
"github.com/jackc/pgx/v5/tracelog"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestingLogger interface defines the subset of testing.TB methods used by this
|
// TestingLogger interface defines the subset of testing.TB methods used by this
|
||||||
// adapter.
|
// adapter.
|
||||||
type TestingLogger interface {
|
type TestingLogger interface {
|
||||||
Log(args ...interface{})
|
Log(args ...any)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Logger struct {
|
type Logger struct {
|
||||||
@@ -23,8 +23,8 @@ func NewLogger(l TestingLogger) *Logger {
|
|||||||
return &Logger{l: l}
|
return &Logger{l: l}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
|
func (l *Logger) Log(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]any) {
|
||||||
logArgs := make([]interface{}, 0, 2+len(data))
|
logArgs := make([]any, 0, 2+len(data))
|
||||||
logArgs = append(logArgs, level, msg)
|
logArgs = append(logArgs, level, msg)
|
||||||
for k, v := range data {
|
for k, v := range data {
|
||||||
logArgs = append(logArgs, fmt.Sprintf("%s=%v", k, v))
|
logArgs = append(logArgs, fmt.Sprintf("%s=%v", k, v))
|
||||||
|
|||||||
@@ -1,42 +0,0 @@
|
|||||||
// Package zapadapter provides a logger that writes to a go.uber.org/zap.Logger.
|
|
||||||
package zapadapter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/v4"
|
|
||||||
"go.uber.org/zap"
|
|
||||||
"go.uber.org/zap/zapcore"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Logger struct {
|
|
||||||
logger *zap.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewLogger(logger *zap.Logger) *Logger {
|
|
||||||
return &Logger{logger: logger.WithOptions(zap.AddCallerSkip(1))}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pl *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
|
|
||||||
fields := make([]zapcore.Field, len(data))
|
|
||||||
i := 0
|
|
||||||
for k, v := range data {
|
|
||||||
fields[i] = zap.Any(k, v)
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
|
|
||||||
switch level {
|
|
||||||
case pgx.LogLevelTrace:
|
|
||||||
pl.logger.Debug(msg, append(fields, zap.Stringer("PGX_LOG_LEVEL", level))...)
|
|
||||||
case pgx.LogLevelDebug:
|
|
||||||
pl.logger.Debug(msg, fields...)
|
|
||||||
case pgx.LogLevelInfo:
|
|
||||||
pl.logger.Info(msg, fields...)
|
|
||||||
case pgx.LogLevelWarn:
|
|
||||||
pl.logger.Warn(msg, fields...)
|
|
||||||
case pgx.LogLevelError:
|
|
||||||
pl.logger.Error(msg, fields...)
|
|
||||||
default:
|
|
||||||
pl.logger.Error(msg, append(fields, zap.Stringer("PGX_LOG_LEVEL", level))...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,102 +0,0 @@
|
|||||||
// Package zerologadapter provides a logger that writes to a github.com/rs/zerolog.
|
|
||||||
package zerologadapter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/v4"
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Logger struct {
|
|
||||||
logger zerolog.Logger
|
|
||||||
withFunc func(context.Context, zerolog.Context) zerolog.Context
|
|
||||||
fromContext bool
|
|
||||||
skipModule bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// option options for configuring the logger when creating a new logger.
|
|
||||||
type option func(logger *Logger)
|
|
||||||
|
|
||||||
// WithContextFunc adds possibility to get request scoped values from the
|
|
||||||
// ctx.Context before logging lines.
|
|
||||||
func WithContextFunc(withFunc func(context.Context, zerolog.Context) zerolog.Context) option {
|
|
||||||
return func(logger *Logger) {
|
|
||||||
logger.withFunc = withFunc
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithoutPGXModule disables adding module:pgx to the default logger context.
|
|
||||||
func WithoutPGXModule() option {
|
|
||||||
return func(logger *Logger) {
|
|
||||||
logger.skipModule = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewLogger accepts a zerolog.Logger as input and returns a new custom pgx
|
|
||||||
// logging facade as output.
|
|
||||||
func NewLogger(logger zerolog.Logger, options ...option) *Logger {
|
|
||||||
l := Logger{
|
|
||||||
logger: logger,
|
|
||||||
}
|
|
||||||
l.init(options)
|
|
||||||
return &l
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewContextLogger creates logger that extracts the zerolog.Logger from the
|
|
||||||
// context.Context by using `zerolog.Ctx`. The zerolog.DefaultContextLogger will
|
|
||||||
// be used if no logger is associated with the context.
|
|
||||||
func NewContextLogger(options ...option) *Logger {
|
|
||||||
l := Logger{
|
|
||||||
fromContext: true,
|
|
||||||
}
|
|
||||||
l.init(options)
|
|
||||||
return &l
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pl *Logger) init(options []option) {
|
|
||||||
for _, opt := range options {
|
|
||||||
opt(pl)
|
|
||||||
}
|
|
||||||
if !pl.skipModule {
|
|
||||||
pl.logger = pl.logger.With().Str("module", "pgx").Logger()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pl *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
|
|
||||||
var zlevel zerolog.Level
|
|
||||||
switch level {
|
|
||||||
case pgx.LogLevelNone:
|
|
||||||
zlevel = zerolog.NoLevel
|
|
||||||
case pgx.LogLevelError:
|
|
||||||
zlevel = zerolog.ErrorLevel
|
|
||||||
case pgx.LogLevelWarn:
|
|
||||||
zlevel = zerolog.WarnLevel
|
|
||||||
case pgx.LogLevelInfo:
|
|
||||||
zlevel = zerolog.InfoLevel
|
|
||||||
case pgx.LogLevelDebug:
|
|
||||||
zlevel = zerolog.DebugLevel
|
|
||||||
default:
|
|
||||||
zlevel = zerolog.DebugLevel
|
|
||||||
}
|
|
||||||
|
|
||||||
var zctx zerolog.Context
|
|
||||||
if pl.fromContext {
|
|
||||||
logger := zerolog.Ctx(ctx)
|
|
||||||
zctx = logger.With()
|
|
||||||
} else {
|
|
||||||
zctx = pl.logger.With()
|
|
||||||
}
|
|
||||||
if pl.withFunc != nil {
|
|
||||||
zctx = pl.withFunc(ctx, zctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
pgxlog := zctx.Logger()
|
|
||||||
event := pgxlog.WithLevel(zlevel)
|
|
||||||
if event.Enabled() {
|
|
||||||
if pl.fromContext && !pl.skipModule {
|
|
||||||
event.Str("module", "pgx")
|
|
||||||
}
|
|
||||||
event.Fields(data).Msg(msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,98 +0,0 @@
|
|||||||
package zerologadapter_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/v4"
|
|
||||||
"github.com/jackc/pgx/v4/log/zerologadapter"
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestLogger(t *testing.T) {
|
|
||||||
|
|
||||||
t.Run("default", func(t *testing.T) {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
zlogger := zerolog.New(&buf)
|
|
||||||
logger := zerologadapter.NewLogger(zlogger)
|
|
||||||
logger.Log(context.Background(), pgx.LogLevelInfo, "hello", map[string]interface{}{"one": "two"})
|
|
||||||
const want = `{"level":"info","module":"pgx","one":"two","message":"hello"}
|
|
||||||
`
|
|
||||||
got := buf.String()
|
|
||||||
if got != want {
|
|
||||||
t.Errorf("%s != %s", got, want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("disable pgx module", func(t *testing.T) {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
zlogger := zerolog.New(&buf)
|
|
||||||
logger := zerologadapter.NewLogger(zlogger, zerologadapter.WithoutPGXModule())
|
|
||||||
logger.Log(context.Background(), pgx.LogLevelInfo, "hello", nil)
|
|
||||||
const want = `{"level":"info","message":"hello"}
|
|
||||||
`
|
|
||||||
got := buf.String()
|
|
||||||
if got != want {
|
|
||||||
t.Errorf("%s != %s", got, want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("from context", func(t *testing.T) {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
zlogger := zerolog.New(&buf)
|
|
||||||
ctx := zlogger.WithContext(context.Background())
|
|
||||||
logger := zerologadapter.NewContextLogger()
|
|
||||||
logger.Log(ctx, pgx.LogLevelInfo, "hello", map[string]interface{}{"one": "two"})
|
|
||||||
const want = `{"level":"info","module":"pgx","one":"two","message":"hello"}
|
|
||||||
`
|
|
||||||
|
|
||||||
got := buf.String()
|
|
||||||
if got != want {
|
|
||||||
t.Log(got)
|
|
||||||
t.Log(want)
|
|
||||||
t.Errorf("%s != %s", got, want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
var buf bytes.Buffer
|
|
||||||
type key string
|
|
||||||
var ck key
|
|
||||||
zlogger := zerolog.New(&buf)
|
|
||||||
logger := zerologadapter.NewLogger(zlogger,
|
|
||||||
zerologadapter.WithContextFunc(func(ctx context.Context, logWith zerolog.Context) zerolog.Context {
|
|
||||||
// You can use zerolog.hlog.IDFromCtx(ctx) or even
|
|
||||||
// zerolog.log.Ctx(ctx) to fetch the whole logger instance from the
|
|
||||||
// context if you want.
|
|
||||||
id, ok := ctx.Value(ck).(string)
|
|
||||||
if ok {
|
|
||||||
logWith = logWith.Str("req_id", id)
|
|
||||||
}
|
|
||||||
return logWith
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
|
|
||||||
t.Run("no request id", func(t *testing.T) {
|
|
||||||
buf.Reset()
|
|
||||||
ctx := context.Background()
|
|
||||||
logger.Log(ctx, pgx.LogLevelInfo, "hello", nil)
|
|
||||||
const want = `{"level":"info","module":"pgx","message":"hello"}
|
|
||||||
`
|
|
||||||
got := buf.String()
|
|
||||||
if got != want {
|
|
||||||
t.Errorf("%s != %s", got, want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("with request id", func(t *testing.T) {
|
|
||||||
buf.Reset()
|
|
||||||
ctx := context.WithValue(context.Background(), ck, "1")
|
|
||||||
logger.Log(ctx, pgx.LogLevelInfo, "hello", map[string]interface{}{"two": "2"})
|
|
||||||
const want = `{"level":"info","module":"pgx","req_id":"1","two":"2","message":"hello"}
|
|
||||||
`
|
|
||||||
got := buf.String()
|
|
||||||
if got != want {
|
|
||||||
t.Errorf("%s != %s", got, want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,107 +0,0 @@
|
|||||||
package pgx
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/hex"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
// The values for log levels are chosen such that the zero value means that no
|
|
||||||
// log level was specified.
|
|
||||||
const (
|
|
||||||
LogLevelTrace = 6
|
|
||||||
LogLevelDebug = 5
|
|
||||||
LogLevelInfo = 4
|
|
||||||
LogLevelWarn = 3
|
|
||||||
LogLevelError = 2
|
|
||||||
LogLevelNone = 1
|
|
||||||
)
|
|
||||||
|
|
||||||
// LogLevel represents the pgx logging level. See LogLevel* constants for
|
|
||||||
// possible values.
|
|
||||||
type LogLevel int
|
|
||||||
|
|
||||||
func (ll LogLevel) String() string {
|
|
||||||
switch ll {
|
|
||||||
case LogLevelTrace:
|
|
||||||
return "trace"
|
|
||||||
case LogLevelDebug:
|
|
||||||
return "debug"
|
|
||||||
case LogLevelInfo:
|
|
||||||
return "info"
|
|
||||||
case LogLevelWarn:
|
|
||||||
return "warn"
|
|
||||||
case LogLevelError:
|
|
||||||
return "error"
|
|
||||||
case LogLevelNone:
|
|
||||||
return "none"
|
|
||||||
default:
|
|
||||||
return fmt.Sprintf("invalid level %d", ll)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Logger is the interface used to get logging from pgx internals.
|
|
||||||
type Logger interface {
|
|
||||||
// Log a message at the given level with data key/value pairs. data may be nil.
|
|
||||||
Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{})
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoggerFunc is a wrapper around a function to satisfy the pgx.Logger interface
|
|
||||||
type LoggerFunc func(ctx context.Context, level LogLevel, msg string, data map[string]interface{})
|
|
||||||
|
|
||||||
// Log delegates the logging request to the wrapped function
|
|
||||||
func (f LoggerFunc) Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) {
|
|
||||||
f(ctx, level, msg, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// LogLevelFromString converts log level string to constant
|
|
||||||
//
|
|
||||||
// Valid levels:
|
|
||||||
//
|
|
||||||
// trace
|
|
||||||
// debug
|
|
||||||
// info
|
|
||||||
// warn
|
|
||||||
// error
|
|
||||||
// none
|
|
||||||
func LogLevelFromString(s string) (LogLevel, error) {
|
|
||||||
switch s {
|
|
||||||
case "trace":
|
|
||||||
return LogLevelTrace, nil
|
|
||||||
case "debug":
|
|
||||||
return LogLevelDebug, nil
|
|
||||||
case "info":
|
|
||||||
return LogLevelInfo, nil
|
|
||||||
case "warn":
|
|
||||||
return LogLevelWarn, nil
|
|
||||||
case "error":
|
|
||||||
return LogLevelError, nil
|
|
||||||
case "none":
|
|
||||||
return LogLevelNone, nil
|
|
||||||
default:
|
|
||||||
return 0, errors.New("invalid log level")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func logQueryArgs(args []interface{}) []interface{} {
|
|
||||||
logArgs := make([]interface{}, 0, len(args))
|
|
||||||
|
|
||||||
for _, a := range args {
|
|
||||||
switch v := a.(type) {
|
|
||||||
case []byte:
|
|
||||||
if len(v) < 64 {
|
|
||||||
a = hex.EncodeToString(v)
|
|
||||||
} else {
|
|
||||||
a = fmt.Sprintf("%x (truncated %d bytes)", v[:64], len(v)-64)
|
|
||||||
}
|
|
||||||
case string:
|
|
||||||
if len(v) > 64 {
|
|
||||||
a = fmt.Sprintf("%s (truncated %d bytes)", v[:64], len(v)-64)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
logArgs = append(logArgs, a)
|
|
||||||
}
|
|
||||||
|
|
||||||
return logArgs
|
|
||||||
}
|
|
||||||
-23
@@ -1,23 +0,0 @@
|
|||||||
package pgx
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql/driver"
|
|
||||||
|
|
||||||
"github.com/jackc/pgtype"
|
|
||||||
)
|
|
||||||
|
|
||||||
func convertDriverValuers(args []interface{}) ([]interface{}, error) {
|
|
||||||
for i, arg := range args {
|
|
||||||
switch arg := arg.(type) {
|
|
||||||
case pgtype.BinaryEncoder:
|
|
||||||
case pgtype.TextEncoder:
|
|
||||||
case driver.Valuer:
|
|
||||||
v, err := callValuerValue(arg)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
args[i] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return args, nil
|
|
||||||
}
|
|
||||||
+266
@@ -0,0 +1,266 @@
|
|||||||
|
package pgx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"unicode/utf8"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NamedArgs can be used as the first argument to a query method. It will replace every '@' named placeholder with a '$'
|
||||||
|
// ordinal placeholder and construct the appropriate arguments.
|
||||||
|
//
|
||||||
|
// For example, the following two queries are equivalent:
|
||||||
|
//
|
||||||
|
// conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2}))
|
||||||
|
// conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2}))
|
||||||
|
type NamedArgs map[string]any
|
||||||
|
|
||||||
|
// RewriteQuery implements the QueryRewriter interface.
|
||||||
|
func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any) {
|
||||||
|
l := &sqlLexer{
|
||||||
|
src: sql,
|
||||||
|
stateFn: rawState,
|
||||||
|
nameToOrdinal: make(map[namedArg]int, len(na)),
|
||||||
|
}
|
||||||
|
|
||||||
|
for l.stateFn != nil {
|
||||||
|
l.stateFn = l.stateFn(l)
|
||||||
|
}
|
||||||
|
|
||||||
|
sb := strings.Builder{}
|
||||||
|
for _, p := range l.parts {
|
||||||
|
switch p := p.(type) {
|
||||||
|
case string:
|
||||||
|
sb.WriteString(p)
|
||||||
|
case namedArg:
|
||||||
|
sb.WriteRune('$')
|
||||||
|
sb.WriteString(strconv.Itoa(l.nameToOrdinal[p]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
newArgs = make([]any, len(l.nameToOrdinal))
|
||||||
|
for name, ordinal := range l.nameToOrdinal {
|
||||||
|
newArgs[ordinal-1] = na[string(name)]
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String(), newArgs
|
||||||
|
}
|
||||||
|
|
||||||
|
type namedArg string
|
||||||
|
|
||||||
|
type sqlLexer struct {
|
||||||
|
src string
|
||||||
|
start int
|
||||||
|
pos int
|
||||||
|
nested int // multiline comment nesting level.
|
||||||
|
stateFn stateFn
|
||||||
|
parts []any
|
||||||
|
|
||||||
|
nameToOrdinal map[namedArg]int
|
||||||
|
}
|
||||||
|
|
||||||
|
type stateFn func(*sqlLexer) stateFn
|
||||||
|
|
||||||
|
func rawState(l *sqlLexer) stateFn {
|
||||||
|
for {
|
||||||
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
l.pos += width
|
||||||
|
|
||||||
|
switch r {
|
||||||
|
case 'e', 'E':
|
||||||
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
if nextRune == '\'' {
|
||||||
|
l.pos += width
|
||||||
|
return escapeStringState
|
||||||
|
}
|
||||||
|
case '\'':
|
||||||
|
return singleQuoteState
|
||||||
|
case '"':
|
||||||
|
return doubleQuoteState
|
||||||
|
case '@':
|
||||||
|
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
if isLetter(nextRune) {
|
||||||
|
if l.pos-l.start > 0 {
|
||||||
|
l.parts = append(l.parts, l.src[l.start:l.pos-width])
|
||||||
|
}
|
||||||
|
l.start = l.pos
|
||||||
|
return namedArgState
|
||||||
|
}
|
||||||
|
case '-':
|
||||||
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
if nextRune == '-' {
|
||||||
|
l.pos += width
|
||||||
|
return oneLineCommentState
|
||||||
|
}
|
||||||
|
case '/':
|
||||||
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
if nextRune == '*' {
|
||||||
|
l.pos += width
|
||||||
|
return multilineCommentState
|
||||||
|
}
|
||||||
|
case utf8.RuneError:
|
||||||
|
if l.pos-l.start > 0 {
|
||||||
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||||
|
l.start = l.pos
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isLetter(r rune) bool {
|
||||||
|
return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z')
|
||||||
|
}
|
||||||
|
|
||||||
|
func namedArgState(l *sqlLexer) stateFn {
|
||||||
|
for {
|
||||||
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
l.pos += width
|
||||||
|
|
||||||
|
if r == utf8.RuneError {
|
||||||
|
if l.pos-l.start > 0 {
|
||||||
|
na := namedArg(l.src[l.start:l.pos])
|
||||||
|
if _, found := l.nameToOrdinal[na]; !found {
|
||||||
|
l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1
|
||||||
|
}
|
||||||
|
l.parts = append(l.parts, na)
|
||||||
|
l.start = l.pos
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
} else if !(isLetter(r) || (r >= '0' && r <= '9') || r == '_') {
|
||||||
|
l.pos -= width
|
||||||
|
na := namedArg(l.src[l.start:l.pos])
|
||||||
|
if _, found := l.nameToOrdinal[na]; !found {
|
||||||
|
l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1
|
||||||
|
}
|
||||||
|
l.parts = append(l.parts, namedArg(na))
|
||||||
|
l.start = l.pos
|
||||||
|
return rawState
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func singleQuoteState(l *sqlLexer) stateFn {
|
||||||
|
for {
|
||||||
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
l.pos += width
|
||||||
|
|
||||||
|
switch r {
|
||||||
|
case '\'':
|
||||||
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
if nextRune != '\'' {
|
||||||
|
return rawState
|
||||||
|
}
|
||||||
|
l.pos += width
|
||||||
|
case utf8.RuneError:
|
||||||
|
if l.pos-l.start > 0 {
|
||||||
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||||
|
l.start = l.pos
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func doubleQuoteState(l *sqlLexer) stateFn {
|
||||||
|
for {
|
||||||
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
l.pos += width
|
||||||
|
|
||||||
|
switch r {
|
||||||
|
case '"':
|
||||||
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
if nextRune != '"' {
|
||||||
|
return rawState
|
||||||
|
}
|
||||||
|
l.pos += width
|
||||||
|
case utf8.RuneError:
|
||||||
|
if l.pos-l.start > 0 {
|
||||||
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||||
|
l.start = l.pos
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func escapeStringState(l *sqlLexer) stateFn {
|
||||||
|
for {
|
||||||
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
l.pos += width
|
||||||
|
|
||||||
|
switch r {
|
||||||
|
case '\\':
|
||||||
|
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
l.pos += width
|
||||||
|
case '\'':
|
||||||
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
if nextRune != '\'' {
|
||||||
|
return rawState
|
||||||
|
}
|
||||||
|
l.pos += width
|
||||||
|
case utf8.RuneError:
|
||||||
|
if l.pos-l.start > 0 {
|
||||||
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||||
|
l.start = l.pos
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func oneLineCommentState(l *sqlLexer) stateFn {
|
||||||
|
for {
|
||||||
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
l.pos += width
|
||||||
|
|
||||||
|
switch r {
|
||||||
|
case '\\':
|
||||||
|
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
l.pos += width
|
||||||
|
case '\n', '\r':
|
||||||
|
return rawState
|
||||||
|
case utf8.RuneError:
|
||||||
|
if l.pos-l.start > 0 {
|
||||||
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||||
|
l.start = l.pos
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func multilineCommentState(l *sqlLexer) stateFn {
|
||||||
|
for {
|
||||||
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
l.pos += width
|
||||||
|
|
||||||
|
switch r {
|
||||||
|
case '/':
|
||||||
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
if nextRune == '*' {
|
||||||
|
l.pos += width
|
||||||
|
l.nested++
|
||||||
|
}
|
||||||
|
case '*':
|
||||||
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
if nextRune != '/' {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
l.pos += width
|
||||||
|
if l.nested == 0 {
|
||||||
|
return rawState
|
||||||
|
}
|
||||||
|
l.nested--
|
||||||
|
|
||||||
|
case utf8.RuneError:
|
||||||
|
if l.pos-l.start > 0 {
|
||||||
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||||
|
l.start = l.pos
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,102 @@
|
|||||||
|
package pgx_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNamedArgsRewriteQuery(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
for i, tt := range []struct {
|
||||||
|
sql string
|
||||||
|
args []any
|
||||||
|
namedArgs pgx.NamedArgs
|
||||||
|
expectedSQL string
|
||||||
|
expectedArgs []any
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
sql: "select * from users where id = @id",
|
||||||
|
namedArgs: pgx.NamedArgs{"id": int32(42)},
|
||||||
|
expectedSQL: "select * from users where id = $1",
|
||||||
|
expectedArgs: []any{int32(42)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: "select * from t where foo < @abc and baz = @def and bar < @abc",
|
||||||
|
namedArgs: pgx.NamedArgs{"abc": int32(42), "def": int32(1)},
|
||||||
|
expectedSQL: "select * from t where foo < $1 and baz = $2 and bar < $1",
|
||||||
|
expectedArgs: []any{int32(42), int32(1)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: "select @a::int, @b::text",
|
||||||
|
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
|
||||||
|
expectedSQL: "select $1::int, $2::text",
|
||||||
|
expectedArgs: []any{int32(42), "foo"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: "select @Abc::int, @b_4::text",
|
||||||
|
namedArgs: pgx.NamedArgs{"Abc": int32(42), "b_4": "foo"},
|
||||||
|
expectedSQL: "select $1::int, $2::text",
|
||||||
|
expectedArgs: []any{int32(42), "foo"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: "at end @",
|
||||||
|
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
|
||||||
|
expectedSQL: "at end @",
|
||||||
|
expectedArgs: []any{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: "ignores without letter after @ foo bar",
|
||||||
|
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
|
||||||
|
expectedSQL: "ignores without letter after @ foo bar",
|
||||||
|
expectedArgs: []any{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: "name must start with letter @1 foo bar",
|
||||||
|
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
|
||||||
|
expectedSQL: "name must start with letter @1 foo bar",
|
||||||
|
expectedArgs: []any{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: `select *, '@foo' as "@bar" from users where id = @id`,
|
||||||
|
namedArgs: pgx.NamedArgs{"id": int32(42)},
|
||||||
|
expectedSQL: `select *, '@foo' as "@bar" from users where id = $1`,
|
||||||
|
expectedArgs: []any{int32(42)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: `select * -- @foo
|
||||||
|
from users -- @single line comments
|
||||||
|
where id = @id;`,
|
||||||
|
namedArgs: pgx.NamedArgs{"id": int32(42)},
|
||||||
|
expectedSQL: `select * -- @foo
|
||||||
|
from users -- @single line comments
|
||||||
|
where id = $1;`,
|
||||||
|
expectedArgs: []any{int32(42)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: `select * /* @multi line
|
||||||
|
@comment
|
||||||
|
*/
|
||||||
|
/* /* with @nesting */ */
|
||||||
|
from users
|
||||||
|
where id = @id;`,
|
||||||
|
namedArgs: pgx.NamedArgs{"id": int32(42)},
|
||||||
|
expectedSQL: `select * /* @multi line
|
||||||
|
@comment
|
||||||
|
*/
|
||||||
|
/* /* with @nesting */ */
|
||||||
|
from users
|
||||||
|
where id = $1;`,
|
||||||
|
expectedArgs: []any{int32(42)},
|
||||||
|
},
|
||||||
|
|
||||||
|
// test comments and quotes
|
||||||
|
} {
|
||||||
|
sql, args := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args)
|
||||||
|
assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
|
||||||
|
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
+4
-8
@@ -5,9 +5,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/jackc/pgconn"
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/jackc/pgconn/stmtcache"
|
|
||||||
"github.com/jackc/pgx/v4"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@@ -19,9 +17,8 @@ func TestPgbouncerStatementCacheDescribe(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
config := mustParseConfig(t, connString)
|
config := mustParseConfig(t, connString)
|
||||||
config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache {
|
config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe
|
||||||
return stmtcache.New(conn, stmtcache.ModeDescribe, 1024)
|
config.DescriptionCacheCapacity = 1024
|
||||||
}
|
|
||||||
|
|
||||||
testPgbouncer(t, config, 10, 100)
|
testPgbouncer(t, config, 10, 100)
|
||||||
}
|
}
|
||||||
@@ -33,8 +30,7 @@ func TestPgbouncerSimpleProtocol(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
config := mustParseConfig(t, connString)
|
config := mustParseConfig(t, connString)
|
||||||
config.BuildStatementCache = nil
|
config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
|
||||||
config.PreferSimpleProtocol = true
|
|
||||||
|
|
||||||
testPgbouncer(t, config, 10, 100)
|
testPgbouncer(t, config, 10, 100)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,53 @@
|
|||||||
|
# pgconn
|
||||||
|
|
||||||
|
Package pgconn is a low-level PostgreSQL database driver. It operates at nearly the same level as the C library libpq.
|
||||||
|
It is primarily intended to serve as the foundation for higher level libraries such as https://github.com/jackc/pgx.
|
||||||
|
Applications should handle normal queries with a higher level library and only use pgconn directly when required for
|
||||||
|
low-level access to PostgreSQL functionality.
|
||||||
|
|
||||||
|
## Example Usage
|
||||||
|
|
||||||
|
```go
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL"))
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln("pgconn failed to connect:", err)
|
||||||
|
}
|
||||||
|
defer pgConn.Close(context.Background())
|
||||||
|
|
||||||
|
result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil)
|
||||||
|
for result.NextRow() {
|
||||||
|
fmt.Println("User 123 has email:", string(result.Values()[0]))
|
||||||
|
}
|
||||||
|
_, err = result.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln("failed reading result:", err)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
The pgconn tests require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_CONN_STRING`
|
||||||
|
environment variable. The `PGX_TEST_CONN_STRING` environment variable can be a URL or DSN. In addition, the standard `PG*`
|
||||||
|
environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to simplify
|
||||||
|
environment variable handling.
|
||||||
|
|
||||||
|
### Example Test Environment
|
||||||
|
|
||||||
|
Connect to your PostgreSQL server and run:
|
||||||
|
|
||||||
|
```
|
||||||
|
create database pgx_test;
|
||||||
|
```
|
||||||
|
|
||||||
|
Now you can run the tests:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
PGX_TEST_CONN_STRING="host=/var/run/postgresql dbname=pgx_test" go test ./...
|
||||||
|
```
|
||||||
|
|
||||||
|
### Connection and Authentication Tests
|
||||||
|
|
||||||
|
Pgconn supports multiple connection types and means of authentication. These tests are optional. They
|
||||||
|
will only run if the appropriate environment variable is set. Run `go test -v | grep SKIP` to see if any tests are being
|
||||||
|
skipped. Most developers will not need to enable these tests. See `ci/setup_test.bash` for an example set up if you need change
|
||||||
|
authentication code.
|
||||||
@@ -0,0 +1,272 @@
|
|||||||
|
// SCRAM-SHA-256 authentication
|
||||||
|
//
|
||||||
|
// Resources:
|
||||||
|
// https://tools.ietf.org/html/rfc5802
|
||||||
|
// https://tools.ietf.org/html/rfc8265
|
||||||
|
// https://www.postgresql.org/docs/current/sasl-authentication.html
|
||||||
|
//
|
||||||
|
// Inspiration drawn from other implementations:
|
||||||
|
// https://github.com/lib/pq/pull/608
|
||||||
|
// https://github.com/lib/pq/pull/788
|
||||||
|
// https://github.com/lib/pq/pull/833
|
||||||
|
|
||||||
|
package pgconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgproto3"
|
||||||
|
"golang.org/x/crypto/pbkdf2"
|
||||||
|
"golang.org/x/text/secure/precis"
|
||||||
|
)
|
||||||
|
|
||||||
|
const clientNonceLen = 18
|
||||||
|
|
||||||
|
// Perform SCRAM authentication.
|
||||||
|
func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
|
||||||
|
sc, err := newScramClient(serverAuthMechanisms, c.config.Password)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send client-first-message in a SASLInitialResponse
|
||||||
|
saslInitialResponse := &pgproto3.SASLInitialResponse{
|
||||||
|
AuthMechanism: "SCRAM-SHA-256",
|
||||||
|
Data: sc.clientFirstMessage(),
|
||||||
|
}
|
||||||
|
c.frontend.Send(saslInitialResponse)
|
||||||
|
err = c.frontend.Flush()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive server-first-message payload in a AuthenticationSASLContinue.
|
||||||
|
saslContinue, err := c.rxSASLContinue()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = sc.recvServerFirstMessage(saslContinue.Data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send client-final-message in a SASLResponse
|
||||||
|
saslResponse := &pgproto3.SASLResponse{
|
||||||
|
Data: []byte(sc.clientFinalMessage()),
|
||||||
|
}
|
||||||
|
c.frontend.Send(saslResponse)
|
||||||
|
err = c.frontend.Flush()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive server-final-message payload in a AuthenticationSASLFinal.
|
||||||
|
saslFinal, err := c.rxSASLFinal()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return sc.recvServerFinalMessage(saslFinal.Data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) {
|
||||||
|
msg, err := c.receiveMessage()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
switch m := msg.(type) {
|
||||||
|
case *pgproto3.AuthenticationSASLContinue:
|
||||||
|
return m, nil
|
||||||
|
case *pgproto3.ErrorResponse:
|
||||||
|
return nil, ErrorResponseToPgError(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
|
||||||
|
msg, err := c.receiveMessage()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
switch m := msg.(type) {
|
||||||
|
case *pgproto3.AuthenticationSASLFinal:
|
||||||
|
return m, nil
|
||||||
|
case *pgproto3.ErrorResponse:
|
||||||
|
return nil, ErrorResponseToPgError(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
type scramClient struct {
|
||||||
|
serverAuthMechanisms []string
|
||||||
|
password []byte
|
||||||
|
clientNonce []byte
|
||||||
|
|
||||||
|
clientFirstMessageBare []byte
|
||||||
|
|
||||||
|
serverFirstMessage []byte
|
||||||
|
clientAndServerNonce []byte
|
||||||
|
salt []byte
|
||||||
|
iterations int
|
||||||
|
|
||||||
|
saltedPassword []byte
|
||||||
|
authMessage []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) {
|
||||||
|
sc := &scramClient{
|
||||||
|
serverAuthMechanisms: serverAuthMechanisms,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure server supports SCRAM-SHA-256
|
||||||
|
hasScramSHA256 := false
|
||||||
|
for _, mech := range sc.serverAuthMechanisms {
|
||||||
|
if mech == "SCRAM-SHA-256" {
|
||||||
|
hasScramSHA256 = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasScramSHA256 {
|
||||||
|
return nil, errors.New("server does not support SCRAM-SHA-256")
|
||||||
|
}
|
||||||
|
|
||||||
|
// precis.OpaqueString is equivalent to SASLprep for password.
|
||||||
|
var err error
|
||||||
|
sc.password, err = precis.OpaqueString.Bytes([]byte(password))
|
||||||
|
if err != nil {
|
||||||
|
// PostgreSQL allows passwords invalid according to SCRAM / SASLprep.
|
||||||
|
sc.password = []byte(password)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, clientNonceLen)
|
||||||
|
_, err = rand.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf)))
|
||||||
|
base64.RawStdEncoding.Encode(sc.clientNonce, buf)
|
||||||
|
|
||||||
|
return sc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *scramClient) clientFirstMessage() []byte {
|
||||||
|
sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce))
|
||||||
|
return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error {
|
||||||
|
sc.serverFirstMessage = serverFirstMessage
|
||||||
|
buf := serverFirstMessage
|
||||||
|
if !bytes.HasPrefix(buf, []byte("r=")) {
|
||||||
|
return errors.New("invalid SCRAM server-first-message received from server: did not include r=")
|
||||||
|
}
|
||||||
|
buf = buf[2:]
|
||||||
|
|
||||||
|
idx := bytes.IndexByte(buf, ',')
|
||||||
|
if idx == -1 {
|
||||||
|
return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
|
||||||
|
}
|
||||||
|
sc.clientAndServerNonce = buf[:idx]
|
||||||
|
buf = buf[idx+1:]
|
||||||
|
|
||||||
|
if !bytes.HasPrefix(buf, []byte("s=")) {
|
||||||
|
return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
|
||||||
|
}
|
||||||
|
buf = buf[2:]
|
||||||
|
|
||||||
|
idx = bytes.IndexByte(buf, ',')
|
||||||
|
if idx == -1 {
|
||||||
|
return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
|
||||||
|
}
|
||||||
|
saltStr := buf[:idx]
|
||||||
|
buf = buf[idx+1:]
|
||||||
|
|
||||||
|
if !bytes.HasPrefix(buf, []byte("i=")) {
|
||||||
|
return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
|
||||||
|
}
|
||||||
|
buf = buf[2:]
|
||||||
|
iterationsStr := buf
|
||||||
|
|
||||||
|
var err error
|
||||||
|
sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid SCRAM salt received from server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sc.iterations, err = strconv.Atoi(string(iterationsStr))
|
||||||
|
if err != nil || sc.iterations <= 0 {
|
||||||
|
return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) {
|
||||||
|
return errors.New("invalid SCRAM nonce: did not start with client nonce")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sc.clientAndServerNonce) <= len(sc.clientNonce) {
|
||||||
|
return errors.New("invalid SCRAM nonce: did not include server nonce")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *scramClient) clientFinalMessage() string {
|
||||||
|
clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce))
|
||||||
|
|
||||||
|
sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New)
|
||||||
|
sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(","))
|
||||||
|
|
||||||
|
clientProof := computeClientProof(sc.saltedPassword, sc.authMessage)
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error {
|
||||||
|
if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) {
|
||||||
|
return errors.New("invalid SCRAM server-final-message received from server")
|
||||||
|
}
|
||||||
|
|
||||||
|
serverSignature := serverFinalMessage[2:]
|
||||||
|
|
||||||
|
if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) {
|
||||||
|
return errors.New("invalid SCRAM ServerSignature received from server")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func computeHMAC(key, msg []byte) []byte {
|
||||||
|
mac := hmac.New(sha256.New, key)
|
||||||
|
mac.Write(msg)
|
||||||
|
return mac.Sum(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func computeClientProof(saltedPassword, authMessage []byte) []byte {
|
||||||
|
clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
|
||||||
|
storedKey := sha256.Sum256(clientKey)
|
||||||
|
clientSignature := computeHMAC(storedKey[:], authMessage)
|
||||||
|
|
||||||
|
clientProof := make([]byte, len(clientSignature))
|
||||||
|
for i := 0; i < len(clientSignature); i++ {
|
||||||
|
clientProof[i] = clientKey[i] ^ clientSignature[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof)))
|
||||||
|
base64.StdEncoding.Encode(buf, clientProof)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte {
|
||||||
|
serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
|
||||||
|
serverSignature := computeHMAC(serverKey, authMessage)
|
||||||
|
buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature)))
|
||||||
|
base64.StdEncoding.Encode(buf, serverSignature)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
@@ -0,0 +1,73 @@
|
|||||||
|
package pgconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkCommandTagRowsAffected(b *testing.B) {
|
||||||
|
benchmarks := []struct {
|
||||||
|
commandTag string
|
||||||
|
rowsAffected int64
|
||||||
|
}{
|
||||||
|
{"UPDATE 1", 1},
|
||||||
|
{"UPDATE 123456789", 123456789},
|
||||||
|
{"INSERT 0 1", 1},
|
||||||
|
{"INSERT 0 123456789", 123456789},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, bm := range benchmarks {
|
||||||
|
ct := CommandTag{s: bm.commandTag}
|
||||||
|
b.Run(bm.commandTag, func(b *testing.B) {
|
||||||
|
var n int64
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
n = ct.RowsAffected()
|
||||||
|
}
|
||||||
|
if n != bm.rowsAffected {
|
||||||
|
b.Errorf("expected %d got %d", bm.rowsAffected, n)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkCommandTagTypeFromString(b *testing.B) {
|
||||||
|
ct := CommandTag{s: "UPDATE 1"}
|
||||||
|
|
||||||
|
var update bool
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
update = strings.HasPrefix(ct.String(), "UPDATE")
|
||||||
|
}
|
||||||
|
if !update {
|
||||||
|
b.Error("expected update")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkCommandTagInsert(b *testing.B) {
|
||||||
|
benchmarks := []struct {
|
||||||
|
commandTag string
|
||||||
|
is bool
|
||||||
|
}{
|
||||||
|
{"INSERT 1", true},
|
||||||
|
{"INSERT 1234567890", true},
|
||||||
|
{"UPDATE 1", false},
|
||||||
|
{"UPDATE 1234567890", false},
|
||||||
|
{"DELETE 1", false},
|
||||||
|
{"DELETE 1234567890", false},
|
||||||
|
{"SELECT 1", false},
|
||||||
|
{"SELECT 1234567890", false},
|
||||||
|
{"UNKNOWN 1234567890", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, bm := range benchmarks {
|
||||||
|
ct := CommandTag{s: bm.commandTag}
|
||||||
|
b.Run(bm.commandTag, func(b *testing.B) {
|
||||||
|
var is bool
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
is = ct.Insert()
|
||||||
|
}
|
||||||
|
if is != bm.is {
|
||||||
|
b.Errorf("expected %v got %v", bm.is, is)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,254 @@
|
|||||||
|
package pgconn_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkConnect(b *testing.B) {
|
||||||
|
benchmarks := []struct {
|
||||||
|
name string
|
||||||
|
env string
|
||||||
|
}{
|
||||||
|
{"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"},
|
||||||
|
{"TCP", "PGX_TEST_TCP_CONN_STRING"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, bm := range benchmarks {
|
||||||
|
bm := bm
|
||||||
|
b.Run(bm.name, func(b *testing.B) {
|
||||||
|
connString := os.Getenv(bm.env)
|
||||||
|
if connString == "" {
|
||||||
|
b.Skipf("Skipping due to missing environment variable %v", bm.env)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
conn, err := pgconn.Connect(context.Background(), connString)
|
||||||
|
require.Nil(b, err)
|
||||||
|
|
||||||
|
err = conn.Close(context.Background())
|
||||||
|
require.Nil(b, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkExec(b *testing.B) {
|
||||||
|
expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
|
||||||
|
benchmarks := []struct {
|
||||||
|
name string
|
||||||
|
ctx context.Context
|
||||||
|
}{
|
||||||
|
// Using an empty context other than context.Background() to compare
|
||||||
|
// performance
|
||||||
|
{"background context", context.Background()},
|
||||||
|
{"empty context", context.TODO()},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, bm := range benchmarks {
|
||||||
|
bm := bm
|
||||||
|
b.Run(bm.name, func(b *testing.B) {
|
||||||
|
conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
|
require.Nil(b, err)
|
||||||
|
defer closeConn(b, conn)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
mrr := conn.Exec(bm.ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date")
|
||||||
|
|
||||||
|
for mrr.NextResult() {
|
||||||
|
rr := mrr.ResultReader()
|
||||||
|
|
||||||
|
rowCount := 0
|
||||||
|
for rr.NextRow() {
|
||||||
|
rowCount++
|
||||||
|
if len(rr.Values()) != len(expectedValues) {
|
||||||
|
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
|
||||||
|
}
|
||||||
|
for i := range rr.Values() {
|
||||||
|
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
|
||||||
|
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, err = rr.Close()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
if rowCount != 1 {
|
||||||
|
b.Fatalf("unexpected rowCount: %d", rowCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err := mrr.Close()
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkExecPossibleToCancel(b *testing.B) {
|
||||||
|
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
|
require.Nil(b, err)
|
||||||
|
defer closeConn(b, conn)
|
||||||
|
|
||||||
|
expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
mrr := conn.Exec(ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date")
|
||||||
|
|
||||||
|
for mrr.NextResult() {
|
||||||
|
rr := mrr.ResultReader()
|
||||||
|
|
||||||
|
rowCount := 0
|
||||||
|
for rr.NextRow() {
|
||||||
|
rowCount++
|
||||||
|
if len(rr.Values()) != len(expectedValues) {
|
||||||
|
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
|
||||||
|
}
|
||||||
|
for i := range rr.Values() {
|
||||||
|
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
|
||||||
|
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, err = rr.Close()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
if rowCount != 1 {
|
||||||
|
b.Fatalf("unexpected rowCount: %d", rowCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err := mrr.Close()
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkExecPrepared(b *testing.B) {
|
||||||
|
expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
|
||||||
|
|
||||||
|
benchmarks := []struct {
|
||||||
|
name string
|
||||||
|
ctx context.Context
|
||||||
|
}{
|
||||||
|
// Using an empty context other than context.Background() to compare
|
||||||
|
// performance
|
||||||
|
{"background context", context.Background()},
|
||||||
|
{"empty context", context.TODO()},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, bm := range benchmarks {
|
||||||
|
bm := bm
|
||||||
|
b.Run(bm.name, func(b *testing.B) {
|
||||||
|
conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
|
require.Nil(b, err)
|
||||||
|
defer closeConn(b, conn)
|
||||||
|
|
||||||
|
_, err = conn.Prepare(bm.ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
|
||||||
|
require.Nil(b, err)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
rr := conn.ExecPrepared(bm.ctx, "ps1", nil, nil, nil)
|
||||||
|
|
||||||
|
rowCount := 0
|
||||||
|
for rr.NextRow() {
|
||||||
|
rowCount++
|
||||||
|
if len(rr.Values()) != len(expectedValues) {
|
||||||
|
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
|
||||||
|
}
|
||||||
|
for i := range rr.Values() {
|
||||||
|
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
|
||||||
|
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, err = rr.Close()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
if rowCount != 1 {
|
||||||
|
b.Fatalf("unexpected rowCount: %d", rowCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkExecPreparedPossibleToCancel(b *testing.B) {
|
||||||
|
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
|
require.Nil(b, err)
|
||||||
|
defer closeConn(b, conn)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, err = conn.Prepare(ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
|
||||||
|
require.Nil(b, err)
|
||||||
|
|
||||||
|
expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
rr := conn.ExecPrepared(ctx, "ps1", nil, nil, nil)
|
||||||
|
|
||||||
|
rowCount := 0
|
||||||
|
for rr.NextRow() {
|
||||||
|
rowCount += 1
|
||||||
|
if len(rr.Values()) != len(expectedValues) {
|
||||||
|
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
|
||||||
|
}
|
||||||
|
for i := range rr.Values() {
|
||||||
|
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
|
||||||
|
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, err = rr.Close()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
if rowCount != 1 {
|
||||||
|
b.Fatalf("unexpected rowCount: %d", rowCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// func BenchmarkChanToSetDeadlinePossibleToCancel(b *testing.B) {
|
||||||
|
// conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
|
// require.Nil(b, err)
|
||||||
|
// defer closeConn(b, conn)
|
||||||
|
|
||||||
|
// ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
// defer cancel()
|
||||||
|
|
||||||
|
// b.ResetTimer()
|
||||||
|
|
||||||
|
// for i := 0; i < b.N; i++ {
|
||||||
|
// conn.ChanToSetDeadline().Watch(ctx)
|
||||||
|
// conn.ChanToSetDeadline().Ignore()
|
||||||
|
// }
|
||||||
|
// }
|
||||||
@@ -0,0 +1,886 @@
|
|||||||
|
package pgconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"math"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgpassfile"
|
||||||
|
"github.com/jackc/pgservicefile"
|
||||||
|
"github.com/jackc/pgx/v5/pgproto3"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
||||||
|
type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
||||||
|
type GetSSLPasswordFunc func(ctx context.Context) string
|
||||||
|
|
||||||
|
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig. A
|
||||||
|
// manually initialized Config will cause ConnectConfig to panic.
|
||||||
|
type Config struct {
|
||||||
|
Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp)
|
||||||
|
Port uint16
|
||||||
|
Database string
|
||||||
|
User string
|
||||||
|
Password string
|
||||||
|
TLSConfig *tls.Config // nil disables TLS
|
||||||
|
ConnectTimeout time.Duration
|
||||||
|
DialFunc DialFunc // e.g. net.Dialer.DialContext
|
||||||
|
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
|
||||||
|
BuildFrontend BuildFrontendFunc
|
||||||
|
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
||||||
|
|
||||||
|
KerberosSrvName string
|
||||||
|
KerberosSpn string
|
||||||
|
Fallbacks []*FallbackConfig
|
||||||
|
|
||||||
|
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
|
||||||
|
// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
|
||||||
|
// fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs.
|
||||||
|
ValidateConnect ValidateConnectFunc
|
||||||
|
|
||||||
|
// AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables
|
||||||
|
// or prepare statements). If this returns an error the connection attempt fails.
|
||||||
|
AfterConnect AfterConnectFunc
|
||||||
|
|
||||||
|
// OnNotice is a callback function called when a notice response is received.
|
||||||
|
OnNotice NoticeHandler
|
||||||
|
|
||||||
|
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
|
||||||
|
OnNotification NotificationHandler
|
||||||
|
|
||||||
|
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseConfigOptions contains options that control how a config is built such as getsslpassword.
|
||||||
|
type ParseConfigOptions struct {
|
||||||
|
// GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the the libpq function
|
||||||
|
// PQsetSSLKeyPassHook_OpenSSL.
|
||||||
|
GetSSLPassword GetSSLPasswordFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy returns a deep copy of the config that is safe to use and modify.
|
||||||
|
// The only exception is the TLSConfig field:
|
||||||
|
// according to the tls.Config docs it must not be modified after creation.
|
||||||
|
func (c *Config) Copy() *Config {
|
||||||
|
newConf := new(Config)
|
||||||
|
*newConf = *c
|
||||||
|
if newConf.TLSConfig != nil {
|
||||||
|
newConf.TLSConfig = c.TLSConfig.Clone()
|
||||||
|
}
|
||||||
|
if newConf.RuntimeParams != nil {
|
||||||
|
newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams))
|
||||||
|
for k, v := range c.RuntimeParams {
|
||||||
|
newConf.RuntimeParams[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if newConf.Fallbacks != nil {
|
||||||
|
newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks))
|
||||||
|
for i, fallback := range c.Fallbacks {
|
||||||
|
newFallback := new(FallbackConfig)
|
||||||
|
*newFallback = *fallback
|
||||||
|
if newFallback.TLSConfig != nil {
|
||||||
|
newFallback.TLSConfig = fallback.TLSConfig.Clone()
|
||||||
|
}
|
||||||
|
newConf.Fallbacks[i] = newFallback
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return newConf
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a
|
||||||
|
// network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections.
|
||||||
|
type FallbackConfig struct {
|
||||||
|
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
||||||
|
Port uint16
|
||||||
|
TLSConfig *tls.Config // nil disables TLS
|
||||||
|
}
|
||||||
|
|
||||||
|
// isAbsolutePath checks if the provided value is an absolute path either
|
||||||
|
// beginning with a forward slash (as on Linux-based systems) or with a capital
|
||||||
|
// letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows).
|
||||||
|
func isAbsolutePath(path string) bool {
|
||||||
|
isWindowsPath := func(p string) bool {
|
||||||
|
if len(p) < 3 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
drive := p[0]
|
||||||
|
colon := p[1]
|
||||||
|
backslash := p[2]
|
||||||
|
if drive >= 'A' && drive <= 'Z' && colon == ':' && backslash == '\\' {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.HasPrefix(path, "/") || isWindowsPath(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with
|
||||||
|
// net.Dial.
|
||||||
|
func NetworkAddress(host string, port uint16) (network, address string) {
|
||||||
|
if isAbsolutePath(host) {
|
||||||
|
network = "unix"
|
||||||
|
address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
|
||||||
|
} else {
|
||||||
|
network = "tcp"
|
||||||
|
address = net.JoinHostPort(host, strconv.Itoa(int(port)))
|
||||||
|
}
|
||||||
|
return network, address
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It
|
||||||
|
// uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely
|
||||||
|
// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style).
|
||||||
|
// See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be
|
||||||
|
// empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file.
|
||||||
|
//
|
||||||
|
// # Example DSN
|
||||||
|
// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca
|
||||||
|
//
|
||||||
|
// # Example URL
|
||||||
|
// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca
|
||||||
|
//
|
||||||
|
// The returned *Config may be modified. However, it is strongly recommended that any configuration that can be done
|
||||||
|
// through the connection string be done there. In particular the fields Host, Port, TLSConfig, and Fallbacks can be
|
||||||
|
// interdependent (e.g. TLSConfig needs knowledge of the host to validate the server certificate). These fields should
|
||||||
|
// not be modified individually. They should all be modified or all left unchanged.
|
||||||
|
//
|
||||||
|
// ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated
|
||||||
|
// values that will be tried in order. This can be used as part of a high availability system. See
|
||||||
|
// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information.
|
||||||
|
//
|
||||||
|
// # Example URL
|
||||||
|
// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
|
||||||
|
//
|
||||||
|
// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed
|
||||||
|
// via database URL or DSN:
|
||||||
|
//
|
||||||
|
// PGHOST
|
||||||
|
// PGPORT
|
||||||
|
// PGDATABASE
|
||||||
|
// PGUSER
|
||||||
|
// PGPASSWORD
|
||||||
|
// PGPASSFILE
|
||||||
|
// PGSERVICE
|
||||||
|
// PGSERVICEFILE
|
||||||
|
// PGSSLMODE
|
||||||
|
// PGSSLCERT
|
||||||
|
// PGSSLKEY
|
||||||
|
// PGSSLROOTCERT
|
||||||
|
// PGSSLPASSWORD
|
||||||
|
// PGAPPNAME
|
||||||
|
// PGCONNECT_TIMEOUT
|
||||||
|
// PGTARGETSESSIONATTRS
|
||||||
|
//
|
||||||
|
// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables.
|
||||||
|
//
|
||||||
|
// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are
|
||||||
|
// usually but not always the environment variable name downcased and without the "PG" prefix.
|
||||||
|
//
|
||||||
|
// Important Security Notes:
|
||||||
|
//
|
||||||
|
// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if
|
||||||
|
// not set.
|
||||||
|
//
|
||||||
|
// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of
|
||||||
|
// security each sslmode provides.
|
||||||
|
//
|
||||||
|
// The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of
|
||||||
|
// the Config struct. If TLSConfig is manually changed it will not affect the fallbacks. For example, in the case of
|
||||||
|
// sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try the fallback
|
||||||
|
// which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually
|
||||||
|
// changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting
|
||||||
|
// TLSConfig.
|
||||||
|
//
|
||||||
|
// Other known differences with libpq:
|
||||||
|
//
|
||||||
|
// When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn
|
||||||
|
// does not.
|
||||||
|
//
|
||||||
|
// In addition, ParseConfig accepts the following options:
|
||||||
|
//
|
||||||
|
// servicefile
|
||||||
|
// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a
|
||||||
|
// part of the connection string.
|
||||||
|
func ParseConfig(connString string) (*Config, error) {
|
||||||
|
var parseConfigOptions ParseConfigOptions
|
||||||
|
return ParseConfigWithOptions(connString, parseConfigOptions)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseConfigWithOptions builds a *Config from connString and options with similar behavior to the PostgreSQL standard
|
||||||
|
// C library libpq. options contains settings that cannot be specified in a connString such as providing a function to
|
||||||
|
// get the SSL password.
|
||||||
|
func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Config, error) {
|
||||||
|
defaultSettings := defaultSettings()
|
||||||
|
envSettings := parseEnvSettings()
|
||||||
|
|
||||||
|
connStringSettings := make(map[string]string)
|
||||||
|
if connString != "" {
|
||||||
|
var err error
|
||||||
|
// connString may be a database URL or a DSN
|
||||||
|
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
|
||||||
|
connStringSettings, err = parseURLSettings(connString)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
connStringSettings, err = parseDSNSettings(connString)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
settings := mergeSettings(defaultSettings, envSettings, connStringSettings)
|
||||||
|
if service, present := settings["service"]; present {
|
||||||
|
serviceSettings, err := parseServiceSettings(settings["servicefile"], service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings)
|
||||||
|
}
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
createdByParseConfig: true,
|
||||||
|
Database: settings["database"],
|
||||||
|
User: settings["user"],
|
||||||
|
Password: settings["password"],
|
||||||
|
RuntimeParams: make(map[string]string),
|
||||||
|
BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend {
|
||||||
|
return pgproto3.NewFrontend(r, w)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if connectTimeoutSetting, present := settings["connect_timeout"]; present {
|
||||||
|
connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
|
||||||
|
}
|
||||||
|
config.ConnectTimeout = connectTimeout
|
||||||
|
config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
|
||||||
|
} else {
|
||||||
|
defaultDialer := makeDefaultDialer()
|
||||||
|
config.DialFunc = defaultDialer.DialContext
|
||||||
|
}
|
||||||
|
|
||||||
|
config.LookupFunc = makeDefaultResolver().LookupHost
|
||||||
|
|
||||||
|
notRuntimeParams := map[string]struct{}{
|
||||||
|
"host": {},
|
||||||
|
"port": {},
|
||||||
|
"database": {},
|
||||||
|
"user": {},
|
||||||
|
"password": {},
|
||||||
|
"passfile": {},
|
||||||
|
"connect_timeout": {},
|
||||||
|
"sslmode": {},
|
||||||
|
"sslkey": {},
|
||||||
|
"sslcert": {},
|
||||||
|
"sslrootcert": {},
|
||||||
|
"sslpassword": {},
|
||||||
|
"sslsni": {},
|
||||||
|
"krbspn": {},
|
||||||
|
"krbsrvname": {},
|
||||||
|
"target_session_attrs": {},
|
||||||
|
"service": {},
|
||||||
|
"servicefile": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adding kerberos configuration
|
||||||
|
if _, present := settings["krbsrvname"]; present {
|
||||||
|
config.KerberosSrvName = settings["krbsrvname"]
|
||||||
|
}
|
||||||
|
if _, present := settings["krbspn"]; present {
|
||||||
|
config.KerberosSpn = settings["krbspn"]
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range settings {
|
||||||
|
if _, present := notRuntimeParams[k]; present {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
config.RuntimeParams[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
fallbacks := []*FallbackConfig{}
|
||||||
|
|
||||||
|
hosts := strings.Split(settings["host"], ",")
|
||||||
|
ports := strings.Split(settings["port"], ",")
|
||||||
|
|
||||||
|
for i, host := range hosts {
|
||||||
|
var portStr string
|
||||||
|
if i < len(ports) {
|
||||||
|
portStr = ports[i]
|
||||||
|
} else {
|
||||||
|
portStr = ports[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
port, err := parsePort(portStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
var tlsConfigs []*tls.Config
|
||||||
|
|
||||||
|
// Ignore TLS settings if Unix domain socket like libpq
|
||||||
|
if network, _ := NetworkAddress(host, port); network == "unix" {
|
||||||
|
tlsConfigs = append(tlsConfigs, nil)
|
||||||
|
} else {
|
||||||
|
var err error
|
||||||
|
tlsConfigs, err = configTLS(settings, host, options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tlsConfig := range tlsConfigs {
|
||||||
|
fallbacks = append(fallbacks, &FallbackConfig{
|
||||||
|
Host: host,
|
||||||
|
Port: port,
|
||||||
|
TLSConfig: tlsConfig,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Host = fallbacks[0].Host
|
||||||
|
config.Port = fallbacks[0].Port
|
||||||
|
config.TLSConfig = fallbacks[0].TLSConfig
|
||||||
|
config.Fallbacks = fallbacks[1:]
|
||||||
|
|
||||||
|
passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
|
||||||
|
if err == nil {
|
||||||
|
if config.Password == "" {
|
||||||
|
host := config.Host
|
||||||
|
if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" {
|
||||||
|
host = "localhost"
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch tsa := settings["target_session_attrs"]; tsa {
|
||||||
|
case "read-write":
|
||||||
|
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite
|
||||||
|
case "read-only":
|
||||||
|
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadOnly
|
||||||
|
case "primary":
|
||||||
|
config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary
|
||||||
|
case "standby":
|
||||||
|
config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby
|
||||||
|
case "prefer-standby":
|
||||||
|
config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby
|
||||||
|
case "any":
|
||||||
|
// do nothing
|
||||||
|
default:
|
||||||
|
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeSettings(settingSets ...map[string]string) map[string]string {
|
||||||
|
settings := make(map[string]string)
|
||||||
|
|
||||||
|
for _, s2 := range settingSets {
|
||||||
|
for k, v := range s2 {
|
||||||
|
settings[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return settings
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseEnvSettings() map[string]string {
|
||||||
|
settings := make(map[string]string)
|
||||||
|
|
||||||
|
nameMap := map[string]string{
|
||||||
|
"PGHOST": "host",
|
||||||
|
"PGPORT": "port",
|
||||||
|
"PGDATABASE": "database",
|
||||||
|
"PGUSER": "user",
|
||||||
|
"PGPASSWORD": "password",
|
||||||
|
"PGPASSFILE": "passfile",
|
||||||
|
"PGAPPNAME": "application_name",
|
||||||
|
"PGCONNECT_TIMEOUT": "connect_timeout",
|
||||||
|
"PGSSLMODE": "sslmode",
|
||||||
|
"PGSSLKEY": "sslkey",
|
||||||
|
"PGSSLCERT": "sslcert",
|
||||||
|
"PGSSLSNI": "sslsni",
|
||||||
|
"PGSSLROOTCERT": "sslrootcert",
|
||||||
|
"PGSSLPASSWORD": "sslpassword",
|
||||||
|
"PGTARGETSESSIONATTRS": "target_session_attrs",
|
||||||
|
"PGSERVICE": "service",
|
||||||
|
"PGSERVICEFILE": "servicefile",
|
||||||
|
}
|
||||||
|
|
||||||
|
for envname, realname := range nameMap {
|
||||||
|
value := os.Getenv(envname)
|
||||||
|
if value != "" {
|
||||||
|
settings[realname] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return settings
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseURLSettings(connString string) (map[string]string, error) {
|
||||||
|
settings := make(map[string]string)
|
||||||
|
|
||||||
|
url, err := url.Parse(connString)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if url.User != nil {
|
||||||
|
settings["user"] = url.User.Username()
|
||||||
|
if password, present := url.User.Password(); present {
|
||||||
|
settings["password"] = password
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port.
|
||||||
|
var hosts []string
|
||||||
|
var ports []string
|
||||||
|
for _, host := range strings.Split(url.Host, ",") {
|
||||||
|
if host == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if isIPOnly(host) {
|
||||||
|
hosts = append(hosts, strings.Trim(host, "[]"))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
h, p, err := net.SplitHostPort(host)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err)
|
||||||
|
}
|
||||||
|
if h != "" {
|
||||||
|
hosts = append(hosts, h)
|
||||||
|
}
|
||||||
|
if p != "" {
|
||||||
|
ports = append(ports, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(hosts) > 0 {
|
||||||
|
settings["host"] = strings.Join(hosts, ",")
|
||||||
|
}
|
||||||
|
if len(ports) > 0 {
|
||||||
|
settings["port"] = strings.Join(ports, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
database := strings.TrimLeft(url.Path, "/")
|
||||||
|
if database != "" {
|
||||||
|
settings["database"] = database
|
||||||
|
}
|
||||||
|
|
||||||
|
nameMap := map[string]string{
|
||||||
|
"dbname": "database",
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range url.Query() {
|
||||||
|
if k2, present := nameMap[k]; present {
|
||||||
|
k = k2
|
||||||
|
}
|
||||||
|
|
||||||
|
settings[k] = v[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
return settings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIPOnly(host string) bool {
|
||||||
|
return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":")
|
||||||
|
}
|
||||||
|
|
||||||
|
var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
|
||||||
|
|
||||||
|
func parseDSNSettings(s string) (map[string]string, error) {
|
||||||
|
settings := make(map[string]string)
|
||||||
|
|
||||||
|
nameMap := map[string]string{
|
||||||
|
"dbname": "database",
|
||||||
|
}
|
||||||
|
|
||||||
|
for len(s) > 0 {
|
||||||
|
var key, val string
|
||||||
|
eqIdx := strings.IndexRune(s, '=')
|
||||||
|
if eqIdx < 0 {
|
||||||
|
return nil, errors.New("invalid dsn")
|
||||||
|
}
|
||||||
|
|
||||||
|
key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
|
||||||
|
s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f")
|
||||||
|
if len(s) == 0 {
|
||||||
|
} else if s[0] != '\'' {
|
||||||
|
end := 0
|
||||||
|
for ; end < len(s); end++ {
|
||||||
|
if asciiSpace[s[end]] == 1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if s[end] == '\\' {
|
||||||
|
end++
|
||||||
|
if end == len(s) {
|
||||||
|
return nil, errors.New("invalid backslash")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
|
||||||
|
if end == len(s) {
|
||||||
|
s = ""
|
||||||
|
} else {
|
||||||
|
s = s[end+1:]
|
||||||
|
}
|
||||||
|
} else { // quoted string
|
||||||
|
s = s[1:]
|
||||||
|
end := 0
|
||||||
|
for ; end < len(s); end++ {
|
||||||
|
if s[end] == '\'' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if s[end] == '\\' {
|
||||||
|
end++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if end == len(s) {
|
||||||
|
return nil, errors.New("unterminated quoted string in connection info string")
|
||||||
|
}
|
||||||
|
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
|
||||||
|
if end == len(s) {
|
||||||
|
s = ""
|
||||||
|
} else {
|
||||||
|
s = s[end+1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if k, ok := nameMap[key]; ok {
|
||||||
|
key = k
|
||||||
|
}
|
||||||
|
|
||||||
|
if key == "" {
|
||||||
|
return nil, errors.New("invalid dsn")
|
||||||
|
}
|
||||||
|
|
||||||
|
settings[key] = val
|
||||||
|
}
|
||||||
|
|
||||||
|
return settings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) {
|
||||||
|
servicefile, err := pgservicefile.ReadServicefile(servicefilePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read service file: %v", servicefilePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
service, err := servicefile.GetService(serviceName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to find service: %v", serviceName)
|
||||||
|
}
|
||||||
|
|
||||||
|
nameMap := map[string]string{
|
||||||
|
"dbname": "database",
|
||||||
|
}
|
||||||
|
|
||||||
|
settings := make(map[string]string, len(service.Settings))
|
||||||
|
for k, v := range service.Settings {
|
||||||
|
if k2, present := nameMap[k]; present {
|
||||||
|
k = k2
|
||||||
|
}
|
||||||
|
settings[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
return settings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is
|
||||||
|
// necessary to allow returning multiple TLS configs as sslmode "allow" and
|
||||||
|
// "prefer" allow fallback.
|
||||||
|
func configTLS(settings map[string]string, thisHost string, parseConfigOptions ParseConfigOptions) ([]*tls.Config, error) {
|
||||||
|
host := thisHost
|
||||||
|
sslmode := settings["sslmode"]
|
||||||
|
sslrootcert := settings["sslrootcert"]
|
||||||
|
sslcert := settings["sslcert"]
|
||||||
|
sslkey := settings["sslkey"]
|
||||||
|
sslpassword := settings["sslpassword"]
|
||||||
|
sslsni := settings["sslsni"]
|
||||||
|
|
||||||
|
// Match libpq default behavior
|
||||||
|
if sslmode == "" {
|
||||||
|
sslmode = "prefer"
|
||||||
|
}
|
||||||
|
if sslsni == "" {
|
||||||
|
sslsni = "1"
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{}
|
||||||
|
|
||||||
|
switch sslmode {
|
||||||
|
case "disable":
|
||||||
|
return []*tls.Config{nil}, nil
|
||||||
|
case "allow", "prefer":
|
||||||
|
tlsConfig.InsecureSkipVerify = true
|
||||||
|
case "require":
|
||||||
|
// According to PostgreSQL documentation, if a root CA file exists,
|
||||||
|
// the behavior of sslmode=require should be the same as that of verify-ca
|
||||||
|
//
|
||||||
|
// See https://www.postgresql.org/docs/12/libpq-ssl.html
|
||||||
|
if sslrootcert != "" {
|
||||||
|
goto nextCase
|
||||||
|
}
|
||||||
|
tlsConfig.InsecureSkipVerify = true
|
||||||
|
break
|
||||||
|
nextCase:
|
||||||
|
fallthrough
|
||||||
|
case "verify-ca":
|
||||||
|
// Don't perform the default certificate verification because it
|
||||||
|
// will verify the hostname. Instead, verify the server's
|
||||||
|
// certificate chain ourselves in VerifyPeerCertificate and
|
||||||
|
// ignore the server name. This emulates libpq's verify-ca
|
||||||
|
// behavior.
|
||||||
|
//
|
||||||
|
// See https://github.com/golang/go/issues/21971#issuecomment-332693931
|
||||||
|
// and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate
|
||||||
|
// for more info.
|
||||||
|
tlsConfig.InsecureSkipVerify = true
|
||||||
|
tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error {
|
||||||
|
certs := make([]*x509.Certificate, len(certificates))
|
||||||
|
for i, asn1Data := range certificates {
|
||||||
|
cert, err := x509.ParseCertificate(asn1Data)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("failed to parse certificate from server: " + err.Error())
|
||||||
|
}
|
||||||
|
certs[i] = cert
|
||||||
|
}
|
||||||
|
|
||||||
|
// Leave DNSName empty to skip hostname verification.
|
||||||
|
opts := x509.VerifyOptions{
|
||||||
|
Roots: tlsConfig.RootCAs,
|
||||||
|
Intermediates: x509.NewCertPool(),
|
||||||
|
}
|
||||||
|
// Skip the first cert because it's the leaf. All others
|
||||||
|
// are intermediates.
|
||||||
|
for _, cert := range certs[1:] {
|
||||||
|
opts.Intermediates.AddCert(cert)
|
||||||
|
}
|
||||||
|
_, err := certs[0].Verify(opts)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case "verify-full":
|
||||||
|
tlsConfig.ServerName = host
|
||||||
|
default:
|
||||||
|
return nil, errors.New("sslmode is invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
if sslrootcert != "" {
|
||||||
|
caCertPool := x509.NewCertPool()
|
||||||
|
|
||||||
|
caPath := sslrootcert
|
||||||
|
caCert, err := ioutil.ReadFile(caPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to read CA file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||||
|
return nil, errors.New("unable to add CA to cert pool")
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig.RootCAs = caCertPool
|
||||||
|
tlsConfig.ClientCAs = caCertPool
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
|
||||||
|
return nil, errors.New(`both "sslcert" and "sslkey" are required`)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sslcert != "" && sslkey != "" {
|
||||||
|
buf, err := ioutil.ReadFile(sslkey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to read sslkey: %w", err)
|
||||||
|
}
|
||||||
|
block, _ := pem.Decode(buf)
|
||||||
|
var pemKey []byte
|
||||||
|
var decryptedKey []byte
|
||||||
|
var decryptedError error
|
||||||
|
// If PEM is encrypted, attempt to decrypt using pass phrase
|
||||||
|
if x509.IsEncryptedPEMBlock(block) {
|
||||||
|
// Attempt decryption with pass phrase
|
||||||
|
// NOTE: only supports RSA (PKCS#1)
|
||||||
|
if sslpassword != "" {
|
||||||
|
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
|
||||||
|
}
|
||||||
|
//if sslpassword not provided or has decryption error when use it
|
||||||
|
//try to find sslpassword with callback function
|
||||||
|
if sslpassword == "" || decryptedError != nil {
|
||||||
|
if parseConfigOptions.GetSSLPassword != nil {
|
||||||
|
sslpassword = parseConfigOptions.GetSSLPassword(context.Background())
|
||||||
|
}
|
||||||
|
if sslpassword == "" {
|
||||||
|
return nil, fmt.Errorf("unable to find sslpassword")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
|
||||||
|
// Should we also provide warning for PKCS#1 needed?
|
||||||
|
if decryptedError != nil {
|
||||||
|
return nil, fmt.Errorf("unable to decrypt key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pemBytes := pem.Block{
|
||||||
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: decryptedKey,
|
||||||
|
}
|
||||||
|
pemKey = pem.EncodeToMemory(&pemBytes)
|
||||||
|
} else {
|
||||||
|
pemKey = pem.EncodeToMemory(block)
|
||||||
|
}
|
||||||
|
certfile, err := ioutil.ReadFile(sslcert)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to read cert: %w", err)
|
||||||
|
}
|
||||||
|
cert, err := tls.X509KeyPair(certfile, pemKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to load cert: %w", err)
|
||||||
|
}
|
||||||
|
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set Server Name Indication (SNI), if enabled by connection parameters.
|
||||||
|
// Per RFC 6066, do not set it if the host is a literal IP address (IPv4
|
||||||
|
// or IPv6).
|
||||||
|
if sslsni == "1" && net.ParseIP(host) == nil {
|
||||||
|
tlsConfig.ServerName = host
|
||||||
|
}
|
||||||
|
|
||||||
|
switch sslmode {
|
||||||
|
case "allow":
|
||||||
|
return []*tls.Config{nil, tlsConfig}, nil
|
||||||
|
case "prefer":
|
||||||
|
return []*tls.Config{tlsConfig, nil}, nil
|
||||||
|
case "require", "verify-ca", "verify-full":
|
||||||
|
return []*tls.Config{tlsConfig}, nil
|
||||||
|
default:
|
||||||
|
panic("BUG: bad sslmode should already have been caught")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePort(s string) (uint16, error) {
|
||||||
|
port, err := strconv.ParseUint(s, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if port < 1 || port > math.MaxUint16 {
|
||||||
|
return 0, errors.New("outside range")
|
||||||
|
}
|
||||||
|
return uint16(port), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeDefaultDialer() *net.Dialer {
|
||||||
|
return &net.Dialer{KeepAlive: 5 * time.Minute}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeDefaultResolver() *net.Resolver {
|
||||||
|
return net.DefaultResolver
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseConnectTimeoutSetting(s string) (time.Duration, error) {
|
||||||
|
timeout, err := strconv.ParseInt(s, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if timeout < 0 {
|
||||||
|
return 0, errors.New("negative timeout")
|
||||||
|
}
|
||||||
|
return time.Duration(timeout) * time.Second, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc {
|
||||||
|
d := makeDefaultDialer()
|
||||||
|
d.Timeout = timeout
|
||||||
|
return d.DialContext
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible
|
||||||
|
// target_session_attrs=read-write.
|
||||||
|
func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
|
||||||
|
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
|
||||||
|
if result.Err != nil {
|
||||||
|
return result.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(result.Rows[0][0]) == "on" {
|
||||||
|
return errors.New("read only connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateConnectTargetSessionAttrsReadOnly is an ValidateConnectFunc that implements libpq compatible
|
||||||
|
// target_session_attrs=read-only.
|
||||||
|
func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error {
|
||||||
|
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
|
||||||
|
if result.Err != nil {
|
||||||
|
return result.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(result.Rows[0][0]) != "on" {
|
||||||
|
return errors.New("connection is not read only")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateConnectTargetSessionAttrsStandby is an ValidateConnectFunc that implements libpq compatible
|
||||||
|
// target_session_attrs=standby.
|
||||||
|
func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error {
|
||||||
|
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
|
||||||
|
if result.Err != nil {
|
||||||
|
return result.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(result.Rows[0][0]) != "t" {
|
||||||
|
return errors.New("server is not in hot standby mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateConnectTargetSessionAttrsPrimary is an ValidateConnectFunc that implements libpq compatible
|
||||||
|
// target_session_attrs=primary.
|
||||||
|
func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error {
|
||||||
|
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
|
||||||
|
if result.Err != nil {
|
||||||
|
return result.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(result.Rows[0][0]) == "t" {
|
||||||
|
return errors.New("server is in standby mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateConnectTargetSessionAttrsPreferStandby is an ValidateConnectFunc that implements libpq compatible
|
||||||
|
// target_session_attrs=prefer-standby.
|
||||||
|
func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error {
|
||||||
|
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
|
||||||
|
if result.Err != nil {
|
||||||
|
return result.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(result.Rows[0][0]) != "t" {
|
||||||
|
return &NotPreferredError{err: errors.New("server is not in hot standby mode")}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,63 @@
|
|||||||
|
//go:build !windows
|
||||||
|
// +build !windows
|
||||||
|
|
||||||
|
package pgconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"os/user"
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
func defaultSettings() map[string]string {
|
||||||
|
settings := make(map[string]string)
|
||||||
|
|
||||||
|
settings["host"] = defaultHost()
|
||||||
|
settings["port"] = "5432"
|
||||||
|
|
||||||
|
// Default to the OS user name. Purposely ignoring err getting user name from
|
||||||
|
// OS. The client application will simply have to specify the user in that
|
||||||
|
// case (which they typically will be doing anyway).
|
||||||
|
user, err := user.Current()
|
||||||
|
if err == nil {
|
||||||
|
settings["user"] = user.Username
|
||||||
|
settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass")
|
||||||
|
settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf")
|
||||||
|
sslcert := filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt")
|
||||||
|
sslkey := filepath.Join(user.HomeDir, ".postgresql", "postgresql.key")
|
||||||
|
if _, err := os.Stat(sslcert); err == nil {
|
||||||
|
if _, err := os.Stat(sslkey); err == nil {
|
||||||
|
// Both the cert and key must be present to use them, or do not use either
|
||||||
|
settings["sslcert"] = sslcert
|
||||||
|
settings["sslkey"] = sslkey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sslrootcert := filepath.Join(user.HomeDir, ".postgresql", "root.crt")
|
||||||
|
if _, err := os.Stat(sslrootcert); err == nil {
|
||||||
|
settings["sslrootcert"] = sslrootcert
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
settings["target_session_attrs"] = "any"
|
||||||
|
|
||||||
|
return settings
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost
|
||||||
|
// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it
|
||||||
|
// checks the existence of common locations.
|
||||||
|
func defaultHost() string {
|
||||||
|
candidatePaths := []string{
|
||||||
|
"/var/run/postgresql", // Debian
|
||||||
|
"/private/tmp", // OSX - homebrew
|
||||||
|
"/tmp", // standard PostgreSQL
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, path := range candidatePaths {
|
||||||
|
if _, err := os.Stat(path); err == nil {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "localhost"
|
||||||
|
}
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
package pgconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"os/user"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func defaultSettings() map[string]string {
|
||||||
|
settings := make(map[string]string)
|
||||||
|
|
||||||
|
settings["host"] = defaultHost()
|
||||||
|
settings["port"] = "5432"
|
||||||
|
|
||||||
|
// Default to the OS user name. Purposely ignoring err getting user name from
|
||||||
|
// OS. The client application will simply have to specify the user in that
|
||||||
|
// case (which they typically will be doing anyway).
|
||||||
|
user, err := user.Current()
|
||||||
|
appData := os.Getenv("APPDATA")
|
||||||
|
if err == nil {
|
||||||
|
// Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`,
|
||||||
|
// but the libpq default is just the `user` portion, so we strip off the first part.
|
||||||
|
username := user.Username
|
||||||
|
if strings.Contains(username, "\\") {
|
||||||
|
username = username[strings.LastIndex(username, "\\")+1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
settings["user"] = username
|
||||||
|
settings["passfile"] = filepath.Join(appData, "postgresql", "pgpass.conf")
|
||||||
|
settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf")
|
||||||
|
sslcert := filepath.Join(appData, "postgresql", "postgresql.crt")
|
||||||
|
sslkey := filepath.Join(appData, "postgresql", "postgresql.key")
|
||||||
|
if _, err := os.Stat(sslcert); err == nil {
|
||||||
|
if _, err := os.Stat(sslkey); err == nil {
|
||||||
|
// Both the cert and key must be present to use them, or do not use either
|
||||||
|
settings["sslcert"] = sslcert
|
||||||
|
settings["sslkey"] = sslkey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sslrootcert := filepath.Join(appData, "postgresql", "root.crt")
|
||||||
|
if _, err := os.Stat(sslrootcert); err == nil {
|
||||||
|
settings["sslrootcert"] = sslrootcert
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
settings["target_session_attrs"] = "any"
|
||||||
|
|
||||||
|
return settings
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost
|
||||||
|
// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it
|
||||||
|
// checks the existence of common locations.
|
||||||
|
func defaultHost() string {
|
||||||
|
return "localhost"
|
||||||
|
}
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
// Package pgconn is a low-level PostgreSQL database driver.
|
||||||
|
/*
|
||||||
|
pgconn provides lower level access to a PostgreSQL connection than a database/sql or pgx connection. It operates at
|
||||||
|
nearly the same level is the C library libpq.
|
||||||
|
|
||||||
|
Establishing a Connection
|
||||||
|
|
||||||
|
Use Connect to establish a connection. It accepts a connection string in URL or DSN and will read the environment for
|
||||||
|
libpq style environment variables.
|
||||||
|
|
||||||
|
Executing a Query
|
||||||
|
|
||||||
|
ExecParams and ExecPrepared execute a single query. They return readers that iterate over each row. The Read method
|
||||||
|
reads all rows into memory.
|
||||||
|
|
||||||
|
Executing Multiple Queries in a Single Round Trip
|
||||||
|
|
||||||
|
Exec and ExecBatch can execute multiple queries in a single round trip. They return readers that iterate over each query
|
||||||
|
result. The ReadAll method reads all query results into memory.
|
||||||
|
|
||||||
|
Pipeline Mode
|
||||||
|
|
||||||
|
Pipeline mode allows sending queries without having read the results of previously sent queries. It allows
|
||||||
|
control of exactly how many and when network round trips occur.
|
||||||
|
|
||||||
|
Context Support
|
||||||
|
|
||||||
|
All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the
|
||||||
|
method immediately returns. In most circumstances, this will close the underlying connection.
|
||||||
|
|
||||||
|
The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the
|
||||||
|
client to abort.
|
||||||
|
*/
|
||||||
|
package pgconn
|
||||||
@@ -0,0 +1,226 @@
|
|||||||
|
package pgconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server.
|
||||||
|
func SafeToRetry(err error) bool {
|
||||||
|
if e, ok := err.(interface{ SafeToRetry() bool }); ok {
|
||||||
|
return e.SafeToRetry()
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a
|
||||||
|
// context.DeadlineExceeded or an implementer of net.Error where Timeout() is true.
|
||||||
|
func Timeout(err error) bool {
|
||||||
|
var timeoutErr *errTimeout
|
||||||
|
return errors.As(err, &timeoutErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PgError represents an error reported by the PostgreSQL server. See
|
||||||
|
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
|
||||||
|
// detailed field description.
|
||||||
|
type PgError struct {
|
||||||
|
Severity string
|
||||||
|
Code string
|
||||||
|
Message string
|
||||||
|
Detail string
|
||||||
|
Hint string
|
||||||
|
Position int32
|
||||||
|
InternalPosition int32
|
||||||
|
InternalQuery string
|
||||||
|
Where string
|
||||||
|
SchemaName string
|
||||||
|
TableName string
|
||||||
|
ColumnName string
|
||||||
|
DataTypeName string
|
||||||
|
ConstraintName string
|
||||||
|
File string
|
||||||
|
Line int32
|
||||||
|
Routine string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pe *PgError) Error() string {
|
||||||
|
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
|
||||||
|
}
|
||||||
|
|
||||||
|
// SQLState returns the SQLState of the error.
|
||||||
|
func (pe *PgError) SQLState() string {
|
||||||
|
return pe.Code
|
||||||
|
}
|
||||||
|
|
||||||
|
type connectError struct {
|
||||||
|
config *Config
|
||||||
|
msg string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *connectError) Error() string {
|
||||||
|
sb := &strings.Builder{}
|
||||||
|
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg)
|
||||||
|
if e.err != nil {
|
||||||
|
fmt.Fprintf(sb, " (%s)", e.err.Error())
|
||||||
|
}
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *connectError) Unwrap() error {
|
||||||
|
return e.err
|
||||||
|
}
|
||||||
|
|
||||||
|
type connLockError struct {
|
||||||
|
status string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *connLockError) SafeToRetry() bool {
|
||||||
|
return true // a lock failure by definition happens before the connection is used.
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *connLockError) Error() string {
|
||||||
|
return e.status
|
||||||
|
}
|
||||||
|
|
||||||
|
type parseConfigError struct {
|
||||||
|
connString string
|
||||||
|
msg string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *parseConfigError) Error() string {
|
||||||
|
connString := redactPW(e.connString)
|
||||||
|
if e.err == nil {
|
||||||
|
return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *parseConfigError) Unwrap() error {
|
||||||
|
return e.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeTimeoutError(ctx context.Context, err error) error {
|
||||||
|
if err, ok := err.(net.Error); ok && err.Timeout() {
|
||||||
|
if ctx.Err() == context.Canceled {
|
||||||
|
// Since the timeout was caused by a context cancellation, the actual error is context.Canceled not the timeout error.
|
||||||
|
return context.Canceled
|
||||||
|
} else if ctx.Err() == context.DeadlineExceeded {
|
||||||
|
return &errTimeout{err: ctx.Err()}
|
||||||
|
} else {
|
||||||
|
return &errTimeout{err: err}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
type pgconnError struct {
|
||||||
|
msg string
|
||||||
|
err error
|
||||||
|
safeToRetry bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *pgconnError) Error() string {
|
||||||
|
if e.msg == "" {
|
||||||
|
return e.err.Error()
|
||||||
|
}
|
||||||
|
if e.err == nil {
|
||||||
|
return e.msg
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s: %s", e.msg, e.err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *pgconnError) SafeToRetry() bool {
|
||||||
|
return e.safeToRetry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *pgconnError) Unwrap() error {
|
||||||
|
return e.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// errTimeout occurs when an error was caused by a timeout. Specifically, it wraps an error which is
|
||||||
|
// context.Canceled, context.DeadlineExceeded, or an implementer of net.Error where Timeout() is true.
|
||||||
|
type errTimeout struct {
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *errTimeout) Error() string {
|
||||||
|
return fmt.Sprintf("timeout: %s", e.err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *errTimeout) SafeToRetry() bool {
|
||||||
|
return SafeToRetry(e.err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *errTimeout) Unwrap() error {
|
||||||
|
return e.err
|
||||||
|
}
|
||||||
|
|
||||||
|
type contextAlreadyDoneError struct {
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *contextAlreadyDoneError) Error() string {
|
||||||
|
return fmt.Sprintf("context already done: %s", e.err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *contextAlreadyDoneError) SafeToRetry() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *contextAlreadyDoneError) Unwrap() error {
|
||||||
|
return e.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// newContextAlreadyDoneError double-wraps a context error in `contextAlreadyDoneError` and `errTimeout`.
|
||||||
|
func newContextAlreadyDoneError(ctx context.Context) (err error) {
|
||||||
|
return &errTimeout{&contextAlreadyDoneError{err: ctx.Err()}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func redactPW(connString string) string {
|
||||||
|
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
|
||||||
|
if u, err := url.Parse(connString); err == nil {
|
||||||
|
return redactURL(u)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
quotedDSN := regexp.MustCompile(`password='[^']*'`)
|
||||||
|
connString = quotedDSN.ReplaceAllLiteralString(connString, "password=xxxxx")
|
||||||
|
plainDSN := regexp.MustCompile(`password=[^ ]*`)
|
||||||
|
connString = plainDSN.ReplaceAllLiteralString(connString, "password=xxxxx")
|
||||||
|
brokenURL := regexp.MustCompile(`:[^:@]+?@`)
|
||||||
|
connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@")
|
||||||
|
return connString
|
||||||
|
}
|
||||||
|
|
||||||
|
func redactURL(u *url.URL) string {
|
||||||
|
if u == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if _, pwSet := u.User.Password(); pwSet {
|
||||||
|
u.User = url.UserPassword(u.User.Username(), "xxxxx")
|
||||||
|
}
|
||||||
|
return u.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
type NotPreferredError struct {
|
||||||
|
err error
|
||||||
|
safeToRetry bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *NotPreferredError) Error() string {
|
||||||
|
return fmt.Sprintf("standby server not found: %s", e.err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *NotPreferredError) SafeToRetry() bool {
|
||||||
|
return e.safeToRetry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *NotPreferredError) Unwrap() error {
|
||||||
|
return e.err
|
||||||
|
}
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
package pgconn_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConfigError(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
expectedMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "url with password",
|
||||||
|
err: pgconn.NewParseConfigError("postgresql://foo:password@host", "msg", nil),
|
||||||
|
expectedMsg: "cannot parse `postgresql://foo:xxxxx@host`: msg",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dsn with password unquoted",
|
||||||
|
err: pgconn.NewParseConfigError("host=host password=password user=user", "msg", nil),
|
||||||
|
expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dsn with password quoted",
|
||||||
|
err: pgconn.NewParseConfigError("host=host password='pass word' user=user", "msg", nil),
|
||||||
|
expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "weird url",
|
||||||
|
err: pgconn.NewParseConfigError("postgresql://foo::pasword@host:1:", "msg", nil),
|
||||||
|
expectedMsg: "cannot parse `postgresql://foo:xxxxx@host:1:`: msg",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "weird url with slash in password",
|
||||||
|
err: pgconn.NewParseConfigError("postgres://user:pass/word@host:5432/db_name", "msg", nil),
|
||||||
|
expectedMsg: "cannot parse `postgres://user:xxxxxx@host:5432/db_name`: msg",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "url without password",
|
||||||
|
err: pgconn.NewParseConfigError("postgresql://other@host/db", "msg", nil),
|
||||||
|
expectedMsg: "cannot parse `postgresql://other@host/db`: msg",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
assert.EqualError(t, tt.err, tt.expectedMsg)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
// File export_test exports some methods for better testing.
|
||||||
|
|
||||||
|
package pgconn
|
||||||
|
|
||||||
|
func NewParseConfigError(conn, msg string, err error) error {
|
||||||
|
return &parseConfigError{
|
||||||
|
connString: conn,
|
||||||
|
msg: msg,
|
||||||
|
err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
package pgconn_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func closeConn(t testing.TB, conn *pgconn.PgConn) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
require.NoError(t, conn.Close(ctx))
|
||||||
|
select {
|
||||||
|
case <-conn.CleanupDone():
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("Connection cleanup exceeded maximum time")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do a simple query to ensure the connection is still usable
|
||||||
|
func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
|
result := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil).Read()
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
require.Nil(t, result.Err)
|
||||||
|
assert.Equal(t, 3, len(result.Rows))
|
||||||
|
assert.Equal(t, "1", string(result.Rows[0][0]))
|
||||||
|
assert.Equal(t, "2", string(result.Rows[1][0]))
|
||||||
|
assert.Equal(t, "3", string(result.Rows[2][0]))
|
||||||
|
}
|
||||||
@@ -0,0 +1,73 @@
|
|||||||
|
package ctxwatch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a
|
||||||
|
// time.
|
||||||
|
type ContextWatcher struct {
|
||||||
|
onCancel func()
|
||||||
|
onUnwatchAfterCancel func()
|
||||||
|
unwatchChan chan struct{}
|
||||||
|
|
||||||
|
lock sync.Mutex
|
||||||
|
watchInProgress bool
|
||||||
|
onCancelWasCalled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled.
|
||||||
|
// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and
|
||||||
|
// onCancel called.
|
||||||
|
func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher {
|
||||||
|
cw := &ContextWatcher{
|
||||||
|
onCancel: onCancel,
|
||||||
|
onUnwatchAfterCancel: onUnwatchAfterCancel,
|
||||||
|
unwatchChan: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
return cw
|
||||||
|
}
|
||||||
|
|
||||||
|
// Watch starts watching ctx. If ctx is canceled then the onCancel function passed to NewContextWatcher will be called.
|
||||||
|
func (cw *ContextWatcher) Watch(ctx context.Context) {
|
||||||
|
cw.lock.Lock()
|
||||||
|
defer cw.lock.Unlock()
|
||||||
|
|
||||||
|
if cw.watchInProgress {
|
||||||
|
panic("Watch already in progress")
|
||||||
|
}
|
||||||
|
|
||||||
|
cw.onCancelWasCalled = false
|
||||||
|
|
||||||
|
if ctx.Done() != nil {
|
||||||
|
cw.watchInProgress = true
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
cw.onCancel()
|
||||||
|
cw.onCancelWasCalled = true
|
||||||
|
<-cw.unwatchChan
|
||||||
|
case <-cw.unwatchChan:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
} else {
|
||||||
|
cw.watchInProgress = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unwatch stops watching the previously watched context. If the onCancel function passed to NewContextWatcher was
|
||||||
|
// called then onUnwatchAfterCancel will also be called.
|
||||||
|
func (cw *ContextWatcher) Unwatch() {
|
||||||
|
cw.lock.Lock()
|
||||||
|
defer cw.lock.Unlock()
|
||||||
|
|
||||||
|
if cw.watchInProgress {
|
||||||
|
cw.unwatchChan <- struct{}{}
|
||||||
|
if cw.onCancelWasCalled {
|
||||||
|
cw.onUnwatchAfterCancel()
|
||||||
|
}
|
||||||
|
cw.watchInProgress = false
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,163 @@
|
|||||||
|
package ctxwatch_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgconn/internal/ctxwatch"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestContextWatcherContextCancelled(t *testing.T) {
|
||||||
|
canceledChan := make(chan struct{})
|
||||||
|
cleanupCalled := false
|
||||||
|
cw := ctxwatch.NewContextWatcher(func() {
|
||||||
|
canceledChan <- struct{}{}
|
||||||
|
}, func() {
|
||||||
|
cleanupCalled = true
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cw.Watch(ctx)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-canceledChan:
|
||||||
|
case <-time.NewTimer(time.Second).C:
|
||||||
|
t.Fatal("Timed out waiting for cancel func to be called")
|
||||||
|
}
|
||||||
|
|
||||||
|
cw.Unwatch()
|
||||||
|
|
||||||
|
require.True(t, cleanupCalled, "Cleanup func was not called")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) {
|
||||||
|
cw := ctxwatch.NewContextWatcher(func() {
|
||||||
|
t.Error("cancel func should not have been called")
|
||||||
|
}, func() {
|
||||||
|
t.Error("cleanup func should not have been called")
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cw.Watch(ctx)
|
||||||
|
cw.Unwatch()
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextWatcherMultipleWatchPanics(t *testing.T) {
|
||||||
|
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
cw.Watch(ctx)
|
||||||
|
|
||||||
|
ctx2, cancel2 := context.WithCancel(context.Background())
|
||||||
|
defer cancel2()
|
||||||
|
require.Panics(t, func() { cw.Watch(ctx2) }, "Expected panic when Watch called multiple times")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) {
|
||||||
|
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||||
|
cw.Unwatch() // unwatch when not / never watching
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
cw.Watch(ctx)
|
||||||
|
cw.Unwatch()
|
||||||
|
cw.Unwatch() // double unwatch
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextWatcherUnwatchIsConcurrencySafe(t *testing.T) {
|
||||||
|
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
cw.Watch(ctx)
|
||||||
|
|
||||||
|
go cw.Unwatch()
|
||||||
|
go cw.Unwatch()
|
||||||
|
|
||||||
|
<-ctx.Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextWatcherStress(t *testing.T) {
|
||||||
|
var cancelFuncCalls int64
|
||||||
|
var cleanupFuncCalls int64
|
||||||
|
|
||||||
|
cw := ctxwatch.NewContextWatcher(func() {
|
||||||
|
atomic.AddInt64(&cancelFuncCalls, 1)
|
||||||
|
}, func() {
|
||||||
|
atomic.AddInt64(&cleanupFuncCalls, 1)
|
||||||
|
})
|
||||||
|
|
||||||
|
cycleCount := 100000
|
||||||
|
|
||||||
|
for i := 0; i < cycleCount; i++ {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cw.Watch(ctx)
|
||||||
|
if i%2 == 0 {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Without time.Sleep, cw.Unwatch will almost always run before the cancel func which means cancel will never happen. This gives us a better mix.
|
||||||
|
if i%3 == 0 {
|
||||||
|
time.Sleep(time.Nanosecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
cw.Unwatch()
|
||||||
|
if i%2 == 1 {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
actualCancelFuncCalls := atomic.LoadInt64(&cancelFuncCalls)
|
||||||
|
actualCleanupFuncCalls := atomic.LoadInt64(&cleanupFuncCalls)
|
||||||
|
|
||||||
|
if actualCancelFuncCalls == 0 {
|
||||||
|
t.Fatal("actualCancelFuncCalls == 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
maxCancelFuncCalls := int64(cycleCount) / 2
|
||||||
|
if actualCancelFuncCalls > maxCancelFuncCalls {
|
||||||
|
t.Errorf("cancel func calls should be no more than %d but was %d", actualCancelFuncCalls, maxCancelFuncCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actualCancelFuncCalls != actualCleanupFuncCalls {
|
||||||
|
t.Errorf("cancel func calls (%d) should be equal to cleanup func calls (%d) but was not", actualCancelFuncCalls, actualCleanupFuncCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkContextWatcherUncancellable(b *testing.B) {
|
||||||
|
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
cw.Watch(context.Background())
|
||||||
|
cw.Unwatch()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkContextWatcherCancelled(b *testing.B) {
|
||||||
|
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cw.Watch(ctx)
|
||||||
|
cancel()
|
||||||
|
cw.Unwatch()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkContextWatcherCancellable(b *testing.B) {
|
||||||
|
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
cw.Watch(ctx)
|
||||||
|
cw.Unwatch()
|
||||||
|
}
|
||||||
|
}
|
||||||
+100
@@ -0,0 +1,100 @@
|
|||||||
|
package pgconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgproto3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewGSSFunc creates a GSS authentication provider, for use with
|
||||||
|
// RegisterGSSProvider.
|
||||||
|
type NewGSSFunc func() (GSS, error)
|
||||||
|
|
||||||
|
var newGSS NewGSSFunc
|
||||||
|
|
||||||
|
// RegisterGSSProvider registers a GSS authentication provider. For example, if
|
||||||
|
// you need to use Kerberos to authenticate with your server, add this to your
|
||||||
|
// main package:
|
||||||
|
//
|
||||||
|
// import "github.com/otan/gopgkrb5"
|
||||||
|
//
|
||||||
|
// func init() {
|
||||||
|
// pgconn.RegisterGSSProvider(func() (pgconn.GSS, error) { return gopgkrb5.NewGSS() })
|
||||||
|
// }
|
||||||
|
func RegisterGSSProvider(newGSSArg NewGSSFunc) {
|
||||||
|
newGSS = newGSSArg
|
||||||
|
}
|
||||||
|
|
||||||
|
// GSS provides GSSAPI authentication (e.g., Kerberos).
|
||||||
|
type GSS interface {
|
||||||
|
GetInitToken(host string, service string) ([]byte, error)
|
||||||
|
GetInitTokenFromSPN(spn string) ([]byte, error)
|
||||||
|
Continue(inToken []byte) (done bool, outToken []byte, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PgConn) gssAuth() error {
|
||||||
|
if newGSS == nil {
|
||||||
|
return errors.New("kerberos error: no GSSAPI provider registered, see https://github.com/otan/gopgkrb5")
|
||||||
|
}
|
||||||
|
cli, err := newGSS()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var nextData []byte
|
||||||
|
if c.config.KerberosSpn != "" {
|
||||||
|
// Use the supplied SPN if provided.
|
||||||
|
nextData, err = cli.GetInitTokenFromSPN(c.config.KerberosSpn)
|
||||||
|
} else {
|
||||||
|
// Allow the kerberos service name to be overridden
|
||||||
|
service := "postgres"
|
||||||
|
if c.config.KerberosSrvName != "" {
|
||||||
|
service = c.config.KerberosSrvName
|
||||||
|
}
|
||||||
|
nextData, err = cli.GetInitToken(c.config.Host, service)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
gssResponse := &pgproto3.GSSResponse{
|
||||||
|
Data: nextData,
|
||||||
|
}
|
||||||
|
c.frontend.Send(gssResponse)
|
||||||
|
err = c.frontend.Flush()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
resp, err := c.rxGSSContinue()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var done bool
|
||||||
|
done, nextData, err = cli.Continue(resp.Data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if done {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) {
|
||||||
|
msg, err := c.receiveMessage()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch m := msg.(type) {
|
||||||
|
case *pgproto3.AuthenticationGSSContinue:
|
||||||
|
return m, nil
|
||||||
|
case *pgproto3.ErrorResponse:
|
||||||
|
return nil, ErrorResponseToPgError(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg)
|
||||||
|
}
|
||||||
+1962
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,41 @@
|
|||||||
|
package pgconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCommandTag(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var tests = []struct {
|
||||||
|
commandTag CommandTag
|
||||||
|
rowsAffected int64
|
||||||
|
isInsert bool
|
||||||
|
isUpdate bool
|
||||||
|
isDelete bool
|
||||||
|
isSelect bool
|
||||||
|
}{
|
||||||
|
{commandTag: CommandTag{s: "INSERT 0 5"}, rowsAffected: 5, isInsert: true},
|
||||||
|
{commandTag: CommandTag{s: "UPDATE 0"}, rowsAffected: 0, isUpdate: true},
|
||||||
|
{commandTag: CommandTag{s: "UPDATE 1"}, rowsAffected: 1, isUpdate: true},
|
||||||
|
{commandTag: CommandTag{s: "DELETE 0"}, rowsAffected: 0, isDelete: true},
|
||||||
|
{commandTag: CommandTag{s: "DELETE 1"}, rowsAffected: 1, isDelete: true},
|
||||||
|
{commandTag: CommandTag{s: "DELETE 1234567890"}, rowsAffected: 1234567890, isDelete: true},
|
||||||
|
{commandTag: CommandTag{s: "SELECT 1"}, rowsAffected: 1, isSelect: true},
|
||||||
|
{commandTag: CommandTag{s: "SELECT 99999999999"}, rowsAffected: 99999999999, isSelect: true},
|
||||||
|
{commandTag: CommandTag{s: "CREATE TABLE"}, rowsAffected: 0},
|
||||||
|
{commandTag: CommandTag{s: "ALTER TABLE"}, rowsAffected: 0},
|
||||||
|
{commandTag: CommandTag{s: "DROP TABLE"}, rowsAffected: 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
ct := tt.commandTag
|
||||||
|
assert.Equalf(t, tt.rowsAffected, ct.RowsAffected(), "%d. %v", i, tt.commandTag)
|
||||||
|
assert.Equalf(t, tt.isInsert, ct.Insert(), "%d. %v", i, tt.commandTag)
|
||||||
|
assert.Equalf(t, tt.isUpdate, ct.Update(), "%d. %v", i, tt.commandTag)
|
||||||
|
assert.Equalf(t, tt.isDelete, ct.Delete(), "%d. %v", i, tt.commandTag)
|
||||||
|
assert.Equalf(t, tt.isSelect, ct.Select(), "%d. %v", i, tt.commandTag)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,90 @@
|
|||||||
|
package pgconn_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"math/rand"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConnStress(t *testing.T) {
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
actionCount := 10000
|
||||||
|
if s := os.Getenv("PGX_TEST_STRESS_FACTOR"); s != "" {
|
||||||
|
stressFactor, err := strconv.ParseInt(s, 10, 64)
|
||||||
|
require.Nil(t, err, "Failed to parse PGX_TEST_STRESS_FACTOR")
|
||||||
|
actionCount *= int(stressFactor)
|
||||||
|
}
|
||||||
|
|
||||||
|
setupStressDB(t, pgConn)
|
||||||
|
|
||||||
|
actions := []struct {
|
||||||
|
name string
|
||||||
|
fn func(*pgconn.PgConn) error
|
||||||
|
}{
|
||||||
|
{"Exec Select", stressExecSelect},
|
||||||
|
{"ExecParams Select", stressExecParamsSelect},
|
||||||
|
{"Batch", stressBatch},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < actionCount; i++ {
|
||||||
|
action := actions[rand.Intn(len(actions))]
|
||||||
|
err := action.fn(pgConn)
|
||||||
|
require.Nilf(t, err, "%d: %s", i, action.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Each call with a context starts a goroutine. Ensure they are cleaned up when context is not canceled.
|
||||||
|
numGoroutine := runtime.NumGoroutine()
|
||||||
|
require.Truef(t, numGoroutine < 1000, "goroutines appear to be orphaned: %d in process", numGoroutine)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) {
|
||||||
|
_, err := pgConn.Exec(context.Background(), `
|
||||||
|
create temporary table widgets(
|
||||||
|
id serial primary key,
|
||||||
|
name varchar not null,
|
||||||
|
description text,
|
||||||
|
creation_time timestamptz default now()
|
||||||
|
);
|
||||||
|
|
||||||
|
insert into widgets(name, description) values
|
||||||
|
('Foo', 'bar'),
|
||||||
|
('baz', 'Something really long Something really long Something really long Something really long Something really long'),
|
||||||
|
('a', 'b')`).ReadAll()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func stressExecSelect(pgConn *pgconn.PgConn) error {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
_, err := pgConn.Exec(ctx, "select * from widgets").ReadAll()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func stressExecParamsSelect(pgConn *pgconn.PgConn) error {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
result := pgConn.ExecParams(ctx, "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read()
|
||||||
|
return result.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
func stressBatch(pgConn *pgconn.PgConn) error {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
batch := &pgconn.Batch{}
|
||||||
|
|
||||||
|
batch.ExecParams("select * from widgets", nil, nil, nil, nil)
|
||||||
|
batch.ExecParams("select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil)
|
||||||
|
_, err := pgConn.ExecBatch(ctx, batch).ReadAll()
|
||||||
|
return err
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,7 @@
|
|||||||
|
# pgproto3
|
||||||
|
|
||||||
|
Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3.
|
||||||
|
|
||||||
|
pgproto3 can be used as a foundation for PostgreSQL drivers, proxies, mock servers, load balancers and more.
|
||||||
|
|
||||||
|
See example/pgfortune for a playful example of a fake PostgreSQL server.
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticationCleartextPassword is a message sent from the backend indicating that a clear-text password is required.
|
||||||
|
type AuthenticationCleartextPassword struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*AuthenticationCleartextPassword) Backend() {}
|
||||||
|
|
||||||
|
// Backend identifies this message as an authentication response.
|
||||||
|
func (*AuthenticationCleartextPassword) AuthenticationResponse() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
|
||||||
|
if len(src) != 4 {
|
||||||
|
return errors.New("bad authentication message size")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeCleartextPassword {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
dst = pgio.AppendInt32(dst, 8)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src AuthenticationCleartextPassword) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "AuthenticationCleartextPassword",
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,59 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthenticationGSS struct{}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSS) Backend() {}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSS) AuthenticationResponse() {}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSS) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return errors.New("authentication message too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeGSS {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSS) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
dst = pgio.AppendInt32(dst, 4)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeGSS)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Data []byte
|
||||||
|
}{
|
||||||
|
Type: "AuthenticationGSS",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSS) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
Type string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthenticationGSSContinue struct {
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSSContinue) Backend() {}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSSContinue) AuthenticationResponse() {}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSSContinue) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return errors.New("authentication message too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeGSSCont {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
a.Data = src[4:]
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
dst = pgio.AppendInt32(dst, int32(len(a.Data))+8)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
|
||||||
|
dst = append(dst, a.Data...)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Data []byte
|
||||||
|
}{
|
||||||
|
Type: "AuthenticationGSSContinue",
|
||||||
|
Data: a.Data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSSContinue) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
Type string
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
a.Data = msg.Data
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticationMD5Password is a message sent from the backend indicating that an MD5 hashed password is required.
|
||||||
|
type AuthenticationMD5Password struct {
|
||||||
|
Salt [4]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*AuthenticationMD5Password) Backend() {}
|
||||||
|
|
||||||
|
// Backend identifies this message as an authentication response.
|
||||||
|
func (*AuthenticationMD5Password) AuthenticationResponse() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *AuthenticationMD5Password) Decode(src []byte) error {
|
||||||
|
if len(src) != 8 {
|
||||||
|
return errors.New("bad authentication message size")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeMD5Password {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(dst.Salt[:], src[4:8])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
dst = pgio.AppendInt32(dst, 12)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeMD5Password)
|
||||||
|
dst = append(dst, src.Salt[:]...)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src AuthenticationMD5Password) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Salt [4]byte
|
||||||
|
}{
|
||||||
|
Type: "AuthenticationMD5Password",
|
||||||
|
Salt: src.Salt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *AuthenticationMD5Password) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
Type string
|
||||||
|
Salt [4]byte
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Salt = msg.Salt
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticationOk is a message sent from the backend indicating that authentication was successful.
|
||||||
|
type AuthenticationOk struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*AuthenticationOk) Backend() {}
|
||||||
|
|
||||||
|
// Backend identifies this message as an authentication response.
|
||||||
|
func (*AuthenticationOk) AuthenticationResponse() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *AuthenticationOk) Decode(src []byte) error {
|
||||||
|
if len(src) != 4 {
|
||||||
|
return errors.New("bad authentication message size")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeOk {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *AuthenticationOk) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
dst = pgio.AppendInt32(dst, 8)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeOk)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src AuthenticationOk) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "AuthenticationOK",
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,76 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticationSASL is a message sent from the backend indicating that SASL authentication is required.
|
||||||
|
type AuthenticationSASL struct {
|
||||||
|
AuthMechanisms []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*AuthenticationSASL) Backend() {}
|
||||||
|
|
||||||
|
// Backend identifies this message as an authentication response.
|
||||||
|
func (*AuthenticationSASL) AuthenticationResponse() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *AuthenticationSASL) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return errors.New("authentication message too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeSASL {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
authMechanisms := src[4:]
|
||||||
|
for len(authMechanisms) > 1 {
|
||||||
|
idx := bytes.IndexByte(authMechanisms, 0)
|
||||||
|
if idx == -1 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "AuthenticationSASL", details: "unterminated string"}
|
||||||
|
}
|
||||||
|
dst.AuthMechanisms = append(dst.AuthMechanisms, string(authMechanisms[:idx]))
|
||||||
|
authMechanisms = authMechanisms[idx+1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *AuthenticationSASL) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeSASL)
|
||||||
|
|
||||||
|
for _, s := range src.AuthMechanisms {
|
||||||
|
dst = append(dst, []byte(s)...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
}
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src AuthenticationSASL) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
AuthMechanisms []string
|
||||||
|
}{
|
||||||
|
Type: "AuthenticationSASL",
|
||||||
|
AuthMechanisms: src.AuthMechanisms,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,81 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticationSASLContinue is a message sent from the backend containing a SASL challenge.
|
||||||
|
type AuthenticationSASLContinue struct {
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*AuthenticationSASLContinue) Backend() {}
|
||||||
|
|
||||||
|
// Backend identifies this message as an authentication response.
|
||||||
|
func (*AuthenticationSASLContinue) AuthenticationResponse() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return errors.New("authentication message too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeSASLContinue {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Data = src[4:]
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)
|
||||||
|
|
||||||
|
dst = append(dst, src.Data...)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src AuthenticationSASLContinue) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Data string
|
||||||
|
}{
|
||||||
|
Type: "AuthenticationSASLContinue",
|
||||||
|
Data: string(src.Data),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *AuthenticationSASLContinue) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
Data string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Data = []byte(msg.Data)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,81 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticationSASLFinal is a message sent from the backend indicating a SASL authentication has completed.
|
||||||
|
type AuthenticationSASLFinal struct {
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*AuthenticationSASLFinal) Backend() {}
|
||||||
|
|
||||||
|
// Backend identifies this message as an authentication response.
|
||||||
|
func (*AuthenticationSASLFinal) AuthenticationResponse() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return errors.New("authentication message too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeSASLFinal {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Data = src[4:]
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)
|
||||||
|
|
||||||
|
dst = append(dst, src.Data...)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (src AuthenticationSASLFinal) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Data string
|
||||||
|
}{
|
||||||
|
Type: "AuthenticationSASLFinal",
|
||||||
|
Data: string(src.Data),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *AuthenticationSASLFinal) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
Data string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Data = []byte(msg.Data)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,262 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Backend acts as a server for the PostgreSQL wire protocol version 3.
|
||||||
|
type Backend struct {
|
||||||
|
cr *chunkReader
|
||||||
|
w io.Writer
|
||||||
|
|
||||||
|
// tracer is used to trace messages when Send or Receive is called. This means an outbound message is traced
|
||||||
|
// before it is actually transmitted (i.e. before Flush).
|
||||||
|
tracer *tracer
|
||||||
|
|
||||||
|
wbuf []byte
|
||||||
|
|
||||||
|
// Frontend message flyweights
|
||||||
|
bind Bind
|
||||||
|
cancelRequest CancelRequest
|
||||||
|
_close Close
|
||||||
|
copyFail CopyFail
|
||||||
|
copyData CopyData
|
||||||
|
copyDone CopyDone
|
||||||
|
describe Describe
|
||||||
|
execute Execute
|
||||||
|
flush Flush
|
||||||
|
functionCall FunctionCall
|
||||||
|
gssEncRequest GSSEncRequest
|
||||||
|
parse Parse
|
||||||
|
query Query
|
||||||
|
sslRequest SSLRequest
|
||||||
|
startupMessage StartupMessage
|
||||||
|
sync Sync
|
||||||
|
terminate Terminate
|
||||||
|
|
||||||
|
bodyLen int
|
||||||
|
msgType byte
|
||||||
|
partialMsg bool
|
||||||
|
authType uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code.
|
||||||
|
maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source.
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewBackend creates a new Backend.
|
||||||
|
func NewBackend(r io.Reader, w io.Writer) *Backend {
|
||||||
|
cr := newChunkReader(r, 0)
|
||||||
|
return &Backend{cr: cr, w: w}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is
|
||||||
|
// called.
|
||||||
|
func (b *Backend) Send(msg BackendMessage) {
|
||||||
|
prevLen := len(b.wbuf)
|
||||||
|
b.wbuf = msg.Encode(b.wbuf)
|
||||||
|
if b.tracer != nil {
|
||||||
|
b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush writes any pending messages to the frontend (i.e. the client).
|
||||||
|
func (b *Backend) Flush() error {
|
||||||
|
n, err := b.w.Write(b.wbuf)
|
||||||
|
|
||||||
|
const maxLen = 1024
|
||||||
|
if len(b.wbuf) > maxLen {
|
||||||
|
b.wbuf = make([]byte, 0, maxLen)
|
||||||
|
} else {
|
||||||
|
b.wbuf = b.wbuf[:0]
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return &writeError{err: err, safeToRetry: n == 0}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trace starts tracing the message traffic to w. It writes in a similar format to that produced by the libpq function
|
||||||
|
// PQtrace.
|
||||||
|
func (b *Backend) Trace(w io.Writer, options TracerOptions) {
|
||||||
|
b.tracer = &tracer{
|
||||||
|
w: w,
|
||||||
|
buf: &bytes.Buffer{},
|
||||||
|
TracerOptions: options,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Untrace stops tracing.
|
||||||
|
func (b *Backend) Untrace() {
|
||||||
|
b.tracer = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method
|
||||||
|
// because the initial connection message is "special" and does not include the message type as the first byte. This
|
||||||
|
// will return either a StartupMessage, SSLRequest, GSSEncRequest, or CancelRequest.
|
||||||
|
func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
|
||||||
|
buf, err := b.cr.Next(4)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
msgSize := int(binary.BigEndian.Uint32(buf) - 4)
|
||||||
|
|
||||||
|
if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen {
|
||||||
|
return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf, err = b.cr.Next(msgSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, translateEOFtoErrUnexpectedEOF(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
code := binary.BigEndian.Uint32(buf)
|
||||||
|
|
||||||
|
switch code {
|
||||||
|
case ProtocolVersionNumber:
|
||||||
|
err = b.startupMessage.Decode(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &b.startupMessage, nil
|
||||||
|
case sslRequestNumber:
|
||||||
|
err = b.sslRequest.Decode(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &b.sslRequest, nil
|
||||||
|
case cancelRequestCode:
|
||||||
|
err = b.cancelRequest.Decode(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &b.cancelRequest, nil
|
||||||
|
case gssEncReqNumber:
|
||||||
|
err = b.gssEncRequest.Decode(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &b.gssEncRequest, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown startup message code: %d", code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive receives a message from the frontend. The returned message is only valid until the next call to Receive.
|
||||||
|
func (b *Backend) Receive() (FrontendMessage, error) {
|
||||||
|
if !b.partialMsg {
|
||||||
|
header, err := b.cr.Next(5)
|
||||||
|
if err != nil {
|
||||||
|
return nil, translateEOFtoErrUnexpectedEOF(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.msgType = header[0]
|
||||||
|
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
|
||||||
|
b.partialMsg = true
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg FrontendMessage
|
||||||
|
switch b.msgType {
|
||||||
|
case 'B':
|
||||||
|
msg = &b.bind
|
||||||
|
case 'C':
|
||||||
|
msg = &b._close
|
||||||
|
case 'D':
|
||||||
|
msg = &b.describe
|
||||||
|
case 'E':
|
||||||
|
msg = &b.execute
|
||||||
|
case 'F':
|
||||||
|
msg = &b.functionCall
|
||||||
|
case 'f':
|
||||||
|
msg = &b.copyFail
|
||||||
|
case 'd':
|
||||||
|
msg = &b.copyData
|
||||||
|
case 'c':
|
||||||
|
msg = &b.copyDone
|
||||||
|
case 'H':
|
||||||
|
msg = &b.flush
|
||||||
|
case 'P':
|
||||||
|
msg = &b.parse
|
||||||
|
case 'p':
|
||||||
|
switch b.authType {
|
||||||
|
case AuthTypeSASL:
|
||||||
|
msg = &SASLInitialResponse{}
|
||||||
|
case AuthTypeSASLContinue:
|
||||||
|
msg = &SASLResponse{}
|
||||||
|
case AuthTypeSASLFinal:
|
||||||
|
msg = &SASLResponse{}
|
||||||
|
case AuthTypeGSS, AuthTypeGSSCont:
|
||||||
|
msg = &GSSResponse{}
|
||||||
|
case AuthTypeCleartextPassword, AuthTypeMD5Password:
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
// to maintain backwards compatability
|
||||||
|
msg = &PasswordMessage{}
|
||||||
|
}
|
||||||
|
case 'Q':
|
||||||
|
msg = &b.query
|
||||||
|
case 'S':
|
||||||
|
msg = &b.sync
|
||||||
|
case 'X':
|
||||||
|
msg = &b.terminate
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown message type: %c", b.msgType)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgBody, err := b.cr.Next(b.bodyLen)
|
||||||
|
if err != nil {
|
||||||
|
return nil, translateEOFtoErrUnexpectedEOF(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.partialMsg = false
|
||||||
|
|
||||||
|
err = msg.Decode(msgBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.tracer != nil {
|
||||||
|
b.tracer.traceMessage('F', int32(5+len(msgBody)), msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAuthType sets the authentication type in the backend.
|
||||||
|
// Since multiple message types can start with 'p', SetAuthType allows
|
||||||
|
// contextual identification of FrontendMessages. For example, in the
|
||||||
|
// PG message flow documentation for PasswordMessage:
|
||||||
|
//
|
||||||
|
// Byte1('p')
|
||||||
|
//
|
||||||
|
// Identifies the message as a password response. Note that this is also used for
|
||||||
|
// GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from
|
||||||
|
// the context.
|
||||||
|
//
|
||||||
|
// Since the Frontend does not know about the state of a backend, it is important
|
||||||
|
// to call SetAuthType() after an authentication request is received by the Frontend.
|
||||||
|
func (b *Backend) SetAuthType(authType uint32) error {
|
||||||
|
switch authType {
|
||||||
|
case AuthTypeOk,
|
||||||
|
AuthTypeCleartextPassword,
|
||||||
|
AuthTypeMD5Password,
|
||||||
|
AuthTypeSCMCreds,
|
||||||
|
AuthTypeGSS,
|
||||||
|
AuthTypeGSSCont,
|
||||||
|
AuthTypeSSPI,
|
||||||
|
AuthTypeSASL,
|
||||||
|
AuthTypeSASLContinue,
|
||||||
|
AuthTypeSASLFinal:
|
||||||
|
b.authType = authType
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("authType not recognized: %d", authType)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,51 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BackendKeyData struct {
|
||||||
|
ProcessID uint32
|
||||||
|
SecretKey uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*BackendKeyData) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *BackendKeyData) Decode(src []byte) error {
|
||||||
|
if len(src) != 8 {
|
||||||
|
return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.ProcessID = binary.BigEndian.Uint32(src[:4])
|
||||||
|
dst.SecretKey = binary.BigEndian.Uint32(src[4:])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *BackendKeyData) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'K')
|
||||||
|
dst = pgio.AppendUint32(dst, 12)
|
||||||
|
dst = pgio.AppendUint32(dst, src.ProcessID)
|
||||||
|
dst = pgio.AppendUint32(dst, src.SecretKey)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src BackendKeyData) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ProcessID uint32
|
||||||
|
SecretKey uint32
|
||||||
|
}{
|
||||||
|
Type: "BackendKeyData",
|
||||||
|
ProcessID: src.ProcessID,
|
||||||
|
SecretKey: src.SecretKey,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,122 @@
|
|||||||
|
package pgproto3_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/pgio"
|
||||||
|
"github.com/jackc/pgx/v5/pgproto3"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBackendReceiveInterrupted(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
server := &interruptReader{}
|
||||||
|
server.push([]byte{'Q', 0, 0, 0, 6})
|
||||||
|
|
||||||
|
backend := pgproto3.NewBackend(server, nil)
|
||||||
|
|
||||||
|
msg, err := backend.Receive()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected err")
|
||||||
|
}
|
||||||
|
if msg != nil {
|
||||||
|
t.Fatalf("did not expect msg, but %v", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
server.push([]byte{'I', 0})
|
||||||
|
|
||||||
|
msg, err = backend.Receive()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if msg, ok := msg.(*pgproto3.Query); !ok || msg.String != "I" {
|
||||||
|
t.Fatalf("unexpected msg: %v", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackendReceiveUnexpectedEOF(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
server := &interruptReader{}
|
||||||
|
server.push([]byte{'Q', 0, 0, 0, 6})
|
||||||
|
|
||||||
|
backend := pgproto3.NewBackend(server, nil)
|
||||||
|
|
||||||
|
// Receive regular msg
|
||||||
|
msg, err := backend.Receive()
|
||||||
|
assert.Nil(t, msg)
|
||||||
|
assert.Equal(t, io.ErrUnexpectedEOF, err)
|
||||||
|
|
||||||
|
// Receive StartupMessage msg
|
||||||
|
dst := []byte{}
|
||||||
|
dst = pgio.AppendUint32(dst, 1000) // tell the backend we expect 1000 bytes to be read
|
||||||
|
dst = pgio.AppendUint32(dst, 1) // only send 1 byte
|
||||||
|
server.push(dst)
|
||||||
|
|
||||||
|
msg, err = backend.ReceiveStartupMessage()
|
||||||
|
assert.Nil(t, msg)
|
||||||
|
assert.Equal(t, io.ErrUnexpectedEOF, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartupMessage(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("valid StartupMessage", func(t *testing.T) {
|
||||||
|
want := &pgproto3.StartupMessage{
|
||||||
|
ProtocolVersion: pgproto3.ProtocolVersionNumber,
|
||||||
|
Parameters: map[string]string{
|
||||||
|
"username": "tester",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dst := []byte{}
|
||||||
|
dst = want.Encode(dst)
|
||||||
|
|
||||||
|
server := &interruptReader{}
|
||||||
|
server.push(dst)
|
||||||
|
|
||||||
|
backend := pgproto3.NewBackend(server, nil)
|
||||||
|
|
||||||
|
msg, err := backend.ReceiveStartupMessage()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, want, msg)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid packet length", func(t *testing.T) {
|
||||||
|
wantErr := "invalid length of startup packet"
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
packetLen uint32
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "large packet length",
|
||||||
|
// Since the StartupMessage contains the "Length of message contents
|
||||||
|
// in bytes, including self", the max startup packet length is actually
|
||||||
|
// 10000+4. Therefore, let's go past the limit with 10005
|
||||||
|
packetLen: 10005,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "short packet length",
|
||||||
|
packetLen: 3,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
server := &interruptReader{}
|
||||||
|
dst := []byte{}
|
||||||
|
dst = pgio.AppendUint32(dst, tt.packetLen)
|
||||||
|
dst = pgio.AppendUint32(dst, pgproto3.ProtocolVersionNumber)
|
||||||
|
server.push(dst)
|
||||||
|
|
||||||
|
backend := pgproto3.NewBackend(server, nil)
|
||||||
|
|
||||||
|
msg, err := backend.ReceiveStartupMessage()
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, msg)
|
||||||
|
require.Contains(t, err.Error(), wantErr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BigEndianBuf [8]byte
|
||||||
|
|
||||||
|
func (b BigEndianBuf) Int16(n int16) []byte {
|
||||||
|
buf := b[0:2]
|
||||||
|
binary.BigEndian.PutUint16(buf, uint16(n))
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b BigEndianBuf) Uint16(n uint16) []byte {
|
||||||
|
buf := b[0:2]
|
||||||
|
binary.BigEndian.PutUint16(buf, n)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b BigEndianBuf) Int32(n int32) []byte {
|
||||||
|
buf := b[0:4]
|
||||||
|
binary.BigEndian.PutUint32(buf, uint32(n))
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b BigEndianBuf) Uint32(n uint32) []byte {
|
||||||
|
buf := b[0:4]
|
||||||
|
binary.BigEndian.PutUint32(buf, n)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b BigEndianBuf) Int64(n int64) []byte {
|
||||||
|
buf := b[0:8]
|
||||||
|
binary.BigEndian.PutUint64(buf, uint64(n))
|
||||||
|
return buf
|
||||||
|
}
|
||||||
@@ -0,0 +1,216 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Bind struct {
|
||||||
|
DestinationPortal string
|
||||||
|
PreparedStatement string
|
||||||
|
ParameterFormatCodes []int16
|
||||||
|
Parameters [][]byte
|
||||||
|
ResultFormatCodes []int16
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*Bind) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *Bind) Decode(src []byte) error {
|
||||||
|
*dst = Bind{}
|
||||||
|
|
||||||
|
idx := bytes.IndexByte(src, 0)
|
||||||
|
if idx < 0 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||||
|
}
|
||||||
|
dst.DestinationPortal = string(src[:idx])
|
||||||
|
rp := idx + 1
|
||||||
|
|
||||||
|
idx = bytes.IndexByte(src[rp:], 0)
|
||||||
|
if idx < 0 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||||
|
}
|
||||||
|
dst.PreparedStatement = string(src[rp : rp+idx])
|
||||||
|
rp += idx + 1
|
||||||
|
|
||||||
|
if len(src[rp:]) < 2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||||
|
}
|
||||||
|
parameterFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
|
||||||
|
rp += 2
|
||||||
|
|
||||||
|
if parameterFormatCodeCount > 0 {
|
||||||
|
dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount)
|
||||||
|
|
||||||
|
if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||||
|
}
|
||||||
|
for i := 0; i < parameterFormatCodeCount; i++ {
|
||||||
|
dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
|
||||||
|
rp += 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(src[rp:]) < 2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||||
|
}
|
||||||
|
parameterCount := int(binary.BigEndian.Uint16(src[rp:]))
|
||||||
|
rp += 2
|
||||||
|
|
||||||
|
if parameterCount > 0 {
|
||||||
|
dst.Parameters = make([][]byte, parameterCount)
|
||||||
|
|
||||||
|
for i := 0; i < parameterCount; i++ {
|
||||||
|
if len(src[rp:]) < 4 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||||
|
}
|
||||||
|
|
||||||
|
msgSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
||||||
|
rp += 4
|
||||||
|
|
||||||
|
// null
|
||||||
|
if msgSize == -1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(src[rp:]) < msgSize {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Parameters[i] = src[rp : rp+msgSize]
|
||||||
|
rp += msgSize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(src[rp:]) < 2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||||
|
}
|
||||||
|
resultFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
|
||||||
|
rp += 2
|
||||||
|
|
||||||
|
dst.ResultFormatCodes = make([]int16, resultFormatCodeCount)
|
||||||
|
if len(src[rp:]) < len(dst.ResultFormatCodes)*2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||||
|
}
|
||||||
|
for i := 0; i < resultFormatCodeCount; i++ {
|
||||||
|
dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
|
||||||
|
rp += 2
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *Bind) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'B')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = append(dst, src.DestinationPortal...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
dst = append(dst, src.PreparedStatement...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
|
||||||
|
for _, fc := range src.ParameterFormatCodes {
|
||||||
|
dst = pgio.AppendInt16(dst, fc)
|
||||||
|
}
|
||||||
|
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
|
||||||
|
for _, p := range src.Parameters {
|
||||||
|
if p == nil {
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dst = pgio.AppendInt32(dst, int32(len(p)))
|
||||||
|
dst = append(dst, p...)
|
||||||
|
}
|
||||||
|
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
|
||||||
|
for _, fc := range src.ResultFormatCodes {
|
||||||
|
dst = pgio.AppendInt16(dst, fc)
|
||||||
|
}
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src Bind) MarshalJSON() ([]byte, error) {
|
||||||
|
formattedParameters := make([]map[string]string, len(src.Parameters))
|
||||||
|
for i, p := range src.Parameters {
|
||||||
|
if p == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
textFormat := true
|
||||||
|
if len(src.ParameterFormatCodes) == 1 {
|
||||||
|
textFormat = src.ParameterFormatCodes[0] == 0
|
||||||
|
} else if len(src.ParameterFormatCodes) > 1 {
|
||||||
|
textFormat = src.ParameterFormatCodes[i] == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if textFormat {
|
||||||
|
formattedParameters[i] = map[string]string{"text": string(p)}
|
||||||
|
} else {
|
||||||
|
formattedParameters[i] = map[string]string{"binary": hex.EncodeToString(p)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
DestinationPortal string
|
||||||
|
PreparedStatement string
|
||||||
|
ParameterFormatCodes []int16
|
||||||
|
Parameters []map[string]string
|
||||||
|
ResultFormatCodes []int16
|
||||||
|
}{
|
||||||
|
Type: "Bind",
|
||||||
|
DestinationPortal: src.DestinationPortal,
|
||||||
|
PreparedStatement: src.PreparedStatement,
|
||||||
|
ParameterFormatCodes: src.ParameterFormatCodes,
|
||||||
|
Parameters: formattedParameters,
|
||||||
|
ResultFormatCodes: src.ResultFormatCodes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *Bind) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
DestinationPortal string
|
||||||
|
PreparedStatement string
|
||||||
|
ParameterFormatCodes []int16
|
||||||
|
Parameters []map[string]string
|
||||||
|
ResultFormatCodes []int16
|
||||||
|
}
|
||||||
|
err := json.Unmarshal(data, &msg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dst.DestinationPortal = msg.DestinationPortal
|
||||||
|
dst.PreparedStatement = msg.PreparedStatement
|
||||||
|
dst.ParameterFormatCodes = msg.ParameterFormatCodes
|
||||||
|
dst.Parameters = make([][]byte, len(msg.Parameters))
|
||||||
|
dst.ResultFormatCodes = msg.ResultFormatCodes
|
||||||
|
for n, parameter := range msg.Parameters {
|
||||||
|
dst.Parameters[n], err = getValueFromJSON(parameter)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot get param %d: %w", n, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BindComplete struct{}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*BindComplete) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *BindComplete) Decode(src []byte) error {
|
||||||
|
if len(src) != 0 {
|
||||||
|
return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *BindComplete) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, '2', 0, 0, 0, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src BindComplete) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "BindComplete",
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
const cancelRequestCode = 80877102
|
||||||
|
|
||||||
|
type CancelRequest struct {
|
||||||
|
ProcessID uint32
|
||||||
|
SecretKey uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*CancelRequest) Frontend() {}
|
||||||
|
|
||||||
|
func (dst *CancelRequest) Decode(src []byte) error {
|
||||||
|
if len(src) != 12 {
|
||||||
|
return errors.New("bad cancel request size")
|
||||||
|
}
|
||||||
|
|
||||||
|
requestCode := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if requestCode != cancelRequestCode {
|
||||||
|
return errors.New("bad cancel request code")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.ProcessID = binary.BigEndian.Uint32(src[4:])
|
||||||
|
dst.SecretKey = binary.BigEndian.Uint32(src[8:])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 4 byte message length.
|
||||||
|
func (src *CancelRequest) Encode(dst []byte) []byte {
|
||||||
|
dst = pgio.AppendInt32(dst, 16)
|
||||||
|
dst = pgio.AppendInt32(dst, cancelRequestCode)
|
||||||
|
dst = pgio.AppendUint32(dst, src.ProcessID)
|
||||||
|
dst = pgio.AppendUint32(dst, src.SecretKey)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src CancelRequest) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ProcessID uint32
|
||||||
|
SecretKey uint32
|
||||||
|
}{
|
||||||
|
Type: "CancelRequest",
|
||||||
|
ProcessID: src.ProcessID,
|
||||||
|
SecretKey: src.SecretKey,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,90 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/iobufpool"
|
||||||
|
)
|
||||||
|
|
||||||
|
// chunkReader is a io.Reader wrapper that minimizes IO reads and memory allocations. It allocates memory in chunks and
|
||||||
|
// will read as much as will fit in the current buffer in a single call regardless of how large a read is actually
|
||||||
|
// requested. The memory returned via Next is only valid until the next call to Next.
|
||||||
|
//
|
||||||
|
// This is roughly equivalent to a bufio.Reader that only uses Peek and Discard to never copy bytes.
|
||||||
|
type chunkReader struct {
|
||||||
|
r io.Reader
|
||||||
|
|
||||||
|
buf []byte
|
||||||
|
rp, wp int // buf read position and write position
|
||||||
|
|
||||||
|
minBufSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
// newChunkReader creates and returns a new chunkReader for r with default configuration. If minBufSize is <= 0 it uses
|
||||||
|
// a default value.
|
||||||
|
func newChunkReader(r io.Reader, minBufSize int) *chunkReader {
|
||||||
|
if minBufSize <= 0 {
|
||||||
|
// By historical reasons Postgres currently has 8KB send buffer inside,
|
||||||
|
// so here we want to have at least the same size buffer.
|
||||||
|
// @see https://github.com/postgres/postgres/blob/249d64999615802752940e017ee5166e726bc7cd/src/backend/libpq/pqcomm.c#L134
|
||||||
|
// @see https://www.postgresql.org/message-id/0cdc5485-cb3c-5e16-4a46-e3b2f7a41322%40ya.ru
|
||||||
|
//
|
||||||
|
// In addition, testing has found no benefit of any larger buffer.
|
||||||
|
minBufSize = 8192
|
||||||
|
}
|
||||||
|
|
||||||
|
return &chunkReader{
|
||||||
|
r: r,
|
||||||
|
minBufSize: minBufSize,
|
||||||
|
buf: iobufpool.Get(minBufSize),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next returns buf filled with the next n bytes. buf is only valid until next call of Next. If an error occurs, buf
|
||||||
|
// will be nil.
|
||||||
|
func (r *chunkReader) Next(n int) (buf []byte, err error) {
|
||||||
|
// Reset the buffer if it is empty
|
||||||
|
if r.rp == r.wp {
|
||||||
|
if len(r.buf) != r.minBufSize {
|
||||||
|
iobufpool.Put(r.buf)
|
||||||
|
r.buf = iobufpool.Get(r.minBufSize)
|
||||||
|
}
|
||||||
|
r.rp = 0
|
||||||
|
r.wp = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// n bytes already in buf
|
||||||
|
if (r.wp - r.rp) >= n {
|
||||||
|
buf = r.buf[r.rp : r.rp+n : r.rp+n]
|
||||||
|
r.rp += n
|
||||||
|
return buf, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// buf is smaller than requested number of bytes
|
||||||
|
if len(r.buf) < n {
|
||||||
|
bigBuf := iobufpool.Get(n)
|
||||||
|
r.wp = copy(bigBuf, r.buf[r.rp:r.wp])
|
||||||
|
r.rp = 0
|
||||||
|
iobufpool.Put(r.buf)
|
||||||
|
r.buf = bigBuf
|
||||||
|
}
|
||||||
|
|
||||||
|
// buf is large enough, but need to shift filled area to start to make enough contiguous space
|
||||||
|
minReadCount := n - (r.wp - r.rp)
|
||||||
|
if (len(r.buf) - r.wp) < minReadCount {
|
||||||
|
r.wp = copy(r.buf, r.buf[r.rp:r.wp])
|
||||||
|
r.rp = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read at least the required number of bytes from the underlying io.Reader
|
||||||
|
readBytesCount, err := io.ReadAtLeast(r.r, r.buf[r.wp:], minReadCount)
|
||||||
|
r.wp += readBytesCount
|
||||||
|
// fmt.Println("read", n)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
buf = r.buf[r.rp : r.rp+n : r.rp+n]
|
||||||
|
r.rp += n
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,75 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) {
|
||||||
|
server := &bytes.Buffer{}
|
||||||
|
r := newChunkReader(server, 4)
|
||||||
|
|
||||||
|
src := []byte{1, 2, 3, 4}
|
||||||
|
server.Write(src)
|
||||||
|
|
||||||
|
n1, err := r.Next(2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if bytes.Compare(n1, src[0:2]) != 0 {
|
||||||
|
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:2], n1)
|
||||||
|
}
|
||||||
|
|
||||||
|
n2, err := r.Next(2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if bytes.Compare(n2, src[2:4]) != 0 {
|
||||||
|
t.Fatalf("Expected read bytes to be %v, but they were %v", src[2:4], n2)
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.Compare(r.buf[:len(src)], src) != 0 {
|
||||||
|
t.Fatalf("Expected r.buf to be %v, but it was %v", src, r.buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = r.Next(0) // Trigger the buffer reset.
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.rp != 0 {
|
||||||
|
t.Fatalf("Expected r.rp to be %v, but it was %v", 0, r.rp)
|
||||||
|
}
|
||||||
|
if r.wp != 0 {
|
||||||
|
t.Fatalf("Expected r.wp to be %v, but it was %v", 0, r.wp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type randomReader struct {
|
||||||
|
rnd *rand.Rand
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read reads a random number of random bytes.
|
||||||
|
func (r *randomReader) Read(p []byte) (n int, err error) {
|
||||||
|
n = r.rnd.Intn(len(p) + 1)
|
||||||
|
return r.rnd.Read(p[:n])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChunkReaderNextFuzz(t *testing.T) {
|
||||||
|
rr := &randomReader{rnd: rand.New(rand.NewSource(1))}
|
||||||
|
r := newChunkReader(rr, 8192)
|
||||||
|
|
||||||
|
randomSizes := rand.New(rand.NewSource(0))
|
||||||
|
|
||||||
|
for i := 0; i < 100000; i++ {
|
||||||
|
size := randomSizes.Intn(16384) + 1
|
||||||
|
buf, err := r.Next(size)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(buf) != size {
|
||||||
|
t.Fatalf("Expected to get %v bytes but got %v bytes", size, len(buf))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,89 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Close struct {
|
||||||
|
ObjectType byte // 'S' = prepared statement, 'P' = portal
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*Close) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *Close) Decode(src []byte) error {
|
||||||
|
if len(src) < 2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Close"}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.ObjectType = src[0]
|
||||||
|
rp := 1
|
||||||
|
|
||||||
|
idx := bytes.IndexByte(src[rp:], 0)
|
||||||
|
if idx != len(src[rp:])-1 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Close"}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Name = string(src[rp : len(src)-1])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *Close) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'C')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = append(dst, src.ObjectType)
|
||||||
|
dst = append(dst, src.Name...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src Close) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ObjectType string
|
||||||
|
Name string
|
||||||
|
}{
|
||||||
|
Type: "Close",
|
||||||
|
ObjectType: string(src.ObjectType),
|
||||||
|
Name: src.Name,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *Close) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
ObjectType string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(msg.ObjectType) != 1 {
|
||||||
|
return errors.New("invalid length for Close.ObjectType")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.ObjectType = byte(msg.ObjectType[0])
|
||||||
|
dst.Name = msg.Name
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CloseComplete struct{}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*CloseComplete) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *CloseComplete) Decode(src []byte) error {
|
||||||
|
if len(src) != 0 {
|
||||||
|
return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *CloseComplete) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, '3', 0, 0, 0, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src CloseComplete) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "CloseComplete",
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,74 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CommandComplete struct {
|
||||||
|
CommandTag []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*CommandComplete) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *CommandComplete) Decode(src []byte) error {
|
||||||
|
idx := bytes.IndexByte(src, 0)
|
||||||
|
if idx == -1 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "CommandComplete", details: "unterminated string"}
|
||||||
|
}
|
||||||
|
if idx != len(src)-1 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "CommandComplete", details: "string terminated too early"}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.CommandTag = src[:idx]
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *CommandComplete) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'C')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = append(dst, src.CommandTag...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src CommandComplete) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
CommandTag string
|
||||||
|
}{
|
||||||
|
Type: "CommandComplete",
|
||||||
|
CommandTag: string(src.CommandTag),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *CommandComplete) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
CommandTag string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.CommandTag = []byte(msg.CommandTag)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,95 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CopyBothResponse struct {
|
||||||
|
OverallFormat byte
|
||||||
|
ColumnFormatCodes []uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*CopyBothResponse) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *CopyBothResponse) Decode(src []byte) error {
|
||||||
|
buf := bytes.NewBuffer(src)
|
||||||
|
|
||||||
|
if buf.Len() < 3 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
|
||||||
|
}
|
||||||
|
|
||||||
|
overallFormat := buf.Next(1)[0]
|
||||||
|
|
||||||
|
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||||
|
if buf.Len() != columnCount*2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
|
||||||
|
}
|
||||||
|
|
||||||
|
columnFormatCodes := make([]uint16, columnCount)
|
||||||
|
for i := 0; i < columnCount; i++ {
|
||||||
|
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
|
||||||
|
}
|
||||||
|
|
||||||
|
*dst = CopyBothResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *CopyBothResponse) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'W')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
dst = append(dst, src.OverallFormat)
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||||
|
for _, fc := range src.ColumnFormatCodes {
|
||||||
|
dst = pgio.AppendUint16(dst, fc)
|
||||||
|
}
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src CopyBothResponse) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ColumnFormatCodes []uint16
|
||||||
|
}{
|
||||||
|
Type: "CopyBothResponse",
|
||||||
|
ColumnFormatCodes: src.ColumnFormatCodes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *CopyBothResponse) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
OverallFormat string
|
||||||
|
ColumnFormatCodes []uint16
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(msg.OverallFormat) != 1 {
|
||||||
|
return errors.New("invalid length for CopyBothResponse.OverallFormat")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.OverallFormat = msg.OverallFormat[0]
|
||||||
|
dst.ColumnFormatCodes = msg.ColumnFormatCodes
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
package pgproto3_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgproto3"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEncodeDecode(t *testing.T) {
|
||||||
|
srcBytes := []byte{'W', 0x00, 0x00, 0x00, 0x0b, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01}
|
||||||
|
dstResp := pgproto3.CopyBothResponse{}
|
||||||
|
err := dstResp.Decode(srcBytes[5:])
|
||||||
|
assert.NoError(t, err, "No errors on decode")
|
||||||
|
dstBytes := []byte{}
|
||||||
|
dstBytes = dstResp.Encode(dstBytes)
|
||||||
|
assert.EqualValues(t, srcBytes, dstBytes, "Expecting src & dest bytes to match")
|
||||||
|
}
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CopyData struct {
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*CopyData) Backend() {}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*CopyData) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *CopyData) Decode(src []byte) error {
|
||||||
|
dst.Data = src
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *CopyData) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'd')
|
||||||
|
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
|
||||||
|
dst = append(dst, src.Data...)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src CopyData) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Data string
|
||||||
|
}{
|
||||||
|
Type: "CopyData",
|
||||||
|
Data: hex.EncodeToString(src.Data),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *CopyData) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
Data string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Data = []byte(msg.Data)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user