Merge branch 'v5-dev'
This commit is contained in:
@@ -2,7 +2,7 @@ name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master ]
|
||||
branches: [ master, v5-dev ]
|
||||
pull_request:
|
||||
branches: [ master ]
|
||||
|
||||
@@ -14,21 +14,52 @@ jobs:
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
go-version: [1.16, 1.17]
|
||||
go-version: [1.18]
|
||||
pg-version: [10, 11, 12, 13, 14, cockroachdb]
|
||||
include:
|
||||
- pg-version: 10
|
||||
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
|
||||
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
|
||||
- pg-version: 11
|
||||
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
|
||||
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
|
||||
- pg-version: 12
|
||||
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
|
||||
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
|
||||
- pg-version: 13
|
||||
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
|
||||
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
|
||||
- pg-version: 14
|
||||
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
|
||||
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
|
||||
- pg-version: cockroachdb
|
||||
pgx-test-database: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on"
|
||||
pgx-test-conn-string: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on"
|
||||
|
||||
steps:
|
||||
|
||||
@@ -49,3 +80,9 @@ jobs:
|
||||
run: go test -race ./...
|
||||
env:
|
||||
PGX_TEST_DATABASE: ${{ matrix.pgx-test-database }}
|
||||
PGX_TEST_CONN_STRING: ${{ matrix.pgx-test-conn-string }}
|
||||
PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }}
|
||||
PGX_TEST_TCP_CONN_STRING: ${{ matrix.pgx-test-tcp-conn-string }}
|
||||
PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }}
|
||||
PGX_TEST_MD5_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-md5-password-conn-string }}
|
||||
PGX_TEST_PLAIN_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-plain-password-conn-string }}
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
language: go
|
||||
|
||||
go:
|
||||
- 1.x
|
||||
- tip
|
||||
|
||||
matrix:
|
||||
allow_failures:
|
||||
- go: tip
|
||||
+115
-206
@@ -1,268 +1,177 @@
|
||||
# 4.17.2 (September 3, 2022)
|
||||
# v5.0.0
|
||||
|
||||
* Fix panic when logging batch error (Tom Möller)
|
||||
## Merged Packages
|
||||
|
||||
# 4.17.1 (August 27, 2022)
|
||||
`github.com/jackc/pgtype`, `github.com/jackc/pgconn`, and `github.com/jackc/pgproto3` are now included in the main
|
||||
`github.com/jackc/pgx` repository. Previously there was confusion as to where issues should be reported, additional
|
||||
release work due to releasing multiple packages, and less clear changelogs.
|
||||
|
||||
* Upgrade puddle to v1.3.0 - fixes context failing to cancel Acquire when acquire is creating resource which was introduced in v4.17.0 (James Hartig)
|
||||
* Fix atomic alignment on 32-bit platforms
|
||||
## pgconn
|
||||
|
||||
# 4.17.0 (August 6, 2022)
|
||||
`CommandTag` is now an opaque type instead of directly exposing an underlying `[]byte`.
|
||||
|
||||
* Upgrade pgconn to v1.13.0
|
||||
* Upgrade pgproto3 to v2.3.1
|
||||
* Upgrade pgtype to v1.12.0
|
||||
* Allow background pool connections to continue even if cause is canceled (James Hartig)
|
||||
* Add LoggerFunc (Gabor Szabad)
|
||||
* pgxpool: health check should avoid going below minConns (James Hartig)
|
||||
* Add pgxpool.Conn.Hijack()
|
||||
* Logging improvements (Stepan Rabotkin)
|
||||
The return value `ResultReader.Values()` is no longer safe to retain a reference to after a subsequent call to `NextRow()` or `Close()`.
|
||||
|
||||
# 4.16.1 (May 7, 2022)
|
||||
`Trace()` method adds low level message tracing similar to the `PQtrace` function in `libpq`.
|
||||
|
||||
* Upgrade pgconn to v1.12.1
|
||||
* Fix explicitly prepared statements with describe statement cache mode
|
||||
pgconn now uses non-blocking IO. This is a significant internal restructuring, but it should not cause any visible changes on its own. However, it is important in implementing other new features.
|
||||
|
||||
# 4.16.0 (April 21, 2022)
|
||||
`CheckConn()` checks a connection's liveness by doing a non-blocking read. This can be used to detect database restarts or network interruptions without executing a query or a ping.
|
||||
|
||||
* Upgrade pgconn to v1.12.0
|
||||
* Upgrade pgproto3 to v2.3.0
|
||||
* Upgrade pgtype to v1.11.0
|
||||
* Fix: Do not panic when context cancelled while getting statement from cache.
|
||||
* Fix: Less memory pinning from old Rows.
|
||||
* Fix: Support '\r' line ending when sanitizing SQL comment.
|
||||
* Add pluggable GSSAPI support (Oliver Tan)
|
||||
pgconn now supports pipeline mode.
|
||||
|
||||
# 4.15.0 (February 7, 2022)
|
||||
`*PgConn.ReceiveResults` removed. Use pipeline mode instead.
|
||||
|
||||
* Upgrade to pgconn v1.11.0
|
||||
* Upgrade to pgtype v1.10.0
|
||||
* Upgrade puddle to v1.2.1
|
||||
* Make BatchResults.Close safe to be called multiple times
|
||||
`Timeout()` no longer considers `context.Canceled` as a timeout error. `context.DeadlineExceeded` still is considered a timeout error.
|
||||
|
||||
# 4.14.1 (November 28, 2021)
|
||||
## pgxpool
|
||||
|
||||
* Upgrade pgtype to v1.9.1 (fixes unintentional change to timestamp binary decoding)
|
||||
* Start pgxpool background health check after initial connections
|
||||
`Connect` and `ConnectConfig` have been renamed to `New` and `NewWithConfig` respectively. The `LazyConnect` option has been removed. Pools always lazily connect.
|
||||
|
||||
# 4.14.0 (November 20, 2021)
|
||||
## pgtype
|
||||
|
||||
* Upgrade pgconn to v1.10.1
|
||||
* Upgrade pgproto3 to v2.2.0
|
||||
* Upgrade pgtype to v1.9.0
|
||||
* Upgrade puddle to v1.2.0
|
||||
* Add QueryFunc to BatchResults
|
||||
* Add context options to zerologadapter (Thomas Frössman)
|
||||
* Add zerologadapter.NewContextLogger (urso)
|
||||
* Eager initialize minpoolsize on connect (Daniel)
|
||||
* Unpin memory used by large queries immediately after use
|
||||
The `pgtype` package has been significantly changed.
|
||||
|
||||
# 4.13.0 (July 24, 2021)
|
||||
### NULL Representation
|
||||
|
||||
* Trimmed pseudo-dependencies in Go modules from other packages tests
|
||||
* Upgrade pgconn -- context cancellation no longer will return a net.Error
|
||||
* Support time durations for simple protocol (Michael Darr)
|
||||
Previously, types had a `Status` field that could be `Undefined`, `Null`, or `Present`. This has been changed to a
|
||||
`Valid` `bool` field to harmonize with how `database/sql` represents `NULL` and to make the zero value useable.
|
||||
|
||||
# 4.12.0 (July 10, 2021)
|
||||
### Codec and Value Split
|
||||
|
||||
* ResetSession hook is called before a connection is reused from pool for another query (Dmytro Haranzha)
|
||||
* stdlib: Add RandomizeHostOrderFunc (dkinder)
|
||||
* stdlib: add OptionBeforeConnect (dkinder)
|
||||
* stdlib: Do not reuse ConnConfig strings (Andrew Kimball)
|
||||
* stdlib: implement Conn.ResetSession (Jonathan Amsterdam)
|
||||
* Upgrade pgconn to v1.9.0
|
||||
* Upgrade pgtype to v1.8.0
|
||||
Previously, the type system combined decoding and encoding values with the value types. e.g. Type `Int8` both handled
|
||||
encoding and decoding the PostgreSQL representation and acted as a value object. This caused some difficulties when
|
||||
there was not an exact 1 to 1 relationship between the Go types and the PostgreSQL types For example, scanning a
|
||||
PostgreSQL binary `numeric` into a Go `float64` was awkward (see https://github.com/jackc/pgtype/issues/147). This
|
||||
concepts have been separated. A `Codec` only has responsibility for encoding and decoding values. Value types are
|
||||
generally defined by implementing an interface that a particular `Codec` understands (e.g. `PointScanner` and
|
||||
`PointValuer` for the PostgreSQL `point` type).
|
||||
|
||||
# 4.11.0 (March 25, 2021)
|
||||
### Array Types
|
||||
|
||||
* Add BeforeConnect callback to pgxpool.Config (Robert Froehlich)
|
||||
* Add Ping method to pgxpool.Conn (davidsbond)
|
||||
* Added a kitlog level log adapter (Fabrice Aneche)
|
||||
* Make ScanArgError public to allow identification of offending column (Pau Sanchez)
|
||||
* Add *pgxpool.AcquireFunc
|
||||
* Add BeginFunc and BeginTxFunc
|
||||
* Add prefer_simple_protocol to connection string
|
||||
* Add logging on CopyFrom (Patrick Hemmer)
|
||||
* Add comment support when sanitizing SQL queries (Rusakow Andrew)
|
||||
* Do not panic on double close of pgxpool.Pool (Matt Schultz)
|
||||
* Avoid panic on SendBatch on closed Tx (Matt Schultz)
|
||||
* Update pgconn to v1.8.1
|
||||
* Update pgtype to v1.7.0
|
||||
All array types are now handled by `ArrayCodec` instead of using code generation for each new array type. This also
|
||||
means that less common array types such as `point[]` are now supported. `Array[T]` supports PostgreSQL multi-dimensional
|
||||
arrays.
|
||||
|
||||
# 4.10.1 (December 19, 2020)
|
||||
### Composite Types
|
||||
|
||||
* Fix panic on Query error with nil stmtcache.
|
||||
Composite types must be registered before use. `CompositeFields` may still be used to construct and destruct composite
|
||||
values, but any type may now implement `CompositeIndexGetter` and `CompositeIndexScanner` to be used as a composite.
|
||||
|
||||
# 4.10.0 (December 3, 2020)
|
||||
### Range Types
|
||||
|
||||
* Add CopyFromSlice to simplify CopyFrom usage (Egon Elbre)
|
||||
* Remove broken prepared statements from stmtcache (Ethan Pailes)
|
||||
* stdlib: consider any Ping error as fatal
|
||||
* Update puddle to v1.1.3 - this fixes an issue where concurrent Acquires can hang when a connection cannot be established
|
||||
* Update pgtype to v1.6.2
|
||||
Range types are now handled with types `RangeCodec` and `Range[T]`. This allows additional user defined range types to
|
||||
easily be handled. Multirange types are handled similarly with `MultirangeCodec` and `Multirange[T]`.
|
||||
|
||||
# 4.9.2 (November 3, 2020)
|
||||
### pgxtype
|
||||
|
||||
The underlying library updates fix an issue where appending to a scanned slice could corrupt other data.
|
||||
`LoadDataType` moved to `*Conn` as `LoadType`.
|
||||
|
||||
* Update pgconn to v1.7.2
|
||||
* Update pgproto3 to v2.0.6
|
||||
### Bytea
|
||||
|
||||
# 4.9.1 (October 31, 2020)
|
||||
The `Bytea` and `GenericBinary` types have been replaced. Use the following instead:
|
||||
|
||||
* Update pgconn to v1.7.1
|
||||
* Update pgtype to v1.6.1
|
||||
* Fix SendBatch of all prepared statements with statement cache disabled
|
||||
* `[]byte` - For normal usage directly use `[]byte`.
|
||||
* `DriverBytes` - Uses driver memory only available until next database method call. Avoids a copy and an allocation.
|
||||
* `PreallocBytes` - Uses preallocated byte slice to avoid an allocation.
|
||||
* `UndecodedBytes` - Avoids any decoding. Allows working with raw bytes.
|
||||
|
||||
# 4.9.0 (September 26, 2020)
|
||||
### Dropped lib/pq Support
|
||||
|
||||
* pgxpool now waits for connection cleanup to finish before making room in pool for another connection. This prevents temporarily exceeding max pool size.
|
||||
* Fix when scanning a column to nil to skip it on the first row but scanning it to a real value on a subsequent row.
|
||||
* Fix prefer simple protocol with prepared statements. (Jinzhu)
|
||||
* Fix FieldDescriptions not being available on Rows before calling Next the first time.
|
||||
* Various minor fixes in updated versions of pgconn, pgtype, and puddle.
|
||||
`pgtype` previously supported and was tested against [lib/pq](https://github.com/lib/pq). While it will continue to work
|
||||
in most cases this is no longer supported.
|
||||
|
||||
# 4.8.1 (July 29, 2020)
|
||||
### database/sql Scan
|
||||
|
||||
* Update pgconn to v1.6.4
|
||||
* Fix deadlock on error after CommandComplete but before ReadyForQuery
|
||||
* Fix panic on parsing DSN with trailing '='
|
||||
Previously, most `Scan` implementations would convert `[]byte` to `string` automatically to decode a text value. Now
|
||||
only `string` is handled. This is to allow the possibility of future binary support in `database/sql` mode by
|
||||
considering `[]byte` to be binary format and `string` text format. This change should have no effect for any use with
|
||||
`pgx`. The previous behavior was only necessary for `lib/pq` compatibility.
|
||||
|
||||
# 4.8.0 (July 22, 2020)
|
||||
Added `*Map.SQLScanner` to create a `sql.Scanner` for types such as `[]int32` and `Range[T]` that do not implement
|
||||
`sql.Scanner` directly.
|
||||
|
||||
* All argument types supported by native pgx should now also work through database/sql
|
||||
* Update pgconn to v1.6.3
|
||||
* Update pgtype to v1.4.2
|
||||
### Number Type Fields Include Bit size
|
||||
|
||||
# 4.7.2 (July 14, 2020)
|
||||
`Int2`, `Int4`, `Int8`, `Float4`, `Float8`, and `Uint32` fields now include bit size. e.g. `Int` is renamed to `Int64`.
|
||||
This matches the convention set by `database/sql`. In addition, for comparable types like `pgtype.Int8` and
|
||||
`sql.NullInt64` the structures are identical. This means they can be directly converted one to another.
|
||||
|
||||
* Improve performance of Columns() (zikaeroh)
|
||||
* Fix fatal Commit() failure not being considered fatal
|
||||
* Update pgconn to v1.6.2
|
||||
* Update pgtype to v1.4.1
|
||||
### 3rd Party Type Integrations
|
||||
|
||||
# 4.7.1 (June 29, 2020)
|
||||
* Extracted integrations with https://github.com/shopspring/decimal and https://github.com/gofrs/uuid to
|
||||
https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid respectively. This trims
|
||||
the pgx dependency tree.
|
||||
|
||||
* Fix stdlib decoding error with certain order and combination of fields
|
||||
### Other Changes
|
||||
|
||||
# 4.7.0 (June 27, 2020)
|
||||
* `Bit` and `Varbit` are both replaced by the `Bits` type.
|
||||
* `CID`, `OID`, `OIDValue`, and `XID` are replaced by the `Uint32` type.
|
||||
* `Hstore` is now defined as `map[string]*string`.
|
||||
* `JSON` and `JSONB` types removed. Use `[]byte` or `string` directly.
|
||||
* `QChar` type removed. Use `rune` or `byte` directly.
|
||||
* `Inet` and `Cidr` types removed. Use `netip.Addr` and `netip.Prefix` directly. These types are more memory efficient than the previous `net.IPNet`.
|
||||
* `Macaddr` type removed. Use `net.HardwareAddr` directly.
|
||||
* Renamed `pgtype.ConnInfo` to `pgtype.Map`.
|
||||
* Renamed `pgtype.DataType` to `pgtype.Type`.
|
||||
* Renamed `pgtype.None` to `pgtype.Finite`.
|
||||
* `RegisterType` now accepts a `*Type` instead of `Type`.
|
||||
* Assorted array helper methods and types made private.
|
||||
|
||||
* Update pgtype to v1.4.0
|
||||
* Update pgconn to v1.6.1
|
||||
* Update puddle to v1.1.1
|
||||
* Fix context propagation with Tx commit and Rollback (georgysavva)
|
||||
* Add lazy connect option to pgxpool (georgysavva)
|
||||
* Fix connection leak if pgxpool.BeginTx() fail (Jean-Baptiste Bronisz)
|
||||
* Add native Go slice support for strings and numbers to simple protocol
|
||||
* stdlib add default timeouts for Conn.Close() and Stmt.Close() (georgysavva)
|
||||
* Assorted performance improvements especially with large result sets
|
||||
* Fix close pool on not lazy connect failure (Yegor Myskin)
|
||||
* Add Config copy (georgysavva)
|
||||
* Support SendBatch with Simple Protocol (Jordan Lewis)
|
||||
* Better error logging on rows close (Igor V. Kozinov)
|
||||
* Expose stdlib.Conn.Conn() to enable database/sql.Conn.Raw()
|
||||
* Improve unknown type support for database/sql
|
||||
* Fix transaction commit failure closing connection
|
||||
## stdlib
|
||||
|
||||
# 4.6.0 (March 30, 2020)
|
||||
* Removed `AcquireConn` and `ReleaseConn` as that functionality has been built in since Go 1.13.
|
||||
|
||||
* stdlib: Bail early if preloading rows.Next() results in rows.Err() (Bas van Beek)
|
||||
* Sanitize time to microsecond accuracy (Andrew Nicoll)
|
||||
* Update pgtype to v1.3.0
|
||||
* Update pgconn to v1.5.0
|
||||
* Update golang.org/x/crypto for security fix
|
||||
* Implement "verify-ca" SSL mode
|
||||
## Reduced Memory Usage by Reusing Read Buffers
|
||||
|
||||
# 4.5.0 (March 7, 2020)
|
||||
Previously, the connection read buffer would allocate large chunks of memory and never reuse them. This allowed
|
||||
transferring ownership to anything such as scanned values without incurring an additional allocation and memory copy.
|
||||
However, this came at the cost of overall increased memory allocation size. But worse it was also possible to pin large
|
||||
chunks of memory by retaining a reference to a small value that originally came directly from the read buffer. Now
|
||||
ownership remains with the read buffer and anything needing to retain a value must make a copy.
|
||||
|
||||
* Update to pgconn v1.4.0
|
||||
* Fixes QueryRow with empty SQL
|
||||
* Adds PostgreSQL service file support
|
||||
* Add Len() to *pgx.Batch (WGH)
|
||||
* Better logging for individual batch items (Ben Bader)
|
||||
## Query Execution Modes
|
||||
|
||||
# 4.4.1 (February 14, 2020)
|
||||
Control over automatic prepared statement caching and simple protocol use are now combined into query execution mode.
|
||||
See documentation for `QueryExecMode`.
|
||||
|
||||
* Update pgconn to v1.3.2 - better default read buffer size
|
||||
* Fix race in CopyFrom
|
||||
## QueryRewriter Interface and NamedArgs
|
||||
|
||||
# 4.4.0 (February 5, 2020)
|
||||
pgx now supports named arguments with the `NamedArgs` type. This is implemented via the new `QueryRewriter` interface which
|
||||
allows arbitrary rewriting of query SQL and arguments.
|
||||
|
||||
* Update puddle to v1.1.0 - fixes possible deadlock when acquire is cancelled
|
||||
* Update pgconn to v1.3.1 - fixes CopyFrom deadlock when multiple NoticeResponse received during copy
|
||||
* Update pgtype to v1.2.0
|
||||
* Add MaxConnIdleTime to pgxpool (Patrick Ellul)
|
||||
* Add MinConns to pgxpool (Patrick Ellul)
|
||||
* Fix: stdlib.ReleaseConn closes connections left in invalid state
|
||||
## RowScanner Interface
|
||||
|
||||
# 4.3.0 (January 23, 2020)
|
||||
The `RowScanner` interface allows a single argument to Rows.Scan to scan the entire row.
|
||||
|
||||
* Fix Rows.Values panic when unable to decode
|
||||
* Add Rows.Values support for unknown types
|
||||
* Add DriverContext support for stdlib (Alex Gaynor)
|
||||
* Update pgproto3 to v2.0.1 to never return an io.EOF as it would be misinterpreted by database/sql. Instead return io.UnexpectedEOF.
|
||||
## Rows Result Helpers
|
||||
|
||||
# 4.2.1 (January 13, 2020)
|
||||
* `CollectRows` and `RowTo*` functions simplify collecting results into a slice.
|
||||
* `CollectOneRow` collects one row using `RowTo*` functions.
|
||||
* `ForEachRow` simplifies scanning each row and executing code using the scanned values. `ForEachRow` replaces `QueryFunc`.
|
||||
|
||||
* Update pgconn to v1.2.1 (fixes context cancellation data race introduced in v1.2.0))
|
||||
## Tx Helpers
|
||||
|
||||
# 4.2.0 (January 11, 2020)
|
||||
Rather than every type that implemented `Begin` or `BeginTx` methods also needing to implement `BeginFunc` and
|
||||
`BeginTxFunc` these methods have been converted to functions that take a db that implements `Begin` or `BeginTx`.
|
||||
|
||||
* Update pgconn to v1.2.0.
|
||||
* Update pgtype to v1.1.0.
|
||||
* Return error instead of panic when wrong number of arguments passed to Exec. (malstoun)
|
||||
* Fix large objects functionality when PreferSimpleProtocol = true.
|
||||
* Restore GetDefaultDriver which existed in v3. (Johan Brandhorst)
|
||||
* Add RegisterConnConfig to stdlib which replaces the removed RegisterDriverConfig from v3.
|
||||
## Improved Batch Query Ergonomics
|
||||
|
||||
# 4.1.2 (October 22, 2019)
|
||||
Previously, the code for building a batch went in one place before the call to `SendBatch`, and the code for reading the
|
||||
results went in one place after the call to `SendBatch`. This could make it difficult to match up the query and the code
|
||||
to handle the results. Now `Queue` returns a `QueuedQuery` which has methods `Query`, `QueryRow`, and `Exec` which can
|
||||
be used to register a callback function that will handle the result. Callback functions are called automatically when
|
||||
`BatchResults.Close` is called.
|
||||
|
||||
* Fix dbSavepoint.Begin recursive self call
|
||||
* Upgrade pgtype to v1.0.2 - fix scan pointer to pointer
|
||||
## SendBatch Uses Pipeline Mode When Appropriate
|
||||
|
||||
# 4.1.1 (October 21, 2019)
|
||||
Previously, a batch with 10 unique parameterized statements executed 100 times would entail 11 network round trips. 1
|
||||
for each prepare / describe and 1 for executing them all. Now pipeline mode is used to prepare / describe all statements
|
||||
in a single network round trip. So it would only take 2 round trips.
|
||||
|
||||
* Fix pgxpool Rows.CommandTag() infinite loop / typo
|
||||
## Tracing and Logging
|
||||
|
||||
# 4.1.0 (October 12, 2019)
|
||||
Internal logging support has been replaced with tracing hooks. This allows custom tracing integration with tools like OpenTelemetry. Package tracelog provides an adapter for pgx v4 loggers to act as a tracer.
|
||||
|
||||
## Potentially Breaking Changes
|
||||
|
||||
Technically, two changes are breaking changes, but in practice these are extremely unlikely to break existing code.
|
||||
|
||||
* Conn.Begin and Conn.BeginTx return a Tx interface instead of the internal dbTx struct. This is necessary for the Conn.Begin method to signature as other methods that begin a transaction.
|
||||
* Add Conn() to Tx interface. This is necessary to allow code using a Tx to access the *Conn (and pgconn.PgConn) on which the Tx is executing.
|
||||
|
||||
## Fixes
|
||||
|
||||
* Releasing a busy connection closes the connection instead of returning an unusable connection to the pool
|
||||
* Do not mutate config.Config.OnNotification in connect
|
||||
|
||||
# 4.0.1 (September 19, 2019)
|
||||
|
||||
* Fix statement cache cleanup.
|
||||
* Corrected daterange OID.
|
||||
* Fix Tx when committing or rolling back multiple times in certain cases.
|
||||
* Improve documentation.
|
||||
|
||||
# 4.0.0 (September 14, 2019)
|
||||
|
||||
v4 is a major release with many significant changes some of which are breaking changes. The most significant are
|
||||
included below.
|
||||
|
||||
* Simplified establishing a connection with a connection string.
|
||||
* All potentially blocking operations now require a context.Context. The non-context aware functions have been removed.
|
||||
* OIDs are hard-coded for known types. This saves the query on connection.
|
||||
* Context cancellations while network activity is in progress is now always fatal. Previously, it was sometimes recoverable. This led to increased complexity in pgx itself and in application code.
|
||||
* Go modules are required.
|
||||
* Errors are now implemented in the Go 1.13 style.
|
||||
* `Rows` and `Tx` are now interfaces.
|
||||
* The connection pool as been decoupled from pgx and is now a separate, included package (github.com/jackc/pgx/v4/pgxpool).
|
||||
* pgtype has been spun off to a separate package (github.com/jackc/pgtype).
|
||||
* pgproto3 has been spun off to a separate package (github.com/jackc/pgproto3/v2).
|
||||
* Logical replication support has been spun off to a separate package (github.com/jackc/pglogrepl).
|
||||
* Lower level PostgreSQL functionality is now implemented in a separate package (github.com/jackc/pgconn).
|
||||
* Tests are now configured with environment variables.
|
||||
* Conn has an automatic statement cache by default.
|
||||
* Batch interface has been simplified.
|
||||
* QueryArgs has been removed.
|
||||
All integrations with 3rd party loggers have been extracted to separate repositories. This trims the pgx dependency
|
||||
tree.
|
||||
|
||||
@@ -1,25 +1,17 @@
|
||||
[](https://pkg.go.dev/github.com/jackc/pgx/v4)
|
||||
[](https://travis-ci.org/jackc/pgx)
|
||||
[](https://pkg.go.dev/github.com/jackc/pgx/v5)
|
||||

|
||||
|
||||
---
|
||||
|
||||
This is the stable `v4` release. `v5` is now in beta testing with final release expected in September. See https://github.com/jackc/pgx/issues/1273 for more information. Please consider testing `v5`.
|
||||
|
||||
---
|
||||
# pgx - PostgreSQL Driver and Toolkit
|
||||
|
||||
pgx is a pure Go driver and toolkit for PostgreSQL.
|
||||
|
||||
pgx aims to be low-level, fast, and performant, while also enabling PostgreSQL-specific features that the standard `database/sql` package does not allow for.
|
||||
|
||||
The driver component of pgx can be used alongside the standard `database/sql` package.
|
||||
The pgx driver is a low-level, high performance interface that exposes PostgreSQL-specific features such as `LISTEN` /
|
||||
`NOTIFY` and `COPY`. It also includes an adapter for the standard `database/sql` interface.
|
||||
|
||||
The toolkit component is a related set of packages that implement PostgreSQL functionality such as parsing the wire protocol
|
||||
and type mapping between PostgreSQL and Go. These underlying packages can be used to implement alternative drivers,
|
||||
proxies, load balancers, logical replication clients, etc.
|
||||
|
||||
The current release of `pgx v4` requires Go modules. To use the previous version, checkout and vendor the `v3` branch.
|
||||
|
||||
## Example Usage
|
||||
|
||||
```go
|
||||
@@ -30,7 +22,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -56,52 +48,39 @@ func main() {
|
||||
|
||||
See the [getting started guide](https://github.com/jackc/pgx/wiki/Getting-started-with-pgx) for more information.
|
||||
|
||||
## Choosing Between the pgx and database/sql Interfaces
|
||||
|
||||
It is recommended to use the pgx interface if:
|
||||
1. The application only targets PostgreSQL.
|
||||
2. No other libraries that require `database/sql` are in use.
|
||||
|
||||
The pgx interface is faster and exposes more features.
|
||||
|
||||
The `database/sql` interface only allows the underlying driver to return or receive the following types: `int64`,
|
||||
`float64`, `bool`, `[]byte`, `string`, `time.Time`, or `nil`. Handling other types requires implementing the
|
||||
`database/sql.Scanner` and the `database/sql/driver/driver.Valuer` interfaces which require transmission of values in text format. The binary format can be substantially faster, which is what the pgx interface uses.
|
||||
|
||||
## Features
|
||||
|
||||
pgx supports many features beyond what is available through `database/sql`:
|
||||
|
||||
* Support for approximately 70 different PostgreSQL types
|
||||
* Automatic statement preparation and caching
|
||||
* Batch queries
|
||||
* Single-round trip query mode
|
||||
* Full TLS connection control
|
||||
* Binary format support for custom types (allows for much quicker encoding/decoding)
|
||||
* COPY protocol support for faster bulk data loads
|
||||
* Extendable logging support including built-in support for `log15adapter`, [`logrus`](https://github.com/sirupsen/logrus), [`zap`](https://github.com/uber-go/zap), and [`zerolog`](https://github.com/rs/zerolog)
|
||||
* `COPY` protocol support for faster bulk data loads
|
||||
* Tracing and logging support
|
||||
* Connection pool with after-connect hook for arbitrary connection setup
|
||||
* Listen / notify
|
||||
* `LISTEN` / `NOTIFY`
|
||||
* Conversion of PostgreSQL arrays to Go slice mappings for integers, floats, and strings
|
||||
* Hstore support
|
||||
* JSON and JSONB support
|
||||
* Maps `inet` and `cidr` PostgreSQL types to `net.IPNet` and `net.IP`
|
||||
* `hstore` support
|
||||
* `json` and `jsonb` support
|
||||
* Maps `inet` and `cidr` PostgreSQL types to `netip.Addr` and `netip.Prefix`
|
||||
* Large object support
|
||||
* NULL mapping to Null* struct or pointer to pointer
|
||||
* NULL mapping to pointer to pointer
|
||||
* Supports `database/sql.Scanner` and `database/sql/driver.Valuer` interfaces for custom types
|
||||
* Notice response handling
|
||||
* Simulated nested transactions with savepoints
|
||||
|
||||
## Performance
|
||||
## Choosing Between the pgx and database/sql Interfaces
|
||||
|
||||
There are three areas in particular where pgx can provide a significant performance advantage over the standard
|
||||
`database/sql` interface and other drivers:
|
||||
The pgx interface is faster. Many PostgreSQL specific features such as `LISTEN` / `NOTIFY` and `COPY` are not available
|
||||
through the `database/sql` interface.
|
||||
|
||||
1. PostgreSQL specific types - Types such as arrays can be parsed much quicker because pgx uses the binary format.
|
||||
2. Automatic statement preparation and caching - pgx will prepare and cache statements by default. This can provide an
|
||||
significant free improvement to code that does not explicitly use prepared statements. Under certain workloads, it can
|
||||
perform nearly 3x the number of queries per second.
|
||||
3. Batched queries - Multiple queries can be batched together to minimize network round trips.
|
||||
The pgx interface is recommended when:
|
||||
|
||||
1. The application only targets PostgreSQL.
|
||||
2. No other libraries that require `database/sql` are in use.
|
||||
|
||||
It is also possible to use the `database/sql` interface and convert a connection to the lower-level pgx interface as needed.
|
||||
|
||||
## Testing
|
||||
|
||||
@@ -134,37 +113,14 @@ In addition, there are tests specific for PgBouncer that will be executed if `PG
|
||||
|
||||
## Supported Go and PostgreSQL Versions
|
||||
|
||||
pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.16 and higher and PostgreSQL 10 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
|
||||
pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.18 and higher and PostgreSQL 10 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
|
||||
|
||||
## Version Policy
|
||||
|
||||
pgx follows semantic versioning for the documented public API on stable releases. `v4` is the latest stable major version.
|
||||
pgx follows semantic versioning for the documented public API on stable releases. `v5` is the latest stable major version.
|
||||
|
||||
## PGX Family Libraries
|
||||
|
||||
pgx is the head of a family of PostgreSQL libraries. Many of these can be used independently. Many can also be accessed
|
||||
from pgx for lower-level control.
|
||||
|
||||
### [github.com/jackc/pgconn](https://github.com/jackc/pgconn)
|
||||
|
||||
`pgconn` is a lower-level PostgreSQL database driver that operates at nearly the same level as the C library `libpq`.
|
||||
|
||||
### [github.com/jackc/pgx/v4/pgxpool](https://github.com/jackc/pgx/tree/master/pgxpool)
|
||||
|
||||
`pgxpool` is a connection pool for pgx. pgx is entirely decoupled from its default pool implementation. This means that pgx can be used with a different pool or without any pool at all.
|
||||
|
||||
### [github.com/jackc/pgx/v4/stdlib](https://github.com/jackc/pgx/tree/master/stdlib)
|
||||
|
||||
This is a `database/sql` compatibility layer for pgx. pgx can be used as a normal `database/sql` driver, but at any time, the native interface can be acquired for more performance or PostgreSQL specific functionality.
|
||||
|
||||
### [github.com/jackc/pgtype](https://github.com/jackc/pgtype)
|
||||
|
||||
Over 70 PostgreSQL types are supported including `uuid`, `hstore`, `json`, `bytea`, `numeric`, `interval`, `inet`, and arrays. These types support `database/sql` interfaces and are usable outside of pgx. They are fully tested in pgx and pq. They also support a higher performance interface when used with the pgx driver.
|
||||
|
||||
### [github.com/jackc/pgproto3](https://github.com/jackc/pgproto3)
|
||||
|
||||
pgproto3 provides standalone encoding and decoding of the PostgreSQL v3 wire protocol. This is useful for implementing very low level PostgreSQL tooling.
|
||||
|
||||
### [github.com/jackc/pglogrepl](https://github.com/jackc/pglogrepl)
|
||||
|
||||
pglogrepl provides functionality to act as a client for PostgreSQL logical replication.
|
||||
@@ -181,6 +137,22 @@ tern is a stand-alone SQL migration system.
|
||||
|
||||
pgerrcode contains constants for the PostgreSQL error codes.
|
||||
|
||||
## Adapters for 3rd Party Types
|
||||
|
||||
* [github.com/jackc/pgx-gofrs-uuid](https://github.com/jackc/pgx-gofrs-uuid)
|
||||
* [github.com/jackc/pgx-shopspring-decimal](https://github.com/jackc/pgx-shopspring-decimal)
|
||||
* [github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid)
|
||||
|
||||
## Adapters for 3rd Party Loggers
|
||||
|
||||
These adapters can be used with the tracelog package.
|
||||
|
||||
* [github.com/jackc/pgx-go-kit-log](https://github.com/jackc/pgx-go-kit-log)
|
||||
* [github.com/jackc/pgx-log15](https://github.com/jackc/pgx-log15)
|
||||
* [github.com/jackc/pgx-logrus](https://github.com/jackc/pgx-logrus)
|
||||
* [github.com/jackc/pgx-zap](https://github.com/jackc/pgx-zap)
|
||||
* [github.com/jackc/pgx-zerolog](https://github.com/jackc/pgx-zerolog)
|
||||
|
||||
## 3rd Party Libraries with PGX Support
|
||||
|
||||
### [github.com/georgysavva/scany](https://github.com/georgysavva/scany)
|
||||
@@ -190,7 +162,3 @@ Library for scanning data from a database into Go structs and more.
|
||||
### [https://github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5)
|
||||
|
||||
Adds GSSAPI / Kerberos authentication support.
|
||||
|
||||
### [https://github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid)
|
||||
|
||||
Adds support for [`github.com/google/uuid`](https://github.com/google/uuid).
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
require "erb"
|
||||
|
||||
rule '.go' => '.go.erb' do |task|
|
||||
erb = ERB.new(File.read(task.source))
|
||||
File.write(task.name, "// Do not edit. Generated from #{task.source}\n" + erb.result(binding))
|
||||
sh "goimports", "-w", task.name
|
||||
end
|
||||
|
||||
generated_code_files = [
|
||||
"pgtype/int.go",
|
||||
"pgtype/int_test.go",
|
||||
"pgtype/integration_benchmark_test.go",
|
||||
"pgtype/zeronull/int.go",
|
||||
"pgtype/zeronull/int_test.go"
|
||||
]
|
||||
|
||||
desc "Generate code"
|
||||
task generate: generated_code_files
|
||||
@@ -5,69 +5,123 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
type batchItem struct {
|
||||
// QueuedQuery is a query that has been queued for execution via a Batch.
|
||||
type QueuedQuery struct {
|
||||
query string
|
||||
arguments []interface{}
|
||||
arguments []any
|
||||
fn batchItemFunc
|
||||
sd *pgconn.StatementDescription
|
||||
}
|
||||
|
||||
type batchItemFunc func(br BatchResults) error
|
||||
|
||||
// Query sets fn to be called when the response to qq is received.
|
||||
func (qq *QueuedQuery) Query(fn func(rows Rows) error) {
|
||||
qq.fn = func(br BatchResults) error {
|
||||
rows, err := br.Query()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
err = fn(rows)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
return rows.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Query sets fn to be called when the response to qq is received.
|
||||
func (qq *QueuedQuery) QueryRow(fn func(row Row) error) {
|
||||
qq.fn = func(br BatchResults) error {
|
||||
row := br.QueryRow()
|
||||
return fn(row)
|
||||
}
|
||||
}
|
||||
|
||||
// Exec sets fn to be called when the response to qq is received.
|
||||
func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) {
|
||||
qq.fn = func(br BatchResults) error {
|
||||
ct, err := br.Exec()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return fn(ct)
|
||||
}
|
||||
}
|
||||
|
||||
// Batch queries are a way of bundling multiple queries together to avoid
|
||||
// unnecessary network round trips.
|
||||
// unnecessary network round trips. A Batch must only be sent once.
|
||||
type Batch struct {
|
||||
items []*batchItem
|
||||
queuedQueries []*QueuedQuery
|
||||
}
|
||||
|
||||
// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement.
|
||||
func (b *Batch) Queue(query string, arguments ...interface{}) {
|
||||
b.items = append(b.items, &batchItem{
|
||||
func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery {
|
||||
qq := &QueuedQuery{
|
||||
query: query,
|
||||
arguments: arguments,
|
||||
})
|
||||
}
|
||||
b.queuedQueries = append(b.queuedQueries, qq)
|
||||
return qq
|
||||
}
|
||||
|
||||
// Len returns number of queries that have been queued so far.
|
||||
func (b *Batch) Len() int {
|
||||
return len(b.items)
|
||||
return len(b.queuedQueries)
|
||||
}
|
||||
|
||||
type BatchResults interface {
|
||||
// Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec.
|
||||
// Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec. Prefer
|
||||
// calling Exec on the QueuedQuery.
|
||||
Exec() (pgconn.CommandTag, error)
|
||||
|
||||
// Query reads the results from the next query in the batch as if the query has been sent with Conn.Query.
|
||||
// Query reads the results from the next query in the batch as if the query has been sent with Conn.Query. Prefer
|
||||
// calling Query on the QueuedQuery.
|
||||
Query() (Rows, error)
|
||||
|
||||
// QueryRow reads the results from the next query in the batch as if the query has been sent with Conn.QueryRow.
|
||||
// Prefer calling QueryRow on the QueuedQuery.
|
||||
QueryRow() Row
|
||||
|
||||
// QueryFunc reads the results from the next query in the batch as if the query has been sent with Conn.QueryFunc.
|
||||
QueryFunc(scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error)
|
||||
|
||||
// Close closes the batch operation. This must be called before the underlying connection can be used again. Any error
|
||||
// that occurred during a batch operation may have made it impossible to resyncronize the connection with the server.
|
||||
// In this case the underlying connection will have been closed. Close is safe to call multiple times.
|
||||
// Close closes the batch operation. All unread results are read and any callback functions registered with
|
||||
// QueuedQuery.Query, QueuedQuery.QueryRow, or QueuedQuery.Exec will be called. If a callback function returns an
|
||||
// error or the batch encounters an error subsequent callback functions will not be called.
|
||||
//
|
||||
// Close must be called before the underlying connection can be used again. Any error that occurred during a batch
|
||||
// operation may have made it impossible to resyncronize the connection with the server. In this case the underlying
|
||||
// connection will have been closed.
|
||||
//
|
||||
// Close is safe to call multiple times. If it returns an error subsequent calls will return the same error. Callback
|
||||
// functions will not be rerun.
|
||||
Close() error
|
||||
}
|
||||
|
||||
type batchResults struct {
|
||||
ctx context.Context
|
||||
conn *Conn
|
||||
mrr *pgconn.MultiResultReader
|
||||
err error
|
||||
b *Batch
|
||||
ix int
|
||||
closed bool
|
||||
ctx context.Context
|
||||
conn *Conn
|
||||
mrr *pgconn.MultiResultReader
|
||||
err error
|
||||
b *Batch
|
||||
qqIdx int
|
||||
closed bool
|
||||
endTraced bool
|
||||
}
|
||||
|
||||
// Exec reads the results from the next query in the batch as if the query has been sent with Exec.
|
||||
func (br *batchResults) Exec() (pgconn.CommandTag, error) {
|
||||
if br.err != nil {
|
||||
return nil, br.err
|
||||
return pgconn.CommandTag{}, br.err
|
||||
}
|
||||
if br.closed {
|
||||
return nil, fmt.Errorf("batch already closed")
|
||||
return pgconn.CommandTag{}, fmt.Errorf("batch already closed")
|
||||
}
|
||||
|
||||
query, arguments, _ := br.nextQueryAndArgs()
|
||||
@@ -77,35 +131,29 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) {
|
||||
if err == nil {
|
||||
err = errors.New("no result")
|
||||
}
|
||||
if br.conn.shouldLog(LogLevelError) {
|
||||
br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]interface{}{
|
||||
"sql": query,
|
||||
"args": logQueryArgs(arguments),
|
||||
"err": err,
|
||||
if br.conn.batchTracer != nil {
|
||||
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
|
||||
SQL: query,
|
||||
Args: arguments,
|
||||
Err: err,
|
||||
})
|
||||
}
|
||||
return nil, err
|
||||
return pgconn.CommandTag{}, err
|
||||
}
|
||||
|
||||
commandTag, err := br.mrr.ResultReader().Close()
|
||||
br.err = err
|
||||
|
||||
if err != nil {
|
||||
if br.conn.shouldLog(LogLevelError) {
|
||||
br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]interface{}{
|
||||
"sql": query,
|
||||
"args": logQueryArgs(arguments),
|
||||
"err": err,
|
||||
})
|
||||
}
|
||||
} else if br.conn.shouldLog(LogLevelInfo) {
|
||||
br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Exec", map[string]interface{}{
|
||||
"sql": query,
|
||||
"args": logQueryArgs(arguments),
|
||||
"commandTag": commandTag,
|
||||
if br.conn.batchTracer != nil {
|
||||
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
|
||||
SQL: query,
|
||||
Args: arguments,
|
||||
CommandTag: commandTag,
|
||||
Err: br.err,
|
||||
})
|
||||
}
|
||||
|
||||
return commandTag, err
|
||||
return commandTag, br.err
|
||||
}
|
||||
|
||||
// Query reads the results from the next query in the batch as if the query has been sent with Query.
|
||||
@@ -116,15 +164,16 @@ func (br *batchResults) Query() (Rows, error) {
|
||||
}
|
||||
|
||||
if br.err != nil {
|
||||
return &connRows{err: br.err, closed: true}, br.err
|
||||
return &baseRows{err: br.err, closed: true}, br.err
|
||||
}
|
||||
|
||||
if br.closed {
|
||||
alreadyClosedErr := fmt.Errorf("batch already closed")
|
||||
return &connRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr
|
||||
return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr
|
||||
}
|
||||
|
||||
rows := br.conn.getRows(br.ctx, query, arguments)
|
||||
rows.batchTracer = br.conn.batchTracer
|
||||
|
||||
if !br.mrr.NextResult() {
|
||||
rows.err = br.mrr.Close()
|
||||
@@ -133,11 +182,11 @@ func (br *batchResults) Query() (Rows, error) {
|
||||
}
|
||||
rows.closed = true
|
||||
|
||||
if br.conn.shouldLog(LogLevelError) {
|
||||
br.conn.log(br.ctx, LogLevelError, "BatchResult.Query", map[string]interface{}{
|
||||
"sql": query,
|
||||
"args": logQueryArgs(arguments),
|
||||
"err": rows.err,
|
||||
if br.conn.batchTracer != nil {
|
||||
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
|
||||
SQL: query,
|
||||
Args: arguments,
|
||||
Err: rows.err,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -148,47 +197,25 @@ func (br *batchResults) Query() (Rows, error) {
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// QueryFunc reads the results from the next query in the batch as if the query has been sent with Conn.QueryFunc.
|
||||
func (br *batchResults) QueryFunc(scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) {
|
||||
if br.closed {
|
||||
return nil, fmt.Errorf("batch already closed")
|
||||
}
|
||||
|
||||
rows, err := br.Query()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
err = rows.Scan(scans...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = f(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return rows.CommandTag(), nil
|
||||
}
|
||||
|
||||
// QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow.
|
||||
func (br *batchResults) QueryRow() Row {
|
||||
rows, _ := br.Query()
|
||||
return (*connRow)(rows.(*connRows))
|
||||
return (*connRow)(rows.(*baseRows))
|
||||
|
||||
}
|
||||
|
||||
// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to
|
||||
// resyncronize the connection with the server. In this case the underlying connection will have been closed.
|
||||
func (br *batchResults) Close() error {
|
||||
defer func() {
|
||||
if !br.endTraced {
|
||||
if br.conn != nil && br.conn.batchTracer != nil {
|
||||
br.conn.batchTracer.TraceBatchEnd(br.ctx, br.conn, TraceBatchEndData{Err: br.err})
|
||||
}
|
||||
br.endTraced = true
|
||||
}
|
||||
}()
|
||||
|
||||
if br.err != nil {
|
||||
return br.err
|
||||
}
|
||||
@@ -196,33 +223,213 @@ func (br *batchResults) Close() error {
|
||||
if br.closed {
|
||||
return nil
|
||||
}
|
||||
br.closed = true
|
||||
|
||||
// log any queries that haven't yet been logged by Exec or Query
|
||||
for {
|
||||
query, args, ok := br.nextQueryAndArgs()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
if br.conn.shouldLog(LogLevelInfo) {
|
||||
br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Close", map[string]interface{}{
|
||||
"sql": query,
|
||||
"args": logQueryArgs(args),
|
||||
})
|
||||
// Read and run fn for all remaining items
|
||||
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
|
||||
if br.b.queuedQueries[br.qqIdx].fn != nil {
|
||||
err := br.b.queuedQueries[br.qqIdx].fn(br)
|
||||
if err != nil && br.err == nil {
|
||||
br.err = err
|
||||
}
|
||||
} else {
|
||||
br.Exec()
|
||||
}
|
||||
}
|
||||
|
||||
return br.mrr.Close()
|
||||
br.closed = true
|
||||
|
||||
err := br.mrr.Close()
|
||||
if br.err == nil {
|
||||
br.err = err
|
||||
}
|
||||
|
||||
return br.err
|
||||
}
|
||||
|
||||
func (br *batchResults) nextQueryAndArgs() (query string, args []interface{}, ok bool) {
|
||||
if br.b != nil && br.ix < len(br.b.items) {
|
||||
bi := br.b.items[br.ix]
|
||||
func (br *batchResults) earlyError() error {
|
||||
return br.err
|
||||
}
|
||||
|
||||
func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
|
||||
if br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
|
||||
bi := br.b.queuedQueries[br.qqIdx]
|
||||
query = bi.query
|
||||
args = bi.arguments
|
||||
ok = true
|
||||
br.ix++
|
||||
br.qqIdx++
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type pipelineBatchResults struct {
|
||||
ctx context.Context
|
||||
conn *Conn
|
||||
pipeline *pgconn.Pipeline
|
||||
lastRows *baseRows
|
||||
err error
|
||||
b *Batch
|
||||
qqIdx int
|
||||
closed bool
|
||||
endTraced bool
|
||||
}
|
||||
|
||||
// Exec reads the results from the next query in the batch as if the query has been sent with Exec.
|
||||
func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) {
|
||||
if br.err != nil {
|
||||
return pgconn.CommandTag{}, br.err
|
||||
}
|
||||
if br.closed {
|
||||
return pgconn.CommandTag{}, fmt.Errorf("batch already closed")
|
||||
}
|
||||
if br.lastRows != nil && br.lastRows.err != nil {
|
||||
return pgconn.CommandTag{}, br.err
|
||||
}
|
||||
|
||||
query, arguments, _ := br.nextQueryAndArgs()
|
||||
|
||||
results, err := br.pipeline.GetResults()
|
||||
if err != nil {
|
||||
br.err = err
|
||||
return pgconn.CommandTag{}, err
|
||||
}
|
||||
var commandTag pgconn.CommandTag
|
||||
switch results := results.(type) {
|
||||
case *pgconn.ResultReader:
|
||||
commandTag, br.err = results.Close()
|
||||
default:
|
||||
return pgconn.CommandTag{}, fmt.Errorf("unexpected pipeline result: %T", results)
|
||||
}
|
||||
|
||||
if br.conn.batchTracer != nil {
|
||||
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
|
||||
SQL: query,
|
||||
Args: arguments,
|
||||
CommandTag: commandTag,
|
||||
Err: br.err,
|
||||
})
|
||||
}
|
||||
|
||||
return commandTag, err
|
||||
}
|
||||
|
||||
// Query reads the results from the next query in the batch as if the query has been sent with Query.
|
||||
func (br *pipelineBatchResults) Query() (Rows, error) {
|
||||
if br.err != nil {
|
||||
return &baseRows{err: br.err, closed: true}, br.err
|
||||
}
|
||||
|
||||
if br.closed {
|
||||
alreadyClosedErr := fmt.Errorf("batch already closed")
|
||||
return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr
|
||||
}
|
||||
|
||||
if br.lastRows != nil && br.lastRows.err != nil {
|
||||
br.err = br.lastRows.err
|
||||
return &baseRows{err: br.err, closed: true}, br.err
|
||||
}
|
||||
|
||||
query, arguments, ok := br.nextQueryAndArgs()
|
||||
if !ok {
|
||||
query = "batch query"
|
||||
}
|
||||
|
||||
rows := br.conn.getRows(br.ctx, query, arguments)
|
||||
rows.batchTracer = br.conn.batchTracer
|
||||
br.lastRows = rows
|
||||
|
||||
results, err := br.pipeline.GetResults()
|
||||
if err != nil {
|
||||
br.err = err
|
||||
rows.err = err
|
||||
rows.closed = true
|
||||
|
||||
if br.conn.batchTracer != nil {
|
||||
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
|
||||
SQL: query,
|
||||
Args: arguments,
|
||||
Err: err,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
switch results := results.(type) {
|
||||
case *pgconn.ResultReader:
|
||||
rows.resultReader = results
|
||||
default:
|
||||
err = fmt.Errorf("unexpected pipeline result: %T", results)
|
||||
br.err = err
|
||||
rows.err = err
|
||||
rows.closed = true
|
||||
}
|
||||
}
|
||||
|
||||
return rows, rows.err
|
||||
}
|
||||
|
||||
// QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow.
|
||||
func (br *pipelineBatchResults) QueryRow() Row {
|
||||
rows, _ := br.Query()
|
||||
return (*connRow)(rows.(*baseRows))
|
||||
|
||||
}
|
||||
|
||||
// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to
|
||||
// resyncronize the connection with the server. In this case the underlying connection will have been closed.
|
||||
func (br *pipelineBatchResults) Close() error {
|
||||
defer func() {
|
||||
if !br.endTraced {
|
||||
if br.conn.batchTracer != nil {
|
||||
br.conn.batchTracer.TraceBatchEnd(br.ctx, br.conn, TraceBatchEndData{Err: br.err})
|
||||
}
|
||||
br.endTraced = true
|
||||
}
|
||||
}()
|
||||
|
||||
if br.err != nil {
|
||||
return br.err
|
||||
}
|
||||
|
||||
if br.lastRows != nil && br.lastRows.err != nil {
|
||||
br.err = br.lastRows.err
|
||||
return br.err
|
||||
}
|
||||
|
||||
if br.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read and run fn for all remaining items
|
||||
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
|
||||
if br.b.queuedQueries[br.qqIdx].fn != nil {
|
||||
err := br.b.queuedQueries[br.qqIdx].fn(br)
|
||||
if err != nil && br.err == nil {
|
||||
br.err = err
|
||||
}
|
||||
} else {
|
||||
br.Exec()
|
||||
}
|
||||
}
|
||||
|
||||
br.closed = true
|
||||
|
||||
err := br.pipeline.Close()
|
||||
if br.err == nil {
|
||||
br.err = err
|
||||
}
|
||||
|
||||
return br.err
|
||||
}
|
||||
|
||||
func (br *pipelineBatchResults) earlyError() error {
|
||||
return br.err
|
||||
}
|
||||
|
||||
func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
|
||||
if br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
|
||||
bi := br.b.queuedQueries[br.qqIdx]
|
||||
query = bi.query
|
||||
args = bi.arguments
|
||||
ok = true
|
||||
br.qqIdx++
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
+599
-523
File diff suppressed because it is too large
Load Diff
+63
-220
@@ -12,16 +12,32 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgconn/stmtcache"
|
||||
"github.com/jackc/pgtype"
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/internal/nbconn"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func BenchmarkConnectClose(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
err = conn.Close(context.Background())
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMinimalUnpreparedSelectWithoutStatementCache(b *testing.B) {
|
||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||
config.BuildStatementCache = nil
|
||||
config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec
|
||||
config.StatementCacheCapacity = 0
|
||||
config.DescriptionCacheCapacity = 0
|
||||
|
||||
conn := mustConnect(b, config)
|
||||
defer closeConn(b, conn)
|
||||
@@ -43,9 +59,9 @@ func BenchmarkMinimalUnpreparedSelectWithoutStatementCache(b *testing.B) {
|
||||
|
||||
func BenchmarkMinimalUnpreparedSelectWithStatementCacheModeDescribe(b *testing.B) {
|
||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||
config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache {
|
||||
return stmtcache.New(conn, stmtcache.ModeDescribe, 32)
|
||||
}
|
||||
config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe
|
||||
config.StatementCacheCapacity = 0
|
||||
config.DescriptionCacheCapacity = 32
|
||||
|
||||
conn := mustConnect(b, config)
|
||||
defer closeConn(b, conn)
|
||||
@@ -67,9 +83,9 @@ func BenchmarkMinimalUnpreparedSelectWithStatementCacheModeDescribe(b *testing.B
|
||||
|
||||
func BenchmarkMinimalUnpreparedSelectWithStatementCacheModePrepare(b *testing.B) {
|
||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||
config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache {
|
||||
return stmtcache.New(conn, stmtcache.ModePrepare, 32)
|
||||
}
|
||||
config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement
|
||||
config.StatementCacheCapacity = 32
|
||||
config.DescriptionCacheCapacity = 0
|
||||
|
||||
conn := mustConnect(b, config)
|
||||
defer closeConn(b, conn)
|
||||
@@ -268,123 +284,6 @@ func BenchmarkPointerPointerWithPresentValues(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSelectWithoutLogging(b *testing.B) {
|
||||
conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")))
|
||||
defer closeConn(b, conn)
|
||||
|
||||
benchmarkSelectWithLog(b, conn)
|
||||
}
|
||||
|
||||
type discardLogger struct{}
|
||||
|
||||
func (dl discardLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
}
|
||||
|
||||
func BenchmarkSelectWithLoggingTraceDiscard(b *testing.B) {
|
||||
var logger discardLogger
|
||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||
config.Logger = logger
|
||||
config.LogLevel = pgx.LogLevelTrace
|
||||
|
||||
conn := mustConnect(b, config)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
benchmarkSelectWithLog(b, conn)
|
||||
}
|
||||
|
||||
func BenchmarkSelectWithLoggingDebugWithDiscard(b *testing.B) {
|
||||
var logger discardLogger
|
||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||
config.Logger = logger
|
||||
config.LogLevel = pgx.LogLevelDebug
|
||||
|
||||
conn := mustConnect(b, config)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
benchmarkSelectWithLog(b, conn)
|
||||
}
|
||||
|
||||
func BenchmarkSelectWithLoggingInfoWithDiscard(b *testing.B) {
|
||||
var logger discardLogger
|
||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||
config.Logger = logger
|
||||
config.LogLevel = pgx.LogLevelInfo
|
||||
|
||||
conn := mustConnect(b, config)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
benchmarkSelectWithLog(b, conn)
|
||||
}
|
||||
|
||||
func BenchmarkSelectWithLoggingErrorWithDiscard(b *testing.B) {
|
||||
var logger discardLogger
|
||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||
config.Logger = logger
|
||||
config.LogLevel = pgx.LogLevelError
|
||||
|
||||
conn := mustConnect(b, config)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
benchmarkSelectWithLog(b, conn)
|
||||
}
|
||||
|
||||
func benchmarkSelectWithLog(b *testing.B, conn *pgx.Conn) {
|
||||
_, err := conn.Prepare(context.Background(), "test", "select 1::int4, 'johnsmith', 'johnsmith@example.com', 'John Smith', 'male', '1970-01-01'::date, '2015-01-01 00:00:00'::timestamptz")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var record struct {
|
||||
id int32
|
||||
userName string
|
||||
email string
|
||||
name string
|
||||
sex string
|
||||
birthDate time.Time
|
||||
lastLoginTime time.Time
|
||||
}
|
||||
|
||||
err = conn.QueryRow(context.Background(), "test").Scan(
|
||||
&record.id,
|
||||
&record.userName,
|
||||
&record.email,
|
||||
&record.name,
|
||||
&record.sex,
|
||||
&record.birthDate,
|
||||
&record.lastLoginTime,
|
||||
)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// These checks both ensure that the correct data was returned
|
||||
// and provide a benchmark of accessing the returned values.
|
||||
if record.id != 1 {
|
||||
b.Fatalf("bad value for id: %v", record.id)
|
||||
}
|
||||
if record.userName != "johnsmith" {
|
||||
b.Fatalf("bad value for userName: %v", record.userName)
|
||||
}
|
||||
if record.email != "johnsmith@example.com" {
|
||||
b.Fatalf("bad value for email: %v", record.email)
|
||||
}
|
||||
if record.name != "John Smith" {
|
||||
b.Fatalf("bad value for name: %v", record.name)
|
||||
}
|
||||
if record.sex != "male" {
|
||||
b.Fatalf("bad value for sex: %v", record.sex)
|
||||
}
|
||||
if record.birthDate != time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC) {
|
||||
b.Fatalf("bad value for birthDate: %v", record.birthDate)
|
||||
}
|
||||
if record.lastLoginTime != time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local) {
|
||||
b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const benchmarkWriteTableCreateSQL = `drop table if exists t;
|
||||
|
||||
create table t(
|
||||
@@ -437,7 +336,7 @@ const benchmarkWriteTableInsertSQL = `insert into t(
|
||||
type benchmarkWriteTableCopyFromSrc struct {
|
||||
count int
|
||||
idx int
|
||||
row []interface{}
|
||||
row []any
|
||||
}
|
||||
|
||||
func (s *benchmarkWriteTableCopyFromSrc) Next() bool {
|
||||
@@ -445,7 +344,7 @@ func (s *benchmarkWriteTableCopyFromSrc) Next() bool {
|
||||
return s.idx < s.count
|
||||
}
|
||||
|
||||
func (s *benchmarkWriteTableCopyFromSrc) Values() ([]interface{}, error) {
|
||||
func (s *benchmarkWriteTableCopyFromSrc) Values() ([]any, error) {
|
||||
return s.row, nil
|
||||
}
|
||||
|
||||
@@ -456,15 +355,15 @@ func (s *benchmarkWriteTableCopyFromSrc) Err() error {
|
||||
func newBenchmarkWriteTableCopyFromSrc(count int) pgx.CopyFromSource {
|
||||
return &benchmarkWriteTableCopyFromSrc{
|
||||
count: count,
|
||||
row: []interface{}{
|
||||
row: []any{
|
||||
"varchar_1",
|
||||
"varchar_2",
|
||||
&pgtype.Text{Status: pgtype.Null},
|
||||
&pgtype.Text{},
|
||||
time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local),
|
||||
&pgtype.Date{Status: pgtype.Null},
|
||||
&pgtype.Date{},
|
||||
1,
|
||||
2,
|
||||
&pgtype.Int4{Status: pgtype.Null},
|
||||
&pgtype.Int4{},
|
||||
time.Date(2001, 1, 1, 0, 0, 0, 0, time.Local),
|
||||
time.Date(2002, 1, 1, 0, 0, 0, 0, time.Local),
|
||||
true,
|
||||
@@ -508,9 +407,9 @@ func benchmarkWriteNRowsViaInsert(b *testing.B, n int) {
|
||||
}
|
||||
}
|
||||
|
||||
type queryArgs []interface{}
|
||||
type queryArgs []any
|
||||
|
||||
func (qa *queryArgs) Append(v interface{}) string {
|
||||
func (qa *queryArgs) Append(v any) string {
|
||||
*qa = append(*qa, v)
|
||||
return "$" + strconv.Itoa(len(*qa))
|
||||
}
|
||||
@@ -723,7 +622,9 @@ func BenchmarkWrite10000RowsViaCopy(b *testing.B) {
|
||||
|
||||
func BenchmarkMultipleQueriesNonBatchNoStatementCache(b *testing.B) {
|
||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||
config.BuildStatementCache = nil
|
||||
config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec
|
||||
config.StatementCacheCapacity = 0
|
||||
config.DescriptionCacheCapacity = 0
|
||||
|
||||
conn := mustConnect(b, config)
|
||||
defer closeConn(b, conn)
|
||||
@@ -733,9 +634,9 @@ func BenchmarkMultipleQueriesNonBatchNoStatementCache(b *testing.B) {
|
||||
|
||||
func BenchmarkMultipleQueriesNonBatchPrepareStatementCache(b *testing.B) {
|
||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||
config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache {
|
||||
return stmtcache.New(conn, stmtcache.ModePrepare, 32)
|
||||
}
|
||||
config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement
|
||||
config.StatementCacheCapacity = 32
|
||||
config.DescriptionCacheCapacity = 0
|
||||
|
||||
conn := mustConnect(b, config)
|
||||
defer closeConn(b, conn)
|
||||
@@ -745,9 +646,9 @@ func BenchmarkMultipleQueriesNonBatchPrepareStatementCache(b *testing.B) {
|
||||
|
||||
func BenchmarkMultipleQueriesNonBatchDescribeStatementCache(b *testing.B) {
|
||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||
config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache {
|
||||
return stmtcache.New(conn, stmtcache.ModeDescribe, 32)
|
||||
}
|
||||
config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe
|
||||
config.StatementCacheCapacity = 0
|
||||
config.DescriptionCacheCapacity = 32
|
||||
|
||||
conn := mustConnect(b, config)
|
||||
defer closeConn(b, conn)
|
||||
@@ -783,7 +684,9 @@ func benchmarkMultipleQueriesNonBatch(b *testing.B, conn *pgx.Conn, queryCount i
|
||||
|
||||
func BenchmarkMultipleQueriesBatchNoStatementCache(b *testing.B) {
|
||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||
config.BuildStatementCache = nil
|
||||
config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec
|
||||
config.StatementCacheCapacity = 0
|
||||
config.DescriptionCacheCapacity = 0
|
||||
|
||||
conn := mustConnect(b, config)
|
||||
defer closeConn(b, conn)
|
||||
@@ -793,9 +696,9 @@ func BenchmarkMultipleQueriesBatchNoStatementCache(b *testing.B) {
|
||||
|
||||
func BenchmarkMultipleQueriesBatchPrepareStatementCache(b *testing.B) {
|
||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||
config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache {
|
||||
return stmtcache.New(conn, stmtcache.ModePrepare, 32)
|
||||
}
|
||||
config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement
|
||||
config.StatementCacheCapacity = 32
|
||||
config.DescriptionCacheCapacity = 0
|
||||
|
||||
conn := mustConnect(b, config)
|
||||
defer closeConn(b, conn)
|
||||
@@ -805,9 +708,9 @@ func BenchmarkMultipleQueriesBatchPrepareStatementCache(b *testing.B) {
|
||||
|
||||
func BenchmarkMultipleQueriesBatchDescribeStatementCache(b *testing.B) {
|
||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||
config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache {
|
||||
return stmtcache.New(conn, stmtcache.ModeDescribe, 32)
|
||||
}
|
||||
config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe
|
||||
config.StatementCacheCapacity = 0
|
||||
config.DescriptionCacheCapacity = 32
|
||||
|
||||
conn := mustConnect(b, config)
|
||||
defer closeConn(b, conn)
|
||||
@@ -918,8 +821,7 @@ func BenchmarkSelectManyRegisteredEnum(b *testing.B) {
|
||||
err = conn.QueryRow(context.Background(), "select oid from pg_type where typname=$1;", "color").Scan(&oid)
|
||||
require.NoError(b, err)
|
||||
|
||||
et := pgtype.NewEnumType("color", []string{"blue", "green", "orange"})
|
||||
conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: et, Name: "color", OID: oid})
|
||||
conn.TypeMap().RegisterType(&pgtype.Type{Name: "color", OID: oid, Codec: &pgtype.EnumCodec{}})
|
||||
|
||||
b.ResetTimer()
|
||||
var x, y, z string
|
||||
@@ -1105,73 +1007,6 @@ func BenchmarkSelectRowsScanDecoder(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSelectRowsExplicitDecoding(b *testing.B) {
|
||||
conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(b, conn)
|
||||
|
||||
rowCounts := getSelectRowsCounts(b)
|
||||
|
||||
for _, rowCount := range rowCounts {
|
||||
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
|
||||
br := &BenchRowDecoder{}
|
||||
for i := 0; i < b.N; i++ {
|
||||
rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
rawValues := rows.RawValues()
|
||||
|
||||
err = br.ID.DecodeBinary(conn.ConnInfo(), rawValues[0])
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
err = br.FirstName.DecodeText(conn.ConnInfo(), rawValues[1])
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
err = br.LastName.DecodeText(conn.ConnInfo(), rawValues[2])
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
err = br.Sex.DecodeText(conn.ConnInfo(), rawValues[3])
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
err = br.BirthDate.DecodeBinary(conn.ConnInfo(), rawValues[4])
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
err = br.Weight.DecodeBinary(conn.ConnInfo(), rawValues[5])
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
err = br.Height.DecodeBinary(conn.ConnInfo(), rawValues[6])
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
err = br.UpdateTime.DecodeBinary(conn.ConnInfo(), rawValues[7])
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
b.Fatal(rows.Err())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSelectRowsPgConnExecText(b *testing.B) {
|
||||
conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(b, conn)
|
||||
@@ -1285,7 +1120,7 @@ func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) {
|
||||
}
|
||||
|
||||
type queryRecorder struct {
|
||||
conn net.Conn
|
||||
conn nbconn.Conn
|
||||
writeBuf []byte
|
||||
readCount int
|
||||
}
|
||||
@@ -1301,6 +1136,14 @@ func (qr *queryRecorder) Write(b []byte) (n int, err error) {
|
||||
return qr.conn.Write(b)
|
||||
}
|
||||
|
||||
func (qr *queryRecorder) BufferReadUntilBlock() error {
|
||||
return qr.conn.BufferReadUntilBlock()
|
||||
}
|
||||
|
||||
func (qr *queryRecorder) Flush() error {
|
||||
return qr.conn.Flush()
|
||||
}
|
||||
|
||||
func (qr *queryRecorder) Close() error {
|
||||
return qr.conn.Close()
|
||||
}
|
||||
|
||||
@@ -13,6 +13,10 @@ then
|
||||
echo "local all postgres trust" > /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||
echo "local all all trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||
echo "host all pgx_md5 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||
echo "host all pgx_pw 127.0.0.1/32 password" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||
echo "hostssl all pgx_ssl 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||
echo "host replication pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||
echo "host pgx_test pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
|
||||
sudo chmod 777 /etc/postgresql/$PGVERSION/main/postgresql.conf
|
||||
if $(dpkg --compare-versions $PGVERSION ge 9.6) ; then
|
||||
echo "wal_level='logical'" >> /etc/postgresql/$PGVERSION/main/postgresql.conf
|
||||
@@ -24,8 +28,15 @@ then
|
||||
psql -U postgres -c 'create database pgx_test'
|
||||
psql -U postgres pgx_test -c 'create extension hstore'
|
||||
psql -U postgres pgx_test -c 'create domain uint64 as numeric(20,0)'
|
||||
psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'"
|
||||
psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'"
|
||||
psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'"
|
||||
psql -U postgres -c "create user `whoami`"
|
||||
psql -U postgres -c "create user pgx_replication with replication password 'secret'"
|
||||
|
||||
# The tricky test user, below, has to actually exist so that it can be used in a test
|
||||
# of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles.
|
||||
psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'"
|
||||
fi
|
||||
|
||||
if [[ "${PGVERSION-}" =~ ^cockroach ]]
|
||||
|
||||
+202
-263
@@ -3,17 +3,16 @@ package pgx_test
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgconn/stmtcache"
|
||||
"github.com/jackc/pgtype"
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxtest"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -83,7 +82,7 @@ func TestConnectWithPreferSimpleProtocol(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
connConfig := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
connConfig.PreferSimpleProtocol = true
|
||||
connConfig.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
|
||||
|
||||
conn := mustConnect(t, connConfig)
|
||||
defer closeConn(t, conn)
|
||||
@@ -93,13 +92,8 @@ func TestConnectWithPreferSimpleProtocol(t *testing.T) {
|
||||
|
||||
var s pgtype.Text
|
||||
err := conn.QueryRow(context.Background(), "select $1::int4", 42).Scan(&s)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if s.Get() != "42" {
|
||||
t.Fatalf(`expected "42", got %v`, s)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pgtype.Text{String: "42", Valid: true}, s)
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
@@ -144,91 +138,122 @@ func TestParseConfigExtractsStatementCacheOptions(t *testing.T) {
|
||||
|
||||
config, err := pgx.ParseConfig("statement_cache_capacity=0")
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, config.BuildStatementCache)
|
||||
require.EqualValues(t, 0, config.StatementCacheCapacity)
|
||||
|
||||
config, err = pgx.ParseConfig("statement_cache_capacity=42")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config.BuildStatementCache)
|
||||
c := config.BuildStatementCache(nil)
|
||||
require.NotNil(t, c)
|
||||
require.Equal(t, 42, c.Cap())
|
||||
require.Equal(t, stmtcache.ModePrepare, c.Mode())
|
||||
require.EqualValues(t, 42, config.StatementCacheCapacity)
|
||||
|
||||
config, err = pgx.ParseConfig("statement_cache_capacity=42 statement_cache_mode=prepare")
|
||||
config, err = pgx.ParseConfig("description_cache_capacity=0")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config.BuildStatementCache)
|
||||
c = config.BuildStatementCache(nil)
|
||||
require.NotNil(t, c)
|
||||
require.Equal(t, 42, c.Cap())
|
||||
require.Equal(t, stmtcache.ModePrepare, c.Mode())
|
||||
require.EqualValues(t, 0, config.DescriptionCacheCapacity)
|
||||
|
||||
config, err = pgx.ParseConfig("statement_cache_capacity=42 statement_cache_mode=describe")
|
||||
config, err = pgx.ParseConfig("description_cache_capacity=42")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config.BuildStatementCache)
|
||||
c = config.BuildStatementCache(nil)
|
||||
require.NotNil(t, c)
|
||||
require.Equal(t, 42, c.Cap())
|
||||
require.Equal(t, stmtcache.ModeDescribe, c.Mode())
|
||||
require.EqualValues(t, 42, config.DescriptionCacheCapacity)
|
||||
|
||||
// default_query_exec_mode
|
||||
// Possible values: "cache_statement", "cache_describe", "describe_exec", "exec", and "simple_protocol". See
|
||||
|
||||
config, err = pgx.ParseConfig("default_query_exec_mode=cache_statement")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pgx.QueryExecModeCacheStatement, config.DefaultQueryExecMode)
|
||||
|
||||
config, err = pgx.ParseConfig("default_query_exec_mode=cache_describe")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pgx.QueryExecModeCacheDescribe, config.DefaultQueryExecMode)
|
||||
|
||||
config, err = pgx.ParseConfig("default_query_exec_mode=describe_exec")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pgx.QueryExecModeDescribeExec, config.DefaultQueryExecMode)
|
||||
|
||||
config, err = pgx.ParseConfig("default_query_exec_mode=exec")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pgx.QueryExecModeExec, config.DefaultQueryExecMode)
|
||||
|
||||
config, err = pgx.ParseConfig("default_query_exec_mode=simple_protocol")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pgx.QueryExecModeSimpleProtocol, config.DefaultQueryExecMode)
|
||||
}
|
||||
|
||||
func TestParseConfigExtractsPreferSimpleProtocol(t *testing.T) {
|
||||
func TestParseConfigExtractsDefaultQueryExecMode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, tt := range []struct {
|
||||
connString string
|
||||
preferSimpleProtocol bool
|
||||
defaultQueryExecMode pgx.QueryExecMode
|
||||
}{
|
||||
{"", false},
|
||||
{"prefer_simple_protocol=false", false},
|
||||
{"prefer_simple_protocol=0", false},
|
||||
{"prefer_simple_protocol=true", true},
|
||||
{"prefer_simple_protocol=1", true},
|
||||
{"", pgx.QueryExecModeCacheStatement},
|
||||
{"default_query_exec_mode=cache_statement", pgx.QueryExecModeCacheStatement},
|
||||
{"default_query_exec_mode=cache_describe", pgx.QueryExecModeCacheDescribe},
|
||||
{"default_query_exec_mode=describe_exec", pgx.QueryExecModeDescribeExec},
|
||||
{"default_query_exec_mode=exec", pgx.QueryExecModeExec},
|
||||
{"default_query_exec_mode=simple_protocol", pgx.QueryExecModeSimpleProtocol},
|
||||
} {
|
||||
config, err := pgx.ParseConfig(tt.connString)
|
||||
require.NoError(t, err)
|
||||
require.Equalf(t, tt.preferSimpleProtocol, config.PreferSimpleProtocol, "connString: `%s`", tt.connString)
|
||||
require.Empty(t, config.RuntimeParams["prefer_simple_protocol"])
|
||||
require.Equalf(t, tt.defaultQueryExecMode, config.DefaultQueryExecMode, "connString: `%s`", tt.connString)
|
||||
require.Empty(t, config.RuntimeParams["default_query_exec_mode"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestExec(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
|
||||
if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); string(results) != "CREATE TABLE" {
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); results.String() != "CREATE TABLE" {
|
||||
t.Error("Unexpected results from Exec")
|
||||
}
|
||||
|
||||
// Accept parameters
|
||||
if results := mustExec(t, conn, "insert into foo(id) values($1)", 1); string(results) != "INSERT 0 1" {
|
||||
if results := mustExec(t, conn, "insert into foo(id) values($1)", 1); results.String() != "INSERT 0 1" {
|
||||
t.Errorf("Unexpected results from Exec: %v", results)
|
||||
}
|
||||
|
||||
if results := mustExec(t, conn, "drop table foo;"); string(results) != "DROP TABLE" {
|
||||
if results := mustExec(t, conn, "drop table foo;"); results.String() != "DROP TABLE" {
|
||||
t.Error("Unexpected results from Exec")
|
||||
}
|
||||
|
||||
// Multiple statements can be executed -- last command tag is returned
|
||||
if results := mustExec(t, conn, "create temporary table foo(id serial primary key); drop table foo;"); string(results) != "DROP TABLE" {
|
||||
if results := mustExec(t, conn, "create temporary table foo(id serial primary key); drop table foo;"); results.String() != "DROP TABLE" {
|
||||
t.Error("Unexpected results from Exec")
|
||||
}
|
||||
|
||||
// Can execute longer SQL strings than sharedBufferSize
|
||||
if results := mustExec(t, conn, strings.Repeat("select 42; ", 1000)); string(results) != "SELECT 1" {
|
||||
if results := mustExec(t, conn, strings.Repeat("select 42; ", 1000)); results.String() != "SELECT 1" {
|
||||
t.Errorf("Unexpected results from Exec: %v", results)
|
||||
}
|
||||
|
||||
// Exec no-op which does not return a command tag
|
||||
if results := mustExec(t, conn, "--;"); string(results) != "" {
|
||||
if results := mustExec(t, conn, "--;"); results.String() != "" {
|
||||
t.Errorf("Unexpected results from Exec: %v", results)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type testQueryRewriter struct {
|
||||
sql string
|
||||
args []any
|
||||
}
|
||||
|
||||
func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args []any) (newSQL string, newArgs []any) {
|
||||
return qr.sql, qr.args
|
||||
}
|
||||
|
||||
func TestExecWithQueryRewriter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
qr := testQueryRewriter{sql: "select $1::int", args: []any{42}}
|
||||
_, err := conn.Exec(ctx, "should be replaced", &qr)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestExecFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
if _, err := conn.Exec(context.Background(), "selct;"); err == nil {
|
||||
t.Fatal("Expected SQL syntax error")
|
||||
}
|
||||
@@ -244,7 +269,7 @@ func TestExecFailure(t *testing.T) {
|
||||
func TestExecFailureWithArguments(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
_, err := conn.Exec(context.Background(), "selct $1;", 1)
|
||||
if err == nil {
|
||||
t.Fatal("Expected SQL syntax error")
|
||||
@@ -259,7 +284,7 @@ func TestExecFailureWithArguments(t *testing.T) {
|
||||
func TestExecContextWithoutCancelation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
@@ -267,7 +292,7 @@ func TestExecContextWithoutCancelation(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(commandTag) != "CREATE TABLE" {
|
||||
if commandTag.String() != "CREATE TABLE" {
|
||||
t.Fatalf("Unexpected results from Exec: %v", commandTag)
|
||||
}
|
||||
assert.False(t, pgconn.SafeToRetry(err))
|
||||
@@ -277,7 +302,7 @@ func TestExecContextWithoutCancelation(t *testing.T) {
|
||||
func TestExecContextFailureWithoutCancelation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
@@ -299,7 +324,7 @@ func TestExecContextFailureWithoutCancelation(t *testing.T) {
|
||||
func TestExecContextFailureWithoutCancelationWithArguments(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
@@ -322,56 +347,6 @@ func TestExecFailureCloseBefore(t *testing.T) {
|
||||
assert.True(t, pgconn.SafeToRetry(err))
|
||||
}
|
||||
|
||||
func TestExecStatementCacheModes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
buildStatementCache pgx.BuildStatementCacheFunc
|
||||
}{
|
||||
{
|
||||
name: "disabled",
|
||||
buildStatementCache: nil,
|
||||
},
|
||||
{
|
||||
name: "prepare",
|
||||
buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache {
|
||||
return stmtcache.New(conn, stmtcache.ModePrepare, 32)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "describe",
|
||||
buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache {
|
||||
return stmtcache.New(conn, stmtcache.ModeDescribe, 32)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
func() {
|
||||
config.BuildStatementCache = tt.buildStatementCache
|
||||
conn := mustConnect(t, config)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
commandTag, err := conn.Exec(context.Background(), "select 1")
|
||||
assert.NoError(t, err, tt.name)
|
||||
assert.Equal(t, "SELECT 1", string(commandTag), tt.name)
|
||||
|
||||
commandTag, err = conn.Exec(context.Background(), "select 1 union all select 1")
|
||||
assert.NoError(t, err, tt.name)
|
||||
assert.Equal(t, "SELECT 2", string(commandTag), tt.name)
|
||||
|
||||
commandTag, err = conn.Exec(context.Background(), "select 1")
|
||||
assert.NoError(t, err, tt.name)
|
||||
assert.Equal(t, "SELECT 1", string(commandTag), tt.name)
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecPerQuerySimpleProtocol(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -385,19 +360,19 @@ func TestExecPerQuerySimpleProtocol(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(commandTag) != "CREATE TABLE" {
|
||||
if commandTag.String() != "CREATE TABLE" {
|
||||
t.Fatalf("Unexpected results from Exec: %v", commandTag)
|
||||
}
|
||||
|
||||
commandTag, err = conn.Exec(ctx,
|
||||
"insert into foo(name) values($1);",
|
||||
pgx.QuerySimpleProtocol(true),
|
||||
pgx.QueryExecModeSimpleProtocol,
|
||||
"bar'; drop table foo;--",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(commandTag) != "INSERT 0 1" {
|
||||
if commandTag.String() != "INSERT 0 1" {
|
||||
t.Fatalf("Unexpected results from Exec: %v", commandTag)
|
||||
}
|
||||
|
||||
@@ -501,45 +476,15 @@ func TestPrepareIdempotency(t *testing.T) {
|
||||
func TestPrepareStatementCacheModes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
_, err := conn.Prepare(context.Background(), "test", "select $1::text")
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
buildStatementCache pgx.BuildStatementCacheFunc
|
||||
}{
|
||||
{
|
||||
name: "disabled",
|
||||
buildStatementCache: nil,
|
||||
},
|
||||
{
|
||||
name: "prepare",
|
||||
buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache {
|
||||
return stmtcache.New(conn, stmtcache.ModePrepare, 32)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "describe",
|
||||
buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache {
|
||||
return stmtcache.New(conn, stmtcache.ModeDescribe, 32)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config.BuildStatementCache = tt.buildStatementCache
|
||||
conn := mustConnect(t, config)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
_, err := conn.Prepare(context.Background(), "test", "select $1::text")
|
||||
require.NoError(t, err)
|
||||
|
||||
var s string
|
||||
err = conn.QueryRow(context.Background(), "test", "hello").Scan(&s)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "hello", s)
|
||||
})
|
||||
}
|
||||
var s string
|
||||
err = conn.QueryRow(context.Background(), "test", "hello").Scan(&s)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "hello", s)
|
||||
})
|
||||
}
|
||||
|
||||
func TestListenNotify(t *testing.T) {
|
||||
@@ -595,7 +540,7 @@ func TestListenNotifyWhileBusyIsSafe(t *testing.T) {
|
||||
func() {
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
skipCockroachDB(t, conn, "Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)")
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)")
|
||||
}()
|
||||
|
||||
listenerDone := make(chan bool)
|
||||
@@ -671,7 +616,7 @@ func TestListenNotifySelfNotification(t *testing.T) {
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
skipCockroachDB(t, conn, "Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)")
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)")
|
||||
|
||||
mustExec(t, conn, "listen self")
|
||||
|
||||
@@ -706,7 +651,7 @@ func TestFatalRxError(t *testing.T) {
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
skipCockroachDB(t, conn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)")
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)")
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
@@ -745,7 +690,7 @@ func TestFatalTxError(t *testing.T) {
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
skipCockroachDB(t, conn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)")
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)")
|
||||
|
||||
otherConn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer otherConn.Close(context.Background())
|
||||
@@ -770,13 +715,13 @@ func TestFatalTxError(t *testing.T) {
|
||||
func TestInsertBoolArray(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
|
||||
if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); string(results) != "CREATE TABLE" {
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); results.String() != "CREATE TABLE" {
|
||||
t.Error("Unexpected results from Exec")
|
||||
}
|
||||
|
||||
// Accept parameters
|
||||
if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); string(results) != "INSERT 0 1" {
|
||||
if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); results.String() != "INSERT 0 1" {
|
||||
t.Errorf("Unexpected results from Exec: %v", results)
|
||||
}
|
||||
})
|
||||
@@ -785,91 +730,18 @@ func TestInsertBoolArray(t *testing.T) {
|
||||
func TestInsertTimestampArray(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
|
||||
if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); string(results) != "CREATE TABLE" {
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); results.String() != "CREATE TABLE" {
|
||||
t.Error("Unexpected results from Exec")
|
||||
}
|
||||
|
||||
// Accept parameters
|
||||
if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); string(results) != "INSERT 0 1" {
|
||||
if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); results.String() != "INSERT 0 1" {
|
||||
t.Errorf("Unexpected results from Exec: %v", results)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type testLog struct {
|
||||
lvl pgx.LogLevel
|
||||
msg string
|
||||
data map[string]interface{}
|
||||
}
|
||||
|
||||
type testLogger struct {
|
||||
logs []testLog
|
||||
}
|
||||
|
||||
func (l *testLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
data["ctxdata"] = ctx.Value("ctxdata")
|
||||
l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data})
|
||||
}
|
||||
|
||||
func TestLogPassesContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l1 := &testLogger{}
|
||||
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
config.Logger = l1
|
||||
|
||||
conn := mustConnect(t, config)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
l1.logs = l1.logs[0:0] // Clear logs written when establishing connection
|
||||
|
||||
ctx := context.WithValue(context.Background(), "ctxdata", "foo")
|
||||
|
||||
if _, err := conn.Exec(ctx, ";"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(l1.logs) != 1 {
|
||||
t.Fatal("Expected logger to be called once, but it wasn't")
|
||||
}
|
||||
|
||||
if l1.logs[0].data["ctxdata"] != "foo" {
|
||||
t.Fatal("Expected context data to be passed to logger, but it wasn't")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerFunc(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const testMsg = "foo"
|
||||
|
||||
buf := bytes.Buffer{}
|
||||
logger := log.New(&buf, "", 0)
|
||||
|
||||
createAdapterFn := func(logger *log.Logger) pgx.LoggerFunc {
|
||||
return func(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
logger.Printf("%s", testMsg)
|
||||
}
|
||||
}
|
||||
|
||||
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
config.Logger = createAdapterFn(logger)
|
||||
|
||||
conn := mustConnect(t, config)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
buf.Reset() // Clear logs written when establishing connection
|
||||
|
||||
if _, err := conn.Exec(context.TODO(), ";"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(buf.String()) != testMsg {
|
||||
t.Errorf("Expected logger function to return '%s', but it was '%s'", testMsg, buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestIdentifierSanitize(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -911,7 +783,7 @@ func TestIdentifierSanitize(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnInitConnInfo(t *testing.T) {
|
||||
func TestConnInitTypeMap(t *testing.T) {
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
@@ -923,11 +795,11 @@ func TestConnInitConnInfo(t *testing.T) {
|
||||
"text": pgtype.TextOID,
|
||||
}
|
||||
for name, oid := range nameOIDs {
|
||||
dtByName, ok := conn.ConnInfo().DataTypeForName(name)
|
||||
dtByName, ok := conn.TypeMap().TypeForName(name)
|
||||
if !ok {
|
||||
t.Fatalf("Expected type named %v to be present", name)
|
||||
}
|
||||
dtByOID, ok := conn.ConnInfo().DataTypeForOID(oid)
|
||||
dtByOID, ok := conn.TypeMap().TypeForOID(oid)
|
||||
if !ok {
|
||||
t.Fatalf("Expected type OID %v to be present", oid)
|
||||
}
|
||||
@@ -940,8 +812,8 @@ func TestConnInitConnInfo(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUnregisteredTypeUsableAsStringArgumentAndBaseResult(t *testing.T) {
|
||||
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
|
||||
skipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)")
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)")
|
||||
|
||||
var n uint64
|
||||
err := conn.QueryRow(context.Background(), "select $1::uint64", "42").Scan(&n)
|
||||
@@ -956,32 +828,30 @@ func TestUnregisteredTypeUsableAsStringArgumentAndBaseResult(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDomainType(t *testing.T) {
|
||||
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
|
||||
skipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)")
|
||||
|
||||
var n uint64
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)")
|
||||
|
||||
// Domain type uint64 is a PostgreSQL domain of underlying type numeric.
|
||||
|
||||
err := conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n)
|
||||
// In the extended protocol preparing "select $1::uint64" appears to create a statement that expects a param OID of
|
||||
// uint64 but a result OID of the underlying numeric.
|
||||
|
||||
var s string
|
||||
err := conn.QueryRow(context.Background(), "select $1::uint64", "24").Scan(&s)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "24", s)
|
||||
|
||||
// A string can be used. But a string cannot be the result because the describe result from the PostgreSQL server gives
|
||||
// the underlying type of numeric.
|
||||
err = conn.QueryRow(context.Background(), "select $1::uint64", "42").Scan(&n)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != 42 {
|
||||
t.Fatalf("Expected n to be 42, but was %v", n)
|
||||
}
|
||||
|
||||
// Register type
|
||||
var uint64OID uint32
|
||||
err = conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='uint64';").Scan(&uint64OID)
|
||||
if err != nil {
|
||||
t.Fatalf("did not find uint64 OID, %v", err)
|
||||
}
|
||||
conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: &pgtype.Numeric{}, Name: "uint64", OID: uint64OID})
|
||||
conn.TypeMap().RegisterType(&pgtype.Type{Name: "uint64", OID: uint64OID, Codec: pgtype.NumericCodec{}})
|
||||
|
||||
var n uint64
|
||||
err = conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n)
|
||||
require.NoError(t, err)
|
||||
|
||||
// String is still an acceptable argument after registration
|
||||
err = conn.QueryRow(context.Background(), "select $1::uint64", "7").Scan(&n)
|
||||
@@ -991,15 +861,48 @@ func TestDomainType(t *testing.T) {
|
||||
if n != 7 {
|
||||
t.Fatalf("Expected n to be 7, but was %v", n)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// But a uint64 is acceptable
|
||||
err = conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
func TestLoadTypeSameNameInDifferentSchemas(t *testing.T) {
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server does support composite types (https://github.com/cockroachdb/cockroach/issues/27792)")
|
||||
|
||||
tx, err := conn.Begin(ctx)
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
_, err = tx.Exec(ctx, `create schema pgx_a;
|
||||
create type pgx_a.point as (a text, b text);
|
||||
create schema pgx_b;
|
||||
create type pgx_b.point as (c text);
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Register types
|
||||
for _, typename := range []string{"pgx_a.point", "pgx_b.point"} {
|
||||
// Obviously using conn while a tx is in use and registering a type after the connection has been established are
|
||||
// really bad practices, but for the sake of convenience we do it in the test here.
|
||||
dt, err := conn.LoadType(ctx, typename)
|
||||
require.NoError(t, err)
|
||||
conn.TypeMap().RegisterType(dt)
|
||||
}
|
||||
if n != 24 {
|
||||
t.Fatalf("Expected n to be 24, but was %v", n)
|
||||
|
||||
type aPoint struct {
|
||||
A string
|
||||
B string
|
||||
}
|
||||
|
||||
type bPoint struct {
|
||||
C string
|
||||
}
|
||||
|
||||
var a aPoint
|
||||
var b bPoint
|
||||
err = tx.QueryRow(ctx, `select '(foo,bar)'::pgx_a.point, '(baz)'::pgx_b.point`).Scan(&a, &b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, aPoint{"foo", "bar"}, a)
|
||||
require.Equal(t, bPoint{"baz"}, b)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1028,6 +931,7 @@ func TestStmtCacheInvalidationConn(t *testing.T) {
|
||||
rows, err := conn.Query(ctx, getSQL, 1)
|
||||
require.NoError(t, err)
|
||||
rows.Close()
|
||||
require.NoError(t, rows.Err())
|
||||
|
||||
// Now, change the schema of the table out from under the statement, making it invalid.
|
||||
_, err = conn.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1")
|
||||
@@ -1045,10 +949,10 @@ func TestStmtCacheInvalidationConn(t *testing.T) {
|
||||
rows.Close()
|
||||
for _, err := range []error{nextErr, rows.Err()} {
|
||||
if err == nil {
|
||||
t.Fatal("expected InvalidCachedStatementPlanError: no error")
|
||||
t.Fatal(`expected "cached plan must not change result type": no error`)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "cached plan must not change result type") {
|
||||
t.Fatalf("expected InvalidCachedStatementPlanError, got: %s", err.Error())
|
||||
t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1070,6 +974,10 @@ func TestStmtCacheInvalidationTx(t *testing.T) {
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
if conn.PgConn().ParameterStatus("crdb_version") != "" {
|
||||
t.Skip("Server has non-standard prepare in errored transaction behavior (https://github.com/cockroachdb/cockroach/issues/84140)")
|
||||
}
|
||||
|
||||
// create a table and fill it with some data
|
||||
_, err := conn.Exec(ctx, `
|
||||
DROP TABLE IF EXISTS drop_cols;
|
||||
@@ -1092,6 +1000,7 @@ func TestStmtCacheInvalidationTx(t *testing.T) {
|
||||
rows, err := tx.Query(ctx, getSQL, 1)
|
||||
require.NoError(t, err)
|
||||
rows.Close()
|
||||
require.NoError(t, rows.Err())
|
||||
|
||||
// Now, change the schema of the table out from under the statement, making it invalid.
|
||||
_, err = tx.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1")
|
||||
@@ -1109,18 +1018,17 @@ func TestStmtCacheInvalidationTx(t *testing.T) {
|
||||
rows.Close()
|
||||
for _, err := range []error{nextErr, rows.Err()} {
|
||||
if err == nil {
|
||||
t.Fatal("expected InvalidCachedStatementPlanError: no error")
|
||||
t.Fatal(`expected "cached plan must not change result type": no error`)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "cached plan must not change result type") {
|
||||
t.Fatalf("expected InvalidCachedStatementPlanError, got: %s", err.Error())
|
||||
t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
rows, err = tx.Query(ctx, getSQL, 1)
|
||||
require.NoError(t, err) // error does not pop up immediately
|
||||
rows.Next()
|
||||
rows, _ = tx.Query(ctx, getSQL, 1)
|
||||
rows.Close()
|
||||
err = rows.Err()
|
||||
// Retries within the same transaction are errors (really anything except a rollbakc
|
||||
// Retries within the same transaction are errors (really anything except a rollback
|
||||
// will be an error in this transaction).
|
||||
require.Error(t, err)
|
||||
rows.Close()
|
||||
@@ -1140,7 +1048,7 @@ func TestStmtCacheInvalidationTx(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInsertDurationInterval(t *testing.T) {
|
||||
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
_, err := conn.Exec(context.Background(), "create temporary table t(duration INTERVAL(0) NOT NULL)")
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1151,3 +1059,34 @@ func TestInsertDurationInterval(t *testing.T) {
|
||||
require.EqualValues(t, 1, n)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRawValuesUnderlyingMemoryReused(t *testing.T) {
|
||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
var buf []byte
|
||||
|
||||
rows, err := conn.Query(ctx, `select 1::int`)
|
||||
require.NoError(t, err)
|
||||
|
||||
for rows.Next() {
|
||||
buf = rows.RawValues()[0]
|
||||
}
|
||||
|
||||
require.NoError(t, rows.Err())
|
||||
|
||||
original := make([]byte, len(buf))
|
||||
copy(original, buf)
|
||||
|
||||
for i := 0; i < 1_000_000; i++ {
|
||||
rows, err := conn.Query(ctx, `select $1::int`, i)
|
||||
require.NoError(t, err)
|
||||
rows.Close()
|
||||
require.NoError(t, rows.Err())
|
||||
|
||||
if bytes.Compare(original, buf) != 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
t.Fatal("expected buffer from RawValues to be overwritten by subsequent queries but it was not")
|
||||
})
|
||||
}
|
||||
|
||||
+23
-22
@@ -5,20 +5,19 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
// CopyFromRows returns a CopyFromSource interface over the provided rows slice
|
||||
// making it usable by *Conn.CopyFrom.
|
||||
func CopyFromRows(rows [][]interface{}) CopyFromSource {
|
||||
func CopyFromRows(rows [][]any) CopyFromSource {
|
||||
return ©FromRows{rows: rows, idx: -1}
|
||||
}
|
||||
|
||||
type copyFromRows struct {
|
||||
rows [][]interface{}
|
||||
rows [][]any
|
||||
idx int
|
||||
}
|
||||
|
||||
@@ -27,7 +26,7 @@ func (ctr *copyFromRows) Next() bool {
|
||||
return ctr.idx < len(ctr.rows)
|
||||
}
|
||||
|
||||
func (ctr *copyFromRows) Values() ([]interface{}, error) {
|
||||
func (ctr *copyFromRows) Values() ([]any, error) {
|
||||
return ctr.rows[ctr.idx], nil
|
||||
}
|
||||
|
||||
@@ -37,12 +36,12 @@ func (ctr *copyFromRows) Err() error {
|
||||
|
||||
// CopyFromSlice returns a CopyFromSource interface over a dynamic func
|
||||
// making it usable by *Conn.CopyFrom.
|
||||
func CopyFromSlice(length int, next func(int) ([]interface{}, error)) CopyFromSource {
|
||||
func CopyFromSlice(length int, next func(int) ([]any, error)) CopyFromSource {
|
||||
return ©FromSlice{next: next, idx: -1, len: length}
|
||||
}
|
||||
|
||||
type copyFromSlice struct {
|
||||
next func(int) ([]interface{}, error)
|
||||
next func(int) ([]any, error)
|
||||
idx int
|
||||
len int
|
||||
err error
|
||||
@@ -53,7 +52,7 @@ func (cts *copyFromSlice) Next() bool {
|
||||
return cts.idx < cts.len
|
||||
}
|
||||
|
||||
func (cts *copyFromSlice) Values() ([]interface{}, error) {
|
||||
func (cts *copyFromSlice) Values() ([]any, error) {
|
||||
values, err := cts.next(cts.idx)
|
||||
if err != nil {
|
||||
cts.err = err
|
||||
@@ -73,7 +72,7 @@ type CopyFromSource interface {
|
||||
Next() bool
|
||||
|
||||
// Values returns the values for the current row.
|
||||
Values() ([]interface{}, error)
|
||||
Values() ([]any, error)
|
||||
|
||||
// Err returns any error that has been encountered by the CopyFromSource. If
|
||||
// this is not nil *Conn.CopyFrom will abort the copy.
|
||||
@@ -89,6 +88,13 @@ type copyFrom struct {
|
||||
}
|
||||
|
||||
func (ct *copyFrom) run(ctx context.Context) (int64, error) {
|
||||
if ct.conn.copyFromTracer != nil {
|
||||
ctx = ct.conn.copyFromTracer.TraceCopyFromStart(ctx, ct.conn, TraceCopyFromStartData{
|
||||
TableName: ct.tableName,
|
||||
ColumnNames: ct.columnNames,
|
||||
})
|
||||
}
|
||||
|
||||
quotedTableName := ct.tableName.Sanitize()
|
||||
cbuf := &bytes.Buffer{}
|
||||
for i, cn := range ct.columnNames {
|
||||
@@ -145,24 +151,19 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) {
|
||||
w.Close()
|
||||
}()
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
|
||||
|
||||
r.Close()
|
||||
<-doneChan
|
||||
|
||||
rowsAffected := commandTag.RowsAffected()
|
||||
endTime := time.Now()
|
||||
if err == nil {
|
||||
if ct.conn.shouldLog(LogLevelInfo) {
|
||||
ct.conn.log(ctx, LogLevelInfo, "CopyFrom", map[string]interface{}{"tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime), "rowCount": rowsAffected})
|
||||
}
|
||||
} else if ct.conn.shouldLog(LogLevelError) {
|
||||
ct.conn.log(ctx, LogLevelError, "CopyFrom", map[string]interface{}{"err": err, "tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime)})
|
||||
if ct.conn.copyFromTracer != nil {
|
||||
ct.conn.copyFromTracer.TraceCopyFromEnd(ctx, ct.conn, TraceCopyFromEndData{
|
||||
CommandTag: commandTag,
|
||||
Err: err,
|
||||
})
|
||||
}
|
||||
|
||||
return rowsAffected, err
|
||||
return commandTag.RowsAffected(), err
|
||||
}
|
||||
|
||||
func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) {
|
||||
@@ -178,7 +179,7 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (b
|
||||
|
||||
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
|
||||
for i, val := range values {
|
||||
buf, err = encodePreparedStatementArgument(ct.conn.connInfo, buf, sd.Fields[i].DataTypeOID, val)
|
||||
buf, err = encodeCopyValue(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val)
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
|
||||
+72
-32
@@ -8,8 +8,9 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgxtest"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -31,7 +32,7 @@ func TestConnCopyFromSmall(t *testing.T) {
|
||||
|
||||
tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
|
||||
|
||||
inputRows := [][]interface{}{
|
||||
inputRows := [][]any{
|
||||
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime},
|
||||
{nil, nil, nil, nil, nil, nil, nil},
|
||||
}
|
||||
@@ -49,7 +50,7 @@ func TestConnCopyFromSmall(t *testing.T) {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]interface{}
|
||||
var outputRows [][]any
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
@@ -87,13 +88,13 @@ func TestConnCopyFromSliceSmall(t *testing.T) {
|
||||
|
||||
tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
|
||||
|
||||
inputRows := [][]interface{}{
|
||||
inputRows := [][]any{
|
||||
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime},
|
||||
{nil, nil, nil, nil, nil, nil, nil},
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"},
|
||||
pgx.CopyFromSlice(len(inputRows), func(i int) ([]interface{}, error) {
|
||||
pgx.CopyFromSlice(len(inputRows), func(i int) ([]any, error) {
|
||||
return inputRows[i], nil
|
||||
}))
|
||||
if err != nil {
|
||||
@@ -108,7 +109,7 @@ func TestConnCopyFromSliceSmall(t *testing.T) {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]interface{}
|
||||
var outputRows [][]any
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
@@ -134,7 +135,7 @@ func TestConnCopyFromLarge(t *testing.T) {
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
skipCockroachDB(t, conn, "Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/52722)")
|
||||
pgxtest.SkipCockroachDB(t, conn, "Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/52722)")
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a int2,
|
||||
@@ -149,10 +150,10 @@ func TestConnCopyFromLarge(t *testing.T) {
|
||||
|
||||
tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
|
||||
|
||||
inputRows := [][]interface{}{}
|
||||
inputRows := [][]any{}
|
||||
|
||||
for i := 0; i < 10000; i++ {
|
||||
inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime, []byte{111, 111, 111, 111}})
|
||||
inputRows = append(inputRows, []any{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime, []byte{111, 111, 111, 111}})
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows))
|
||||
@@ -168,7 +169,7 @@ func TestConnCopyFromLarge(t *testing.T) {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]interface{}
|
||||
var outputRows [][]any
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
@@ -211,6 +212,14 @@ func TestConnCopyFromEnum(t *testing.T) {
|
||||
_, err = tx.Exec(ctx, `create type fruit as enum ('apple', 'orange', 'grape')`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Obviously using conn while a tx is in use and registering a type after the connection has been established are
|
||||
// really bad practices, but for the sake of convenience we do it in the test here.
|
||||
for _, name := range []string{"fruit", "color"} {
|
||||
typ, err := conn.LoadType(ctx, name)
|
||||
require.NoError(t, err)
|
||||
conn.TypeMap().RegisterType(typ)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx, `create table foo(
|
||||
a text,
|
||||
b color,
|
||||
@@ -221,7 +230,7 @@ func TestConnCopyFromEnum(t *testing.T) {
|
||||
)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
inputRows := [][]interface{}{
|
||||
inputRows := [][]any{
|
||||
{"abc", "blue", "grape", "orange", "orange", "def"},
|
||||
{nil, nil, nil, nil, nil, nil},
|
||||
}
|
||||
@@ -233,7 +242,7 @@ func TestConnCopyFromEnum(t *testing.T) {
|
||||
rows, err := conn.Query(ctx, "select * from foo")
|
||||
require.NoError(t, err)
|
||||
|
||||
var outputRows [][]interface{}
|
||||
var outputRows [][]any
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
require.NoError(t, err)
|
||||
@@ -256,7 +265,7 @@ func TestConnCopyFromJSON(t *testing.T) {
|
||||
defer closeConn(t, conn)
|
||||
|
||||
for _, typeName := range []string{"json", "jsonb"} {
|
||||
if _, ok := conn.ConnInfo().DataTypeForName(typeName); !ok {
|
||||
if _, ok := conn.TypeMap().TypeForName(typeName); !ok {
|
||||
return // No JSON/JSONB type -- must be running against old PostgreSQL
|
||||
}
|
||||
}
|
||||
@@ -266,8 +275,8 @@ func TestConnCopyFromJSON(t *testing.T) {
|
||||
b jsonb
|
||||
)`)
|
||||
|
||||
inputRows := [][]interface{}{
|
||||
{map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}},
|
||||
inputRows := [][]any{
|
||||
{map[string]any{"foo": "bar"}, map[string]any{"bar": "quz"}},
|
||||
{nil, nil},
|
||||
}
|
||||
|
||||
@@ -284,7 +293,7 @@ func TestConnCopyFromJSON(t *testing.T) {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]interface{}
|
||||
var outputRows [][]any
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
@@ -314,12 +323,12 @@ func (cfs *clientFailSource) Next() bool {
|
||||
return cfs.count < 100
|
||||
}
|
||||
|
||||
func (cfs *clientFailSource) Values() ([]interface{}, error) {
|
||||
func (cfs *clientFailSource) Values() ([]any, error) {
|
||||
if cfs.count == 3 {
|
||||
cfs.err = fmt.Errorf("client error")
|
||||
return nil, cfs.err
|
||||
}
|
||||
return []interface{}{make([]byte, 100000)}, nil
|
||||
return []any{make([]byte, 100000)}, nil
|
||||
}
|
||||
|
||||
func (cfs *clientFailSource) Err() error {
|
||||
@@ -337,7 +346,7 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) {
|
||||
b varchar not null
|
||||
)`)
|
||||
|
||||
inputRows := [][]interface{}{
|
||||
inputRows := [][]any{
|
||||
{int32(1), "abc"},
|
||||
{int32(2), nil}, // this row should trigger a failure
|
||||
{int32(3), "def"},
|
||||
@@ -359,7 +368,7 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]interface{}
|
||||
var outputRows [][]any
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
@@ -391,11 +400,11 @@ func (fs *failSource) Next() bool {
|
||||
return fs.count < 100
|
||||
}
|
||||
|
||||
func (fs *failSource) Values() ([]interface{}, error) {
|
||||
func (fs *failSource) Values() ([]any, error) {
|
||||
if fs.count == 3 {
|
||||
return []interface{}{nil}, nil
|
||||
return []any{nil}, nil
|
||||
}
|
||||
return []interface{}{make([]byte, 100000)}, nil
|
||||
return []any{make([]byte, 100000)}, nil
|
||||
}
|
||||
|
||||
func (fs *failSource) Err() error {
|
||||
@@ -408,6 +417,8 @@ func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server copy error does not fail fast")
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a bytea not null
|
||||
)`)
|
||||
@@ -436,7 +447,7 @@ func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]interface{}
|
||||
var outputRows [][]any
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
@@ -466,11 +477,11 @@ func (fs *slowFailRaceSource) Next() bool {
|
||||
return fs.count < 1000
|
||||
}
|
||||
|
||||
func (fs *slowFailRaceSource) Values() ([]interface{}, error) {
|
||||
func (fs *slowFailRaceSource) Values() ([]any, error) {
|
||||
if fs.count == 500 {
|
||||
return []interface{}{nil, nil}, nil
|
||||
return []any{nil, nil}, nil
|
||||
}
|
||||
return []interface{}{1, make([]byte, 1000)}, nil
|
||||
return []any{1, make([]byte, 1000)}, nil
|
||||
}
|
||||
|
||||
func (fs *slowFailRaceSource) Err() error {
|
||||
@@ -525,7 +536,7 @@ func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]interface{}
|
||||
var outputRows [][]any
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
@@ -554,8 +565,8 @@ func (cfs *clientFinalErrSource) Next() bool {
|
||||
return cfs.count < 5
|
||||
}
|
||||
|
||||
func (cfs *clientFinalErrSource) Values() ([]interface{}, error) {
|
||||
return []interface{}{make([]byte, 100000)}, nil
|
||||
func (cfs *clientFinalErrSource) Values() ([]any, error) {
|
||||
return []any{make([]byte, 100000)}, nil
|
||||
}
|
||||
|
||||
func (cfs *clientFinalErrSource) Err() error {
|
||||
@@ -585,7 +596,7 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]interface{}
|
||||
var outputRows [][]any
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
@@ -604,3 +615,32 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnCopyFromAutomaticStringConversion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a int8
|
||||
)`)
|
||||
|
||||
inputRows := [][]interface{}{
|
||||
{"42"},
|
||||
{"7"},
|
||||
{8},
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows))
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, len(inputRows), copyCount)
|
||||
|
||||
rows, _ := conn.Query(context.Background(), "select * from foo")
|
||||
nums, err := pgx.CollectRows(rows, pgx.RowTo[int64])
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, []int64{42, 7, 8}, nums)
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
// Package pgx is a PostgreSQL database driver.
|
||||
/*
|
||||
pgx provides lower level access to PostgreSQL than the standard database/sql. It remains as similar to the database/sql
|
||||
interface as possible while providing better speed and access to PostgreSQL specific features. Import
|
||||
github.com/jackc/pgx/v4/stdlib to use pgx as a database/sql compatible driver.
|
||||
pgx provides a native PostgreSQL driver and can act as a database/sql driver. The native PostgreSQL interface is similar
|
||||
to the database/sql interface while providing better speed and access to PostgreSQL specific features. Use
|
||||
github.com/jackc/pgx/v5/stdlib to use pgx as a database/sql compatible driver. See that package's documentation for
|
||||
details.
|
||||
|
||||
Establishing a Connection
|
||||
|
||||
@@ -12,57 +13,41 @@ The primary way of establishing a connection is with `pgx.Connect`.
|
||||
|
||||
The database connection string can be in URL or DSN format. Both PostgreSQL settings and pgx settings can be specified
|
||||
here. In addition, a config struct can be created by `ParseConfig` and modified before establishing the connection with
|
||||
`ConnectConfig`.
|
||||
|
||||
config, err := pgx.ParseConfig(os.Getenv("DATABASE_URL"))
|
||||
if err != nil {
|
||||
// ...
|
||||
}
|
||||
config.Logger = log15adapter.NewLogger(log.New("module", "pgx"))
|
||||
|
||||
conn, err := pgx.ConnectConfig(context.Background(), config)
|
||||
`ConnectConfig` to configure settings such as tracing that cannot be configured with a connection string.
|
||||
|
||||
Connection Pool
|
||||
|
||||
`*pgx.Conn` represents a single connection to the database and is not concurrency safe. Use sub-package pgxpool for a
|
||||
concurrency safe connection pool.
|
||||
`*pgx.Conn` represents a single connection to the database and is not concurrency safe. Use package
|
||||
github.com/jackc/pgx/v5/pgxpool for a concurrency safe connection pool.
|
||||
|
||||
Query Interface
|
||||
|
||||
pgx implements Query and Scan in the familiar database/sql style.
|
||||
pgx implements Query in the familiar database/sql style. However, pgx provides generic functions such as CollectRows and
|
||||
ForEachRow that are a simpler and safer way of processing rows than manually calling rows.Next(), rows.Scan, and
|
||||
rows.Err().
|
||||
|
||||
var sum int32
|
||||
CollectRows can be used collect all returned rows into a slice.
|
||||
|
||||
// Send the query to the server. The returned rows MUST be closed
|
||||
// before conn can be used again.
|
||||
rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
|
||||
rows, _ := conn.Query(context.Background(), "select generate_series(1,$1)", 5)
|
||||
numbers, err := pgx.CollectRows(rows, pgx.RowTo[int32])
|
||||
if err != nil {
|
||||
return err
|
||||
return err
|
||||
}
|
||||
// numbers => [1 2 3 4 5]
|
||||
|
||||
// rows.Close is called by rows.Next when all rows are read
|
||||
// or an error occurs in Next or Scan. So it may optionally be
|
||||
// omitted if nothing in the rows.Next loop can panic. It is
|
||||
// safe to close rows multiple times.
|
||||
defer rows.Close()
|
||||
ForEachRow can be used to execute a callback function for every row. This is often easier than iterating over rows
|
||||
directly.
|
||||
|
||||
// Iterate through the result set
|
||||
for rows.Next() {
|
||||
var n int32
|
||||
err = rows.Scan(&n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sum += n
|
||||
var sum, n int32
|
||||
rows, _ := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
|
||||
_, err := pgx.ForEachRow(rows, []any{&n}, func(pgx.QueryFuncRow) error {
|
||||
sum += n
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Any errors encountered by rows.Next or rows.Scan will be returned here
|
||||
if rows.Err() != nil {
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
// No errors found - do something with sum
|
||||
|
||||
pgx also implements QueryRow in the same style as database/sql.
|
||||
|
||||
var name string
|
||||
@@ -82,148 +67,10 @@ Use Exec to execute a query that does not return a result set.
|
||||
return errors.New("No row found to delete")
|
||||
}
|
||||
|
||||
QueryFunc can be used to execute a callback function for every row. This is often easier to use than Query.
|
||||
PostgreSQL Data Types
|
||||
|
||||
var sum, n int32
|
||||
_, err = conn.QueryFunc(
|
||||
context.Background(),
|
||||
"select generate_series(1,$1)",
|
||||
[]interface{}{10},
|
||||
[]interface{}{&n},
|
||||
func(pgx.QueryFuncRow) error {
|
||||
sum += n
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Base Type Mapping
|
||||
|
||||
pgx maps between all common base types directly between Go and PostgreSQL. In particular:
|
||||
|
||||
Go PostgreSQL
|
||||
-----------------------
|
||||
string varchar
|
||||
text
|
||||
|
||||
// Integers are automatically be converted to any other integer type if
|
||||
// it can be done without overflow or underflow.
|
||||
int8
|
||||
int16 smallint
|
||||
int32 int
|
||||
int64 bigint
|
||||
int
|
||||
uint8
|
||||
uint16
|
||||
uint32
|
||||
uint64
|
||||
uint
|
||||
|
||||
// Floats are strict and do not automatically convert like integers.
|
||||
float32 float4
|
||||
float64 float8
|
||||
|
||||
time.Time date
|
||||
timestamp
|
||||
timestamptz
|
||||
|
||||
[]byte bytea
|
||||
|
||||
|
||||
Null Mapping
|
||||
|
||||
pgx can map nulls in two ways. The first is package pgtype provides types that have a data field and a status field.
|
||||
They work in a similar fashion to database/sql. The second is to use a pointer to a pointer.
|
||||
|
||||
var foo pgtype.Varchar
|
||||
var bar *string
|
||||
err := conn.QueryRow("select foo, bar from widgets where id=$1", 42).Scan(&foo, &bar)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Array Mapping
|
||||
|
||||
pgx maps between int16, int32, int64, float32, float64, and string Go slices and the equivalent PostgreSQL array type.
|
||||
Go slices of native types do not support nulls, so if a PostgreSQL array that contains a null is read into a native Go
|
||||
slice an error will occur. The pgtype package includes many more array types for PostgreSQL types that do not directly
|
||||
map to native Go types.
|
||||
|
||||
JSON and JSONB Mapping
|
||||
|
||||
pgx includes built-in support to marshal and unmarshal between Go types and the PostgreSQL JSON and JSONB.
|
||||
|
||||
Inet and CIDR Mapping
|
||||
|
||||
pgx encodes from net.IPNet to and from inet and cidr PostgreSQL types. In addition, as a convenience pgx will encode
|
||||
from a net.IP; it will assume a /32 netmask for IPv4 and a /128 for IPv6.
|
||||
|
||||
Custom Type Support
|
||||
|
||||
pgx includes support for the common data types like integers, floats, strings, dates, and times that have direct
|
||||
mappings between Go and SQL. In addition, pgx uses the github.com/jackc/pgtype library to support more types. See
|
||||
documention for that library for instructions on how to implement custom types.
|
||||
|
||||
See example_custom_type_test.go for an example of a custom type for the PostgreSQL point type.
|
||||
|
||||
pgx also includes support for custom types implementing the database/sql.Scanner and database/sql/driver.Valuer
|
||||
interfaces.
|
||||
|
||||
If pgx does cannot natively encode a type and that type is a renamed type (e.g. type MyTime time.Time) pgx will attempt
|
||||
to encode the underlying type. While this is usually desired behavior it can produce surprising behavior if one the
|
||||
underlying type and the renamed type each implement database/sql interfaces and the other implements pgx interfaces. It
|
||||
is recommended that this situation be avoided by implementing pgx interfaces on the renamed type.
|
||||
|
||||
Composite types and row values
|
||||
|
||||
Row values and composite types are represented as pgtype.Record (https://pkg.go.dev/github.com/jackc/pgtype?tab=doc#Record).
|
||||
It is possible to get values of your custom type by implementing DecodeBinary interface. Decoding into
|
||||
pgtype.Record first can simplify process by avoiding dealing with raw protocol directly.
|
||||
|
||||
For example:
|
||||
|
||||
type MyType struct {
|
||||
a int // NULL will cause decoding error
|
||||
b *string // there can be NULL in this position in SQL
|
||||
}
|
||||
|
||||
func (t *MyType) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error {
|
||||
r := pgtype.Record{
|
||||
Fields: []pgtype.Value{&pgtype.Int4{}, &pgtype.Text{}},
|
||||
}
|
||||
|
||||
if err := r.DecodeBinary(ci, src); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.Status != pgtype.Present {
|
||||
return errors.New("BUG: decoding should not be called on NULL value")
|
||||
}
|
||||
|
||||
a := r.Fields[0].(*pgtype.Int4)
|
||||
b := r.Fields[1].(*pgtype.Text)
|
||||
|
||||
// type compatibility is checked by AssignTo
|
||||
// only lossless assignments will succeed
|
||||
if err := a.AssignTo(&t.a); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// AssignTo also deals with null value handling
|
||||
if err := b.AssignTo(&t.b); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
result := MyType{}
|
||||
err := conn.QueryRow(context.Background(), "select row(1, 'foo'::text)", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&r)
|
||||
|
||||
Raw Bytes Mapping
|
||||
|
||||
[]byte passed as arguments to Query, QueryRow, and Exec are passed unmodified to PostgreSQL.
|
||||
The package pgtype provides extensive and customizable support for converting Go values to and from PostgreSQL values
|
||||
including array and composite types. See that package's documentation for details.
|
||||
|
||||
Transactions
|
||||
|
||||
@@ -250,12 +97,13 @@ Transactions are started by calling Begin.
|
||||
The Tx returned from Begin also implements the Begin method. This can be used to implement pseudo nested transactions.
|
||||
These are internally implemented with savepoints.
|
||||
|
||||
Use BeginTx to control the transaction mode.
|
||||
Use BeginTx to control the transaction mode. BeginTx also can be used to ensure a new transaction is created instead of
|
||||
a pseudo nested transaction.
|
||||
|
||||
BeginFunc and BeginTxFunc are variants that begin a transaction, execute a function, and commit or rollback the
|
||||
BeginFunc and BeginTxFunc are functions that begin a transaction, execute a function, and commit or rollback the
|
||||
transaction depending on the return value of the function. These can be simpler and less error prone to use.
|
||||
|
||||
err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error {
|
||||
err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error {
|
||||
_, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
|
||||
return err
|
||||
})
|
||||
@@ -273,10 +121,10 @@ for information on how to customize or disable the statement cache.
|
||||
Copy Protocol
|
||||
|
||||
Use CopyFrom to efficiently insert multiple rows at a time using the PostgreSQL copy protocol. CopyFrom accepts a
|
||||
CopyFromSource interface. If the data is already in a [][]interface{} use CopyFromRows to wrap it in a CopyFromSource
|
||||
interface. Or implement CopyFromSource to avoid buffering the entire data set in memory.
|
||||
CopyFromSource interface. If the data is already in a [][]any use CopyFromRows to wrap it in a CopyFromSource interface.
|
||||
Or implement CopyFromSource to avoid buffering the entire data set in memory.
|
||||
|
||||
rows := [][]interface{}{
|
||||
rows := [][]any{
|
||||
{"John", "Smith", int32(36)},
|
||||
{"Jane", "Doe", int32(29)},
|
||||
}
|
||||
@@ -299,8 +147,8 @@ When you already have a typed array using CopyFromSlice can be more convenient.
|
||||
context.Background(),
|
||||
pgx.Identifier{"people"},
|
||||
[]string{"first_name", "last_name", "age"},
|
||||
pgx.CopyFromSlice(len(rows), func(i int) ([]interface{}, error) {
|
||||
return []interface{}{rows[i].FirstName, rows[i].LastName, rows[i].Age}, nil
|
||||
pgx.CopyFromSlice(len(rows), func(i int) ([]any, error) {
|
||||
return []any{rows[i].FirstName, rows[i].LastName, rows[i].Age}, nil
|
||||
}),
|
||||
)
|
||||
|
||||
@@ -321,20 +169,22 @@ notification is received or the context is canceled.
|
||||
}
|
||||
|
||||
|
||||
Logging
|
||||
Tracing and Logging
|
||||
|
||||
pgx defines a simple logger interface. Connections optionally accept a logger that satisfies this interface. Set
|
||||
LogLevel to control logging verbosity. Adapters for github.com/inconshreveable/log15, github.com/sirupsen/logrus,
|
||||
go.uber.org/zap, github.com/rs/zerolog, and the testing log are provided in the log directory.
|
||||
pgx supports tracing by setting ConnConfig.Tracer.
|
||||
|
||||
In addition, the tracelog package provides the TraceLog type which lets a traditional logger act as a Tracer.
|
||||
|
||||
For debug tracing of the actual PostgreSQL wire protocol messages see github.com/jackc/pgx/v5/pgproto3.
|
||||
|
||||
Lower Level PostgreSQL Functionality
|
||||
|
||||
pgx is implemented on top of github.com/jackc/pgconn a lower level PostgreSQL driver. The Conn.PgConn() method can be
|
||||
used to access this lower layer.
|
||||
github.com/jackc/pgx/v5/pgconn contains a lower level PostgreSQL driver roughly at the level of libpq. pgx.Conn in
|
||||
implemented on top of pgconn. The Conn.PgConn() method can be used to access this lower layer.
|
||||
|
||||
PgBouncer
|
||||
|
||||
pgx is compatible with PgBouncer in two modes. One is when the connection has a statement cache in "describe" mode. The
|
||||
other is when the connection is using the simple protocol. This can be set with the PreferSimpleProtocol config option.
|
||||
By default pgx automatically uses prepared statements. Prepared statements are incompaptible with PgBouncer. This can be
|
||||
disabled by setting a different QueryExecMode in ConnConfig.DefaultQueryExecMode.
|
||||
*/
|
||||
package pgx
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
package pgx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
"github.com/jackc/pgtype"
|
||||
"github.com/jackc/pgx/v4"
|
||||
)
|
||||
|
||||
var pointRegexp *regexp.Regexp = regexp.MustCompile(`^\((.*),(.*)\)$`)
|
||||
|
||||
// Point represents a point that may be null.
|
||||
type Point struct {
|
||||
X, Y float64 // Coordinates of point
|
||||
Status pgtype.Status
|
||||
}
|
||||
|
||||
func (dst *Point) Set(src interface{}) error {
|
||||
return fmt.Errorf("cannot convert %v to Point", src)
|
||||
}
|
||||
|
||||
func (dst *Point) Get() interface{} {
|
||||
switch dst.Status {
|
||||
case pgtype.Present:
|
||||
return dst
|
||||
case pgtype.Null:
|
||||
return nil
|
||||
default:
|
||||
return dst.Status
|
||||
}
|
||||
}
|
||||
|
||||
func (src *Point) AssignTo(dst interface{}) error {
|
||||
return fmt.Errorf("cannot assign %v to %T", src, dst)
|
||||
}
|
||||
|
||||
func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error {
|
||||
if src == nil {
|
||||
*dst = Point{Status: pgtype.Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
s := string(src)
|
||||
match := pointRegexp.FindStringSubmatch(s)
|
||||
if match == nil {
|
||||
return fmt.Errorf("Received invalid point: %v", s)
|
||||
}
|
||||
|
||||
x, err := strconv.ParseFloat(match[1], 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Received invalid point: %v", s)
|
||||
}
|
||||
y, err := strconv.ParseFloat(match[2], 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Received invalid point: %v", s)
|
||||
}
|
||||
|
||||
*dst = Point{X: x, Y: y, Status: pgtype.Present}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Point) String() string {
|
||||
if src.Status == pgtype.Null {
|
||||
return "null point"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%.1f, %.1f", src.X, src.Y)
|
||||
}
|
||||
|
||||
func Example_CustomType() {
|
||||
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
if err != nil {
|
||||
fmt.Printf("Unable to establish connection: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close(context.Background())
|
||||
|
||||
if conn.PgConn().ParameterStatus("crdb_version") != "" {
|
||||
// Skip test / example when running on CockroachDB which doesn't support the point type. Since an example can't be
|
||||
// skipped fake success instead.
|
||||
fmt.Println("null point")
|
||||
fmt.Println("1.5, 2.5")
|
||||
return
|
||||
}
|
||||
|
||||
// Override registered handler for point
|
||||
conn.ConnInfo().RegisterDataType(pgtype.DataType{
|
||||
Value: &Point{},
|
||||
Name: "point",
|
||||
OID: 600,
|
||||
})
|
||||
|
||||
p := &Point{}
|
||||
err = conn.QueryRow(context.Background(), "select null::point").Scan(p)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
fmt.Println(p)
|
||||
|
||||
err = conn.QueryRow(context.Background(), "select point(1.5,2.5)").Scan(p)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
fmt.Println(p)
|
||||
// Output:
|
||||
// null point
|
||||
// 1.5, 2.5
|
||||
}
|
||||
@@ -6,14 +6,14 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
var pool *pgxpool.Pool
|
||||
|
||||
func main() {
|
||||
var err error
|
||||
pool, err = pgxpool.Connect(context.Background(), os.Getenv("DATABASE_URL"))
|
||||
pool, err = pgxpool.New(context.Background(), os.Getenv("DATABASE_URL"))
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Unable to connect to database:", err)
|
||||
os.Exit(1)
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
var conn *pgx.Conn
|
||||
|
||||
@@ -3,13 +3,12 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/jackc/pgx/v4/log/log15adapter"
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
log "gopkg.in/inconshreveable/log15.v2"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
var db *pgxpool.Pool
|
||||
@@ -71,28 +70,21 @@ func urlHandler(w http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
|
||||
func main() {
|
||||
logger := log15adapter.NewLogger(log.New("module", "pgx"))
|
||||
|
||||
poolConfig, err := pgxpool.ParseConfig(os.Getenv("DATABASE_URL"))
|
||||
if err != nil {
|
||||
log.Crit("Unable to parse DATABASE_URL", "error", err)
|
||||
os.Exit(1)
|
||||
log.Fatalln("Unable to parse DATABASE_URL:", err)
|
||||
}
|
||||
|
||||
poolConfig.ConnConfig.Logger = logger
|
||||
|
||||
db, err = pgxpool.ConnectConfig(context.Background(), poolConfig)
|
||||
db, err = pgxpool.NewWithConfig(context.Background(), poolConfig)
|
||||
if err != nil {
|
||||
log.Crit("Unable to create connection pool", "error", err)
|
||||
os.Exit(1)
|
||||
log.Fatalln("Unable to create connection pool:", err)
|
||||
}
|
||||
|
||||
http.HandleFunc("/", urlHandler)
|
||||
|
||||
log.Info("Starting URL shortener on localhost:8080")
|
||||
log.Println("Starting URL shortener on localhost:8080")
|
||||
err = http.ListenAndServe("localhost:8080", nil)
|
||||
if err != nil {
|
||||
log.Crit("Unable to start web server", "error", err)
|
||||
os.Exit(1)
|
||||
log.Fatalln("Unable to start web server:", err)
|
||||
}
|
||||
}
|
||||
|
||||
+152
-109
@@ -1,69 +1,118 @@
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/jackc/pgtype"
|
||||
"github.com/jackc/pgx/v5/internal/anynil"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
type extendedQueryBuilder struct {
|
||||
paramValues [][]byte
|
||||
// ExtendedQueryBuilder is used to choose the parameter formats, to format the parameters and to choose the result
|
||||
// formats for an extended query.
|
||||
type ExtendedQueryBuilder struct {
|
||||
ParamValues [][]byte
|
||||
paramValueBytes []byte
|
||||
paramFormats []int16
|
||||
resultFormats []int16
|
||||
ParamFormats []int16
|
||||
ResultFormats []int16
|
||||
}
|
||||
|
||||
func (eqb *extendedQueryBuilder) AppendParam(ci *pgtype.ConnInfo, oid uint32, arg interface{}) error {
|
||||
f := chooseParameterFormatCode(ci, oid, arg)
|
||||
eqb.paramFormats = append(eqb.paramFormats, f)
|
||||
// Build sets ParamValues, ParamFormats, and ResultFormats for use with *PgConn.ExecParams or *PgConn.ExecPrepared. If
|
||||
// sd is nil then QueryExecModeExec behavior will be used.
|
||||
func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescription, args []any) error {
|
||||
eqb.reset()
|
||||
|
||||
v, err := eqb.encodeExtendedParamValue(ci, oid, f, arg)
|
||||
if err != nil {
|
||||
return err
|
||||
anynil.NormalizeSlice(args)
|
||||
|
||||
if sd == nil {
|
||||
return eqb.appendParamsForQueryExecModeExec(m, args)
|
||||
}
|
||||
|
||||
if len(sd.ParamOIDs) != len(args) {
|
||||
return fmt.Errorf("mismatched param and argument count")
|
||||
}
|
||||
|
||||
for i := range args {
|
||||
err := eqb.appendParam(m, sd.ParamOIDs[i], -1, args[i])
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to encode args[%d]: %v", i, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for i := range sd.Fields {
|
||||
eqb.appendResultFormat(m.FormatCodeForOID(sd.Fields[i].DataTypeOID))
|
||||
}
|
||||
eqb.paramValues = append(eqb.paramValues, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (eqb *extendedQueryBuilder) AppendResultFormat(f int16) {
|
||||
eqb.resultFormats = append(eqb.resultFormats, f)
|
||||
// appendParam appends a parameter to the query. format may be -1 to automatically choose the format. If arg is nil it
|
||||
// must be an untyped nil.
|
||||
func (eqb *ExtendedQueryBuilder) appendParam(m *pgtype.Map, oid uint32, format int16, arg any) error {
|
||||
if format == -1 {
|
||||
preferredFormat := eqb.chooseParameterFormatCode(m, oid, arg)
|
||||
preferredErr := eqb.appendParam(m, oid, preferredFormat, arg)
|
||||
if preferredErr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var otherFormat int16
|
||||
if preferredFormat == TextFormatCode {
|
||||
otherFormat = BinaryFormatCode
|
||||
} else {
|
||||
otherFormat = TextFormatCode
|
||||
}
|
||||
|
||||
otherErr := eqb.appendParam(m, oid, otherFormat, arg)
|
||||
if otherErr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return preferredErr // return the error from the preferred format
|
||||
}
|
||||
|
||||
v, err := eqb.encodeExtendedParamValue(m, oid, format, arg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
eqb.ParamFormats = append(eqb.ParamFormats, format)
|
||||
eqb.ParamValues = append(eqb.ParamValues, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset readies eqb to build another query.
|
||||
func (eqb *extendedQueryBuilder) Reset() {
|
||||
eqb.paramValues = eqb.paramValues[0:0]
|
||||
eqb.paramValueBytes = eqb.paramValueBytes[0:0]
|
||||
eqb.paramFormats = eqb.paramFormats[0:0]
|
||||
eqb.resultFormats = eqb.resultFormats[0:0]
|
||||
// appendResultFormat appends a result format to the query.
|
||||
func (eqb *ExtendedQueryBuilder) appendResultFormat(format int16) {
|
||||
eqb.ResultFormats = append(eqb.ResultFormats, format)
|
||||
}
|
||||
|
||||
if cap(eqb.paramValues) > 64 {
|
||||
eqb.paramValues = make([][]byte, 0, 64)
|
||||
// reset readies eqb to build another query.
|
||||
func (eqb *ExtendedQueryBuilder) reset() {
|
||||
eqb.ParamValues = eqb.ParamValues[0:0]
|
||||
eqb.paramValueBytes = eqb.paramValueBytes[0:0]
|
||||
eqb.ParamFormats = eqb.ParamFormats[0:0]
|
||||
eqb.ResultFormats = eqb.ResultFormats[0:0]
|
||||
|
||||
if cap(eqb.ParamValues) > 64 {
|
||||
eqb.ParamValues = make([][]byte, 0, 64)
|
||||
}
|
||||
|
||||
if cap(eqb.paramValueBytes) > 256 {
|
||||
eqb.paramValueBytes = make([]byte, 0, 256)
|
||||
}
|
||||
|
||||
if cap(eqb.paramFormats) > 64 {
|
||||
eqb.paramFormats = make([]int16, 0, 64)
|
||||
if cap(eqb.ParamFormats) > 64 {
|
||||
eqb.ParamFormats = make([]int16, 0, 64)
|
||||
}
|
||||
if cap(eqb.resultFormats) > 64 {
|
||||
eqb.resultFormats = make([]int16, 0, 64)
|
||||
if cap(eqb.ResultFormats) > 64 {
|
||||
eqb.ResultFormats = make([]int16, 0, 64)
|
||||
}
|
||||
}
|
||||
|
||||
func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, oid uint32, formatCode int16, arg interface{}) ([]byte, error) {
|
||||
if arg == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
refVal := reflect.ValueOf(arg)
|
||||
argIsPtr := refVal.Kind() == reflect.Ptr
|
||||
|
||||
if argIsPtr && refVal.IsNil() {
|
||||
func (eqb *ExtendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) {
|
||||
if anynil.Is(arg) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -71,91 +120,85 @@ func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, o
|
||||
eqb.paramValueBytes = make([]byte, 0, 128)
|
||||
}
|
||||
|
||||
var err error
|
||||
var buf []byte
|
||||
pos := len(eqb.paramValueBytes)
|
||||
|
||||
if arg, ok := arg.(string); ok {
|
||||
return []byte(arg), nil
|
||||
buf, err := m.Encode(oid, formatCode, arg, eqb.paramValueBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if buf == nil {
|
||||
return nil, nil
|
||||
}
|
||||
eqb.paramValueBytes = buf
|
||||
return eqb.paramValueBytes[pos:], nil
|
||||
}
|
||||
|
||||
// chooseParameterFormatCode determines the correct format code for an
|
||||
// argument to a prepared statement. It defaults to TextFormatCode if no
|
||||
// determination can be made.
|
||||
func (eqb *ExtendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid uint32, arg any) int16 {
|
||||
switch arg.(type) {
|
||||
case string, *string:
|
||||
return TextFormatCode
|
||||
}
|
||||
|
||||
if formatCode == TextFormatCode {
|
||||
if arg, ok := arg.(pgtype.TextEncoder); ok {
|
||||
buf, err = arg.EncodeText(ci, eqb.paramValueBytes)
|
||||
return m.FormatCodeForOID(oid)
|
||||
}
|
||||
|
||||
// appendParamsForQueryExecModeExec appends the args to eqb.
|
||||
//
|
||||
// Parameters must be encoded in the text format because of differences in type conversion between timestamps and
|
||||
// dates. In QueryExecModeExec we don't know what the actual PostgreSQL type is. To determine the type we use the
|
||||
// Go type to OID type mapping registered by RegisterDefaultPgType. However, the Go time.Time represents both
|
||||
// PostgreSQL timestamp[tz] and date. To use the binary format we would need to also specify what the PostgreSQL
|
||||
// type OID is. But that would mean telling PostgreSQL that we have sent a timestamp[tz] when what is needed is a date.
|
||||
// This means that the value is converted from text to timestamp[tz] to date. This means it does a time zone conversion
|
||||
// before converting it to date. This means that dates can be shifted by one day. In text format without that double
|
||||
// type conversion it takes the date directly and ignores time zone (i.e. it works).
|
||||
//
|
||||
// Given that the whole point of QueryExecModeExec is to operate without having to know the PostgreSQL types there is
|
||||
// no way to safely use binary or to specify the parameter OIDs.
|
||||
func (eqb *ExtendedQueryBuilder) appendParamsForQueryExecModeExec(m *pgtype.Map, args []any) error {
|
||||
for _, arg := range args {
|
||||
if arg == nil {
|
||||
err := eqb.appendParam(m, 0, TextFormatCode, arg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
if buf == nil {
|
||||
return nil, nil
|
||||
}
|
||||
eqb.paramValueBytes = buf
|
||||
return eqb.paramValueBytes[pos:], nil
|
||||
}
|
||||
} else if formatCode == BinaryFormatCode {
|
||||
if arg, ok := arg.(pgtype.BinaryEncoder); ok {
|
||||
buf, err = arg.EncodeBinary(ci, eqb.paramValueBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if buf == nil {
|
||||
return nil, nil
|
||||
}
|
||||
eqb.paramValueBytes = buf
|
||||
return eqb.paramValueBytes[pos:], nil
|
||||
}
|
||||
}
|
||||
|
||||
if argIsPtr {
|
||||
// We have already checked that arg is not pointing to nil,
|
||||
// so it is safe to dereference here.
|
||||
arg = refVal.Elem().Interface()
|
||||
return eqb.encodeExtendedParamValue(ci, oid, formatCode, arg)
|
||||
}
|
||||
|
||||
if dt, ok := ci.DataTypeForOID(oid); ok {
|
||||
value := dt.Value
|
||||
err := value.Set(arg)
|
||||
if err != nil {
|
||||
{
|
||||
if arg, ok := arg.(driver.Valuer); ok {
|
||||
v, err := callValuerValue(arg)
|
||||
} else {
|
||||
dt, ok := m.TypeForValue(arg)
|
||||
if !ok {
|
||||
var tv pgtype.TextValuer
|
||||
if tv, ok = arg.(pgtype.TextValuer); ok {
|
||||
t, err := tv.TextValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
dt, ok = m.TypeForOID(pgtype.TextOID)
|
||||
if ok {
|
||||
arg = t
|
||||
}
|
||||
return eqb.encodeExtendedParamValue(ci, oid, formatCode, v)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return eqb.encodeExtendedParamValue(ci, oid, formatCode, value)
|
||||
}
|
||||
|
||||
// There is no data type registered for the destination OID, but maybe there is data type registered for the arg
|
||||
// type. If so use it's text encoder (if available).
|
||||
if dt, ok := ci.DataTypeForValue(arg); ok {
|
||||
value := dt.Value
|
||||
if textEncoder, ok := value.(pgtype.TextEncoder); ok {
|
||||
err := value.Set(arg)
|
||||
if !ok {
|
||||
var str fmt.Stringer
|
||||
if str, ok = arg.(fmt.Stringer); ok {
|
||||
dt, ok = m.TypeForOID(pgtype.TextOID)
|
||||
if ok {
|
||||
arg = str.String()
|
||||
}
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
return &unknownArgumentTypeQueryExecModeExecError{arg: arg}
|
||||
}
|
||||
err := eqb.appendParam(m, dt.OID, TextFormatCode, arg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
buf, err = textEncoder.EncodeText(ci, eqb.paramValueBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if buf == nil {
|
||||
return nil, nil
|
||||
}
|
||||
eqb.paramValueBytes = buf
|
||||
return eqb.paramValueBytes[pos:], nil
|
||||
}
|
||||
}
|
||||
|
||||
if strippedArg, ok := stripNamedType(&refVal); ok {
|
||||
return eqb.encodeExtendedParamValue(ci, oid, formatCode, strippedArg)
|
||||
}
|
||||
return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,21 +1,20 @@
|
||||
module github.com/jackc/pgx/v4
|
||||
module github.com/jackc/pgx/v5
|
||||
|
||||
go 1.13
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/Masterminds/semver/v3 v3.1.1
|
||||
github.com/cockroachdb/apd v1.1.0
|
||||
github.com/go-kit/log v0.1.0
|
||||
github.com/gofrs/uuid v4.0.0+incompatible
|
||||
github.com/jackc/pgconn v1.13.0
|
||||
github.com/jackc/pgio v1.0.0
|
||||
github.com/jackc/pgproto3/v2 v2.3.1
|
||||
github.com/jackc/pgtype v1.12.0
|
||||
github.com/jackc/puddle v1.3.0
|
||||
github.com/rs/zerolog v1.15.0
|
||||
github.com/shopspring/decimal v1.2.0
|
||||
github.com/sirupsen/logrus v1.4.2
|
||||
github.com/jackc/pgpassfile v1.0.0
|
||||
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b
|
||||
github.com/jackc/puddle/v2 v2.0.0
|
||||
github.com/stretchr/testify v1.8.0
|
||||
go.uber.org/zap v1.13.0
|
||||
gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec
|
||||
golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90
|
||||
golang.org/x/text v0.3.7
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/kr/pretty v0.3.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -1,208 +1,42 @@
|
||||
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc=
|
||||
github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs=
|
||||
github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I=
|
||||
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
|
||||
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
|
||||
github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
|
||||
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-kit/log v0.1.0 h1:DGJh0Sm43HbOeYDNnVZFl8BvcYVvjD5bqYJvp0REbwQ=
|
||||
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
|
||||
github.com/go-logfmt/logfmt v0.5.0 h1:TrB8swr/68K7m9CcGut2g3UOihhbcbiMAYiuTXdEih4=
|
||||
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
|
||||
github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk=
|
||||
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
||||
github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw=
|
||||
github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
|
||||
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
|
||||
github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0=
|
||||
github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=
|
||||
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
|
||||
github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8=
|
||||
github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
|
||||
github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA=
|
||||
github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE=
|
||||
github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s=
|
||||
github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o=
|
||||
github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY=
|
||||
github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI=
|
||||
github.com/jackc/pgconn v1.13.0 h1:3L1XMNV2Zvca/8BYhzcRFS70Lr0WlDg16Di6SFGAbys=
|
||||
github.com/jackc/pgconn v1.13.0/go.mod h1:AnowpAqO4CMIIJNZl2VJp+KrkAZciAkhEl0W0JIobpI=
|
||||
github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE=
|
||||
github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8=
|
||||
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE=
|
||||
github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c=
|
||||
github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc=
|
||||
github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A=
|
||||
github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78=
|
||||
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
|
||||
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg=
|
||||
github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM=
|
||||
github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM=
|
||||
github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
|
||||
github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
|
||||
github.com/jackc/pgproto3/v2 v2.3.1 h1:nwj7qwf0S+Q7ISFfBndqeLwSwxs+4DPsbRFjECT1Y4Y=
|
||||
github.com/jackc/pgproto3/v2 v2.3.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
|
||||
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E=
|
||||
github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg=
|
||||
github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc=
|
||||
github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw=
|
||||
github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM=
|
||||
github.com/jackc/pgtype v1.12.0 h1:Dlq8Qvcch7kiehm8wPGIW0W3KsCCHJnRacKW0UM8n5w=
|
||||
github.com/jackc/pgtype v1.12.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4=
|
||||
github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y=
|
||||
github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM=
|
||||
github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc=
|
||||
github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs=
|
||||
github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
|
||||
github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
|
||||
github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
|
||||
github.com/jackc/puddle v1.2.1 h1:gI8os0wpRXFd4FiAY2dWiqRK037tjj3t7rKFeO4X5iw=
|
||||
github.com/jackc/puddle v1.2.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
|
||||
github.com/jackc/puddle v1.3.0 h1:eHK/5clGOatcjX3oWGBO/MpxpbHzSwud5EWTSCI+MX0=
|
||||
github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||
github.com/jackc/puddle/v2 v2.0.0 h1:Kwk/AlLigcnZsDssc3Zun1dk1tAtQNPaBBxBHWn0Mjc=
|
||||
github.com/jackc/puddle/v2 v2.0.0/go.mod h1:itE7ZJY8xnoo0JqJEpSMprN0f+NQkMCuEV/N9j8h0oc=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw=
|
||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||
github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||
github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8=
|
||||
github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ=
|
||||
github.com/mattn/go-colorable v0.1.6 h1:6Su7aK7lXmJ/U79bYtBjLNaha4Fs1Rg9plHpcH+vvnE=
|
||||
github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
|
||||
github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
|
||||
github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
|
||||
github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY=
|
||||
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
|
||||
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
|
||||
github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
|
||||
github.com/rs/zerolog v1.15.0 h1:uPRuwkWF4J6fGsJ2R0Gn2jB1EQiav9k3S6CSdygQJXY=
|
||||
github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc=
|
||||
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
|
||||
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4=
|
||||
github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ=
|
||||
github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
|
||||
github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
|
||||
github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4=
|
||||
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
|
||||
github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k=
|
||||
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
|
||||
go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
|
||||
go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
|
||||
go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
|
||||
go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk=
|
||||
go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
|
||||
go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0=
|
||||
go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4=
|
||||
go.uber.org/multierr v1.5.0 h1:KCa4XfM8CWFCpxXRGok+Q0SS/0XBhMDbHHGABQLvD2A=
|
||||
go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU=
|
||||
go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4=
|
||||
go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA=
|
||||
go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
|
||||
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
|
||||
go.uber.org/zap v1.13.0 h1:nR6NoDBgAf67s68NhaXbsojM+2gxp3S1hWkHDl27pVU=
|
||||
go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE=
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
|
||||
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c=
|
||||
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs=
|
||||
golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
|
||||
golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc=
|
||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90 h1:Y/gsMcFOcR+6S6f3YeMKl5g+dZMEWqcz5Czj/GWYbkM=
|
||||
golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
||||
golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
||||
golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
|
||||
golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20200103221440-774c71fcf114 h1:DnSr2mCsxyCE6ZgIkmcWUQY2R5cH/6wL7eIxEmQOMSE=
|
||||
golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
||||
golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec h1:RlWgLqCMMIYYEVcAR5MDsuHlVkaIPDAF+5Dehzg8L5A=
|
||||
gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM=
|
||||
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// This file contains code copied from the Go standard library due to the
|
||||
// required function not being public.
|
||||
|
||||
// Copyright (c) 2009 The Go Authors. All rights reserved.
|
||||
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
// From database/sql/convert.go
|
||||
|
||||
var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
|
||||
|
||||
// callValuerValue returns vr.Value(), with one exception:
|
||||
// If vr.Value is an auto-generated method on a pointer type and the
|
||||
// pointer is nil, it would panic at runtime in the panicwrap
|
||||
// method. Treat it like nil instead.
|
||||
// Issue 8415.
|
||||
//
|
||||
// This is so people can implement driver.Value on value types and
|
||||
// still use nil pointers to those types to mean nil/NULL, just like
|
||||
// string/*string.
|
||||
//
|
||||
// This function is mirrored in the database/sql/driver package.
|
||||
func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
|
||||
if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr &&
|
||||
rv.IsNil() &&
|
||||
rv.Type().Elem().Implements(valuerReflectType) {
|
||||
return nil, nil
|
||||
}
|
||||
return vr.Value()
|
||||
}
|
||||
+17
-52
@@ -7,48 +7,21 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgxtest"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func testWithAndWithoutPreferSimpleProtocol(t *testing.T, f func(t *testing.T, conn *pgx.Conn)) {
|
||||
t.Run("SimpleProto",
|
||||
func(t *testing.T) {
|
||||
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
var defaultConnTestRunner pgxtest.ConnTestRunner
|
||||
|
||||
config.PreferSimpleProtocol = true
|
||||
conn, err := pgx.ConnectConfig(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := conn.Close(context.Background())
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
f(t, conn)
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
},
|
||||
)
|
||||
|
||||
t.Run("DefaultProto",
|
||||
func(t *testing.T) {
|
||||
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := pgx.ConnectConfig(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := conn.Close(context.Background())
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
f(t, conn)
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
},
|
||||
)
|
||||
func init() {
|
||||
defaultConnTestRunner = pgxtest.DefaultConnTestRunner()
|
||||
defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
|
||||
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
return config
|
||||
}
|
||||
}
|
||||
|
||||
func mustConnectString(t testing.TB, connString string) *pgx.Conn {
|
||||
@@ -80,7 +53,7 @@ func closeConn(t testing.TB, conn *pgx.Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag) {
|
||||
func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...any) (commandTag pgconn.CommandTag) {
|
||||
var err error
|
||||
if commandTag, err = conn.Exec(context.Background(), sql, arguments...); err != nil {
|
||||
t.Fatalf("Exec unexpectedly failed with %v: %v", sql, err)
|
||||
@@ -89,7 +62,7 @@ func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}
|
||||
}
|
||||
|
||||
// Do a simple query to ensure the connection is still usable
|
||||
func ensureConnValid(t *testing.T, conn *pgx.Conn) {
|
||||
func ensureConnValid(t testing.TB, conn *pgx.Conn) {
|
||||
var sum, rowCount int32
|
||||
|
||||
rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
|
||||
@@ -125,13 +98,11 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equalf(t, expected.Logger, actual.Logger, "%s - Logger", testName)
|
||||
assert.Equalf(t, expected.LogLevel, actual.LogLevel, "%s - LogLevel", testName)
|
||||
assert.Equalf(t, expected.Tracer, actual.Tracer, "%s - Tracer", testName)
|
||||
assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName)
|
||||
// Can't test function equality, so just test that they are set or not.
|
||||
assert.Equalf(t, expected.BuildStatementCache == nil, actual.BuildStatementCache == nil, "%s - BuildStatementCache", testName)
|
||||
assert.Equalf(t, expected.PreferSimpleProtocol, actual.PreferSimpleProtocol, "%s - PreferSimpleProtocol", testName)
|
||||
|
||||
assert.Equalf(t, expected.StatementCacheCapacity, actual.StatementCacheCapacity, "%s - StatementCacheCapacity", testName)
|
||||
assert.Equalf(t, expected.DescriptionCacheCapacity, actual.DescriptionCacheCapacity, "%s - DescriptionCacheCapacity", testName)
|
||||
assert.Equalf(t, expected.DefaultQueryExecMode, actual.DefaultQueryExecMode, "%s - DefaultQueryExecMode", testName)
|
||||
assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName)
|
||||
assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName)
|
||||
assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName)
|
||||
@@ -165,9 +136,3 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func skipCockroachDB(t testing.TB, conn *pgx.Conn, msg string) {
|
||||
if conn.PgConn().ParameterStatus("crdb_version") != "" {
|
||||
t.Skip(msg)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
package anynil
|
||||
|
||||
import "reflect"
|
||||
|
||||
// Is returns true if value is any type of nil. e.g. nil or []byte(nil).
|
||||
func Is(value any) bool {
|
||||
if value == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
refVal := reflect.ValueOf(value)
|
||||
switch refVal.Kind() {
|
||||
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice:
|
||||
return refVal.IsNil()
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize converts typed nils (e.g. []byte(nil)) into untyped nil. Other values are returned unmodified.
|
||||
func Normalize(v any) any {
|
||||
if Is(v) {
|
||||
return nil
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// NormalizeSlice converts all typed nils (e.g. []byte(nil)) in s into untyped nils. Other values are unmodified. s is
|
||||
// mutated in place.
|
||||
func NormalizeSlice(s []any) {
|
||||
for i := range s {
|
||||
if Is(s[i]) {
|
||||
s[i] = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
// Package iobufpool implements a global segregated-fit pool of buffers for IO.
|
||||
package iobufpool
|
||||
|
||||
import "sync"
|
||||
|
||||
const minPoolExpOf2 = 8
|
||||
|
||||
var pools [18]*sync.Pool
|
||||
|
||||
func init() {
|
||||
for i := range pools {
|
||||
bufLen := 1 << (minPoolExpOf2 + i)
|
||||
pools[i] = &sync.Pool{New: func() any { return make([]byte, bufLen) }}
|
||||
}
|
||||
}
|
||||
|
||||
// Get gets a []byte of len size with cap <= size*2.
|
||||
func Get(size int) []byte {
|
||||
i := getPoolIdx(size)
|
||||
if i >= len(pools) {
|
||||
return make([]byte, size)
|
||||
}
|
||||
return pools[i].Get().([]byte)[:size]
|
||||
}
|
||||
|
||||
func getPoolIdx(size int) int {
|
||||
size--
|
||||
size >>= minPoolExpOf2
|
||||
i := 0
|
||||
for size > 0 {
|
||||
size >>= 1
|
||||
i++
|
||||
}
|
||||
|
||||
return i
|
||||
}
|
||||
|
||||
// Put returns buf to the pool.
|
||||
func Put(buf []byte) {
|
||||
i := putPoolIdx(cap(buf))
|
||||
if i < 0 {
|
||||
return
|
||||
}
|
||||
|
||||
pools[i].Put(buf)
|
||||
}
|
||||
|
||||
func putPoolIdx(size int) int {
|
||||
minPoolSize := 1 << minPoolExpOf2
|
||||
for i := range pools {
|
||||
if size == minPoolSize<<i {
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
package iobufpool
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPoolIdx(t *testing.T) {
|
||||
tests := []struct {
|
||||
size int
|
||||
expected int
|
||||
}{
|
||||
{size: 0, expected: 0},
|
||||
{size: 1, expected: 0},
|
||||
{size: 255, expected: 0},
|
||||
{size: 256, expected: 0},
|
||||
{size: 257, expected: 1},
|
||||
{size: 511, expected: 1},
|
||||
{size: 512, expected: 1},
|
||||
{size: 513, expected: 2},
|
||||
{size: 1023, expected: 2},
|
||||
{size: 1024, expected: 2},
|
||||
{size: 1025, expected: 3},
|
||||
{size: 2047, expected: 3},
|
||||
{size: 2048, expected: 3},
|
||||
{size: 2049, expected: 4},
|
||||
{size: 8388607, expected: 15},
|
||||
{size: 8388608, expected: 15},
|
||||
{size: 8388609, expected: 16},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
idx := getPoolIdx(tt.size)
|
||||
assert.Equalf(t, tt.expected, idx, "size: %d", tt.size)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
package iobufpool_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/iobufpool"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetCap(t *testing.T) {
|
||||
tests := []struct {
|
||||
requestedLen int
|
||||
expectedCap int
|
||||
}{
|
||||
{requestedLen: 0, expectedCap: 256},
|
||||
{requestedLen: 128, expectedCap: 256},
|
||||
{requestedLen: 255, expectedCap: 256},
|
||||
{requestedLen: 256, expectedCap: 256},
|
||||
{requestedLen: 257, expectedCap: 512},
|
||||
{requestedLen: 511, expectedCap: 512},
|
||||
{requestedLen: 512, expectedCap: 512},
|
||||
{requestedLen: 513, expectedCap: 1024},
|
||||
{requestedLen: 1023, expectedCap: 1024},
|
||||
{requestedLen: 1024, expectedCap: 1024},
|
||||
{requestedLen: 33554431, expectedCap: 33554432},
|
||||
{requestedLen: 33554432, expectedCap: 33554432},
|
||||
|
||||
// Above 32 MiB skip the pool and allocate exactly the requested size.
|
||||
{requestedLen: 33554433, expectedCap: 33554433},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
buf := iobufpool.Get(tt.requestedLen)
|
||||
assert.Equalf(t, tt.requestedLen, len(buf), "bad len for requestedLen: %d", len(buf), tt.requestedLen)
|
||||
assert.Equalf(t, tt.expectedCap, cap(buf), "bad cap for requestedLen: %d", tt.requestedLen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutHandlesWrongSizedBuffers(t *testing.T) {
|
||||
for putBufSize := range []int{0, 1, 128, 250, 256, 257, 1023, 1024, 1025, 1 << 28} {
|
||||
putBuf := make([]byte, putBufSize)
|
||||
iobufpool.Put(putBuf)
|
||||
|
||||
tests := []struct {
|
||||
requestedLen int
|
||||
expectedCap int
|
||||
}{
|
||||
{requestedLen: 0, expectedCap: 256},
|
||||
{requestedLen: 128, expectedCap: 256},
|
||||
{requestedLen: 255, expectedCap: 256},
|
||||
{requestedLen: 256, expectedCap: 256},
|
||||
{requestedLen: 257, expectedCap: 512},
|
||||
{requestedLen: 511, expectedCap: 512},
|
||||
{requestedLen: 512, expectedCap: 512},
|
||||
{requestedLen: 513, expectedCap: 1024},
|
||||
{requestedLen: 1023, expectedCap: 1024},
|
||||
{requestedLen: 1024, expectedCap: 1024},
|
||||
{requestedLen: 33554431, expectedCap: 33554432},
|
||||
{requestedLen: 33554432, expectedCap: 33554432},
|
||||
|
||||
// Above 32 MiB skip the pool and allocate exactly the requested size.
|
||||
{requestedLen: 33554433, expectedCap: 33554433},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
getBuf := iobufpool.Get(tt.requestedLen)
|
||||
assert.Equalf(t, tt.requestedLen, len(getBuf), "len(putBuf): %d, requestedLen: %d", len(putBuf), tt.requestedLen)
|
||||
assert.Equalf(t, tt.expectedCap, cap(getBuf), "cap(putBuf): %d, requestedLen: %d", cap(putBuf), tt.requestedLen)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutGetBufferReuse(t *testing.T) {
|
||||
// There is no way to guarantee a buffer will be reused. It should be, but a GC between the Put and the Get will cause
|
||||
// it not to be. So try many times.
|
||||
for i := 0; i < 100000; i++ {
|
||||
buf := iobufpool.Get(4)
|
||||
buf[0] = 1
|
||||
iobufpool.Put(buf)
|
||||
buf = iobufpool.Get(4)
|
||||
if buf[0] == 1 {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
t.Error("buffer was never reused")
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package nbconn
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
const minBufferQueueLen = 8
|
||||
|
||||
type bufferQueue struct {
|
||||
lock sync.Mutex
|
||||
queue [][]byte
|
||||
r, w int
|
||||
}
|
||||
|
||||
func (bq *bufferQueue) pushBack(buf []byte) {
|
||||
bq.lock.Lock()
|
||||
defer bq.lock.Unlock()
|
||||
|
||||
if bq.w >= len(bq.queue) {
|
||||
bq.growQueue()
|
||||
}
|
||||
bq.queue[bq.w] = buf
|
||||
bq.w++
|
||||
}
|
||||
|
||||
func (bq *bufferQueue) pushFront(buf []byte) {
|
||||
bq.lock.Lock()
|
||||
defer bq.lock.Unlock()
|
||||
|
||||
if bq.w >= len(bq.queue) {
|
||||
bq.growQueue()
|
||||
}
|
||||
copy(bq.queue[bq.r+1:bq.w+1], bq.queue[bq.r:bq.w])
|
||||
bq.queue[bq.r] = buf
|
||||
bq.w++
|
||||
}
|
||||
|
||||
func (bq *bufferQueue) popFront() []byte {
|
||||
bq.lock.Lock()
|
||||
defer bq.lock.Unlock()
|
||||
|
||||
if bq.r == bq.w {
|
||||
return nil
|
||||
}
|
||||
|
||||
buf := bq.queue[bq.r]
|
||||
bq.queue[bq.r] = nil // Clear reference so it can be garbage collected.
|
||||
bq.r++
|
||||
|
||||
if bq.r == bq.w {
|
||||
bq.r = 0
|
||||
bq.w = 0
|
||||
if len(bq.queue) > minBufferQueueLen {
|
||||
bq.queue = make([][]byte, minBufferQueueLen)
|
||||
}
|
||||
}
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
func (bq *bufferQueue) growQueue() {
|
||||
desiredLen := (len(bq.queue) + 1) * 3 / 2
|
||||
if desiredLen < minBufferQueueLen {
|
||||
desiredLen = minBufferQueueLen
|
||||
}
|
||||
|
||||
newQueue := make([][]byte, desiredLen)
|
||||
copy(newQueue, bq.queue)
|
||||
bq.queue = newQueue
|
||||
}
|
||||
@@ -0,0 +1,476 @@
|
||||
// Package nbconn implements a non-blocking net.Conn wrapper.
|
||||
//
|
||||
// It is designed to solve three problems.
|
||||
//
|
||||
// The first is resolving the deadlock that can occur when both sides of a connection are blocked writing because all
|
||||
// buffers between are full. See https://github.com/jackc/pgconn/issues/27 for discussion.
|
||||
//
|
||||
// The second is the inability to use a write deadline with a TLS.Conn without killing the connection.
|
||||
//
|
||||
// The third is to efficiently check if a connection has been closed via a non-blocking read.
|
||||
package nbconn
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/iobufpool"
|
||||
)
|
||||
|
||||
var errClosed = errors.New("closed")
|
||||
var ErrWouldBlock = new(wouldBlockError)
|
||||
|
||||
const fakeNonblockingWaitDuration = 100 * time.Millisecond
|
||||
|
||||
// NonBlockingDeadline is a magic value that when passed to Set[Read]Deadline places the connection in non-blocking read
|
||||
// mode.
|
||||
var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 608536336, time.UTC)
|
||||
|
||||
// disableSetDeadlineDeadline is a magic value that when passed to Set[Read|Write]Deadline causes those methods to
|
||||
// ignore all future calls.
|
||||
var disableSetDeadlineDeadline = time.Date(1900, 1, 1, 0, 0, 0, 968549727, time.UTC)
|
||||
|
||||
// wouldBlockError implements net.Error so tls.Conn will recognize ErrWouldBlock as a temporary error.
|
||||
type wouldBlockError struct{}
|
||||
|
||||
func (*wouldBlockError) Error() string {
|
||||
return "would block"
|
||||
}
|
||||
|
||||
func (*wouldBlockError) Timeout() bool { return true }
|
||||
func (*wouldBlockError) Temporary() bool { return true }
|
||||
|
||||
// Conn is a net.Conn where Write never blocks and always succeeds. Flush or Read must be called to actually write to
|
||||
// the underlying connection.
|
||||
type Conn interface {
|
||||
net.Conn
|
||||
|
||||
// Flush flushes any buffered writes.
|
||||
Flush() error
|
||||
|
||||
// BufferReadUntilBlock reads and buffers any sucessfully read bytes until the read would block.
|
||||
BufferReadUntilBlock() error
|
||||
}
|
||||
|
||||
// NetConn is a non-blocking net.Conn wrapper. It implements net.Conn.
|
||||
type NetConn struct {
|
||||
conn net.Conn
|
||||
rawConn syscall.RawConn
|
||||
|
||||
readQueue bufferQueue
|
||||
writeQueue bufferQueue
|
||||
|
||||
readFlushLock sync.Mutex
|
||||
// non-blocking writes with syscall.RawConn are done with a callback function. By using these fields instead of the
|
||||
// callback functions closure to pass the buf argument and receive the n and err results we avoid some allocations.
|
||||
nonblockWriteBuf []byte
|
||||
nonblockWriteErr error
|
||||
nonblockWriteN int
|
||||
|
||||
readDeadlineLock sync.Mutex
|
||||
readDeadline time.Time
|
||||
readNonblocking bool
|
||||
|
||||
writeDeadlineLock sync.Mutex
|
||||
writeDeadline time.Time
|
||||
|
||||
// Only access with atomics
|
||||
closed int64 // 0 = not closed, 1 = closed
|
||||
}
|
||||
|
||||
func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn {
|
||||
nc := &NetConn{
|
||||
conn: conn,
|
||||
}
|
||||
|
||||
if !fakeNonBlockingIO {
|
||||
if sc, ok := conn.(syscall.Conn); ok {
|
||||
if rawConn, err := sc.SyscallConn(); err == nil {
|
||||
nc.rawConn = rawConn
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nc
|
||||
}
|
||||
|
||||
// Read implements io.Reader.
|
||||
func (c *NetConn) Read(b []byte) (n int, err error) {
|
||||
if c.isClosed() {
|
||||
return 0, errClosed
|
||||
}
|
||||
|
||||
c.readFlushLock.Lock()
|
||||
defer c.readFlushLock.Unlock()
|
||||
|
||||
err = c.flush()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
for n < len(b) {
|
||||
buf := c.readQueue.popFront()
|
||||
if buf == nil {
|
||||
break
|
||||
}
|
||||
copiedN := copy(b[n:], buf)
|
||||
if copiedN < len(buf) {
|
||||
buf = buf[copiedN:]
|
||||
c.readQueue.pushFront(buf)
|
||||
} else {
|
||||
iobufpool.Put(buf)
|
||||
}
|
||||
n += copiedN
|
||||
}
|
||||
|
||||
// If any bytes were already buffered return them without trying to do a Read. Otherwise, when the caller is trying to
|
||||
// Read up to len(b) bytes but all available bytes have already been buffered the underlying Read would block.
|
||||
if n > 0 {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
var readNonblocking bool
|
||||
c.readDeadlineLock.Lock()
|
||||
readNonblocking = c.readNonblocking
|
||||
c.readDeadlineLock.Unlock()
|
||||
|
||||
var readN int
|
||||
if readNonblocking {
|
||||
readN, err = c.nonblockingRead(b[n:])
|
||||
} else {
|
||||
readN, err = c.conn.Read(b[n:])
|
||||
}
|
||||
n += readN
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is
|
||||
// closed. Call Flush to actually write to the underlying connection.
|
||||
func (c *NetConn) Write(b []byte) (n int, err error) {
|
||||
if c.isClosed() {
|
||||
return 0, errClosed
|
||||
}
|
||||
|
||||
buf := iobufpool.Get(len(b))
|
||||
copy(buf, b)
|
||||
c.writeQueue.pushBack(buf)
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c *NetConn) Close() (err error) {
|
||||
swapped := atomic.CompareAndSwapInt64(&c.closed, 0, 1)
|
||||
if !swapped {
|
||||
return errClosed
|
||||
}
|
||||
|
||||
defer func() {
|
||||
closeErr := c.conn.Close()
|
||||
if err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
|
||||
c.readFlushLock.Lock()
|
||||
defer c.readFlushLock.Unlock()
|
||||
err = c.flush()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *NetConn) LocalAddr() net.Addr {
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *NetConn) RemoteAddr() net.Addr {
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t).
|
||||
func (c *NetConn) SetDeadline(t time.Time) error {
|
||||
err := c.SetReadDeadline(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
// SetReadDeadline sets the read deadline as t. If t == NonBlockingDeadline then future reads will be non-blocking.
|
||||
func (c *NetConn) SetReadDeadline(t time.Time) error {
|
||||
if c.isClosed() {
|
||||
return errClosed
|
||||
}
|
||||
|
||||
c.readDeadlineLock.Lock()
|
||||
defer c.readDeadlineLock.Unlock()
|
||||
if c.readDeadline == disableSetDeadlineDeadline {
|
||||
return nil
|
||||
}
|
||||
if t == disableSetDeadlineDeadline {
|
||||
c.readDeadline = t
|
||||
return nil
|
||||
}
|
||||
|
||||
if t == NonBlockingDeadline {
|
||||
c.readNonblocking = true
|
||||
t = time.Time{}
|
||||
} else {
|
||||
c.readNonblocking = false
|
||||
}
|
||||
|
||||
c.readDeadline = t
|
||||
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *NetConn) SetWriteDeadline(t time.Time) error {
|
||||
if c.isClosed() {
|
||||
return errClosed
|
||||
}
|
||||
|
||||
c.writeDeadlineLock.Lock()
|
||||
defer c.writeDeadlineLock.Unlock()
|
||||
if c.writeDeadline == disableSetDeadlineDeadline {
|
||||
return nil
|
||||
}
|
||||
if t == disableSetDeadlineDeadline {
|
||||
c.writeDeadline = t
|
||||
return nil
|
||||
}
|
||||
|
||||
c.writeDeadline = t
|
||||
|
||||
return c.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (c *NetConn) Flush() error {
|
||||
if c.isClosed() {
|
||||
return errClosed
|
||||
}
|
||||
|
||||
c.readFlushLock.Lock()
|
||||
defer c.readFlushLock.Unlock()
|
||||
return c.flush()
|
||||
}
|
||||
|
||||
// flush does the actual work of flushing the writeQueue. readFlushLock must already be held.
|
||||
func (c *NetConn) flush() error {
|
||||
var stopChan chan struct{}
|
||||
var errChan chan error
|
||||
|
||||
defer func() {
|
||||
if stopChan != nil {
|
||||
select {
|
||||
case stopChan <- struct{}{}:
|
||||
case <-errChan:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for buf := c.writeQueue.popFront(); buf != nil; buf = c.writeQueue.popFront() {
|
||||
remainingBuf := buf
|
||||
for len(remainingBuf) > 0 {
|
||||
n, err := c.nonblockingWrite(remainingBuf)
|
||||
remainingBuf = remainingBuf[n:]
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrWouldBlock) {
|
||||
buf = buf[:len(remainingBuf)]
|
||||
copy(buf, remainingBuf)
|
||||
c.writeQueue.pushFront(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
// Writing was blocked. Reading might unblock it.
|
||||
if stopChan == nil {
|
||||
stopChan, errChan = c.bufferNonblockingRead()
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
stopChan = nil
|
||||
return err
|
||||
default:
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
iobufpool.Put(buf)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *NetConn) BufferReadUntilBlock() error {
|
||||
for {
|
||||
buf := iobufpool.Get(8 * 1024)
|
||||
n, err := c.nonblockingRead(buf)
|
||||
if n > 0 {
|
||||
buf = buf[:n]
|
||||
c.readQueue.pushBack(buf)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrWouldBlock) {
|
||||
return nil
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NetConn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) {
|
||||
stopChan = make(chan struct{})
|
||||
errChan = make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
err := c.BufferReadUntilBlock()
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-stopChan:
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return stopChan, errChan
|
||||
}
|
||||
|
||||
func (c *NetConn) isClosed() bool {
|
||||
closed := atomic.LoadInt64(&c.closed)
|
||||
return closed == 1
|
||||
}
|
||||
|
||||
func (c *NetConn) nonblockingWrite(b []byte) (n int, err error) {
|
||||
if c.rawConn == nil {
|
||||
return c.fakeNonblockingWrite(b)
|
||||
} else {
|
||||
return c.realNonblockingWrite(b)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) {
|
||||
c.writeDeadlineLock.Lock()
|
||||
defer c.writeDeadlineLock.Unlock()
|
||||
|
||||
deadline := time.Now().Add(fakeNonblockingWaitDuration)
|
||||
if c.writeDeadline.IsZero() || deadline.Before(c.writeDeadline) {
|
||||
err = c.conn.SetWriteDeadline(deadline)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() {
|
||||
// Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails.
|
||||
c.conn.SetWriteDeadline(c.writeDeadline)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
err = ErrWouldBlock
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return c.conn.Write(b)
|
||||
}
|
||||
|
||||
func (c *NetConn) nonblockingRead(b []byte) (n int, err error) {
|
||||
if c.rawConn == nil {
|
||||
return c.fakeNonblockingRead(b)
|
||||
} else {
|
||||
return c.realNonblockingRead(b)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) {
|
||||
c.readDeadlineLock.Lock()
|
||||
defer c.readDeadlineLock.Unlock()
|
||||
|
||||
deadline := time.Now().Add(fakeNonblockingWaitDuration)
|
||||
if c.readDeadline.IsZero() || deadline.Before(c.readDeadline) {
|
||||
err = c.conn.SetReadDeadline(deadline)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() {
|
||||
// Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails.
|
||||
c.conn.SetReadDeadline(c.readDeadline)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
err = ErrWouldBlock
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return c.conn.Read(b)
|
||||
}
|
||||
|
||||
// syscall.Conn is interface
|
||||
|
||||
// TLSClient establishes a TLS connection as a client over conn using config.
|
||||
//
|
||||
// To avoid the first Read on the returned *TLSConn also triggering a Write due to the TLS handshake and thereby
|
||||
// potentially causing a read and write deadlines to behave unexpectedly, Handshake is called explicitly before the
|
||||
// *TLSConn is returned.
|
||||
func TLSClient(conn *NetConn, config *tls.Config) (*TLSConn, error) {
|
||||
tc := tls.Client(conn, config)
|
||||
err := tc.Handshake()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Ensure last written part of Handshake is actually sent.
|
||||
err = conn.Flush()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TLSConn{
|
||||
tlsConn: tc,
|
||||
nbConn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TLSConn is a TLS wrapper around a *Conn. It works around a temporary write error (such as a timeout) being fatal to a
|
||||
// tls.Conn.
|
||||
type TLSConn struct {
|
||||
tlsConn *tls.Conn
|
||||
nbConn *NetConn
|
||||
}
|
||||
|
||||
func (tc *TLSConn) Read(b []byte) (n int, err error) { return tc.tlsConn.Read(b) }
|
||||
func (tc *TLSConn) Write(b []byte) (n int, err error) { return tc.tlsConn.Write(b) }
|
||||
func (tc *TLSConn) BufferReadUntilBlock() error { return tc.nbConn.BufferReadUntilBlock() }
|
||||
func (tc *TLSConn) Flush() error { return tc.nbConn.Flush() }
|
||||
func (tc *TLSConn) LocalAddr() net.Addr { return tc.tlsConn.LocalAddr() }
|
||||
func (tc *TLSConn) RemoteAddr() net.Addr { return tc.tlsConn.RemoteAddr() }
|
||||
|
||||
func (tc *TLSConn) Close() error {
|
||||
// tls.Conn.closeNotify() sets a 5 second deadline to avoid blocking, sends a TLS alert close notification, and then
|
||||
// sets the deadline to now. This causes NetConn's Close not to be able to flush the write buffer. Instead we set our
|
||||
// own 5 second deadline then make all set deadlines no-op.
|
||||
tc.tlsConn.SetDeadline(time.Now().Add(time.Second * 5))
|
||||
tc.tlsConn.SetDeadline(disableSetDeadlineDeadline)
|
||||
|
||||
return tc.tlsConn.Close()
|
||||
}
|
||||
|
||||
func (tc *TLSConn) SetDeadline(t time.Time) error { return tc.tlsConn.SetDeadline(t) }
|
||||
func (tc *TLSConn) SetReadDeadline(t time.Time) error { return tc.tlsConn.SetReadDeadline(t) }
|
||||
func (tc *TLSConn) SetWriteDeadline(t time.Time) error { return tc.tlsConn.SetWriteDeadline(t) }
|
||||
@@ -0,0 +1,13 @@
|
||||
//go:build !(aix || android || darwin || dragonfly || freebsd || hurd || illumos || ios || linux || netbsd || openbsd || solaris)
|
||||
|
||||
package nbconn
|
||||
|
||||
// Not using unix build tag for support on Go 1.18.
|
||||
|
||||
func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
|
||||
return c.fakeNonblockingWrite(b)
|
||||
}
|
||||
|
||||
func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
|
||||
return c.fakeNonblockingRead(b)
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
//go:build aix || android || darwin || dragonfly || freebsd || hurd || illumos || ios || linux || netbsd || openbsd || solaris
|
||||
|
||||
package nbconn
|
||||
|
||||
// Not using unix build tag for support on Go 1.18.
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// realNonblockingWrite does a non-blocking write. readFlushLock must already be held.
|
||||
func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
|
||||
c.nonblockWriteBuf = b
|
||||
c.nonblockWriteN = 0
|
||||
c.nonblockWriteErr = nil
|
||||
err = c.rawConn.Write(func(fd uintptr) (done bool) {
|
||||
c.nonblockWriteN, c.nonblockWriteErr = syscall.Write(int(fd), c.nonblockWriteBuf)
|
||||
return true
|
||||
})
|
||||
n = c.nonblockWriteN
|
||||
if err == nil && c.nonblockWriteErr != nil {
|
||||
if errors.Is(c.nonblockWriteErr, syscall.EWOULDBLOCK) {
|
||||
err = ErrWouldBlock
|
||||
} else {
|
||||
err = c.nonblockWriteErr
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// n may be -1 when an error occurs.
|
||||
if n < 0 {
|
||||
n = 0
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
|
||||
var funcErr error
|
||||
err = c.rawConn.Read(func(fd uintptr) (done bool) {
|
||||
n, funcErr = syscall.Read(int(fd), b)
|
||||
return true
|
||||
})
|
||||
if err == nil && funcErr != nil {
|
||||
if errors.Is(funcErr, syscall.EWOULDBLOCK) {
|
||||
err = ErrWouldBlock
|
||||
} else {
|
||||
err = funcErr
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// n may be -1 when an error occurs.
|
||||
if n < 0 {
|
||||
n = 0
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// syscall read did not return an error and 0 bytes were read means EOF.
|
||||
if n == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
@@ -0,0 +1,584 @@
|
||||
package nbconn_test
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/nbconn"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Test keys generated with:
|
||||
//
|
||||
// $ openssl req -x509 -newkey rsa:2048 -keyout key.pem -out cert.pem -sha256 -nodes -days 20000 -subj '/CN=localhost'
|
||||
|
||||
var testTLSPublicKey = []byte(`-----BEGIN CERTIFICATE-----
|
||||
MIICpjCCAY4CCQCjQKYdUDQzKDANBgkqhkiG9w0BAQsFADAUMRIwEAYDVQQDDAls
|
||||
b2NhbGhvc3QwIBcNMjIwNjA0MTY1MzE2WhgPMjA3NzAzMDcxNjUzMTZaMBQxEjAQ
|
||||
BgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
|
||||
ALHbOu80cfSPufKTZsKf3E5rCXHeIHjaIbgHEXA2SW/n77U8oZX518s+27FO0sK5
|
||||
yA0WnEIwY34PU359sNR5KelARGnaeh3HdaGm1nuyyxBtwwAqIuM0UxGAMF/mQ4lT
|
||||
caZPxG+7WlYDqnE3eVXUtG4c+T7t5qKAB3MtfbzKFSjczkWkroi6cTypmHArGghT
|
||||
0VWWVu0s9oNp5q8iWchY2o9f0aIjmKv6FgtilO+geev+4U+QvtvrziR5BO3/3EgW
|
||||
c5TUVcf+lwkvp8ziXvargmjjnNTyeF37y4KpFcex0v7z7hSrUK4zU0+xRn7Bp17v
|
||||
7gzj0xN+HCsUW1cjPFNezX0CAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAbEBzewzg
|
||||
Z5F+BqMSxP3HkMCkLLH0N9q0/DkZaVyZ38vrjcjaDYuabq28kA2d5dc5jxsQpvTw
|
||||
HTGqSv1ZxJP3pBFv6jLSh8xaM6tUkk482Q6DnZGh97CD4yup/yJzkn5nv9OHtZ9g
|
||||
TnaQeeXgOz0o5Zq9IpzHJb19ysya3UCIK8oKXbSO4Qd168seCq75V2BFHDpmejjk
|
||||
D92eT6WODlzzvZbhzA1F3/cUilZdhbQtJMqdecKvD+yrBpzGVqzhWQsXwsRAU1fB
|
||||
hShx+D14zUGM2l4wlVzOAuGh4ZL7x3AjJsc86TsCavTspS0Xl51j+mRbiULq7G7Y
|
||||
E7ZYmaKTMOhvkg==
|
||||
-----END CERTIFICATE-----`)
|
||||
|
||||
// The strings.ReplaceAll is used to placate any secret scanners that would squawk if they saw a private key embedded in
|
||||
// source code.
|
||||
var testTLSPrivateKey = []byte(strings.ReplaceAll(`-----BEGIN TESTING KEY-----
|
||||
MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQCx2zrvNHH0j7ny
|
||||
k2bCn9xOawlx3iB42iG4BxFwNklv5++1PKGV+dfLPtuxTtLCucgNFpxCMGN+D1N+
|
||||
fbDUeSnpQERp2nodx3WhptZ7sssQbcMAKiLjNFMRgDBf5kOJU3GmT8Rvu1pWA6px
|
||||
N3lV1LRuHPk+7eaigAdzLX28yhUo3M5FpK6IunE8qZhwKxoIU9FVllbtLPaDaeav
|
||||
IlnIWNqPX9GiI5ir+hYLYpTvoHnr/uFPkL7b684keQTt/9xIFnOU1FXH/pcJL6fM
|
||||
4l72q4Jo45zU8nhd+8uCqRXHsdL+8+4Uq1CuM1NPsUZ+wade7+4M49MTfhwrFFtX
|
||||
IzxTXs19AgMBAAECggEBAJcHt5ARVQN8WUbobMawwX/F3QtYuPJnKWMAfYpwTwQ8
|
||||
TI32orCcrObmxeBXMxowcPTMUnzSYmpV0W0EhvimuzRbYr0Qzcoj6nwPFOuN9GpL
|
||||
CuBE58NQV4nw9SM6gfdHaKb17bWDvz5zdnUVym9cZKts5yrNEqDDX5Aq/S8n27gJ
|
||||
/qheXwSxwETVO6kMEW1ndNIWDP8DPQ0E4O//RuMZwxpnZdnjGKkdVNy8I1BpgDgn
|
||||
lwgkE3H3IciASki1GYXoyvrIiRwMQVzvYD2zcgwK9OZSjZe0TGwAGa+eQdbs3A9I
|
||||
Ir1kYn6ZMGMRFJA2XHJW3hMZdWB/t2xMBGy75Uv9sAECgYEA1o+oRUYwwQ1MwBo9
|
||||
YA6c00KjhFgrjdzyKPQrN14Q0dw5ErqRkhp2cs7BRdCDTDrjAegPc3Otg7uMa1vp
|
||||
RgU/C72jwzFLYATvn+RLGRYRyqIE+bQ22/lLnXTrp4DCfdMrqWuQbIYouGHqfQrq
|
||||
MfdtSUpQ6VZCi9zHehXOYwBMvQECgYEA1DTQFpe+tndIFmguxxaBwDltoPh5omzd
|
||||
3vA7iFct2+UYk5W9shfAekAaZk2WufKmmC3OfBWYyIaJ7QwQpuGDS3zwjy6WFMTE
|
||||
Otp2CypFCVahwHcvn2jYHmDMT0k0Pt6X2S3GAyWTyEPv7mAfKR1OWUYi7ZgdXpt0
|
||||
TtL3Z3JyhH0CgYEAwveHUGuXodUUCPvPCZo9pzrGm1wDN8WtxskY/Bbd8dTLh9lA
|
||||
riKdv3Vg6q+un3ZjETht0dsrsKib0HKUZqwdve11AcmpVHcnx4MLOqBzSk4vdzfr
|
||||
IbhGna3A9VRrZyqcYjb75aGDHwjaqwVgCkdrZ03AeEeJ8M2N9cIa6Js9IAECgYBu
|
||||
nlU24cVdspJWc9qml3ntrUITnlMxs1R5KXuvF9rk/OixzmYDV1RTpeTdHWcL6Yyk
|
||||
WYSAtHVfWpq9ggOQKpBZonh3+w3rJ6MvFsBgE5nHQ2ywOrENhQbb1xPJ5NwiRcCc
|
||||
Srsk2srNo3SIK30y3n8AFIqSljABKEIZ8Olc+JDvtQKBgQCiKz43zI6a0HscgZ77
|
||||
DCBduWP4nk8BM7QTFxs9VypjrylMDGGtTKHc5BLA5fNZw97Hb7pcicN7/IbUnQUD
|
||||
pz01y53wMSTJs0ocAxkYvUc5laF+vMsLpG2vp8f35w8uKuO7+vm5LAjUsPd099jG
|
||||
2qWm8jTPeDC3sq+67s2oojHf+Q==
|
||||
-----END TESTING KEY-----`, "TESTING KEY", "PRIVATE KEY"))
|
||||
|
||||
func testVariants(t *testing.T, f func(t *testing.T, local nbconn.Conn, remote net.Conn)) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
makeConns func(t *testing.T) (local, remote net.Conn)
|
||||
useTLS bool
|
||||
fakeNonBlockingIO bool
|
||||
}{
|
||||
{
|
||||
name: "Pipe",
|
||||
makeConns: makePipeConns,
|
||||
useTLS: false,
|
||||
fakeNonBlockingIO: true,
|
||||
},
|
||||
{
|
||||
name: "TCP with Fake Non-blocking IO",
|
||||
makeConns: makeTCPConns,
|
||||
useTLS: false,
|
||||
fakeNonBlockingIO: true,
|
||||
},
|
||||
{
|
||||
name: "TLS over TCP with Fake Non-blocking IO",
|
||||
makeConns: makeTCPConns,
|
||||
useTLS: true,
|
||||
fakeNonBlockingIO: true,
|
||||
},
|
||||
{
|
||||
name: "TCP with Real Non-blocking IO",
|
||||
makeConns: makeTCPConns,
|
||||
useTLS: false,
|
||||
fakeNonBlockingIO: false,
|
||||
},
|
||||
{
|
||||
name: "TLS over TCP with Real Non-blocking IO",
|
||||
makeConns: makeTCPConns,
|
||||
useTLS: true,
|
||||
fakeNonBlockingIO: false,
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
local, remote := tt.makeConns(t)
|
||||
|
||||
// Just to be sure both ends get closed. Also, it retains a reference so one side of the connection doesn't get
|
||||
// garbage collected. This could happen when a test is testing against a non-responsive remote. Since it never
|
||||
// uses remote it may be garbage collected leading to the connection being closed.
|
||||
defer local.Close()
|
||||
defer remote.Close()
|
||||
|
||||
var conn nbconn.Conn
|
||||
netConn := nbconn.NewNetConn(local, tt.fakeNonBlockingIO)
|
||||
|
||||
if tt.useTLS {
|
||||
cert, err := tls.X509KeyPair(testTLSPublicKey, testTLSPrivateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
tlsServer := tls.Server(remote, &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
})
|
||||
serverTLSHandshakeChan := make(chan error)
|
||||
go func() {
|
||||
err := tlsServer.Handshake()
|
||||
serverTLSHandshakeChan <- err
|
||||
}()
|
||||
|
||||
tlsConn, err := nbconn.TLSClient(netConn, &tls.Config{InsecureSkipVerify: true})
|
||||
require.NoError(t, err)
|
||||
conn = tlsConn
|
||||
|
||||
err = <-serverTLSHandshakeChan
|
||||
require.NoError(t, err)
|
||||
remote = tlsServer
|
||||
} else {
|
||||
conn = netConn
|
||||
}
|
||||
|
||||
f(t, conn, remote)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// makePipeConns returns a connected pair of net.Conns created with net.Pipe(). It is entirely synchronous so it is
|
||||
// useful for testing an exact sequence of reads and writes with the underlying connection blocking.
|
||||
func makePipeConns(t *testing.T) (local, remote net.Conn) {
|
||||
local, remote = net.Pipe()
|
||||
t.Cleanup(func() {
|
||||
local.Close()
|
||||
remote.Close()
|
||||
})
|
||||
|
||||
return local, remote
|
||||
}
|
||||
|
||||
// makeTCPConns returns a connected pair of net.Conns running over TCP on localhost.
|
||||
func makeTCPConns(t *testing.T) (local, remote net.Conn) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
type acceptResultT struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
acceptChan := make(chan acceptResultT)
|
||||
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
acceptChan <- acceptResultT{conn: conn, err: err}
|
||||
}()
|
||||
|
||||
local, err = net.Dial("tcp", ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
acceptResult := <-acceptChan
|
||||
require.NoError(t, acceptResult.err)
|
||||
|
||||
remote = acceptResult.conn
|
||||
|
||||
return local, remote
|
||||
}
|
||||
|
||||
func TestWriteIsBuffered(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
// net.Pipe is synchronous so the Write would block if not buffered.
|
||||
writeBuf := []byte("test")
|
||||
n, err := conn.Write(writeBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 4, n)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := conn.Flush()
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
readBuf := make([]byte, len(writeBuf))
|
||||
_, err = remote.Read(readBuf)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, <-errChan)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetWriteDeadlineDoesNotBlockWrite(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
err := conn.SetWriteDeadline(time.Now())
|
||||
require.NoError(t, err)
|
||||
|
||||
writeBuf := []byte("test")
|
||||
n, err := conn.Write(writeBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 4, n)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadFlushesWriteBuffer(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
writeBuf := []byte("test")
|
||||
n, err := conn.Write(writeBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 4, n)
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
go func() {
|
||||
readBuf := make([]byte, len(writeBuf))
|
||||
_, err := remote.Read(readBuf)
|
||||
errChan <- err
|
||||
|
||||
_, err = remote.Write([]byte("okay"))
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
readBuf := make([]byte, 4)
|
||||
_, err = conn.Read(readBuf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("okay"), readBuf)
|
||||
|
||||
require.NoError(t, <-errChan)
|
||||
require.NoError(t, <-errChan)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCloseFlushesWriteBuffer(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
writeBuf := []byte("test")
|
||||
n, err := conn.Write(writeBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 4, n)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
readBuf := make([]byte, len(writeBuf))
|
||||
_, err := remote.Read(readBuf)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, <-errChan)
|
||||
})
|
||||
}
|
||||
|
||||
// This test exercises the non-blocking write path. Because writes are buffered it is difficult trigger this with
|
||||
// certainty and visibility. So this test tries to trigger what would otherwise be a deadlock by both sides writing
|
||||
// large values.
|
||||
func TestInternalNonBlockingWrite(t *testing.T) {
|
||||
const deadlockSize = 4 * 1024 * 1024
|
||||
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
writeBuf := make([]byte, deadlockSize)
|
||||
n, err := conn.Write(writeBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, deadlockSize, n)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
remoteWriteBuf := make([]byte, deadlockSize)
|
||||
_, err := remote.Write(remoteWriteBuf)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
readBuf := make([]byte, deadlockSize)
|
||||
_, err = io.ReadFull(remote, readBuf)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
readBuf := make([]byte, deadlockSize)
|
||||
_, err = io.ReadFull(conn, readBuf)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, <-errChan)
|
||||
})
|
||||
}
|
||||
|
||||
func TestInternalNonBlockingWriteWithDeadline(t *testing.T) {
|
||||
const deadlockSize = 4 * 1024 * 1024
|
||||
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
writeBuf := make([]byte, deadlockSize)
|
||||
n, err := conn.Write(writeBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, deadlockSize, n)
|
||||
|
||||
err = conn.SetDeadline(time.Now().Add(100 * time.Millisecond))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.Flush()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "i/o timeout")
|
||||
})
|
||||
}
|
||||
|
||||
func TestNonBlockingRead(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
err := conn.SetReadDeadline(nbconn.NonBlockingDeadline)
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := make([]byte, 4)
|
||||
n, err := conn.Read(buf)
|
||||
require.ErrorIs(t, err, nbconn.ErrWouldBlock)
|
||||
require.EqualValues(t, 0, n)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := remote.Write([]byte("okay"))
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
err = conn.SetReadDeadline(time.Time{})
|
||||
require.NoError(t, err)
|
||||
|
||||
n, err = conn.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 4, n)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBufferNonBlockingRead(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
err := conn.BufferReadUntilBlock()
|
||||
require.NoError(t, err)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := remote.Write([]byte("okay"))
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
err = conn.BufferReadUntilBlock()
|
||||
if !errors.Is(err, nbconn.ErrWouldBlock) {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := make([]byte, 4)
|
||||
n, err := conn.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 4, n)
|
||||
require.Equal(t, []byte("okay"), buf)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadPreviouslyBuffered(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := func() error {
|
||||
_, err := remote.Write([]byte("alpha"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
readBuf := make([]byte, 4)
|
||||
_, err = remote.Read(readBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
_, err := conn.Write([]byte("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
|
||||
err = conn.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
readBuf := make([]byte, 5)
|
||||
n, err := conn.Read(readBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 5, n)
|
||||
require.Equal(t, []byte("alpha"), readBuf)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadMoreThanPreviouslyBufferedDoesNotBlock(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := func() error {
|
||||
_, err := remote.Write([]byte("alpha"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
readBuf := make([]byte, 4)
|
||||
_, err = remote.Read(readBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
_, err := conn.Write([]byte("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
|
||||
err = conn.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
readBuf := make([]byte, 10)
|
||||
n, err := conn.Read(readBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 5, n)
|
||||
require.Equal(t, []byte("alpha"), readBuf[:n])
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadPreviouslyBufferedPartialRead(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := func() error {
|
||||
_, err := remote.Write([]byte("alpha"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
readBuf := make([]byte, 4)
|
||||
_, err = remote.Read(readBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
_, err := conn.Write([]byte("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
|
||||
err = conn.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
readBuf := make([]byte, 2)
|
||||
n, err := conn.Read(readBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 2, n)
|
||||
require.Equal(t, []byte("al"), readBuf)
|
||||
|
||||
readBuf = make([]byte, 3)
|
||||
n, err = conn.Read(readBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 3, n)
|
||||
require.Equal(t, []byte("pha"), readBuf)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadMultiplePreviouslyBuffered(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := func() error {
|
||||
_, err := remote.Write([]byte("alpha"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = remote.Write([]byte("beta"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
readBuf := make([]byte, 4)
|
||||
_, err = remote.Read(readBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
_, err := conn.Write([]byte("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
|
||||
err = conn.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
readBuf := make([]byte, 9)
|
||||
n, err := io.ReadFull(conn, readBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 9, n)
|
||||
require.Equal(t, []byte("alphabeta"), readBuf)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadPreviouslyBufferedAndReadMore(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
|
||||
flushCompleteChan := make(chan struct{})
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := func() error {
|
||||
_, err := remote.Write([]byte("alpha"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
readBuf := make([]byte, 4)
|
||||
_, err = remote.Read(readBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
<-flushCompleteChan
|
||||
|
||||
_, err = remote.Write([]byte("beta"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
_, err := conn.Write([]byte("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
|
||||
err = conn.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
close(flushCompleteChan)
|
||||
|
||||
readBuf := make([]byte, 9)
|
||||
|
||||
n, err := io.ReadFull(conn, readBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 9, n)
|
||||
require.Equal(t, []byte("alphabeta"), readBuf)
|
||||
|
||||
err = <-errChan
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
# pgio
|
||||
|
||||
Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol.
|
||||
|
||||
pgio provides functions for appending integers to a []byte while doing byte
|
||||
order conversion.
|
||||
@@ -0,0 +1,6 @@
|
||||
// Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol.
|
||||
/*
|
||||
pgio provides functions for appending integers to a []byte while doing byte
|
||||
order conversion.
|
||||
*/
|
||||
package pgio
|
||||
@@ -0,0 +1,40 @@
|
||||
package pgio
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
func AppendUint16(buf []byte, n uint16) []byte {
|
||||
wp := len(buf)
|
||||
buf = append(buf, 0, 0)
|
||||
binary.BigEndian.PutUint16(buf[wp:], n)
|
||||
return buf
|
||||
}
|
||||
|
||||
func AppendUint32(buf []byte, n uint32) []byte {
|
||||
wp := len(buf)
|
||||
buf = append(buf, 0, 0, 0, 0)
|
||||
binary.BigEndian.PutUint32(buf[wp:], n)
|
||||
return buf
|
||||
}
|
||||
|
||||
func AppendUint64(buf []byte, n uint64) []byte {
|
||||
wp := len(buf)
|
||||
buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0)
|
||||
binary.BigEndian.PutUint64(buf[wp:], n)
|
||||
return buf
|
||||
}
|
||||
|
||||
func AppendInt16(buf []byte, n int16) []byte {
|
||||
return AppendUint16(buf, uint16(n))
|
||||
}
|
||||
|
||||
func AppendInt32(buf []byte, n int32) []byte {
|
||||
return AppendUint32(buf, uint32(n))
|
||||
}
|
||||
|
||||
func AppendInt64(buf []byte, n int64) []byte {
|
||||
return AppendUint64(buf, uint64(n))
|
||||
}
|
||||
|
||||
func SetInt32(buf []byte, n int32) {
|
||||
binary.BigEndian.PutUint32(buf, uint32(n))
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
package pgio
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAppendUint16NilBuf(t *testing.T) {
|
||||
buf := AppendUint16(nil, 1)
|
||||
if !reflect.DeepEqual(buf, []byte{0, 1}) {
|
||||
t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUint16EmptyBuf(t *testing.T) {
|
||||
buf := []byte{}
|
||||
buf = AppendUint16(buf, 1)
|
||||
if !reflect.DeepEqual(buf, []byte{0, 1}) {
|
||||
t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUint16BufWithCapacityDoesNotAllocate(t *testing.T) {
|
||||
buf := make([]byte, 0, 4)
|
||||
AppendUint16(buf, 1)
|
||||
buf = buf[0:2]
|
||||
if !reflect.DeepEqual(buf, []byte{0, 1}) {
|
||||
t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUint32NilBuf(t *testing.T) {
|
||||
buf := AppendUint32(nil, 1)
|
||||
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) {
|
||||
t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUint32EmptyBuf(t *testing.T) {
|
||||
buf := []byte{}
|
||||
buf = AppendUint32(buf, 1)
|
||||
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) {
|
||||
t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUint32BufWithCapacityDoesNotAllocate(t *testing.T) {
|
||||
buf := make([]byte, 0, 4)
|
||||
AppendUint32(buf, 1)
|
||||
buf = buf[0:4]
|
||||
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) {
|
||||
t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUint64NilBuf(t *testing.T) {
|
||||
buf := AppendUint64(nil, 1)
|
||||
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) {
|
||||
t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUint64EmptyBuf(t *testing.T) {
|
||||
buf := []byte{}
|
||||
buf = AppendUint64(buf, 1)
|
||||
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) {
|
||||
t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUint64BufWithCapacityDoesNotAllocate(t *testing.T) {
|
||||
buf := make([]byte, 0, 8)
|
||||
AppendUint64(buf, 1)
|
||||
buf = buf[0:8]
|
||||
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) {
|
||||
t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
// Package pgmock provides the ability to mock a PostgreSQL server.
|
||||
package pgmock
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
)
|
||||
|
||||
type Step interface {
|
||||
Step(*pgproto3.Backend) error
|
||||
}
|
||||
|
||||
type Script struct {
|
||||
Steps []Step
|
||||
}
|
||||
|
||||
func (s *Script) Run(backend *pgproto3.Backend) error {
|
||||
for _, step := range s.Steps {
|
||||
err := step.Step(backend)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Script) Step(backend *pgproto3.Backend) error {
|
||||
return s.Run(backend)
|
||||
}
|
||||
|
||||
type expectMessageStep struct {
|
||||
want pgproto3.FrontendMessage
|
||||
any bool
|
||||
}
|
||||
|
||||
func (e *expectMessageStep) Step(backend *pgproto3.Backend) error {
|
||||
msg, err := backend.Receive()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if e.any && reflect.TypeOf(msg) == reflect.TypeOf(e.want) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(msg, e.want) {
|
||||
return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type expectStartupMessageStep struct {
|
||||
want *pgproto3.StartupMessage
|
||||
any bool
|
||||
}
|
||||
|
||||
func (e *expectStartupMessageStep) Step(backend *pgproto3.Backend) error {
|
||||
msg, err := backend.ReceiveStartupMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if e.any {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(msg, e.want) {
|
||||
return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ExpectMessage(want pgproto3.FrontendMessage) Step {
|
||||
return expectMessage(want, false)
|
||||
}
|
||||
|
||||
func ExpectAnyMessage(want pgproto3.FrontendMessage) Step {
|
||||
return expectMessage(want, true)
|
||||
}
|
||||
|
||||
func expectMessage(want pgproto3.FrontendMessage, any bool) Step {
|
||||
if want, ok := want.(*pgproto3.StartupMessage); ok {
|
||||
return &expectStartupMessageStep{want: want, any: any}
|
||||
}
|
||||
|
||||
return &expectMessageStep{want: want, any: any}
|
||||
}
|
||||
|
||||
type sendMessageStep struct {
|
||||
msg pgproto3.BackendMessage
|
||||
}
|
||||
|
||||
func (e *sendMessageStep) Step(backend *pgproto3.Backend) error {
|
||||
backend.Send(e.msg)
|
||||
return backend.Flush()
|
||||
}
|
||||
|
||||
func SendMessage(msg pgproto3.BackendMessage) Step {
|
||||
return &sendMessageStep{msg: msg}
|
||||
}
|
||||
|
||||
type waitForCloseMessageStep struct{}
|
||||
|
||||
func (e *waitForCloseMessageStep) Step(backend *pgproto3.Backend) error {
|
||||
for {
|
||||
msg, err := backend.Receive()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, ok := msg.(*pgproto3.Terminate); ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func WaitForClose() Step {
|
||||
return &waitForCloseMessageStep{}
|
||||
}
|
||||
|
||||
func AcceptUnauthenticatedConnRequestSteps() []Step {
|
||||
return []Step{
|
||||
ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
|
||||
SendMessage(&pgproto3.AuthenticationOk{}),
|
||||
SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
|
||||
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
package pgmock_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgmock"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestScript(t *testing.T) {
|
||||
script := &pgmock.Script{
|
||||
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
|
||||
}
|
||||
script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Query{String: "select 42"}))
|
||||
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.RowDescription{
|
||||
Fields: []pgproto3.FieldDescription{
|
||||
pgproto3.FieldDescription{
|
||||
Name: []byte("?column?"),
|
||||
TableOID: 0,
|
||||
TableAttributeNumber: 0,
|
||||
DataTypeOID: 23,
|
||||
DataTypeSize: 4,
|
||||
TypeModifier: -1,
|
||||
Format: 0,
|
||||
},
|
||||
},
|
||||
}))
|
||||
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.DataRow{
|
||||
Values: [][]byte{[]byte("42")},
|
||||
}))
|
||||
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}))
|
||||
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}))
|
||||
script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Terminate{}))
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:")
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
serverErrChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(serverErrChan)
|
||||
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
err = conn.SetDeadline(time.Now().Add(time.Second))
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
err = script.Run(pgproto3.NewBackend(conn, conn))
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
parts := strings.Split(ln.Addr().String(), ":")
|
||||
host := parts[0]
|
||||
port := parts[1]
|
||||
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
pgConn, err := pgconn.Connect(ctx, connStr)
|
||||
require.NoError(t, err)
|
||||
results, err := pgConn.Exec(ctx, "select 42").ReadAll()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, results, 1)
|
||||
assert.Nil(t, results[0].Err)
|
||||
assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
|
||||
assert.Len(t, results[0].Rows, 1)
|
||||
assert.Equal(t, "42", string(results[0].Rows[0][0]))
|
||||
|
||||
pgConn.Close(ctx)
|
||||
|
||||
assert.NoError(t, <-serverErrChan)
|
||||
}
|
||||
@@ -12,13 +12,13 @@ import (
|
||||
|
||||
// Part is either a string or an int. A string is raw SQL. An int is a
|
||||
// argument placeholder.
|
||||
type Part interface{}
|
||||
type Part any
|
||||
|
||||
type Query struct {
|
||||
Parts []Part
|
||||
}
|
||||
|
||||
func (q *Query) Sanitize(args ...interface{}) (string, error) {
|
||||
func (q *Query) Sanitize(args ...any) (string, error) {
|
||||
argUse := make([]bool, len(args))
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
@@ -295,7 +295,7 @@ func multilineCommentState(l *sqlLexer) stateFn {
|
||||
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
|
||||
// as necessary. This function is only safe when standard_conforming_strings is
|
||||
// on.
|
||||
func SanitizeSQL(sql string, args ...interface{}) (string, error) {
|
||||
func SanitizeSQL(sql string, args ...any) (string, error) {
|
||||
query, err := NewQuery(sql)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v4/internal/sanitize"
|
||||
"github.com/jackc/pgx/v5/internal/sanitize"
|
||||
)
|
||||
|
||||
func TestNewQuery(t *testing.T) {
|
||||
@@ -111,57 +111,57 @@ func TestNewQuery(t *testing.T) {
|
||||
func TestQuerySanitize(t *testing.T) {
|
||||
successfulTests := []struct {
|
||||
query sanitize.Query
|
||||
args []interface{}
|
||||
args []any
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
|
||||
args: []interface{}{},
|
||||
args: []any{},
|
||||
expected: `select 42`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{int64(42)},
|
||||
args: []any{int64(42)},
|
||||
expected: `select 42`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{float64(1.23)},
|
||||
args: []any{float64(1.23)},
|
||||
expected: `select 1.23`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{true},
|
||||
args: []any{true},
|
||||
expected: `select true`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{[]byte{0, 1, 2, 3, 255}},
|
||||
args: []any{[]byte{0, 1, 2, 3, 255}},
|
||||
expected: `select '\x00010203ff'`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{nil},
|
||||
args: []any{nil},
|
||||
expected: `select null`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{"foobar"},
|
||||
args: []any{"foobar"},
|
||||
expected: `select 'foobar'`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{"foo'bar"},
|
||||
args: []any{"foo'bar"},
|
||||
expected: `select 'foo''bar'`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{`foo\'bar`},
|
||||
args: []any{`foo\'bar`},
|
||||
expected: `select 'foo\''bar'`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"insert ", 1}},
|
||||
args: []interface{}{time.Date(2020, time.March, 1, 23, 59, 59, 999999999, time.UTC)},
|
||||
args: []any{time.Date(2020, time.March, 1, 23, 59, 59, 999999999, time.UTC)},
|
||||
expected: `insert '2020-03-01 23:59:59.999999Z'`,
|
||||
},
|
||||
}
|
||||
@@ -180,22 +180,22 @@ func TestQuerySanitize(t *testing.T) {
|
||||
|
||||
errorTests := []struct {
|
||||
query sanitize.Query
|
||||
args []interface{}
|
||||
args []any
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2}},
|
||||
args: []interface{}{int64(42)},
|
||||
args: []any{int64(42)},
|
||||
expected: `insufficient arguments`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select 'foo'"}},
|
||||
args: []interface{}{int64(42)},
|
||||
args: []any{int64(42)},
|
||||
expected: `unused argument: 0`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{42},
|
||||
args: []any{42},
|
||||
expected: `invalid arg type: int`,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
package stmtcache
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
// LRUCache implements Cache with a Least Recently Used (LRU) cache.
|
||||
type LRUCache struct {
|
||||
cap int
|
||||
m map[string]*list.Element
|
||||
l *list.List
|
||||
invalidStmts []*pgconn.StatementDescription
|
||||
}
|
||||
|
||||
// NewLRUCache creates a new LRUCache. cap is the maximum size of the cache.
|
||||
func NewLRUCache(cap int) *LRUCache {
|
||||
return &LRUCache{
|
||||
cap: cap,
|
||||
m: make(map[string]*list.Element),
|
||||
l: list.New(),
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns the statement description for sql. Returns nil if not found.
|
||||
func (c *LRUCache) Get(key string) *pgconn.StatementDescription {
|
||||
if el, ok := c.m[key]; ok {
|
||||
c.l.MoveToFront(el)
|
||||
return el.Value.(*pgconn.StatementDescription)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
|
||||
func (c *LRUCache) Put(sd *pgconn.StatementDescription) {
|
||||
if sd.SQL == "" {
|
||||
panic("cannot store statement description with empty SQL")
|
||||
}
|
||||
|
||||
if _, present := c.m[sd.SQL]; present {
|
||||
return
|
||||
}
|
||||
|
||||
if c.l.Len() == c.cap {
|
||||
c.invalidateOldest()
|
||||
}
|
||||
|
||||
el := c.l.PushFront(sd)
|
||||
c.m[sd.SQL] = el
|
||||
}
|
||||
|
||||
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
|
||||
func (c *LRUCache) Invalidate(sql string) {
|
||||
if el, ok := c.m[sql]; ok {
|
||||
delete(c.m, sql)
|
||||
c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription))
|
||||
c.l.Remove(el)
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateAll invalidates all statement descriptions.
|
||||
func (c *LRUCache) InvalidateAll() {
|
||||
el := c.l.Front()
|
||||
for el != nil {
|
||||
c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription))
|
||||
el = el.Next()
|
||||
}
|
||||
|
||||
c.m = make(map[string]*list.Element)
|
||||
c.l = list.New()
|
||||
}
|
||||
|
||||
func (c *LRUCache) HandleInvalidated() []*pgconn.StatementDescription {
|
||||
invalidStmts := c.invalidStmts
|
||||
c.invalidStmts = nil
|
||||
return invalidStmts
|
||||
}
|
||||
|
||||
// Len returns the number of cached prepared statement descriptions.
|
||||
func (c *LRUCache) Len() int {
|
||||
return c.l.Len()
|
||||
}
|
||||
|
||||
// Cap returns the maximum number of cached prepared statement descriptions.
|
||||
func (c *LRUCache) Cap() int {
|
||||
return c.cap
|
||||
}
|
||||
|
||||
func (c *LRUCache) invalidateOldest() {
|
||||
oldest := c.l.Back()
|
||||
sd := oldest.Value.(*pgconn.StatementDescription)
|
||||
c.invalidStmts = append(c.invalidStmts, sd)
|
||||
delete(c.m, sd.SQL)
|
||||
c.l.Remove(oldest)
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
// Package stmtcache is a cache for statement descriptions.
|
||||
package stmtcache
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
var stmtCounter int64
|
||||
|
||||
// NextStatementName returns a statement name that will be unique for the lifetime of the program.
|
||||
func NextStatementName() string {
|
||||
n := atomic.AddInt64(&stmtCounter, 1)
|
||||
return "stmtcache_" + strconv.FormatInt(n, 10)
|
||||
}
|
||||
|
||||
// Cache caches statement descriptions.
|
||||
type Cache interface {
|
||||
// Get returns the statement description for sql. Returns nil if not found.
|
||||
Get(sql string) *pgconn.StatementDescription
|
||||
|
||||
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
|
||||
Put(sd *pgconn.StatementDescription)
|
||||
|
||||
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
|
||||
Invalidate(sql string)
|
||||
|
||||
// InvalidateAll invalidates all statement descriptions.
|
||||
InvalidateAll()
|
||||
|
||||
// HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated.
|
||||
HandleInvalidated() []*pgconn.StatementDescription
|
||||
|
||||
// Len returns the number of cached prepared statement descriptions.
|
||||
Len() int
|
||||
|
||||
// Cap returns the maximum number of cached prepared statement descriptions.
|
||||
Cap() int
|
||||
}
|
||||
|
||||
func IsStatementInvalid(err error) bool {
|
||||
pgErr, ok := err.(*pgconn.PgError)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/1162
|
||||
//
|
||||
// We used to look for the message "cached plan must not change result type". However, that message can be localized.
|
||||
// Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to
|
||||
// tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't
|
||||
// have so it should be safe.
|
||||
possibleInvalidCachedPlanError := pgErr.Code == "0A000"
|
||||
return possibleInvalidCachedPlanError
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
package stmtcache
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
// UnlimitedCache implements Cache with no capacity limit.
|
||||
type UnlimitedCache struct {
|
||||
m map[string]*pgconn.StatementDescription
|
||||
invalidStmts []*pgconn.StatementDescription
|
||||
}
|
||||
|
||||
// NewUnlimitedCache creates a new UnlimitedCache.
|
||||
func NewUnlimitedCache() *UnlimitedCache {
|
||||
return &UnlimitedCache{
|
||||
m: make(map[string]*pgconn.StatementDescription),
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns the statement description for sql. Returns nil if not found.
|
||||
func (c *UnlimitedCache) Get(sql string) *pgconn.StatementDescription {
|
||||
return c.m[sql]
|
||||
}
|
||||
|
||||
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
|
||||
func (c *UnlimitedCache) Put(sd *pgconn.StatementDescription) {
|
||||
if sd.SQL == "" {
|
||||
panic("cannot store statement description with empty SQL")
|
||||
}
|
||||
|
||||
if _, present := c.m[sd.SQL]; present {
|
||||
return
|
||||
}
|
||||
|
||||
c.m[sd.SQL] = sd
|
||||
}
|
||||
|
||||
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
|
||||
func (c *UnlimitedCache) Invalidate(sql string) {
|
||||
if sd, ok := c.m[sql]; ok {
|
||||
delete(c.m, sql)
|
||||
c.invalidStmts = append(c.invalidStmts, sd)
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateAll invalidates all statement descriptions.
|
||||
func (c *UnlimitedCache) InvalidateAll() {
|
||||
for _, sd := range c.m {
|
||||
c.invalidStmts = append(c.invalidStmts, sd)
|
||||
}
|
||||
|
||||
c.m = make(map[string]*pgconn.StatementDescription)
|
||||
}
|
||||
|
||||
func (c *UnlimitedCache) HandleInvalidated() []*pgconn.StatementDescription {
|
||||
invalidStmts := c.invalidStmts
|
||||
c.invalidStmts = nil
|
||||
return invalidStmts
|
||||
}
|
||||
|
||||
// Len returns the number of cached prepared statement descriptions.
|
||||
func (c *UnlimitedCache) Len() int {
|
||||
return len(c.m)
|
||||
}
|
||||
|
||||
// Cap returns the maximum number of cached prepared statement descriptions.
|
||||
func (c *UnlimitedCache) Cap() int {
|
||||
return math.MaxInt
|
||||
}
|
||||
@@ -7,8 +7,9 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgxtest"
|
||||
)
|
||||
|
||||
func TestLargeObjects(t *testing.T) {
|
||||
@@ -22,7 +23,7 @@ func TestLargeObjects(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
skipCockroachDB(t, conn, "Server does support large objects")
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server does support large objects")
|
||||
|
||||
tx, err := conn.Begin(ctx)
|
||||
if err != nil {
|
||||
@@ -32,7 +33,7 @@ func TestLargeObjects(t *testing.T) {
|
||||
testLargeObjects(t, ctx, tx)
|
||||
}
|
||||
|
||||
func TestLargeObjectsPreferSimpleProtocol(t *testing.T) {
|
||||
func TestLargeObjectsSimpleProtocol(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
@@ -43,14 +44,14 @@ func TestLargeObjectsPreferSimpleProtocol(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config.PreferSimpleProtocol = true
|
||||
config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
|
||||
|
||||
conn, err := pgx.ConnectConfig(ctx, config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
skipCockroachDB(t, conn, "Server does support large objects")
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server does support large objects")
|
||||
|
||||
tx, err := conn.Begin(ctx)
|
||||
if err != nil {
|
||||
@@ -169,7 +170,7 @@ func TestLargeObjectsMultipleTransactions(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
skipCockroachDB(t, conn, "Server does support large objects")
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server does support large objects")
|
||||
|
||||
tx, err := conn.Begin(ctx)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
package kitlogadapter
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-kit/log"
|
||||
kitlevel "github.com/go-kit/log/level"
|
||||
"github.com/jackc/pgx/v4"
|
||||
)
|
||||
|
||||
type Logger struct {
|
||||
l log.Logger
|
||||
}
|
||||
|
||||
func NewLogger(l log.Logger) *Logger {
|
||||
return &Logger{l: l}
|
||||
}
|
||||
|
||||
func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
logger := l.l
|
||||
for k, v := range data {
|
||||
logger = log.With(logger, k, v)
|
||||
}
|
||||
|
||||
switch level {
|
||||
case pgx.LogLevelTrace:
|
||||
logger.Log("PGX_LOG_LEVEL", level, "msg", msg)
|
||||
case pgx.LogLevelDebug:
|
||||
kitlevel.Debug(logger).Log("msg", msg)
|
||||
case pgx.LogLevelInfo:
|
||||
kitlevel.Info(logger).Log("msg", msg)
|
||||
case pgx.LogLevelWarn:
|
||||
kitlevel.Warn(logger).Log("msg", msg)
|
||||
case pgx.LogLevelError:
|
||||
kitlevel.Error(logger).Log("msg", msg)
|
||||
default:
|
||||
logger.Log("INVALID_PGX_LOG_LEVEL", level, "error", msg)
|
||||
}
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
// Package log15adapter provides a logger that writes to a github.com/inconshreveable/log15.Logger
|
||||
// log.
|
||||
package log15adapter
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v4"
|
||||
)
|
||||
|
||||
// Log15Logger interface defines the subset of
|
||||
// github.com/inconshreveable/log15.Logger that this adapter uses.
|
||||
type Log15Logger interface {
|
||||
Debug(msg string, ctx ...interface{})
|
||||
Info(msg string, ctx ...interface{})
|
||||
Warn(msg string, ctx ...interface{})
|
||||
Error(msg string, ctx ...interface{})
|
||||
Crit(msg string, ctx ...interface{})
|
||||
}
|
||||
|
||||
type Logger struct {
|
||||
l Log15Logger
|
||||
}
|
||||
|
||||
func NewLogger(l Log15Logger) *Logger {
|
||||
return &Logger{l: l}
|
||||
}
|
||||
|
||||
func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
logArgs := make([]interface{}, 0, len(data))
|
||||
for k, v := range data {
|
||||
logArgs = append(logArgs, k, v)
|
||||
}
|
||||
|
||||
switch level {
|
||||
case pgx.LogLevelTrace:
|
||||
l.l.Debug(msg, append(logArgs, "PGX_LOG_LEVEL", level)...)
|
||||
case pgx.LogLevelDebug:
|
||||
l.l.Debug(msg, logArgs...)
|
||||
case pgx.LogLevelInfo:
|
||||
l.l.Info(msg, logArgs...)
|
||||
case pgx.LogLevelWarn:
|
||||
l.l.Warn(msg, logArgs...)
|
||||
case pgx.LogLevelError:
|
||||
l.l.Error(msg, logArgs...)
|
||||
default:
|
||||
l.l.Error(msg, append(logArgs, "INVALID_PGX_LOG_LEVEL", level)...)
|
||||
}
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
// Package logrusadapter provides a logger that writes to a github.com/sirupsen/logrus.Logger
|
||||
// log.
|
||||
package logrusadapter
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type Logger struct {
|
||||
l logrus.FieldLogger
|
||||
}
|
||||
|
||||
func NewLogger(l logrus.FieldLogger) *Logger {
|
||||
return &Logger{l: l}
|
||||
}
|
||||
|
||||
func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
var logger logrus.FieldLogger
|
||||
if data != nil {
|
||||
logger = l.l.WithFields(data)
|
||||
} else {
|
||||
logger = l.l
|
||||
}
|
||||
|
||||
switch level {
|
||||
case pgx.LogLevelTrace:
|
||||
logger.WithField("PGX_LOG_LEVEL", level).Debug(msg)
|
||||
case pgx.LogLevelDebug:
|
||||
logger.Debug(msg)
|
||||
case pgx.LogLevelInfo:
|
||||
logger.Info(msg)
|
||||
case pgx.LogLevelWarn:
|
||||
logger.Warn(msg)
|
||||
case pgx.LogLevelError:
|
||||
logger.Error(msg)
|
||||
default:
|
||||
logger.WithField("INVALID_PGX_LOG_LEVEL", level).Error(msg)
|
||||
}
|
||||
}
|
||||
@@ -6,13 +6,13 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/jackc/pgx/v5/tracelog"
|
||||
)
|
||||
|
||||
// TestingLogger interface defines the subset of testing.TB methods used by this
|
||||
// adapter.
|
||||
type TestingLogger interface {
|
||||
Log(args ...interface{})
|
||||
Log(args ...any)
|
||||
}
|
||||
|
||||
type Logger struct {
|
||||
@@ -23,8 +23,8 @@ func NewLogger(l TestingLogger) *Logger {
|
||||
return &Logger{l: l}
|
||||
}
|
||||
|
||||
func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
logArgs := make([]interface{}, 0, 2+len(data))
|
||||
func (l *Logger) Log(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]any) {
|
||||
logArgs := make([]any, 0, 2+len(data))
|
||||
logArgs = append(logArgs, level, msg)
|
||||
for k, v := range data {
|
||||
logArgs = append(logArgs, fmt.Sprintf("%s=%v", k, v))
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
// Package zapadapter provides a logger that writes to a go.uber.org/zap.Logger.
|
||||
package zapadapter
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v4"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
type Logger struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func NewLogger(logger *zap.Logger) *Logger {
|
||||
return &Logger{logger: logger.WithOptions(zap.AddCallerSkip(1))}
|
||||
}
|
||||
|
||||
func (pl *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
fields := make([]zapcore.Field, len(data))
|
||||
i := 0
|
||||
for k, v := range data {
|
||||
fields[i] = zap.Any(k, v)
|
||||
i++
|
||||
}
|
||||
|
||||
switch level {
|
||||
case pgx.LogLevelTrace:
|
||||
pl.logger.Debug(msg, append(fields, zap.Stringer("PGX_LOG_LEVEL", level))...)
|
||||
case pgx.LogLevelDebug:
|
||||
pl.logger.Debug(msg, fields...)
|
||||
case pgx.LogLevelInfo:
|
||||
pl.logger.Info(msg, fields...)
|
||||
case pgx.LogLevelWarn:
|
||||
pl.logger.Warn(msg, fields...)
|
||||
case pgx.LogLevelError:
|
||||
pl.logger.Error(msg, fields...)
|
||||
default:
|
||||
pl.logger.Error(msg, append(fields, zap.Stringer("PGX_LOG_LEVEL", level))...)
|
||||
}
|
||||
}
|
||||
@@ -1,102 +0,0 @@
|
||||
// Package zerologadapter provides a logger that writes to a github.com/rs/zerolog.
|
||||
package zerologadapter
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
type Logger struct {
|
||||
logger zerolog.Logger
|
||||
withFunc func(context.Context, zerolog.Context) zerolog.Context
|
||||
fromContext bool
|
||||
skipModule bool
|
||||
}
|
||||
|
||||
// option options for configuring the logger when creating a new logger.
|
||||
type option func(logger *Logger)
|
||||
|
||||
// WithContextFunc adds possibility to get request scoped values from the
|
||||
// ctx.Context before logging lines.
|
||||
func WithContextFunc(withFunc func(context.Context, zerolog.Context) zerolog.Context) option {
|
||||
return func(logger *Logger) {
|
||||
logger.withFunc = withFunc
|
||||
}
|
||||
}
|
||||
|
||||
// WithoutPGXModule disables adding module:pgx to the default logger context.
|
||||
func WithoutPGXModule() option {
|
||||
return func(logger *Logger) {
|
||||
logger.skipModule = true
|
||||
}
|
||||
}
|
||||
|
||||
// NewLogger accepts a zerolog.Logger as input and returns a new custom pgx
|
||||
// logging facade as output.
|
||||
func NewLogger(logger zerolog.Logger, options ...option) *Logger {
|
||||
l := Logger{
|
||||
logger: logger,
|
||||
}
|
||||
l.init(options)
|
||||
return &l
|
||||
}
|
||||
|
||||
// NewContextLogger creates logger that extracts the zerolog.Logger from the
|
||||
// context.Context by using `zerolog.Ctx`. The zerolog.DefaultContextLogger will
|
||||
// be used if no logger is associated with the context.
|
||||
func NewContextLogger(options ...option) *Logger {
|
||||
l := Logger{
|
||||
fromContext: true,
|
||||
}
|
||||
l.init(options)
|
||||
return &l
|
||||
}
|
||||
|
||||
func (pl *Logger) init(options []option) {
|
||||
for _, opt := range options {
|
||||
opt(pl)
|
||||
}
|
||||
if !pl.skipModule {
|
||||
pl.logger = pl.logger.With().Str("module", "pgx").Logger()
|
||||
}
|
||||
}
|
||||
|
||||
func (pl *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
var zlevel zerolog.Level
|
||||
switch level {
|
||||
case pgx.LogLevelNone:
|
||||
zlevel = zerolog.NoLevel
|
||||
case pgx.LogLevelError:
|
||||
zlevel = zerolog.ErrorLevel
|
||||
case pgx.LogLevelWarn:
|
||||
zlevel = zerolog.WarnLevel
|
||||
case pgx.LogLevelInfo:
|
||||
zlevel = zerolog.InfoLevel
|
||||
case pgx.LogLevelDebug:
|
||||
zlevel = zerolog.DebugLevel
|
||||
default:
|
||||
zlevel = zerolog.DebugLevel
|
||||
}
|
||||
|
||||
var zctx zerolog.Context
|
||||
if pl.fromContext {
|
||||
logger := zerolog.Ctx(ctx)
|
||||
zctx = logger.With()
|
||||
} else {
|
||||
zctx = pl.logger.With()
|
||||
}
|
||||
if pl.withFunc != nil {
|
||||
zctx = pl.withFunc(ctx, zctx)
|
||||
}
|
||||
|
||||
pgxlog := zctx.Logger()
|
||||
event := pgxlog.WithLevel(zlevel)
|
||||
if event.Enabled() {
|
||||
if pl.fromContext && !pl.skipModule {
|
||||
event.Str("module", "pgx")
|
||||
}
|
||||
event.Fields(data).Msg(msg)
|
||||
}
|
||||
}
|
||||
@@ -1,98 +0,0 @@
|
||||
package zerologadapter_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/jackc/pgx/v4/log/zerologadapter"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func TestLogger(t *testing.T) {
|
||||
|
||||
t.Run("default", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
zlogger := zerolog.New(&buf)
|
||||
logger := zerologadapter.NewLogger(zlogger)
|
||||
logger.Log(context.Background(), pgx.LogLevelInfo, "hello", map[string]interface{}{"one": "two"})
|
||||
const want = `{"level":"info","module":"pgx","one":"two","message":"hello"}
|
||||
`
|
||||
got := buf.String()
|
||||
if got != want {
|
||||
t.Errorf("%s != %s", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("disable pgx module", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
zlogger := zerolog.New(&buf)
|
||||
logger := zerologadapter.NewLogger(zlogger, zerologadapter.WithoutPGXModule())
|
||||
logger.Log(context.Background(), pgx.LogLevelInfo, "hello", nil)
|
||||
const want = `{"level":"info","message":"hello"}
|
||||
`
|
||||
got := buf.String()
|
||||
if got != want {
|
||||
t.Errorf("%s != %s", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("from context", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
zlogger := zerolog.New(&buf)
|
||||
ctx := zlogger.WithContext(context.Background())
|
||||
logger := zerologadapter.NewContextLogger()
|
||||
logger.Log(ctx, pgx.LogLevelInfo, "hello", map[string]interface{}{"one": "two"})
|
||||
const want = `{"level":"info","module":"pgx","one":"two","message":"hello"}
|
||||
`
|
||||
|
||||
got := buf.String()
|
||||
if got != want {
|
||||
t.Log(got)
|
||||
t.Log(want)
|
||||
t.Errorf("%s != %s", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
var buf bytes.Buffer
|
||||
type key string
|
||||
var ck key
|
||||
zlogger := zerolog.New(&buf)
|
||||
logger := zerologadapter.NewLogger(zlogger,
|
||||
zerologadapter.WithContextFunc(func(ctx context.Context, logWith zerolog.Context) zerolog.Context {
|
||||
// You can use zerolog.hlog.IDFromCtx(ctx) or even
|
||||
// zerolog.log.Ctx(ctx) to fetch the whole logger instance from the
|
||||
// context if you want.
|
||||
id, ok := ctx.Value(ck).(string)
|
||||
if ok {
|
||||
logWith = logWith.Str("req_id", id)
|
||||
}
|
||||
return logWith
|
||||
}),
|
||||
)
|
||||
|
||||
t.Run("no request id", func(t *testing.T) {
|
||||
buf.Reset()
|
||||
ctx := context.Background()
|
||||
logger.Log(ctx, pgx.LogLevelInfo, "hello", nil)
|
||||
const want = `{"level":"info","module":"pgx","message":"hello"}
|
||||
`
|
||||
got := buf.String()
|
||||
if got != want {
|
||||
t.Errorf("%s != %s", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with request id", func(t *testing.T) {
|
||||
buf.Reset()
|
||||
ctx := context.WithValue(context.Background(), ck, "1")
|
||||
logger.Log(ctx, pgx.LogLevelInfo, "hello", map[string]interface{}{"two": "2"})
|
||||
const want = `{"level":"info","module":"pgx","req_id":"1","two":"2","message":"hello"}
|
||||
`
|
||||
got := buf.String()
|
||||
if got != want {
|
||||
t.Errorf("%s != %s", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,107 +0,0 @@
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// The values for log levels are chosen such that the zero value means that no
|
||||
// log level was specified.
|
||||
const (
|
||||
LogLevelTrace = 6
|
||||
LogLevelDebug = 5
|
||||
LogLevelInfo = 4
|
||||
LogLevelWarn = 3
|
||||
LogLevelError = 2
|
||||
LogLevelNone = 1
|
||||
)
|
||||
|
||||
// LogLevel represents the pgx logging level. See LogLevel* constants for
|
||||
// possible values.
|
||||
type LogLevel int
|
||||
|
||||
func (ll LogLevel) String() string {
|
||||
switch ll {
|
||||
case LogLevelTrace:
|
||||
return "trace"
|
||||
case LogLevelDebug:
|
||||
return "debug"
|
||||
case LogLevelInfo:
|
||||
return "info"
|
||||
case LogLevelWarn:
|
||||
return "warn"
|
||||
case LogLevelError:
|
||||
return "error"
|
||||
case LogLevelNone:
|
||||
return "none"
|
||||
default:
|
||||
return fmt.Sprintf("invalid level %d", ll)
|
||||
}
|
||||
}
|
||||
|
||||
// Logger is the interface used to get logging from pgx internals.
|
||||
type Logger interface {
|
||||
// Log a message at the given level with data key/value pairs. data may be nil.
|
||||
Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{})
|
||||
}
|
||||
|
||||
// LoggerFunc is a wrapper around a function to satisfy the pgx.Logger interface
|
||||
type LoggerFunc func(ctx context.Context, level LogLevel, msg string, data map[string]interface{})
|
||||
|
||||
// Log delegates the logging request to the wrapped function
|
||||
func (f LoggerFunc) Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) {
|
||||
f(ctx, level, msg, data)
|
||||
}
|
||||
|
||||
// LogLevelFromString converts log level string to constant
|
||||
//
|
||||
// Valid levels:
|
||||
//
|
||||
// trace
|
||||
// debug
|
||||
// info
|
||||
// warn
|
||||
// error
|
||||
// none
|
||||
func LogLevelFromString(s string) (LogLevel, error) {
|
||||
switch s {
|
||||
case "trace":
|
||||
return LogLevelTrace, nil
|
||||
case "debug":
|
||||
return LogLevelDebug, nil
|
||||
case "info":
|
||||
return LogLevelInfo, nil
|
||||
case "warn":
|
||||
return LogLevelWarn, nil
|
||||
case "error":
|
||||
return LogLevelError, nil
|
||||
case "none":
|
||||
return LogLevelNone, nil
|
||||
default:
|
||||
return 0, errors.New("invalid log level")
|
||||
}
|
||||
}
|
||||
|
||||
func logQueryArgs(args []interface{}) []interface{} {
|
||||
logArgs := make([]interface{}, 0, len(args))
|
||||
|
||||
for _, a := range args {
|
||||
switch v := a.(type) {
|
||||
case []byte:
|
||||
if len(v) < 64 {
|
||||
a = hex.EncodeToString(v)
|
||||
} else {
|
||||
a = fmt.Sprintf("%x (truncated %d bytes)", v[:64], len(v)-64)
|
||||
}
|
||||
case string:
|
||||
if len(v) > 64 {
|
||||
a = fmt.Sprintf("%s (truncated %d bytes)", v[:64], len(v)-64)
|
||||
}
|
||||
}
|
||||
logArgs = append(logArgs, a)
|
||||
}
|
||||
|
||||
return logArgs
|
||||
}
|
||||
-23
@@ -1,23 +0,0 @@
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
|
||||
"github.com/jackc/pgtype"
|
||||
)
|
||||
|
||||
func convertDriverValuers(args []interface{}) ([]interface{}, error) {
|
||||
for i, arg := range args {
|
||||
switch arg := arg.(type) {
|
||||
case pgtype.BinaryEncoder:
|
||||
case pgtype.TextEncoder:
|
||||
case driver.Valuer:
|
||||
v, err := callValuerValue(arg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
args[i] = v
|
||||
}
|
||||
}
|
||||
return args, nil
|
||||
}
|
||||
+266
@@ -0,0 +1,266 @@
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// NamedArgs can be used as the first argument to a query method. It will replace every '@' named placeholder with a '$'
|
||||
// ordinal placeholder and construct the appropriate arguments.
|
||||
//
|
||||
// For example, the following two queries are equivalent:
|
||||
//
|
||||
// conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2}))
|
||||
// conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2}))
|
||||
type NamedArgs map[string]any
|
||||
|
||||
// RewriteQuery implements the QueryRewriter interface.
|
||||
func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any) {
|
||||
l := &sqlLexer{
|
||||
src: sql,
|
||||
stateFn: rawState,
|
||||
nameToOrdinal: make(map[namedArg]int, len(na)),
|
||||
}
|
||||
|
||||
for l.stateFn != nil {
|
||||
l.stateFn = l.stateFn(l)
|
||||
}
|
||||
|
||||
sb := strings.Builder{}
|
||||
for _, p := range l.parts {
|
||||
switch p := p.(type) {
|
||||
case string:
|
||||
sb.WriteString(p)
|
||||
case namedArg:
|
||||
sb.WriteRune('$')
|
||||
sb.WriteString(strconv.Itoa(l.nameToOrdinal[p]))
|
||||
}
|
||||
}
|
||||
|
||||
newArgs = make([]any, len(l.nameToOrdinal))
|
||||
for name, ordinal := range l.nameToOrdinal {
|
||||
newArgs[ordinal-1] = na[string(name)]
|
||||
}
|
||||
|
||||
return sb.String(), newArgs
|
||||
}
|
||||
|
||||
type namedArg string
|
||||
|
||||
type sqlLexer struct {
|
||||
src string
|
||||
start int
|
||||
pos int
|
||||
nested int // multiline comment nesting level.
|
||||
stateFn stateFn
|
||||
parts []any
|
||||
|
||||
nameToOrdinal map[namedArg]int
|
||||
}
|
||||
|
||||
type stateFn func(*sqlLexer) stateFn
|
||||
|
||||
func rawState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case 'e', 'E':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune == '\'' {
|
||||
l.pos += width
|
||||
return escapeStringState
|
||||
}
|
||||
case '\'':
|
||||
return singleQuoteState
|
||||
case '"':
|
||||
return doubleQuoteState
|
||||
case '@':
|
||||
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if isLetter(nextRune) {
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos-width])
|
||||
}
|
||||
l.start = l.pos
|
||||
return namedArgState
|
||||
}
|
||||
case '-':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune == '-' {
|
||||
l.pos += width
|
||||
return oneLineCommentState
|
||||
}
|
||||
case '/':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune == '*' {
|
||||
l.pos += width
|
||||
return multilineCommentState
|
||||
}
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isLetter(r rune) bool {
|
||||
return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z')
|
||||
}
|
||||
|
||||
func namedArgState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
if r == utf8.RuneError {
|
||||
if l.pos-l.start > 0 {
|
||||
na := namedArg(l.src[l.start:l.pos])
|
||||
if _, found := l.nameToOrdinal[na]; !found {
|
||||
l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1
|
||||
}
|
||||
l.parts = append(l.parts, na)
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
} else if !(isLetter(r) || (r >= '0' && r <= '9') || r == '_') {
|
||||
l.pos -= width
|
||||
na := namedArg(l.src[l.start:l.pos])
|
||||
if _, found := l.nameToOrdinal[na]; !found {
|
||||
l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1
|
||||
}
|
||||
l.parts = append(l.parts, namedArg(na))
|
||||
l.start = l.pos
|
||||
return rawState
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func singleQuoteState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '\'':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune != '\'' {
|
||||
return rawState
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func doubleQuoteState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '"':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune != '"' {
|
||||
return rawState
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func escapeStringState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '\\':
|
||||
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
case '\'':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune != '\'' {
|
||||
return rawState
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func oneLineCommentState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '\\':
|
||||
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
case '\n', '\r':
|
||||
return rawState
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func multilineCommentState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '/':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune == '*' {
|
||||
l.pos += width
|
||||
l.nested++
|
||||
}
|
||||
case '*':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune != '/' {
|
||||
continue
|
||||
}
|
||||
|
||||
l.pos += width
|
||||
if l.nested == 0 {
|
||||
return rawState
|
||||
}
|
||||
l.nested--
|
||||
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
package pgx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNamedArgsRewriteQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for i, tt := range []struct {
|
||||
sql string
|
||||
args []any
|
||||
namedArgs pgx.NamedArgs
|
||||
expectedSQL string
|
||||
expectedArgs []any
|
||||
}{
|
||||
{
|
||||
sql: "select * from users where id = @id",
|
||||
namedArgs: pgx.NamedArgs{"id": int32(42)},
|
||||
expectedSQL: "select * from users where id = $1",
|
||||
expectedArgs: []any{int32(42)},
|
||||
},
|
||||
{
|
||||
sql: "select * from t where foo < @abc and baz = @def and bar < @abc",
|
||||
namedArgs: pgx.NamedArgs{"abc": int32(42), "def": int32(1)},
|
||||
expectedSQL: "select * from t where foo < $1 and baz = $2 and bar < $1",
|
||||
expectedArgs: []any{int32(42), int32(1)},
|
||||
},
|
||||
{
|
||||
sql: "select @a::int, @b::text",
|
||||
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
|
||||
expectedSQL: "select $1::int, $2::text",
|
||||
expectedArgs: []any{int32(42), "foo"},
|
||||
},
|
||||
{
|
||||
sql: "select @Abc::int, @b_4::text",
|
||||
namedArgs: pgx.NamedArgs{"Abc": int32(42), "b_4": "foo"},
|
||||
expectedSQL: "select $1::int, $2::text",
|
||||
expectedArgs: []any{int32(42), "foo"},
|
||||
},
|
||||
{
|
||||
sql: "at end @",
|
||||
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
|
||||
expectedSQL: "at end @",
|
||||
expectedArgs: []any{},
|
||||
},
|
||||
{
|
||||
sql: "ignores without letter after @ foo bar",
|
||||
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
|
||||
expectedSQL: "ignores without letter after @ foo bar",
|
||||
expectedArgs: []any{},
|
||||
},
|
||||
{
|
||||
sql: "name must start with letter @1 foo bar",
|
||||
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
|
||||
expectedSQL: "name must start with letter @1 foo bar",
|
||||
expectedArgs: []any{},
|
||||
},
|
||||
{
|
||||
sql: `select *, '@foo' as "@bar" from users where id = @id`,
|
||||
namedArgs: pgx.NamedArgs{"id": int32(42)},
|
||||
expectedSQL: `select *, '@foo' as "@bar" from users where id = $1`,
|
||||
expectedArgs: []any{int32(42)},
|
||||
},
|
||||
{
|
||||
sql: `select * -- @foo
|
||||
from users -- @single line comments
|
||||
where id = @id;`,
|
||||
namedArgs: pgx.NamedArgs{"id": int32(42)},
|
||||
expectedSQL: `select * -- @foo
|
||||
from users -- @single line comments
|
||||
where id = $1;`,
|
||||
expectedArgs: []any{int32(42)},
|
||||
},
|
||||
{
|
||||
sql: `select * /* @multi line
|
||||
@comment
|
||||
*/
|
||||
/* /* with @nesting */ */
|
||||
from users
|
||||
where id = @id;`,
|
||||
namedArgs: pgx.NamedArgs{"id": int32(42)},
|
||||
expectedSQL: `select * /* @multi line
|
||||
@comment
|
||||
*/
|
||||
/* /* with @nesting */ */
|
||||
from users
|
||||
where id = $1;`,
|
||||
expectedArgs: []any{int32(42)},
|
||||
},
|
||||
|
||||
// test comments and quotes
|
||||
} {
|
||||
sql, args := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args)
|
||||
assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
|
||||
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
|
||||
}
|
||||
}
|
||||
+4
-8
@@ -5,9 +5,7 @@ import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgconn/stmtcache"
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -19,9 +17,8 @@ func TestPgbouncerStatementCacheDescribe(t *testing.T) {
|
||||
}
|
||||
|
||||
config := mustParseConfig(t, connString)
|
||||
config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache {
|
||||
return stmtcache.New(conn, stmtcache.ModeDescribe, 1024)
|
||||
}
|
||||
config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe
|
||||
config.DescriptionCacheCapacity = 1024
|
||||
|
||||
testPgbouncer(t, config, 10, 100)
|
||||
}
|
||||
@@ -33,8 +30,7 @@ func TestPgbouncerSimpleProtocol(t *testing.T) {
|
||||
}
|
||||
|
||||
config := mustParseConfig(t, connString)
|
||||
config.BuildStatementCache = nil
|
||||
config.PreferSimpleProtocol = true
|
||||
config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
|
||||
|
||||
testPgbouncer(t, config, 10, 100)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
# pgconn
|
||||
|
||||
Package pgconn is a low-level PostgreSQL database driver. It operates at nearly the same level as the C library libpq.
|
||||
It is primarily intended to serve as the foundation for higher level libraries such as https://github.com/jackc/pgx.
|
||||
Applications should handle normal queries with a higher level library and only use pgconn directly when required for
|
||||
low-level access to PostgreSQL functionality.
|
||||
|
||||
## Example Usage
|
||||
|
||||
```go
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL"))
|
||||
if err != nil {
|
||||
log.Fatalln("pgconn failed to connect:", err)
|
||||
}
|
||||
defer pgConn.Close(context.Background())
|
||||
|
||||
result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil)
|
||||
for result.NextRow() {
|
||||
fmt.Println("User 123 has email:", string(result.Values()[0]))
|
||||
}
|
||||
_, err = result.Close()
|
||||
if err != nil {
|
||||
log.Fatalln("failed reading result:", err)
|
||||
}
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
The pgconn tests require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_CONN_STRING`
|
||||
environment variable. The `PGX_TEST_CONN_STRING` environment variable can be a URL or DSN. In addition, the standard `PG*`
|
||||
environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to simplify
|
||||
environment variable handling.
|
||||
|
||||
### Example Test Environment
|
||||
|
||||
Connect to your PostgreSQL server and run:
|
||||
|
||||
```
|
||||
create database pgx_test;
|
||||
```
|
||||
|
||||
Now you can run the tests:
|
||||
|
||||
```bash
|
||||
PGX_TEST_CONN_STRING="host=/var/run/postgresql dbname=pgx_test" go test ./...
|
||||
```
|
||||
|
||||
### Connection and Authentication Tests
|
||||
|
||||
Pgconn supports multiple connection types and means of authentication. These tests are optional. They
|
||||
will only run if the appropriate environment variable is set. Run `go test -v | grep SKIP` to see if any tests are being
|
||||
skipped. Most developers will not need to enable these tests. See `ci/setup_test.bash` for an example set up if you need change
|
||||
authentication code.
|
||||
@@ -0,0 +1,272 @@
|
||||
// SCRAM-SHA-256 authentication
|
||||
//
|
||||
// Resources:
|
||||
// https://tools.ietf.org/html/rfc5802
|
||||
// https://tools.ietf.org/html/rfc8265
|
||||
// https://www.postgresql.org/docs/current/sasl-authentication.html
|
||||
//
|
||||
// Inspiration drawn from other implementations:
|
||||
// https://github.com/lib/pq/pull/608
|
||||
// https://github.com/lib/pq/pull/788
|
||||
// https://github.com/lib/pq/pull/833
|
||||
|
||||
package pgconn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
"golang.org/x/text/secure/precis"
|
||||
)
|
||||
|
||||
const clientNonceLen = 18
|
||||
|
||||
// Perform SCRAM authentication.
|
||||
func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
|
||||
sc, err := newScramClient(serverAuthMechanisms, c.config.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send client-first-message in a SASLInitialResponse
|
||||
saslInitialResponse := &pgproto3.SASLInitialResponse{
|
||||
AuthMechanism: "SCRAM-SHA-256",
|
||||
Data: sc.clientFirstMessage(),
|
||||
}
|
||||
c.frontend.Send(saslInitialResponse)
|
||||
err = c.frontend.Flush()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Receive server-first-message payload in a AuthenticationSASLContinue.
|
||||
saslContinue, err := c.rxSASLContinue()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = sc.recvServerFirstMessage(saslContinue.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send client-final-message in a SASLResponse
|
||||
saslResponse := &pgproto3.SASLResponse{
|
||||
Data: []byte(sc.clientFinalMessage()),
|
||||
}
|
||||
c.frontend.Send(saslResponse)
|
||||
err = c.frontend.Flush()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Receive server-final-message payload in a AuthenticationSASLFinal.
|
||||
saslFinal, err := c.rxSASLFinal()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sc.recvServerFinalMessage(saslFinal.Data)
|
||||
}
|
||||
|
||||
func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) {
|
||||
msg, err := c.receiveMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch m := msg.(type) {
|
||||
case *pgproto3.AuthenticationSASLContinue:
|
||||
return m, nil
|
||||
case *pgproto3.ErrorResponse:
|
||||
return nil, ErrorResponseToPgError(m)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", msg)
|
||||
}
|
||||
|
||||
func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
|
||||
msg, err := c.receiveMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch m := msg.(type) {
|
||||
case *pgproto3.AuthenticationSASLFinal:
|
||||
return m, nil
|
||||
case *pgproto3.ErrorResponse:
|
||||
return nil, ErrorResponseToPgError(m)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", msg)
|
||||
}
|
||||
|
||||
type scramClient struct {
|
||||
serverAuthMechanisms []string
|
||||
password []byte
|
||||
clientNonce []byte
|
||||
|
||||
clientFirstMessageBare []byte
|
||||
|
||||
serverFirstMessage []byte
|
||||
clientAndServerNonce []byte
|
||||
salt []byte
|
||||
iterations int
|
||||
|
||||
saltedPassword []byte
|
||||
authMessage []byte
|
||||
}
|
||||
|
||||
func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) {
|
||||
sc := &scramClient{
|
||||
serverAuthMechanisms: serverAuthMechanisms,
|
||||
}
|
||||
|
||||
// Ensure server supports SCRAM-SHA-256
|
||||
hasScramSHA256 := false
|
||||
for _, mech := range sc.serverAuthMechanisms {
|
||||
if mech == "SCRAM-SHA-256" {
|
||||
hasScramSHA256 = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasScramSHA256 {
|
||||
return nil, errors.New("server does not support SCRAM-SHA-256")
|
||||
}
|
||||
|
||||
// precis.OpaqueString is equivalent to SASLprep for password.
|
||||
var err error
|
||||
sc.password, err = precis.OpaqueString.Bytes([]byte(password))
|
||||
if err != nil {
|
||||
// PostgreSQL allows passwords invalid according to SCRAM / SASLprep.
|
||||
sc.password = []byte(password)
|
||||
}
|
||||
|
||||
buf := make([]byte, clientNonceLen)
|
||||
_, err = rand.Read(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf)))
|
||||
base64.RawStdEncoding.Encode(sc.clientNonce, buf)
|
||||
|
||||
return sc, nil
|
||||
}
|
||||
|
||||
func (sc *scramClient) clientFirstMessage() []byte {
|
||||
sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce))
|
||||
return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare))
|
||||
}
|
||||
|
||||
func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error {
|
||||
sc.serverFirstMessage = serverFirstMessage
|
||||
buf := serverFirstMessage
|
||||
if !bytes.HasPrefix(buf, []byte("r=")) {
|
||||
return errors.New("invalid SCRAM server-first-message received from server: did not include r=")
|
||||
}
|
||||
buf = buf[2:]
|
||||
|
||||
idx := bytes.IndexByte(buf, ',')
|
||||
if idx == -1 {
|
||||
return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
|
||||
}
|
||||
sc.clientAndServerNonce = buf[:idx]
|
||||
buf = buf[idx+1:]
|
||||
|
||||
if !bytes.HasPrefix(buf, []byte("s=")) {
|
||||
return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
|
||||
}
|
||||
buf = buf[2:]
|
||||
|
||||
idx = bytes.IndexByte(buf, ',')
|
||||
if idx == -1 {
|
||||
return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
|
||||
}
|
||||
saltStr := buf[:idx]
|
||||
buf = buf[idx+1:]
|
||||
|
||||
if !bytes.HasPrefix(buf, []byte("i=")) {
|
||||
return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
|
||||
}
|
||||
buf = buf[2:]
|
||||
iterationsStr := buf
|
||||
|
||||
var err error
|
||||
sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr))
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid SCRAM salt received from server: %w", err)
|
||||
}
|
||||
|
||||
sc.iterations, err = strconv.Atoi(string(iterationsStr))
|
||||
if err != nil || sc.iterations <= 0 {
|
||||
return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err)
|
||||
}
|
||||
|
||||
if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) {
|
||||
return errors.New("invalid SCRAM nonce: did not start with client nonce")
|
||||
}
|
||||
|
||||
if len(sc.clientAndServerNonce) <= len(sc.clientNonce) {
|
||||
return errors.New("invalid SCRAM nonce: did not include server nonce")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sc *scramClient) clientFinalMessage() string {
|
||||
clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce))
|
||||
|
||||
sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New)
|
||||
sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(","))
|
||||
|
||||
clientProof := computeClientProof(sc.saltedPassword, sc.authMessage)
|
||||
|
||||
return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof)
|
||||
}
|
||||
|
||||
func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error {
|
||||
if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) {
|
||||
return errors.New("invalid SCRAM server-final-message received from server")
|
||||
}
|
||||
|
||||
serverSignature := serverFinalMessage[2:]
|
||||
|
||||
if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) {
|
||||
return errors.New("invalid SCRAM ServerSignature received from server")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func computeHMAC(key, msg []byte) []byte {
|
||||
mac := hmac.New(sha256.New, key)
|
||||
mac.Write(msg)
|
||||
return mac.Sum(nil)
|
||||
}
|
||||
|
||||
func computeClientProof(saltedPassword, authMessage []byte) []byte {
|
||||
clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
|
||||
storedKey := sha256.Sum256(clientKey)
|
||||
clientSignature := computeHMAC(storedKey[:], authMessage)
|
||||
|
||||
clientProof := make([]byte, len(clientSignature))
|
||||
for i := 0; i < len(clientSignature); i++ {
|
||||
clientProof[i] = clientKey[i] ^ clientSignature[i]
|
||||
}
|
||||
|
||||
buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof)))
|
||||
base64.StdEncoding.Encode(buf, clientProof)
|
||||
return buf
|
||||
}
|
||||
|
||||
func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte {
|
||||
serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
|
||||
serverSignature := computeHMAC(serverKey, authMessage)
|
||||
buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature)))
|
||||
base64.StdEncoding.Encode(buf, serverSignature)
|
||||
return buf
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package pgconn
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkCommandTagRowsAffected(b *testing.B) {
|
||||
benchmarks := []struct {
|
||||
commandTag string
|
||||
rowsAffected int64
|
||||
}{
|
||||
{"UPDATE 1", 1},
|
||||
{"UPDATE 123456789", 123456789},
|
||||
{"INSERT 0 1", 1},
|
||||
{"INSERT 0 123456789", 123456789},
|
||||
}
|
||||
|
||||
for _, bm := range benchmarks {
|
||||
ct := CommandTag{s: bm.commandTag}
|
||||
b.Run(bm.commandTag, func(b *testing.B) {
|
||||
var n int64
|
||||
for i := 0; i < b.N; i++ {
|
||||
n = ct.RowsAffected()
|
||||
}
|
||||
if n != bm.rowsAffected {
|
||||
b.Errorf("expected %d got %d", bm.rowsAffected, n)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCommandTagTypeFromString(b *testing.B) {
|
||||
ct := CommandTag{s: "UPDATE 1"}
|
||||
|
||||
var update bool
|
||||
for i := 0; i < b.N; i++ {
|
||||
update = strings.HasPrefix(ct.String(), "UPDATE")
|
||||
}
|
||||
if !update {
|
||||
b.Error("expected update")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCommandTagInsert(b *testing.B) {
|
||||
benchmarks := []struct {
|
||||
commandTag string
|
||||
is bool
|
||||
}{
|
||||
{"INSERT 1", true},
|
||||
{"INSERT 1234567890", true},
|
||||
{"UPDATE 1", false},
|
||||
{"UPDATE 1234567890", false},
|
||||
{"DELETE 1", false},
|
||||
{"DELETE 1234567890", false},
|
||||
{"SELECT 1", false},
|
||||
{"SELECT 1234567890", false},
|
||||
{"UNKNOWN 1234567890", false},
|
||||
}
|
||||
|
||||
for _, bm := range benchmarks {
|
||||
ct := CommandTag{s: bm.commandTag}
|
||||
b.Run(bm.commandTag, func(b *testing.B) {
|
||||
var is bool
|
||||
for i := 0; i < b.N; i++ {
|
||||
is = ct.Insert()
|
||||
}
|
||||
if is != bm.is {
|
||||
b.Errorf("expected %v got %v", bm.is, is)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,254 @@
|
||||
package pgconn_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func BenchmarkConnect(b *testing.B) {
|
||||
benchmarks := []struct {
|
||||
name string
|
||||
env string
|
||||
}{
|
||||
{"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"},
|
||||
{"TCP", "PGX_TEST_TCP_CONN_STRING"},
|
||||
}
|
||||
|
||||
for _, bm := range benchmarks {
|
||||
bm := bm
|
||||
b.Run(bm.name, func(b *testing.B) {
|
||||
connString := os.Getenv(bm.env)
|
||||
if connString == "" {
|
||||
b.Skipf("Skipping due to missing environment variable %v", bm.env)
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn, err := pgconn.Connect(context.Background(), connString)
|
||||
require.Nil(b, err)
|
||||
|
||||
err = conn.Close(context.Background())
|
||||
require.Nil(b, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkExec(b *testing.B) {
|
||||
expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
|
||||
benchmarks := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
}{
|
||||
// Using an empty context other than context.Background() to compare
|
||||
// performance
|
||||
{"background context", context.Background()},
|
||||
{"empty context", context.TODO()},
|
||||
}
|
||||
|
||||
for _, bm := range benchmarks {
|
||||
bm := bm
|
||||
b.Run(bm.name, func(b *testing.B) {
|
||||
conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.Nil(b, err)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
mrr := conn.Exec(bm.ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date")
|
||||
|
||||
for mrr.NextResult() {
|
||||
rr := mrr.ResultReader()
|
||||
|
||||
rowCount := 0
|
||||
for rr.NextRow() {
|
||||
rowCount++
|
||||
if len(rr.Values()) != len(expectedValues) {
|
||||
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
|
||||
}
|
||||
for i := range rr.Values() {
|
||||
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
|
||||
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
_, err = rr.Close()
|
||||
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if rowCount != 1 {
|
||||
b.Fatalf("unexpected rowCount: %d", rowCount)
|
||||
}
|
||||
}
|
||||
|
||||
err := mrr.Close()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkExecPossibleToCancel(b *testing.B) {
|
||||
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.Nil(b, err)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
mrr := conn.Exec(ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date")
|
||||
|
||||
for mrr.NextResult() {
|
||||
rr := mrr.ResultReader()
|
||||
|
||||
rowCount := 0
|
||||
for rr.NextRow() {
|
||||
rowCount++
|
||||
if len(rr.Values()) != len(expectedValues) {
|
||||
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
|
||||
}
|
||||
for i := range rr.Values() {
|
||||
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
|
||||
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
_, err = rr.Close()
|
||||
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if rowCount != 1 {
|
||||
b.Fatalf("unexpected rowCount: %d", rowCount)
|
||||
}
|
||||
}
|
||||
|
||||
err := mrr.Close()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkExecPrepared(b *testing.B) {
|
||||
expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
|
||||
|
||||
benchmarks := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
}{
|
||||
// Using an empty context other than context.Background() to compare
|
||||
// performance
|
||||
{"background context", context.Background()},
|
||||
{"empty context", context.TODO()},
|
||||
}
|
||||
|
||||
for _, bm := range benchmarks {
|
||||
bm := bm
|
||||
b.Run(bm.name, func(b *testing.B) {
|
||||
conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.Nil(b, err)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
_, err = conn.Prepare(bm.ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
|
||||
require.Nil(b, err)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
rr := conn.ExecPrepared(bm.ctx, "ps1", nil, nil, nil)
|
||||
|
||||
rowCount := 0
|
||||
for rr.NextRow() {
|
||||
rowCount++
|
||||
if len(rr.Values()) != len(expectedValues) {
|
||||
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
|
||||
}
|
||||
for i := range rr.Values() {
|
||||
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
|
||||
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
_, err = rr.Close()
|
||||
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if rowCount != 1 {
|
||||
b.Fatalf("unexpected rowCount: %d", rowCount)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkExecPreparedPossibleToCancel(b *testing.B) {
|
||||
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.Nil(b, err)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
_, err = conn.Prepare(ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
|
||||
require.Nil(b, err)
|
||||
|
||||
expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
rr := conn.ExecPrepared(ctx, "ps1", nil, nil, nil)
|
||||
|
||||
rowCount := 0
|
||||
for rr.NextRow() {
|
||||
rowCount += 1
|
||||
if len(rr.Values()) != len(expectedValues) {
|
||||
b.Fatalf("unexpected number of values: %d", len(rr.Values()))
|
||||
}
|
||||
for i := range rr.Values() {
|
||||
if !bytes.Equal(rr.Values()[i], expectedValues[i]) {
|
||||
b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
_, err = rr.Close()
|
||||
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if rowCount != 1 {
|
||||
b.Fatalf("unexpected rowCount: %d", rowCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// func BenchmarkChanToSetDeadlinePossibleToCancel(b *testing.B) {
|
||||
// conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
// require.Nil(b, err)
|
||||
// defer closeConn(b, conn)
|
||||
|
||||
// ctx, cancel := context.WithCancel(context.Background())
|
||||
// defer cancel()
|
||||
|
||||
// b.ResetTimer()
|
||||
|
||||
// for i := 0; i < b.N; i++ {
|
||||
// conn.ChanToSetDeadline().Watch(ctx)
|
||||
// conn.ChanToSetDeadline().Ignore()
|
||||
// }
|
||||
// }
|
||||
@@ -0,0 +1,886 @@
|
||||
package pgconn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgpassfile"
|
||||
"github.com/jackc/pgservicefile"
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
)
|
||||
|
||||
type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
||||
type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
||||
type GetSSLPasswordFunc func(ctx context.Context) string
|
||||
|
||||
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig. A
|
||||
// manually initialized Config will cause ConnectConfig to panic.
|
||||
type Config struct {
|
||||
Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp)
|
||||
Port uint16
|
||||
Database string
|
||||
User string
|
||||
Password string
|
||||
TLSConfig *tls.Config // nil disables TLS
|
||||
ConnectTimeout time.Duration
|
||||
DialFunc DialFunc // e.g. net.Dialer.DialContext
|
||||
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
|
||||
BuildFrontend BuildFrontendFunc
|
||||
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
||||
|
||||
KerberosSrvName string
|
||||
KerberosSpn string
|
||||
Fallbacks []*FallbackConfig
|
||||
|
||||
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
|
||||
// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
|
||||
// fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs.
|
||||
ValidateConnect ValidateConnectFunc
|
||||
|
||||
// AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables
|
||||
// or prepare statements). If this returns an error the connection attempt fails.
|
||||
AfterConnect AfterConnectFunc
|
||||
|
||||
// OnNotice is a callback function called when a notice response is received.
|
||||
OnNotice NoticeHandler
|
||||
|
||||
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
|
||||
OnNotification NotificationHandler
|
||||
|
||||
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
|
||||
}
|
||||
|
||||
// ParseConfigOptions contains options that control how a config is built such as getsslpassword.
|
||||
type ParseConfigOptions struct {
|
||||
// GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the the libpq function
|
||||
// PQsetSSLKeyPassHook_OpenSSL.
|
||||
GetSSLPassword GetSSLPasswordFunc
|
||||
}
|
||||
|
||||
// Copy returns a deep copy of the config that is safe to use and modify.
|
||||
// The only exception is the TLSConfig field:
|
||||
// according to the tls.Config docs it must not be modified after creation.
|
||||
func (c *Config) Copy() *Config {
|
||||
newConf := new(Config)
|
||||
*newConf = *c
|
||||
if newConf.TLSConfig != nil {
|
||||
newConf.TLSConfig = c.TLSConfig.Clone()
|
||||
}
|
||||
if newConf.RuntimeParams != nil {
|
||||
newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams))
|
||||
for k, v := range c.RuntimeParams {
|
||||
newConf.RuntimeParams[k] = v
|
||||
}
|
||||
}
|
||||
if newConf.Fallbacks != nil {
|
||||
newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks))
|
||||
for i, fallback := range c.Fallbacks {
|
||||
newFallback := new(FallbackConfig)
|
||||
*newFallback = *fallback
|
||||
if newFallback.TLSConfig != nil {
|
||||
newFallback.TLSConfig = fallback.TLSConfig.Clone()
|
||||
}
|
||||
newConf.Fallbacks[i] = newFallback
|
||||
}
|
||||
}
|
||||
return newConf
|
||||
}
|
||||
|
||||
// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a
|
||||
// network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections.
|
||||
type FallbackConfig struct {
|
||||
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
||||
Port uint16
|
||||
TLSConfig *tls.Config // nil disables TLS
|
||||
}
|
||||
|
||||
// isAbsolutePath checks if the provided value is an absolute path either
|
||||
// beginning with a forward slash (as on Linux-based systems) or with a capital
|
||||
// letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows).
|
||||
func isAbsolutePath(path string) bool {
|
||||
isWindowsPath := func(p string) bool {
|
||||
if len(p) < 3 {
|
||||
return false
|
||||
}
|
||||
drive := p[0]
|
||||
colon := p[1]
|
||||
backslash := p[2]
|
||||
if drive >= 'A' && drive <= 'Z' && colon == ':' && backslash == '\\' {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
return strings.HasPrefix(path, "/") || isWindowsPath(path)
|
||||
}
|
||||
|
||||
// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with
|
||||
// net.Dial.
|
||||
func NetworkAddress(host string, port uint16) (network, address string) {
|
||||
if isAbsolutePath(host) {
|
||||
network = "unix"
|
||||
address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
|
||||
} else {
|
||||
network = "tcp"
|
||||
address = net.JoinHostPort(host, strconv.Itoa(int(port)))
|
||||
}
|
||||
return network, address
|
||||
}
|
||||
|
||||
// ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It
|
||||
// uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely
|
||||
// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style).
|
||||
// See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be
|
||||
// empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file.
|
||||
//
|
||||
// # Example DSN
|
||||
// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca
|
||||
//
|
||||
// # Example URL
|
||||
// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca
|
||||
//
|
||||
// The returned *Config may be modified. However, it is strongly recommended that any configuration that can be done
|
||||
// through the connection string be done there. In particular the fields Host, Port, TLSConfig, and Fallbacks can be
|
||||
// interdependent (e.g. TLSConfig needs knowledge of the host to validate the server certificate). These fields should
|
||||
// not be modified individually. They should all be modified or all left unchanged.
|
||||
//
|
||||
// ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated
|
||||
// values that will be tried in order. This can be used as part of a high availability system. See
|
||||
// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information.
|
||||
//
|
||||
// # Example URL
|
||||
// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
|
||||
//
|
||||
// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed
|
||||
// via database URL or DSN:
|
||||
//
|
||||
// PGHOST
|
||||
// PGPORT
|
||||
// PGDATABASE
|
||||
// PGUSER
|
||||
// PGPASSWORD
|
||||
// PGPASSFILE
|
||||
// PGSERVICE
|
||||
// PGSERVICEFILE
|
||||
// PGSSLMODE
|
||||
// PGSSLCERT
|
||||
// PGSSLKEY
|
||||
// PGSSLROOTCERT
|
||||
// PGSSLPASSWORD
|
||||
// PGAPPNAME
|
||||
// PGCONNECT_TIMEOUT
|
||||
// PGTARGETSESSIONATTRS
|
||||
//
|
||||
// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables.
|
||||
//
|
||||
// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are
|
||||
// usually but not always the environment variable name downcased and without the "PG" prefix.
|
||||
//
|
||||
// Important Security Notes:
|
||||
//
|
||||
// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if
|
||||
// not set.
|
||||
//
|
||||
// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of
|
||||
// security each sslmode provides.
|
||||
//
|
||||
// The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of
|
||||
// the Config struct. If TLSConfig is manually changed it will not affect the fallbacks. For example, in the case of
|
||||
// sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try the fallback
|
||||
// which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually
|
||||
// changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting
|
||||
// TLSConfig.
|
||||
//
|
||||
// Other known differences with libpq:
|
||||
//
|
||||
// When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn
|
||||
// does not.
|
||||
//
|
||||
// In addition, ParseConfig accepts the following options:
|
||||
//
|
||||
// servicefile
|
||||
// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a
|
||||
// part of the connection string.
|
||||
func ParseConfig(connString string) (*Config, error) {
|
||||
var parseConfigOptions ParseConfigOptions
|
||||
return ParseConfigWithOptions(connString, parseConfigOptions)
|
||||
}
|
||||
|
||||
// ParseConfigWithOptions builds a *Config from connString and options with similar behavior to the PostgreSQL standard
|
||||
// C library libpq. options contains settings that cannot be specified in a connString such as providing a function to
|
||||
// get the SSL password.
|
||||
func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Config, error) {
|
||||
defaultSettings := defaultSettings()
|
||||
envSettings := parseEnvSettings()
|
||||
|
||||
connStringSettings := make(map[string]string)
|
||||
if connString != "" {
|
||||
var err error
|
||||
// connString may be a database URL or a DSN
|
||||
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
|
||||
connStringSettings, err = parseURLSettings(connString)
|
||||
if err != nil {
|
||||
return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err}
|
||||
}
|
||||
} else {
|
||||
connStringSettings, err = parseDSNSettings(connString)
|
||||
if err != nil {
|
||||
return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
settings := mergeSettings(defaultSettings, envSettings, connStringSettings)
|
||||
if service, present := settings["service"]; present {
|
||||
serviceSettings, err := parseServiceSettings(settings["servicefile"], service)
|
||||
if err != nil {
|
||||
return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err}
|
||||
}
|
||||
|
||||
settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings)
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
createdByParseConfig: true,
|
||||
Database: settings["database"],
|
||||
User: settings["user"],
|
||||
Password: settings["password"],
|
||||
RuntimeParams: make(map[string]string),
|
||||
BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend {
|
||||
return pgproto3.NewFrontend(r, w)
|
||||
},
|
||||
}
|
||||
|
||||
if connectTimeoutSetting, present := settings["connect_timeout"]; present {
|
||||
connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
|
||||
if err != nil {
|
||||
return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
|
||||
}
|
||||
config.ConnectTimeout = connectTimeout
|
||||
config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
|
||||
} else {
|
||||
defaultDialer := makeDefaultDialer()
|
||||
config.DialFunc = defaultDialer.DialContext
|
||||
}
|
||||
|
||||
config.LookupFunc = makeDefaultResolver().LookupHost
|
||||
|
||||
notRuntimeParams := map[string]struct{}{
|
||||
"host": {},
|
||||
"port": {},
|
||||
"database": {},
|
||||
"user": {},
|
||||
"password": {},
|
||||
"passfile": {},
|
||||
"connect_timeout": {},
|
||||
"sslmode": {},
|
||||
"sslkey": {},
|
||||
"sslcert": {},
|
||||
"sslrootcert": {},
|
||||
"sslpassword": {},
|
||||
"sslsni": {},
|
||||
"krbspn": {},
|
||||
"krbsrvname": {},
|
||||
"target_session_attrs": {},
|
||||
"service": {},
|
||||
"servicefile": {},
|
||||
}
|
||||
|
||||
// Adding kerberos configuration
|
||||
if _, present := settings["krbsrvname"]; present {
|
||||
config.KerberosSrvName = settings["krbsrvname"]
|
||||
}
|
||||
if _, present := settings["krbspn"]; present {
|
||||
config.KerberosSpn = settings["krbspn"]
|
||||
}
|
||||
|
||||
for k, v := range settings {
|
||||
if _, present := notRuntimeParams[k]; present {
|
||||
continue
|
||||
}
|
||||
config.RuntimeParams[k] = v
|
||||
}
|
||||
|
||||
fallbacks := []*FallbackConfig{}
|
||||
|
||||
hosts := strings.Split(settings["host"], ",")
|
||||
ports := strings.Split(settings["port"], ",")
|
||||
|
||||
for i, host := range hosts {
|
||||
var portStr string
|
||||
if i < len(ports) {
|
||||
portStr = ports[i]
|
||||
} else {
|
||||
portStr = ports[0]
|
||||
}
|
||||
|
||||
port, err := parsePort(portStr)
|
||||
if err != nil {
|
||||
return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err}
|
||||
}
|
||||
|
||||
var tlsConfigs []*tls.Config
|
||||
|
||||
// Ignore TLS settings if Unix domain socket like libpq
|
||||
if network, _ := NetworkAddress(host, port); network == "unix" {
|
||||
tlsConfigs = append(tlsConfigs, nil)
|
||||
} else {
|
||||
var err error
|
||||
tlsConfigs, err = configTLS(settings, host, options)
|
||||
if err != nil {
|
||||
return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err}
|
||||
}
|
||||
}
|
||||
|
||||
for _, tlsConfig := range tlsConfigs {
|
||||
fallbacks = append(fallbacks, &FallbackConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
TLSConfig: tlsConfig,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
config.Host = fallbacks[0].Host
|
||||
config.Port = fallbacks[0].Port
|
||||
config.TLSConfig = fallbacks[0].TLSConfig
|
||||
config.Fallbacks = fallbacks[1:]
|
||||
|
||||
passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
|
||||
if err == nil {
|
||||
if config.Password == "" {
|
||||
host := config.Host
|
||||
if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" {
|
||||
host = "localhost"
|
||||
}
|
||||
|
||||
config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User)
|
||||
}
|
||||
}
|
||||
|
||||
switch tsa := settings["target_session_attrs"]; tsa {
|
||||
case "read-write":
|
||||
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite
|
||||
case "read-only":
|
||||
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadOnly
|
||||
case "primary":
|
||||
config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary
|
||||
case "standby":
|
||||
config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby
|
||||
case "prefer-standby":
|
||||
config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby
|
||||
case "any":
|
||||
// do nothing
|
||||
default:
|
||||
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func mergeSettings(settingSets ...map[string]string) map[string]string {
|
||||
settings := make(map[string]string)
|
||||
|
||||
for _, s2 := range settingSets {
|
||||
for k, v := range s2 {
|
||||
settings[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return settings
|
||||
}
|
||||
|
||||
func parseEnvSettings() map[string]string {
|
||||
settings := make(map[string]string)
|
||||
|
||||
nameMap := map[string]string{
|
||||
"PGHOST": "host",
|
||||
"PGPORT": "port",
|
||||
"PGDATABASE": "database",
|
||||
"PGUSER": "user",
|
||||
"PGPASSWORD": "password",
|
||||
"PGPASSFILE": "passfile",
|
||||
"PGAPPNAME": "application_name",
|
||||
"PGCONNECT_TIMEOUT": "connect_timeout",
|
||||
"PGSSLMODE": "sslmode",
|
||||
"PGSSLKEY": "sslkey",
|
||||
"PGSSLCERT": "sslcert",
|
||||
"PGSSLSNI": "sslsni",
|
||||
"PGSSLROOTCERT": "sslrootcert",
|
||||
"PGSSLPASSWORD": "sslpassword",
|
||||
"PGTARGETSESSIONATTRS": "target_session_attrs",
|
||||
"PGSERVICE": "service",
|
||||
"PGSERVICEFILE": "servicefile",
|
||||
}
|
||||
|
||||
for envname, realname := range nameMap {
|
||||
value := os.Getenv(envname)
|
||||
if value != "" {
|
||||
settings[realname] = value
|
||||
}
|
||||
}
|
||||
|
||||
return settings
|
||||
}
|
||||
|
||||
func parseURLSettings(connString string) (map[string]string, error) {
|
||||
settings := make(map[string]string)
|
||||
|
||||
url, err := url.Parse(connString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if url.User != nil {
|
||||
settings["user"] = url.User.Username()
|
||||
if password, present := url.User.Password(); present {
|
||||
settings["password"] = password
|
||||
}
|
||||
}
|
||||
|
||||
// Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port.
|
||||
var hosts []string
|
||||
var ports []string
|
||||
for _, host := range strings.Split(url.Host, ",") {
|
||||
if host == "" {
|
||||
continue
|
||||
}
|
||||
if isIPOnly(host) {
|
||||
hosts = append(hosts, strings.Trim(host, "[]"))
|
||||
continue
|
||||
}
|
||||
h, p, err := net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err)
|
||||
}
|
||||
if h != "" {
|
||||
hosts = append(hosts, h)
|
||||
}
|
||||
if p != "" {
|
||||
ports = append(ports, p)
|
||||
}
|
||||
}
|
||||
if len(hosts) > 0 {
|
||||
settings["host"] = strings.Join(hosts, ",")
|
||||
}
|
||||
if len(ports) > 0 {
|
||||
settings["port"] = strings.Join(ports, ",")
|
||||
}
|
||||
|
||||
database := strings.TrimLeft(url.Path, "/")
|
||||
if database != "" {
|
||||
settings["database"] = database
|
||||
}
|
||||
|
||||
nameMap := map[string]string{
|
||||
"dbname": "database",
|
||||
}
|
||||
|
||||
for k, v := range url.Query() {
|
||||
if k2, present := nameMap[k]; present {
|
||||
k = k2
|
||||
}
|
||||
|
||||
settings[k] = v[0]
|
||||
}
|
||||
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
func isIPOnly(host string) bool {
|
||||
return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":")
|
||||
}
|
||||
|
||||
var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
|
||||
|
||||
func parseDSNSettings(s string) (map[string]string, error) {
|
||||
settings := make(map[string]string)
|
||||
|
||||
nameMap := map[string]string{
|
||||
"dbname": "database",
|
||||
}
|
||||
|
||||
for len(s) > 0 {
|
||||
var key, val string
|
||||
eqIdx := strings.IndexRune(s, '=')
|
||||
if eqIdx < 0 {
|
||||
return nil, errors.New("invalid dsn")
|
||||
}
|
||||
|
||||
key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
|
||||
s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f")
|
||||
if len(s) == 0 {
|
||||
} else if s[0] != '\'' {
|
||||
end := 0
|
||||
for ; end < len(s); end++ {
|
||||
if asciiSpace[s[end]] == 1 {
|
||||
break
|
||||
}
|
||||
if s[end] == '\\' {
|
||||
end++
|
||||
if end == len(s) {
|
||||
return nil, errors.New("invalid backslash")
|
||||
}
|
||||
}
|
||||
}
|
||||
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
|
||||
if end == len(s) {
|
||||
s = ""
|
||||
} else {
|
||||
s = s[end+1:]
|
||||
}
|
||||
} else { // quoted string
|
||||
s = s[1:]
|
||||
end := 0
|
||||
for ; end < len(s); end++ {
|
||||
if s[end] == '\'' {
|
||||
break
|
||||
}
|
||||
if s[end] == '\\' {
|
||||
end++
|
||||
}
|
||||
}
|
||||
if end == len(s) {
|
||||
return nil, errors.New("unterminated quoted string in connection info string")
|
||||
}
|
||||
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
|
||||
if end == len(s) {
|
||||
s = ""
|
||||
} else {
|
||||
s = s[end+1:]
|
||||
}
|
||||
}
|
||||
|
||||
if k, ok := nameMap[key]; ok {
|
||||
key = k
|
||||
}
|
||||
|
||||
if key == "" {
|
||||
return nil, errors.New("invalid dsn")
|
||||
}
|
||||
|
||||
settings[key] = val
|
||||
}
|
||||
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) {
|
||||
servicefile, err := pgservicefile.ReadServicefile(servicefilePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read service file: %v", servicefilePath)
|
||||
}
|
||||
|
||||
service, err := servicefile.GetService(serviceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to find service: %v", serviceName)
|
||||
}
|
||||
|
||||
nameMap := map[string]string{
|
||||
"dbname": "database",
|
||||
}
|
||||
|
||||
settings := make(map[string]string, len(service.Settings))
|
||||
for k, v := range service.Settings {
|
||||
if k2, present := nameMap[k]; present {
|
||||
k = k2
|
||||
}
|
||||
settings[k] = v
|
||||
}
|
||||
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is
|
||||
// necessary to allow returning multiple TLS configs as sslmode "allow" and
|
||||
// "prefer" allow fallback.
|
||||
func configTLS(settings map[string]string, thisHost string, parseConfigOptions ParseConfigOptions) ([]*tls.Config, error) {
|
||||
host := thisHost
|
||||
sslmode := settings["sslmode"]
|
||||
sslrootcert := settings["sslrootcert"]
|
||||
sslcert := settings["sslcert"]
|
||||
sslkey := settings["sslkey"]
|
||||
sslpassword := settings["sslpassword"]
|
||||
sslsni := settings["sslsni"]
|
||||
|
||||
// Match libpq default behavior
|
||||
if sslmode == "" {
|
||||
sslmode = "prefer"
|
||||
}
|
||||
if sslsni == "" {
|
||||
sslsni = "1"
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{}
|
||||
|
||||
switch sslmode {
|
||||
case "disable":
|
||||
return []*tls.Config{nil}, nil
|
||||
case "allow", "prefer":
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
case "require":
|
||||
// According to PostgreSQL documentation, if a root CA file exists,
|
||||
// the behavior of sslmode=require should be the same as that of verify-ca
|
||||
//
|
||||
// See https://www.postgresql.org/docs/12/libpq-ssl.html
|
||||
if sslrootcert != "" {
|
||||
goto nextCase
|
||||
}
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
break
|
||||
nextCase:
|
||||
fallthrough
|
||||
case "verify-ca":
|
||||
// Don't perform the default certificate verification because it
|
||||
// will verify the hostname. Instead, verify the server's
|
||||
// certificate chain ourselves in VerifyPeerCertificate and
|
||||
// ignore the server name. This emulates libpq's verify-ca
|
||||
// behavior.
|
||||
//
|
||||
// See https://github.com/golang/go/issues/21971#issuecomment-332693931
|
||||
// and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate
|
||||
// for more info.
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error {
|
||||
certs := make([]*x509.Certificate, len(certificates))
|
||||
for i, asn1Data := range certificates {
|
||||
cert, err := x509.ParseCertificate(asn1Data)
|
||||
if err != nil {
|
||||
return errors.New("failed to parse certificate from server: " + err.Error())
|
||||
}
|
||||
certs[i] = cert
|
||||
}
|
||||
|
||||
// Leave DNSName empty to skip hostname verification.
|
||||
opts := x509.VerifyOptions{
|
||||
Roots: tlsConfig.RootCAs,
|
||||
Intermediates: x509.NewCertPool(),
|
||||
}
|
||||
// Skip the first cert because it's the leaf. All others
|
||||
// are intermediates.
|
||||
for _, cert := range certs[1:] {
|
||||
opts.Intermediates.AddCert(cert)
|
||||
}
|
||||
_, err := certs[0].Verify(opts)
|
||||
return err
|
||||
}
|
||||
case "verify-full":
|
||||
tlsConfig.ServerName = host
|
||||
default:
|
||||
return nil, errors.New("sslmode is invalid")
|
||||
}
|
||||
|
||||
if sslrootcert != "" {
|
||||
caCertPool := x509.NewCertPool()
|
||||
|
||||
caPath := sslrootcert
|
||||
caCert, err := ioutil.ReadFile(caPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read CA file: %w", err)
|
||||
}
|
||||
|
||||
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||
return nil, errors.New("unable to add CA to cert pool")
|
||||
}
|
||||
|
||||
tlsConfig.RootCAs = caCertPool
|
||||
tlsConfig.ClientCAs = caCertPool
|
||||
}
|
||||
|
||||
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
|
||||
return nil, errors.New(`both "sslcert" and "sslkey" are required`)
|
||||
}
|
||||
|
||||
if sslcert != "" && sslkey != "" {
|
||||
buf, err := ioutil.ReadFile(sslkey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read sslkey: %w", err)
|
||||
}
|
||||
block, _ := pem.Decode(buf)
|
||||
var pemKey []byte
|
||||
var decryptedKey []byte
|
||||
var decryptedError error
|
||||
// If PEM is encrypted, attempt to decrypt using pass phrase
|
||||
if x509.IsEncryptedPEMBlock(block) {
|
||||
// Attempt decryption with pass phrase
|
||||
// NOTE: only supports RSA (PKCS#1)
|
||||
if sslpassword != "" {
|
||||
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
|
||||
}
|
||||
//if sslpassword not provided or has decryption error when use it
|
||||
//try to find sslpassword with callback function
|
||||
if sslpassword == "" || decryptedError != nil {
|
||||
if parseConfigOptions.GetSSLPassword != nil {
|
||||
sslpassword = parseConfigOptions.GetSSLPassword(context.Background())
|
||||
}
|
||||
if sslpassword == "" {
|
||||
return nil, fmt.Errorf("unable to find sslpassword")
|
||||
}
|
||||
}
|
||||
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
|
||||
// Should we also provide warning for PKCS#1 needed?
|
||||
if decryptedError != nil {
|
||||
return nil, fmt.Errorf("unable to decrypt key: %w", err)
|
||||
}
|
||||
|
||||
pemBytes := pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: decryptedKey,
|
||||
}
|
||||
pemKey = pem.EncodeToMemory(&pemBytes)
|
||||
} else {
|
||||
pemKey = pem.EncodeToMemory(block)
|
||||
}
|
||||
certfile, err := ioutil.ReadFile(sslcert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read cert: %w", err)
|
||||
}
|
||||
cert, err := tls.X509KeyPair(certfile, pemKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to load cert: %w", err)
|
||||
}
|
||||
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||
}
|
||||
|
||||
// Set Server Name Indication (SNI), if enabled by connection parameters.
|
||||
// Per RFC 6066, do not set it if the host is a literal IP address (IPv4
|
||||
// or IPv6).
|
||||
if sslsni == "1" && net.ParseIP(host) == nil {
|
||||
tlsConfig.ServerName = host
|
||||
}
|
||||
|
||||
switch sslmode {
|
||||
case "allow":
|
||||
return []*tls.Config{nil, tlsConfig}, nil
|
||||
case "prefer":
|
||||
return []*tls.Config{tlsConfig, nil}, nil
|
||||
case "require", "verify-ca", "verify-full":
|
||||
return []*tls.Config{tlsConfig}, nil
|
||||
default:
|
||||
panic("BUG: bad sslmode should already have been caught")
|
||||
}
|
||||
}
|
||||
|
||||
func parsePort(s string) (uint16, error) {
|
||||
port, err := strconv.ParseUint(s, 10, 16)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if port < 1 || port > math.MaxUint16 {
|
||||
return 0, errors.New("outside range")
|
||||
}
|
||||
return uint16(port), nil
|
||||
}
|
||||
|
||||
func makeDefaultDialer() *net.Dialer {
|
||||
return &net.Dialer{KeepAlive: 5 * time.Minute}
|
||||
}
|
||||
|
||||
func makeDefaultResolver() *net.Resolver {
|
||||
return net.DefaultResolver
|
||||
}
|
||||
|
||||
func parseConnectTimeoutSetting(s string) (time.Duration, error) {
|
||||
timeout, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if timeout < 0 {
|
||||
return 0, errors.New("negative timeout")
|
||||
}
|
||||
return time.Duration(timeout) * time.Second, nil
|
||||
}
|
||||
|
||||
func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc {
|
||||
d := makeDefaultDialer()
|
||||
d.Timeout = timeout
|
||||
return d.DialContext
|
||||
}
|
||||
|
||||
// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible
|
||||
// target_session_attrs=read-write.
|
||||
func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
|
||||
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
|
||||
if result.Err != nil {
|
||||
return result.Err
|
||||
}
|
||||
|
||||
if string(result.Rows[0][0]) == "on" {
|
||||
return errors.New("read only connection")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateConnectTargetSessionAttrsReadOnly is an ValidateConnectFunc that implements libpq compatible
|
||||
// target_session_attrs=read-only.
|
||||
func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error {
|
||||
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
|
||||
if result.Err != nil {
|
||||
return result.Err
|
||||
}
|
||||
|
||||
if string(result.Rows[0][0]) != "on" {
|
||||
return errors.New("connection is not read only")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateConnectTargetSessionAttrsStandby is an ValidateConnectFunc that implements libpq compatible
|
||||
// target_session_attrs=standby.
|
||||
func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error {
|
||||
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
|
||||
if result.Err != nil {
|
||||
return result.Err
|
||||
}
|
||||
|
||||
if string(result.Rows[0][0]) != "t" {
|
||||
return errors.New("server is not in hot standby mode")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateConnectTargetSessionAttrsPrimary is an ValidateConnectFunc that implements libpq compatible
|
||||
// target_session_attrs=primary.
|
||||
func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error {
|
||||
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
|
||||
if result.Err != nil {
|
||||
return result.Err
|
||||
}
|
||||
|
||||
if string(result.Rows[0][0]) == "t" {
|
||||
return errors.New("server is in standby mode")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateConnectTargetSessionAttrsPreferStandby is an ValidateConnectFunc that implements libpq compatible
|
||||
// target_session_attrs=prefer-standby.
|
||||
func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error {
|
||||
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
|
||||
if result.Err != nil {
|
||||
return result.Err
|
||||
}
|
||||
|
||||
if string(result.Rows[0][0]) != "t" {
|
||||
return &NotPreferredError{err: errors.New("server is not in hot standby mode")}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,63 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package pgconn
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func defaultSettings() map[string]string {
|
||||
settings := make(map[string]string)
|
||||
|
||||
settings["host"] = defaultHost()
|
||||
settings["port"] = "5432"
|
||||
|
||||
// Default to the OS user name. Purposely ignoring err getting user name from
|
||||
// OS. The client application will simply have to specify the user in that
|
||||
// case (which they typically will be doing anyway).
|
||||
user, err := user.Current()
|
||||
if err == nil {
|
||||
settings["user"] = user.Username
|
||||
settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass")
|
||||
settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf")
|
||||
sslcert := filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt")
|
||||
sslkey := filepath.Join(user.HomeDir, ".postgresql", "postgresql.key")
|
||||
if _, err := os.Stat(sslcert); err == nil {
|
||||
if _, err := os.Stat(sslkey); err == nil {
|
||||
// Both the cert and key must be present to use them, or do not use either
|
||||
settings["sslcert"] = sslcert
|
||||
settings["sslkey"] = sslkey
|
||||
}
|
||||
}
|
||||
sslrootcert := filepath.Join(user.HomeDir, ".postgresql", "root.crt")
|
||||
if _, err := os.Stat(sslrootcert); err == nil {
|
||||
settings["sslrootcert"] = sslrootcert
|
||||
}
|
||||
}
|
||||
|
||||
settings["target_session_attrs"] = "any"
|
||||
|
||||
return settings
|
||||
}
|
||||
|
||||
// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost
|
||||
// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it
|
||||
// checks the existence of common locations.
|
||||
func defaultHost() string {
|
||||
candidatePaths := []string{
|
||||
"/var/run/postgresql", // Debian
|
||||
"/private/tmp", // OSX - homebrew
|
||||
"/tmp", // standard PostgreSQL
|
||||
}
|
||||
|
||||
for _, path := range candidatePaths {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return path
|
||||
}
|
||||
}
|
||||
|
||||
return "localhost"
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package pgconn
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func defaultSettings() map[string]string {
|
||||
settings := make(map[string]string)
|
||||
|
||||
settings["host"] = defaultHost()
|
||||
settings["port"] = "5432"
|
||||
|
||||
// Default to the OS user name. Purposely ignoring err getting user name from
|
||||
// OS. The client application will simply have to specify the user in that
|
||||
// case (which they typically will be doing anyway).
|
||||
user, err := user.Current()
|
||||
appData := os.Getenv("APPDATA")
|
||||
if err == nil {
|
||||
// Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`,
|
||||
// but the libpq default is just the `user` portion, so we strip off the first part.
|
||||
username := user.Username
|
||||
if strings.Contains(username, "\\") {
|
||||
username = username[strings.LastIndex(username, "\\")+1:]
|
||||
}
|
||||
|
||||
settings["user"] = username
|
||||
settings["passfile"] = filepath.Join(appData, "postgresql", "pgpass.conf")
|
||||
settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf")
|
||||
sslcert := filepath.Join(appData, "postgresql", "postgresql.crt")
|
||||
sslkey := filepath.Join(appData, "postgresql", "postgresql.key")
|
||||
if _, err := os.Stat(sslcert); err == nil {
|
||||
if _, err := os.Stat(sslkey); err == nil {
|
||||
// Both the cert and key must be present to use them, or do not use either
|
||||
settings["sslcert"] = sslcert
|
||||
settings["sslkey"] = sslkey
|
||||
}
|
||||
}
|
||||
sslrootcert := filepath.Join(appData, "postgresql", "root.crt")
|
||||
if _, err := os.Stat(sslrootcert); err == nil {
|
||||
settings["sslrootcert"] = sslrootcert
|
||||
}
|
||||
}
|
||||
|
||||
settings["target_session_attrs"] = "any"
|
||||
|
||||
return settings
|
||||
}
|
||||
|
||||
// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost
|
||||
// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it
|
||||
// checks the existence of common locations.
|
||||
func defaultHost() string {
|
||||
return "localhost"
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
// Package pgconn is a low-level PostgreSQL database driver.
|
||||
/*
|
||||
pgconn provides lower level access to a PostgreSQL connection than a database/sql or pgx connection. It operates at
|
||||
nearly the same level is the C library libpq.
|
||||
|
||||
Establishing a Connection
|
||||
|
||||
Use Connect to establish a connection. It accepts a connection string in URL or DSN and will read the environment for
|
||||
libpq style environment variables.
|
||||
|
||||
Executing a Query
|
||||
|
||||
ExecParams and ExecPrepared execute a single query. They return readers that iterate over each row. The Read method
|
||||
reads all rows into memory.
|
||||
|
||||
Executing Multiple Queries in a Single Round Trip
|
||||
|
||||
Exec and ExecBatch can execute multiple queries in a single round trip. They return readers that iterate over each query
|
||||
result. The ReadAll method reads all query results into memory.
|
||||
|
||||
Pipeline Mode
|
||||
|
||||
Pipeline mode allows sending queries without having read the results of previously sent queries. It allows
|
||||
control of exactly how many and when network round trips occur.
|
||||
|
||||
Context Support
|
||||
|
||||
All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the
|
||||
method immediately returns. In most circumstances, this will close the underlying connection.
|
||||
|
||||
The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the
|
||||
client to abort.
|
||||
*/
|
||||
package pgconn
|
||||
@@ -0,0 +1,226 @@
|
||||
package pgconn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server.
|
||||
func SafeToRetry(err error) bool {
|
||||
if e, ok := err.(interface{ SafeToRetry() bool }); ok {
|
||||
return e.SafeToRetry()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a
|
||||
// context.DeadlineExceeded or an implementer of net.Error where Timeout() is true.
|
||||
func Timeout(err error) bool {
|
||||
var timeoutErr *errTimeout
|
||||
return errors.As(err, &timeoutErr)
|
||||
}
|
||||
|
||||
// PgError represents an error reported by the PostgreSQL server. See
|
||||
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
|
||||
// detailed field description.
|
||||
type PgError struct {
|
||||
Severity string
|
||||
Code string
|
||||
Message string
|
||||
Detail string
|
||||
Hint string
|
||||
Position int32
|
||||
InternalPosition int32
|
||||
InternalQuery string
|
||||
Where string
|
||||
SchemaName string
|
||||
TableName string
|
||||
ColumnName string
|
||||
DataTypeName string
|
||||
ConstraintName string
|
||||
File string
|
||||
Line int32
|
||||
Routine string
|
||||
}
|
||||
|
||||
func (pe *PgError) Error() string {
|
||||
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
|
||||
}
|
||||
|
||||
// SQLState returns the SQLState of the error.
|
||||
func (pe *PgError) SQLState() string {
|
||||
return pe.Code
|
||||
}
|
||||
|
||||
type connectError struct {
|
||||
config *Config
|
||||
msg string
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *connectError) Error() string {
|
||||
sb := &strings.Builder{}
|
||||
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg)
|
||||
if e.err != nil {
|
||||
fmt.Fprintf(sb, " (%s)", e.err.Error())
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (e *connectError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
type connLockError struct {
|
||||
status string
|
||||
}
|
||||
|
||||
func (e *connLockError) SafeToRetry() bool {
|
||||
return true // a lock failure by definition happens before the connection is used.
|
||||
}
|
||||
|
||||
func (e *connLockError) Error() string {
|
||||
return e.status
|
||||
}
|
||||
|
||||
type parseConfigError struct {
|
||||
connString string
|
||||
msg string
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *parseConfigError) Error() string {
|
||||
connString := redactPW(e.connString)
|
||||
if e.err == nil {
|
||||
return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg)
|
||||
}
|
||||
return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error())
|
||||
}
|
||||
|
||||
func (e *parseConfigError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
func normalizeTimeoutError(ctx context.Context, err error) error {
|
||||
if err, ok := err.(net.Error); ok && err.Timeout() {
|
||||
if ctx.Err() == context.Canceled {
|
||||
// Since the timeout was caused by a context cancellation, the actual error is context.Canceled not the timeout error.
|
||||
return context.Canceled
|
||||
} else if ctx.Err() == context.DeadlineExceeded {
|
||||
return &errTimeout{err: ctx.Err()}
|
||||
} else {
|
||||
return &errTimeout{err: err}
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type pgconnError struct {
|
||||
msg string
|
||||
err error
|
||||
safeToRetry bool
|
||||
}
|
||||
|
||||
func (e *pgconnError) Error() string {
|
||||
if e.msg == "" {
|
||||
return e.err.Error()
|
||||
}
|
||||
if e.err == nil {
|
||||
return e.msg
|
||||
}
|
||||
return fmt.Sprintf("%s: %s", e.msg, e.err.Error())
|
||||
}
|
||||
|
||||
func (e *pgconnError) SafeToRetry() bool {
|
||||
return e.safeToRetry
|
||||
}
|
||||
|
||||
func (e *pgconnError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
// errTimeout occurs when an error was caused by a timeout. Specifically, it wraps an error which is
|
||||
// context.Canceled, context.DeadlineExceeded, or an implementer of net.Error where Timeout() is true.
|
||||
type errTimeout struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *errTimeout) Error() string {
|
||||
return fmt.Sprintf("timeout: %s", e.err.Error())
|
||||
}
|
||||
|
||||
func (e *errTimeout) SafeToRetry() bool {
|
||||
return SafeToRetry(e.err)
|
||||
}
|
||||
|
||||
func (e *errTimeout) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
type contextAlreadyDoneError struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *contextAlreadyDoneError) Error() string {
|
||||
return fmt.Sprintf("context already done: %s", e.err.Error())
|
||||
}
|
||||
|
||||
func (e *contextAlreadyDoneError) SafeToRetry() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *contextAlreadyDoneError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
// newContextAlreadyDoneError double-wraps a context error in `contextAlreadyDoneError` and `errTimeout`.
|
||||
func newContextAlreadyDoneError(ctx context.Context) (err error) {
|
||||
return &errTimeout{&contextAlreadyDoneError{err: ctx.Err()}}
|
||||
}
|
||||
|
||||
func redactPW(connString string) string {
|
||||
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
|
||||
if u, err := url.Parse(connString); err == nil {
|
||||
return redactURL(u)
|
||||
}
|
||||
}
|
||||
quotedDSN := regexp.MustCompile(`password='[^']*'`)
|
||||
connString = quotedDSN.ReplaceAllLiteralString(connString, "password=xxxxx")
|
||||
plainDSN := regexp.MustCompile(`password=[^ ]*`)
|
||||
connString = plainDSN.ReplaceAllLiteralString(connString, "password=xxxxx")
|
||||
brokenURL := regexp.MustCompile(`:[^:@]+?@`)
|
||||
connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@")
|
||||
return connString
|
||||
}
|
||||
|
||||
func redactURL(u *url.URL) string {
|
||||
if u == nil {
|
||||
return ""
|
||||
}
|
||||
if _, pwSet := u.User.Password(); pwSet {
|
||||
u.User = url.UserPassword(u.User.Username(), "xxxxx")
|
||||
}
|
||||
return u.String()
|
||||
}
|
||||
|
||||
type NotPreferredError struct {
|
||||
err error
|
||||
safeToRetry bool
|
||||
}
|
||||
|
||||
func (e *NotPreferredError) Error() string {
|
||||
return fmt.Sprintf("standby server not found: %s", e.err.Error())
|
||||
}
|
||||
|
||||
func (e *NotPreferredError) SafeToRetry() bool {
|
||||
return e.safeToRetry
|
||||
}
|
||||
|
||||
func (e *NotPreferredError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
package pgconn_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfigError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expectedMsg string
|
||||
}{
|
||||
{
|
||||
name: "url with password",
|
||||
err: pgconn.NewParseConfigError("postgresql://foo:password@host", "msg", nil),
|
||||
expectedMsg: "cannot parse `postgresql://foo:xxxxx@host`: msg",
|
||||
},
|
||||
{
|
||||
name: "dsn with password unquoted",
|
||||
err: pgconn.NewParseConfigError("host=host password=password user=user", "msg", nil),
|
||||
expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg",
|
||||
},
|
||||
{
|
||||
name: "dsn with password quoted",
|
||||
err: pgconn.NewParseConfigError("host=host password='pass word' user=user", "msg", nil),
|
||||
expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg",
|
||||
},
|
||||
{
|
||||
name: "weird url",
|
||||
err: pgconn.NewParseConfigError("postgresql://foo::pasword@host:1:", "msg", nil),
|
||||
expectedMsg: "cannot parse `postgresql://foo:xxxxx@host:1:`: msg",
|
||||
},
|
||||
{
|
||||
name: "weird url with slash in password",
|
||||
err: pgconn.NewParseConfigError("postgres://user:pass/word@host:5432/db_name", "msg", nil),
|
||||
expectedMsg: "cannot parse `postgres://user:xxxxxx@host:5432/db_name`: msg",
|
||||
},
|
||||
{
|
||||
name: "url without password",
|
||||
err: pgconn.NewParseConfigError("postgresql://other@host/db", "msg", nil),
|
||||
expectedMsg: "cannot parse `postgresql://other@host/db`: msg",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.EqualError(t, tt.err, tt.expectedMsg)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
// File export_test exports some methods for better testing.
|
||||
|
||||
package pgconn
|
||||
|
||||
func NewParseConfigError(conn, msg string, err error) error {
|
||||
return &parseConfigError{
|
||||
connString: conn,
|
||||
msg: msg,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
package pgconn_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func closeConn(t testing.TB, conn *pgconn.PgConn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
require.NoError(t, conn.Close(ctx))
|
||||
select {
|
||||
case <-conn.CleanupDone():
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Connection cleanup exceeded maximum time")
|
||||
}
|
||||
}
|
||||
|
||||
// Do a simple query to ensure the connection is still usable
|
||||
func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
result := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil).Read()
|
||||
cancel()
|
||||
|
||||
require.Nil(t, result.Err)
|
||||
assert.Equal(t, 3, len(result.Rows))
|
||||
assert.Equal(t, "1", string(result.Rows[0][0]))
|
||||
assert.Equal(t, "2", string(result.Rows[1][0]))
|
||||
assert.Equal(t, "3", string(result.Rows[2][0]))
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package ctxwatch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a
|
||||
// time.
|
||||
type ContextWatcher struct {
|
||||
onCancel func()
|
||||
onUnwatchAfterCancel func()
|
||||
unwatchChan chan struct{}
|
||||
|
||||
lock sync.Mutex
|
||||
watchInProgress bool
|
||||
onCancelWasCalled bool
|
||||
}
|
||||
|
||||
// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled.
|
||||
// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and
|
||||
// onCancel called.
|
||||
func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher {
|
||||
cw := &ContextWatcher{
|
||||
onCancel: onCancel,
|
||||
onUnwatchAfterCancel: onUnwatchAfterCancel,
|
||||
unwatchChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
return cw
|
||||
}
|
||||
|
||||
// Watch starts watching ctx. If ctx is canceled then the onCancel function passed to NewContextWatcher will be called.
|
||||
func (cw *ContextWatcher) Watch(ctx context.Context) {
|
||||
cw.lock.Lock()
|
||||
defer cw.lock.Unlock()
|
||||
|
||||
if cw.watchInProgress {
|
||||
panic("Watch already in progress")
|
||||
}
|
||||
|
||||
cw.onCancelWasCalled = false
|
||||
|
||||
if ctx.Done() != nil {
|
||||
cw.watchInProgress = true
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
cw.onCancel()
|
||||
cw.onCancelWasCalled = true
|
||||
<-cw.unwatchChan
|
||||
case <-cw.unwatchChan:
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
cw.watchInProgress = false
|
||||
}
|
||||
}
|
||||
|
||||
// Unwatch stops watching the previously watched context. If the onCancel function passed to NewContextWatcher was
|
||||
// called then onUnwatchAfterCancel will also be called.
|
||||
func (cw *ContextWatcher) Unwatch() {
|
||||
cw.lock.Lock()
|
||||
defer cw.lock.Unlock()
|
||||
|
||||
if cw.watchInProgress {
|
||||
cw.unwatchChan <- struct{}{}
|
||||
if cw.onCancelWasCalled {
|
||||
cw.onUnwatchAfterCancel()
|
||||
}
|
||||
cw.watchInProgress = false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
package ctxwatch_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn/internal/ctxwatch"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestContextWatcherContextCancelled(t *testing.T) {
|
||||
canceledChan := make(chan struct{})
|
||||
cleanupCalled := false
|
||||
cw := ctxwatch.NewContextWatcher(func() {
|
||||
canceledChan <- struct{}{}
|
||||
}, func() {
|
||||
cleanupCalled = true
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cw.Watch(ctx)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case <-canceledChan:
|
||||
case <-time.NewTimer(time.Second).C:
|
||||
t.Fatal("Timed out waiting for cancel func to be called")
|
||||
}
|
||||
|
||||
cw.Unwatch()
|
||||
|
||||
require.True(t, cleanupCalled, "Cleanup func was not called")
|
||||
}
|
||||
|
||||
func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) {
|
||||
cw := ctxwatch.NewContextWatcher(func() {
|
||||
t.Error("cancel func should not have been called")
|
||||
}, func() {
|
||||
t.Error("cleanup func should not have been called")
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cw.Watch(ctx)
|
||||
cw.Unwatch()
|
||||
cancel()
|
||||
}
|
||||
|
||||
func TestContextWatcherMultipleWatchPanics(t *testing.T) {
|
||||
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
cw.Watch(ctx)
|
||||
|
||||
ctx2, cancel2 := context.WithCancel(context.Background())
|
||||
defer cancel2()
|
||||
require.Panics(t, func() { cw.Watch(ctx2) }, "Expected panic when Watch called multiple times")
|
||||
}
|
||||
|
||||
func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) {
|
||||
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||
cw.Unwatch() // unwatch when not / never watching
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
cw.Watch(ctx)
|
||||
cw.Unwatch()
|
||||
cw.Unwatch() // double unwatch
|
||||
}
|
||||
|
||||
func TestContextWatcherUnwatchIsConcurrencySafe(t *testing.T) {
|
||||
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
cw.Watch(ctx)
|
||||
|
||||
go cw.Unwatch()
|
||||
go cw.Unwatch()
|
||||
|
||||
<-ctx.Done()
|
||||
}
|
||||
|
||||
func TestContextWatcherStress(t *testing.T) {
|
||||
var cancelFuncCalls int64
|
||||
var cleanupFuncCalls int64
|
||||
|
||||
cw := ctxwatch.NewContextWatcher(func() {
|
||||
atomic.AddInt64(&cancelFuncCalls, 1)
|
||||
}, func() {
|
||||
atomic.AddInt64(&cleanupFuncCalls, 1)
|
||||
})
|
||||
|
||||
cycleCount := 100000
|
||||
|
||||
for i := 0; i < cycleCount; i++ {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cw.Watch(ctx)
|
||||
if i%2 == 0 {
|
||||
cancel()
|
||||
}
|
||||
|
||||
// Without time.Sleep, cw.Unwatch will almost always run before the cancel func which means cancel will never happen. This gives us a better mix.
|
||||
if i%3 == 0 {
|
||||
time.Sleep(time.Nanosecond)
|
||||
}
|
||||
|
||||
cw.Unwatch()
|
||||
if i%2 == 1 {
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
actualCancelFuncCalls := atomic.LoadInt64(&cancelFuncCalls)
|
||||
actualCleanupFuncCalls := atomic.LoadInt64(&cleanupFuncCalls)
|
||||
|
||||
if actualCancelFuncCalls == 0 {
|
||||
t.Fatal("actualCancelFuncCalls == 0")
|
||||
}
|
||||
|
||||
maxCancelFuncCalls := int64(cycleCount) / 2
|
||||
if actualCancelFuncCalls > maxCancelFuncCalls {
|
||||
t.Errorf("cancel func calls should be no more than %d but was %d", actualCancelFuncCalls, maxCancelFuncCalls)
|
||||
}
|
||||
|
||||
if actualCancelFuncCalls != actualCleanupFuncCalls {
|
||||
t.Errorf("cancel func calls (%d) should be equal to cleanup func calls (%d) but was not", actualCancelFuncCalls, actualCleanupFuncCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkContextWatcherUncancellable(b *testing.B) {
|
||||
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
cw.Watch(context.Background())
|
||||
cw.Unwatch()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkContextWatcherCancelled(b *testing.B) {
|
||||
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cw.Watch(ctx)
|
||||
cancel()
|
||||
cw.Unwatch()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkContextWatcherCancellable(b *testing.B) {
|
||||
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
cw.Watch(ctx)
|
||||
cw.Unwatch()
|
||||
}
|
||||
}
|
||||
+100
@@ -0,0 +1,100 @@
|
||||
package pgconn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
)
|
||||
|
||||
// NewGSSFunc creates a GSS authentication provider, for use with
|
||||
// RegisterGSSProvider.
|
||||
type NewGSSFunc func() (GSS, error)
|
||||
|
||||
var newGSS NewGSSFunc
|
||||
|
||||
// RegisterGSSProvider registers a GSS authentication provider. For example, if
|
||||
// you need to use Kerberos to authenticate with your server, add this to your
|
||||
// main package:
|
||||
//
|
||||
// import "github.com/otan/gopgkrb5"
|
||||
//
|
||||
// func init() {
|
||||
// pgconn.RegisterGSSProvider(func() (pgconn.GSS, error) { return gopgkrb5.NewGSS() })
|
||||
// }
|
||||
func RegisterGSSProvider(newGSSArg NewGSSFunc) {
|
||||
newGSS = newGSSArg
|
||||
}
|
||||
|
||||
// GSS provides GSSAPI authentication (e.g., Kerberos).
|
||||
type GSS interface {
|
||||
GetInitToken(host string, service string) ([]byte, error)
|
||||
GetInitTokenFromSPN(spn string) ([]byte, error)
|
||||
Continue(inToken []byte) (done bool, outToken []byte, err error)
|
||||
}
|
||||
|
||||
func (c *PgConn) gssAuth() error {
|
||||
if newGSS == nil {
|
||||
return errors.New("kerberos error: no GSSAPI provider registered, see https://github.com/otan/gopgkrb5")
|
||||
}
|
||||
cli, err := newGSS()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var nextData []byte
|
||||
if c.config.KerberosSpn != "" {
|
||||
// Use the supplied SPN if provided.
|
||||
nextData, err = cli.GetInitTokenFromSPN(c.config.KerberosSpn)
|
||||
} else {
|
||||
// Allow the kerberos service name to be overridden
|
||||
service := "postgres"
|
||||
if c.config.KerberosSrvName != "" {
|
||||
service = c.config.KerberosSrvName
|
||||
}
|
||||
nextData, err = cli.GetInitToken(c.config.Host, service)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
gssResponse := &pgproto3.GSSResponse{
|
||||
Data: nextData,
|
||||
}
|
||||
c.frontend.Send(gssResponse)
|
||||
err = c.frontend.Flush()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := c.rxGSSContinue()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var done bool
|
||||
done, nextData, err = cli.Continue(resp.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) {
|
||||
msg, err := c.receiveMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch m := msg.(type) {
|
||||
case *pgproto3.AuthenticationGSSContinue:
|
||||
return m, nil
|
||||
case *pgproto3.ErrorResponse:
|
||||
return nil, ErrorResponseToPgError(m)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg)
|
||||
}
|
||||
+1962
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,41 @@
|
||||
package pgconn
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCommandTag(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var tests = []struct {
|
||||
commandTag CommandTag
|
||||
rowsAffected int64
|
||||
isInsert bool
|
||||
isUpdate bool
|
||||
isDelete bool
|
||||
isSelect bool
|
||||
}{
|
||||
{commandTag: CommandTag{s: "INSERT 0 5"}, rowsAffected: 5, isInsert: true},
|
||||
{commandTag: CommandTag{s: "UPDATE 0"}, rowsAffected: 0, isUpdate: true},
|
||||
{commandTag: CommandTag{s: "UPDATE 1"}, rowsAffected: 1, isUpdate: true},
|
||||
{commandTag: CommandTag{s: "DELETE 0"}, rowsAffected: 0, isDelete: true},
|
||||
{commandTag: CommandTag{s: "DELETE 1"}, rowsAffected: 1, isDelete: true},
|
||||
{commandTag: CommandTag{s: "DELETE 1234567890"}, rowsAffected: 1234567890, isDelete: true},
|
||||
{commandTag: CommandTag{s: "SELECT 1"}, rowsAffected: 1, isSelect: true},
|
||||
{commandTag: CommandTag{s: "SELECT 99999999999"}, rowsAffected: 99999999999, isSelect: true},
|
||||
{commandTag: CommandTag{s: "CREATE TABLE"}, rowsAffected: 0},
|
||||
{commandTag: CommandTag{s: "ALTER TABLE"}, rowsAffected: 0},
|
||||
{commandTag: CommandTag{s: "DROP TABLE"}, rowsAffected: 0},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
ct := tt.commandTag
|
||||
assert.Equalf(t, tt.rowsAffected, ct.RowsAffected(), "%d. %v", i, tt.commandTag)
|
||||
assert.Equalf(t, tt.isInsert, ct.Insert(), "%d. %v", i, tt.commandTag)
|
||||
assert.Equalf(t, tt.isUpdate, ct.Update(), "%d. %v", i, tt.commandTag)
|
||||
assert.Equalf(t, tt.isDelete, ct.Delete(), "%d. %v", i, tt.commandTag)
|
||||
assert.Equalf(t, tt.isSelect, ct.Select(), "%d. %v", i, tt.commandTag)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
package pgconn_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConnStress(t *testing.T) {
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
actionCount := 10000
|
||||
if s := os.Getenv("PGX_TEST_STRESS_FACTOR"); s != "" {
|
||||
stressFactor, err := strconv.ParseInt(s, 10, 64)
|
||||
require.Nil(t, err, "Failed to parse PGX_TEST_STRESS_FACTOR")
|
||||
actionCount *= int(stressFactor)
|
||||
}
|
||||
|
||||
setupStressDB(t, pgConn)
|
||||
|
||||
actions := []struct {
|
||||
name string
|
||||
fn func(*pgconn.PgConn) error
|
||||
}{
|
||||
{"Exec Select", stressExecSelect},
|
||||
{"ExecParams Select", stressExecParamsSelect},
|
||||
{"Batch", stressBatch},
|
||||
}
|
||||
|
||||
for i := 0; i < actionCount; i++ {
|
||||
action := actions[rand.Intn(len(actions))]
|
||||
err := action.fn(pgConn)
|
||||
require.Nilf(t, err, "%d: %s", i, action.name)
|
||||
}
|
||||
|
||||
// Each call with a context starts a goroutine. Ensure they are cleaned up when context is not canceled.
|
||||
numGoroutine := runtime.NumGoroutine()
|
||||
require.Truef(t, numGoroutine < 1000, "goroutines appear to be orphaned: %d in process", numGoroutine)
|
||||
}
|
||||
|
||||
func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) {
|
||||
_, err := pgConn.Exec(context.Background(), `
|
||||
create temporary table widgets(
|
||||
id serial primary key,
|
||||
name varchar not null,
|
||||
description text,
|
||||
creation_time timestamptz default now()
|
||||
);
|
||||
|
||||
insert into widgets(name, description) values
|
||||
('Foo', 'bar'),
|
||||
('baz', 'Something really long Something really long Something really long Something really long Something really long'),
|
||||
('a', 'b')`).ReadAll()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func stressExecSelect(pgConn *pgconn.PgConn) error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
_, err := pgConn.Exec(ctx, "select * from widgets").ReadAll()
|
||||
return err
|
||||
}
|
||||
|
||||
func stressExecParamsSelect(pgConn *pgconn.PgConn) error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
result := pgConn.ExecParams(ctx, "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read()
|
||||
return result.Err
|
||||
}
|
||||
|
||||
func stressBatch(pgConn *pgconn.PgConn) error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
batch := &pgconn.Batch{}
|
||||
|
||||
batch.ExecParams("select * from widgets", nil, nil, nil, nil)
|
||||
batch.ExecParams("select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil)
|
||||
_, err := pgConn.ExecBatch(ctx, batch).ReadAll()
|
||||
return err
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,7 @@
|
||||
# pgproto3
|
||||
|
||||
Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3.
|
||||
|
||||
pgproto3 can be used as a foundation for PostgreSQL drivers, proxies, mock servers, load balancers and more.
|
||||
|
||||
See example/pgfortune for a playful example of a fake PostgreSQL server.
|
||||
@@ -0,0 +1,52 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
// AuthenticationCleartextPassword is a message sent from the backend indicating that a clear-text password is required.
|
||||
type AuthenticationCleartextPassword struct {
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*AuthenticationCleartextPassword) Backend() {}
|
||||
|
||||
// Backend identifies this message as an authentication response.
|
||||
func (*AuthenticationCleartextPassword) AuthenticationResponse() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
|
||||
if len(src) != 4 {
|
||||
return errors.New("bad authentication message size")
|
||||
}
|
||||
|
||||
authType := binary.BigEndian.Uint32(src)
|
||||
|
||||
if authType != AuthTypeCleartextPassword {
|
||||
return errors.New("bad auth type")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
dst = pgio.AppendInt32(dst, 8)
|
||||
dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src AuthenticationCleartextPassword) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "AuthenticationCleartextPassword",
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type AuthenticationGSS struct{}
|
||||
|
||||
func (a *AuthenticationGSS) Backend() {}
|
||||
|
||||
func (a *AuthenticationGSS) AuthenticationResponse() {}
|
||||
|
||||
func (a *AuthenticationGSS) Decode(src []byte) error {
|
||||
if len(src) < 4 {
|
||||
return errors.New("authentication message too short")
|
||||
}
|
||||
|
||||
authType := binary.BigEndian.Uint32(src)
|
||||
|
||||
if authType != AuthTypeGSS {
|
||||
return errors.New("bad auth type")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSS) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
dst = pgio.AppendInt32(dst, 4)
|
||||
dst = pgio.AppendUint32(dst, AuthTypeGSS)
|
||||
return dst
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Data []byte
|
||||
}{
|
||||
Type: "AuthenticationGSS",
|
||||
})
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSS) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
Type string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type AuthenticationGSSContinue struct {
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSSContinue) Backend() {}
|
||||
|
||||
func (a *AuthenticationGSSContinue) AuthenticationResponse() {}
|
||||
|
||||
func (a *AuthenticationGSSContinue) Decode(src []byte) error {
|
||||
if len(src) < 4 {
|
||||
return errors.New("authentication message too short")
|
||||
}
|
||||
|
||||
authType := binary.BigEndian.Uint32(src)
|
||||
|
||||
if authType != AuthTypeGSSCont {
|
||||
return errors.New("bad auth type")
|
||||
}
|
||||
|
||||
a.Data = src[4:]
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
dst = pgio.AppendInt32(dst, int32(len(a.Data))+8)
|
||||
dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
|
||||
dst = append(dst, a.Data...)
|
||||
return dst
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Data []byte
|
||||
}{
|
||||
Type: "AuthenticationGSSContinue",
|
||||
Data: a.Data,
|
||||
})
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSSContinue) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
Type string
|
||||
Data []byte
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
a.Data = msg.Data
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
// AuthenticationMD5Password is a message sent from the backend indicating that an MD5 hashed password is required.
|
||||
type AuthenticationMD5Password struct {
|
||||
Salt [4]byte
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*AuthenticationMD5Password) Backend() {}
|
||||
|
||||
// Backend identifies this message as an authentication response.
|
||||
func (*AuthenticationMD5Password) AuthenticationResponse() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *AuthenticationMD5Password) Decode(src []byte) error {
|
||||
if len(src) != 8 {
|
||||
return errors.New("bad authentication message size")
|
||||
}
|
||||
|
||||
authType := binary.BigEndian.Uint32(src)
|
||||
|
||||
if authType != AuthTypeMD5Password {
|
||||
return errors.New("bad auth type")
|
||||
}
|
||||
|
||||
copy(dst.Salt[:], src[4:8])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
dst = pgio.AppendInt32(dst, 12)
|
||||
dst = pgio.AppendUint32(dst, AuthTypeMD5Password)
|
||||
dst = append(dst, src.Salt[:]...)
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src AuthenticationMD5Password) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Salt [4]byte
|
||||
}{
|
||||
Type: "AuthenticationMD5Password",
|
||||
Salt: src.Salt,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *AuthenticationMD5Password) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
Type string
|
||||
Salt [4]byte
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dst.Salt = msg.Salt
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
// AuthenticationOk is a message sent from the backend indicating that authentication was successful.
|
||||
type AuthenticationOk struct {
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*AuthenticationOk) Backend() {}
|
||||
|
||||
// Backend identifies this message as an authentication response.
|
||||
func (*AuthenticationOk) AuthenticationResponse() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *AuthenticationOk) Decode(src []byte) error {
|
||||
if len(src) != 4 {
|
||||
return errors.New("bad authentication message size")
|
||||
}
|
||||
|
||||
authType := binary.BigEndian.Uint32(src)
|
||||
|
||||
if authType != AuthTypeOk {
|
||||
return errors.New("bad auth type")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationOk) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
dst = pgio.AppendInt32(dst, 8)
|
||||
dst = pgio.AppendUint32(dst, AuthTypeOk)
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src AuthenticationOk) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "AuthenticationOK",
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
// AuthenticationSASL is a message sent from the backend indicating that SASL authentication is required.
|
||||
type AuthenticationSASL struct {
|
||||
AuthMechanisms []string
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*AuthenticationSASL) Backend() {}
|
||||
|
||||
// Backend identifies this message as an authentication response.
|
||||
func (*AuthenticationSASL) AuthenticationResponse() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *AuthenticationSASL) Decode(src []byte) error {
|
||||
if len(src) < 4 {
|
||||
return errors.New("authentication message too short")
|
||||
}
|
||||
|
||||
authType := binary.BigEndian.Uint32(src)
|
||||
|
||||
if authType != AuthTypeSASL {
|
||||
return errors.New("bad auth type")
|
||||
}
|
||||
|
||||
authMechanisms := src[4:]
|
||||
for len(authMechanisms) > 1 {
|
||||
idx := bytes.IndexByte(authMechanisms, 0)
|
||||
if idx == -1 {
|
||||
return &invalidMessageFormatErr{messageType: "AuthenticationSASL", details: "unterminated string"}
|
||||
}
|
||||
dst.AuthMechanisms = append(dst.AuthMechanisms, string(authMechanisms[:idx]))
|
||||
authMechanisms = authMechanisms[idx+1:]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationSASL) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
dst = pgio.AppendUint32(dst, AuthTypeSASL)
|
||||
|
||||
for _, s := range src.AuthMechanisms {
|
||||
dst = append(dst, []byte(s)...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src AuthenticationSASL) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
AuthMechanisms []string
|
||||
}{
|
||||
Type: "AuthenticationSASL",
|
||||
AuthMechanisms: src.AuthMechanisms,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
// AuthenticationSASLContinue is a message sent from the backend containing a SASL challenge.
|
||||
type AuthenticationSASLContinue struct {
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*AuthenticationSASLContinue) Backend() {}
|
||||
|
||||
// Backend identifies this message as an authentication response.
|
||||
func (*AuthenticationSASLContinue) AuthenticationResponse() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
|
||||
if len(src) < 4 {
|
||||
return errors.New("authentication message too short")
|
||||
}
|
||||
|
||||
authType := binary.BigEndian.Uint32(src)
|
||||
|
||||
if authType != AuthTypeSASLContinue {
|
||||
return errors.New("bad auth type")
|
||||
}
|
||||
|
||||
dst.Data = src[4:]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)
|
||||
|
||||
dst = append(dst, src.Data...)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src AuthenticationSASLContinue) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Data string
|
||||
}{
|
||||
Type: "AuthenticationSASLContinue",
|
||||
Data: string(src.Data),
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *AuthenticationSASLContinue) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
Data string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dst.Data = []byte(msg.Data)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
// AuthenticationSASLFinal is a message sent from the backend indicating a SASL authentication has completed.
|
||||
type AuthenticationSASLFinal struct {
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*AuthenticationSASLFinal) Backend() {}
|
||||
|
||||
// Backend identifies this message as an authentication response.
|
||||
func (*AuthenticationSASLFinal) AuthenticationResponse() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
|
||||
if len(src) < 4 {
|
||||
return errors.New("authentication message too short")
|
||||
}
|
||||
|
||||
authType := binary.BigEndian.Uint32(src)
|
||||
|
||||
if authType != AuthTypeSASLFinal {
|
||||
return errors.New("bad auth type")
|
||||
}
|
||||
|
||||
dst.Data = src[4:]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)
|
||||
|
||||
dst = append(dst, src.Data...)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (src AuthenticationSASLFinal) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Data string
|
||||
}{
|
||||
Type: "AuthenticationSASLFinal",
|
||||
Data: string(src.Data),
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *AuthenticationSASLFinal) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
Data string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dst.Data = []byte(msg.Data)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,262 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Backend acts as a server for the PostgreSQL wire protocol version 3.
|
||||
type Backend struct {
|
||||
cr *chunkReader
|
||||
w io.Writer
|
||||
|
||||
// tracer is used to trace messages when Send or Receive is called. This means an outbound message is traced
|
||||
// before it is actually transmitted (i.e. before Flush).
|
||||
tracer *tracer
|
||||
|
||||
wbuf []byte
|
||||
|
||||
// Frontend message flyweights
|
||||
bind Bind
|
||||
cancelRequest CancelRequest
|
||||
_close Close
|
||||
copyFail CopyFail
|
||||
copyData CopyData
|
||||
copyDone CopyDone
|
||||
describe Describe
|
||||
execute Execute
|
||||
flush Flush
|
||||
functionCall FunctionCall
|
||||
gssEncRequest GSSEncRequest
|
||||
parse Parse
|
||||
query Query
|
||||
sslRequest SSLRequest
|
||||
startupMessage StartupMessage
|
||||
sync Sync
|
||||
terminate Terminate
|
||||
|
||||
bodyLen int
|
||||
msgType byte
|
||||
partialMsg bool
|
||||
authType uint32
|
||||
}
|
||||
|
||||
const (
|
||||
minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code.
|
||||
maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source.
|
||||
)
|
||||
|
||||
// NewBackend creates a new Backend.
|
||||
func NewBackend(r io.Reader, w io.Writer) *Backend {
|
||||
cr := newChunkReader(r, 0)
|
||||
return &Backend{cr: cr, w: w}
|
||||
}
|
||||
|
||||
// Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is
|
||||
// called.
|
||||
func (b *Backend) Send(msg BackendMessage) {
|
||||
prevLen := len(b.wbuf)
|
||||
b.wbuf = msg.Encode(b.wbuf)
|
||||
if b.tracer != nil {
|
||||
b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg)
|
||||
}
|
||||
}
|
||||
|
||||
// Flush writes any pending messages to the frontend (i.e. the client).
|
||||
func (b *Backend) Flush() error {
|
||||
n, err := b.w.Write(b.wbuf)
|
||||
|
||||
const maxLen = 1024
|
||||
if len(b.wbuf) > maxLen {
|
||||
b.wbuf = make([]byte, 0, maxLen)
|
||||
} else {
|
||||
b.wbuf = b.wbuf[:0]
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return &writeError{err: err, safeToRetry: n == 0}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Trace starts tracing the message traffic to w. It writes in a similar format to that produced by the libpq function
|
||||
// PQtrace.
|
||||
func (b *Backend) Trace(w io.Writer, options TracerOptions) {
|
||||
b.tracer = &tracer{
|
||||
w: w,
|
||||
buf: &bytes.Buffer{},
|
||||
TracerOptions: options,
|
||||
}
|
||||
}
|
||||
|
||||
// Untrace stops tracing.
|
||||
func (b *Backend) Untrace() {
|
||||
b.tracer = nil
|
||||
}
|
||||
|
||||
// ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method
|
||||
// because the initial connection message is "special" and does not include the message type as the first byte. This
|
||||
// will return either a StartupMessage, SSLRequest, GSSEncRequest, or CancelRequest.
|
||||
func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
|
||||
buf, err := b.cr.Next(4)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msgSize := int(binary.BigEndian.Uint32(buf) - 4)
|
||||
|
||||
if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen {
|
||||
return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize)
|
||||
}
|
||||
|
||||
buf, err = b.cr.Next(msgSize)
|
||||
if err != nil {
|
||||
return nil, translateEOFtoErrUnexpectedEOF(err)
|
||||
}
|
||||
|
||||
code := binary.BigEndian.Uint32(buf)
|
||||
|
||||
switch code {
|
||||
case ProtocolVersionNumber:
|
||||
err = b.startupMessage.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &b.startupMessage, nil
|
||||
case sslRequestNumber:
|
||||
err = b.sslRequest.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &b.sslRequest, nil
|
||||
case cancelRequestCode:
|
||||
err = b.cancelRequest.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &b.cancelRequest, nil
|
||||
case gssEncReqNumber:
|
||||
err = b.gssEncRequest.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &b.gssEncRequest, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown startup message code: %d", code)
|
||||
}
|
||||
}
|
||||
|
||||
// Receive receives a message from the frontend. The returned message is only valid until the next call to Receive.
|
||||
func (b *Backend) Receive() (FrontendMessage, error) {
|
||||
if !b.partialMsg {
|
||||
header, err := b.cr.Next(5)
|
||||
if err != nil {
|
||||
return nil, translateEOFtoErrUnexpectedEOF(err)
|
||||
}
|
||||
|
||||
b.msgType = header[0]
|
||||
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
|
||||
b.partialMsg = true
|
||||
}
|
||||
|
||||
var msg FrontendMessage
|
||||
switch b.msgType {
|
||||
case 'B':
|
||||
msg = &b.bind
|
||||
case 'C':
|
||||
msg = &b._close
|
||||
case 'D':
|
||||
msg = &b.describe
|
||||
case 'E':
|
||||
msg = &b.execute
|
||||
case 'F':
|
||||
msg = &b.functionCall
|
||||
case 'f':
|
||||
msg = &b.copyFail
|
||||
case 'd':
|
||||
msg = &b.copyData
|
||||
case 'c':
|
||||
msg = &b.copyDone
|
||||
case 'H':
|
||||
msg = &b.flush
|
||||
case 'P':
|
||||
msg = &b.parse
|
||||
case 'p':
|
||||
switch b.authType {
|
||||
case AuthTypeSASL:
|
||||
msg = &SASLInitialResponse{}
|
||||
case AuthTypeSASLContinue:
|
||||
msg = &SASLResponse{}
|
||||
case AuthTypeSASLFinal:
|
||||
msg = &SASLResponse{}
|
||||
case AuthTypeGSS, AuthTypeGSSCont:
|
||||
msg = &GSSResponse{}
|
||||
case AuthTypeCleartextPassword, AuthTypeMD5Password:
|
||||
fallthrough
|
||||
default:
|
||||
// to maintain backwards compatability
|
||||
msg = &PasswordMessage{}
|
||||
}
|
||||
case 'Q':
|
||||
msg = &b.query
|
||||
case 'S':
|
||||
msg = &b.sync
|
||||
case 'X':
|
||||
msg = &b.terminate
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown message type: %c", b.msgType)
|
||||
}
|
||||
|
||||
msgBody, err := b.cr.Next(b.bodyLen)
|
||||
if err != nil {
|
||||
return nil, translateEOFtoErrUnexpectedEOF(err)
|
||||
}
|
||||
|
||||
b.partialMsg = false
|
||||
|
||||
err = msg.Decode(msgBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if b.tracer != nil {
|
||||
b.tracer.traceMessage('F', int32(5+len(msgBody)), msg)
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// SetAuthType sets the authentication type in the backend.
|
||||
// Since multiple message types can start with 'p', SetAuthType allows
|
||||
// contextual identification of FrontendMessages. For example, in the
|
||||
// PG message flow documentation for PasswordMessage:
|
||||
//
|
||||
// Byte1('p')
|
||||
//
|
||||
// Identifies the message as a password response. Note that this is also used for
|
||||
// GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from
|
||||
// the context.
|
||||
//
|
||||
// Since the Frontend does not know about the state of a backend, it is important
|
||||
// to call SetAuthType() after an authentication request is received by the Frontend.
|
||||
func (b *Backend) SetAuthType(authType uint32) error {
|
||||
switch authType {
|
||||
case AuthTypeOk,
|
||||
AuthTypeCleartextPassword,
|
||||
AuthTypeMD5Password,
|
||||
AuthTypeSCMCreds,
|
||||
AuthTypeGSS,
|
||||
AuthTypeGSSCont,
|
||||
AuthTypeSSPI,
|
||||
AuthTypeSASL,
|
||||
AuthTypeSASLContinue,
|
||||
AuthTypeSASLFinal:
|
||||
b.authType = authType
|
||||
default:
|
||||
return fmt.Errorf("authType not recognized: %d", authType)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type BackendKeyData struct {
|
||||
ProcessID uint32
|
||||
SecretKey uint32
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*BackendKeyData) Backend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *BackendKeyData) Decode(src []byte) error {
|
||||
if len(src) != 8 {
|
||||
return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)}
|
||||
}
|
||||
|
||||
dst.ProcessID = binary.BigEndian.Uint32(src[:4])
|
||||
dst.SecretKey = binary.BigEndian.Uint32(src[4:])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *BackendKeyData) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'K')
|
||||
dst = pgio.AppendUint32(dst, 12)
|
||||
dst = pgio.AppendUint32(dst, src.ProcessID)
|
||||
dst = pgio.AppendUint32(dst, src.SecretKey)
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src BackendKeyData) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ProcessID uint32
|
||||
SecretKey uint32
|
||||
}{
|
||||
Type: "BackendKeyData",
|
||||
ProcessID: src.ProcessID,
|
||||
SecretKey: src.SecretKey,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,122 @@
|
||||
package pgproto3_test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBackendReceiveInterrupted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := &interruptReader{}
|
||||
server.push([]byte{'Q', 0, 0, 0, 6})
|
||||
|
||||
backend := pgproto3.NewBackend(server, nil)
|
||||
|
||||
msg, err := backend.Receive()
|
||||
if err == nil {
|
||||
t.Fatal("expected err")
|
||||
}
|
||||
if msg != nil {
|
||||
t.Fatalf("did not expect msg, but %v", msg)
|
||||
}
|
||||
|
||||
server.push([]byte{'I', 0})
|
||||
|
||||
msg, err = backend.Receive()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msg, ok := msg.(*pgproto3.Query); !ok || msg.String != "I" {
|
||||
t.Fatalf("unexpected msg: %v", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendReceiveUnexpectedEOF(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := &interruptReader{}
|
||||
server.push([]byte{'Q', 0, 0, 0, 6})
|
||||
|
||||
backend := pgproto3.NewBackend(server, nil)
|
||||
|
||||
// Receive regular msg
|
||||
msg, err := backend.Receive()
|
||||
assert.Nil(t, msg)
|
||||
assert.Equal(t, io.ErrUnexpectedEOF, err)
|
||||
|
||||
// Receive StartupMessage msg
|
||||
dst := []byte{}
|
||||
dst = pgio.AppendUint32(dst, 1000) // tell the backend we expect 1000 bytes to be read
|
||||
dst = pgio.AppendUint32(dst, 1) // only send 1 byte
|
||||
server.push(dst)
|
||||
|
||||
msg, err = backend.ReceiveStartupMessage()
|
||||
assert.Nil(t, msg)
|
||||
assert.Equal(t, io.ErrUnexpectedEOF, err)
|
||||
}
|
||||
|
||||
func TestStartupMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("valid StartupMessage", func(t *testing.T) {
|
||||
want := &pgproto3.StartupMessage{
|
||||
ProtocolVersion: pgproto3.ProtocolVersionNumber,
|
||||
Parameters: map[string]string{
|
||||
"username": "tester",
|
||||
},
|
||||
}
|
||||
dst := []byte{}
|
||||
dst = want.Encode(dst)
|
||||
|
||||
server := &interruptReader{}
|
||||
server.push(dst)
|
||||
|
||||
backend := pgproto3.NewBackend(server, nil)
|
||||
|
||||
msg, err := backend.ReceiveStartupMessage()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, want, msg)
|
||||
})
|
||||
|
||||
t.Run("invalid packet length", func(t *testing.T) {
|
||||
wantErr := "invalid length of startup packet"
|
||||
tests := []struct {
|
||||
name string
|
||||
packetLen uint32
|
||||
}{
|
||||
{
|
||||
name: "large packet length",
|
||||
// Since the StartupMessage contains the "Length of message contents
|
||||
// in bytes, including self", the max startup packet length is actually
|
||||
// 10000+4. Therefore, let's go past the limit with 10005
|
||||
packetLen: 10005,
|
||||
},
|
||||
{
|
||||
name: "short packet length",
|
||||
packetLen: 3,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := &interruptReader{}
|
||||
dst := []byte{}
|
||||
dst = pgio.AppendUint32(dst, tt.packetLen)
|
||||
dst = pgio.AppendUint32(dst, pgproto3.ProtocolVersionNumber)
|
||||
server.push(dst)
|
||||
|
||||
backend := pgproto3.NewBackend(server, nil)
|
||||
|
||||
msg, err := backend.ReceiveStartupMessage()
|
||||
require.Error(t, err)
|
||||
require.Nil(t, msg)
|
||||
require.Contains(t, err.Error(), wantErr)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
type BigEndianBuf [8]byte
|
||||
|
||||
func (b BigEndianBuf) Int16(n int16) []byte {
|
||||
buf := b[0:2]
|
||||
binary.BigEndian.PutUint16(buf, uint16(n))
|
||||
return buf
|
||||
}
|
||||
|
||||
func (b BigEndianBuf) Uint16(n uint16) []byte {
|
||||
buf := b[0:2]
|
||||
binary.BigEndian.PutUint16(buf, n)
|
||||
return buf
|
||||
}
|
||||
|
||||
func (b BigEndianBuf) Int32(n int32) []byte {
|
||||
buf := b[0:4]
|
||||
binary.BigEndian.PutUint32(buf, uint32(n))
|
||||
return buf
|
||||
}
|
||||
|
||||
func (b BigEndianBuf) Uint32(n uint32) []byte {
|
||||
buf := b[0:4]
|
||||
binary.BigEndian.PutUint32(buf, n)
|
||||
return buf
|
||||
}
|
||||
|
||||
func (b BigEndianBuf) Int64(n int64) []byte {
|
||||
buf := b[0:8]
|
||||
binary.BigEndian.PutUint64(buf, uint64(n))
|
||||
return buf
|
||||
}
|
||||
@@ -0,0 +1,216 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type Bind struct {
|
||||
DestinationPortal string
|
||||
PreparedStatement string
|
||||
ParameterFormatCodes []int16
|
||||
Parameters [][]byte
|
||||
ResultFormatCodes []int16
|
||||
}
|
||||
|
||||
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||
func (*Bind) Frontend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *Bind) Decode(src []byte) error {
|
||||
*dst = Bind{}
|
||||
|
||||
idx := bytes.IndexByte(src, 0)
|
||||
if idx < 0 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
dst.DestinationPortal = string(src[:idx])
|
||||
rp := idx + 1
|
||||
|
||||
idx = bytes.IndexByte(src[rp:], 0)
|
||||
if idx < 0 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
dst.PreparedStatement = string(src[rp : rp+idx])
|
||||
rp += idx + 1
|
||||
|
||||
if len(src[rp:]) < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
parameterFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
|
||||
if parameterFormatCodeCount > 0 {
|
||||
dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount)
|
||||
|
||||
if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
for i := 0; i < parameterFormatCodeCount; i++ {
|
||||
dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
}
|
||||
}
|
||||
|
||||
if len(src[rp:]) < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
parameterCount := int(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
|
||||
if parameterCount > 0 {
|
||||
dst.Parameters = make([][]byte, parameterCount)
|
||||
|
||||
for i := 0; i < parameterCount; i++ {
|
||||
if len(src[rp:]) < 4 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
|
||||
msgSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
||||
rp += 4
|
||||
|
||||
// null
|
||||
if msgSize == -1 {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(src[rp:]) < msgSize {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
|
||||
dst.Parameters[i] = src[rp : rp+msgSize]
|
||||
rp += msgSize
|
||||
}
|
||||
}
|
||||
|
||||
if len(src[rp:]) < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
resultFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
|
||||
dst.ResultFormatCodes = make([]int16, resultFormatCodeCount)
|
||||
if len(src[rp:]) < len(dst.ResultFormatCodes)*2 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
for i := 0; i < resultFormatCodeCount; i++ {
|
||||
dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Bind) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'B')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.DestinationPortal...)
|
||||
dst = append(dst, 0)
|
||||
dst = append(dst, src.PreparedStatement...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
|
||||
for _, fc := range src.ParameterFormatCodes {
|
||||
dst = pgio.AppendInt16(dst, fc)
|
||||
}
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
|
||||
for _, p := range src.Parameters {
|
||||
if p == nil {
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
continue
|
||||
}
|
||||
|
||||
dst = pgio.AppendInt32(dst, int32(len(p)))
|
||||
dst = append(dst, p...)
|
||||
}
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
|
||||
for _, fc := range src.ResultFormatCodes {
|
||||
dst = pgio.AppendInt16(dst, fc)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src Bind) MarshalJSON() ([]byte, error) {
|
||||
formattedParameters := make([]map[string]string, len(src.Parameters))
|
||||
for i, p := range src.Parameters {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
textFormat := true
|
||||
if len(src.ParameterFormatCodes) == 1 {
|
||||
textFormat = src.ParameterFormatCodes[0] == 0
|
||||
} else if len(src.ParameterFormatCodes) > 1 {
|
||||
textFormat = src.ParameterFormatCodes[i] == 0
|
||||
}
|
||||
|
||||
if textFormat {
|
||||
formattedParameters[i] = map[string]string{"text": string(p)}
|
||||
} else {
|
||||
formattedParameters[i] = map[string]string{"binary": hex.EncodeToString(p)}
|
||||
}
|
||||
}
|
||||
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
DestinationPortal string
|
||||
PreparedStatement string
|
||||
ParameterFormatCodes []int16
|
||||
Parameters []map[string]string
|
||||
ResultFormatCodes []int16
|
||||
}{
|
||||
Type: "Bind",
|
||||
DestinationPortal: src.DestinationPortal,
|
||||
PreparedStatement: src.PreparedStatement,
|
||||
ParameterFormatCodes: src.ParameterFormatCodes,
|
||||
Parameters: formattedParameters,
|
||||
ResultFormatCodes: src.ResultFormatCodes,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *Bind) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
DestinationPortal string
|
||||
PreparedStatement string
|
||||
ParameterFormatCodes []int16
|
||||
Parameters []map[string]string
|
||||
ResultFormatCodes []int16
|
||||
}
|
||||
err := json.Unmarshal(data, &msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dst.DestinationPortal = msg.DestinationPortal
|
||||
dst.PreparedStatement = msg.PreparedStatement
|
||||
dst.ParameterFormatCodes = msg.ParameterFormatCodes
|
||||
dst.Parameters = make([][]byte, len(msg.Parameters))
|
||||
dst.ResultFormatCodes = msg.ResultFormatCodes
|
||||
for n, parameter := range msg.Parameters {
|
||||
dst.Parameters[n], err = getValueFromJSON(parameter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot get param %d: %w", n, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type BindComplete struct{}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*BindComplete) Backend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *BindComplete) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *BindComplete) Encode(dst []byte) []byte {
|
||||
return append(dst, '2', 0, 0, 0, 4)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src BindComplete) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "BindComplete",
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
const cancelRequestCode = 80877102
|
||||
|
||||
type CancelRequest struct {
|
||||
ProcessID uint32
|
||||
SecretKey uint32
|
||||
}
|
||||
|
||||
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||
func (*CancelRequest) Frontend() {}
|
||||
|
||||
func (dst *CancelRequest) Decode(src []byte) error {
|
||||
if len(src) != 12 {
|
||||
return errors.New("bad cancel request size")
|
||||
}
|
||||
|
||||
requestCode := binary.BigEndian.Uint32(src)
|
||||
|
||||
if requestCode != cancelRequestCode {
|
||||
return errors.New("bad cancel request code")
|
||||
}
|
||||
|
||||
dst.ProcessID = binary.BigEndian.Uint32(src[4:])
|
||||
dst.SecretKey = binary.BigEndian.Uint32(src[8:])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 4 byte message length.
|
||||
func (src *CancelRequest) Encode(dst []byte) []byte {
|
||||
dst = pgio.AppendInt32(dst, 16)
|
||||
dst = pgio.AppendInt32(dst, cancelRequestCode)
|
||||
dst = pgio.AppendUint32(dst, src.ProcessID)
|
||||
dst = pgio.AppendUint32(dst, src.SecretKey)
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src CancelRequest) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ProcessID uint32
|
||||
SecretKey uint32
|
||||
}{
|
||||
Type: "CancelRequest",
|
||||
ProcessID: src.ProcessID,
|
||||
SecretKey: src.SecretKey,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/iobufpool"
|
||||
)
|
||||
|
||||
// chunkReader is a io.Reader wrapper that minimizes IO reads and memory allocations. It allocates memory in chunks and
|
||||
// will read as much as will fit in the current buffer in a single call regardless of how large a read is actually
|
||||
// requested. The memory returned via Next is only valid until the next call to Next.
|
||||
//
|
||||
// This is roughly equivalent to a bufio.Reader that only uses Peek and Discard to never copy bytes.
|
||||
type chunkReader struct {
|
||||
r io.Reader
|
||||
|
||||
buf []byte
|
||||
rp, wp int // buf read position and write position
|
||||
|
||||
minBufSize int
|
||||
}
|
||||
|
||||
// newChunkReader creates and returns a new chunkReader for r with default configuration. If minBufSize is <= 0 it uses
|
||||
// a default value.
|
||||
func newChunkReader(r io.Reader, minBufSize int) *chunkReader {
|
||||
if minBufSize <= 0 {
|
||||
// By historical reasons Postgres currently has 8KB send buffer inside,
|
||||
// so here we want to have at least the same size buffer.
|
||||
// @see https://github.com/postgres/postgres/blob/249d64999615802752940e017ee5166e726bc7cd/src/backend/libpq/pqcomm.c#L134
|
||||
// @see https://www.postgresql.org/message-id/0cdc5485-cb3c-5e16-4a46-e3b2f7a41322%40ya.ru
|
||||
//
|
||||
// In addition, testing has found no benefit of any larger buffer.
|
||||
minBufSize = 8192
|
||||
}
|
||||
|
||||
return &chunkReader{
|
||||
r: r,
|
||||
minBufSize: minBufSize,
|
||||
buf: iobufpool.Get(minBufSize),
|
||||
}
|
||||
}
|
||||
|
||||
// Next returns buf filled with the next n bytes. buf is only valid until next call of Next. If an error occurs, buf
|
||||
// will be nil.
|
||||
func (r *chunkReader) Next(n int) (buf []byte, err error) {
|
||||
// Reset the buffer if it is empty
|
||||
if r.rp == r.wp {
|
||||
if len(r.buf) != r.minBufSize {
|
||||
iobufpool.Put(r.buf)
|
||||
r.buf = iobufpool.Get(r.minBufSize)
|
||||
}
|
||||
r.rp = 0
|
||||
r.wp = 0
|
||||
}
|
||||
|
||||
// n bytes already in buf
|
||||
if (r.wp - r.rp) >= n {
|
||||
buf = r.buf[r.rp : r.rp+n : r.rp+n]
|
||||
r.rp += n
|
||||
return buf, err
|
||||
}
|
||||
|
||||
// buf is smaller than requested number of bytes
|
||||
if len(r.buf) < n {
|
||||
bigBuf := iobufpool.Get(n)
|
||||
r.wp = copy(bigBuf, r.buf[r.rp:r.wp])
|
||||
r.rp = 0
|
||||
iobufpool.Put(r.buf)
|
||||
r.buf = bigBuf
|
||||
}
|
||||
|
||||
// buf is large enough, but need to shift filled area to start to make enough contiguous space
|
||||
minReadCount := n - (r.wp - r.rp)
|
||||
if (len(r.buf) - r.wp) < minReadCount {
|
||||
r.wp = copy(r.buf, r.buf[r.rp:r.wp])
|
||||
r.rp = 0
|
||||
}
|
||||
|
||||
// Read at least the required number of bytes from the underlying io.Reader
|
||||
readBytesCount, err := io.ReadAtLeast(r.r, r.buf[r.wp:], minReadCount)
|
||||
r.wp += readBytesCount
|
||||
// fmt.Println("read", n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buf = r.buf[r.rp : r.rp+n : r.rp+n]
|
||||
r.rp += n
|
||||
return buf, nil
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"math/rand"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) {
|
||||
server := &bytes.Buffer{}
|
||||
r := newChunkReader(server, 4)
|
||||
|
||||
src := []byte{1, 2, 3, 4}
|
||||
server.Write(src)
|
||||
|
||||
n1, err := r.Next(2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if bytes.Compare(n1, src[0:2]) != 0 {
|
||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:2], n1)
|
||||
}
|
||||
|
||||
n2, err := r.Next(2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if bytes.Compare(n2, src[2:4]) != 0 {
|
||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[2:4], n2)
|
||||
}
|
||||
|
||||
if bytes.Compare(r.buf[:len(src)], src) != 0 {
|
||||
t.Fatalf("Expected r.buf to be %v, but it was %v", src, r.buf)
|
||||
}
|
||||
|
||||
_, err = r.Next(0) // Trigger the buffer reset.
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if r.rp != 0 {
|
||||
t.Fatalf("Expected r.rp to be %v, but it was %v", 0, r.rp)
|
||||
}
|
||||
if r.wp != 0 {
|
||||
t.Fatalf("Expected r.wp to be %v, but it was %v", 0, r.wp)
|
||||
}
|
||||
}
|
||||
|
||||
type randomReader struct {
|
||||
rnd *rand.Rand
|
||||
}
|
||||
|
||||
// Read reads a random number of random bytes.
|
||||
func (r *randomReader) Read(p []byte) (n int, err error) {
|
||||
n = r.rnd.Intn(len(p) + 1)
|
||||
return r.rnd.Read(p[:n])
|
||||
}
|
||||
|
||||
func TestChunkReaderNextFuzz(t *testing.T) {
|
||||
rr := &randomReader{rnd: rand.New(rand.NewSource(1))}
|
||||
r := newChunkReader(rr, 8192)
|
||||
|
||||
randomSizes := rand.New(rand.NewSource(0))
|
||||
|
||||
for i := 0; i < 100000; i++ {
|
||||
size := randomSizes.Intn(16384) + 1
|
||||
buf, err := r.Next(size)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(buf) != size {
|
||||
t.Fatalf("Expected to get %v bytes but got %v bytes", size, len(buf))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type Close struct {
|
||||
ObjectType byte // 'S' = prepared statement, 'P' = portal
|
||||
Name string
|
||||
}
|
||||
|
||||
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||
func (*Close) Frontend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *Close) Decode(src []byte) error {
|
||||
if len(src) < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "Close"}
|
||||
}
|
||||
|
||||
dst.ObjectType = src[0]
|
||||
rp := 1
|
||||
|
||||
idx := bytes.IndexByte(src[rp:], 0)
|
||||
if idx != len(src[rp:])-1 {
|
||||
return &invalidMessageFormatErr{messageType: "Close"}
|
||||
}
|
||||
|
||||
dst.Name = string(src[rp : len(src)-1])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Close) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'C')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.ObjectType)
|
||||
dst = append(dst, src.Name...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src Close) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ObjectType string
|
||||
Name string
|
||||
}{
|
||||
Type: "Close",
|
||||
ObjectType: string(src.ObjectType),
|
||||
Name: src.Name,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *Close) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
ObjectType string
|
||||
Name string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(msg.ObjectType) != 1 {
|
||||
return errors.New("invalid length for Close.ObjectType")
|
||||
}
|
||||
|
||||
dst.ObjectType = byte(msg.ObjectType[0])
|
||||
dst.Name = msg.Name
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type CloseComplete struct{}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*CloseComplete) Backend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *CloseComplete) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CloseComplete) Encode(dst []byte) []byte {
|
||||
return append(dst, '3', 0, 0, 0, 4)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src CloseComplete) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "CloseComplete",
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type CommandComplete struct {
|
||||
CommandTag []byte
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*CommandComplete) Backend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *CommandComplete) Decode(src []byte) error {
|
||||
idx := bytes.IndexByte(src, 0)
|
||||
if idx == -1 {
|
||||
return &invalidMessageFormatErr{messageType: "CommandComplete", details: "unterminated string"}
|
||||
}
|
||||
if idx != len(src)-1 {
|
||||
return &invalidMessageFormatErr{messageType: "CommandComplete", details: "string terminated too early"}
|
||||
}
|
||||
|
||||
dst.CommandTag = src[:idx]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CommandComplete) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'C')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.CommandTag...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src CommandComplete) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
CommandTag string
|
||||
}{
|
||||
Type: "CommandComplete",
|
||||
CommandTag: string(src.CommandTag),
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *CommandComplete) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
CommandTag string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dst.CommandTag = []byte(msg.CommandTag)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type CopyBothResponse struct {
|
||||
OverallFormat byte
|
||||
ColumnFormatCodes []uint16
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*CopyBothResponse) Backend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *CopyBothResponse) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 3 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
|
||||
}
|
||||
|
||||
overallFormat := buf.Next(1)[0]
|
||||
|
||||
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
if buf.Len() != columnCount*2 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
|
||||
}
|
||||
|
||||
columnFormatCodes := make([]uint16, columnCount)
|
||||
for i := 0; i < columnCount; i++ {
|
||||
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
|
||||
}
|
||||
|
||||
*dst = CopyBothResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CopyBothResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'W')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
dst = append(dst, src.OverallFormat)
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||
for _, fc := range src.ColumnFormatCodes {
|
||||
dst = pgio.AppendUint16(dst, fc)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src CopyBothResponse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ColumnFormatCodes []uint16
|
||||
}{
|
||||
Type: "CopyBothResponse",
|
||||
ColumnFormatCodes: src.ColumnFormatCodes,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *CopyBothResponse) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
OverallFormat string
|
||||
ColumnFormatCodes []uint16
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(msg.OverallFormat) != 1 {
|
||||
return errors.New("invalid length for CopyBothResponse.OverallFormat")
|
||||
}
|
||||
|
||||
dst.OverallFormat = msg.OverallFormat[0]
|
||||
dst.ColumnFormatCodes = msg.ColumnFormatCodes
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package pgproto3_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestEncodeDecode(t *testing.T) {
|
||||
srcBytes := []byte{'W', 0x00, 0x00, 0x00, 0x0b, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01}
|
||||
dstResp := pgproto3.CopyBothResponse{}
|
||||
err := dstResp.Decode(srcBytes[5:])
|
||||
assert.NoError(t, err, "No errors on decode")
|
||||
dstBytes := []byte{}
|
||||
dstBytes = dstResp.Encode(dstBytes)
|
||||
assert.EqualValues(t, srcBytes, dstBytes, "Expecting src & dest bytes to match")
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type CopyData struct {
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*CopyData) Backend() {}
|
||||
|
||||
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||
func (*CopyData) Frontend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *CopyData) Decode(src []byte) error {
|
||||
dst.Data = src
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CopyData) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'd')
|
||||
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
|
||||
dst = append(dst, src.Data...)
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src CopyData) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Data string
|
||||
}{
|
||||
Type: "CopyData",
|
||||
Data: hex.EncodeToString(src.Data),
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *CopyData) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
Data string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dst.Data = []byte(msg.Data)
|
||||
return nil
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user